commit
c523c4f0c7
|
@ -1,34 +0,0 @@
|
||||||
# Configuration for lock-threads - https://github.com/dessant/lock-threads
|
|
||||||
|
|
||||||
# Number of days of inactivity before a closed issue or pull request is locked
|
|
||||||
daysUntilLock: 30
|
|
||||||
|
|
||||||
# Issues and pull requests with these labels will not be locked. Set to `[]` to disable
|
|
||||||
exemptLabels: []
|
|
||||||
|
|
||||||
# Label to add before locking, such as `outdated`. Set to `false` to disable
|
|
||||||
lockLabel: false
|
|
||||||
|
|
||||||
# Comment to post before locking. Set to `false` to disable
|
|
||||||
lockComment: >
|
|
||||||
This thread has been automatically locked since there has not been
|
|
||||||
any recent activity after it was closed. Please open a new issue for
|
|
||||||
related bugs.
|
|
||||||
|
|
||||||
# Assign `resolved` as the reason for locking. Set to `false` to disable
|
|
||||||
setLockReason: false
|
|
||||||
|
|
||||||
# Limit to only `issues` or `pulls`
|
|
||||||
only: issues
|
|
||||||
|
|
||||||
# Optionally, specify configuration settings just for `issues` or `pulls`
|
|
||||||
# issues:
|
|
||||||
# exemptLabels:
|
|
||||||
# - help-wanted
|
|
||||||
# lockLabel: outdated
|
|
||||||
|
|
||||||
# pulls:
|
|
||||||
# daysUntilLock: 30
|
|
||||||
|
|
||||||
# Repository to extend settings from
|
|
||||||
# _extends: repo
|
|
|
@ -0,0 +1,58 @@
|
||||||
|
# Onnx runtime module
|
||||||
|
|
||||||
|
## Status
|
||||||
|
Proposed
|
||||||
|
|
||||||
|
Proposed by: Adam Gibson (23-09-2020)
|
||||||
|
|
||||||
|
Discussed with: saudet
|
||||||
|
|
||||||
|
## Context
|
||||||
|
|
||||||
|
We need a way of providing nd4j a way of running onnx modules
|
||||||
|
that is easily compatible with the onnx community. The gold standard for this
|
||||||
|
is is using [onnxruntime](https://github.com/microsoft/onnxruntime/blob/master/docs/Java_API.md).
|
||||||
|
|
||||||
|
|
||||||
|
## Decision
|
||||||
|
|
||||||
|
We will use javacpp's onnxruntime bindings in a similar manner to [nd4j-tensorflow](../nd4j-tensorflow)
|
||||||
|
allowing nd4j to be used as an ndarray format that interops with onnxruntime.
|
||||||
|
|
||||||
|
We will implement a simple api similar to the [GraphRunner](../nd4j-tensorflow/src/main/java/org/nd4j/tensorflow/conversion/graphrunner/GraphRunner.java)
|
||||||
|
This will sit on top of javacpp's lower level onnxruntime bindings.
|
||||||
|
|
||||||
|
This module will follow a similar structure to the nd4j-tensorflow module
|
||||||
|
focusing on INDArrays as a data interchange format, but otherwise pass execution
|
||||||
|
down to onnxruntime.
|
||||||
|
|
||||||
|
|
||||||
|
The main api to the graph runner works as follows:
|
||||||
|
|
||||||
|
```java
|
||||||
|
try(GraphRunner runner = new GraphRunner(...)) {
|
||||||
|
Map<String,INDArray> inputs = new HashMap<>();
|
||||||
|
// ..initialize inputs
|
||||||
|
Map<String,INDArray> outputs = runner.run(inputs);
|
||||||
|
// process outputs...
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
The core logic will contain the following components:
|
||||||
|
|
||||||
|
1. Loading onnx pb files
|
||||||
|
2. A graph runner in similar nature to nd4j-tensorflow
|
||||||
|
3. Interop with onnxruntime's version of an ndarray/tensor
|
||||||
|
|
||||||
|
Using different accelerators/backends
|
||||||
|
-----------------------------------------
|
||||||
|
|
||||||
|
Similar to nd4j-tensorflow which uses javacpp for the specific version of
|
||||||
|
tensorflow to use, this module will rely on the user picking the right dependency
|
||||||
|
to link against. Different builds of cpu, gpu, .. exist [here](https://repo1.maven.org/maven2/org/bytedeco/tensorflow/1.15.3-1.5.4/)
|
||||||
|
The equivalent of this in onnxruntime can be found [here](https://repo1.maven.org/maven2/org/bytedeco/onnxruntime/1.4.0-1.5.4/)
|
||||||
|
|
||||||
|
The user will need to include the version of onnxruntime they wish to use
|
||||||
|
similar to how you link against a particular implementation in a c library
|
||||||
|
or include a backend in nd4j. This will happen via maven.
|
||||||
|
|
|
@ -0,0 +1,251 @@
|
||||||
|
# Import IR
|
||||||
|
|
||||||
|
## Status
|
||||||
|
|
||||||
|
Proposed
|
||||||
|
|
||||||
|
Proposed by: Adam Gibson (28-09-2020)
|
||||||
|
|
||||||
|
Discussed with: Paul Dubs
|
||||||
|
|
||||||
|
## Context
|
||||||
|
|
||||||
|
Currently, there is a gap in the way samediff/nd4j operations are implemented
|
||||||
|
vs. how other frameworks represent their models.
|
||||||
|
|
||||||
|
Keras, Tensorflow, and Pytorch use an attribute based format with names. Interop
|
||||||
|
between Onnx ,Tensorflow, and Keras tends to follow the following formula:
|
||||||
|
|
||||||
|
1. Map names to equivalent names in the other framework for each operation
|
||||||
|
configuration. Names being both op names and associated attributes of the
|
||||||
|
operations such as in Conv2D where you have strides, kernel sizes.
|
||||||
|
2. Map input/output tensors to the equivalent tensor type in each framework.
|
||||||
|
3. Setup the complete graph in the equivalent framework. Sometimes the
|
||||||
|
framework's concepts don't map 1 to 1. They should output equivalent results
|
||||||
|
regardless though. In order to do this, sometimes the framework needs to
|
||||||
|
add/remove operations in order to produce equivalent output in a different
|
||||||
|
graph. The [tensorflow onnx import](https://github.com/onnx/tensorflow-onnx#how-tf2onnx-works)
|
||||||
|
is a good example of this.
|
||||||
|
|
||||||
|
Samediff/nd4j have their internal op representations as a set of ordered
|
||||||
|
arguments for execution in the form of:
|
||||||
|
|
||||||
|
1. t arguments: floating point arguments (float, double,..)
|
||||||
|
2. integer arguments: integer arguments (long, integer)
|
||||||
|
3. boolean argument: boolean arguments
|
||||||
|
4. data type arguments: data types for input/output
|
||||||
|
5. input arguments: ndarrays for input
|
||||||
|
6. output arguments: often optional (dynamically created) output ndarray
|
||||||
|
arguments. If the user wants to pass in outputs to control memory, they are
|
||||||
|
allowed to do so.
|
||||||
|
7. axis arguments: Integer arguments that represent the dimension(s) for an
|
||||||
|
operation to be executed on.
|
||||||
|
|
||||||
|
[Reference implementation](https://github.com/KonduitAI/deeplearning4j/blob/master/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/DynamicCustomOp.java#L58)
|
||||||
|
|
||||||
|
This maps well enough for execution, but not for file formats.
|
||||||
|
|
||||||
|
## Related Work
|
||||||
|
This may encourage future work to be done to the
|
||||||
|
[samediff file format](https://github.com/KonduitAI/deeplearning4j/blob/master/nd4j/ADRs/0001-SameDiff_File_Format.md).
|
||||||
|
Implementation of serialization of file format via flatbuffers can be found
|
||||||
|
[here](https://github.com/eclipse/deeplearning4j/blob/master/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/SameDiff.java#L4748)
|
||||||
|
Of note here for prior work is the
|
||||||
|
[current code generation]
|
||||||
|
(https://github.com/KonduitAI/dl4j-dev-tools/blob/master/codegen/src/main/ops/org/nd4j/codegen/ops/CNN.kt#L28)
|
||||||
|
|
||||||
|
The definitions for the kotlin dsl can be found
|
||||||
|
[here](https://github.com/KonduitAI/dl4j-dev-tools/blob/master/codegen/src/main/kotlin/org/nd4j/codegen/dsl/OpBuilder.kt)
|
||||||
|
|
||||||
|
|
||||||
|
While it does have the intended description,
|
||||||
|
it’s kotlin specific and is only available for a very small subset
|
||||||
|
of the ops where pre-created objects were created
|
||||||
|
for specific operations. The goal of this ADR is to expand upon
|
||||||
|
that and make it language agnostic by providing this information in a
|
||||||
|
neutral file format that has code generation with it.
|
||||||
|
|
||||||
|
Current code generation efforts can be augmented using this file format.
|
||||||
|
More on this decision making can be found [here](https://github.com/KonduitAI/dl4j-dev-tools/blob/master/codegen/adr/0007-configuration_objects.md)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
## Proposal
|
||||||
|
|
||||||
|
We expose a symbol based mapping in libnd4j in protobuf format, similar to how
|
||||||
|
other frameworks are doing it, as a bridge/intermediary format.
|
||||||
|
|
||||||
|
This makes it easier to implement interop with the other frameworks, because it
|
||||||
|
adds the necessary information that is needed to be able to define a direct
|
||||||
|
mapping.
|
||||||
|
|
||||||
|
This could be a future file format depending on how the framework evolves. For
|
||||||
|
now, this is considered a work around for making writing import code easier/more
|
||||||
|
portable.
|
||||||
|
|
||||||
|
Similar to [ONNX](https://onnx.ai/) and [Tensorflow](https://tensorflow.org/)
|
||||||
|
we use protobuf to express an attribute based file format and map
|
||||||
|
samediff/nd4j operations to this format.
|
||||||
|
|
||||||
|
We use a translation layer that handles mapping from attributes to the ordered
|
||||||
|
arguments approach reflected in samediff/nd4j.
|
||||||
|
|
||||||
|
For each operation, we define a mapping process to/from this attribute format to the
|
||||||
|
order based execution format.
|
||||||
|
|
||||||
|
A separate but similar set of rules are used for mapping ndarrays.
|
||||||
|
|
||||||
|
This attribute based format is an Intermediary Representation that we then
|
||||||
|
"compile" to the equivalent calls in libnd4j.
|
||||||
|
|
||||||
|
|
||||||
|
The format definitions for the IR can be found [here](./src/main/proto/nd4j/nd4j.proto)
|
||||||
|
|
||||||
|
## Consequences
|
||||||
|
|
||||||
|
Migration to an attribute based import format makes working with other deep
|
||||||
|
learning frameworks easier in the future.
|
||||||
|
|
||||||
|
|
||||||
|
### Drawbacks
|
||||||
|
|
||||||
|
1. Yet another file format.
|
||||||
|
2. Risk migrating to new file format in the future.
|
||||||
|
3. A lot of up front manual work to index set of current operations.
|
||||||
|
4. Backwards compatibility: yet another thing to maintain. We wrote converters
|
||||||
|
for any forward compatibility. We address this by specifying an opset schema
|
||||||
|
scheme similar to onnx.
|
||||||
|
|
||||||
|
### Advantages
|
||||||
|
|
||||||
|
1. Easy to maintain.
|
||||||
|
2. Backwards compatible.
|
||||||
|
3. Easily interops with existing other deep learning frameworks.
|
||||||
|
4. No additional dependencies from what's already normal.
|
||||||
|
5. Protobuf allows easy code generation for other languages.
|
||||||
|
6. Industry standard conventions being used over proprietary tooling reducing
|
||||||
|
friction for adoption for people coming from other frameworks
|
||||||
|
7. Straightforward mapping of arguments for import
|
||||||
|
8. Provide an easy bridge to existing libnd4j
|
||||||
|
9. Allow automation of op descriptors in any language that would understand how
|
||||||
|
to pass data to the c++ library.
|
||||||
|
|
||||||
|
|
||||||
|
## Appendix A: Comparison with other Frameworks, implicit vs. explicit
|
||||||
|
|
||||||
|
We can find the existing attributes from the conventions of the
|
||||||
|
libnd4j code base. The libnd4j [conv1d.cpp](https://github.com/KonduitAI/deeplearning4j/blob/master/libnd4j/include/ops/declarable/generic/nn/convo/conv1d.cpp#L104)
|
||||||
|
file contains the following declaration:
|
||||||
|
|
||||||
|
```
|
||||||
|
auto inputShapeInfo = inputShape->at(0);
|
||||||
|
auto weightsShapeInfo = inputShape->at(1);
|
||||||
|
Nd4jLong const* biasShapeInfo = block.width() > 2 ? inputShape->at(2) : nullptr;
|
||||||
|
|
||||||
|
int kW = INT_ARG(0) > 0 ? INT_ARG(0) : static_cast<int>(shape::sizeAt(weightsShapeInfo, 0)); // filter(kernel) width
|
||||||
|
int sW = INT_ARG(1); // strides width
|
||||||
|
int pW = INT_ARG(2); // paddings width
|
||||||
|
int dW = INT_ARG(3); // dilations width
|
||||||
|
int paddingMode = INT_ARG(4); // 0-VALID, 1-SAME
|
||||||
|
int isNCW = block.getIArguments()->size() > 5 ? !INT_ARG(5) : 1; // INT_ARG(4): 1-NWC, 0-NCW
|
||||||
|
int wFormat = block.getIArguments()->size() > 6 ? INT_ARG(6) : 0; // 0 - [kW, iC, oC], 1 - [oC, iC, kW], 2 - [oC, kW, iC]
|
||||||
|
```
|
||||||
|
|
||||||
|
We can see that there are macros in the libnd4j code base, which reflect how
|
||||||
|
each argument is accessed. Each list of arguments has an expected order, that we
|
||||||
|
need to explicitly map to a parseable structure.
|
||||||
|
|
||||||
|
In comparison, the
|
||||||
|
[onnx Convolution operator](https://github.com/onnx/onnx/blob/master/docs/Operators.md#Conv)
|
||||||
|
has *explicit* attributes of various types such as lists of ints and named
|
||||||
|
tensors.
|
||||||
|
|
||||||
|
As shown above, these concepts exist internally in the operations and layers
|
||||||
|
themselves in nd4j/samediff, but they are not exposed directly to the user.
|
||||||
|
|
||||||
|
|
||||||
|
A theoretical op descriptor from libnd4j is as follows:
|
||||||
|
```java
|
||||||
|
private String name;
|
||||||
|
private int nIn,nOut,tArgs,iArgs;
|
||||||
|
private boolean inplaceAble;
|
||||||
|
private List<String> inArgNames;
|
||||||
|
private List<String> outArgNames;
|
||||||
|
private List<String> tArgNames;
|
||||||
|
private List<String> iArgNames;
|
||||||
|
private List<String> bArgNames;
|
||||||
|
private OpDeclarationType opDeclarationType;
|
||||||
|
|
||||||
|
public enum OpDeclarationType {
|
||||||
|
CUSTOM_OP_IMPL,
|
||||||
|
BOOLEAN_OP_IMPL,
|
||||||
|
LIST_OP_IMPL,
|
||||||
|
LOGIC_OP_IMPL,
|
||||||
|
OP_IMPL,
|
||||||
|
DIVERGENT_OP_IMPL,
|
||||||
|
CONFIGURABLE_OP_IMPL,
|
||||||
|
REDUCTION_OP_IMPL,
|
||||||
|
BROADCASTABLE_OP_IMPL,
|
||||||
|
BROADCASTABLE_BOOL_OP_IMPL
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
It contains all the op declarations and fields associated with a descriptor.
|
||||||
|
|
||||||
|
In the libnd4j code base, we represent the op descriptor types above
|
||||||
|
*implicitly* through validation as well as the different macros present in the
|
||||||
|
code base representing what an op execution looks like.
|
||||||
|
|
||||||
|
Validation for what can be present in the various names can be found
|
||||||
|
[here](https://github.com/KonduitAI/deeplearning4j/blob/master/libnd4j/include/ops/declarable/impl/DeclarableOp.cpp#L734-L765)
|
||||||
|
|
||||||
|
The set of macro declarations in libnd4j can be found
|
||||||
|
[here](https://github.com/eclipse/deeplearning4j/blob/master/libnd4j/include/system/op_boilerplate.h)
|
||||||
|
|
||||||
|
|
||||||
|
## Appendix B: Format Comparison to other frameworks
|
||||||
|
|
||||||
|
An add op in tensorflow looks like:
|
||||||
|
|
||||||
|
```
|
||||||
|
op {
|
||||||
|
name: "Add"
|
||||||
|
input_arg {
|
||||||
|
name: "x"
|
||||||
|
type_attr: "T"
|
||||||
|
}
|
||||||
|
input_arg {
|
||||||
|
name: "y"
|
||||||
|
type_attr: "T"
|
||||||
|
}
|
||||||
|
output_arg {
|
||||||
|
name: "z"
|
||||||
|
type_attr: "T"
|
||||||
|
}
|
||||||
|
attr {
|
||||||
|
name: "T"
|
||||||
|
type: "type"
|
||||||
|
allowed_values {
|
||||||
|
list {
|
||||||
|
type: DT_BFLOAT16
|
||||||
|
type: DT_HALF
|
||||||
|
type: DT_FLOAT
|
||||||
|
type: DT_DOUBLE
|
||||||
|
type: DT_UINT8
|
||||||
|
type: DT_INT8
|
||||||
|
type: DT_INT16
|
||||||
|
type: DT_INT32
|
||||||
|
type: DT_INT64
|
||||||
|
type: DT_COMPLEX64
|
||||||
|
type: DT_COMPLEX128
|
||||||
|
type: DT_STRING
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
Onnx’s add can be found here
|
||||||
|
https://github.com/onnx/onnx/blob/master/docs/Operators.md#Add
|
||||||
|
|
||||||
|
Onnx and tensorflow are purely attribute based formats.
|
|
@ -0,0 +1,196 @@
|
||||||
|
|
||||||
|
|
||||||
|
# Libnd4j NdArray padded buffers, strides for Arm_Compute Library wrapper
|
||||||
|
|
||||||
|
## Status
|
||||||
|
PROPOSED
|
||||||
|
|
||||||
|
Proposed by: Abdelrauf (23/09/2020)
|
||||||
|
|
||||||
|
Discussed with:
|
||||||
|
|
||||||
|
## Context
|
||||||
|
During the integration process of our library with arm_compute, I faced that our NdArray strides are not flexible. (i.e it cant be set properly without **special and manual handling**).
|
||||||
|
Let's say our Nd Array shapes are `[3,4,2]` and the last index is moving faster (i.e C order). Then our strides will be `[ 8, 2, 1 ]`.
|
||||||
|
As far as I know, our last index stride can be different (called as ews), but overall strides should follow the cyclic strict rule of dependency.:
|
||||||
|
|
||||||
|
strides[index-1] = strides[index] * shapes[index];
|
||||||
|
On arm_compute besides strides there is also Padding `{top, right, bottom, left}` that can be used to increase strides and change offsets adn as well as total size. its mostly done for performance reasons. As from above we can see that **its just hosting NdArray shape in the buffer of the bigger NdArray shape**. In arm_compute those paddings applied to last 2 dimensions (on NCHW it will be H and W}. We can define it like this:
|
||||||
|
|
||||||
|
newH = pad.top + H + pad.bottom;
|
||||||
|
newW = pad.left + W + pad.right;
|
||||||
|
|
||||||
|
so strides will be calculated for the shape `{N,C, newH, newW}` and offset of the first element will be:
|
||||||
|
|
||||||
|
offset = pad.left * strideOfNewW + pad.top * strideOfNewH
|
||||||
|
|
||||||
|
|
||||||
|
## Proposal
|
||||||
|
Introduce helper functions checking below case :
|
||||||
|
|
||||||
|
strides[index-1] >= strides[index] * shapes[index];
|
||||||
|
|
||||||
|
Add **generic method for the padded buffer** ( we can simulate arm_compute 2d padding and more)
|
||||||
|
|
||||||
|
int paddings[rank] = {...}; // total padding
|
||||||
|
int paddingOffsets[rank] = {...}; //offset indices of the first element
|
||||||
|
|
||||||
|
This could be used to padd ndArray shapes and calculate strides based on it while keeping original shape, paddOffsets could be used to determine the beginning of the first element. Though this interface ismore generic its drawback is that on armcompute its possible to padd 1d into 2D while keeping rank but on this one we should supply 2d with one of its dimensions being 1.
|
||||||
|
|
||||||
|
|
||||||
|
## Consequences
|
||||||
|
|
||||||
|
1. All tests that were not tested **against subArray** could break. So they will require a fix
|
||||||
|
2. Writing additional test cases
|
||||||
|
|
||||||
|
### Advantages
|
||||||
|
- alignment possibility for CPUs where alignment is required for speed and vectorization.
|
||||||
|
- easier integration with libraries. in the case of arm_compute, the last two dimensions are sometimes padded.
|
||||||
|
|
||||||
|
|
||||||
|
### Disadvantages
|
||||||
|
- its advantage is not so big for modern CPUs where unaligned vector loads possible
|
||||||
|
- exposing it for users is not desirable: (excessive usage creates unnecessary memory spaces and performance problems)
|
||||||
|
- could result in unnecessary complications for some function implementations
|
||||||
|
- possibility of requiring additional tests and fixes
|
||||||
|
|
||||||
|
|
||||||
|
### Technical details about the addition of this functionality into NdArray
|
||||||
|
A little investigation showed that the current NdArray actually has constructors to specify strides.
|
||||||
|
Here is the constructor that could be used
|
||||||
|
[ShapeDescriptor.h](https://github.com/KonduitAI/deeplearning4j/blob/qwr_armcompute/libnd4j/include/array/ShapeDescriptor.h)
|
||||||
|
Here are additions into ShapeDescriptor:
|
||||||
|
- validate() //it willbe used for validation of strides and et cetera. This way we can create NdArray by just using ShapeDescriptor alone. And it will be more flexible with correctness
|
||||||
|
- allocLength() //returns minimal buffer size for the given strides and shapes. (this was missing on libnd4j side)
|
||||||
|
- paddedBufferDescriptor(..) //helper method for returning ShapeDescriptor for padded buffer.
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
#### [NdArrayFactory](https://github.com/KonduitAI/deeplearning4j/blob/qwr_armcompute/libnd4j/include/array/impl/NDArrayFactory.cpp#L39-L80)
|
||||||
|
The method that is using ShapeDescriptor validation, and ShapeDescriptor paddedBuffer .
|
||||||
|
|
||||||
|
Furthermore to indicate that shape of the NdArray is using paddedBuffer we will flag with `ARRAY_HAS_PADDED_BUFFER` . so it will be possible to know if NdArray is padded.
|
||||||
|
|
||||||
|
Furthermore, it is still possible to recover Paddings from the allocation size of the padded NdArray. But its not an easy task to get PaddingOffsets from offset and recovered full shape. Thats why it requires storing them. Fortunately, for arm_compute tensors **manual padding** we just need to know **total size and the offset** of the first element. So we dont need to change internals that much
|
||||||
|
|
||||||
|
As our padded Buffer follows the strict ews() rule instead of the loose one. Paddings will be obtained from this rule:
|
||||||
|
|
||||||
|
strides[index-1] == strides[index] * shapes[index];
|
||||||
|
|
||||||
|
pseudo code for C order:
|
||||||
|
|
||||||
|
for (int j = rank - 1; j >= 0; j--) {
|
||||||
|
shapesAfterPadding[j] = strides[j - 1] / strides[j]
|
||||||
|
}
|
||||||
|
shapesAfterPadding[0] = buffer.AllocSize / strides[0]
|
||||||
|
//Paddings for index in 0..rank-1
|
||||||
|
paddings[index] = shapesAfterPadding[index] - shape[index]
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
### Technical notes on arm_compute library
|
||||||
|
|
||||||
|
The main drive for the above proposal to avoid unnecessary performance and memory allocation. And also we should keep on mind :
|
||||||
|
- in each newer version of arm_compute there are new implementations in which the padding requirements were removed.
|
||||||
|
|
||||||
|
This **can diminish the necessity for the proposed changes** if such versions of the desired functions are implemented.
|
||||||
|
|
||||||
|
##### Notes on arm_compute tensors
|
||||||
|
Arm_compute tensors are mostly 3d 4d with max 6d dimensions.
|
||||||
|
So lets show C order NdArray({2,2,5,5},)
|
||||||
|
|
||||||
|
shapeInfo shapeInfo: [4, 2,2,5,5, 50,25,5,1, 8192,1,99]
|
||||||
|
|
||||||
|
of float type and its arm_compute tensor equivalent :
|
||||||
|
- first of all, we map NdArray dataTypes into arm_compute [armcomputeUtils.cpp#L35-L75](https://github.com/KonduitAI/deeplearning4j/blob/qwr_armcompute/libnd4j/include/ops/declarable/platform/armcompute/armcomputeUtils.cpp#L35-L75)
|
||||||
|
- it will be with the reversed shape. **`NdArray{n,z,y,x} -> TensorShape{x,y,z,n}`**
|
||||||
|
-
|
||||||
|
|
||||||
|
total length in bytes: 400
|
||||||
|
shapes: 5,5,2,2,1,1,
|
||||||
|
strides in bytes: 4,20,100,200,0,0,
|
||||||
|
strides as elements: (1,5,25,50)
|
||||||
|
|
||||||
|
Paddings in arm_compute Tensors. `Padding{left,right, top, bottom}`
|
||||||
|
As both OpenCL and NEON use vector loads and stores instructions to access the data in buffers, so in order to avoid having special cases to handle for the borders all the images and tensors used in this library must be padded
|
||||||
|
There are different ways padding can be calculated:
|
||||||
|
|
||||||
|
- Accurate padding.
|
||||||
|
in this case it is importan to configure and then after that to allocate
|
||||||
|
- auto padding.
|
||||||
|
It guarantees that the allocation will have enough padding to run any of the provided functions
|
||||||
|
- no padding
|
||||||
|
- manual padding
|
||||||
|
|
||||||
|
#### how padding affects strides offset and total size
|
||||||
|
in arm_compute Tensor:
|
||||||
|
it's 2d {Width Height} can be padded and thats why it affects strides.
|
||||||
|
Lets show it with the picture:
|
||||||
|
|
||||||
|
\ top /
|
||||||
|
\ _____________________ /
|
||||||
|
left | ^ | right
|
||||||
|
| Width |
|
||||||
|
| <-Height |
|
||||||
|
| |
|
||||||
|
| |
|
||||||
|
----------------------
|
||||||
|
/ bottom \
|
||||||
|
/ \
|
||||||
|
|
||||||
|
Here is the stride calculation pseudo code for Tensor {x,y,z}
|
||||||
|
|
||||||
|
stride_x = element_size(); //float will be 4
|
||||||
|
stride_y = (padding.left + _tensor_shape[0] + padding.right) * stride_x;
|
||||||
|
stride_z = (padding.top + _tensor_shape[1] + padding.bottom) * stride_y;
|
||||||
|
|
||||||
|
required_offset_first_element = padding.left * stride_x + padding.top * stride_y;
|
||||||
|
|
||||||
|
|
||||||
|
For example: if arm_tensor had `padding: left 0, right 1, top 0, bottom 1` :
|
||||||
|
|
||||||
|
total: 576
|
||||||
|
shapes: 5,5,2,2,1,1,
|
||||||
|
strides in bytes: 4,24,144,288,0,0,
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
### Notes on the current wrapper implementation
|
||||||
|
|
||||||
|
This is a simple wrapper for arm functions with input and output tensors:
|
||||||
|
[armcomputeUtils.h#L95-L165](https://github.com/KonduitAI/deeplearning4j/blob/qwr_armcompute/libnd4j/include/ops/declarable/platform/armcompute/armcomputeUtils.h#L85-L133)
|
||||||
|
|
||||||
|
From above we could see :
|
||||||
|
- we had to flag padded NdArrays so that we can use manual padding version of arm_compute Tensors
|
||||||
|
- when padding information is changed during configure process we **have to copy** our NdArray buffer into **new allocated** arm_tensor buffer. and the same with the output.
|
||||||
|
- for cases without padding , arm_tensor could use our buffer if its ews()==1.
|
||||||
|
- its desired to call configure and run separately to avoid multiple configure calls ( this is not discussed here, for now)
|
||||||
|
|
||||||
|
|
||||||
|
## arm_compute wrapper proposal
|
||||||
|
|
||||||
|
|
||||||
|
So from above we can conclude that we have two options:
|
||||||
|
|
||||||
|
- creating our NdArray with auto_padding strides and modifying the current wrapper. Still configure will be called foreach run. But with auto padding it is using more memory for small ndarrays
|
||||||
|
- to be able to use accurate padding properly we should call configure before NdArray memory allocation so that we can import it. For that I should investigate graph, DeclarableOps and NdArrays usage lifecycle.
|
||||||
|
|
||||||
|
Here is auto padding:
|
||||||
|
|
||||||
|
// Some kernels compute 32 elements at the time, worst case scenario they
|
||||||
|
// will read 32 values after the last element
|
||||||
|
extra_pad_x = _tensor_shape.num_dimensions() < 1 ? 0 : 32;
|
||||||
|
pad_x = _tensor_shape.num_dimensions() < 1 ? 0 : 4;
|
||||||
|
pad_y = _tensor_shape.num_dimensions() < 2 ? 0 : 4;
|
||||||
|
|
||||||
|
PaddingSize(pad_y, pad_x + extra_pad_x, pad_y, pad_x);
|
||||||
|
|
||||||
|
## Discussion
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,275 @@
|
||||||
|
# Import IR
|
||||||
|
|
||||||
|
## Status
|
||||||
|
Proposed
|
||||||
|
|
||||||
|
Proposed by: Adam Gibson (28-09-2020)
|
||||||
|
|
||||||
|
Discussed with: N/A
|
||||||
|
|
||||||
|
## Context
|
||||||
|
|
||||||
|
Generally, every neural network file format defines a sequence of operations
|
||||||
|
to execute mathematical operations that comprises a neural network.
|
||||||
|
|
||||||
|
Each element in the sequence is a node that contains information such as the
|
||||||
|
desired operation, and a set of attributes that represent parameters
|
||||||
|
in to the mathematical function to execute.
|
||||||
|
|
||||||
|
In order to write import/export for different frameworks, we need to adapt
|
||||||
|
an attribute based format from various popular deep learning frameworks.
|
||||||
|
Nd4j has a different list based format for operation execution arguments.
|
||||||
|
In the [previous ADR](./Import_IR.md), we added an IR which makes it easier to
|
||||||
|
interop with other frameworks.
|
||||||
|
|
||||||
|
In this ADR, this work is extended to add a file format for
|
||||||
|
describing lists of operations as MappingRules which allow transformations
|
||||||
|
from one framework to another.
|
||||||
|
|
||||||
|
These transformations manipulate protobuf as input and output Nd4j's
|
||||||
|
new OpDescriptor format as output.
|
||||||
|
|
||||||
|
|
||||||
|
##Related work
|
||||||
|
|
||||||
|
See [the import IR](./0003-Import_IR.md)
|
||||||
|
|
||||||
|
## Decision
|
||||||
|
|
||||||
|
We implement a mapping process framework that defines transforms on an input file format.
|
||||||
|
A MappingProcess defines a list of MappingRules which represent a sequence of transformations
|
||||||
|
on each attribute of an op definition.
|
||||||
|
|
||||||
|
To assist in mapping, a mapping context with needed information like rule arguments
|
||||||
|
for transformation, current node, and whole graph are used as input.
|
||||||
|
|
||||||
|
The input is a protobuf file for a specific framework and the output is an op descriptor
|
||||||
|
described [here](./0003-Import_IR.md).
|
||||||
|
|
||||||
|
A MappingRule converts 1 or more attributes in to 1 more or arg definitions. A potential definition
|
||||||
|
can be found in Appendix E.
|
||||||
|
|
||||||
|
Attributes are named values supporting a wide variety of types from floats/doubles
|
||||||
|
to lists of the same primitive types. See Appendix C for a theoretical definition.
|
||||||
|
|
||||||
|
Arg Definitions are the arguments for an OpDescriptor described in [the import IR ADR.](./0003-Import_IR.md)
|
||||||
|
See Appendix D for a potential definition of arg definitions.
|
||||||
|
|
||||||
|
All of this together describes how to implement a framework agnostic
|
||||||
|
interface to convert between a target deep learning framework and the nd4j format.
|
||||||
|
|
||||||
|
|
||||||
|
## Implementation details
|
||||||
|
|
||||||
|
In order to implement proper mapping functionality, a common interface is implemented.
|
||||||
|
Below are the needed common types for mapping:
|
||||||
|
|
||||||
|
1. IRNodeDef: A node definition in a graph
|
||||||
|
2. IRTensor: A tensor type for mapping
|
||||||
|
3. IROpList: A list of operations
|
||||||
|
4. IRAttrDef: An attribute definition
|
||||||
|
5. IRAttrValue: An attribute value
|
||||||
|
6. IROpDef: An op definition for the IR
|
||||||
|
7. IRDataType: A data type
|
||||||
|
8. IRGraph: A graph abstraction
|
||||||
|
|
||||||
|
Each one of these types is a wrapper around a specific framework's input types
|
||||||
|
of the equivalent concepts.
|
||||||
|
|
||||||
|
Each of these wrappers knows how to convert the specific concepts
|
||||||
|
in to the nd4j equivalents for interpretation by a mapper which applies
|
||||||
|
the mapping rules for a particular framework.
|
||||||
|
|
||||||
|
Doing this will allow us to share logic between mappers and making 1 implementation of
|
||||||
|
mapping possible by calling associated getter methods for concepts like data types and nodes.
|
||||||
|
|
||||||
|
## Serialization
|
||||||
|
|
||||||
|
In order to persist rules using protobuf, all rules will know how to serialize themselves.
|
||||||
|
A simple serialize() and load() methods are implemented which covers conversion using
|
||||||
|
interface methods up to the user to implement which describes how to persist the protobuf
|
||||||
|
representation. This applies to any of the relevant functionality such as rules and processes.
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
## Custom types
|
||||||
|
|
||||||
|
Some types will not map 1 to 1 or are directly applicable to nd4j.
|
||||||
|
In order to combat this, when an unknown type is discovered during mapping,
|
||||||
|
adapter functions for specific types must be specified.
|
||||||
|
|
||||||
|
Supported types include:
|
||||||
|
|
||||||
|
1. Long/Int
|
||||||
|
2. Double/Float
|
||||||
|
3. String
|
||||||
|
4. Boolean
|
||||||
|
5. Bytes
|
||||||
|
6. NDArrays
|
||||||
|
|
||||||
|
|
||||||
|
An example:
|
||||||
|
|
||||||
|
A Dim in tensorflow can be mapped to a long in nd4j.
|
||||||
|
|
||||||
|
Shape Information can be a list of longs or multiple lists depending on the
|
||||||
|
context.
|
||||||
|
|
||||||
|
## Consequences
|
||||||
|
### Advantages
|
||||||
|
* Allows a language neutral way of describing a set of transforms necessary
|
||||||
|
for mapping an set of operations found in a graph from one framework to the nd4j format.
|
||||||
|
|
||||||
|
* Allows a straightforward way of writing an interpreter as well as mappers
|
||||||
|
for different frameworks in nd4j in a standardized way.
|
||||||
|
|
||||||
|
* Replaces the old import and makes maintenance of imports/mappers more straightforward.
|
||||||
|
|
||||||
|
### Disadvantages
|
||||||
|
|
||||||
|
* More complexity in the code base instead of a more straightforward java implementation.
|
||||||
|
|
||||||
|
* Risks introducing new errors due to a rewrite
|
||||||
|
|
||||||
|
|
||||||
|
## Appendix A: Contrasting MappingRules with another implementation
|
||||||
|
|
||||||
|
We map names and types to equivalent concepts in each framework.
|
||||||
|
Onnx tensorflow does this with an [attribute converter](https://github.com/onnx/onnx-tensorflow/blob/08e41de7b127a53d072a54730e4784fe50f8c7c3/onnx_tf/common/attr_converter.py)
|
||||||
|
|
||||||
|
This is done by a handler (one for each op).
|
||||||
|
More can be found [here](https://github.com/onnx/onnx-tensorflow/tree/master/onnx_tf/handlers/backend)
|
||||||
|
|
||||||
|
|
||||||
|
## Appendix B: Challenges when mapping nd4j ops
|
||||||
|
|
||||||
|
The above formats are vastly different. Onnx and tensorflow
|
||||||
|
are purely attribute based. Nd4j is index based.
|
||||||
|
This challenge is addressed by the IR by adding names to each property.
|
||||||
|
|
||||||
|
|
||||||
|
In order to actually map these properties, we need to define rules for doing so.
|
||||||
|
Examples of why these mapping rules are needed below:
|
||||||
|
|
||||||
|
1. Different conventions for the same concept. One example that stands out from conv
|
||||||
|
is padding. Padding can be represented as a string or have a boolean that says what a string equals.
|
||||||
|
In nd4j, we represent this as a boolean: isSameMode. We need to do a conversion inline in order
|
||||||
|
to invoke nd4j correctly.
|
||||||
|
|
||||||
|
2. Another issue is implicit concepts. Commonly, convolution requires you to configure a layout
|
||||||
|
of NWHC (Batch size, Height, Width, Channels)
|
||||||
|
or NCHW (Batch size, Channels,Height, Width). Tensorflow allows you to specify it,
|
||||||
|
nd4j also allows you to specify it. Onnx does not.
|
||||||
|
|
||||||
|
A more in depth conversation on this specific issue relating to the
|
||||||
|
2 frameworks can be found [here](https://github.com/onnx/onnx-tensorflow/issues/31)
|
||||||
|
In order to address these challenges, we introduce a MappingRule allowing
|
||||||
|
us to define a series of steps to map the input format to the nd4j format
|
||||||
|
in a language neutral way via a protobuf declaration.
|
||||||
|
|
||||||
|
|
||||||
|
## Appendix C: A theoretical attribute definition
|
||||||
|
```kotlin
|
||||||
|
enum class AttributeValueType {
|
||||||
|
FLOAT,
|
||||||
|
LIST_FLOAT,
|
||||||
|
BYTE,
|
||||||
|
LIST_BYTE,
|
||||||
|
INT,
|
||||||
|
LIST_INT,
|
||||||
|
BOOL,
|
||||||
|
LIST_BOOL,
|
||||||
|
STRING,
|
||||||
|
LIST_STRING
|
||||||
|
}
|
||||||
|
|
||||||
|
interface IRAttribute<ATTRIBUTE_TYPE,ATTRIBUTE_VALUE_TYPE> {
|
||||||
|
|
||||||
|
fun name(): String
|
||||||
|
|
||||||
|
fun floatValue(): Double
|
||||||
|
|
||||||
|
fun listFloatValue(): List<Float>
|
||||||
|
|
||||||
|
fun byteValue(): Byte
|
||||||
|
|
||||||
|
fun listByteValue(): List<Byte>
|
||||||
|
|
||||||
|
fun intValue(): Long
|
||||||
|
|
||||||
|
fun listIntValue(): List<Long>
|
||||||
|
|
||||||
|
fun boolValue(): Boolean
|
||||||
|
|
||||||
|
fun listBoolValue(): List<Boolean>
|
||||||
|
|
||||||
|
fun attributeValueType(): AttributeValueType
|
||||||
|
|
||||||
|
fun internalAttributeDef(): ATTRIBUTE_TYPE
|
||||||
|
|
||||||
|
fun internalAttributeValue(): ATTRIBUTE_VALUE_TYPE
|
||||||
|
}
|
||||||
|
|
||||||
|
```
|
||||||
|
|
||||||
|
## Appendix D: A theoretical kotlin definition of argument descriptors and op descriptors can be found below:
|
||||||
|
```kotlin
|
||||||
|
interface IRArgDef<T,DATA_TYPE> {
|
||||||
|
fun name(): String
|
||||||
|
|
||||||
|
fun description(): String
|
||||||
|
|
||||||
|
fun dataType(): IRDataType<DATA_TYPE>
|
||||||
|
|
||||||
|
fun internalValue(): T
|
||||||
|
|
||||||
|
fun indexOf(): Integer
|
||||||
|
}
|
||||||
|
|
||||||
|
interface IROpDef<T,ARG_DEF_TYPE,DATA_TYPE,ATTRIBUTE_TYPE,ATTRIBUTE_VALUE_TYPE> {
|
||||||
|
fun opName(): String
|
||||||
|
|
||||||
|
fun internalValue(): T
|
||||||
|
|
||||||
|
fun inputArgs(): List<IRArgDef<ARG_DEF_TYPE,DATA_TYPE>>
|
||||||
|
|
||||||
|
fun outputArgs(): List<IRArgDef<ARG_DEF_TYPE,DATA_TYPE>>
|
||||||
|
|
||||||
|
fun attributes(): List<IRAttribute<ATTRIBUTE_TYPE,ATTRIBUTE_VALUE_TYPE>>
|
||||||
|
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
##Appendix E: A theoretical kotlin definition of Mapping Rules, MappingProcess and ArgDef can be found below:
|
||||||
|
```kotlin
|
||||||
|
interface MappingProcess<T,TENSOR_TYPE,ATTRIBUTE_TYPE,ATTRIBUTE_VALUE_TYPE,DATA_TYPE> {
|
||||||
|
fun opName(): String
|
||||||
|
|
||||||
|
fun frameworkVersion(): String
|
||||||
|
|
||||||
|
fun inputFramework(): String
|
||||||
|
|
||||||
|
fun rules(): List<MappingRule<ATTRIBUTE_TYPE,ATTRIBUTE_VALUE_TYPE>>
|
||||||
|
|
||||||
|
|
||||||
|
fun applyProcess(inputNode: IRNode<T,TENSOR_TYPE,ATTRIBUTE_TYPE,ATTRIBUTE_VALUE_TYPE,DATA_TYPE>): OpDeclarationDescriptor
|
||||||
|
|
||||||
|
fun applyProcessReverse(input: OpDeclarationDescriptor): IRNode<T,TENSOR_TYPE,ATTRIBUTE_TYPE,ATTRIBUTE_VALUE_TYPE,DATA_TYPE>
|
||||||
|
|
||||||
|
fun createDescriptor(argDescriptors: List<OpNamespace.ArgDescriptor>): OpDeclarationDescriptor
|
||||||
|
}
|
||||||
|
|
||||||
|
interface MappingRule<ATTRIBUTE_TYPE,ATTRIBUTE_VALUE_TYPE> {
|
||||||
|
fun name(): String
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Convert 1 or more attributes in to a list of {@link ArgDescriptor}
|
||||||
|
*/
|
||||||
|
fun convert(inputs: List<IRAttribute<ATTRIBUTE_TYPE,ATTRIBUTE_VALUE_TYPE>> ): List<OpNamespace.ArgDescriptor>
|
||||||
|
|
||||||
|
fun convertReverse(input: List<OpNamespace.ArgDescriptor>): List<IRAttribute<ATTRIBUTE_TYPE,ATTRIBUTE_VALUE_TYPE>>
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
```
|
|
@ -0,0 +1,81 @@
|
||||||
|
# Interpreter
|
||||||
|
|
||||||
|
## Status
|
||||||
|
Proposed
|
||||||
|
|
||||||
|
Proposed by: Adam Gibson (28-09-2020)
|
||||||
|
|
||||||
|
Discussed with: N/A
|
||||||
|
|
||||||
|
## Context
|
||||||
|
|
||||||
|
|
||||||
|
## Decision
|
||||||
|
|
||||||
|
An interpreter uses the [import IR](./0003-Import_IR.md) and the [mapping rule IR](./0004-Mapping_IR.md)
|
||||||
|
to execute and map operations from one framework to nd4j's file format and back.
|
||||||
|
|
||||||
|
This also allows execution of different frameworks via conversion in the nd4j engine.
|
||||||
|
|
||||||
|
|
||||||
|
A combination of the 2 allows a uniform interface to be used for the interpreter.
|
||||||
|
|
||||||
|
1 or more MappingRules will be used to transform 1 file format to another.
|
||||||
|
|
||||||
|
|
||||||
|
## Mapping Rules Execution
|
||||||
|
|
||||||
|
Mapping Rules are named functions that contain the function signature
|
||||||
|
(input and outputs). These mapping rules are used by the interpreter
|
||||||
|
to know which functions to execute.
|
||||||
|
|
||||||
|
The interpreter has built in implementations of the defined functions
|
||||||
|
for the desired transforms.
|
||||||
|
|
||||||
|
|
||||||
|
## Import process
|
||||||
|
|
||||||
|
An import process is defined for an overall framework.
|
||||||
|
It maps input graphs to samediff graphs using
|
||||||
|
specified mapping processes for op names and frameworks.
|
||||||
|
An import process is all that is needed to create a graph.
|
||||||
|
Below are the needed concepts for an import process to implement.
|
||||||
|
|
||||||
|
|
||||||
|
## Graph creation
|
||||||
|
|
||||||
|
In order for execution to happen, a graph needs to be built.
|
||||||
|
This happens in java via the samediff builder.
|
||||||
|
|
||||||
|
The conversion happens as follows:
|
||||||
|
input node -> convert node to op descriptor via defined mapping rules -> add op descriptor to graph
|
||||||
|
|
||||||
|
The op descriptor is converted to a CustomOp which is then added to the graph via
|
||||||
|
[addArgsFor](https://github.com/KonduitAI/deeplearning4j/blob/88d3c4867fb87ec760b445c6b9459ecf353cec47/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/SameDiff.java#L1078)
|
||||||
|
|
||||||
|
This handles declarative graph creation setting dependencies up. Delegation of the graph structure
|
||||||
|
creation to the existing Samediff library enables the scope of this interpreter to be focused on
|
||||||
|
mapping operations.
|
||||||
|
|
||||||
|
## Custom Sub graphs
|
||||||
|
|
||||||
|
One common use case is mapping sub graphs to custom layers. A custom layer can be thought of as a sequence of operations.
|
||||||
|
In order to map this, a named process can be created. Generally, if you know what ops the sub graph is made of,
|
||||||
|
you only need to declare a set of rules based on the rules that map individual ops in the existing framework.
|
||||||
|
|
||||||
|
## Consequences
|
||||||
|
### Advantages
|
||||||
|
* Uses a common interface across different frameworks making maintenance simple
|
||||||
|
|
||||||
|
* Allows an easy to maintain abstraction for interop with different file formats
|
||||||
|
|
||||||
|
* Allows an easy entry point in to the framework without knowing much about the framework.
|
||||||
|
|
||||||
|
### Disadvantages
|
||||||
|
|
||||||
|
* Need to ensure compatibility across different frameworks
|
||||||
|
|
||||||
|
* Requires extensive testing to ensure proper compatibility
|
||||||
|
|
||||||
|
* May not necessarily support all ops people are expecting. This will be addressed
|
||||||
|
in a new ADR.
|
|
@ -1,12 +1,13 @@
|
||||||
#!groovy
|
|
||||||
|
|
||||||
/* ******************************************************************************
|
/* ******************************************************************************
|
||||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
*
|
||||||
*
|
*
|
||||||
* This program and the accompanying materials are made available under the
|
* This program and the accompanying materials are made available under the
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
*
|
*
|
||||||
|
* See the NOTICE file distributed with this work for additional
|
||||||
|
* information regarding copyright ownership.
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
@ -16,6 +17,8 @@
|
||||||
* SPDX-License-Identifier: Apache-2.0
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
******************************************************************************/
|
******************************************************************************/
|
||||||
|
|
||||||
|
#!groovy
|
||||||
|
|
||||||
/*
|
/*
|
||||||
To redefine some job/run parameters,
|
To redefine some job/run parameters,
|
||||||
please provide arguments to jenkinsBuilder step.
|
please provide arguments to jenkinsBuilder step.
|
||||||
|
|
|
@ -0,0 +1,20 @@
|
||||||
|
Eclipse Deeplearning4j
|
||||||
|
Copyright 2021 Eclipse Deeplearning4j Contributors
|
||||||
|
|
||||||
|
This product includes software developed at
|
||||||
|
The Apache Software Foundation (http://www.apache.org/).
|
||||||
|
|
||||||
|
This product includes software developed at
|
||||||
|
* Skymind Inc (Apache 2.0). Copyright (C) 2015-2018 Skymind Inc .
|
||||||
|
|
||||||
|
This product includes software developed at
|
||||||
|
* Konduit KK (Apache 2.0). Copyright (C) 2020.
|
||||||
|
|
||||||
|
|
||||||
|
This product includes software from the Tensorflow Project (Apache 2.0).
|
||||||
|
* Copyright (C) 2015-2018 Tensorflow Authors.
|
||||||
|
|
||||||
|
# https://github.com/onnx/onnx
|
||||||
|
|
||||||
|
This product includes software from the Onnx Project project (Apache 2.0).
|
||||||
|
* Copyright (C) 2020 Onnx Contributors (https://github.com/onnx/onnx)
|
|
@ -1,45 +0,0 @@
|
||||||
# Arbiter
|
|
||||||
|
|
||||||
A tool dedicated to tuning (hyperparameter optimization) of machine learning models. Part of the DL4J Suite of Machine Learning / Deep Learning tools for the enterprise.
|
|
||||||
|
|
||||||
|
|
||||||
## Modules
|
|
||||||
Arbiter contains the following modules:
|
|
||||||
|
|
||||||
- arbiter-core: Defines the API and core functionality, and also contains functionality for the Arbiter UI
|
|
||||||
- arbiter-deeplearning4j: For hyperparameter optimization of DL4J models (MultiLayerNetwork and ComputationGraph networks)
|
|
||||||
|
|
||||||
|
|
||||||
## Hyperparameter Optimization Functionality
|
|
||||||
|
|
||||||
The open-source version of Arbiter currently defines two methods of hyperparameter optimization:
|
|
||||||
|
|
||||||
- Grid search
|
|
||||||
- Random search
|
|
||||||
|
|
||||||
For optimization of complex models such as neural networks (those with more than a few hyperparameters), random search is superior to grid search, though Bayesian hyperparameter optimization schemes
|
|
||||||
For a comparison of random and grid search methods, see [Random Search for Hyper-parameter Optimization (Bergstra and Bengio, 2012)](http://www.jmlr.org/papers/volume13/bergstra12a/bergstra12a.pdf).
|
|
||||||
|
|
||||||
### Core Concepts and Classes in Arbiter for Hyperparameter Optimization
|
|
||||||
|
|
||||||
In order to conduct hyperparameter optimization in Arbiter, it is necessary for the user to understand and define the following:
|
|
||||||
|
|
||||||
- **Parameter Space**: A ```ParameterSpace<P>``` specifies the type and allowable values of hyperparameters for a model configuration of type ```P```. For example, ```P``` could be a MultiLayerConfiguration for DL4J
|
|
||||||
- **Candidate Generator**: A ```CandidateGenerator<C>``` is used to generate candidate models configurations of some type ```C```. The following implementations are defined in arbiter-core:
|
|
||||||
- ```RandomSearchCandidateGenerator```
|
|
||||||
- ```GridSearchCandidateGenerator```
|
|
||||||
- **Score Function**: A ```ScoreFunction<M,D>``` is used to score a model of type ```M``` given data of type ```D```. For example, in DL4J a score function might be used to calculate the classification accuracy from a DataSetIterator
|
|
||||||
- A key concept here is that they score is a single numerical (double precision) value that we either want to minimize or maximize - this is the goal of hyperparameter optimization
|
|
||||||
- **Termination Conditions**: One or more ```TerminationCondition``` instances must be provided to the ```OptimizationConfiguration```. ```TerminationCondition``` instances are used to control when hyperparameter optimization should be stopped. Some built-in termination conditions:
|
|
||||||
- ```MaxCandidatesCondition```: Terminate if more than the specified number of candidate hyperparameter configurations have been executed
|
|
||||||
- ```MaxTimeCondition```: Terminate after a specified amount of time has elapsed since starting the optimization
|
|
||||||
- **Result Saver**: The ```ResultSaver<C,M,A>``` interface is used to specify how the results of each hyperparameter optimization run should be saved. For example, whether saving should be done to local disk, to a database, to HDFS, or simply stored in memory.
|
|
||||||
- Note that ```ResultSaver.saveModel``` method returns a ```ResultReference``` object, which provides a mechanism for re-loading both the model and score from wherever it may be saved.
|
|
||||||
- **Optimization Configuration**: An ```OptimizationConfiguration<C,M,D,A>``` ties together the above configuration options in a fluent (builder) pattern.
|
|
||||||
- **Candidate Executor**: The ```CandidateExecutor<C,M,D,A>``` interface provides a layer of abstraction between the configuration and execution of each instance of learning. Currently, the only option is the ```LocalCandidateExecutor```, which is used to execute learning on a single machine (in the current JVM). In principle, other execution methods (for example, on Spark or cloud computing machines) could be implemented.
|
|
||||||
- **Optimization Runner**: The ```OptimizationRunner``` uses an ```OptimizationConfiguration``` and a ```CandidateExecutor``` to actually run the optimization, and save the results.
|
|
||||||
|
|
||||||
|
|
||||||
### Optimization of DeepLearning4J Models
|
|
||||||
|
|
||||||
(This section: forthcoming)
|
|
|
@ -1,105 +0,0 @@
|
||||||
<!--~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
|
||||||
~ Copyright (c) 2015-2018 Skymind, Inc.
|
|
||||||
~
|
|
||||||
~ This program and the accompanying materials are made available under the
|
|
||||||
~ terms of the Apache License, Version 2.0 which is available at
|
|
||||||
~ https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
~
|
|
||||||
~ Unless required by applicable law or agreed to in writing, software
|
|
||||||
~ distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
~ WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
~ License for the specific language governing permissions and limitations
|
|
||||||
~ under the License.
|
|
||||||
~
|
|
||||||
~ SPDX-License-Identifier: Apache-2.0
|
|
||||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~-->
|
|
||||||
|
|
||||||
<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
|
|
||||||
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
|
|
||||||
<parent>
|
|
||||||
<artifactId>arbiter</artifactId>
|
|
||||||
<groupId>org.deeplearning4j</groupId>
|
|
||||||
<version>1.0.0-SNAPSHOT</version>
|
|
||||||
</parent>
|
|
||||||
<modelVersion>4.0.0</modelVersion>
|
|
||||||
|
|
||||||
<artifactId>arbiter-core</artifactId>
|
|
||||||
<packaging>jar</packaging>
|
|
||||||
|
|
||||||
<name>arbiter-core</name>
|
|
||||||
|
|
||||||
<dependencies>
|
|
||||||
<dependency>
|
|
||||||
<groupId>org.nd4j</groupId>
|
|
||||||
<artifactId>nd4j-api</artifactId>
|
|
||||||
<version>${nd4j.version}</version>
|
|
||||||
<exclusions>
|
|
||||||
<exclusion>
|
|
||||||
<groupId>com.google.code.findbugs</groupId>
|
|
||||||
<artifactId>*</artifactId>
|
|
||||||
</exclusion>
|
|
||||||
</exclusions>
|
|
||||||
</dependency>
|
|
||||||
|
|
||||||
<dependency>
|
|
||||||
<groupId>org.apache.commons</groupId>
|
|
||||||
<artifactId>commons-lang3</artifactId>
|
|
||||||
<version>${commons.lang.version}</version>
|
|
||||||
</dependency>
|
|
||||||
|
|
||||||
<dependency>
|
|
||||||
<groupId>org.apache.commons</groupId>
|
|
||||||
<artifactId>commons-math3</artifactId>
|
|
||||||
<version>${commons.math.version}</version>
|
|
||||||
</dependency>
|
|
||||||
|
|
||||||
<dependency>
|
|
||||||
<groupId>junit</groupId>
|
|
||||||
<artifactId>junit</artifactId>
|
|
||||||
<version>${junit.version}</version>
|
|
||||||
<scope>test</scope>
|
|
||||||
</dependency>
|
|
||||||
|
|
||||||
<dependency>
|
|
||||||
<groupId>org.slf4j</groupId>
|
|
||||||
<artifactId>slf4j-api</artifactId>
|
|
||||||
<version>${slf4j.version}</version>
|
|
||||||
</dependency>
|
|
||||||
|
|
||||||
<dependency>
|
|
||||||
<groupId>ch.qos.logback</groupId>
|
|
||||||
<artifactId>logback-classic</artifactId>
|
|
||||||
<version>${logback.version}</version>
|
|
||||||
<scope>test</scope>
|
|
||||||
</dependency>
|
|
||||||
|
|
||||||
<dependency>
|
|
||||||
<groupId>joda-time</groupId>
|
|
||||||
<artifactId>joda-time</artifactId>
|
|
||||||
<version>${jodatime.version}</version>
|
|
||||||
</dependency>
|
|
||||||
|
|
||||||
<!-- ND4J Shaded Jackson Dependency -->
|
|
||||||
<dependency>
|
|
||||||
<groupId>org.nd4j</groupId>
|
|
||||||
<artifactId>jackson</artifactId>
|
|
||||||
<version>${nd4j.version}</version>
|
|
||||||
</dependency>
|
|
||||||
|
|
||||||
<dependency>
|
|
||||||
<groupId>org.deeplearning4j</groupId>
|
|
||||||
<artifactId>deeplearning4j-common-tests</artifactId>
|
|
||||||
<version>${project.version}</version>
|
|
||||||
<scope>test</scope>
|
|
||||||
</dependency>
|
|
||||||
</dependencies>
|
|
||||||
|
|
||||||
<profiles>
|
|
||||||
<profile>
|
|
||||||
<id>test-nd4j-native</id>
|
|
||||||
</profile>
|
|
||||||
<profile>
|
|
||||||
<id>test-nd4j-cuda-11.0</id>
|
|
||||||
</profile>
|
|
||||||
</profiles>
|
|
||||||
</project>
|
|
|
@ -1,91 +0,0 @@
|
||||||
<!--~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
|
||||||
~ Copyright (c) 2015-2018 Skymind, Inc.
|
|
||||||
~
|
|
||||||
~ This program and the accompanying materials are made available under the
|
|
||||||
~ terms of the Apache License, Version 2.0 which is available at
|
|
||||||
~ https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
~
|
|
||||||
~ Unless required by applicable law or agreed to in writing, software
|
|
||||||
~ distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
~ WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
~ License for the specific language governing permissions and limitations
|
|
||||||
~ under the License.
|
|
||||||
~
|
|
||||||
~ SPDX-License-Identifier: Apache-2.0
|
|
||||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~-->
|
|
||||||
|
|
||||||
<assembly>
|
|
||||||
<id>bin</id>
|
|
||||||
<!-- START SNIPPET: formats -->
|
|
||||||
<formats>
|
|
||||||
<format>tar.gz</format>
|
|
||||||
<!--
|
|
||||||
<format>tar.bz2</format>
|
|
||||||
<format>zip</format>
|
|
||||||
-->
|
|
||||||
</formats>
|
|
||||||
<!-- END SNIPPET: formats -->
|
|
||||||
|
|
||||||
<dependencySets>
|
|
||||||
<dependencySet>
|
|
||||||
<outputDirectory>lib</outputDirectory>
|
|
||||||
<includes>
|
|
||||||
<include>*:jar:*</include>
|
|
||||||
</includes>
|
|
||||||
<excludes>
|
|
||||||
<exclude>*:sources</exclude>
|
|
||||||
</excludes>
|
|
||||||
</dependencySet>
|
|
||||||
</dependencySets>
|
|
||||||
|
|
||||||
<!-- START SNIPPET: fileSets -->
|
|
||||||
<fileSets>
|
|
||||||
<fileSet>
|
|
||||||
<includes>
|
|
||||||
<include>readme.txt</include>
|
|
||||||
</includes>
|
|
||||||
</fileSet>
|
|
||||||
|
|
||||||
<fileSet>
|
|
||||||
<directory>src/main/resources/bin/</directory>
|
|
||||||
<outputDirectory>bin</outputDirectory>
|
|
||||||
<includes>
|
|
||||||
<include>arbiter</include>
|
|
||||||
</includes>
|
|
||||||
<lineEnding>unix</lineEnding>
|
|
||||||
<fileMode>0755</fileMode>
|
|
||||||
</fileSet>
|
|
||||||
|
|
||||||
<fileSet>
|
|
||||||
<directory>examples</directory>
|
|
||||||
<outputDirectory>examples</outputDirectory>
|
|
||||||
<!--
|
|
||||||
<lineEnding>unix</lineEnding>
|
|
||||||
https://stackoverflow.com/questions/2958282/stranges-files-in-my-assembly-since-switching-to-lineendingunix-lineending
|
|
||||||
-->
|
|
||||||
</fileSet>
|
|
||||||
|
|
||||||
|
|
||||||
<!--
|
|
||||||
<fileSet>
|
|
||||||
<directory>src/bin</directory>
|
|
||||||
<outputDirectory>bin</outputDirectory>
|
|
||||||
<includes>
|
|
||||||
<include>hello</include>
|
|
||||||
</includes>
|
|
||||||
<lineEnding>unix</lineEnding>
|
|
||||||
<fileMode>0755</fileMode>
|
|
||||||
</fileSet>
|
|
||||||
-->
|
|
||||||
|
|
||||||
<fileSet>
|
|
||||||
<directory>target</directory>
|
|
||||||
<outputDirectory>./</outputDirectory>
|
|
||||||
<includes>
|
|
||||||
<include>*.jar</include>
|
|
||||||
</includes>
|
|
||||||
</fileSet>
|
|
||||||
|
|
||||||
</fileSets>
|
|
||||||
<!-- END SNIPPET: fileSets -->
|
|
||||||
</assembly>
|
|
|
@ -1,74 +0,0 @@
|
||||||
/*******************************************************************************
|
|
||||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
|
||||||
*
|
|
||||||
* This program and the accompanying materials are made available under the
|
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* License for the specific language governing permissions and limitations
|
|
||||||
* under the License.
|
|
||||||
*
|
|
||||||
* SPDX-License-Identifier: Apache-2.0
|
|
||||||
******************************************************************************/
|
|
||||||
|
|
||||||
package org.deeplearning4j.arbiter.optimize.api;
|
|
||||||
|
|
||||||
import java.lang.reflect.Field;
|
|
||||||
import java.util.ArrayList;
|
|
||||||
import java.util.LinkedHashMap;
|
|
||||||
import java.util.List;
|
|
||||||
import java.util.Map;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Created by Alex on 23/07/2017.
|
|
||||||
*/
|
|
||||||
public abstract class AbstractParameterSpace<T> implements ParameterSpace<T> {
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public Map<String, ParameterSpace> getNestedSpaces() {
|
|
||||||
Map<String, ParameterSpace> m = new LinkedHashMap<>();
|
|
||||||
|
|
||||||
//Need to manually build and walk the class hierarchy...
|
|
||||||
Class<?> currClass = this.getClass();
|
|
||||||
List<Class<?>> classHierarchy = new ArrayList<>();
|
|
||||||
while (currClass != Object.class) {
|
|
||||||
classHierarchy.add(currClass);
|
|
||||||
currClass = currClass.getSuperclass();
|
|
||||||
}
|
|
||||||
|
|
||||||
for (int i = classHierarchy.size() - 1; i >= 0; i--) {
|
|
||||||
//Use reflection here to avoid a mass of boilerplate code...
|
|
||||||
Field[] allFields = classHierarchy.get(i).getDeclaredFields();
|
|
||||||
|
|
||||||
for (Field f : allFields) {
|
|
||||||
|
|
||||||
String name = f.getName();
|
|
||||||
Class<?> fieldClass = f.getType();
|
|
||||||
boolean isParamSpacefield = ParameterSpace.class.isAssignableFrom(fieldClass);
|
|
||||||
|
|
||||||
if (!isParamSpacefield) {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
f.setAccessible(true);
|
|
||||||
|
|
||||||
ParameterSpace<?> p;
|
|
||||||
try {
|
|
||||||
p = (ParameterSpace<?>) f.get(this);
|
|
||||||
} catch (IllegalAccessException e) {
|
|
||||||
throw new RuntimeException(e);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (p != null) {
|
|
||||||
m.put(name, p);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return m;
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
|
@ -1,57 +0,0 @@
|
||||||
/*******************************************************************************
|
|
||||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
|
||||||
*
|
|
||||||
* This program and the accompanying materials are made available under the
|
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* License for the specific language governing permissions and limitations
|
|
||||||
* under the License.
|
|
||||||
*
|
|
||||||
* SPDX-License-Identifier: Apache-2.0
|
|
||||||
******************************************************************************/
|
|
||||||
|
|
||||||
package org.deeplearning4j.arbiter.optimize.api;
|
|
||||||
|
|
||||||
import lombok.AllArgsConstructor;
|
|
||||||
import lombok.Data;
|
|
||||||
import org.deeplearning4j.arbiter.optimize.generator.util.SerializedSupplier;
|
|
||||||
import org.nd4j.common.function.Supplier;
|
|
||||||
|
|
||||||
import java.io.Serializable;
|
|
||||||
import java.util.Map;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Candidate: a proposed hyperparameter configuration.
|
|
||||||
* Also includes a map for data parameters, to configure things like data preprocessing, etc.
|
|
||||||
*/
|
|
||||||
@Data
|
|
||||||
@AllArgsConstructor
|
|
||||||
public class Candidate<C> implements Serializable {
|
|
||||||
|
|
||||||
private Supplier<C> supplier;
|
|
||||||
private int index;
|
|
||||||
private double[] flatParameters;
|
|
||||||
private Map<String, Object> dataParameters;
|
|
||||||
private Exception exception;
|
|
||||||
|
|
||||||
public Candidate(C value, int index, double[] flatParameters, Map<String,Object> dataParameters, Exception e) {
|
|
||||||
this(new SerializedSupplier<C>(value), index, flatParameters, dataParameters, e);
|
|
||||||
}
|
|
||||||
|
|
||||||
public Candidate(C value, int index, double[] flatParameters) {
|
|
||||||
this(new SerializedSupplier<C>(value), index, flatParameters);
|
|
||||||
}
|
|
||||||
|
|
||||||
public Candidate(Supplier<C> value, int index, double[] flatParameters) {
|
|
||||||
this(value, index, flatParameters, null, null);
|
|
||||||
}
|
|
||||||
|
|
||||||
public C getValue(){
|
|
||||||
return supplier.get();
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
|
@ -1,65 +0,0 @@
|
||||||
/*******************************************************************************
|
|
||||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
|
||||||
*
|
|
||||||
* This program and the accompanying materials are made available under the
|
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* License for the specific language governing permissions and limitations
|
|
||||||
* under the License.
|
|
||||||
*
|
|
||||||
* SPDX-License-Identifier: Apache-2.0
|
|
||||||
******************************************************************************/
|
|
||||||
|
|
||||||
package org.deeplearning4j.arbiter.optimize.api;
|
|
||||||
|
|
||||||
import org.nd4j.shade.jackson.annotation.JsonInclude;
|
|
||||||
import org.nd4j.shade.jackson.annotation.JsonTypeInfo;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* A CandidateGenerator proposes candidates (i.e., hyperparameter configurations) for evaluation.
|
|
||||||
* This abstraction allows for different ways of generating the next configuration to test; for example,
|
|
||||||
* random search, grid search, Bayesian optimization methods, etc.
|
|
||||||
*
|
|
||||||
* @author Alex Black
|
|
||||||
*/
|
|
||||||
@JsonInclude(JsonInclude.Include.NON_NULL)
|
|
||||||
@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class")
|
|
||||||
public interface CandidateGenerator {
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Is this candidate generator able to generate more candidates? This will always return true in some
|
|
||||||
* cases, but some search strategies have a limit (grid search, for example)
|
|
||||||
*/
|
|
||||||
boolean hasMoreCandidates();
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Generate a candidate hyperparameter configuration
|
|
||||||
*/
|
|
||||||
Candidate getCandidate();
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Report results for the candidate generator.
|
|
||||||
*
|
|
||||||
* @param result The results to report
|
|
||||||
*/
|
|
||||||
void reportResults(OptimizationResult result);
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @return Get the parameter space for this candidate generator
|
|
||||||
*/
|
|
||||||
ParameterSpace<?> getParameterSpace();
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @param rngSeed Set the random number generator seed for the candidate generator
|
|
||||||
*/
|
|
||||||
void setRngSeed(long rngSeed);
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @return The type (class) of the generated candidates
|
|
||||||
*/
|
|
||||||
Class<?> getCandidateType();
|
|
||||||
}
|
|
|
@ -1,60 +0,0 @@
|
||||||
/*******************************************************************************
|
|
||||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
|
||||||
*
|
|
||||||
* This program and the accompanying materials are made available under the
|
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* License for the specific language governing permissions and limitations
|
|
||||||
* under the License.
|
|
||||||
*
|
|
||||||
* SPDX-License-Identifier: Apache-2.0
|
|
||||||
******************************************************************************/
|
|
||||||
|
|
||||||
package org.deeplearning4j.arbiter.optimize.api;
|
|
||||||
|
|
||||||
import lombok.Data;
|
|
||||||
import org.deeplearning4j.arbiter.optimize.api.saving.ResultReference;
|
|
||||||
import org.deeplearning4j.arbiter.optimize.runner.CandidateInfo;
|
|
||||||
import org.nd4j.shade.jackson.annotation.JsonIgnoreProperties;
|
|
||||||
import org.nd4j.shade.jackson.annotation.JsonProperty;
|
|
||||||
import org.nd4j.shade.jackson.annotation.JsonTypeInfo;
|
|
||||||
|
|
||||||
import java.io.Serializable;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* An optimization result represents the results of an optimization run, including the candidate configuration, the
|
|
||||||
* trained model, the score for that model, and index of the model
|
|
||||||
*
|
|
||||||
* @author Alex Black
|
|
||||||
*/
|
|
||||||
@Data
|
|
||||||
@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class")
|
|
||||||
@JsonIgnoreProperties({"resultReference"})
|
|
||||||
public class OptimizationResult implements Serializable {
|
|
||||||
@JsonProperty
|
|
||||||
private Candidate candidate;
|
|
||||||
@JsonProperty
|
|
||||||
private Double score;
|
|
||||||
@JsonProperty
|
|
||||||
private int index;
|
|
||||||
@JsonProperty
|
|
||||||
private Object modelSpecificResults;
|
|
||||||
@JsonProperty
|
|
||||||
private CandidateInfo candidateInfo;
|
|
||||||
private ResultReference resultReference;
|
|
||||||
|
|
||||||
|
|
||||||
public OptimizationResult(Candidate candidate, Double score, int index, Object modelSpecificResults,
|
|
||||||
CandidateInfo candidateInfo, ResultReference resultReference) {
|
|
||||||
this.candidate = candidate;
|
|
||||||
this.score = score;
|
|
||||||
this.index = index;
|
|
||||||
this.modelSpecificResults = modelSpecificResults;
|
|
||||||
this.candidateInfo = candidateInfo;
|
|
||||||
this.resultReference = resultReference;
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,81 +0,0 @@
|
||||||
/*******************************************************************************
|
|
||||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
|
||||||
*
|
|
||||||
* This program and the accompanying materials are made available under the
|
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* License for the specific language governing permissions and limitations
|
|
||||||
* under the License.
|
|
||||||
*
|
|
||||||
* SPDX-License-Identifier: Apache-2.0
|
|
||||||
******************************************************************************/
|
|
||||||
|
|
||||||
package org.deeplearning4j.arbiter.optimize.api;
|
|
||||||
|
|
||||||
import org.nd4j.shade.jackson.annotation.JsonIgnore;
|
|
||||||
import org.nd4j.shade.jackson.annotation.JsonInclude;
|
|
||||||
import org.nd4j.shade.jackson.annotation.JsonTypeInfo;
|
|
||||||
|
|
||||||
import java.util.List;
|
|
||||||
import java.util.Map;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* ParameterSpace: defines the acceptable ranges of values a given parameter may take.
|
|
||||||
* Note that parameter spaces can be simple (like {@code ParameterSpace<Double>}) or complicated, including
|
|
||||||
* multiple nested ParameterSpaces
|
|
||||||
*
|
|
||||||
* @author Alex Black
|
|
||||||
*/
|
|
||||||
@JsonInclude(JsonInclude.Include.NON_NULL)
|
|
||||||
@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class")
|
|
||||||
public interface ParameterSpace<P> {
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Generate a candidate given a set of values. These values are then mapped to a specific candidate, using some
|
|
||||||
* mapping function (such as the prior probability distribution)
|
|
||||||
*
|
|
||||||
* @param parameterValues A set of values, each in the range [0,1], of length {@link #numParameters()}
|
|
||||||
*/
|
|
||||||
P getValue(double[] parameterValues);
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Get the total number of parameters (hyperparameters) to be optimized. This includes optional parameters from
|
|
||||||
* different parameter subpaces. (Thus, not every parameter may be used in every candidate)
|
|
||||||
*
|
|
||||||
* @return Number of hyperparameters to be optimized
|
|
||||||
*/
|
|
||||||
int numParameters();
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Collect a list of parameters, recursively. Note that leaf parameters are parameters that do not have any
|
|
||||||
* nested parameter spaces
|
|
||||||
*/
|
|
||||||
List<ParameterSpace> collectLeaves();
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Get a list of nested parameter spaces by name. Note that the returned parameter spaces may in turn have further
|
|
||||||
* nested parameter spaces. The map should be empty for leaf parameter spaces
|
|
||||||
*
|
|
||||||
* @return A map of nested parameter spaces
|
|
||||||
*/
|
|
||||||
Map<String, ParameterSpace> getNestedSpaces();
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Is this ParameterSpace a leaf? (i.e., does it contain other ParameterSpaces internally?)
|
|
||||||
*/
|
|
||||||
@JsonIgnore
|
|
||||||
boolean isLeaf();
|
|
||||||
|
|
||||||
/**
|
|
||||||
* For leaf ParameterSpaces: set the indices of the leaf ParameterSpace.
|
|
||||||
* Expects input of length {@link #numParameters()}. Throws exception if {@link #isLeaf()} is false.
|
|
||||||
*
|
|
||||||
* @param indices Indices to set. Length should equal {@link #numParameters()}
|
|
||||||
*/
|
|
||||||
void setIndices(int... indices);
|
|
||||||
|
|
||||||
}
|
|
|
@ -1,62 +0,0 @@
|
||||||
/*******************************************************************************
|
|
||||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
|
||||||
*
|
|
||||||
* This program and the accompanying materials are made available under the
|
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* License for the specific language governing permissions and limitations
|
|
||||||
* under the License.
|
|
||||||
*
|
|
||||||
* SPDX-License-Identifier: Apache-2.0
|
|
||||||
******************************************************************************/
|
|
||||||
|
|
||||||
package org.deeplearning4j.arbiter.optimize.api;
|
|
||||||
|
|
||||||
import org.deeplearning4j.arbiter.optimize.api.data.DataProvider;
|
|
||||||
import org.deeplearning4j.arbiter.optimize.api.data.DataSource;
|
|
||||||
import org.deeplearning4j.arbiter.optimize.api.score.ScoreFunction;
|
|
||||||
import org.deeplearning4j.arbiter.optimize.runner.IOptimizationRunner;
|
|
||||||
import org.deeplearning4j.arbiter.optimize.runner.listener.StatusListener;
|
|
||||||
|
|
||||||
import java.util.List;
|
|
||||||
import java.util.Properties;
|
|
||||||
import java.util.concurrent.Callable;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* The TaskCreator is used to take a candidate configuration, data provider and score function, and create something
|
|
||||||
* that can be executed as a Callable
|
|
||||||
*
|
|
||||||
* @author Alex Black
|
|
||||||
*/
|
|
||||||
public interface TaskCreator {
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Generate a callable that can be executed to conduct the training of this model (given the model configuration)
|
|
||||||
*
|
|
||||||
* @param candidate Candidate (model) configuration to be trained
|
|
||||||
* @param dataProvider DataProvider, for the data
|
|
||||||
* @param scoreFunction Score function to be used to evaluate the model
|
|
||||||
* @param statusListeners Status listeners, that can be used for callbacks (to UI, for example)
|
|
||||||
* @return A callable that returns an OptimizationResult, once optimization is complete
|
|
||||||
*/
|
|
||||||
@Deprecated
|
|
||||||
Callable<OptimizationResult> create(Candidate candidate, DataProvider dataProvider, ScoreFunction scoreFunction,
|
|
||||||
List<StatusListener> statusListeners, IOptimizationRunner runner);
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Generate a callable that can be executed to conduct the training of this model (given the model configuration)
|
|
||||||
*
|
|
||||||
* @param candidate Candidate (model) configuration to be trained
|
|
||||||
* @param dataSource Data source
|
|
||||||
* @param dataSourceProperties Properties (may be null) for the data source
|
|
||||||
* @param scoreFunction Score function to be used to evaluate the model
|
|
||||||
* @param statusListeners Status listeners, that can be used for callbacks (to UI, for example)
|
|
||||||
* @return A callable that returns an OptimizationResult, once optimization is complete
|
|
||||||
*/
|
|
||||||
Callable<OptimizationResult> create(Candidate candidate, Class<? extends DataSource> dataSource, Properties dataSourceProperties,
|
|
||||||
ScoreFunction scoreFunction, List<StatusListener> statusListeners, IOptimizationRunner runner);
|
|
||||||
}
|
|
|
@ -1,43 +0,0 @@
|
||||||
/*******************************************************************************
|
|
||||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
|
||||||
*
|
|
||||||
* This program and the accompanying materials are made available under the
|
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* License for the specific language governing permissions and limitations
|
|
||||||
* under the License.
|
|
||||||
*
|
|
||||||
* SPDX-License-Identifier: Apache-2.0
|
|
||||||
******************************************************************************/
|
|
||||||
|
|
||||||
package org.deeplearning4j.arbiter.optimize.api;
|
|
||||||
|
|
||||||
import java.util.HashMap;
|
|
||||||
import java.util.Map;
|
|
||||||
|
|
||||||
public class TaskCreatorProvider {
|
|
||||||
|
|
||||||
private static Map<Class<? extends ParameterSpace>, Class<? extends TaskCreator>> map = new HashMap<>();
|
|
||||||
|
|
||||||
public synchronized static TaskCreator defaultTaskCreatorFor(Class<? extends ParameterSpace> paramSpaceClass){
|
|
||||||
Class<? extends TaskCreator> c = map.get(paramSpaceClass);
|
|
||||||
try {
|
|
||||||
if(c == null){
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
return c.newInstance();
|
|
||||||
} catch (Exception e){
|
|
||||||
throw new RuntimeException("Could not create new instance of task creator class: " + c + " - missing no-arg constructor?", e);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
public synchronized static void registerDefaultTaskCreatorClass(Class<? extends ParameterSpace> spaceClass,
|
|
||||||
Class<? extends TaskCreator> creatorClass){
|
|
||||||
map.put(spaceClass, creatorClass);
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
|
@ -1,82 +0,0 @@
|
||||||
/*******************************************************************************
|
|
||||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
|
||||||
*
|
|
||||||
* This program and the accompanying materials are made available under the
|
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* License for the specific language governing permissions and limitations
|
|
||||||
* under the License.
|
|
||||||
*
|
|
||||||
* SPDX-License-Identifier: Apache-2.0
|
|
||||||
******************************************************************************/
|
|
||||||
|
|
||||||
package org.deeplearning4j.arbiter.optimize.api.adapter;
|
|
||||||
|
|
||||||
import lombok.AllArgsConstructor;
|
|
||||||
import org.deeplearning4j.arbiter.optimize.api.ParameterSpace;
|
|
||||||
|
|
||||||
import java.util.Collections;
|
|
||||||
import java.util.List;
|
|
||||||
import java.util.Map;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* An abstract class used for adapting one type into another. Subclasses of this need to merely implement 2 simple methods
|
|
||||||
*
|
|
||||||
* @param <F> Type to convert from
|
|
||||||
* @param <T> Type to convert to
|
|
||||||
* @author Alex Black
|
|
||||||
*/
|
|
||||||
@AllArgsConstructor
|
|
||||||
public abstract class ParameterSpaceAdapter<F, T> implements ParameterSpace<T> {
|
|
||||||
|
|
||||||
|
|
||||||
protected abstract T convertValue(F from);
|
|
||||||
|
|
||||||
protected abstract ParameterSpace<F> underlying();
|
|
||||||
|
|
||||||
protected abstract String underlyingName();
|
|
||||||
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public T getValue(double[] parameterValues) {
|
|
||||||
return convertValue(underlying().getValue(parameterValues));
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public int numParameters() {
|
|
||||||
return underlying().numParameters();
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public List<ParameterSpace> collectLeaves() {
|
|
||||||
ParameterSpace p = underlying();
|
|
||||||
if(p.isLeaf()){
|
|
||||||
return Collections.singletonList(p);
|
|
||||||
}
|
|
||||||
return underlying().collectLeaves();
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public Map<String, ParameterSpace> getNestedSpaces() {
|
|
||||||
return Collections.singletonMap(underlyingName(), (ParameterSpace)underlying());
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public boolean isLeaf() {
|
|
||||||
return false; //Underlying may be a leaf, however
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void setIndices(int... indices) {
|
|
||||||
underlying().setIndices(indices);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public String toString() {
|
|
||||||
return underlying().toString();
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,54 +0,0 @@
|
||||||
/*******************************************************************************
|
|
||||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
|
||||||
*
|
|
||||||
* This program and the accompanying materials are made available under the
|
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* License for the specific language governing permissions and limitations
|
|
||||||
* under the License.
|
|
||||||
*
|
|
||||||
* SPDX-License-Identifier: Apache-2.0
|
|
||||||
******************************************************************************/
|
|
||||||
|
|
||||||
package org.deeplearning4j.arbiter.optimize.api.data;
|
|
||||||
|
|
||||||
import org.nd4j.shade.jackson.annotation.JsonInclude;
|
|
||||||
import org.nd4j.shade.jackson.annotation.JsonTypeInfo;
|
|
||||||
|
|
||||||
import java.io.Serializable;
|
|
||||||
import java.util.Map;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* DataProvider interface abstracts out the providing of data
|
|
||||||
* @deprecated Use {@link DataSource}
|
|
||||||
*/
|
|
||||||
@JsonInclude(JsonInclude.Include.NON_NULL)
|
|
||||||
@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class")
|
|
||||||
@Deprecated
|
|
||||||
public interface DataProvider extends Serializable {
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Get training data given some parameters for the data.
|
|
||||||
* Data parameters map is used to specify things like batch
|
|
||||||
* size data preprocessing
|
|
||||||
*
|
|
||||||
* @param dataParameters Parameters for data. May be null or empty for default data
|
|
||||||
* @return training data
|
|
||||||
*/
|
|
||||||
Object trainData(Map<String, Object> dataParameters);
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Get training data given some parameters for the data. Data parameters map is used to specify things like batch
|
|
||||||
* size data preprocessing
|
|
||||||
*
|
|
||||||
* @param dataParameters Parameters for data. May be null or empty for default data
|
|
||||||
* @return training data
|
|
||||||
*/
|
|
||||||
Object testData(Map<String, Object> dataParameters);
|
|
||||||
|
|
||||||
Class<?> getDataType();
|
|
||||||
}
|
|
|
@ -1,89 +0,0 @@
|
||||||
/*******************************************************************************
|
|
||||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
|
||||||
*
|
|
||||||
* This program and the accompanying materials are made available under the
|
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* License for the specific language governing permissions and limitations
|
|
||||||
* under the License.
|
|
||||||
*
|
|
||||||
* SPDX-License-Identifier: Apache-2.0
|
|
||||||
******************************************************************************/
|
|
||||||
|
|
||||||
package org.deeplearning4j.arbiter.optimize.api.data;
|
|
||||||
|
|
||||||
import lombok.Data;
|
|
||||||
import org.nd4j.linalg.dataset.api.iterator.DataSetIteratorFactory;
|
|
||||||
|
|
||||||
import java.util.Map;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* This is a {@link DataProvider} for
|
|
||||||
* an {@link DataSetIteratorFactory} which
|
|
||||||
* based on a key of {@link DataSetIteratorFactoryProvider#FACTORY_KEY}
|
|
||||||
* will create {@link org.nd4j.linalg.dataset.api.iterator.DataSetIterator}
|
|
||||||
* for use with arbiter.
|
|
||||||
*
|
|
||||||
* This {@link DataProvider} is mainly meant for use for command line driven
|
|
||||||
* applications.
|
|
||||||
*
|
|
||||||
* @author Adam Gibson
|
|
||||||
*/
|
|
||||||
@Data
|
|
||||||
public class DataSetIteratorFactoryProvider implements DataProvider {
|
|
||||||
|
|
||||||
public final static String FACTORY_KEY = "org.deeplearning4j.arbiter.data.data.factory";
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Get training data given some parameters for the data.
|
|
||||||
* Data parameters map is used to specify things like batch
|
|
||||||
* size data preprocessing
|
|
||||||
*
|
|
||||||
* @param dataParameters Parameters for data. May be null or empty for default data
|
|
||||||
* @return training data
|
|
||||||
*/
|
|
||||||
@Override
|
|
||||||
public DataSetIteratorFactory trainData(Map<String, Object> dataParameters) {
|
|
||||||
return create(dataParameters);
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Get training data given some parameters for the data. Data parameters map
|
|
||||||
* is used to specify things like batch
|
|
||||||
* size data preprocessing
|
|
||||||
*
|
|
||||||
* @param dataParameters Parameters for data. May be null or empty for default data
|
|
||||||
* @return training data
|
|
||||||
*/
|
|
||||||
@Override
|
|
||||||
public DataSetIteratorFactory testData(Map<String, Object> dataParameters) {
|
|
||||||
return create(dataParameters);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public Class<?> getDataType() {
|
|
||||||
return DataSetIteratorFactory.class;
|
|
||||||
}
|
|
||||||
|
|
||||||
private DataSetIteratorFactory create(Map<String, Object> dataParameters) {
|
|
||||||
if (dataParameters == null)
|
|
||||||
throw new IllegalArgumentException(
|
|
||||||
"Data parameters is null. Please specify a class name to create a dataset iterator.");
|
|
||||||
if (!dataParameters.containsKey(FACTORY_KEY))
|
|
||||||
throw new IllegalArgumentException(
|
|
||||||
"No data set iterator factory class found. Please specify a class name with key "
|
|
||||||
+ FACTORY_KEY);
|
|
||||||
String value = dataParameters.get(FACTORY_KEY).toString();
|
|
||||||
try {
|
|
||||||
Class<? extends DataSetIteratorFactory> clazz =
|
|
||||||
(Class<? extends DataSetIteratorFactory>) Class.forName(value);
|
|
||||||
return clazz.newInstance();
|
|
||||||
} catch (Exception e) {
|
|
||||||
throw new RuntimeException("Could not create DataSetIteratorFactory instance - missing no-arg constructor?", e);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,57 +0,0 @@
|
||||||
/*******************************************************************************
|
|
||||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
|
||||||
*
|
|
||||||
* This program and the accompanying materials are made available under the
|
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* License for the specific language governing permissions and limitations
|
|
||||||
* under the License.
|
|
||||||
*
|
|
||||||
* SPDX-License-Identifier: Apache-2.0
|
|
||||||
******************************************************************************/
|
|
||||||
|
|
||||||
package org.deeplearning4j.arbiter.optimize.api.data;
|
|
||||||
|
|
||||||
import java.io.Serializable;
|
|
||||||
import java.util.Properties;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* DataSource: defines where the data should come from for training and testing.
|
|
||||||
* Note that implementations must have a no-argument contsructor
|
|
||||||
*
|
|
||||||
* @author Alex Black
|
|
||||||
*/
|
|
||||||
public interface DataSource extends Serializable {
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Configure the current data source with the specified properties
|
|
||||||
* Note: These properties are fixed for the training instance, and are optionally provided by the user
|
|
||||||
* at the configuration stage.
|
|
||||||
* The properties could be anything - and are usually specific to each DataSource implementation.
|
|
||||||
* For example, values such as batch size could be set using these properties
|
|
||||||
* @param properties Properties to apply to the data source instance
|
|
||||||
*/
|
|
||||||
void configure(Properties properties);
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Get test data to be used for the optimization. Usually a DataSetIterator or MultiDataSetIterator
|
|
||||||
*/
|
|
||||||
Object trainData();
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Get test data to be used for the optimization. Usually a DataSetIterator or MultiDataSetIterator
|
|
||||||
*/
|
|
||||||
Object testData();
|
|
||||||
|
|
||||||
/**
|
|
||||||
* The type of data returned by {@link #trainData()} and {@link #testData()}.
|
|
||||||
* Usually DataSetIterator or MultiDataSetIterator
|
|
||||||
* @return Class of the objects returned by trainData and testData
|
|
||||||
*/
|
|
||||||
Class<?> getDataType();
|
|
||||||
|
|
||||||
}
|
|
|
@ -1,40 +0,0 @@
|
||||||
/*******************************************************************************
|
|
||||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
|
||||||
*
|
|
||||||
* This program and the accompanying materials are made available under the
|
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* License for the specific language governing permissions and limitations
|
|
||||||
* under the License.
|
|
||||||
*
|
|
||||||
* SPDX-License-Identifier: Apache-2.0
|
|
||||||
******************************************************************************/
|
|
||||||
|
|
||||||
package org.deeplearning4j.arbiter.optimize.api.evaluation;
|
|
||||||
|
|
||||||
import org.deeplearning4j.arbiter.optimize.api.data.DataProvider;
|
|
||||||
|
|
||||||
import java.io.Serializable;
|
|
||||||
import java.util.List;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* ModelEvaluator: Used to conduct additional evaluation.
|
|
||||||
* For example, this may be classification performance on a test set or similar
|
|
||||||
*/
|
|
||||||
public interface ModelEvaluator extends Serializable {
|
|
||||||
Object evaluateModel(Object model, DataProvider dataProvider);
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @return The model types supported by this class
|
|
||||||
*/
|
|
||||||
List<Class<?>> getSupportedModelTypes();
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @return The datatypes supported by this class
|
|
||||||
*/
|
|
||||||
List<Class<?>> getSupportedDataTypes();
|
|
||||||
}
|
|
|
@ -1,63 +0,0 @@
|
||||||
/*******************************************************************************
|
|
||||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
|
||||||
*
|
|
||||||
* This program and the accompanying materials are made available under the
|
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* License for the specific language governing permissions and limitations
|
|
||||||
* under the License.
|
|
||||||
*
|
|
||||||
* SPDX-License-Identifier: Apache-2.0
|
|
||||||
******************************************************************************/
|
|
||||||
|
|
||||||
package org.deeplearning4j.arbiter.optimize.api.saving;
|
|
||||||
|
|
||||||
import lombok.AllArgsConstructor;
|
|
||||||
import lombok.NoArgsConstructor;
|
|
||||||
import org.deeplearning4j.arbiter.optimize.api.OptimizationResult;
|
|
||||||
|
|
||||||
import java.io.IOException;
|
|
||||||
import java.util.Collections;
|
|
||||||
import java.util.List;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* A simple class to store optimization results in-memory.
|
|
||||||
* Not recommended for large (or a large number of) models.
|
|
||||||
*/
|
|
||||||
@NoArgsConstructor
|
|
||||||
public class InMemoryResultSaver implements ResultSaver {
|
|
||||||
@Override
|
|
||||||
public ResultReference saveModel(OptimizationResult result, Object modelResult) throws IOException {
|
|
||||||
return new InMemoryResult(result, modelResult);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public List<Class<?>> getSupportedCandidateTypes() {
|
|
||||||
return Collections.<Class<?>>singletonList(Object.class);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public List<Class<?>> getSupportedModelTypes() {
|
|
||||||
return Collections.<Class<?>>singletonList(Object.class);
|
|
||||||
}
|
|
||||||
|
|
||||||
@AllArgsConstructor
|
|
||||||
private static class InMemoryResult implements ResultReference {
|
|
||||||
private OptimizationResult result;
|
|
||||||
private Object modelResult;
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public OptimizationResult getResult() throws IOException {
|
|
||||||
return result;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public Object getResultModel() throws IOException {
|
|
||||||
return modelResult;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,37 +0,0 @@
|
||||||
/*******************************************************************************
|
|
||||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
|
||||||
*
|
|
||||||
* This program and the accompanying materials are made available under the
|
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* License for the specific language governing permissions and limitations
|
|
||||||
* under the License.
|
|
||||||
*
|
|
||||||
* SPDX-License-Identifier: Apache-2.0
|
|
||||||
******************************************************************************/
|
|
||||||
|
|
||||||
package org.deeplearning4j.arbiter.optimize.api.saving;
|
|
||||||
|
|
||||||
import org.deeplearning4j.arbiter.optimize.api.OptimizationResult;
|
|
||||||
import org.nd4j.shade.jackson.annotation.JsonTypeInfo;
|
|
||||||
|
|
||||||
import java.io.IOException;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Idea: We can't store all results in memory in general (might have thousands of candidates with millions of
|
|
||||||
* parameters each)
|
|
||||||
* So instead: return a reference to the saved result. Idea is that the result may be saved to disk or a database,
|
|
||||||
* and we can easily load it back into memory (if/when required) using the getResult() method
|
|
||||||
*/
|
|
||||||
@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class")
|
|
||||||
public interface ResultReference {
|
|
||||||
|
|
||||||
OptimizationResult getResult() throws IOException;
|
|
||||||
|
|
||||||
Object getResultModel() throws IOException;
|
|
||||||
|
|
||||||
}
|
|
|
@ -1,57 +0,0 @@
|
||||||
/*******************************************************************************
|
|
||||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
|
||||||
*
|
|
||||||
* This program and the accompanying materials are made available under the
|
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* License for the specific language governing permissions and limitations
|
|
||||||
* under the License.
|
|
||||||
*
|
|
||||||
* SPDX-License-Identifier: Apache-2.0
|
|
||||||
******************************************************************************/
|
|
||||||
|
|
||||||
package org.deeplearning4j.arbiter.optimize.api.saving;
|
|
||||||
|
|
||||||
import org.deeplearning4j.arbiter.optimize.api.OptimizationResult;
|
|
||||||
import org.nd4j.shade.jackson.annotation.JsonInclude;
|
|
||||||
import org.nd4j.shade.jackson.annotation.JsonTypeInfo;
|
|
||||||
|
|
||||||
import java.io.IOException;
|
|
||||||
import java.util.List;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* The ResultSaver interface provides a means of saving models in such a way that they can be loaded back into memory later,
|
|
||||||
* regardless of where/how they are saved.
|
|
||||||
*
|
|
||||||
* @author Alex Black
|
|
||||||
*/
|
|
||||||
@JsonInclude(JsonInclude.Include.NON_NULL)
|
|
||||||
@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class")
|
|
||||||
public interface ResultSaver {
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Save the model (including configuration and any additional evaluation/results)
|
|
||||||
*
|
|
||||||
* @param result Optimization result for the model to save
|
|
||||||
* @param modelResult Model result to save
|
|
||||||
* @return ResultReference, such that the result can be loaded back into memory
|
|
||||||
* @throws IOException If IO error occurs during model saving
|
|
||||||
*/
|
|
||||||
ResultReference saveModel(OptimizationResult result, Object modelResult) throws IOException;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @return The candidate types supported by this class
|
|
||||||
*/
|
|
||||||
List<Class<?>> getSupportedCandidateTypes();
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @return The model types supported by this class
|
|
||||||
*/
|
|
||||||
List<Class<?>> getSupportedModelTypes();
|
|
||||||
|
|
||||||
|
|
||||||
}
|
|
|
@ -1,75 +0,0 @@
|
||||||
/*******************************************************************************
|
|
||||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
|
||||||
*
|
|
||||||
* This program and the accompanying materials are made available under the
|
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* License for the specific language governing permissions and limitations
|
|
||||||
* under the License.
|
|
||||||
*
|
|
||||||
* SPDX-License-Identifier: Apache-2.0
|
|
||||||
******************************************************************************/
|
|
||||||
|
|
||||||
package org.deeplearning4j.arbiter.optimize.api.score;
|
|
||||||
|
|
||||||
import org.deeplearning4j.arbiter.optimize.api.data.DataProvider;
|
|
||||||
import org.deeplearning4j.arbiter.optimize.api.data.DataSource;
|
|
||||||
import org.nd4j.shade.jackson.annotation.JsonInclude;
|
|
||||||
import org.nd4j.shade.jackson.annotation.JsonTypeInfo;
|
|
||||||
|
|
||||||
import java.io.Serializable;
|
|
||||||
import java.util.List;
|
|
||||||
import java.util.Map;
|
|
||||||
import java.util.Properties;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* ScoreFunction defines the objective of hyperparameter optimization.
|
|
||||||
* Specifically, it is used to calculate a score for a given model, relative to the data set provided
|
|
||||||
* in the configuration.
|
|
||||||
*
|
|
||||||
*/
|
|
||||||
@JsonInclude(JsonInclude.Include.NON_NULL)
|
|
||||||
@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class")
|
|
||||||
public interface ScoreFunction extends Serializable {
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Calculate and return the score, for the given model and data provider
|
|
||||||
*
|
|
||||||
* @param model Model to score
|
|
||||||
* @param dataProvider Data provider - data to use
|
|
||||||
* @param dataParameters Parameters for data
|
|
||||||
* @return Calculated score
|
|
||||||
*/
|
|
||||||
double score(Object model, DataProvider dataProvider, Map<String, Object> dataParameters);
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Calculate and return the score, for the given model and data provider
|
|
||||||
*
|
|
||||||
* @param model Model to score
|
|
||||||
* @param dataSource Data source
|
|
||||||
* @param dataSourceProperties data source properties
|
|
||||||
* @return Calculated score
|
|
||||||
*/
|
|
||||||
double score(Object model, Class<? extends DataSource> dataSource, Properties dataSourceProperties);
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Should this score function be minimized or maximized?
|
|
||||||
*
|
|
||||||
* @return true if score should be minimized, false if score should be maximized
|
|
||||||
*/
|
|
||||||
boolean minimize();
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @return The model types supported by this class
|
|
||||||
*/
|
|
||||||
List<Class<?>> getSupportedModelTypes();
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @return The data types supported by this class
|
|
||||||
*/
|
|
||||||
List<Class<?>> getSupportedDataTypes();
|
|
||||||
}
|
|
|
@ -1,50 +0,0 @@
|
||||||
/*******************************************************************************
|
|
||||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
|
||||||
*
|
|
||||||
* This program and the accompanying materials are made available under the
|
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* License for the specific language governing permissions and limitations
|
|
||||||
* under the License.
|
|
||||||
*
|
|
||||||
* SPDX-License-Identifier: Apache-2.0
|
|
||||||
******************************************************************************/
|
|
||||||
|
|
||||||
package org.deeplearning4j.arbiter.optimize.api.termination;
|
|
||||||
|
|
||||||
import lombok.AllArgsConstructor;
|
|
||||||
import lombok.Data;
|
|
||||||
import lombok.NoArgsConstructor;
|
|
||||||
import org.deeplearning4j.arbiter.optimize.runner.IOptimizationRunner;
|
|
||||||
import org.nd4j.shade.jackson.annotation.JsonProperty;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Terminate hyperparameter search when the number of candidates exceeds a specified value.
|
|
||||||
* Note that this is counted as number of completed candidates, plus number of failed candidates.
|
|
||||||
*/
|
|
||||||
@AllArgsConstructor
|
|
||||||
@NoArgsConstructor
|
|
||||||
@Data
|
|
||||||
public class MaxCandidatesCondition implements TerminationCondition {
|
|
||||||
@JsonProperty
|
|
||||||
private int maxCandidates;
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void initialize(IOptimizationRunner optimizationRunner) {
|
|
||||||
//No op
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public boolean terminate(IOptimizationRunner optimizationRunner) {
|
|
||||||
return optimizationRunner.numCandidatesCompleted() + optimizationRunner.numCandidatesFailed() >= maxCandidates;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public String toString() {
|
|
||||||
return "MaxCandidatesCondition(" + maxCandidates + ")";
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,81 +0,0 @@
|
||||||
/*******************************************************************************
|
|
||||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
|
||||||
*
|
|
||||||
* This program and the accompanying materials are made available under the
|
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* License for the specific language governing permissions and limitations
|
|
||||||
* under the License.
|
|
||||||
*
|
|
||||||
* SPDX-License-Identifier: Apache-2.0
|
|
||||||
******************************************************************************/
|
|
||||||
|
|
||||||
package org.deeplearning4j.arbiter.optimize.api.termination;
|
|
||||||
|
|
||||||
import lombok.Data;
|
|
||||||
import lombok.NoArgsConstructor;
|
|
||||||
import org.deeplearning4j.arbiter.optimize.runner.IOptimizationRunner;
|
|
||||||
import org.joda.time.format.DateTimeFormat;
|
|
||||||
import org.joda.time.format.DateTimeFormatter;
|
|
||||||
import org.nd4j.shade.jackson.annotation.JsonProperty;
|
|
||||||
|
|
||||||
import java.util.concurrent.TimeUnit;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Terminate hyperparameter optimization after
|
|
||||||
* a fixed amount of time has passed
|
|
||||||
* @author Alex Black
|
|
||||||
*/
|
|
||||||
@NoArgsConstructor
|
|
||||||
@Data
|
|
||||||
public class MaxTimeCondition implements TerminationCondition {
|
|
||||||
private static final DateTimeFormatter formatter = DateTimeFormat.forPattern("dd-MMM HH:mm ZZ");
|
|
||||||
|
|
||||||
private long duration;
|
|
||||||
private TimeUnit timeUnit;
|
|
||||||
private long startTime;
|
|
||||||
private long endTime;
|
|
||||||
|
|
||||||
|
|
||||||
private MaxTimeCondition(@JsonProperty("duration") long duration, @JsonProperty("timeUnit") TimeUnit timeUnit,
|
|
||||||
@JsonProperty("startTime") long startTime, @JsonProperty("endTime") long endTime) {
|
|
||||||
this.duration = duration;
|
|
||||||
this.timeUnit = timeUnit;
|
|
||||||
this.startTime = startTime;
|
|
||||||
this.endTime = endTime;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @param duration Duration of time
|
|
||||||
* @param timeUnit Unit that the duration is specified in
|
|
||||||
*/
|
|
||||||
public MaxTimeCondition(long duration, TimeUnit timeUnit) {
|
|
||||||
this.duration = duration;
|
|
||||||
this.timeUnit = timeUnit;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void initialize(IOptimizationRunner optimizationRunner) {
|
|
||||||
startTime = System.currentTimeMillis();
|
|
||||||
this.endTime = startTime + timeUnit.toMillis(duration);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public boolean terminate(IOptimizationRunner optimizationRunner) {
|
|
||||||
return System.currentTimeMillis() >= endTime;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public String toString() {
|
|
||||||
if (startTime > 0) {
|
|
||||||
return "MaxTimeCondition(" + duration + "," + timeUnit + ",start=\"" + formatter.print(startTime)
|
|
||||||
+ "\",end=\"" + formatter.print(endTime) + "\")";
|
|
||||||
} else {
|
|
||||||
return "MaxTimeCondition(" + duration + "," + timeUnit + "\")";
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,45 +0,0 @@
|
||||||
/*******************************************************************************
|
|
||||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
|
||||||
*
|
|
||||||
* This program and the accompanying materials are made available under the
|
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* License for the specific language governing permissions and limitations
|
|
||||||
* under the License.
|
|
||||||
*
|
|
||||||
* SPDX-License-Identifier: Apache-2.0
|
|
||||||
******************************************************************************/
|
|
||||||
|
|
||||||
package org.deeplearning4j.arbiter.optimize.api.termination;
|
|
||||||
|
|
||||||
|
|
||||||
import org.deeplearning4j.arbiter.optimize.runner.IOptimizationRunner;
|
|
||||||
import org.nd4j.shade.jackson.annotation.JsonInclude;
|
|
||||||
import org.nd4j.shade.jackson.annotation.JsonTypeInfo;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Global termination condition for conducting hyperparameter optimization.
|
|
||||||
* Termination conditions are used to determine if/when the optimization should stop.
|
|
||||||
*/
|
|
||||||
@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class")
|
|
||||||
@JsonInclude(JsonInclude.Include.NON_NULL)
|
|
||||||
public interface TerminationCondition {
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Initialize the termination condition (such as starting timers, etc).
|
|
||||||
*/
|
|
||||||
void initialize(IOptimizationRunner optimizationRunner);
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Determine whether optimization should be terminated
|
|
||||||
*
|
|
||||||
* @param optimizationRunner Optimization runner
|
|
||||||
* @return true if learning should be terminated, false otherwise
|
|
||||||
*/
|
|
||||||
boolean terminate(IOptimizationRunner optimizationRunner);
|
|
||||||
|
|
||||||
}
|
|
|
@ -1,221 +0,0 @@
|
||||||
/*******************************************************************************
|
|
||||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
|
||||||
*
|
|
||||||
* This program and the accompanying materials are made available under the
|
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* License for the specific language governing permissions and limitations
|
|
||||||
* under the License.
|
|
||||||
*
|
|
||||||
* SPDX-License-Identifier: Apache-2.0
|
|
||||||
******************************************************************************/
|
|
||||||
|
|
||||||
package org.deeplearning4j.arbiter.optimize.config;
|
|
||||||
|
|
||||||
import lombok.*;
|
|
||||||
import org.deeplearning4j.arbiter.optimize.api.CandidateGenerator;
|
|
||||||
import org.deeplearning4j.arbiter.optimize.api.data.DataProvider;
|
|
||||||
import org.deeplearning4j.arbiter.optimize.api.data.DataSource;
|
|
||||||
import org.deeplearning4j.arbiter.optimize.api.saving.ResultSaver;
|
|
||||||
import org.deeplearning4j.arbiter.optimize.api.score.ScoreFunction;
|
|
||||||
import org.deeplearning4j.arbiter.optimize.api.termination.TerminationCondition;
|
|
||||||
import org.deeplearning4j.arbiter.optimize.serde.jackson.JsonMapper;
|
|
||||||
import org.nd4j.shade.jackson.annotation.JsonTypeInfo;
|
|
||||||
import org.nd4j.shade.jackson.core.JsonProcessingException;
|
|
||||||
import org.nd4j.shade.jackson.databind.annotation.JsonSerialize;
|
|
||||||
|
|
||||||
import java.io.IOException;
|
|
||||||
import java.util.Arrays;
|
|
||||||
import java.util.List;
|
|
||||||
import java.util.Properties;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* OptimizationConfiguration ties together all of the various
|
|
||||||
* components (such as data, score functions, result saving etc)
|
|
||||||
* required to execute hyperparameter optimization.
|
|
||||||
*
|
|
||||||
* @author Alex Black
|
|
||||||
*/
|
|
||||||
@Data
|
|
||||||
@NoArgsConstructor
|
|
||||||
@EqualsAndHashCode(exclude = {"dataProvider", "terminationConditions", "candidateGenerator", "resultSaver"})
|
|
||||||
@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class")
|
|
||||||
public class OptimizationConfiguration {
|
|
||||||
@JsonSerialize
|
|
||||||
private DataProvider dataProvider;
|
|
||||||
@JsonSerialize
|
|
||||||
private Class<? extends DataSource> dataSource;
|
|
||||||
@JsonSerialize
|
|
||||||
private Properties dataSourceProperties;
|
|
||||||
@JsonSerialize
|
|
||||||
private CandidateGenerator candidateGenerator;
|
|
||||||
@JsonSerialize
|
|
||||||
private ResultSaver resultSaver;
|
|
||||||
@JsonSerialize
|
|
||||||
private ScoreFunction scoreFunction;
|
|
||||||
@JsonSerialize
|
|
||||||
private List<TerminationCondition> terminationConditions;
|
|
||||||
@JsonSerialize
|
|
||||||
private Long rngSeed;
|
|
||||||
|
|
||||||
@Getter
|
|
||||||
@Setter
|
|
||||||
private long executionStartTime;
|
|
||||||
|
|
||||||
|
|
||||||
private OptimizationConfiguration(Builder builder) {
|
|
||||||
this.dataProvider = builder.dataProvider;
|
|
||||||
this.dataSource = builder.dataSource;
|
|
||||||
this.dataSourceProperties = builder.dataSourceProperties;
|
|
||||||
this.candidateGenerator = builder.candidateGenerator;
|
|
||||||
this.resultSaver = builder.resultSaver;
|
|
||||||
this.scoreFunction = builder.scoreFunction;
|
|
||||||
this.terminationConditions = builder.terminationConditions;
|
|
||||||
this.rngSeed = builder.rngSeed;
|
|
||||||
|
|
||||||
if (rngSeed != null)
|
|
||||||
candidateGenerator.setRngSeed(rngSeed);
|
|
||||||
|
|
||||||
//Validate the configuration: data types, score types, etc
|
|
||||||
//TODO
|
|
||||||
|
|
||||||
//Validate that the dataSource has a no-arg constructor
|
|
||||||
if(dataSource != null){
|
|
||||||
try{
|
|
||||||
dataSource.getConstructor();
|
|
||||||
} catch (NoSuchMethodException e){
|
|
||||||
throw new IllegalStateException("Data source class " + dataSource.getName() + " does not have a public no-argument constructor");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
public static class Builder {
|
|
||||||
|
|
||||||
private DataProvider dataProvider;
|
|
||||||
private Class<? extends DataSource> dataSource;
|
|
||||||
private Properties dataSourceProperties;
|
|
||||||
private CandidateGenerator candidateGenerator;
|
|
||||||
private ResultSaver resultSaver;
|
|
||||||
private ScoreFunction scoreFunction;
|
|
||||||
private List<TerminationCondition> terminationConditions;
|
|
||||||
private Long rngSeed;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @deprecated Use {@link #dataSource(Class, Properties)}
|
|
||||||
*/
|
|
||||||
@Deprecated
|
|
||||||
public Builder dataProvider(DataProvider dataProvider) {
|
|
||||||
this.dataProvider = dataProvider;
|
|
||||||
return this;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* DataSource: defines where the data should come from for training and testing.
|
|
||||||
* Note that implementations must have a no-argument constructor
|
|
||||||
* @param dataSource Class for the data source
|
|
||||||
* @param dataSourceProperties May be null. Properties for configuring the data source
|
|
||||||
*/
|
|
||||||
public Builder dataSource(Class<? extends DataSource> dataSource, Properties dataSourceProperties){
|
|
||||||
this.dataSource = dataSource;
|
|
||||||
this.dataSourceProperties = dataSourceProperties;
|
|
||||||
return this;
|
|
||||||
}
|
|
||||||
|
|
||||||
public Builder candidateGenerator(CandidateGenerator candidateGenerator) {
|
|
||||||
this.candidateGenerator = candidateGenerator;
|
|
||||||
return this;
|
|
||||||
}
|
|
||||||
|
|
||||||
public Builder modelSaver(ResultSaver resultSaver) {
|
|
||||||
this.resultSaver = resultSaver;
|
|
||||||
return this;
|
|
||||||
}
|
|
||||||
|
|
||||||
public Builder scoreFunction(ScoreFunction scoreFunction) {
|
|
||||||
this.scoreFunction = scoreFunction;
|
|
||||||
return this;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Termination conditions to use
|
|
||||||
* @param conditions
|
|
||||||
* @return
|
|
||||||
*/
|
|
||||||
public Builder terminationConditions(TerminationCondition... conditions) {
|
|
||||||
terminationConditions = Arrays.asList(conditions);
|
|
||||||
return this;
|
|
||||||
}
|
|
||||||
|
|
||||||
public Builder terminationConditions(List<TerminationCondition> terminationConditions) {
|
|
||||||
this.terminationConditions = terminationConditions;
|
|
||||||
return this;
|
|
||||||
}
|
|
||||||
|
|
||||||
public Builder rngSeed(long rngSeed) {
|
|
||||||
this.rngSeed = rngSeed;
|
|
||||||
return this;
|
|
||||||
}
|
|
||||||
|
|
||||||
public OptimizationConfiguration build() {
|
|
||||||
return new OptimizationConfiguration(this);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Create an optimization configuration from the json
|
|
||||||
* @param json the json to create the config from
|
|
||||||
* For type definitions
|
|
||||||
* @see OptimizationConfiguration
|
|
||||||
*/
|
|
||||||
public static OptimizationConfiguration fromYaml(String json) {
|
|
||||||
try {
|
|
||||||
return JsonMapper.getYamlMapper().readValue(json, OptimizationConfiguration.class);
|
|
||||||
} catch (IOException e) {
|
|
||||||
throw new RuntimeException(e);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Create an optimization configuration from the json
|
|
||||||
* @param json the json to create the config from
|
|
||||||
* @see OptimizationConfiguration
|
|
||||||
*/
|
|
||||||
public static OptimizationConfiguration fromJson(String json) {
|
|
||||||
try {
|
|
||||||
return JsonMapper.getMapper().readValue(json, OptimizationConfiguration.class);
|
|
||||||
} catch (IOException e) {
|
|
||||||
throw new RuntimeException(e);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Return a json configuration of this optimization configuration
|
|
||||||
*
|
|
||||||
* @return
|
|
||||||
*/
|
|
||||||
public String toJson() {
|
|
||||||
try {
|
|
||||||
return JsonMapper.getMapper().writeValueAsString(this);
|
|
||||||
} catch (JsonProcessingException e) {
|
|
||||||
throw new RuntimeException(e);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Return a yaml configuration of this optimization configuration
|
|
||||||
*
|
|
||||||
* @return
|
|
||||||
*/
|
|
||||||
public String toYaml() {
|
|
||||||
try {
|
|
||||||
return JsonMapper.getYamlMapper().writeValueAsString(this);
|
|
||||||
} catch (JsonProcessingException e) {
|
|
||||||
throw new RuntimeException(e);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,97 +0,0 @@
|
||||||
/*******************************************************************************
|
|
||||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
|
||||||
*
|
|
||||||
* This program and the accompanying materials are made available under the
|
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* License for the specific language governing permissions and limitations
|
|
||||||
* under the License.
|
|
||||||
*
|
|
||||||
* SPDX-License-Identifier: Apache-2.0
|
|
||||||
******************************************************************************/
|
|
||||||
|
|
||||||
package org.deeplearning4j.arbiter.optimize.distribution;
|
|
||||||
|
|
||||||
import org.apache.commons.math3.distribution.IntegerDistribution;
|
|
||||||
import org.apache.commons.math3.exception.NumberIsTooLargeException;
|
|
||||||
import org.apache.commons.math3.exception.OutOfRangeException;
|
|
||||||
|
|
||||||
import java.util.Arrays;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Degenerate distribution: i.e., integer "distribution" that is just a fixed value
|
|
||||||
*/
|
|
||||||
public class DegenerateIntegerDistribution implements IntegerDistribution {
|
|
||||||
private int value;
|
|
||||||
|
|
||||||
public DegenerateIntegerDistribution(int value) {
|
|
||||||
this.value = value;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public double probability(int x) {
|
|
||||||
return (x == value ? 1.0 : 0.0);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public double cumulativeProbability(int x) {
|
|
||||||
return (x >= value ? 1.0 : 0.0);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public double cumulativeProbability(int x0, int x1) throws NumberIsTooLargeException {
|
|
||||||
return (value >= x0 && value <= x1 ? 1.0 : 0.0);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public int inverseCumulativeProbability(double p) throws OutOfRangeException {
|
|
||||||
throw new UnsupportedOperationException();
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public double getNumericalMean() {
|
|
||||||
return value;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public double getNumericalVariance() {
|
|
||||||
return 0;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public int getSupportLowerBound() {
|
|
||||||
return value;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public int getSupportUpperBound() {
|
|
||||||
return value;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public boolean isSupportConnected() {
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void reseedRandomGenerator(long seed) {
|
|
||||||
//no op
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public int sample() {
|
|
||||||
return value;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public int[] sample(int sampleSize) {
|
|
||||||
int[] out = new int[sampleSize];
|
|
||||||
Arrays.fill(out, value);
|
|
||||||
return out;
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,149 +0,0 @@
|
||||||
/*******************************************************************************
|
|
||||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
|
||||||
*
|
|
||||||
* This program and the accompanying materials are made available under the
|
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* License for the specific language governing permissions and limitations
|
|
||||||
* under the License.
|
|
||||||
*
|
|
||||||
* SPDX-License-Identifier: Apache-2.0
|
|
||||||
******************************************************************************/
|
|
||||||
|
|
||||||
package org.deeplearning4j.arbiter.optimize.distribution;
|
|
||||||
|
|
||||||
import org.apache.commons.math3.distribution.*;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Distribution utils for Apache Commons math distributions - which don't provide equals, hashcode, toString methods,
|
|
||||||
* don't implement serializable etc.
|
|
||||||
* Which makes unit testing etc quite difficult.
|
|
||||||
*
|
|
||||||
* @author Alex Black
|
|
||||||
*/
|
|
||||||
public class DistributionUtils {
|
|
||||||
|
|
||||||
private DistributionUtils() {}
|
|
||||||
|
|
||||||
|
|
||||||
public static boolean distributionsEqual(RealDistribution a, RealDistribution b) {
|
|
||||||
if (a.getClass() != b.getClass())
|
|
||||||
return false;
|
|
||||||
Class<?> c = a.getClass();
|
|
||||||
if (c == BetaDistribution.class) {
|
|
||||||
BetaDistribution ba = (BetaDistribution) a;
|
|
||||||
BetaDistribution bb = (BetaDistribution) b;
|
|
||||||
|
|
||||||
return ba.getAlpha() == bb.getAlpha() && ba.getBeta() == bb.getBeta();
|
|
||||||
} else if (c == CauchyDistribution.class) {
|
|
||||||
CauchyDistribution ca = (CauchyDistribution) a;
|
|
||||||
CauchyDistribution cb = (CauchyDistribution) b;
|
|
||||||
return ca.getMedian() == cb.getMedian() && ca.getScale() == cb.getScale();
|
|
||||||
} else if (c == ChiSquaredDistribution.class) {
|
|
||||||
ChiSquaredDistribution ca = (ChiSquaredDistribution) a;
|
|
||||||
ChiSquaredDistribution cb = (ChiSquaredDistribution) b;
|
|
||||||
return ca.getDegreesOfFreedom() == cb.getDegreesOfFreedom();
|
|
||||||
} else if (c == ExponentialDistribution.class) {
|
|
||||||
ExponentialDistribution ea = (ExponentialDistribution) a;
|
|
||||||
ExponentialDistribution eb = (ExponentialDistribution) b;
|
|
||||||
return ea.getMean() == eb.getMean();
|
|
||||||
} else if (c == FDistribution.class) {
|
|
||||||
FDistribution fa = (FDistribution) a;
|
|
||||||
FDistribution fb = (FDistribution) b;
|
|
||||||
return fa.getNumeratorDegreesOfFreedom() == fb.getNumeratorDegreesOfFreedom()
|
|
||||||
&& fa.getDenominatorDegreesOfFreedom() == fb.getDenominatorDegreesOfFreedom();
|
|
||||||
} else if (c == GammaDistribution.class) {
|
|
||||||
GammaDistribution ga = (GammaDistribution) a;
|
|
||||||
GammaDistribution gb = (GammaDistribution) b;
|
|
||||||
return ga.getShape() == gb.getShape() && ga.getScale() == gb.getScale();
|
|
||||||
} else if (c == LevyDistribution.class) {
|
|
||||||
LevyDistribution la = (LevyDistribution) a;
|
|
||||||
LevyDistribution lb = (LevyDistribution) b;
|
|
||||||
return la.getLocation() == lb.getLocation() && la.getScale() == lb.getScale();
|
|
||||||
} else if (c == LogNormalDistribution.class) {
|
|
||||||
LogNormalDistribution la = (LogNormalDistribution) a;
|
|
||||||
LogNormalDistribution lb = (LogNormalDistribution) b;
|
|
||||||
return la.getScale() == lb.getScale() && la.getShape() == lb.getShape();
|
|
||||||
} else if (c == NormalDistribution.class) {
|
|
||||||
NormalDistribution na = (NormalDistribution) a;
|
|
||||||
NormalDistribution nb = (NormalDistribution) b;
|
|
||||||
return na.getMean() == nb.getMean() && na.getStandardDeviation() == nb.getStandardDeviation();
|
|
||||||
} else if (c == ParetoDistribution.class) {
|
|
||||||
ParetoDistribution pa = (ParetoDistribution) a;
|
|
||||||
ParetoDistribution pb = (ParetoDistribution) b;
|
|
||||||
return pa.getScale() == pb.getScale() && pa.getShape() == pb.getShape();
|
|
||||||
} else if (c == TDistribution.class) {
|
|
||||||
TDistribution ta = (TDistribution) a;
|
|
||||||
TDistribution tb = (TDistribution) b;
|
|
||||||
return ta.getDegreesOfFreedom() == tb.getDegreesOfFreedom();
|
|
||||||
} else if (c == TriangularDistribution.class) {
|
|
||||||
TriangularDistribution ta = (TriangularDistribution) a;
|
|
||||||
TriangularDistribution tb = (TriangularDistribution) b;
|
|
||||||
return ta.getSupportLowerBound() == tb.getSupportLowerBound()
|
|
||||||
&& ta.getSupportUpperBound() == tb.getSupportUpperBound() && ta.getMode() == tb.getMode();
|
|
||||||
} else if (c == UniformRealDistribution.class) {
|
|
||||||
UniformRealDistribution ua = (UniformRealDistribution) a;
|
|
||||||
UniformRealDistribution ub = (UniformRealDistribution) b;
|
|
||||||
return ua.getSupportLowerBound() == ub.getSupportLowerBound()
|
|
||||||
&& ua.getSupportUpperBound() == ub.getSupportUpperBound();
|
|
||||||
} else if (c == WeibullDistribution.class) {
|
|
||||||
WeibullDistribution wa = (WeibullDistribution) a;
|
|
||||||
WeibullDistribution wb = (WeibullDistribution) b;
|
|
||||||
return wa.getShape() == wb.getShape() && wa.getScale() == wb.getScale();
|
|
||||||
} else if (c == LogUniformDistribution.class ){
|
|
||||||
LogUniformDistribution lu_a = (LogUniformDistribution)a;
|
|
||||||
LogUniformDistribution lu_b = (LogUniformDistribution)b;
|
|
||||||
return lu_a.getMin() == lu_b.getMin() && lu_a.getMax() == lu_b.getMax();
|
|
||||||
} else {
|
|
||||||
throw new UnsupportedOperationException("Unknown or not supported RealDistribution: " + c);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
public static boolean distributionEquals(IntegerDistribution a, IntegerDistribution b) {
|
|
||||||
if (a.getClass() != b.getClass())
|
|
||||||
return false;
|
|
||||||
Class<?> c = a.getClass();
|
|
||||||
|
|
||||||
if (c == BinomialDistribution.class) {
|
|
||||||
BinomialDistribution ba = (BinomialDistribution) a;
|
|
||||||
BinomialDistribution bb = (BinomialDistribution) b;
|
|
||||||
return ba.getNumberOfTrials() == bb.getNumberOfTrials()
|
|
||||||
&& ba.getProbabilityOfSuccess() == bb.getProbabilityOfSuccess();
|
|
||||||
} else if (c == GeometricDistribution.class) {
|
|
||||||
GeometricDistribution ga = (GeometricDistribution) a;
|
|
||||||
GeometricDistribution gb = (GeometricDistribution) b;
|
|
||||||
return ga.getProbabilityOfSuccess() == gb.getProbabilityOfSuccess();
|
|
||||||
} else if (c == HypergeometricDistribution.class) {
|
|
||||||
HypergeometricDistribution ha = (HypergeometricDistribution) a;
|
|
||||||
HypergeometricDistribution hb = (HypergeometricDistribution) b;
|
|
||||||
return ha.getPopulationSize() == hb.getPopulationSize()
|
|
||||||
&& ha.getNumberOfSuccesses() == hb.getNumberOfSuccesses()
|
|
||||||
&& ha.getSampleSize() == hb.getSampleSize();
|
|
||||||
} else if (c == PascalDistribution.class) {
|
|
||||||
PascalDistribution pa = (PascalDistribution) a;
|
|
||||||
PascalDistribution pb = (PascalDistribution) b;
|
|
||||||
return pa.getNumberOfSuccesses() == pb.getNumberOfSuccesses()
|
|
||||||
&& pa.getProbabilityOfSuccess() == pb.getProbabilityOfSuccess();
|
|
||||||
} else if (c == PoissonDistribution.class) {
|
|
||||||
PoissonDistribution pa = (PoissonDistribution) a;
|
|
||||||
PoissonDistribution pb = (PoissonDistribution) b;
|
|
||||||
return pa.getMean() == pb.getMean();
|
|
||||||
} else if (c == UniformIntegerDistribution.class) {
|
|
||||||
UniformIntegerDistribution ua = (UniformIntegerDistribution) a;
|
|
||||||
UniformIntegerDistribution ub = (UniformIntegerDistribution) b;
|
|
||||||
return ua.getSupportUpperBound() == ub.getSupportUpperBound()
|
|
||||||
&& ua.getSupportUpperBound() == ub.getSupportUpperBound();
|
|
||||||
} else if (c == ZipfDistribution.class) {
|
|
||||||
ZipfDistribution za = (ZipfDistribution) a;
|
|
||||||
ZipfDistribution zb = (ZipfDistribution) b;
|
|
||||||
return za.getNumberOfElements() == zb.getNumberOfElements() && za.getExponent() == zb.getNumberOfElements();
|
|
||||||
} else {
|
|
||||||
throw new UnsupportedOperationException("Unknown or not supported IntegerDistribution: " + c);
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,155 +0,0 @@
|
||||||
/*******************************************************************************
|
|
||||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
|
||||||
*
|
|
||||||
* This program and the accompanying materials are made available under the
|
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* License for the specific language governing permissions and limitations
|
|
||||||
* under the License.
|
|
||||||
*
|
|
||||||
* SPDX-License-Identifier: Apache-2.0
|
|
||||||
******************************************************************************/
|
|
||||||
|
|
||||||
package org.deeplearning4j.arbiter.optimize.distribution;
|
|
||||||
|
|
||||||
import org.nd4j.shade.guava.base.Preconditions;
|
|
||||||
import lombok.Getter;
|
|
||||||
import org.apache.commons.math3.distribution.RealDistribution;
|
|
||||||
import org.apache.commons.math3.exception.NumberIsTooLargeException;
|
|
||||||
import org.apache.commons.math3.exception.OutOfRangeException;
|
|
||||||
|
|
||||||
import java.util.Random;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Log uniform distribution, with support in range [min, max] for min > 0
|
|
||||||
*
|
|
||||||
* Reference: <a href="https://www.vosesoftware.com/riskwiki/LogUniformdistribution.php">https://www.vosesoftware.com/riskwiki/LogUniformdistribution.php</a>
|
|
||||||
*
|
|
||||||
* @author Alex Black
|
|
||||||
*/
|
|
||||||
public class LogUniformDistribution implements RealDistribution {
|
|
||||||
|
|
||||||
@Getter private final double min;
|
|
||||||
@Getter private final double max;
|
|
||||||
|
|
||||||
private final double logMin;
|
|
||||||
private final double logMax;
|
|
||||||
|
|
||||||
private transient Random rng = new Random();
|
|
||||||
|
|
||||||
/**
|
|
||||||
*
|
|
||||||
* @param min Minimum value
|
|
||||||
* @param max Maximum value
|
|
||||||
*/
|
|
||||||
public LogUniformDistribution(double min, double max) {
|
|
||||||
Preconditions.checkArgument(min > 0, "Minimum must be > 0. Got: " + min);
|
|
||||||
Preconditions.checkArgument(max > min, "Maximum must be > min. Got: (min, max)=("
|
|
||||||
+ min + "," + max + ")");
|
|
||||||
this.min = min;
|
|
||||||
this.max = max;
|
|
||||||
|
|
||||||
this.logMin = Math.log(min);
|
|
||||||
this.logMax = Math.log(max);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public double probability(double x) {
|
|
||||||
if(x < min || x > max){
|
|
||||||
return 0;
|
|
||||||
}
|
|
||||||
|
|
||||||
return 1.0 / (x * (logMax - logMin));
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public double density(double x) {
|
|
||||||
return probability(x);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public double cumulativeProbability(double x) {
|
|
||||||
if(x <= min){
|
|
||||||
return 0.0;
|
|
||||||
} else if(x >= max){
|
|
||||||
return 1.0;
|
|
||||||
}
|
|
||||||
|
|
||||||
return (Math.log(x)-logMin)/(logMax-logMin);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public double cumulativeProbability(double x0, double x1) throws NumberIsTooLargeException {
|
|
||||||
return cumulativeProbability(x1) - cumulativeProbability(x0);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public double inverseCumulativeProbability(double p) throws OutOfRangeException {
|
|
||||||
Preconditions.checkArgument(p >= 0 && p <= 1, "Invalid input: " + p);
|
|
||||||
return Math.exp(p * (logMax-logMin) + logMin);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public double getNumericalMean() {
|
|
||||||
return (max-min)/(logMax-logMin);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public double getNumericalVariance() {
|
|
||||||
double d1 = (logMax-logMin)*(max*max - min*min) - 2*(max-min)*(max-min);
|
|
||||||
return d1 / (2*Math.pow(logMax-logMin, 2.0));
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public double getSupportLowerBound() {
|
|
||||||
return min;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public double getSupportUpperBound() {
|
|
||||||
return max;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public boolean isSupportLowerBoundInclusive() {
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public boolean isSupportUpperBoundInclusive() {
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public boolean isSupportConnected() {
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void reseedRandomGenerator(long seed) {
|
|
||||||
rng.setSeed(seed);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public double sample() {
|
|
||||||
return inverseCumulativeProbability(rng.nextDouble());
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public double[] sample(int sampleSize) {
|
|
||||||
double[] d = new double[sampleSize];
|
|
||||||
for( int i=0; i<sampleSize; i++ ){
|
|
||||||
d[i] = sample();
|
|
||||||
}
|
|
||||||
return d;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public String toString(){
|
|
||||||
return "LogUniformDistribution(min=" + min + ",max=" + max + ")";
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,91 +0,0 @@
|
||||||
/*******************************************************************************
|
|
||||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
|
||||||
*
|
|
||||||
* This program and the accompanying materials are made available under the
|
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* License for the specific language governing permissions and limitations
|
|
||||||
* under the License.
|
|
||||||
*
|
|
||||||
* SPDX-License-Identifier: Apache-2.0
|
|
||||||
******************************************************************************/
|
|
||||||
|
|
||||||
package org.deeplearning4j.arbiter.optimize.generator;
|
|
||||||
|
|
||||||
import lombok.Data;
|
|
||||||
import lombok.EqualsAndHashCode;
|
|
||||||
import org.apache.commons.math3.random.JDKRandomGenerator;
|
|
||||||
import org.apache.commons.math3.random.SynchronizedRandomGenerator;
|
|
||||||
import org.deeplearning4j.arbiter.optimize.api.CandidateGenerator;
|
|
||||||
import org.deeplearning4j.arbiter.optimize.api.OptimizationResult;
|
|
||||||
import org.deeplearning4j.arbiter.optimize.api.ParameterSpace;
|
|
||||||
import org.deeplearning4j.arbiter.util.LeafUtils;
|
|
||||||
|
|
||||||
import java.util.List;
|
|
||||||
import java.util.Map;
|
|
||||||
import java.util.concurrent.atomic.AtomicInteger;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* BaseCandidateGenerator: abstract class upon which {@link RandomSearchGenerator},
|
|
||||||
* {@link GridSearchCandidateGenerator} and {@link GeneticSearchCandidateGenerator}
|
|
||||||
* are built.
|
|
||||||
*
|
|
||||||
* @param <T> Type of candidates to generate
|
|
||||||
*/
|
|
||||||
@Data
|
|
||||||
@EqualsAndHashCode(exclude = {"rng", "candidateCounter"})
|
|
||||||
public abstract class BaseCandidateGenerator<T> implements CandidateGenerator {
|
|
||||||
protected ParameterSpace<T> parameterSpace;
|
|
||||||
protected AtomicInteger candidateCounter = new AtomicInteger(0);
|
|
||||||
protected SynchronizedRandomGenerator rng = new SynchronizedRandomGenerator(new JDKRandomGenerator());
|
|
||||||
protected Map<String, Object> dataParameters;
|
|
||||||
protected boolean initDone = false;
|
|
||||||
|
|
||||||
public BaseCandidateGenerator(ParameterSpace<T> parameterSpace, Map<String, Object> dataParameters,
|
|
||||||
boolean initDone) {
|
|
||||||
this.parameterSpace = parameterSpace;
|
|
||||||
this.dataParameters = dataParameters;
|
|
||||||
this.initDone = initDone;
|
|
||||||
}
|
|
||||||
|
|
||||||
protected void initialize() {
|
|
||||||
if(!initDone) {
|
|
||||||
//First: collect leaf parameter spaces objects and remove duplicates
|
|
||||||
List<ParameterSpace> noDuplicatesList = LeafUtils.getUniqueObjects(parameterSpace.collectLeaves());
|
|
||||||
|
|
||||||
//Second: assign each a number
|
|
||||||
int i = 0;
|
|
||||||
for (ParameterSpace ps : noDuplicatesList) {
|
|
||||||
int np = ps.numParameters();
|
|
||||||
if (np == 1) {
|
|
||||||
ps.setIndices(i++);
|
|
||||||
} else {
|
|
||||||
int[] values = new int[np];
|
|
||||||
for (int j = 0; j < np; j++)
|
|
||||||
values[j] = i++;
|
|
||||||
ps.setIndices(values);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
initDone = true;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public ParameterSpace<T> getParameterSpace() {
|
|
||||||
return parameterSpace;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void reportResults(OptimizationResult result) {
|
|
||||||
//No op
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void setRngSeed(long rngSeed) {
|
|
||||||
rng.setSeed(rngSeed);
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,187 +0,0 @@
|
||||||
/*******************************************************************************
|
|
||||||
* Copyright (c) 2015-2019 Skymind, Inc.
|
|
||||||
*
|
|
||||||
* This program and the accompanying materials are made available under the
|
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* License for the specific language governing permissions and limitations
|
|
||||||
* under the License.
|
|
||||||
*
|
|
||||||
* SPDX-License-Identifier: Apache-2.0
|
|
||||||
******************************************************************************/
|
|
||||||
|
|
||||||
package org.deeplearning4j.arbiter.optimize.generator;
|
|
||||||
|
|
||||||
import lombok.Getter;
|
|
||||||
import lombok.extern.slf4j.Slf4j;
|
|
||||||
import org.deeplearning4j.arbiter.optimize.api.Candidate;
|
|
||||||
import org.deeplearning4j.arbiter.optimize.api.OptimizationResult;
|
|
||||||
import org.deeplearning4j.arbiter.optimize.api.ParameterSpace;
|
|
||||||
import org.deeplearning4j.arbiter.optimize.api.score.ScoreFunction;
|
|
||||||
import org.deeplearning4j.arbiter.optimize.generator.genetic.Chromosome;
|
|
||||||
import org.deeplearning4j.arbiter.optimize.generator.genetic.ChromosomeFactory;
|
|
||||||
import org.deeplearning4j.arbiter.optimize.generator.genetic.exceptions.GeneticGenerationException;
|
|
||||||
import org.deeplearning4j.arbiter.optimize.generator.genetic.population.EmptyPopulationInitializer;
|
|
||||||
import org.deeplearning4j.arbiter.optimize.generator.genetic.population.PopulationInitializer;
|
|
||||||
import org.deeplearning4j.arbiter.optimize.generator.genetic.population.PopulationModel;
|
|
||||||
import org.deeplearning4j.arbiter.optimize.generator.genetic.selection.GeneticSelectionOperator;
|
|
||||||
import org.deeplearning4j.arbiter.optimize.generator.genetic.selection.SelectionOperator;
|
|
||||||
|
|
||||||
import java.util.Map;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Uses a genetic algorithm to generate candidates.
|
|
||||||
*
|
|
||||||
* @author Alexandre Boulanger
|
|
||||||
*/
|
|
||||||
@Slf4j
|
|
||||||
public class GeneticSearchCandidateGenerator extends BaseCandidateGenerator {
|
|
||||||
|
|
||||||
@Getter
|
|
||||||
protected final PopulationModel populationModel;
|
|
||||||
|
|
||||||
protected final ChromosomeFactory chromosomeFactory;
|
|
||||||
protected final SelectionOperator selectionOperator;
|
|
||||||
|
|
||||||
protected boolean hasMoreCandidates = true;
|
|
||||||
|
|
||||||
public static class Builder {
|
|
||||||
protected final ParameterSpace<?> parameterSpace;
|
|
||||||
|
|
||||||
protected Map<String, Object> dataParameters;
|
|
||||||
protected boolean initDone;
|
|
||||||
protected boolean minimizeScore;
|
|
||||||
protected PopulationModel populationModel;
|
|
||||||
protected ChromosomeFactory chromosomeFactory;
|
|
||||||
protected SelectionOperator selectionOperator;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @param parameterSpace ParameterSpace from which to generate candidates
|
|
||||||
* @param scoreFunction The score function that will be used in the OptimizationConfiguration
|
|
||||||
*/
|
|
||||||
public Builder(ParameterSpace<?> parameterSpace, ScoreFunction scoreFunction) {
|
|
||||||
this.parameterSpace = parameterSpace;
|
|
||||||
this.minimizeScore = scoreFunction.minimize();
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @param populationModel The PopulationModel instance to use.
|
|
||||||
*/
|
|
||||||
public Builder populationModel(PopulationModel populationModel) {
|
|
||||||
this.populationModel = populationModel;
|
|
||||||
return this;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @param selectionOperator The SelectionOperator to use. Default is GeneticSelectionOperator
|
|
||||||
*/
|
|
||||||
public Builder selectionOperator(SelectionOperator selectionOperator) {
|
|
||||||
this.selectionOperator = selectionOperator;
|
|
||||||
return this;
|
|
||||||
}
|
|
||||||
|
|
||||||
public Builder dataParameters(Map<String, Object> dataParameters) {
|
|
||||||
|
|
||||||
this.dataParameters = dataParameters;
|
|
||||||
return this;
|
|
||||||
}
|
|
||||||
|
|
||||||
public GeneticSearchCandidateGenerator.Builder initDone(boolean initDone) {
|
|
||||||
this.initDone = initDone;
|
|
||||||
return this;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @param chromosomeFactory The ChromosomeFactory to use
|
|
||||||
*/
|
|
||||||
public Builder chromosomeFactory(ChromosomeFactory chromosomeFactory) {
|
|
||||||
this.chromosomeFactory = chromosomeFactory;
|
|
||||||
return this;
|
|
||||||
}
|
|
||||||
|
|
||||||
public GeneticSearchCandidateGenerator build() {
|
|
||||||
if (populationModel == null) {
|
|
||||||
PopulationInitializer defaultPopulationInitializer = new EmptyPopulationInitializer();
|
|
||||||
populationModel = new PopulationModel.Builder().populationInitializer(defaultPopulationInitializer)
|
|
||||||
.build();
|
|
||||||
}
|
|
||||||
|
|
||||||
if (chromosomeFactory == null) {
|
|
||||||
chromosomeFactory = new ChromosomeFactory();
|
|
||||||
}
|
|
||||||
|
|
||||||
if (selectionOperator == null) {
|
|
||||||
selectionOperator = new GeneticSelectionOperator.Builder().build();
|
|
||||||
}
|
|
||||||
|
|
||||||
return new GeneticSearchCandidateGenerator(this);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
private GeneticSearchCandidateGenerator(Builder builder) {
|
|
||||||
super(builder.parameterSpace, builder.dataParameters, builder.initDone);
|
|
||||||
|
|
||||||
initialize();
|
|
||||||
|
|
||||||
chromosomeFactory = builder.chromosomeFactory;
|
|
||||||
populationModel = builder.populationModel;
|
|
||||||
selectionOperator = builder.selectionOperator;
|
|
||||||
|
|
||||||
chromosomeFactory.initializeInstance(builder.parameterSpace.numParameters());
|
|
||||||
populationModel.initializeInstance(builder.minimizeScore);
|
|
||||||
selectionOperator.initializeInstance(populationModel, chromosomeFactory);
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public boolean hasMoreCandidates() {
|
|
||||||
return hasMoreCandidates;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public Candidate getCandidate() {
|
|
||||||
|
|
||||||
double[] values = null;
|
|
||||||
Object value = null;
|
|
||||||
Exception e = null;
|
|
||||||
|
|
||||||
try {
|
|
||||||
values = selectionOperator.buildNextGenes();
|
|
||||||
value = parameterSpace.getValue(values);
|
|
||||||
} catch (GeneticGenerationException e2) {
|
|
||||||
log.warn("Error generating candidate", e2);
|
|
||||||
e = e2;
|
|
||||||
hasMoreCandidates = false;
|
|
||||||
} catch (Exception e2) {
|
|
||||||
log.warn("Error getting configuration for candidate", e2);
|
|
||||||
e = e2;
|
|
||||||
}
|
|
||||||
|
|
||||||
return new Candidate(value, candidateCounter.getAndIncrement(), values, dataParameters, e);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public Class<?> getCandidateType() {
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public String toString() {
|
|
||||||
return "GeneticSearchCandidateGenerator";
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void reportResults(OptimizationResult result) {
|
|
||||||
if (result.getScore() == null) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
Chromosome newChromosome = chromosomeFactory.createChromosome(result.getCandidate().getFlatParameters(),
|
|
||||||
result.getScore());
|
|
||||||
populationModel.add(newChromosome);
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,232 +0,0 @@
|
||||||
/*******************************************************************************
|
|
||||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
|
||||||
*
|
|
||||||
* This program and the accompanying materials are made available under the
|
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* License for the specific language governing permissions and limitations
|
|
||||||
* under the License.
|
|
||||||
*
|
|
||||||
* SPDX-License-Identifier: Apache-2.0
|
|
||||||
******************************************************************************/
|
|
||||||
|
|
||||||
package org.deeplearning4j.arbiter.optimize.generator;
|
|
||||||
|
|
||||||
import lombok.EqualsAndHashCode;
|
|
||||||
import lombok.Getter;
|
|
||||||
import lombok.extern.slf4j.Slf4j;
|
|
||||||
import org.apache.commons.math3.random.RandomAdaptor;
|
|
||||||
import org.deeplearning4j.arbiter.optimize.api.Candidate;
|
|
||||||
import org.deeplearning4j.arbiter.optimize.api.ParameterSpace;
|
|
||||||
import org.deeplearning4j.arbiter.optimize.parameter.FixedValue;
|
|
||||||
import org.deeplearning4j.arbiter.optimize.parameter.discrete.DiscreteParameterSpace;
|
|
||||||
import org.deeplearning4j.arbiter.optimize.parameter.integer.IntegerParameterSpace;
|
|
||||||
import org.deeplearning4j.arbiter.util.LeafUtils;
|
|
||||||
import org.nd4j.shade.jackson.annotation.JsonIgnoreProperties;
|
|
||||||
import org.nd4j.shade.jackson.annotation.JsonProperty;
|
|
||||||
|
|
||||||
import java.util.*;
|
|
||||||
import java.util.concurrent.ConcurrentLinkedQueue;
|
|
||||||
|
|
||||||
|
|
||||||
/**
|
|
||||||
* GridSearchCandidateGenerator: generates candidates in an exhaustive grid search manner.<br>
|
|
||||||
* Note that:<br>
|
|
||||||
* - For discrete parameters: the grid size (# values to check per hyperparameter) is equal to the number of values for
|
|
||||||
* that hyperparameter<br>
|
|
||||||
* - For integer parameters: the grid size is equal to {@code min(discretizationCount,max-min+1)}. Some integer ranges can
|
|
||||||
* be large, and we don't necessarily want to exhaustively search them. {@code discretizationCount} is a constructor argument<br>
|
|
||||||
* - For continuous parameters: the grid size is equal to {@code discretizationCount}.<br>
|
|
||||||
* In all cases, the minimum, maximum and gridSize-2 values between the min/max will be generated.<br>
|
|
||||||
* Also note that: if a probability distribution is provided for continuous hyperparameters, this will be taken into account
|
|
||||||
* when generating candidates. This allows the grid for a hyperparameter to be non-linear: i.e., for example, linear in log space
|
|
||||||
*
|
|
||||||
* @author Alex Black
|
|
||||||
*/
|
|
||||||
@Slf4j
|
|
||||||
@EqualsAndHashCode(exclude = {"order"}, callSuper = true)
|
|
||||||
@JsonIgnoreProperties({"numValuesPerParam", "totalNumCandidates", "order", "candidateCounter", "rng", "candidate"})
|
|
||||||
public class GridSearchCandidateGenerator extends BaseCandidateGenerator {
|
|
||||||
|
|
||||||
/**
|
|
||||||
* In what order should candidates be generated?<br>
|
|
||||||
* <b>Sequential</b>: generate candidates in order. The first hyperparameter will be changed most rapidly, and the last
|
|
||||||
* will be changed least rapidly.<br>
|
|
||||||
* <b>RandomOrder</b>: generate candidates in a random order<br>
|
|
||||||
* In both cases, the same candidates will be generated; only the order of generation is different
|
|
||||||
*/
|
|
||||||
public enum Mode {
|
|
||||||
Sequential, RandomOrder
|
|
||||||
}
|
|
||||||
|
|
||||||
private final int discretizationCount;
|
|
||||||
private final Mode mode;
|
|
||||||
|
|
||||||
private int[] numValuesPerParam;
|
|
||||||
@Getter
|
|
||||||
private int totalNumCandidates;
|
|
||||||
private Queue<Integer> order;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @param parameterSpace ParameterSpace from which to generate candidates
|
|
||||||
* @param discretizationCount For continuous parameters: into how many values should we discretize them into?
|
|
||||||
* For example, suppose continuous parameter is in range [0,1] with 3 bins:
|
|
||||||
* do [0.0, 0.5, 1.0]. Note that if all values
|
|
||||||
* @param mode {@link GridSearchCandidateGenerator.Mode} specifies the order
|
|
||||||
* in which candidates should be generated.
|
|
||||||
*/
|
|
||||||
public GridSearchCandidateGenerator(@JsonProperty("parameterSpace") ParameterSpace<?> parameterSpace,
|
|
||||||
@JsonProperty("discretizationCount") int discretizationCount, @JsonProperty("mode") Mode mode,
|
|
||||||
@JsonProperty("dataParameters") Map<String, Object> dataParameters,
|
|
||||||
@JsonProperty("initDone") boolean initDone) {
|
|
||||||
super(parameterSpace, dataParameters, initDone);
|
|
||||||
this.discretizationCount = discretizationCount;
|
|
||||||
this.mode = mode;
|
|
||||||
initialize();
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @param parameterSpace ParameterSpace from which to generate candidates
|
|
||||||
* @param discretizationCount For continuous parameters: into how many values should we discretize them into?
|
|
||||||
* For example, suppose continuous parameter is in range [0,1] with 3 bins:
|
|
||||||
* do [0.0, 0.5, 1.0]. Note that if all values
|
|
||||||
* @param mode {@link GridSearchCandidateGenerator.Mode} specifies the order
|
|
||||||
* in which candidates should be generated.
|
|
||||||
*/
|
|
||||||
public GridSearchCandidateGenerator(ParameterSpace<?> parameterSpace, int discretizationCount, Mode mode,
|
|
||||||
Map<String, Object> dataParameters){
|
|
||||||
this(parameterSpace, discretizationCount, mode, dataParameters, false);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
protected void initialize() {
|
|
||||||
super.initialize();
|
|
||||||
|
|
||||||
List<ParameterSpace> leaves = LeafUtils.getUniqueObjects(parameterSpace.collectLeaves());
|
|
||||||
int nParams = leaves.size();
|
|
||||||
|
|
||||||
//Work out for each parameter: is it continuous or discrete?
|
|
||||||
// for grid search: discrete values are grid-searchable as-is
|
|
||||||
// continuous values: discretize using 'discretizationCount' bins
|
|
||||||
// integer values: use min(max-min+1, discretizationCount) values. i.e., discretize if necessary
|
|
||||||
numValuesPerParam = new int[nParams];
|
|
||||||
long searchSize = 1;
|
|
||||||
for (int i = 0; i < nParams; i++) {
|
|
||||||
ParameterSpace ps = leaves.get(i);
|
|
||||||
if (ps instanceof DiscreteParameterSpace) {
|
|
||||||
DiscreteParameterSpace dps = (DiscreteParameterSpace) ps;
|
|
||||||
numValuesPerParam[i] = dps.numValues();
|
|
||||||
} else if (ps instanceof IntegerParameterSpace) {
|
|
||||||
IntegerParameterSpace ips = (IntegerParameterSpace) ps;
|
|
||||||
int min = ips.getMin();
|
|
||||||
int max = ips.getMax();
|
|
||||||
//Discretize, as some integer ranges are much too large to search (i.e., num. neural network units, between 100 and 1000)
|
|
||||||
numValuesPerParam[i] = Math.min(max - min + 1, discretizationCount);
|
|
||||||
} else if (ps instanceof FixedValue){
|
|
||||||
numValuesPerParam[i] = 1;
|
|
||||||
} else {
|
|
||||||
numValuesPerParam[i] = discretizationCount;
|
|
||||||
}
|
|
||||||
searchSize *= numValuesPerParam[i];
|
|
||||||
}
|
|
||||||
|
|
||||||
if (searchSize >= Integer.MAX_VALUE)
|
|
||||||
throw new IllegalStateException("Invalid search: cannot process search with " + searchSize
|
|
||||||
+ " candidates > Integer.MAX_VALUE"); //TODO find a more reasonable upper bound?
|
|
||||||
|
|
||||||
order = new ConcurrentLinkedQueue<>();
|
|
||||||
|
|
||||||
totalNumCandidates = (int) searchSize;
|
|
||||||
switch (mode) {
|
|
||||||
case Sequential:
|
|
||||||
for (int i = 0; i < totalNumCandidates; i++) {
|
|
||||||
order.add(i);
|
|
||||||
}
|
|
||||||
break;
|
|
||||||
case RandomOrder:
|
|
||||||
List<Integer> tempList = new ArrayList<>(totalNumCandidates);
|
|
||||||
for (int i = 0; i < totalNumCandidates; i++) {
|
|
||||||
tempList.add(i);
|
|
||||||
}
|
|
||||||
|
|
||||||
Collections.shuffle(tempList, new RandomAdaptor(rng));
|
|
||||||
order.addAll(tempList);
|
|
||||||
break;
|
|
||||||
default:
|
|
||||||
throw new RuntimeException();
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public boolean hasMoreCandidates() {
|
|
||||||
return !order.isEmpty();
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public Candidate getCandidate() {
|
|
||||||
int next = order.remove();
|
|
||||||
|
|
||||||
//Next: max integer (candidate number) to values
|
|
||||||
double[] values = indexToValues(numValuesPerParam, next, totalNumCandidates);
|
|
||||||
|
|
||||||
Object value = null;
|
|
||||||
Exception e = null;
|
|
||||||
try {
|
|
||||||
value = parameterSpace.getValue(values);
|
|
||||||
} catch (Exception e2) {
|
|
||||||
log.warn("Error getting configuration for candidate", e2);
|
|
||||||
e = e2;
|
|
||||||
}
|
|
||||||
|
|
||||||
return new Candidate(value, candidateCounter.getAndIncrement(), values, dataParameters, e);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public Class<?> getCandidateType() {
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
|
|
||||||
public static double[] indexToValues(int[] numValuesPerParam, int candidateIdx, int product) {
|
|
||||||
//How? first map to index of num possible values. Then: to double values in range 0 to 1
|
|
||||||
// 0-> [0,0,0], 1-> [1,0,0], 2-> [2,0,0], 3-> [0,1,0] etc
|
|
||||||
//Based on: Nd4j Shape.ind2sub
|
|
||||||
|
|
||||||
int countNon1 = 0;
|
|
||||||
for( int i : numValuesPerParam)
|
|
||||||
if(i > 1)
|
|
||||||
countNon1++;
|
|
||||||
|
|
||||||
int denom = product;
|
|
||||||
int num = candidateIdx;
|
|
||||||
int[] index = new int[numValuesPerParam.length];
|
|
||||||
|
|
||||||
for (int i = index.length - 1; i >= 0; i--) {
|
|
||||||
denom /= numValuesPerParam[i];
|
|
||||||
index[i] = num / denom;
|
|
||||||
num %= denom;
|
|
||||||
}
|
|
||||||
|
|
||||||
//Now: convert indexes to values in range [0,1]
|
|
||||||
//min value -> 0
|
|
||||||
//max value -> 1
|
|
||||||
double[] out = new double[countNon1];
|
|
||||||
int outIdx = 0;
|
|
||||||
for (int i = 0; i < numValuesPerParam.length; i++) {
|
|
||||||
if (numValuesPerParam[i] > 1){
|
|
||||||
out[outIdx++] = index[i] / ((double) (numValuesPerParam[i] - 1));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return out;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public String toString() {
|
|
||||||
return "GridSearchCandidateGenerator(mode=" + mode + ")";
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,93 +0,0 @@
|
||||||
/*******************************************************************************
|
|
||||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
|
||||||
*
|
|
||||||
* This program and the accompanying materials are made available under the
|
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* License for the specific language governing permissions and limitations
|
|
||||||
* under the License.
|
|
||||||
*
|
|
||||||
* SPDX-License-Identifier: Apache-2.0
|
|
||||||
******************************************************************************/
|
|
||||||
|
|
||||||
package org.deeplearning4j.arbiter.optimize.generator;
|
|
||||||
|
|
||||||
import lombok.EqualsAndHashCode;
|
|
||||||
import lombok.extern.slf4j.Slf4j;
|
|
||||||
import org.deeplearning4j.arbiter.optimize.api.Candidate;
|
|
||||||
import org.deeplearning4j.arbiter.optimize.api.ParameterSpace;
|
|
||||||
import org.nd4j.shade.jackson.annotation.JsonCreator;
|
|
||||||
import org.nd4j.shade.jackson.annotation.JsonIgnoreProperties;
|
|
||||||
import org.nd4j.shade.jackson.annotation.JsonProperty;
|
|
||||||
|
|
||||||
import java.util.Map;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* RandomSearchGenerator: generates candidates at random.<br>
|
|
||||||
* Note: if a probability distribution is provided for continuous hyperparameters,
|
|
||||||
* this will be taken into account
|
|
||||||
* when generating candidates. This allows the search to be weighted more towards
|
|
||||||
* certain values according to a probability
|
|
||||||
* density. For example: generate samples for learning rate according to log uniform distribution
|
|
||||||
*
|
|
||||||
* @author Alex Black
|
|
||||||
*/
|
|
||||||
@Slf4j
|
|
||||||
@EqualsAndHashCode(callSuper = true)
|
|
||||||
@JsonIgnoreProperties({"numValuesPerParam", "totalNumCandidates", "order", "candidateCounter", "rng", "candidate"})
|
|
||||||
public class RandomSearchGenerator extends BaseCandidateGenerator {
|
|
||||||
|
|
||||||
@JsonCreator
|
|
||||||
public RandomSearchGenerator(@JsonProperty("parameterSpace") ParameterSpace<?> parameterSpace,
|
|
||||||
@JsonProperty("dataParameters") Map<String, Object> dataParameters,
|
|
||||||
@JsonProperty("initDone") boolean initDone) {
|
|
||||||
super(parameterSpace, dataParameters, initDone);
|
|
||||||
initialize();
|
|
||||||
}
|
|
||||||
|
|
||||||
public RandomSearchGenerator(ParameterSpace<?> parameterSpace, Map<String,Object> dataParameters){
|
|
||||||
this(parameterSpace, dataParameters, false);
|
|
||||||
}
|
|
||||||
|
|
||||||
public RandomSearchGenerator(ParameterSpace<?> parameterSpace){
|
|
||||||
this(parameterSpace, null, false);
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public boolean hasMoreCandidates() {
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public Candidate getCandidate() {
|
|
||||||
double[] randomValues = new double[parameterSpace.numParameters()];
|
|
||||||
for (int i = 0; i < randomValues.length; i++)
|
|
||||||
randomValues[i] = rng.nextDouble();
|
|
||||||
|
|
||||||
Object value = null;
|
|
||||||
Exception e = null;
|
|
||||||
try {
|
|
||||||
value = parameterSpace.getValue(randomValues);
|
|
||||||
} catch (Exception e2) {
|
|
||||||
log.warn("Error getting configuration for candidate", e2);
|
|
||||||
e = e2;
|
|
||||||
}
|
|
||||||
|
|
||||||
return new Candidate(value, candidateCounter.getAndIncrement(), randomValues, dataParameters, e);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public Class<?> getCandidateType() {
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public String toString() {
|
|
||||||
return "RandomSearchGenerator";
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,42 +0,0 @@
|
||||||
/*******************************************************************************
|
|
||||||
* Copyright (c) 2015-2019 Skymind, Inc.
|
|
||||||
*
|
|
||||||
* This program and the accompanying materials are made available under the
|
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* License for the specific language governing permissions and limitations
|
|
||||||
* under the License.
|
|
||||||
*
|
|
||||||
* SPDX-License-Identifier: Apache-2.0
|
|
||||||
******************************************************************************/
|
|
||||||
|
|
||||||
package org.deeplearning4j.arbiter.optimize.generator.genetic;
|
|
||||||
|
|
||||||
import lombok.Data;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Candidates are stored as Chromosome in the population model
|
|
||||||
*
|
|
||||||
* @author Alexandre Boulanger
|
|
||||||
*/
|
|
||||||
@Data
|
|
||||||
public class Chromosome {
|
|
||||||
/**
|
|
||||||
* The fitness score of the genes.
|
|
||||||
*/
|
|
||||||
protected final double fitness;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* The genes.
|
|
||||||
*/
|
|
||||||
protected final double[] genes;
|
|
||||||
|
|
||||||
public Chromosome(double[] genes, double fitness) {
|
|
||||||
this.genes = genes;
|
|
||||||
this.fitness = fitness;
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,51 +0,0 @@
|
||||||
/*******************************************************************************
|
|
||||||
* Copyright (c) 2015-2019 Skymind, Inc.
|
|
||||||
*
|
|
||||||
* This program and the accompanying materials are made available under the
|
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* License for the specific language governing permissions and limitations
|
|
||||||
* under the License.
|
|
||||||
*
|
|
||||||
* SPDX-License-Identifier: Apache-2.0
|
|
||||||
******************************************************************************/
|
|
||||||
|
|
||||||
package org.deeplearning4j.arbiter.optimize.generator.genetic;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* A factory that builds new chromosomes. Used by the GeneticSearchCandidateGenerator.
|
|
||||||
*
|
|
||||||
* @author Alexandre Boulanger
|
|
||||||
*/
|
|
||||||
public class ChromosomeFactory {
|
|
||||||
private int chromosomeLength;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Called by the GeneticSearchCandidateGenerator.
|
|
||||||
*/
|
|
||||||
public void initializeInstance(int chromosomeLength) {
|
|
||||||
this.chromosomeLength = chromosomeLength;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Create a new instance of a Chromosome
|
|
||||||
*
|
|
||||||
* @param genes The genes
|
|
||||||
* @param fitness The fitness score
|
|
||||||
* @return A new instance of Chromosome
|
|
||||||
*/
|
|
||||||
public Chromosome createChromosome(double[] genes, double fitness) {
|
|
||||||
return new Chromosome(genes, fitness);
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @return The number of genes in a chromosome
|
|
||||||
*/
|
|
||||||
public int getChromosomeLength() {
|
|
||||||
return chromosomeLength;
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,120 +0,0 @@
|
||||||
/*******************************************************************************
|
|
||||||
* Copyright (c) 2015-2019 Skymind, Inc.
|
|
||||||
*
|
|
||||||
* This program and the accompanying materials are made available under the
|
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* License for the specific language governing permissions and limitations
|
|
||||||
* under the License.
|
|
||||||
*
|
|
||||||
* SPDX-License-Identifier: Apache-2.0
|
|
||||||
******************************************************************************/
|
|
||||||
|
|
||||||
package org.deeplearning4j.arbiter.optimize.generator.genetic.crossover;
|
|
||||||
|
|
||||||
import org.apache.commons.math3.random.JDKRandomGenerator;
|
|
||||||
import org.apache.commons.math3.random.RandomGenerator;
|
|
||||||
import org.apache.commons.math3.random.SynchronizedRandomGenerator;
|
|
||||||
import org.deeplearning4j.arbiter.optimize.generator.genetic.crossover.parentselection.RandomTwoParentSelection;
|
|
||||||
import org.deeplearning4j.arbiter.optimize.generator.genetic.crossover.parentselection.TwoParentSelection;
|
|
||||||
import org.nd4j.common.base.Preconditions;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* A crossover operator that linearly combines the genes of two parents. <br>
|
|
||||||
* When a crossover is generated (with a of probability <i>crossover rate</i>), each genes is a linear combination of the corresponding genes of the parents.
|
|
||||||
* <p>
|
|
||||||
* <i>t*parentA + (1-t)*parentB, where t is [0, 1] and different for each gene.</i>
|
|
||||||
*
|
|
||||||
* @author Alexandre Boulanger
|
|
||||||
*/
|
|
||||||
public class ArithmeticCrossover extends TwoParentsCrossoverOperator {
|
|
||||||
private static final double DEFAULT_CROSSOVER_RATE = 0.85;
|
|
||||||
|
|
||||||
private final double crossoverRate;
|
|
||||||
private final RandomGenerator rng;
|
|
||||||
|
|
||||||
public static class Builder {
|
|
||||||
private double crossoverRate = DEFAULT_CROSSOVER_RATE;
|
|
||||||
private RandomGenerator rng;
|
|
||||||
private TwoParentSelection parentSelection;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* The probability that the operator generates a crossover (default 0.85).
|
|
||||||
*
|
|
||||||
* @param rate A value between 0.0 and 1.0
|
|
||||||
*/
|
|
||||||
public Builder crossoverRate(double rate) {
|
|
||||||
Preconditions.checkState(rate >= 0.0 && rate <= 1.0, "Rate must be between 0.0 and 1.0, got %s", rate);
|
|
||||||
|
|
||||||
this.crossoverRate = rate;
|
|
||||||
return this;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Use a supplied RandomGenerator
|
|
||||||
*
|
|
||||||
* @param rng An instance of RandomGenerator
|
|
||||||
*/
|
|
||||||
public Builder randomGenerator(RandomGenerator rng) {
|
|
||||||
this.rng = rng;
|
|
||||||
return this;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* The parent selection behavior. Default is random parent selection.
|
|
||||||
*
|
|
||||||
* @param parentSelection An instance of TwoParentSelection
|
|
||||||
*/
|
|
||||||
public Builder parentSelection(TwoParentSelection parentSelection) {
|
|
||||||
this.parentSelection = parentSelection;
|
|
||||||
return this;
|
|
||||||
}
|
|
||||||
|
|
||||||
public ArithmeticCrossover build() {
|
|
||||||
if (rng == null) {
|
|
||||||
rng = new SynchronizedRandomGenerator(new JDKRandomGenerator());
|
|
||||||
}
|
|
||||||
|
|
||||||
if (parentSelection == null) {
|
|
||||||
parentSelection = new RandomTwoParentSelection();
|
|
||||||
}
|
|
||||||
|
|
||||||
return new ArithmeticCrossover(this);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
private ArithmeticCrossover(ArithmeticCrossover.Builder builder) {
|
|
||||||
super(builder.parentSelection);
|
|
||||||
|
|
||||||
this.crossoverRate = builder.crossoverRate;
|
|
||||||
this.rng = builder.rng;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Has a probability <i>crossoverRate</i> of performing the crossover where each gene is a linear combination of:<br>
|
|
||||||
* <i>t*parentA + (1-t)*parentB, where t is [0, 1] and different for each gene.</i><br>
|
|
||||||
* Otherwise, returns the genes of a random parent.
|
|
||||||
*
|
|
||||||
* @return The crossover result. See {@link CrossoverResult}.
|
|
||||||
*/
|
|
||||||
@Override
|
|
||||||
public CrossoverResult crossover() {
|
|
||||||
double[][] parents = parentSelection.selectParents();
|
|
||||||
|
|
||||||
double[] offspringValues = new double[parents[0].length];
|
|
||||||
|
|
||||||
if (rng.nextDouble() < crossoverRate) {
|
|
||||||
for (int i = 0; i < offspringValues.length; ++i) {
|
|
||||||
double t = rng.nextDouble();
|
|
||||||
offspringValues[i] = t * parents[0][i] + (1.0 - t) * parents[1][i];
|
|
||||||
}
|
|
||||||
return new CrossoverResult(true, offspringValues);
|
|
||||||
}
|
|
||||||
|
|
||||||
return new CrossoverResult(false, parents[0]);
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,45 +0,0 @@
|
||||||
/*******************************************************************************
|
|
||||||
* Copyright (c) 2015-2019 Skymind, Inc.
|
|
||||||
*
|
|
||||||
* This program and the accompanying materials are made available under the
|
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* License for the specific language governing permissions and limitations
|
|
||||||
* under the License.
|
|
||||||
*
|
|
||||||
* SPDX-License-Identifier: Apache-2.0
|
|
||||||
******************************************************************************/
|
|
||||||
|
|
||||||
package org.deeplearning4j.arbiter.optimize.generator.genetic.crossover;
|
|
||||||
|
|
||||||
import org.deeplearning4j.arbiter.optimize.generator.genetic.population.PopulationModel;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Abstract class for all crossover operators
|
|
||||||
*
|
|
||||||
* @author Alexandre Boulanger
|
|
||||||
*/
|
|
||||||
public abstract class CrossoverOperator {
|
|
||||||
protected PopulationModel populationModel;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Will be called by the selection operator once the population model is instantiated.
|
|
||||||
*/
|
|
||||||
public void initializeInstance(PopulationModel populationModel) {
|
|
||||||
this.populationModel = populationModel;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Performs the crossover
|
|
||||||
*
|
|
||||||
* @return The crossover result. See {@link CrossoverResult}.
|
|
||||||
*/
|
|
||||||
public abstract CrossoverResult crossover();
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
}
|
|
|
@ -1,43 +0,0 @@
|
||||||
/*******************************************************************************
|
|
||||||
* Copyright (c) 2015-2019 Skymind, Inc.
|
|
||||||
*
|
|
||||||
* This program and the accompanying materials are made available under the
|
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* License for the specific language governing permissions and limitations
|
|
||||||
* under the License.
|
|
||||||
*
|
|
||||||
* SPDX-License-Identifier: Apache-2.0
|
|
||||||
******************************************************************************/
|
|
||||||
|
|
||||||
package org.deeplearning4j.arbiter.optimize.generator.genetic.crossover;
|
|
||||||
|
|
||||||
import lombok.Data;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Returned by a crossover operator
|
|
||||||
*
|
|
||||||
* @author Alexandre Boulanger
|
|
||||||
*/
|
|
||||||
@Data
|
|
||||||
public class CrossoverResult {
|
|
||||||
/**
|
|
||||||
* If false, there was no crossover and the operator simply returned the genes of a random parent.
|
|
||||||
* If true, the genes are the result of a crossover.
|
|
||||||
*/
|
|
||||||
private final boolean isModified;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* The genes returned by the operator.
|
|
||||||
*/
|
|
||||||
private final double[] genes;
|
|
||||||
|
|
||||||
public CrossoverResult(boolean isModified, double[] genes) {
|
|
||||||
this.isModified = isModified;
|
|
||||||
this.genes = genes;
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,178 +0,0 @@
|
||||||
/*******************************************************************************
|
|
||||||
* Copyright (c) 2015-2019 Skymind, Inc.
|
|
||||||
*
|
|
||||||
* This program and the accompanying materials are made available under the
|
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* License for the specific language governing permissions and limitations
|
|
||||||
* under the License.
|
|
||||||
*
|
|
||||||
* SPDX-License-Identifier: Apache-2.0
|
|
||||||
******************************************************************************/
|
|
||||||
|
|
||||||
package org.deeplearning4j.arbiter.optimize.generator.genetic.crossover;
|
|
||||||
|
|
||||||
import org.apache.commons.math3.random.JDKRandomGenerator;
|
|
||||||
import org.apache.commons.math3.random.RandomGenerator;
|
|
||||||
import org.apache.commons.math3.random.SynchronizedRandomGenerator;
|
|
||||||
import org.deeplearning4j.arbiter.optimize.generator.genetic.crossover.parentselection.RandomTwoParentSelection;
|
|
||||||
import org.deeplearning4j.arbiter.optimize.generator.genetic.crossover.parentselection.TwoParentSelection;
|
|
||||||
import org.deeplearning4j.arbiter.optimize.generator.genetic.crossover.utils.CrossoverPointsGenerator;
|
|
||||||
import org.nd4j.common.base.Preconditions;
|
|
||||||
|
|
||||||
import java.util.Deque;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* The K-Point crossover will select at random multiple crossover points.<br>
|
|
||||||
* Each gene comes from one of the two parents. Each time a crossover point is reached, the parent is switched.
|
|
||||||
*/
|
|
||||||
public class KPointCrossover extends TwoParentsCrossoverOperator {
|
|
||||||
private static final double DEFAULT_CROSSOVER_RATE = 0.85;
|
|
||||||
private static final int DEFAULT_MIN_CROSSOVER = 1;
|
|
||||||
private static final int DEFAULT_MAX_CROSSOVER = 4;
|
|
||||||
|
|
||||||
private final double crossoverRate;
|
|
||||||
private final int minCrossovers;
|
|
||||||
private final int maxCrossovers;
|
|
||||||
|
|
||||||
private final RandomGenerator rng;
|
|
||||||
|
|
||||||
public static class Builder {
|
|
||||||
private double crossoverRate = DEFAULT_CROSSOVER_RATE;
|
|
||||||
private int minCrossovers = DEFAULT_MIN_CROSSOVER;
|
|
||||||
private int maxCrossovers = DEFAULT_MAX_CROSSOVER;
|
|
||||||
private RandomGenerator rng;
|
|
||||||
private TwoParentSelection parentSelection;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* The probability that the operator generates a crossover (default 0.85).
|
|
||||||
*
|
|
||||||
* @param rate A value between 0.0 and 1.0
|
|
||||||
*/
|
|
||||||
public Builder crossoverRate(double rate) {
|
|
||||||
Preconditions.checkState(rate >= 0.0 && rate <= 1.0, "Rate must be between 0.0 and 1.0, got %s", rate);
|
|
||||||
|
|
||||||
this.crossoverRate = rate;
|
|
||||||
return this;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* The number of crossovers points (default is min 1, max 4)
|
|
||||||
*
|
|
||||||
* @param min The minimum number
|
|
||||||
* @param max The maximum number
|
|
||||||
*/
|
|
||||||
public Builder numCrossovers(int min, int max) {
|
|
||||||
Preconditions.checkState(max >= 0 && min >= 0, "Min and max must be positive");
|
|
||||||
Preconditions.checkState(max >= min, "Max must be greater or equal to min");
|
|
||||||
|
|
||||||
this.minCrossovers = min;
|
|
||||||
this.maxCrossovers = max;
|
|
||||||
return this;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Use a fixed number of crossover points
|
|
||||||
*
|
|
||||||
* @param num The number of crossovers
|
|
||||||
*/
|
|
||||||
public Builder numCrossovers(int num) {
|
|
||||||
Preconditions.checkState(num >= 0, "Num must be positive");
|
|
||||||
|
|
||||||
this.minCrossovers = num;
|
|
||||||
this.maxCrossovers = num;
|
|
||||||
return this;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Use a supplied RandomGenerator
|
|
||||||
*
|
|
||||||
* @param rng An instance of RandomGenerator
|
|
||||||
*/
|
|
||||||
public Builder randomGenerator(RandomGenerator rng) {
|
|
||||||
this.rng = rng;
|
|
||||||
return this;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* The parent selection behavior. Default is random parent selection.
|
|
||||||
*
|
|
||||||
* @param parentSelection An instance of TwoParentSelection
|
|
||||||
*/
|
|
||||||
public Builder parentSelection(TwoParentSelection parentSelection) {
|
|
||||||
this.parentSelection = parentSelection;
|
|
||||||
return this;
|
|
||||||
}
|
|
||||||
|
|
||||||
public KPointCrossover build() {
|
|
||||||
if (rng == null) {
|
|
||||||
rng = new SynchronizedRandomGenerator(new JDKRandomGenerator());
|
|
||||||
}
|
|
||||||
|
|
||||||
if (parentSelection == null) {
|
|
||||||
parentSelection = new RandomTwoParentSelection();
|
|
||||||
}
|
|
||||||
|
|
||||||
return new KPointCrossover(this);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
private CrossoverPointsGenerator crossoverPointsGenerator;
|
|
||||||
|
|
||||||
private KPointCrossover(KPointCrossover.Builder builder) {
|
|
||||||
super(builder.parentSelection);
|
|
||||||
|
|
||||||
this.crossoverRate = builder.crossoverRate;
|
|
||||||
this.maxCrossovers = builder.maxCrossovers;
|
|
||||||
this.minCrossovers = builder.minCrossovers;
|
|
||||||
this.rng = builder.rng;
|
|
||||||
}
|
|
||||||
|
|
||||||
private CrossoverPointsGenerator getCrossoverPointsGenerator(int chromosomeLength) {
|
|
||||||
if (crossoverPointsGenerator == null) {
|
|
||||||
crossoverPointsGenerator =
|
|
||||||
new CrossoverPointsGenerator(chromosomeLength, minCrossovers, maxCrossovers, rng);
|
|
||||||
}
|
|
||||||
|
|
||||||
return crossoverPointsGenerator;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Has a probability <i>crossoverRate</i> of performing the crossover where the operator will select at random multiple crossover points.<br>
|
|
||||||
* Each gene comes from one of the two parents. Each time a crossover point is reached, the parent is switched. <br>
|
|
||||||
* Otherwise, returns the genes of a random parent.
|
|
||||||
*
|
|
||||||
* @return The crossover result. See {@link CrossoverResult}.
|
|
||||||
*/
|
|
||||||
@Override
|
|
||||||
public CrossoverResult crossover() {
|
|
||||||
double[][] parents = parentSelection.selectParents();
|
|
||||||
|
|
||||||
boolean isModified = false;
|
|
||||||
double[] resultGenes = parents[0];
|
|
||||||
|
|
||||||
if (rng.nextDouble() < crossoverRate) {
|
|
||||||
// Select crossover points
|
|
||||||
Deque<Integer> crossoverPoints = getCrossoverPointsGenerator(parents[0].length).getCrossoverPoints();
|
|
||||||
|
|
||||||
// Crossover
|
|
||||||
resultGenes = new double[parents[0].length];
|
|
||||||
int currentParent = 0;
|
|
||||||
int nextCrossover = crossoverPoints.pop();
|
|
||||||
for (int i = 0; i < resultGenes.length; ++i) {
|
|
||||||
if (i == nextCrossover) {
|
|
||||||
currentParent = currentParent == 0 ? 1 : 0;
|
|
||||||
nextCrossover = crossoverPoints.pop();
|
|
||||||
}
|
|
||||||
resultGenes[i] = parents[currentParent][i];
|
|
||||||
}
|
|
||||||
isModified = true;
|
|
||||||
}
|
|
||||||
|
|
||||||
return new CrossoverResult(isModified, resultGenes);
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,123 +0,0 @@
|
||||||
/*******************************************************************************
|
|
||||||
* Copyright (c) 2015-2019 Skymind, Inc.
|
|
||||||
*
|
|
||||||
* This program and the accompanying materials are made available under the
|
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* License for the specific language governing permissions and limitations
|
|
||||||
* under the License.
|
|
||||||
*
|
|
||||||
* SPDX-License-Identifier: Apache-2.0
|
|
||||||
******************************************************************************/
|
|
||||||
|
|
||||||
package org.deeplearning4j.arbiter.optimize.generator.genetic.crossover;
|
|
||||||
|
|
||||||
import org.apache.commons.math3.random.JDKRandomGenerator;
|
|
||||||
import org.apache.commons.math3.random.RandomGenerator;
|
|
||||||
import org.apache.commons.math3.random.SynchronizedRandomGenerator;
|
|
||||||
import org.deeplearning4j.arbiter.optimize.generator.genetic.crossover.parentselection.RandomTwoParentSelection;
|
|
||||||
import org.deeplearning4j.arbiter.optimize.generator.genetic.crossover.parentselection.TwoParentSelection;
|
|
||||||
import org.nd4j.common.base.Preconditions;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* The single point crossover will select a random point where every genes before that point comes from one parent
|
|
||||||
* and after which every genes comes from the other parent.
|
|
||||||
*
|
|
||||||
* @author Alexandre Boulanger
|
|
||||||
*/
|
|
||||||
public class SinglePointCrossover extends TwoParentsCrossoverOperator {
|
|
||||||
private static final double DEFAULT_CROSSOVER_RATE = 0.85;
|
|
||||||
|
|
||||||
private final RandomGenerator rng;
|
|
||||||
private final double crossoverRate;
|
|
||||||
|
|
||||||
public static class Builder {
|
|
||||||
private double crossoverRate = DEFAULT_CROSSOVER_RATE;
|
|
||||||
private RandomGenerator rng;
|
|
||||||
private TwoParentSelection parentSelection;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* The probability that the operator generates a crossover (default 0.85).
|
|
||||||
*
|
|
||||||
* @param rate A value between 0.0 and 1.0
|
|
||||||
*/
|
|
||||||
public Builder crossoverRate(double rate) {
|
|
||||||
Preconditions.checkState(rate >= 0.0 && rate <= 1.0, "Rate must be between 0.0 and 1.0, got %s", rate);
|
|
||||||
|
|
||||||
this.crossoverRate = rate;
|
|
||||||
return this;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Use a supplied RandomGenerator
|
|
||||||
*
|
|
||||||
* @param rng An instance of RandomGenerator
|
|
||||||
*/
|
|
||||||
public Builder randomGenerator(RandomGenerator rng) {
|
|
||||||
this.rng = rng;
|
|
||||||
return this;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* The parent selection behavior. Default is random parent selection.
|
|
||||||
*
|
|
||||||
* @param parentSelection An instance of TwoParentSelection
|
|
||||||
*/
|
|
||||||
public Builder parentSelection(TwoParentSelection parentSelection) {
|
|
||||||
this.parentSelection = parentSelection;
|
|
||||||
return this;
|
|
||||||
}
|
|
||||||
|
|
||||||
public SinglePointCrossover build() {
|
|
||||||
if (rng == null) {
|
|
||||||
rng = new SynchronizedRandomGenerator(new JDKRandomGenerator());
|
|
||||||
}
|
|
||||||
|
|
||||||
if (parentSelection == null) {
|
|
||||||
parentSelection = new RandomTwoParentSelection();
|
|
||||||
}
|
|
||||||
|
|
||||||
return new SinglePointCrossover(this);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
private SinglePointCrossover(SinglePointCrossover.Builder builder) {
|
|
||||||
super(builder.parentSelection);
|
|
||||||
|
|
||||||
this.crossoverRate = builder.crossoverRate;
|
|
||||||
this.rng = builder.rng;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Has a probability <i>crossoverRate</i> of performing the crossover where the operator will select a random crossover point.<br>
|
|
||||||
* Each gene before this point comes from one of the two parents and each gene at or after this point comes from the other parent.
|
|
||||||
* Otherwise, returns the genes of a random parent.
|
|
||||||
*
|
|
||||||
* @return The crossover result. See {@link CrossoverResult}.
|
|
||||||
*/
|
|
||||||
public CrossoverResult crossover() {
|
|
||||||
double[][] parents = parentSelection.selectParents();
|
|
||||||
|
|
||||||
boolean isModified = false;
|
|
||||||
double[] resultGenes = parents[0];
|
|
||||||
|
|
||||||
if (rng.nextDouble() < crossoverRate) {
|
|
||||||
int chromosomeLength = parents[0].length;
|
|
||||||
|
|
||||||
// Crossover
|
|
||||||
resultGenes = new double[chromosomeLength];
|
|
||||||
|
|
||||||
int crossoverPoint = rng.nextInt(chromosomeLength);
|
|
||||||
for (int i = 0; i < resultGenes.length; ++i) {
|
|
||||||
resultGenes[i] = ((i < crossoverPoint) ? parents[0] : parents[1])[i];
|
|
||||||
}
|
|
||||||
isModified = true;
|
|
||||||
}
|
|
||||||
|
|
||||||
return new CrossoverResult(isModified, resultGenes);
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,46 +0,0 @@
|
||||||
/*******************************************************************************
|
|
||||||
* Copyright (c) 2015-2019 Skymind, Inc.
|
|
||||||
*
|
|
||||||
* This program and the accompanying materials are made available under the
|
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* License for the specific language governing permissions and limitations
|
|
||||||
* under the License.
|
|
||||||
*
|
|
||||||
* SPDX-License-Identifier: Apache-2.0
|
|
||||||
******************************************************************************/
|
|
||||||
|
|
||||||
package org.deeplearning4j.arbiter.optimize.generator.genetic.crossover;
|
|
||||||
|
|
||||||
import org.deeplearning4j.arbiter.optimize.generator.genetic.crossover.parentselection.TwoParentSelection;
|
|
||||||
import org.deeplearning4j.arbiter.optimize.generator.genetic.population.PopulationModel;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Abstract class for all crossover operators that applies to two parents.
|
|
||||||
*
|
|
||||||
* @author Alexandre Boulanger
|
|
||||||
*/
|
|
||||||
public abstract class TwoParentsCrossoverOperator extends CrossoverOperator {
|
|
||||||
|
|
||||||
protected final TwoParentSelection parentSelection;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @param parentSelection A parent selection that selects two parents.
|
|
||||||
*/
|
|
||||||
protected TwoParentsCrossoverOperator(TwoParentSelection parentSelection) {
|
|
||||||
this.parentSelection = parentSelection;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Will be called by the selection operator once the population model is instantiated.
|
|
||||||
*/
|
|
||||||
@Override
|
|
||||||
public void initializeInstance(PopulationModel populationModel) {
|
|
||||||
super.initializeInstance(populationModel);
|
|
||||||
parentSelection.initializeInstance(populationModel.getPopulation());
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,136 +0,0 @@
|
||||||
/*******************************************************************************
|
|
||||||
* Copyright (c) 2015-2019 Skymind, Inc.
|
|
||||||
*
|
|
||||||
* This program and the accompanying materials are made available under the
|
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* License for the specific language governing permissions and limitations
|
|
||||||
* under the License.
|
|
||||||
*
|
|
||||||
* SPDX-License-Identifier: Apache-2.0
|
|
||||||
******************************************************************************/
|
|
||||||
|
|
||||||
package org.deeplearning4j.arbiter.optimize.generator.genetic.crossover;
|
|
||||||
|
|
||||||
import org.apache.commons.math3.random.JDKRandomGenerator;
|
|
||||||
import org.apache.commons.math3.random.RandomGenerator;
|
|
||||||
import org.apache.commons.math3.random.SynchronizedRandomGenerator;
|
|
||||||
import org.deeplearning4j.arbiter.optimize.generator.genetic.crossover.parentselection.RandomTwoParentSelection;
|
|
||||||
import org.deeplearning4j.arbiter.optimize.generator.genetic.crossover.parentselection.TwoParentSelection;
|
|
||||||
import org.nd4j.common.base.Preconditions;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* The uniform crossover will, for each gene, randomly select the parent that donates the gene.
|
|
||||||
*
|
|
||||||
* @author Alexandre Boulanger
|
|
||||||
*/
|
|
||||||
public class UniformCrossover extends TwoParentsCrossoverOperator {
|
|
||||||
private static final double DEFAULT_CROSSOVER_RATE = 0.85;
|
|
||||||
private static final double DEFAULT_PARENT_BIAS_FACTOR = 0.5;
|
|
||||||
|
|
||||||
private final double crossoverRate;
|
|
||||||
private final double parentBiasFactor;
|
|
||||||
private final RandomGenerator rng;
|
|
||||||
|
|
||||||
public static class Builder {
|
|
||||||
private double crossoverRate = DEFAULT_CROSSOVER_RATE;
|
|
||||||
private double parentBiasFactor = DEFAULT_PARENT_BIAS_FACTOR;
|
|
||||||
private RandomGenerator rng;
|
|
||||||
private TwoParentSelection parentSelection;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* The probability that the operator generates a crossover (default 0.85).
|
|
||||||
*
|
|
||||||
* @param rate A value between 0.0 and 1.0
|
|
||||||
*/
|
|
||||||
public Builder crossoverRate(double rate) {
|
|
||||||
Preconditions.checkState(rate >= 0.0 && rate <= 1.0, "Rate must be between 0.0 and 1.0, got %s", rate);
|
|
||||||
|
|
||||||
this.crossoverRate = rate;
|
|
||||||
return this;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* A factor that will introduce a bias in the parent selection.<br>
|
|
||||||
*
|
|
||||||
* @param factor In the range [0, 1]. 0 will only select the first parent while 1 only select the second one. The default is 0.5; no bias.
|
|
||||||
*/
|
|
||||||
public Builder parentBiasFactor(double factor) {
|
|
||||||
Preconditions.checkState(factor >= 0.0 && factor <= 1.0, "Factor must be between 0.0 and 1.0, got %s",
|
|
||||||
factor);
|
|
||||||
|
|
||||||
this.parentBiasFactor = factor;
|
|
||||||
return this;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Use a supplied RandomGenerator
|
|
||||||
*
|
|
||||||
* @param rng An instance of RandomGenerator
|
|
||||||
*/
|
|
||||||
public Builder randomGenerator(RandomGenerator rng) {
|
|
||||||
this.rng = rng;
|
|
||||||
return this;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* The parent selection behavior. Default is random parent selection.
|
|
||||||
*
|
|
||||||
* @param parentSelection An instance of TwoParentSelection
|
|
||||||
*/
|
|
||||||
public Builder parentSelection(TwoParentSelection parentSelection) {
|
|
||||||
this.parentSelection = parentSelection;
|
|
||||||
return this;
|
|
||||||
}
|
|
||||||
|
|
||||||
public UniformCrossover build() {
|
|
||||||
if (rng == null) {
|
|
||||||
rng = new SynchronizedRandomGenerator(new JDKRandomGenerator());
|
|
||||||
}
|
|
||||||
if (parentSelection == null) {
|
|
||||||
parentSelection = new RandomTwoParentSelection();
|
|
||||||
}
|
|
||||||
return new UniformCrossover(this);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
private UniformCrossover(UniformCrossover.Builder builder) {
|
|
||||||
super(builder.parentSelection);
|
|
||||||
|
|
||||||
this.crossoverRate = builder.crossoverRate;
|
|
||||||
this.parentBiasFactor = builder.parentBiasFactor;
|
|
||||||
this.rng = builder.rng;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Has a probability <i>crossoverRate</i> of performing the crossover where the operator will select randomly which parent donates the gene.<br>
|
|
||||||
* One of the parent may be favored if the bias is different than 0.5
|
|
||||||
* Otherwise, returns the genes of a random parent.
|
|
||||||
*
|
|
||||||
* @return The crossover result. See {@link CrossoverResult}.
|
|
||||||
*/
|
|
||||||
@Override
|
|
||||||
public CrossoverResult crossover() {
|
|
||||||
// select the parents
|
|
||||||
double[][] parents = parentSelection.selectParents();
|
|
||||||
|
|
||||||
double[] resultGenes = parents[0];
|
|
||||||
boolean isModified = false;
|
|
||||||
|
|
||||||
if (rng.nextDouble() < crossoverRate) {
|
|
||||||
// Crossover
|
|
||||||
resultGenes = new double[parents[0].length];
|
|
||||||
|
|
||||||
for (int i = 0; i < resultGenes.length; ++i) {
|
|
||||||
resultGenes[i] = ((rng.nextDouble() < parentBiasFactor) ? parents[0] : parents[1])[i];
|
|
||||||
}
|
|
||||||
isModified = true;
|
|
||||||
}
|
|
||||||
|
|
||||||
return new CrossoverResult(isModified, resultGenes);
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,44 +0,0 @@
|
||||||
/*******************************************************************************
|
|
||||||
* Copyright (c) 2015-2019 Skymind, Inc.
|
|
||||||
*
|
|
||||||
* This program and the accompanying materials are made available under the
|
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* License for the specific language governing permissions and limitations
|
|
||||||
* under the License.
|
|
||||||
*
|
|
||||||
* SPDX-License-Identifier: Apache-2.0
|
|
||||||
******************************************************************************/
|
|
||||||
|
|
||||||
package org.deeplearning4j.arbiter.optimize.generator.genetic.crossover.parentselection;
|
|
||||||
|
|
||||||
import org.deeplearning4j.arbiter.optimize.generator.genetic.Chromosome;
|
|
||||||
|
|
||||||
import java.util.List;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Abstract class for all parent selection behaviors
|
|
||||||
*
|
|
||||||
* @author Alexandre Boulanger
|
|
||||||
*/
|
|
||||||
public abstract class ParentSelection {
|
|
||||||
protected List<Chromosome> population;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Will be called by the crossover operator once the population model is instantiated.
|
|
||||||
*/
|
|
||||||
public void initializeInstance(List<Chromosome> population) {
|
|
||||||
this.population = population;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Performs the parent selection
|
|
||||||
*
|
|
||||||
* @return An array of parents genes. The outer array are the parents, and the inner array are the genes.
|
|
||||||
*/
|
|
||||||
public abstract double[][] selectParents();
|
|
||||||
}
|
|
|
@ -1,65 +0,0 @@
|
||||||
/*******************************************************************************
|
|
||||||
* Copyright (c) 2015-2019 Skymind, Inc.
|
|
||||||
*
|
|
||||||
* This program and the accompanying materials are made available under the
|
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* License for the specific language governing permissions and limitations
|
|
||||||
* under the License.
|
|
||||||
*
|
|
||||||
* SPDX-License-Identifier: Apache-2.0
|
|
||||||
******************************************************************************/
|
|
||||||
|
|
||||||
package org.deeplearning4j.arbiter.optimize.generator.genetic.crossover.parentselection;
|
|
||||||
|
|
||||||
import org.apache.commons.math3.random.JDKRandomGenerator;
|
|
||||||
import org.apache.commons.math3.random.RandomGenerator;
|
|
||||||
import org.apache.commons.math3.random.SynchronizedRandomGenerator;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* A parent selection behavior that returns two random parents.
|
|
||||||
*
|
|
||||||
* @author Alexandre Boulanger
|
|
||||||
*/
|
|
||||||
public class RandomTwoParentSelection extends TwoParentSelection {
|
|
||||||
|
|
||||||
private final RandomGenerator rng;
|
|
||||||
|
|
||||||
public RandomTwoParentSelection() {
|
|
||||||
this(new SynchronizedRandomGenerator(new JDKRandomGenerator()));
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Use a supplied RandomGenerator
|
|
||||||
*
|
|
||||||
* @param rng An instance of RandomGenerator
|
|
||||||
*/
|
|
||||||
public RandomTwoParentSelection(RandomGenerator rng) {
|
|
||||||
this.rng = rng;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Selects two random parents
|
|
||||||
*
|
|
||||||
* @return An array of parents genes. The outer array are the parents, and the inner array are the genes.
|
|
||||||
*/
|
|
||||||
@Override
|
|
||||||
public double[][] selectParents() {
|
|
||||||
double[][] parents = new double[2][];
|
|
||||||
|
|
||||||
int parent1Idx = rng.nextInt(population.size());
|
|
||||||
int parent2Idx;
|
|
||||||
do {
|
|
||||||
parent2Idx = rng.nextInt(population.size());
|
|
||||||
} while (parent1Idx == parent2Idx);
|
|
||||||
|
|
||||||
parents[0] = population.get(parent1Idx).getGenes();
|
|
||||||
parents[1] = population.get(parent2Idx).getGenes();
|
|
||||||
|
|
||||||
return parents;
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,25 +0,0 @@
|
||||||
/*******************************************************************************
|
|
||||||
* Copyright (c) 2015-2019 Skymind, Inc.
|
|
||||||
*
|
|
||||||
* This program and the accompanying materials are made available under the
|
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* License for the specific language governing permissions and limitations
|
|
||||||
* under the License.
|
|
||||||
*
|
|
||||||
* SPDX-License-Identifier: Apache-2.0
|
|
||||||
******************************************************************************/
|
|
||||||
|
|
||||||
package org.deeplearning4j.arbiter.optimize.generator.genetic.crossover.parentselection;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Abstract class for all parent selection behaviors that selects two parents.
|
|
||||||
*
|
|
||||||
* @author Alexandre Boulanger
|
|
||||||
*/
|
|
||||||
public abstract class TwoParentSelection extends ParentSelection {
|
|
||||||
}
|
|
|
@ -1,68 +0,0 @@
|
||||||
/*******************************************************************************
|
|
||||||
* Copyright (c) 2015-2019 Skymind, Inc.
|
|
||||||
*
|
|
||||||
* This program and the accompanying materials are made available under the
|
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* License for the specific language governing permissions and limitations
|
|
||||||
* under the License.
|
|
||||||
*
|
|
||||||
* SPDX-License-Identifier: Apache-2.0
|
|
||||||
******************************************************************************/
|
|
||||||
|
|
||||||
package org.deeplearning4j.arbiter.optimize.generator.genetic.crossover.utils;
|
|
||||||
|
|
||||||
import org.apache.commons.math3.random.RandomGenerator;
|
|
||||||
import org.deeplearning4j.arbiter.optimize.generator.genetic.crossover.KPointCrossover;
|
|
||||||
|
|
||||||
import java.util.*;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* A helper class used by {@link KPointCrossover} to generate the crossover points
|
|
||||||
*
|
|
||||||
* @author Alexandre Boulanger
|
|
||||||
*/
|
|
||||||
public class CrossoverPointsGenerator {
|
|
||||||
private final int minCrossovers;
|
|
||||||
private final int maxCrossovers;
|
|
||||||
private final RandomGenerator rng;
|
|
||||||
private List<Integer> parameterIndexes;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Constructor
|
|
||||||
*
|
|
||||||
* @param chromosomeLength The number of genes
|
|
||||||
* @param minCrossovers The minimum number of crossover points to generate
|
|
||||||
* @param maxCrossovers The maximum number of crossover points to generate
|
|
||||||
* @param rng A RandomGenerator instance
|
|
||||||
*/
|
|
||||||
public CrossoverPointsGenerator(int chromosomeLength, int minCrossovers, int maxCrossovers, RandomGenerator rng) {
|
|
||||||
this.minCrossovers = minCrossovers;
|
|
||||||
this.maxCrossovers = maxCrossovers;
|
|
||||||
this.rng = rng;
|
|
||||||
parameterIndexes = new ArrayList<Integer>();
|
|
||||||
for (int i = 0; i < chromosomeLength; ++i) {
|
|
||||||
parameterIndexes.add(i);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Generate a list of crossover points.
|
|
||||||
*
|
|
||||||
* @return An ordered list of crossover point indexes and with Integer.MAX_VALUE as the last element
|
|
||||||
*/
|
|
||||||
public Deque<Integer> getCrossoverPoints() {
|
|
||||||
Collections.shuffle(parameterIndexes);
|
|
||||||
List<Integer> crossoverPointLists =
|
|
||||||
parameterIndexes.subList(0, rng.nextInt(maxCrossovers - minCrossovers) + minCrossovers);
|
|
||||||
Collections.sort(crossoverPointLists);
|
|
||||||
Deque<Integer> crossoverPoints = new ArrayDeque<Integer>(crossoverPointLists);
|
|
||||||
crossoverPoints.add(Integer.MAX_VALUE);
|
|
||||||
|
|
||||||
return crossoverPoints;
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,41 +0,0 @@
|
||||||
/*******************************************************************************
|
|
||||||
* Copyright (c) 2015-2019 Skymind, Inc.
|
|
||||||
*
|
|
||||||
* This program and the accompanying materials are made available under the
|
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* License for the specific language governing permissions and limitations
|
|
||||||
* under the License.
|
|
||||||
*
|
|
||||||
* SPDX-License-Identifier: Apache-2.0
|
|
||||||
******************************************************************************/
|
|
||||||
|
|
||||||
package org.deeplearning4j.arbiter.optimize.generator.genetic.culling;
|
|
||||||
|
|
||||||
import org.deeplearning4j.arbiter.optimize.generator.genetic.population.PopulationModel;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* The cull operator will remove from the population the least desirables chromosomes.
|
|
||||||
*
|
|
||||||
* @author Alexandre Boulanger
|
|
||||||
*/
|
|
||||||
public interface CullOperator {
|
|
||||||
/**
|
|
||||||
* Will be called by the population model once created.
|
|
||||||
*/
|
|
||||||
void initializeInstance(PopulationModel populationModel);
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Cull the population to the culled size.
|
|
||||||
*/
|
|
||||||
void cullPopulation();
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @return The target population size after culling.
|
|
||||||
*/
|
|
||||||
int getCulledSize();
|
|
||||||
}
|
|
|
@ -1,50 +0,0 @@
|
||||||
/*******************************************************************************
|
|
||||||
* Copyright (c) 2015-2019 Skymind, Inc.
|
|
||||||
*
|
|
||||||
* This program and the accompanying materials are made available under the
|
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* License for the specific language governing permissions and limitations
|
|
||||||
* under the License.
|
|
||||||
*
|
|
||||||
* SPDX-License-Identifier: Apache-2.0
|
|
||||||
******************************************************************************/
|
|
||||||
|
|
||||||
package org.deeplearning4j.arbiter.optimize.generator.genetic.culling;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* An elitist cull operator that discards the chromosomes with the worst fitness while keeping the best ones.
|
|
||||||
*
|
|
||||||
* @author Alexandre Boulanger
|
|
||||||
*/
|
|
||||||
public class LeastFitCullOperator extends RatioCullOperator {
|
|
||||||
|
|
||||||
/**
|
|
||||||
* The default cull ratio is 1/3.
|
|
||||||
*/
|
|
||||||
public LeastFitCullOperator() {
|
|
||||||
super();
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @param cullRatio The ratio of the maximum population size to be culled.<br>
|
|
||||||
* For example, a ratio of 1/3 on a population with a maximum size of 30 will cull back a given population to 20.
|
|
||||||
*/
|
|
||||||
public LeastFitCullOperator(double cullRatio) {
|
|
||||||
super(cullRatio);
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Will discard the chromosomes with the worst fitness until the population size fall back at the culled size.
|
|
||||||
*/
|
|
||||||
@Override
|
|
||||||
public void cullPopulation() {
|
|
||||||
while (population.size() > culledSize) {
|
|
||||||
population.remove(population.size() - 1);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,70 +0,0 @@
|
||||||
/*******************************************************************************
|
|
||||||
* Copyright (c) 2015-2019 Skymind, Inc.
|
|
||||||
*
|
|
||||||
* This program and the accompanying materials are made available under the
|
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* License for the specific language governing permissions and limitations
|
|
||||||
* under the License.
|
|
||||||
*
|
|
||||||
* SPDX-License-Identifier: Apache-2.0
|
|
||||||
******************************************************************************/
|
|
||||||
|
|
||||||
package org.deeplearning4j.arbiter.optimize.generator.genetic.culling;
|
|
||||||
|
|
||||||
import org.deeplearning4j.arbiter.optimize.generator.genetic.Chromosome;
|
|
||||||
import org.deeplearning4j.arbiter.optimize.generator.genetic.population.PopulationModel;
|
|
||||||
import org.nd4j.common.base.Preconditions;
|
|
||||||
|
|
||||||
import java.util.List;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* An abstract base for cull operators that culls back the population to a ratio of its maximum size.
|
|
||||||
*
|
|
||||||
* @author Alexandre Boulanger
|
|
||||||
*/
|
|
||||||
public abstract class RatioCullOperator implements CullOperator {
|
|
||||||
private static final double DEFAULT_CULL_RATIO = 1.0 / 3.0;
|
|
||||||
protected int culledSize;
|
|
||||||
protected List<Chromosome> population;
|
|
||||||
protected final double cullRatio;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @param cullRatio The ratio of the maximum population size to be culled.<br>
|
|
||||||
* For example, a ratio of 1/3 on a population with a maximum size of 30 will cull back a given population to 20.
|
|
||||||
*/
|
|
||||||
public RatioCullOperator(double cullRatio) {
|
|
||||||
Preconditions.checkState(cullRatio >= 0.0 && cullRatio <= 1.0, "Cull ratio must be between 0.0 and 1.0, got %s",
|
|
||||||
cullRatio);
|
|
||||||
|
|
||||||
this.cullRatio = cullRatio;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* The default cull ratio is 1/3
|
|
||||||
*/
|
|
||||||
public RatioCullOperator() {
|
|
||||||
this(DEFAULT_CULL_RATIO);
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Will be called by the population model once created.
|
|
||||||
*/
|
|
||||||
public void initializeInstance(PopulationModel populationModel) {
|
|
||||||
this.population = populationModel.getPopulation();
|
|
||||||
culledSize = (int) (populationModel.getPopulationSize() * (1.0 - cullRatio) + 0.5);
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @return The target population size after culling.
|
|
||||||
*/
|
|
||||||
@Override
|
|
||||||
public int getCulledSize() {
|
|
||||||
return culledSize;
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
|
@ -1,23 +0,0 @@
|
||||||
/*******************************************************************************
|
|
||||||
* Copyright (c) 2015-2019 Skymind, Inc.
|
|
||||||
*
|
|
||||||
* This program and the accompanying materials are made available under the
|
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* License for the specific language governing permissions and limitations
|
|
||||||
* under the License.
|
|
||||||
*
|
|
||||||
* SPDX-License-Identifier: Apache-2.0
|
|
||||||
******************************************************************************/
|
|
||||||
|
|
||||||
package org.deeplearning4j.arbiter.optimize.generator.genetic.exceptions;
|
|
||||||
|
|
||||||
public class GeneticGenerationException extends RuntimeException {
|
|
||||||
public GeneticGenerationException(String message) {
|
|
||||||
super(message);
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,33 +0,0 @@
|
||||||
/*******************************************************************************
|
|
||||||
* Copyright (c) 2015-2019 Skymind, Inc.
|
|
||||||
*
|
|
||||||
* This program and the accompanying materials are made available under the
|
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* License for the specific language governing permissions and limitations
|
|
||||||
* under the License.
|
|
||||||
*
|
|
||||||
* SPDX-License-Identifier: Apache-2.0
|
|
||||||
******************************************************************************/
|
|
||||||
|
|
||||||
package org.deeplearning4j.arbiter.optimize.generator.genetic.mutation;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* The mutation operator will apply a mutation to the given genes.
|
|
||||||
*
|
|
||||||
* @author Alexandre Boulanger
|
|
||||||
*/
|
|
||||||
public interface MutationOperator {
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Performs a mutation.
|
|
||||||
*
|
|
||||||
* @param genes The genes to be mutated
|
|
||||||
* @return True if the genes were mutated, otherwise false.
|
|
||||||
*/
|
|
||||||
boolean mutate(double[] genes);
|
|
||||||
}
|
|
|
@ -1,93 +0,0 @@
|
||||||
/*******************************************************************************
|
|
||||||
* Copyright (c) 2015-2019 Skymind, Inc.
|
|
||||||
*
|
|
||||||
* This program and the accompanying materials are made available under the
|
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* License for the specific language governing permissions and limitations
|
|
||||||
* under the License.
|
|
||||||
*
|
|
||||||
* SPDX-License-Identifier: Apache-2.0
|
|
||||||
******************************************************************************/
|
|
||||||
|
|
||||||
package org.deeplearning4j.arbiter.optimize.generator.genetic.mutation;
|
|
||||||
|
|
||||||
import org.apache.commons.math3.random.JDKRandomGenerator;
|
|
||||||
import org.apache.commons.math3.random.RandomGenerator;
|
|
||||||
import org.apache.commons.math3.random.SynchronizedRandomGenerator;
|
|
||||||
import org.nd4j.common.base.Preconditions;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* A mutation operator where each gene has a chance of being mutated with a <i>mutation rate</i> probability.
|
|
||||||
*
|
|
||||||
* @author Alexandre Boulanger
|
|
||||||
*/
|
|
||||||
public class RandomMutationOperator implements MutationOperator {
|
|
||||||
private static final double DEFAULT_MUTATION_RATE = 0.005;
|
|
||||||
|
|
||||||
private final double mutationRate;
|
|
||||||
private final RandomGenerator rng;
|
|
||||||
|
|
||||||
public static class Builder {
|
|
||||||
private double mutationRate = DEFAULT_MUTATION_RATE;
|
|
||||||
private RandomGenerator rng;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Each gene will have this probability of being mutated.
|
|
||||||
*
|
|
||||||
* @param rate The mutation rate. (default 0.005)
|
|
||||||
*/
|
|
||||||
public Builder mutationRate(double rate) {
|
|
||||||
Preconditions.checkState(rate >= 0.0 && rate <= 1.0, "Rate must be between 0.0 and 1.0, got %s", rate);
|
|
||||||
|
|
||||||
this.mutationRate = rate;
|
|
||||||
return this;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Use a supplied RandomGenerator
|
|
||||||
*
|
|
||||||
* @param rng An instance of RandomGenerator
|
|
||||||
*/
|
|
||||||
public Builder randomGenerator(RandomGenerator rng) {
|
|
||||||
this.rng = rng;
|
|
||||||
return this;
|
|
||||||
}
|
|
||||||
|
|
||||||
public RandomMutationOperator build() {
|
|
||||||
if (rng == null) {
|
|
||||||
rng = new SynchronizedRandomGenerator(new JDKRandomGenerator());
|
|
||||||
}
|
|
||||||
return new RandomMutationOperator(this);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
private RandomMutationOperator(RandomMutationOperator.Builder builder) {
|
|
||||||
this.mutationRate = builder.mutationRate;
|
|
||||||
this.rng = builder.rng;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Performs the mutation. Each gene has a <i>mutation rate</i> probability of being mutated.
|
|
||||||
*
|
|
||||||
* @param genes The genes to be mutated
|
|
||||||
* @return True if the genes were mutated, otherwise false.
|
|
||||||
*/
|
|
||||||
@Override
|
|
||||||
public boolean mutate(double[] genes) {
|
|
||||||
boolean hasMutation = false;
|
|
||||||
|
|
||||||
for (int i = 0; i < genes.length; ++i) {
|
|
||||||
if (rng.nextDouble() < mutationRate) {
|
|
||||||
genes[i] = rng.nextDouble();
|
|
||||||
hasMutation = true;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return hasMutation;
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,41 +0,0 @@
|
||||||
/*******************************************************************************
|
|
||||||
* Copyright (c) 2015-2019 Skymind, Inc.
|
|
||||||
*
|
|
||||||
* This program and the accompanying materials are made available under the
|
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* License for the specific language governing permissions and limitations
|
|
||||||
* under the License.
|
|
||||||
*
|
|
||||||
* SPDX-License-Identifier: Apache-2.0
|
|
||||||
******************************************************************************/
|
|
||||||
|
|
||||||
package org.deeplearning4j.arbiter.optimize.generator.genetic.population;
|
|
||||||
|
|
||||||
import org.deeplearning4j.arbiter.optimize.generator.genetic.Chromosome;
|
|
||||||
|
|
||||||
import java.util.ArrayList;
|
|
||||||
import java.util.List;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* A population initializer that build an empty population.
|
|
||||||
*
|
|
||||||
* @author Alexandre Boulanger
|
|
||||||
*/
|
|
||||||
public class EmptyPopulationInitializer implements PopulationInitializer {
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Initialize an empty population
|
|
||||||
*
|
|
||||||
* @param size The maximum size of the population.
|
|
||||||
* @return The initialized population.
|
|
||||||
*/
|
|
||||||
@Override
|
|
||||||
public List<Chromosome> getInitializedPopulation(int size) {
|
|
||||||
return new ArrayList<>(size);
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,36 +0,0 @@
|
||||||
/*******************************************************************************
|
|
||||||
* Copyright (c) 2015-2019 Skymind, Inc.
|
|
||||||
*
|
|
||||||
* This program and the accompanying materials are made available under the
|
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* License for the specific language governing permissions and limitations
|
|
||||||
* under the License.
|
|
||||||
*
|
|
||||||
* SPDX-License-Identifier: Apache-2.0
|
|
||||||
******************************************************************************/
|
|
||||||
|
|
||||||
package org.deeplearning4j.arbiter.optimize.generator.genetic.population;
|
|
||||||
|
|
||||||
import org.deeplearning4j.arbiter.optimize.generator.genetic.Chromosome;
|
|
||||||
|
|
||||||
import java.util.List;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* An initializer that construct the population used by the population model.
|
|
||||||
*
|
|
||||||
* @author Alexandre Boulanger
|
|
||||||
*/
|
|
||||||
public interface PopulationInitializer {
|
|
||||||
/**
|
|
||||||
* Called by the population model to construct the population
|
|
||||||
*
|
|
||||||
* @param size The maximum size of the population
|
|
||||||
* @return An initialized population
|
|
||||||
*/
|
|
||||||
List<Chromosome> getInitializedPopulation(int size);
|
|
||||||
}
|
|
|
@ -1,35 +0,0 @@
|
||||||
/*******************************************************************************
|
|
||||||
* Copyright (c) 2015-2019 Skymind, Inc.
|
|
||||||
*
|
|
||||||
* This program and the accompanying materials are made available under the
|
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* License for the specific language governing permissions and limitations
|
|
||||||
* under the License.
|
|
||||||
*
|
|
||||||
* SPDX-License-Identifier: Apache-2.0
|
|
||||||
******************************************************************************/
|
|
||||||
|
|
||||||
package org.deeplearning4j.arbiter.optimize.generator.genetic.population;
|
|
||||||
|
|
||||||
import org.deeplearning4j.arbiter.optimize.generator.genetic.Chromosome;
|
|
||||||
|
|
||||||
import java.util.List;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* A listener that is called when the population changes.
|
|
||||||
*
|
|
||||||
* @author Alexandre Boulanger
|
|
||||||
*/
|
|
||||||
public interface PopulationListener {
|
|
||||||
/**
|
|
||||||
* Called after the population has changed.
|
|
||||||
*
|
|
||||||
* @param population The population after it has changed.
|
|
||||||
*/
|
|
||||||
void onChanged(List<Chromosome> population);
|
|
||||||
}
|
|
|
@ -1,182 +0,0 @@
|
||||||
/*******************************************************************************
|
|
||||||
* Copyright (c) 2015-2019 Skymind, Inc.
|
|
||||||
*
|
|
||||||
* This program and the accompanying materials are made available under the
|
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* License for the specific language governing permissions and limitations
|
|
||||||
* under the License.
|
|
||||||
*
|
|
||||||
* SPDX-License-Identifier: Apache-2.0
|
|
||||||
******************************************************************************/
|
|
||||||
|
|
||||||
package org.deeplearning4j.arbiter.optimize.generator.genetic.population;
|
|
||||||
|
|
||||||
import lombok.Getter;
|
|
||||||
import org.deeplearning4j.arbiter.optimize.generator.genetic.Chromosome;
|
|
||||||
import org.deeplearning4j.arbiter.optimize.generator.genetic.culling.CullOperator;
|
|
||||||
import org.deeplearning4j.arbiter.optimize.generator.genetic.culling.LeastFitCullOperator;
|
|
||||||
|
|
||||||
import java.util.ArrayList;
|
|
||||||
import java.util.Collections;
|
|
||||||
import java.util.Comparator;
|
|
||||||
import java.util.List;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* The population model handles all aspects of the population (initialization, additions and culling)
|
|
||||||
*
|
|
||||||
* @author Alexandre Boulanger
|
|
||||||
*/
|
|
||||||
public class PopulationModel {
|
|
||||||
private static final int DEFAULT_POPULATION_SIZE = 30;
|
|
||||||
|
|
||||||
private final CullOperator cullOperator;
|
|
||||||
private final List<PopulationListener> populationListeners = new ArrayList<>();
|
|
||||||
private Comparator<Chromosome> chromosomeComparator;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* The maximum population size
|
|
||||||
*/
|
|
||||||
@Getter
|
|
||||||
private final int populationSize;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* The population
|
|
||||||
*/
|
|
||||||
@Getter
|
|
||||||
public final List<Chromosome> population;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* A comparator used when higher fitness value is better
|
|
||||||
*/
|
|
||||||
public static class MaximizeScoreComparator implements Comparator<Chromosome> {
|
|
||||||
@Override
|
|
||||||
public int compare(Chromosome lhs, Chromosome rhs) {
|
|
||||||
return -Double.compare(lhs.getFitness(), rhs.getFitness());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* A comparator used when lower fitness value is better
|
|
||||||
*/
|
|
||||||
public static class MinimizeScoreComparator implements Comparator<Chromosome> {
|
|
||||||
@Override
|
|
||||||
public int compare(Chromosome lhs, Chromosome rhs) {
|
|
||||||
return Double.compare(lhs.getFitness(), rhs.getFitness());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
public static class Builder {
|
|
||||||
private int populationSize = DEFAULT_POPULATION_SIZE;
|
|
||||||
private PopulationInitializer populationInitializer;
|
|
||||||
private CullOperator cullOperator;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Use an alternate population initialization behavior. Default is empty population.
|
|
||||||
*
|
|
||||||
* @param populationInitializer An instance of PopulationInitializer
|
|
||||||
*/
|
|
||||||
public Builder populationInitializer(PopulationInitializer populationInitializer) {
|
|
||||||
this.populationInitializer = populationInitializer;
|
|
||||||
return this;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* The maximum population size. <br>
|
|
||||||
* If using a ratio based culling, using a population with culled size of around 1.5 to 2 times the number of genes generally gives good results.
|
|
||||||
* (e.g. For a chromosome having 10 genes, the culled size should be between 15 and 20. And with a cull ratio of 1/3 we should set the population size to 23 to 30. (15 / (1 - 1/3)), rounded up)
|
|
||||||
*
|
|
||||||
* @param size The maximum size of the population
|
|
||||||
*/
|
|
||||||
public Builder populationSize(int size) {
|
|
||||||
populationSize = size;
|
|
||||||
return this;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Use an alternate cull operator behavior. Default is least fit culling.
|
|
||||||
*
|
|
||||||
* @param cullOperator An instance of a CullOperator
|
|
||||||
*/
|
|
||||||
public Builder cullOperator(CullOperator cullOperator) {
|
|
||||||
this.cullOperator = cullOperator;
|
|
||||||
return this;
|
|
||||||
}
|
|
||||||
|
|
||||||
public PopulationModel build() {
|
|
||||||
if (cullOperator == null) {
|
|
||||||
cullOperator = new LeastFitCullOperator();
|
|
||||||
}
|
|
||||||
|
|
||||||
if (populationInitializer == null) {
|
|
||||||
populationInitializer = new EmptyPopulationInitializer();
|
|
||||||
}
|
|
||||||
|
|
||||||
return new PopulationModel(this);
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
public PopulationModel(PopulationModel.Builder builder) {
|
|
||||||
populationSize = builder.populationSize;
|
|
||||||
population = new ArrayList<>(builder.populationSize);
|
|
||||||
PopulationInitializer populationInitializer = builder.populationInitializer;
|
|
||||||
|
|
||||||
List<Chromosome> initializedPopulation = populationInitializer.getInitializedPopulation(populationSize);
|
|
||||||
population.clear();
|
|
||||||
population.addAll(initializedPopulation);
|
|
||||||
|
|
||||||
cullOperator = builder.cullOperator;
|
|
||||||
cullOperator.initializeInstance(this);
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Called by the GeneticSearchCandidateGenerator
|
|
||||||
*/
|
|
||||||
public void initializeInstance(boolean minimizeScore) {
|
|
||||||
chromosomeComparator = minimizeScore ? new MinimizeScoreComparator() : new MaximizeScoreComparator();
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Add a PopulationListener to the list of change listeners
|
|
||||||
* @param listener A PopulationListener instance
|
|
||||||
*/
|
|
||||||
public void addListener(PopulationListener listener) {
|
|
||||||
populationListeners.add(listener);
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Add a Chromosome to the population and call the PopulationListeners. Culling may be triggered.
|
|
||||||
*
|
|
||||||
* @param element The chromosome to be added
|
|
||||||
*/
|
|
||||||
public void add(Chromosome element) {
|
|
||||||
if (population.size() == populationSize) {
|
|
||||||
cullOperator.cullPopulation();
|
|
||||||
}
|
|
||||||
|
|
||||||
population.add(element);
|
|
||||||
|
|
||||||
Collections.sort(population, chromosomeComparator);
|
|
||||||
|
|
||||||
triggerPopulationChangedListeners(population);
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @return Return false when the population is below the culled size, otherwise true. <br>
|
|
||||||
* Used by the selection operator to know if the population is still too small and should generate random genes.
|
|
||||||
*/
|
|
||||||
public boolean isReadyToBreed() {
|
|
||||||
return population.size() >= cullOperator.getCulledSize();
|
|
||||||
}
|
|
||||||
|
|
||||||
private void triggerPopulationChangedListeners(List<Chromosome> population) {
|
|
||||||
for (PopulationListener listener : populationListeners) {
|
|
||||||
listener.onChanged(population);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,197 +0,0 @@
|
||||||
/*******************************************************************************
|
|
||||||
* Copyright (c) 2015-2019 Skymind, Inc.
|
|
||||||
*
|
|
||||||
* This program and the accompanying materials are made available under the
|
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* License for the specific language governing permissions and limitations
|
|
||||||
* under the License.
|
|
||||||
*
|
|
||||||
* SPDX-License-Identifier: Apache-2.0
|
|
||||||
******************************************************************************/
|
|
||||||
|
|
||||||
package org.deeplearning4j.arbiter.optimize.generator.genetic.selection;
|
|
||||||
|
|
||||||
import org.apache.commons.math3.random.JDKRandomGenerator;
|
|
||||||
import org.apache.commons.math3.random.RandomGenerator;
|
|
||||||
import org.apache.commons.math3.random.SynchronizedRandomGenerator;
|
|
||||||
import org.deeplearning4j.arbiter.optimize.generator.genetic.ChromosomeFactory;
|
|
||||||
import org.deeplearning4j.arbiter.optimize.generator.genetic.crossover.CrossoverOperator;
|
|
||||||
import org.deeplearning4j.arbiter.optimize.generator.genetic.crossover.CrossoverResult;
|
|
||||||
import org.deeplearning4j.arbiter.optimize.generator.genetic.crossover.SinglePointCrossover;
|
|
||||||
import org.deeplearning4j.arbiter.optimize.generator.genetic.exceptions.GeneticGenerationException;
|
|
||||||
import org.deeplearning4j.arbiter.optimize.generator.genetic.mutation.MutationOperator;
|
|
||||||
import org.deeplearning4j.arbiter.optimize.generator.genetic.mutation.RandomMutationOperator;
|
|
||||||
import org.deeplearning4j.arbiter.optimize.generator.genetic.population.PopulationModel;
|
|
||||||
|
|
||||||
import java.util.Arrays;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* A selection operator that will generate random genes initially. Once the population has reached the culled size,
|
|
||||||
* will start to generate offsprings of parents selected in the population.
|
|
||||||
*
|
|
||||||
* @author Alexandre Boulanger
|
|
||||||
*/
|
|
||||||
public class GeneticSelectionOperator extends SelectionOperator {
|
|
||||||
|
|
||||||
private final static int PREVIOUS_GENES_TO_KEEP = 100;
|
|
||||||
private final static int MAX_NUM_GENERATION_ATTEMPTS = 1024;
|
|
||||||
|
|
||||||
private final CrossoverOperator crossoverOperator;
|
|
||||||
private final MutationOperator mutationOperator;
|
|
||||||
private final RandomGenerator rng;
|
|
||||||
private double[][] previousGenes = new double[PREVIOUS_GENES_TO_KEEP][];
|
|
||||||
private int previousGenesIdx = 0;
|
|
||||||
|
|
||||||
public static class Builder {
|
|
||||||
private ChromosomeFactory chromosomeFactory;
|
|
||||||
private PopulationModel populationModel;
|
|
||||||
private CrossoverOperator crossoverOperator;
|
|
||||||
private MutationOperator mutationOperator;
|
|
||||||
private RandomGenerator rng;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Use an alternate crossover behavior. Default is SinglePointCrossover.
|
|
||||||
*
|
|
||||||
* @param crossoverOperator An instance of CrossoverOperator
|
|
||||||
*/
|
|
||||||
public Builder crossoverOperator(CrossoverOperator crossoverOperator) {
|
|
||||||
this.crossoverOperator = crossoverOperator;
|
|
||||||
return this;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Use an alternate mutation behavior. Default is RandomMutationOperator.
|
|
||||||
*
|
|
||||||
* @param mutationOperator An instance of MutationOperator
|
|
||||||
*/
|
|
||||||
public Builder mutationOperator(MutationOperator mutationOperator) {
|
|
||||||
this.mutationOperator = mutationOperator;
|
|
||||||
return this;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Use a supplied RandomGenerator
|
|
||||||
*
|
|
||||||
* @param rng An instance of RandomGenerator
|
|
||||||
*/
|
|
||||||
public Builder randomGenerator(RandomGenerator rng) {
|
|
||||||
this.rng = rng;
|
|
||||||
return this;
|
|
||||||
}
|
|
||||||
|
|
||||||
public GeneticSelectionOperator build() {
|
|
||||||
if (crossoverOperator == null) {
|
|
||||||
crossoverOperator = new SinglePointCrossover.Builder().build();
|
|
||||||
}
|
|
||||||
|
|
||||||
if (mutationOperator == null) {
|
|
||||||
mutationOperator = new RandomMutationOperator.Builder().build();
|
|
||||||
}
|
|
||||||
|
|
||||||
if (rng == null) {
|
|
||||||
rng = new SynchronizedRandomGenerator(new JDKRandomGenerator());
|
|
||||||
}
|
|
||||||
|
|
||||||
return new GeneticSelectionOperator(crossoverOperator, mutationOperator, rng);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
private GeneticSelectionOperator(CrossoverOperator crossoverOperator, MutationOperator mutationOperator,
|
|
||||||
RandomGenerator rng) {
|
|
||||||
this.crossoverOperator = crossoverOperator;
|
|
||||||
this.mutationOperator = mutationOperator;
|
|
||||||
this.rng = rng;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Called by GeneticSearchCandidateGenerator
|
|
||||||
*/
|
|
||||||
@Override
|
|
||||||
public void initializeInstance(PopulationModel populationModel, ChromosomeFactory chromosomeFactory) {
|
|
||||||
super.initializeInstance(populationModel, chromosomeFactory);
|
|
||||||
crossoverOperator.initializeInstance(populationModel);
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Build a new set of genes. Has two distinct modes of operation
|
|
||||||
* <ul>
|
|
||||||
* <li>Before the population has reached the culled size: will return a random set of genes.</li>
|
|
||||||
* <li>After: Parents will be selected among the population, a crossover will be applied followed by a mutation.</li>
|
|
||||||
* </ul>
|
|
||||||
* @return Returns the generated set of genes
|
|
||||||
* @throws GeneticGenerationException If buildNextGenes() can't generate a set that has not already been tried,
|
|
||||||
* or if the crossover and the mutation operators can't generate a set,
|
|
||||||
* this exception is thrown.
|
|
||||||
*/
|
|
||||||
@Override
|
|
||||||
public double[] buildNextGenes() {
|
|
||||||
double[] result;
|
|
||||||
|
|
||||||
boolean hasAlreadyBeenTried;
|
|
||||||
int attemptsRemaining = MAX_NUM_GENERATION_ATTEMPTS;
|
|
||||||
do {
|
|
||||||
if (populationModel.isReadyToBreed()) {
|
|
||||||
result = buildOffspring();
|
|
||||||
} else {
|
|
||||||
result = buildRandomGenes();
|
|
||||||
}
|
|
||||||
|
|
||||||
hasAlreadyBeenTried = hasAlreadyBeenTried(result);
|
|
||||||
if (hasAlreadyBeenTried && --attemptsRemaining == 0) {
|
|
||||||
throw new GeneticGenerationException("Failed to generate a set of genes not already tried.");
|
|
||||||
}
|
|
||||||
} while (hasAlreadyBeenTried);
|
|
||||||
|
|
||||||
previousGenes[previousGenesIdx] = result;
|
|
||||||
previousGenesIdx = ++previousGenesIdx % previousGenes.length;
|
|
||||||
|
|
||||||
return result;
|
|
||||||
}
|
|
||||||
|
|
||||||
private boolean hasAlreadyBeenTried(double[] genes) {
|
|
||||||
for (int i = 0; i < previousGenes.length; ++i) {
|
|
||||||
double[] current = previousGenes[i];
|
|
||||||
if (current != null && Arrays.equals(current, genes)) {
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
private double[] buildOffspring() {
|
|
||||||
double[] offspringValues;
|
|
||||||
|
|
||||||
boolean isModified;
|
|
||||||
int attemptsRemaining = MAX_NUM_GENERATION_ATTEMPTS;
|
|
||||||
do {
|
|
||||||
CrossoverResult crossoverResult = crossoverOperator.crossover();
|
|
||||||
offspringValues = crossoverResult.getGenes();
|
|
||||||
isModified = crossoverResult.isModified();
|
|
||||||
isModified |= mutationOperator.mutate(offspringValues);
|
|
||||||
|
|
||||||
if (!isModified && --attemptsRemaining == 0) {
|
|
||||||
throw new GeneticGenerationException(
|
|
||||||
String.format("Crossover and mutation operators failed to generate a new set of genes after %s attempts.",
|
|
||||||
MAX_NUM_GENERATION_ATTEMPTS));
|
|
||||||
}
|
|
||||||
} while (!isModified);
|
|
||||||
|
|
||||||
return offspringValues;
|
|
||||||
}
|
|
||||||
|
|
||||||
private double[] buildRandomGenes() {
|
|
||||||
double[] randomValues = new double[chromosomeFactory.getChromosomeLength()];
|
|
||||||
for (int i = 0; i < randomValues.length; ++i) {
|
|
||||||
randomValues[i] = rng.nextDouble();
|
|
||||||
}
|
|
||||||
|
|
||||||
return randomValues;
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
|
@ -1,44 +0,0 @@
|
||||||
/*******************************************************************************
|
|
||||||
* Copyright (c) 2015-2019 Skymind, Inc.
|
|
||||||
*
|
|
||||||
* This program and the accompanying materials are made available under the
|
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* License for the specific language governing permissions and limitations
|
|
||||||
* under the License.
|
|
||||||
*
|
|
||||||
* SPDX-License-Identifier: Apache-2.0
|
|
||||||
******************************************************************************/
|
|
||||||
|
|
||||||
package org.deeplearning4j.arbiter.optimize.generator.genetic.selection;
|
|
||||||
|
|
||||||
import org.deeplearning4j.arbiter.optimize.generator.genetic.ChromosomeFactory;
|
|
||||||
import org.deeplearning4j.arbiter.optimize.generator.genetic.population.PopulationModel;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* An abstract class for all selection operators. Used by the GeneticSearchCandidateGenerator to generate new candidates.
|
|
||||||
*
|
|
||||||
* @author Alexandre Boulanger
|
|
||||||
*/
|
|
||||||
public abstract class SelectionOperator {
|
|
||||||
protected PopulationModel populationModel;
|
|
||||||
protected ChromosomeFactory chromosomeFactory;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Called by GeneticSearchCandidateGenerator
|
|
||||||
*/
|
|
||||||
public void initializeInstance(PopulationModel populationModel, ChromosomeFactory chromosomeFactory) {
|
|
||||||
|
|
||||||
this.populationModel = populationModel;
|
|
||||||
this.chromosomeFactory = chromosomeFactory;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Generate a new set of genes.
|
|
||||||
*/
|
|
||||||
public abstract double[] buildNextGenes();
|
|
||||||
}
|
|
|
@ -1,46 +0,0 @@
|
||||||
/*******************************************************************************
|
|
||||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
|
||||||
*
|
|
||||||
* This program and the accompanying materials are made available under the
|
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* License for the specific language governing permissions and limitations
|
|
||||||
* under the License.
|
|
||||||
*
|
|
||||||
* SPDX-License-Identifier: Apache-2.0
|
|
||||||
******************************************************************************/
|
|
||||||
|
|
||||||
package org.deeplearning4j.arbiter.optimize.generator.util;
|
|
||||||
|
|
||||||
import org.nd4j.common.function.Supplier;
|
|
||||||
|
|
||||||
import java.io.*;
|
|
||||||
|
|
||||||
public class SerializedSupplier<T> implements Serializable, Supplier<T> {
|
|
||||||
|
|
||||||
private byte[] asBytes;
|
|
||||||
|
|
||||||
public SerializedSupplier(T obj){
|
|
||||||
try(ByteArrayOutputStream baos = new ByteArrayOutputStream(); ObjectOutputStream oos = new ObjectOutputStream(baos)){
|
|
||||||
oos.writeObject(obj);
|
|
||||||
oos.flush();
|
|
||||||
oos.close();
|
|
||||||
asBytes = baos.toByteArray();
|
|
||||||
} catch (Exception e){
|
|
||||||
throw new RuntimeException("Error serializing object - must be serializable",e);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public T get() {
|
|
||||||
try(ObjectInputStream ois = new ObjectInputStream(new ByteArrayInputStream(asBytes))){
|
|
||||||
return (T)ois.readObject();
|
|
||||||
} catch (Exception e){
|
|
||||||
throw new RuntimeException("Error deserializing object",e);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,76 +0,0 @@
|
||||||
/*******************************************************************************
|
|
||||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
|
||||||
*
|
|
||||||
* This program and the accompanying materials are made available under the
|
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* License for the specific language governing permissions and limitations
|
|
||||||
* under the License.
|
|
||||||
*
|
|
||||||
* SPDX-License-Identifier: Apache-2.0
|
|
||||||
******************************************************************************/
|
|
||||||
|
|
||||||
package org.deeplearning4j.arbiter.optimize.parameter;
|
|
||||||
|
|
||||||
import lombok.EqualsAndHashCode;
|
|
||||||
import org.deeplearning4j.arbiter.optimize.api.ParameterSpace;
|
|
||||||
|
|
||||||
import java.util.Collections;
|
|
||||||
import java.util.List;
|
|
||||||
import java.util.Map;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* BooleanParameterSpace is a {@code ParameterSpace<Boolean>}; Defines {True, False} as a parameter space
|
|
||||||
* If argument to setValue is less than or equal to 0.5 it will return True else False
|
|
||||||
*
|
|
||||||
* @author susaneraly
|
|
||||||
*/
|
|
||||||
@EqualsAndHashCode
|
|
||||||
public class BooleanSpace implements ParameterSpace<Boolean> {
|
|
||||||
private int index = -1;
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public Boolean getValue(double[] input) {
|
|
||||||
if (index == -1) {
|
|
||||||
throw new IllegalStateException("Cannot get value: ParameterSpace index has not been set");
|
|
||||||
}
|
|
||||||
if (input[index] <= 0.5) return Boolean.TRUE;
|
|
||||||
else return Boolean.FALSE;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public int numParameters() {
|
|
||||||
return 1;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public List<ParameterSpace> collectLeaves() {
|
|
||||||
return Collections.singletonList((ParameterSpace) this);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public Map<String, ParameterSpace> getNestedSpaces() {
|
|
||||||
return Collections.emptyMap();
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public boolean isLeaf() {
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void setIndices(int... indices) {
|
|
||||||
if (indices == null || indices.length != 1)
|
|
||||||
throw new IllegalArgumentException("Invalid index");
|
|
||||||
this.index = indices[0];
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public String toString() {
|
|
||||||
return "BooleanSpace()";
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,90 +0,0 @@
|
||||||
/*******************************************************************************
|
|
||||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
|
||||||
*
|
|
||||||
* This program and the accompanying materials are made available under the
|
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* License for the specific language governing permissions and limitations
|
|
||||||
* under the License.
|
|
||||||
*
|
|
||||||
* SPDX-License-Identifier: Apache-2.0
|
|
||||||
******************************************************************************/
|
|
||||||
|
|
||||||
package org.deeplearning4j.arbiter.optimize.parameter;
|
|
||||||
|
|
||||||
import lombok.EqualsAndHashCode;
|
|
||||||
import lombok.Getter;
|
|
||||||
import org.deeplearning4j.arbiter.optimize.api.ParameterSpace;
|
|
||||||
import org.deeplearning4j.arbiter.optimize.serde.jackson.FixedValueDeserializer;
|
|
||||||
import org.deeplearning4j.arbiter.optimize.serde.jackson.FixedValueSerializer;
|
|
||||||
import org.deeplearning4j.arbiter.util.ObjectUtils;
|
|
||||||
import org.nd4j.shade.jackson.annotation.JsonCreator;
|
|
||||||
import org.nd4j.shade.jackson.annotation.JsonProperty;
|
|
||||||
import org.nd4j.shade.jackson.annotation.JsonTypeInfo;
|
|
||||||
import org.nd4j.shade.jackson.databind.annotation.JsonDeserialize;
|
|
||||||
import org.nd4j.shade.jackson.databind.annotation.JsonSerialize;
|
|
||||||
|
|
||||||
import java.util.Collections;
|
|
||||||
import java.util.List;
|
|
||||||
import java.util.Map;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* FixedValue is a ParameterSpace that defines only a single fixed value
|
|
||||||
*
|
|
||||||
* @param <T> Type of (fixed) value
|
|
||||||
*/
|
|
||||||
@EqualsAndHashCode
|
|
||||||
@JsonSerialize(using = FixedValueSerializer.class)
|
|
||||||
@JsonDeserialize(using = FixedValueDeserializer.class)
|
|
||||||
@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class")
|
|
||||||
public class FixedValue<T> implements ParameterSpace<T> {
|
|
||||||
@Getter
|
|
||||||
private Object value;
|
|
||||||
private int index;
|
|
||||||
|
|
||||||
@JsonCreator
|
|
||||||
public FixedValue(@JsonProperty("value") T value) {
|
|
||||||
this.value = value;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public String toString() {
|
|
||||||
return "FixedValue(" + ObjectUtils.valueToString(value) + ")";
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public T getValue(double[] input) {
|
|
||||||
return (T) value;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public int numParameters() {
|
|
||||||
return 0;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public List<ParameterSpace> collectLeaves() {
|
|
||||||
return Collections.emptyList();
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public Map<String, ParameterSpace> getNestedSpaces() {
|
|
||||||
return Collections.emptyMap();
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public boolean isLeaf() {
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void setIndices(int... indices) {
|
|
||||||
if (indices != null && indices.length != 0)
|
|
||||||
throw new IllegalArgumentException(
|
|
||||||
"Invalid call: FixedValue ParameterSpace " + "should not be given an index (0 params)");
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,135 +0,0 @@
|
||||||
/*******************************************************************************
|
|
||||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
|
||||||
*
|
|
||||||
* This program and the accompanying materials are made available under the
|
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* License for the specific language governing permissions and limitations
|
|
||||||
* under the License.
|
|
||||||
*
|
|
||||||
* SPDX-License-Identifier: Apache-2.0
|
|
||||||
******************************************************************************/
|
|
||||||
|
|
||||||
package org.deeplearning4j.arbiter.optimize.parameter.continuous;
|
|
||||||
|
|
||||||
import org.apache.commons.math3.distribution.RealDistribution;
|
|
||||||
import org.apache.commons.math3.distribution.UniformRealDistribution;
|
|
||||||
import org.deeplearning4j.arbiter.optimize.api.ParameterSpace;
|
|
||||||
import org.deeplearning4j.arbiter.optimize.distribution.DistributionUtils;
|
|
||||||
import org.deeplearning4j.arbiter.optimize.serde.jackson.RealDistributionDeserializer;
|
|
||||||
import org.deeplearning4j.arbiter.optimize.serde.jackson.RealDistributionSerializer;
|
|
||||||
import org.nd4j.shade.jackson.annotation.JsonProperty;
|
|
||||||
import org.nd4j.shade.jackson.databind.annotation.JsonDeserialize;
|
|
||||||
import org.nd4j.shade.jackson.databind.annotation.JsonSerialize;
|
|
||||||
|
|
||||||
import java.util.Collections;
|
|
||||||
import java.util.List;
|
|
||||||
import java.util.Map;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* ContinuousParametSpace is a {@code ParameterSpace<Double>} that (optionally) takes an Apache Commons
|
|
||||||
* {@link RealDistribution} when used for random sampling (such as in a RandomSearchCandidateGenerator)
|
|
||||||
*
|
|
||||||
* @author Alex Black
|
|
||||||
*/
|
|
||||||
public class ContinuousParameterSpace implements ParameterSpace<Double> {
|
|
||||||
|
|
||||||
//Need to use custom serializers/deserializers for commons RealDistribution instances
|
|
||||||
@JsonSerialize(using = RealDistributionSerializer.class)
|
|
||||||
@JsonDeserialize(using = RealDistributionDeserializer.class)
|
|
||||||
private RealDistribution distribution;
|
|
||||||
private int index = -1;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* ContinuousParameterSpace with uniform distribution between the minimum and maximum values
|
|
||||||
*
|
|
||||||
* @param min Minimum value that can be generated
|
|
||||||
* @param max Maximum value that can be generated
|
|
||||||
*/
|
|
||||||
public ContinuousParameterSpace(double min, double max) {
|
|
||||||
this(new UniformRealDistribution(min, max));
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* ConditiousParameterSpcae wiht a specified probability distribution. The provided distribution defines the min/max
|
|
||||||
* values, and (for random search, etc) will be used when generating random values
|
|
||||||
*
|
|
||||||
* @param distribution Distribution to sample from
|
|
||||||
*/
|
|
||||||
public ContinuousParameterSpace(@JsonProperty("distribution") RealDistribution distribution) {
|
|
||||||
this.distribution = distribution;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public Double getValue(double[] input) {
|
|
||||||
if (index == -1) {
|
|
||||||
throw new IllegalStateException("Cannot get value: ParameterSpace index has not been set");
|
|
||||||
}
|
|
||||||
return distribution.inverseCumulativeProbability(input[index]);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public int numParameters() {
|
|
||||||
return 1;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public List<ParameterSpace> collectLeaves() {
|
|
||||||
return Collections.singletonList((ParameterSpace) this);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public Map<String, ParameterSpace> getNestedSpaces() {
|
|
||||||
return Collections.emptyMap();
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public boolean isLeaf() {
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void setIndices(int... indices) {
|
|
||||||
if (indices == null || indices.length != 1) {
|
|
||||||
throw new IllegalArgumentException("Invalid index");
|
|
||||||
}
|
|
||||||
this.index = indices[0];
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public String toString() {
|
|
||||||
if (distribution instanceof UniformRealDistribution) {
|
|
||||||
return "ContinuousParameterSpace(min=" + distribution.getSupportLowerBound() + ",max="
|
|
||||||
+ distribution.getSupportUpperBound() + ")";
|
|
||||||
} else {
|
|
||||||
return "ContinuousParameterSpace(" + distribution + ")";
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
public boolean equals(Object o) {
|
|
||||||
if (o == this)
|
|
||||||
return true;
|
|
||||||
if (!(o instanceof ContinuousParameterSpace))
|
|
||||||
return false;
|
|
||||||
final ContinuousParameterSpace other = (ContinuousParameterSpace) o;
|
|
||||||
if (distribution == null ? other.distribution != null
|
|
||||||
: !DistributionUtils.distributionsEqual(distribution, other.distribution))
|
|
||||||
return false;
|
|
||||||
|
|
||||||
return this.index == other.index;
|
|
||||||
}
|
|
||||||
|
|
||||||
public int hashCode() {
|
|
||||||
final int PRIME = 59;
|
|
||||||
int result = 1;
|
|
||||||
result = result * PRIME + (distribution == null ? 43 : distribution.getClass().hashCode());
|
|
||||||
result = result * PRIME + this.index;
|
|
||||||
return result;
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,112 +0,0 @@
|
||||||
/*******************************************************************************
|
|
||||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
|
||||||
*
|
|
||||||
* This program and the accompanying materials are made available under the
|
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* License for the specific language governing permissions and limitations
|
|
||||||
* under the License.
|
|
||||||
*
|
|
||||||
* SPDX-License-Identifier: Apache-2.0
|
|
||||||
******************************************************************************/
|
|
||||||
|
|
||||||
package org.deeplearning4j.arbiter.optimize.parameter.discrete;
|
|
||||||
|
|
||||||
import lombok.EqualsAndHashCode;
|
|
||||||
import org.deeplearning4j.arbiter.optimize.api.ParameterSpace;
|
|
||||||
import org.deeplearning4j.arbiter.util.ObjectUtils;
|
|
||||||
import org.nd4j.shade.jackson.annotation.JsonProperty;
|
|
||||||
import org.nd4j.shade.jackson.databind.annotation.JsonSerialize;
|
|
||||||
|
|
||||||
import java.util.*;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* A DiscreteParameterSpace is used for a set of un-ordered values
|
|
||||||
*
|
|
||||||
* @param <P> Parameter type
|
|
||||||
* @author Alex Black
|
|
||||||
*/
|
|
||||||
@EqualsAndHashCode
|
|
||||||
public class DiscreteParameterSpace<P> implements ParameterSpace<P> {
|
|
||||||
|
|
||||||
@JsonSerialize
|
|
||||||
private List<P> values;
|
|
||||||
private int index = -1;
|
|
||||||
|
|
||||||
public DiscreteParameterSpace(@JsonProperty("values") P... values) {
|
|
||||||
if (values != null)
|
|
||||||
this.values = Arrays.asList(values);
|
|
||||||
}
|
|
||||||
|
|
||||||
public DiscreteParameterSpace(Collection<P> values) {
|
|
||||||
this.values = new ArrayList<>(values);
|
|
||||||
}
|
|
||||||
|
|
||||||
public int numValues() {
|
|
||||||
return values.size();
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public P getValue(double[] input) {
|
|
||||||
if (index == -1) {
|
|
||||||
throw new IllegalStateException("Cannot get value: ParameterSpace index has not been set");
|
|
||||||
}
|
|
||||||
if (values == null)
|
|
||||||
throw new IllegalStateException("Values are null.");
|
|
||||||
//Map a value in range [0,1] to one of the list of values
|
|
||||||
//First value: [0,width], second: (width,2*width], third: (3*width,4*width] etc
|
|
||||||
int size = values.size();
|
|
||||||
if (size == 1)
|
|
||||||
return values.get(0);
|
|
||||||
double width = 1.0 / size;
|
|
||||||
int val = (int) (input[index] / width);
|
|
||||||
return values.get(Math.min(val, size - 1));
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public int numParameters() {
|
|
||||||
return 1;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public List<ParameterSpace> collectLeaves() {
|
|
||||||
return Collections.singletonList((ParameterSpace) this);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public Map<String, ParameterSpace> getNestedSpaces() {
|
|
||||||
return Collections.emptyMap();
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public boolean isLeaf() {
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void setIndices(int... indices) {
|
|
||||||
if (indices == null || indices.length != 1) {
|
|
||||||
throw new IllegalArgumentException("Invalid index");
|
|
||||||
}
|
|
||||||
this.index = indices[0];
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public String toString() {
|
|
||||||
StringBuilder sb = new StringBuilder();
|
|
||||||
sb.append("DiscreteParameterSpace(");
|
|
||||||
int n = values.size();
|
|
||||||
for (int i = 0; i < n; i++) {
|
|
||||||
P value = values.get(i);
|
|
||||||
sb.append(ObjectUtils.valueToString(value));
|
|
||||||
sb.append((i == n - 1 ? ")" : ","));
|
|
||||||
}
|
|
||||||
return sb.toString();
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
}
|
|
|
@ -1,149 +0,0 @@
|
||||||
/*******************************************************************************
|
|
||||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
|
||||||
*
|
|
||||||
* This program and the accompanying materials are made available under the
|
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* License for the specific language governing permissions and limitations
|
|
||||||
* under the License.
|
|
||||||
*
|
|
||||||
* SPDX-License-Identifier: Apache-2.0
|
|
||||||
******************************************************************************/
|
|
||||||
|
|
||||||
package org.deeplearning4j.arbiter.optimize.parameter.integer;
|
|
||||||
|
|
||||||
import lombok.NoArgsConstructor;
|
|
||||||
import org.apache.commons.math3.distribution.IntegerDistribution;
|
|
||||||
import org.apache.commons.math3.distribution.UniformIntegerDistribution;
|
|
||||||
import org.deeplearning4j.arbiter.optimize.api.ParameterSpace;
|
|
||||||
import org.deeplearning4j.arbiter.optimize.distribution.DistributionUtils;
|
|
||||||
import org.deeplearning4j.arbiter.optimize.serde.jackson.IntegerDistributionDeserializer;
|
|
||||||
import org.deeplearning4j.arbiter.optimize.serde.jackson.IntegerDistributionSerializer;
|
|
||||||
import org.nd4j.shade.jackson.annotation.JsonCreator;
|
|
||||||
import org.nd4j.shade.jackson.annotation.JsonIgnoreProperties;
|
|
||||||
import org.nd4j.shade.jackson.annotation.JsonProperty;
|
|
||||||
import org.nd4j.shade.jackson.databind.annotation.JsonDeserialize;
|
|
||||||
import org.nd4j.shade.jackson.databind.annotation.JsonSerialize;
|
|
||||||
|
|
||||||
import java.util.Collections;
|
|
||||||
import java.util.List;
|
|
||||||
import java.util.Map;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* IntegerParameterSpace is a {@code ParameterSpace<Integer>}; i.e., defines an ordered space of integers between
|
|
||||||
* some minimum and maximum value
|
|
||||||
*
|
|
||||||
* @author Alex Black
|
|
||||||
*/
|
|
||||||
@JsonIgnoreProperties({"min", "max"})
|
|
||||||
@NoArgsConstructor
|
|
||||||
public class IntegerParameterSpace implements ParameterSpace<Integer> {
|
|
||||||
|
|
||||||
@JsonSerialize(using = IntegerDistributionSerializer.class)
|
|
||||||
@JsonDeserialize(using = IntegerDistributionDeserializer.class)
|
|
||||||
private IntegerDistribution distribution;
|
|
||||||
private int index = -1;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Create an IntegerParameterSpace with a uniform distribution between the specified min/max (inclusive)
|
|
||||||
*
|
|
||||||
* @param min Min value, inclusive
|
|
||||||
* @param max Max value, inclusive
|
|
||||||
*/
|
|
||||||
public IntegerParameterSpace(int min, int max) {
|
|
||||||
this(new UniformIntegerDistribution(min, max));
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Crate an IntegerParametSpace from the given IntegerDistribution
|
|
||||||
*
|
|
||||||
* @param distribution Distribution to use
|
|
||||||
*/
|
|
||||||
@JsonCreator
|
|
||||||
public IntegerParameterSpace(@JsonProperty("distribution") IntegerDistribution distribution) {
|
|
||||||
this.distribution = distribution;
|
|
||||||
}
|
|
||||||
|
|
||||||
public int getMin() {
|
|
||||||
return distribution.getSupportLowerBound();
|
|
||||||
}
|
|
||||||
|
|
||||||
public int getMax() {
|
|
||||||
return distribution.getSupportUpperBound();
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public Integer getValue(double[] input) {
|
|
||||||
if (index == -1) {
|
|
||||||
throw new IllegalStateException("Cannot get value: ParameterSpace index has not been set");
|
|
||||||
}
|
|
||||||
return distribution.inverseCumulativeProbability(input[index]);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public int numParameters() {
|
|
||||||
return 1;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public List<ParameterSpace> collectLeaves() {
|
|
||||||
return Collections.singletonList((ParameterSpace) this);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public Map<String, ParameterSpace> getNestedSpaces() {
|
|
||||||
return Collections.emptyMap();
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public boolean isLeaf() {
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void setIndices(int... indices) {
|
|
||||||
if (indices == null || indices.length != 1)
|
|
||||||
throw new IllegalArgumentException("Invalid index");
|
|
||||||
this.index = indices[0];
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public String toString() {
|
|
||||||
if (distribution instanceof UniformIntegerDistribution) {
|
|
||||||
return "IntegerParameterSpace(min=" + distribution.getSupportLowerBound() + ",max="
|
|
||||||
+ distribution.getSupportUpperBound() + ")";
|
|
||||||
} else {
|
|
||||||
return "IntegerParameterSpace(" + distribution + ")";
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
public boolean equals(Object o) {
|
|
||||||
if (o == this)
|
|
||||||
return true;
|
|
||||||
if (!(o instanceof IntegerParameterSpace))
|
|
||||||
return false;
|
|
||||||
final IntegerParameterSpace other = (IntegerParameterSpace) o;
|
|
||||||
if (!other.canEqual(this))
|
|
||||||
return false;
|
|
||||||
if (distribution == null ? other.distribution != null
|
|
||||||
: !DistributionUtils.distributionEquals(distribution, other.distribution))
|
|
||||||
return false;
|
|
||||||
return this.index == other.index;
|
|
||||||
}
|
|
||||||
|
|
||||||
public int hashCode() {
|
|
||||||
final int PRIME = 59;
|
|
||||||
int result = 1;
|
|
||||||
result = result * PRIME + (distribution == null ? 43 : distribution.getClass().hashCode());
|
|
||||||
result = result * PRIME + this.index;
|
|
||||||
return result;
|
|
||||||
}
|
|
||||||
|
|
||||||
protected boolean canEqual(Object other) {
|
|
||||||
return other instanceof IntegerParameterSpace;
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,69 +0,0 @@
|
||||||
/*******************************************************************************
|
|
||||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
|
||||||
*
|
|
||||||
* This program and the accompanying materials are made available under the
|
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* License for the specific language governing permissions and limitations
|
|
||||||
* under the License.
|
|
||||||
*
|
|
||||||
* SPDX-License-Identifier: Apache-2.0
|
|
||||||
******************************************************************************/
|
|
||||||
|
|
||||||
package org.deeplearning4j.arbiter.optimize.parameter.math;
|
|
||||||
|
|
||||||
import org.deeplearning4j.arbiter.optimize.api.AbstractParameterSpace;
|
|
||||||
import org.deeplearning4j.arbiter.optimize.api.ParameterSpace;
|
|
||||||
|
|
||||||
import java.util.List;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* A simple parameter space that implements scalar mathematical operations on another parameter space. This allows you
|
|
||||||
* to do things like Y = X * 2, where X is a parameter space. For example, a layer size hyperparameter could be set
|
|
||||||
* using this to 2x the size of the previous layer
|
|
||||||
*
|
|
||||||
* @param <T> Type of the parameter space
|
|
||||||
* @author Alex Black
|
|
||||||
*/
|
|
||||||
public class MathOp<T extends Number> extends AbstractParameterSpace<T> {
|
|
||||||
|
|
||||||
private ParameterSpace<T> parameterSpace;
|
|
||||||
private Op op;
|
|
||||||
private T scalar;
|
|
||||||
|
|
||||||
public MathOp(ParameterSpace<T> parameterSpace, Op op, T scalar){
|
|
||||||
this.parameterSpace = parameterSpace;
|
|
||||||
this.op = op;
|
|
||||||
this.scalar = scalar;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public T getValue(double[] parameterValues) {
|
|
||||||
T u = parameterSpace.getValue(parameterValues);
|
|
||||||
return op.doOp(u, scalar);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public int numParameters() {
|
|
||||||
return parameterSpace.numParameters();
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public List<ParameterSpace> collectLeaves() {
|
|
||||||
return parameterSpace.collectLeaves();
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public boolean isLeaf() {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void setIndices(int... indices) {
|
|
||||||
parameterSpace.setIndices(indices);
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,76 +0,0 @@
|
||||||
/*******************************************************************************
|
|
||||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
|
||||||
*
|
|
||||||
* This program and the accompanying materials are made available under the
|
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* License for the specific language governing permissions and limitations
|
|
||||||
* under the License.
|
|
||||||
*
|
|
||||||
* SPDX-License-Identifier: Apache-2.0
|
|
||||||
******************************************************************************/
|
|
||||||
|
|
||||||
package org.deeplearning4j.arbiter.optimize.parameter.math;
|
|
||||||
|
|
||||||
public enum Op {
|
|
||||||
ADD, SUB, MUL, DIV;
|
|
||||||
|
|
||||||
|
|
||||||
//Package private
|
|
||||||
<T extends Number> T doOp(T first, T second){
|
|
||||||
if(first instanceof Integer || first instanceof Long){
|
|
||||||
long result;
|
|
||||||
switch (this){
|
|
||||||
case ADD:
|
|
||||||
result = Long.valueOf(first.longValue() + second.longValue());
|
|
||||||
break;
|
|
||||||
case SUB:
|
|
||||||
result = Long.valueOf(first.longValue() - second.longValue());
|
|
||||||
break;
|
|
||||||
case MUL:
|
|
||||||
result = Long.valueOf(first.longValue() * second.longValue());
|
|
||||||
break;
|
|
||||||
case DIV:
|
|
||||||
result = Long.valueOf(first.longValue() / second.longValue());
|
|
||||||
break;
|
|
||||||
default:
|
|
||||||
throw new UnsupportedOperationException("Unknown op: " + this);
|
|
||||||
}
|
|
||||||
if(first instanceof Long){
|
|
||||||
return (T)Long.valueOf(result);
|
|
||||||
} else {
|
|
||||||
return (T)Integer.valueOf((int)result);
|
|
||||||
}
|
|
||||||
} else if(first instanceof Double || first instanceof Float){
|
|
||||||
double result;
|
|
||||||
switch (this){
|
|
||||||
case ADD:
|
|
||||||
result = Double.valueOf(first.doubleValue() + second.doubleValue());
|
|
||||||
break;
|
|
||||||
case SUB:
|
|
||||||
result = Double.valueOf(first.doubleValue() - second.doubleValue());
|
|
||||||
break;
|
|
||||||
case MUL:
|
|
||||||
result = Double.valueOf(first.doubleValue() * second.doubleValue());
|
|
||||||
break;
|
|
||||||
case DIV:
|
|
||||||
result = Double.valueOf(first.doubleValue() / second.doubleValue());
|
|
||||||
break;
|
|
||||||
default:
|
|
||||||
throw new UnsupportedOperationException("Unknown op: " + this);
|
|
||||||
}
|
|
||||||
if(first instanceof Double){
|
|
||||||
return (T)Double.valueOf(result);
|
|
||||||
} else {
|
|
||||||
return (T)Float.valueOf((float)result);
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
throw new UnsupportedOperationException("Not supported type: only Integer, Long, Double, Float supported" +
|
|
||||||
" here. Got type: " + first.getClass());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,79 +0,0 @@
|
||||||
/*******************************************************************************
|
|
||||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
|
||||||
*
|
|
||||||
* This program and the accompanying materials are made available under the
|
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* License for the specific language governing permissions and limitations
|
|
||||||
* under the License.
|
|
||||||
*
|
|
||||||
* SPDX-License-Identifier: Apache-2.0
|
|
||||||
******************************************************************************/
|
|
||||||
|
|
||||||
package org.deeplearning4j.arbiter.optimize.parameter.math;
|
|
||||||
|
|
||||||
import org.deeplearning4j.arbiter.optimize.api.AbstractParameterSpace;
|
|
||||||
import org.deeplearning4j.arbiter.optimize.api.ParameterSpace;
|
|
||||||
|
|
||||||
import java.util.ArrayList;
|
|
||||||
import java.util.Arrays;
|
|
||||||
import java.util.List;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* A simple parameter space that implements pairwise mathematical operations on another parameter space. This allows you
|
|
||||||
* to do things like Z = X + Y, where X and Y are parameter spaces.
|
|
||||||
*
|
|
||||||
* @param <T> Type of the parameter space
|
|
||||||
* @author Alex Black
|
|
||||||
*/
|
|
||||||
public class PairMathOp<T extends Number> extends AbstractParameterSpace<T> {
|
|
||||||
|
|
||||||
private ParameterSpace<T> first;
|
|
||||||
private ParameterSpace<T> second;
|
|
||||||
private Op op;
|
|
||||||
|
|
||||||
public PairMathOp(ParameterSpace<T> first, ParameterSpace<T> second, Op op){
|
|
||||||
this.first = first;
|
|
||||||
this.second = second;
|
|
||||||
this.op = op;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public T getValue(double[] parameterValues) {
|
|
||||||
T f = first.getValue(parameterValues);
|
|
||||||
T s = second.getValue(parameterValues);
|
|
||||||
return op.doOp(f, s);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public int numParameters() {
|
|
||||||
return first.numParameters() + second.numParameters();
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public List<ParameterSpace> collectLeaves() {
|
|
||||||
List<ParameterSpace> l = new ArrayList<>();
|
|
||||||
l.addAll(first.collectLeaves());
|
|
||||||
l.addAll(second.collectLeaves());
|
|
||||||
return l;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public boolean isLeaf() {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void setIndices(int... indices) {
|
|
||||||
int n1 = first.numParameters();
|
|
||||||
int n2 = second.numParameters();
|
|
||||||
int[] s1 = Arrays.copyOfRange(indices, 0, n1);
|
|
||||||
int[] s2 = Arrays.copyOfRange(indices, n1, n1+n2);
|
|
||||||
first.setIndices(s1);
|
|
||||||
second.setIndices(s2);
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,379 +0,0 @@
|
||||||
/*******************************************************************************
|
|
||||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
|
||||||
*
|
|
||||||
* This program and the accompanying materials are made available under the
|
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* License for the specific language governing permissions and limitations
|
|
||||||
* under the License.
|
|
||||||
*
|
|
||||||
* SPDX-License-Identifier: Apache-2.0
|
|
||||||
******************************************************************************/
|
|
||||||
|
|
||||||
package org.deeplearning4j.arbiter.optimize.runner;
|
|
||||||
|
|
||||||
import org.nd4j.shade.guava.util.concurrent.ListenableFuture;
|
|
||||||
import lombok.AllArgsConstructor;
|
|
||||||
import lombok.Data;
|
|
||||||
import lombok.extern.slf4j.Slf4j;
|
|
||||||
import org.apache.commons.lang3.exception.ExceptionUtils;
|
|
||||||
import org.deeplearning4j.arbiter.optimize.api.Candidate;
|
|
||||||
import org.deeplearning4j.arbiter.optimize.api.OptimizationResult;
|
|
||||||
import org.deeplearning4j.arbiter.optimize.api.data.DataProvider;
|
|
||||||
import org.deeplearning4j.arbiter.optimize.api.data.DataSource;
|
|
||||||
import org.deeplearning4j.arbiter.optimize.api.saving.ResultReference;
|
|
||||||
import org.deeplearning4j.arbiter.optimize.api.score.ScoreFunction;
|
|
||||||
import org.deeplearning4j.arbiter.optimize.api.termination.TerminationCondition;
|
|
||||||
import org.deeplearning4j.arbiter.optimize.config.OptimizationConfiguration;
|
|
||||||
import org.deeplearning4j.arbiter.optimize.runner.listener.StatusListener;
|
|
||||||
|
|
||||||
import java.util.*;
|
|
||||||
import java.util.concurrent.*;
|
|
||||||
import java.util.concurrent.atomic.AtomicInteger;
|
|
||||||
import java.util.concurrent.atomic.AtomicLong;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* BaseOptimization runner: responsible for scheduling tasks, saving results using the result saver, etc.
|
|
||||||
*
|
|
||||||
* @author Alex Black
|
|
||||||
*/
|
|
||||||
@Slf4j
|
|
||||||
public abstract class BaseOptimizationRunner implements IOptimizationRunner {
|
|
||||||
private static final int POLLING_FREQUENCY = 1;
|
|
||||||
private static final TimeUnit POLLING_FREQUENCY_UNIT = TimeUnit.SECONDS;
|
|
||||||
|
|
||||||
protected OptimizationConfiguration config;
|
|
||||||
protected Queue<Future<OptimizationResult>> queuedFutures = new ConcurrentLinkedQueue<>();
|
|
||||||
protected BlockingQueue<Future<OptimizationResult>> completedFutures = new LinkedBlockingQueue<>();
|
|
||||||
protected AtomicInteger totalCandidateCount = new AtomicInteger();
|
|
||||||
protected AtomicInteger numCandidatesCompleted = new AtomicInteger();
|
|
||||||
protected AtomicInteger numCandidatesFailed = new AtomicInteger();
|
|
||||||
protected Double bestScore = null;
|
|
||||||
protected Long bestScoreTime = null;
|
|
||||||
protected AtomicInteger bestScoreCandidateIndex = new AtomicInteger(-1);
|
|
||||||
protected List<ResultReference> allResults = new ArrayList<>();
|
|
||||||
|
|
||||||
protected Map<Integer, CandidateInfo> currentStatus = new ConcurrentHashMap<>(); //TODO: better design possible?
|
|
||||||
|
|
||||||
protected ExecutorService futureListenerExecutor;
|
|
||||||
|
|
||||||
protected List<StatusListener> statusListeners = new ArrayList<>();
|
|
||||||
|
|
||||||
|
|
||||||
protected BaseOptimizationRunner(OptimizationConfiguration config) {
|
|
||||||
this.config = config;
|
|
||||||
|
|
||||||
if (config.getTerminationConditions() == null || config.getTerminationConditions().size() == 0) {
|
|
||||||
throw new IllegalArgumentException("Cannot create BaseOptimizationRunner without TerminationConditions ("
|
|
||||||
+ "termination conditions are null or empty)");
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
protected void init() {
|
|
||||||
futureListenerExecutor = Executors.newFixedThreadPool(maxConcurrentTasks(), new ThreadFactory() {
|
|
||||||
private AtomicLong counter = new AtomicLong(0);
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public Thread newThread(Runnable r) {
|
|
||||||
Thread t = Executors.defaultThreadFactory().newThread(r);
|
|
||||||
t.setDaemon(true);
|
|
||||||
t.setName("ArbiterOptimizationRunner-" + counter.getAndIncrement());
|
|
||||||
return t;
|
|
||||||
}
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
*
|
|
||||||
*/
|
|
||||||
@Override
|
|
||||||
public void execute() {
|
|
||||||
log.info("{}: execution started", this.getClass().getSimpleName());
|
|
||||||
config.setExecutionStartTime(System.currentTimeMillis());
|
|
||||||
for (StatusListener listener : statusListeners) {
|
|
||||||
listener.onInitialization(this);
|
|
||||||
}
|
|
||||||
|
|
||||||
//Initialize termination conditions (start timers, etc)
|
|
||||||
for (TerminationCondition c : config.getTerminationConditions()) {
|
|
||||||
c.initialize(this);
|
|
||||||
}
|
|
||||||
|
|
||||||
//Queue initial tasks:
|
|
||||||
List<Future<OptimizationResult>> tempList = new ArrayList<>(100);
|
|
||||||
while (true) {
|
|
||||||
//Otherwise: add tasks if required
|
|
||||||
Future<OptimizationResult> future = null;
|
|
||||||
try {
|
|
||||||
future = completedFutures.poll(POLLING_FREQUENCY, POLLING_FREQUENCY_UNIT);
|
|
||||||
} catch (InterruptedException e) {
|
|
||||||
//No op?
|
|
||||||
}
|
|
||||||
if (future != null) {
|
|
||||||
tempList.add(future);
|
|
||||||
}
|
|
||||||
completedFutures.drainTo(tempList);
|
|
||||||
|
|
||||||
//Process results (if any)
|
|
||||||
for (Future<OptimizationResult> f : tempList) {
|
|
||||||
queuedFutures.remove(f);
|
|
||||||
processReturnedTask(f);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (tempList.size() > 0) {
|
|
||||||
for (StatusListener sl : statusListeners) {
|
|
||||||
sl.onRunnerStatusChange(this);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
tempList.clear();
|
|
||||||
|
|
||||||
//Check termination conditions:
|
|
||||||
if (terminate()) {
|
|
||||||
shutdown(true);
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
|
|
||||||
//Add additional tasks
|
|
||||||
while (config.getCandidateGenerator().hasMoreCandidates() && queuedFutures.size() < maxConcurrentTasks()) {
|
|
||||||
Candidate candidate = config.getCandidateGenerator().getCandidate();
|
|
||||||
CandidateInfo status;
|
|
||||||
if (candidate.getException() != null) {
|
|
||||||
//Failed on generation...
|
|
||||||
status = processFailedCandidates(candidate);
|
|
||||||
} else {
|
|
||||||
long created = System.currentTimeMillis();
|
|
||||||
ListenableFuture<OptimizationResult> f;
|
|
||||||
if(config.getDataSource() != null){
|
|
||||||
f = execute(candidate, config.getDataSource(), config.getDataSourceProperties(), config.getScoreFunction());
|
|
||||||
} else {
|
|
||||||
f = execute(candidate, config.getDataProvider(), config.getScoreFunction());
|
|
||||||
}
|
|
||||||
f.addListener(new OnCompletionListener(f), futureListenerExecutor);
|
|
||||||
queuedFutures.add(f);
|
|
||||||
totalCandidateCount.getAndIncrement();
|
|
||||||
|
|
||||||
status = new CandidateInfo(candidate.getIndex(), CandidateStatus.Created, null,
|
|
||||||
created, null, null, candidate.getFlatParameters(), null);
|
|
||||||
currentStatus.put(candidate.getIndex(), status);
|
|
||||||
}
|
|
||||||
|
|
||||||
for (StatusListener listener : statusListeners) {
|
|
||||||
listener.onCandidateStatusChange(status, this, null);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
//Process any final (completed) tasks:
|
|
||||||
completedFutures.drainTo(tempList);
|
|
||||||
for (Future<OptimizationResult> f : tempList) {
|
|
||||||
queuedFutures.remove(f);
|
|
||||||
processReturnedTask(f);
|
|
||||||
}
|
|
||||||
tempList.clear();
|
|
||||||
|
|
||||||
log.info("Optimization runner: execution complete");
|
|
||||||
for (StatusListener listener : statusListeners) {
|
|
||||||
listener.onShutdown(this);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
private CandidateInfo processFailedCandidates(Candidate<?> candidate) {
|
|
||||||
//In case the candidate fails during the creation of the candidate
|
|
||||||
|
|
||||||
long time = System.currentTimeMillis();
|
|
||||||
String stackTrace = ExceptionUtils.getStackTrace(candidate.getException());
|
|
||||||
CandidateInfo newStatus = new CandidateInfo(candidate.getIndex(), CandidateStatus.Failed, null, time, time,
|
|
||||||
time, candidate.getFlatParameters(), stackTrace);
|
|
||||||
currentStatus.put(candidate.getIndex(), newStatus);
|
|
||||||
|
|
||||||
return newStatus;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Process returned task (either completed or failed
|
|
||||||
*/
|
|
||||||
private void processReturnedTask(Future<OptimizationResult> future) {
|
|
||||||
long currentTime = System.currentTimeMillis();
|
|
||||||
OptimizationResult result;
|
|
||||||
try {
|
|
||||||
result = future.get(100, TimeUnit.MILLISECONDS);
|
|
||||||
} catch (InterruptedException e) {
|
|
||||||
throw new RuntimeException("Unexpected InterruptedException thrown for task", e);
|
|
||||||
} catch (ExecutionException e) {
|
|
||||||
//Note that most of the time, an OptimizationResult is returned even for an exception
|
|
||||||
//This is just to handle any that are missed there (or, by implementations that don't properly do this)
|
|
||||||
log.warn("Task failed", e);
|
|
||||||
|
|
||||||
numCandidatesFailed.getAndIncrement();
|
|
||||||
return;
|
|
||||||
} catch (TimeoutException e) {
|
|
||||||
throw new RuntimeException(e); //TODO
|
|
||||||
}
|
|
||||||
|
|
||||||
//Update internal status:
|
|
||||||
CandidateInfo status = currentStatus.get(result.getIndex());
|
|
||||||
CandidateInfo newStatus = new CandidateInfo(result.getIndex(), result.getCandidateInfo().getCandidateStatus(),
|
|
||||||
result.getScore(), status.getCreatedTime(), result.getCandidateInfo().getStartTime(),
|
|
||||||
currentTime, status.getFlatParams(), result.getCandidateInfo().getExceptionStackTrace());
|
|
||||||
currentStatus.put(result.getIndex(), newStatus);
|
|
||||||
|
|
||||||
//Listeners (on complete, etc) should be executed in underlying task
|
|
||||||
|
|
||||||
|
|
||||||
if (result.getCandidateInfo().getCandidateStatus() == CandidateStatus.Failed) {
|
|
||||||
log.info("Task {} failed during execution: {}", result.getIndex(), result.getCandidateInfo().getExceptionStackTrace());
|
|
||||||
numCandidatesFailed.getAndIncrement();
|
|
||||||
} else {
|
|
||||||
|
|
||||||
//Report completion to candidate generator
|
|
||||||
config.getCandidateGenerator().reportResults(result);
|
|
||||||
|
|
||||||
Double score = result.getScore();
|
|
||||||
log.info("Completed task {}, score = {}", result.getIndex(), result.getScore());
|
|
||||||
|
|
||||||
boolean minimize = config.getScoreFunction().minimize();
|
|
||||||
if (score != null && (bestScore == null
|
|
||||||
|| ((minimize && score < bestScore) || (!minimize && score > bestScore)))) {
|
|
||||||
if (bestScore == null) {
|
|
||||||
log.info("New best score: {} (first completed model)", score);
|
|
||||||
} else {
|
|
||||||
int idx = result.getIndex();
|
|
||||||
int lastBestIdx = bestScoreCandidateIndex.get();
|
|
||||||
log.info("New best score: {}, model {} (prev={}, model {})", score, idx, bestScore, lastBestIdx);
|
|
||||||
}
|
|
||||||
bestScore = score;
|
|
||||||
bestScoreTime = System.currentTimeMillis();
|
|
||||||
bestScoreCandidateIndex.set(result.getIndex());
|
|
||||||
}
|
|
||||||
numCandidatesCompleted.getAndIncrement();
|
|
||||||
|
|
||||||
//Model saving is done in the optimization tasks, to avoid CUDA threading issues
|
|
||||||
ResultReference resultReference = result.getResultReference();
|
|
||||||
|
|
||||||
if (resultReference != null)
|
|
||||||
allResults.add(resultReference);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public int numCandidatesTotal() {
|
|
||||||
return totalCandidateCount.get();
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public int numCandidatesCompleted() {
|
|
||||||
return numCandidatesCompleted.get();
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public int numCandidatesFailed() {
|
|
||||||
return numCandidatesFailed.get();
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public int numCandidatesQueued() {
|
|
||||||
return queuedFutures.size();
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public Double bestScore() {
|
|
||||||
return bestScore;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public Long bestScoreTime() {
|
|
||||||
return bestScoreTime;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public int bestScoreCandidateIndex() {
|
|
||||||
return bestScoreCandidateIndex.get();
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public List<ResultReference> getResults() {
|
|
||||||
return new ArrayList<>(allResults);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public OptimizationConfiguration getConfiguration() {
|
|
||||||
return config;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void addListeners(StatusListener... listeners) {
|
|
||||||
for (StatusListener l : listeners) {
|
|
||||||
if (!statusListeners.contains(l)) {
|
|
||||||
statusListeners.add(l);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void removeListeners(StatusListener... listeners) {
|
|
||||||
for (StatusListener l : listeners) {
|
|
||||||
statusListeners.remove(l);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void removeAllListeners() {
|
|
||||||
statusListeners.clear();
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public List<CandidateInfo> getCandidateStatus() {
|
|
||||||
return new ArrayList<>(currentStatus.values());
|
|
||||||
}
|
|
||||||
|
|
||||||
private boolean terminate() {
|
|
||||||
for (TerminationCondition c : config.getTerminationConditions()) {
|
|
||||||
if (c.terminate(this)) {
|
|
||||||
log.info("BaseOptimizationRunner global termination condition hit: {}", c);
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
@AllArgsConstructor
|
|
||||||
@Data
|
|
||||||
private class FutureDetails {
|
|
||||||
private final Future<OptimizationResult> future;
|
|
||||||
private final long startTime;
|
|
||||||
private final int index;
|
|
||||||
}
|
|
||||||
|
|
||||||
@AllArgsConstructor
|
|
||||||
private class OnCompletionListener implements Runnable {
|
|
||||||
private Future<OptimizationResult> future;
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void run() {
|
|
||||||
completedFutures.add(future);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
protected abstract int maxConcurrentTasks();
|
|
||||||
|
|
||||||
@Deprecated
|
|
||||||
protected abstract ListenableFuture<OptimizationResult> execute(Candidate candidate, DataProvider dataProvider,
|
|
||||||
ScoreFunction scoreFunction);
|
|
||||||
@Deprecated
|
|
||||||
protected abstract List<ListenableFuture<OptimizationResult>> execute(List<Candidate> candidates,
|
|
||||||
DataProvider dataProvider, ScoreFunction scoreFunction);
|
|
||||||
|
|
||||||
protected abstract ListenableFuture<OptimizationResult> execute(Candidate candidate, Class<? extends DataSource> dataSource,
|
|
||||||
Properties dataSourceProperties, ScoreFunction scoreFunction);
|
|
||||||
|
|
||||||
protected abstract List<ListenableFuture<OptimizationResult>> execute(List<Candidate> candidates, Class<? extends DataSource> dataSource,
|
|
||||||
Properties dataSourceProperties, ScoreFunction scoreFunction);
|
|
||||||
}
|
|
|
@ -1,41 +0,0 @@
|
||||||
/*******************************************************************************
|
|
||||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
|
||||||
*
|
|
||||||
* This program and the accompanying materials are made available under the
|
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* License for the specific language governing permissions and limitations
|
|
||||||
* under the License.
|
|
||||||
*
|
|
||||||
* SPDX-License-Identifier: Apache-2.0
|
|
||||||
******************************************************************************/
|
|
||||||
|
|
||||||
package org.deeplearning4j.arbiter.optimize.runner;
|
|
||||||
|
|
||||||
import lombok.AllArgsConstructor;
|
|
||||||
import lombok.Data;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Simple helper class to store status of a candidate that is/has been/will be executed
|
|
||||||
*/
|
|
||||||
@AllArgsConstructor
|
|
||||||
@Data
|
|
||||||
public class CandidateInfo {
|
|
||||||
|
|
||||||
public CandidateInfo() {
|
|
||||||
//No arg constructor for Jackson
|
|
||||||
}
|
|
||||||
|
|
||||||
private int index;
|
|
||||||
private CandidateStatus candidateStatus;
|
|
||||||
private Double score;
|
|
||||||
private long createdTime;
|
|
||||||
private Long startTime;
|
|
||||||
private Long endTime;
|
|
||||||
private double[] flatParams; //Same as parameters in Candidate class
|
|
||||||
private String exceptionStackTrace;
|
|
||||||
}
|
|
|
@ -1,24 +0,0 @@
|
||||||
/*******************************************************************************
|
|
||||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
|
||||||
*
|
|
||||||
* This program and the accompanying materials are made available under the
|
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* License for the specific language governing permissions and limitations
|
|
||||||
* under the License.
|
|
||||||
*
|
|
||||||
* SPDX-License-Identifier: Apache-2.0
|
|
||||||
******************************************************************************/
|
|
||||||
|
|
||||||
package org.deeplearning4j.arbiter.optimize.runner;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Status for candidates
|
|
||||||
*/
|
|
||||||
public enum CandidateStatus {
|
|
||||||
Created, Running, Complete, Failed, Cancelled
|
|
||||||
}
|
|
|
@ -1,67 +0,0 @@
|
||||||
/*******************************************************************************
|
|
||||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
|
||||||
*
|
|
||||||
* This program and the accompanying materials are made available under the
|
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* License for the specific language governing permissions and limitations
|
|
||||||
* under the License.
|
|
||||||
*
|
|
||||||
* SPDX-License-Identifier: Apache-2.0
|
|
||||||
******************************************************************************/
|
|
||||||
|
|
||||||
package org.deeplearning4j.arbiter.optimize.runner;
|
|
||||||
|
|
||||||
import org.deeplearning4j.arbiter.optimize.api.saving.ResultReference;
|
|
||||||
import org.deeplearning4j.arbiter.optimize.config.OptimizationConfiguration;
|
|
||||||
import org.deeplearning4j.arbiter.optimize.runner.listener.StatusListener;
|
|
||||||
import org.nd4j.shade.jackson.annotation.JsonTypeInfo;
|
|
||||||
|
|
||||||
import java.util.List;
|
|
||||||
|
|
||||||
@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class")
|
|
||||||
public interface IOptimizationRunner {
|
|
||||||
|
|
||||||
void execute();
|
|
||||||
|
|
||||||
/** Total number of candidates: created (scheduled), completed and failed */
|
|
||||||
int numCandidatesTotal();
|
|
||||||
|
|
||||||
int numCandidatesCompleted();
|
|
||||||
|
|
||||||
int numCandidatesFailed();
|
|
||||||
|
|
||||||
/** Number of candidates running or queued */
|
|
||||||
int numCandidatesQueued();
|
|
||||||
|
|
||||||
/** Best score found so far */
|
|
||||||
Double bestScore();
|
|
||||||
|
|
||||||
/** Time that the best score was found at, or 0 if no jobs have completed successfully */
|
|
||||||
Long bestScoreTime();
|
|
||||||
|
|
||||||
/** Index of the best scoring candidate, or -1 if no candidate has scored yet*/
|
|
||||||
int bestScoreCandidateIndex();
|
|
||||||
|
|
||||||
List<ResultReference> getResults();
|
|
||||||
|
|
||||||
OptimizationConfiguration getConfiguration();
|
|
||||||
|
|
||||||
void addListeners(StatusListener... listeners);
|
|
||||||
|
|
||||||
void removeListeners(StatusListener... listeners);
|
|
||||||
|
|
||||||
void removeAllListeners();
|
|
||||||
|
|
||||||
List<CandidateInfo> getCandidateStatus();
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @param awaitCompletion If true: await completion of currently scheduled tasks. If false: shutdown immediately,
|
|
||||||
* cancelling any currently executing tasks
|
|
||||||
*/
|
|
||||||
void shutdown(boolean awaitCompletion);
|
|
||||||
}
|
|
|
@ -1,150 +0,0 @@
|
||||||
/*******************************************************************************
|
|
||||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
|
||||||
*
|
|
||||||
* This program and the accompanying materials are made available under the
|
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* License for the specific language governing permissions and limitations
|
|
||||||
* under the License.
|
|
||||||
*
|
|
||||||
* SPDX-License-Identifier: Apache-2.0
|
|
||||||
******************************************************************************/
|
|
||||||
|
|
||||||
package org.deeplearning4j.arbiter.optimize.runner;
|
|
||||||
|
|
||||||
import org.nd4j.shade.guava.util.concurrent.ListenableFuture;
|
|
||||||
import org.nd4j.shade.guava.util.concurrent.ListeningExecutorService;
|
|
||||||
import org.nd4j.shade.guava.util.concurrent.MoreExecutors;
|
|
||||||
import lombok.Setter;
|
|
||||||
import org.deeplearning4j.arbiter.optimize.api.*;
|
|
||||||
import org.deeplearning4j.arbiter.optimize.api.data.DataProvider;
|
|
||||||
import org.deeplearning4j.arbiter.optimize.api.data.DataSource;
|
|
||||||
import org.deeplearning4j.arbiter.optimize.api.score.ScoreFunction;
|
|
||||||
import org.deeplearning4j.arbiter.optimize.config.OptimizationConfiguration;
|
|
||||||
|
|
||||||
import java.util.ArrayList;
|
|
||||||
import java.util.Collections;
|
|
||||||
import java.util.List;
|
|
||||||
import java.util.Properties;
|
|
||||||
import java.util.concurrent.*;
|
|
||||||
import java.util.concurrent.atomic.AtomicLong;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* LocalOptimizationRunner: execute hyperparameter optimization
|
|
||||||
* locally (on current machine, in current JVM).
|
|
||||||
*
|
|
||||||
* @author Alex Black
|
|
||||||
*/
|
|
||||||
public class LocalOptimizationRunner extends BaseOptimizationRunner {
|
|
||||||
|
|
||||||
public static final int DEFAULT_MAX_CONCURRENT_TASKS = 1;
|
|
||||||
|
|
||||||
private final int maxConcurrentTasks;
|
|
||||||
|
|
||||||
private TaskCreator taskCreator;
|
|
||||||
private ListeningExecutorService executor;
|
|
||||||
@Setter
|
|
||||||
private long shutdownMaxWaitMS = 2L * 24 * 60 * 60 * 1000;
|
|
||||||
|
|
||||||
public LocalOptimizationRunner(OptimizationConfiguration config){
|
|
||||||
this(config, null);
|
|
||||||
}
|
|
||||||
|
|
||||||
public LocalOptimizationRunner(OptimizationConfiguration config, TaskCreator taskCreator) {
|
|
||||||
this(DEFAULT_MAX_CONCURRENT_TASKS, config, taskCreator);
|
|
||||||
}
|
|
||||||
|
|
||||||
public LocalOptimizationRunner(int maxConcurrentTasks, OptimizationConfiguration config){
|
|
||||||
this(maxConcurrentTasks, config, null);
|
|
||||||
}
|
|
||||||
|
|
||||||
public LocalOptimizationRunner(int maxConcurrentTasks, OptimizationConfiguration config, TaskCreator taskCreator) {
|
|
||||||
super(config);
|
|
||||||
if (maxConcurrentTasks <= 0)
|
|
||||||
throw new IllegalArgumentException("maxConcurrentTasks must be > 0 (got: " + maxConcurrentTasks + ")");
|
|
||||||
this.maxConcurrentTasks = maxConcurrentTasks;
|
|
||||||
|
|
||||||
if(taskCreator == null){
|
|
||||||
Class<? extends ParameterSpace> psClass = config.getCandidateGenerator().getParameterSpace().getClass();
|
|
||||||
taskCreator = TaskCreatorProvider.defaultTaskCreatorFor(psClass);
|
|
||||||
if(taskCreator == null){
|
|
||||||
throw new IllegalStateException("No TaskCreator was provided and a default TaskCreator cannot be " +
|
|
||||||
"inferred for ParameterSpace class " + psClass.getName() + ". Please provide a TaskCreator " +
|
|
||||||
"via the LocalOptimizationRunner constructor");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
this.taskCreator = taskCreator;
|
|
||||||
|
|
||||||
ExecutorService exec = Executors.newFixedThreadPool(maxConcurrentTasks, new ThreadFactory() {
|
|
||||||
private AtomicLong counter = new AtomicLong(0);
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public Thread newThread(Runnable r) {
|
|
||||||
Thread t = Executors.defaultThreadFactory().newThread(r);
|
|
||||||
t.setDaemon(true);
|
|
||||||
t.setName("LocalCandidateExecutor-" + counter.getAndIncrement());
|
|
||||||
return t;
|
|
||||||
}
|
|
||||||
});
|
|
||||||
executor = MoreExecutors.listeningDecorator(exec);
|
|
||||||
|
|
||||||
init();
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
protected int maxConcurrentTasks() {
|
|
||||||
return maxConcurrentTasks;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
protected ListenableFuture<OptimizationResult> execute(Candidate candidate, DataProvider dataProvider,
|
|
||||||
ScoreFunction scoreFunction) {
|
|
||||||
return execute(Collections.singletonList(candidate), dataProvider, scoreFunction).get(0);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
protected List<ListenableFuture<OptimizationResult>> execute(List<Candidate> candidates, DataProvider dataProvider,
|
|
||||||
ScoreFunction scoreFunction) {
|
|
||||||
List<ListenableFuture<OptimizationResult>> list = new ArrayList<>(candidates.size());
|
|
||||||
for (Candidate candidate : candidates) {
|
|
||||||
Callable<OptimizationResult> task =
|
|
||||||
taskCreator.create(candidate, dataProvider, scoreFunction, statusListeners, this);
|
|
||||||
list.add(executor.submit(task));
|
|
||||||
}
|
|
||||||
return list;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
protected ListenableFuture<OptimizationResult> execute(Candidate candidate, Class<? extends DataSource> dataSource, Properties dataSourceProperties, ScoreFunction scoreFunction) {
|
|
||||||
return execute(Collections.singletonList(candidate), dataSource, dataSourceProperties, scoreFunction).get(0);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
protected List<ListenableFuture<OptimizationResult>> execute(List<Candidate> candidates, Class<? extends DataSource> dataSource, Properties dataSourceProperties, ScoreFunction scoreFunction) {
|
|
||||||
List<ListenableFuture<OptimizationResult>> list = new ArrayList<>(candidates.size());
|
|
||||||
for (Candidate candidate : candidates) {
|
|
||||||
Callable<OptimizationResult> task = taskCreator.create(candidate, dataSource, dataSourceProperties, scoreFunction, statusListeners, this);
|
|
||||||
list.add(executor.submit(task));
|
|
||||||
}
|
|
||||||
return list;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void shutdown(boolean awaitTermination) {
|
|
||||||
if(awaitTermination){
|
|
||||||
try {
|
|
||||||
executor.shutdown();
|
|
||||||
executor.awaitTermination(shutdownMaxWaitMS, TimeUnit.MILLISECONDS);
|
|
||||||
} catch (InterruptedException e){
|
|
||||||
throw new RuntimeException(e);
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
executor.shutdownNow();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,54 +0,0 @@
|
||||||
/*******************************************************************************
|
|
||||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
|
||||||
*
|
|
||||||
* This program and the accompanying materials are made available under the
|
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* License for the specific language governing permissions and limitations
|
|
||||||
* under the License.
|
|
||||||
*
|
|
||||||
* SPDX-License-Identifier: Apache-2.0
|
|
||||||
******************************************************************************/
|
|
||||||
|
|
||||||
package org.deeplearning4j.arbiter.optimize.runner.listener;
|
|
||||||
|
|
||||||
import org.deeplearning4j.arbiter.optimize.api.OptimizationResult;
|
|
||||||
import org.deeplearning4j.arbiter.optimize.runner.CandidateInfo;
|
|
||||||
import org.deeplearning4j.arbiter.optimize.runner.IOptimizationRunner;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* BaseStatusListener: implements all methods of {@link StatusListener} as no-op.
|
|
||||||
* Users can extend this and override only the methods actually required
|
|
||||||
*
|
|
||||||
* @author Alex Black
|
|
||||||
*/
|
|
||||||
public abstract class BaseStatusListener implements StatusListener{
|
|
||||||
@Override
|
|
||||||
public void onInitialization(IOptimizationRunner runner) {
|
|
||||||
//No op
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void onShutdown(IOptimizationRunner runner) {
|
|
||||||
//No op
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void onRunnerStatusChange(IOptimizationRunner runner) {
|
|
||||||
//No op
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void onCandidateStatusChange(CandidateInfo candidateInfo, IOptimizationRunner runner, OptimizationResult result) {
|
|
||||||
//No op
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void onCandidateIteration(CandidateInfo candidateInfo, Object candidate, int iteration) {
|
|
||||||
//No op
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,26 +0,0 @@
|
||||||
/*******************************************************************************
|
|
||||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
|
||||||
*
|
|
||||||
* This program and the accompanying materials are made available under the
|
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* License for the specific language governing permissions and limitations
|
|
||||||
* under the License.
|
|
||||||
*
|
|
||||||
* SPDX-License-Identifier: Apache-2.0
|
|
||||||
******************************************************************************/
|
|
||||||
|
|
||||||
package org.deeplearning4j.arbiter.optimize.runner.listener;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Created by Alex on 20/07/2017.
|
|
||||||
*/
|
|
||||||
public enum StatusChangeType {
|
|
||||||
|
|
||||||
CandidateCompleted, CandidateFailed, CandidateNewScheduled, CandidateNewBestScore
|
|
||||||
|
|
||||||
}
|
|
|
@ -1,60 +0,0 @@
|
||||||
/*******************************************************************************
|
|
||||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
|
||||||
*
|
|
||||||
* This program and the accompanying materials are made available under the
|
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* License for the specific language governing permissions and limitations
|
|
||||||
* under the License.
|
|
||||||
*
|
|
||||||
* SPDX-License-Identifier: Apache-2.0
|
|
||||||
******************************************************************************/
|
|
||||||
|
|
||||||
package org.deeplearning4j.arbiter.optimize.runner.listener;
|
|
||||||
|
|
||||||
import org.deeplearning4j.arbiter.optimize.api.OptimizationResult;
|
|
||||||
import org.deeplearning4j.arbiter.optimize.runner.CandidateInfo;
|
|
||||||
import org.deeplearning4j.arbiter.optimize.runner.IOptimizationRunner;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* The status Listener interface is used to inspect/track the status of execution, both for individual candidates,
|
|
||||||
* and for the optimisation runner overall.
|
|
||||||
*
|
|
||||||
* @author Alex Black
|
|
||||||
*/
|
|
||||||
public interface StatusListener {
|
|
||||||
|
|
||||||
/** Called when optimization runner starts execution */
|
|
||||||
void onInitialization(IOptimizationRunner runner);
|
|
||||||
|
|
||||||
/** Called when optimization runner terminates */
|
|
||||||
void onShutdown(IOptimizationRunner runner);
|
|
||||||
|
|
||||||
/** Called when any of the summary stats change, for the optimization runner:
|
|
||||||
* number scheduled, number completed, number failed, best score, etc. */
|
|
||||||
void onRunnerStatusChange(IOptimizationRunner runner);
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Called when the status of the candidate is change. For example created, completed, failed.
|
|
||||||
*
|
|
||||||
* @param candidateInfo Candidate information
|
|
||||||
* @param runner Optimisation runner calling this method
|
|
||||||
* @param result Optimisation result. Maybe null.
|
|
||||||
*/
|
|
||||||
void onCandidateStatusChange(CandidateInfo candidateInfo, IOptimizationRunner runner, OptimizationResult result);
|
|
||||||
|
|
||||||
/**
|
|
||||||
* This method may be called by tasks as they are executing. The intent of this method is to report partial results,
|
|
||||||
* such as different stages of learning, or scores/evaluations so far
|
|
||||||
*
|
|
||||||
* @param candidateInfo Candidate information
|
|
||||||
* @param candidate Current candidate value/configuration
|
|
||||||
* @param iteration Current iteration number
|
|
||||||
*/
|
|
||||||
void onCandidateIteration(CandidateInfo candidateInfo, Object candidate, int iteration);
|
|
||||||
|
|
||||||
}
|
|
|
@ -1,57 +0,0 @@
|
||||||
/*******************************************************************************
|
|
||||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
|
||||||
*
|
|
||||||
* This program and the accompanying materials are made available under the
|
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* License for the specific language governing permissions and limitations
|
|
||||||
* under the License.
|
|
||||||
*
|
|
||||||
* SPDX-License-Identifier: Apache-2.0
|
|
||||||
******************************************************************************/
|
|
||||||
|
|
||||||
package org.deeplearning4j.arbiter.optimize.runner.listener.impl;
|
|
||||||
|
|
||||||
import lombok.extern.slf4j.Slf4j;
|
|
||||||
import org.deeplearning4j.arbiter.optimize.api.OptimizationResult;
|
|
||||||
import org.deeplearning4j.arbiter.optimize.runner.CandidateInfo;
|
|
||||||
import org.deeplearning4j.arbiter.optimize.runner.IOptimizationRunner;
|
|
||||||
import org.deeplearning4j.arbiter.optimize.runner.listener.StatusListener;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Created by Alex on 20/07/2017.
|
|
||||||
*/
|
|
||||||
@Slf4j
|
|
||||||
public class LoggingStatusListener implements StatusListener {
|
|
||||||
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void onInitialization(IOptimizationRunner runner) {
|
|
||||||
log.info("Optimization runner: initialized");
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void onShutdown(IOptimizationRunner runner) {
|
|
||||||
log.info("Optimization runner: shut down");
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void onRunnerStatusChange(IOptimizationRunner runner) {
|
|
||||||
log.info("Optimization runner: status change");
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void onCandidateStatusChange(CandidateInfo candidateInfo, IOptimizationRunner runner,
|
|
||||||
OptimizationResult result) {
|
|
||||||
log.info("Candidate status change: {}", candidateInfo);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void onCandidateIteration(CandidateInfo candidateInfo, Object candidate, int iteration) {
|
|
||||||
log.info("Candidate iteration #{} - {}", iteration, candidate);
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,52 +0,0 @@
|
||||||
package org.deeplearning4j.arbiter.optimize.serde.jackson;
|
|
||||||
|
|
||||||
import org.apache.commons.codec.binary.Base64;
|
|
||||||
import org.deeplearning4j.arbiter.optimize.parameter.FixedValue;
|
|
||||||
import org.nd4j.shade.jackson.core.JsonParser;
|
|
||||||
import org.nd4j.shade.jackson.core.JsonProcessingException;
|
|
||||||
import org.nd4j.shade.jackson.databind.DeserializationContext;
|
|
||||||
import org.nd4j.shade.jackson.databind.JsonDeserializer;
|
|
||||||
import org.nd4j.shade.jackson.databind.JsonNode;
|
|
||||||
import org.nd4j.shade.jackson.databind.ObjectMapper;
|
|
||||||
|
|
||||||
import java.io.ByteArrayInputStream;
|
|
||||||
import java.io.IOException;
|
|
||||||
import java.io.ObjectInputStream;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* A custom deserializer to be used in conjunction with {@link FixedValueSerializer}
|
|
||||||
* @author Alex Black
|
|
||||||
*/
|
|
||||||
public class FixedValueDeserializer extends JsonDeserializer<FixedValue> {
|
|
||||||
@Override
|
|
||||||
public FixedValue deserialize(JsonParser p, DeserializationContext deserializationContext) throws IOException, JsonProcessingException {
|
|
||||||
JsonNode node = p.getCodec().readTree(p);
|
|
||||||
String className = node.get("@valueclass").asText();
|
|
||||||
Class<?> c;
|
|
||||||
try {
|
|
||||||
c = Class.forName(className);
|
|
||||||
} catch (Exception e) {
|
|
||||||
throw new RuntimeException(e);
|
|
||||||
}
|
|
||||||
|
|
||||||
if(node.has("value")){
|
|
||||||
//Number, String, Enum
|
|
||||||
JsonNode valueNode = node.get("value");
|
|
||||||
Object o = new ObjectMapper().treeToValue(valueNode, c);
|
|
||||||
return new FixedValue<>(o);
|
|
||||||
} else {
|
|
||||||
//Everything else
|
|
||||||
JsonNode valueNode = node.get("data");
|
|
||||||
String data = valueNode.asText();
|
|
||||||
|
|
||||||
byte[] b = new Base64().decode(data);
|
|
||||||
ObjectInputStream ois = new ObjectInputStream(new ByteArrayInputStream(b));
|
|
||||||
try {
|
|
||||||
Object o = ois.readObject();
|
|
||||||
return new FixedValue<>(o);
|
|
||||||
} catch (Throwable t) {
|
|
||||||
throw new RuntimeException(t);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,51 +0,0 @@
|
||||||
package org.deeplearning4j.arbiter.optimize.serde.jackson;
|
|
||||||
|
|
||||||
import org.apache.commons.net.util.Base64;
|
|
||||||
import org.deeplearning4j.arbiter.optimize.parameter.FixedValue;
|
|
||||||
import org.nd4j.shade.jackson.core.JsonGenerator;
|
|
||||||
import org.nd4j.shade.jackson.core.type.WritableTypeId;
|
|
||||||
import org.nd4j.shade.jackson.databind.JsonSerializer;
|
|
||||||
import org.nd4j.shade.jackson.databind.SerializerProvider;
|
|
||||||
import org.nd4j.shade.jackson.databind.jsontype.TypeSerializer;
|
|
||||||
|
|
||||||
import java.io.ByteArrayOutputStream;
|
|
||||||
import java.io.IOException;
|
|
||||||
import java.io.ObjectOutputStream;
|
|
||||||
|
|
||||||
import static org.nd4j.shade.jackson.core.JsonToken.START_OBJECT;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* A custom serializer to handle arbitrary object types
|
|
||||||
* Uses standard JSON where safe (number, string, enumerations) or Java object serialization (bytes -> base64)
|
|
||||||
* The latter is not an ideal approach, but Jackson doesn't support serialization/deserialization of arbitrary
|
|
||||||
* objects very well
|
|
||||||
*
|
|
||||||
* @author Alex Black
|
|
||||||
*/
|
|
||||||
public class FixedValueSerializer extends JsonSerializer<FixedValue> {
|
|
||||||
@Override
|
|
||||||
public void serialize(FixedValue fixedValue, JsonGenerator j, SerializerProvider serializerProvider) throws IOException {
|
|
||||||
Object o = fixedValue.getValue();
|
|
||||||
|
|
||||||
j.writeStringField("@valueclass", o.getClass().getName());
|
|
||||||
if(o instanceof Number || o instanceof String || o instanceof Enum){
|
|
||||||
j.writeObjectField("value", o);
|
|
||||||
} else {
|
|
||||||
ByteArrayOutputStream baos = new ByteArrayOutputStream();
|
|
||||||
ObjectOutputStream oos = new ObjectOutputStream(baos);
|
|
||||||
oos.writeObject(o);
|
|
||||||
baos.close();
|
|
||||||
byte[] b = baos.toByteArray();
|
|
||||||
String base64 = new Base64().encodeToString(b);
|
|
||||||
j.writeStringField("data", base64);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void serializeWithType(FixedValue value, JsonGenerator gen, SerializerProvider serializers, TypeSerializer typeSer) throws IOException {
|
|
||||||
WritableTypeId typeId = typeSer.typeId(value, START_OBJECT);
|
|
||||||
typeSer.writeTypePrefix(gen, typeId);
|
|
||||||
serialize(value, gen, serializers);
|
|
||||||
typeSer.writeTypeSuffix(gen, typeId);
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,59 +0,0 @@
|
||||||
/*******************************************************************************
|
|
||||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
|
||||||
*
|
|
||||||
* This program and the accompanying materials are made available under the
|
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* License for the specific language governing permissions and limitations
|
|
||||||
* under the License.
|
|
||||||
*
|
|
||||||
* SPDX-License-Identifier: Apache-2.0
|
|
||||||
******************************************************************************/
|
|
||||||
|
|
||||||
package org.deeplearning4j.arbiter.optimize.serde.jackson;
|
|
||||||
|
|
||||||
import org.apache.commons.math3.distribution.*;
|
|
||||||
import org.nd4j.shade.jackson.core.JsonParser;
|
|
||||||
import org.nd4j.shade.jackson.databind.DeserializationContext;
|
|
||||||
import org.nd4j.shade.jackson.databind.JsonDeserializer;
|
|
||||||
import org.nd4j.shade.jackson.databind.JsonNode;
|
|
||||||
|
|
||||||
import java.io.IOException;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Custom Jackson deserializer for integer distributions
|
|
||||||
*
|
|
||||||
* @author Alex Black
|
|
||||||
*/
|
|
||||||
public class IntegerDistributionDeserializer extends JsonDeserializer<IntegerDistribution> {
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public IntegerDistribution deserialize(JsonParser p, DeserializationContext ctxt) throws IOException {
|
|
||||||
JsonNode node = p.getCodec().readTree(p);
|
|
||||||
String simpleName = node.get("distribution").asText();
|
|
||||||
|
|
||||||
switch (simpleName) {
|
|
||||||
case "BinomialDistribution":
|
|
||||||
return new BinomialDistribution(node.get("trials").asInt(), node.get("p").asDouble());
|
|
||||||
case "GeometricDistribution":
|
|
||||||
return new GeometricDistribution(node.get("p").asDouble());
|
|
||||||
case "HypergeometricDistribution":
|
|
||||||
return new HypergeometricDistribution(node.get("populationSize").asInt(),
|
|
||||||
node.get("numberOfSuccesses").asInt(), node.get("sampleSize").asInt());
|
|
||||||
case "PascalDistribution":
|
|
||||||
return new PascalDistribution(node.get("r").asInt(), node.get("p").asDouble());
|
|
||||||
case "PoissonDistribution":
|
|
||||||
return new PoissonDistribution(node.get("p").asDouble());
|
|
||||||
case "UniformIntegerDistribution":
|
|
||||||
return new UniformIntegerDistribution(node.get("lower").asInt(), node.get("upper").asInt());
|
|
||||||
case "ZipfDistribution":
|
|
||||||
return new ZipfDistribution(node.get("numElements").asInt(), node.get("exponent").asDouble());
|
|
||||||
default:
|
|
||||||
throw new RuntimeException("Unknown or not supported distribution: " + simpleName);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,74 +0,0 @@
|
||||||
/*******************************************************************************
|
|
||||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
|
||||||
*
|
|
||||||
* This program and the accompanying materials are made available under the
|
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* License for the specific language governing permissions and limitations
|
|
||||||
* under the License.
|
|
||||||
*
|
|
||||||
* SPDX-License-Identifier: Apache-2.0
|
|
||||||
******************************************************************************/
|
|
||||||
|
|
||||||
package org.deeplearning4j.arbiter.optimize.serde.jackson;
|
|
||||||
|
|
||||||
import org.apache.commons.math3.distribution.*;
|
|
||||||
import org.nd4j.shade.jackson.core.JsonGenerator;
|
|
||||||
import org.nd4j.shade.jackson.databind.JsonSerializer;
|
|
||||||
import org.nd4j.shade.jackson.databind.SerializerProvider;
|
|
||||||
|
|
||||||
import java.io.IOException;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Custom Jackson serializer for integer distributions
|
|
||||||
*
|
|
||||||
* @author Alex Black
|
|
||||||
*/
|
|
||||||
public class IntegerDistributionSerializer extends JsonSerializer<IntegerDistribution> {
|
|
||||||
@Override
|
|
||||||
public void serialize(IntegerDistribution d, JsonGenerator j, SerializerProvider serializerProvider)
|
|
||||||
throws IOException {
|
|
||||||
Class<?> c = d.getClass();
|
|
||||||
String s = c.getSimpleName();
|
|
||||||
|
|
||||||
j.writeStartObject();
|
|
||||||
j.writeStringField("distribution", s);
|
|
||||||
|
|
||||||
if (c == BinomialDistribution.class) {
|
|
||||||
BinomialDistribution bd = (BinomialDistribution) d;
|
|
||||||
j.writeNumberField("trials", bd.getNumberOfTrials());
|
|
||||||
j.writeNumberField("p", bd.getProbabilityOfSuccess());
|
|
||||||
} else if (c == GeometricDistribution.class) {
|
|
||||||
GeometricDistribution gd = (GeometricDistribution) d;
|
|
||||||
j.writeNumberField("p", gd.getProbabilityOfSuccess());
|
|
||||||
} else if (c == HypergeometricDistribution.class) {
|
|
||||||
HypergeometricDistribution hd = (HypergeometricDistribution) d;
|
|
||||||
j.writeNumberField("populationSize", hd.getPopulationSize());
|
|
||||||
j.writeNumberField("numberOfSuccesses", hd.getNumberOfSuccesses());
|
|
||||||
j.writeNumberField("sampleSize", hd.getSampleSize());
|
|
||||||
} else if (c == PascalDistribution.class) {
|
|
||||||
PascalDistribution pd = (PascalDistribution) d;
|
|
||||||
j.writeNumberField("r", pd.getNumberOfSuccesses());
|
|
||||||
j.writeNumberField("p", pd.getProbabilityOfSuccess());
|
|
||||||
} else if (c == PoissonDistribution.class) {
|
|
||||||
PoissonDistribution pd = (PoissonDistribution) d;
|
|
||||||
j.writeNumberField("p", pd.getMean());
|
|
||||||
} else if (c == UniformIntegerDistribution.class) {
|
|
||||||
UniformIntegerDistribution ud = (UniformIntegerDistribution) d;
|
|
||||||
j.writeNumberField("lower", ud.getSupportLowerBound());
|
|
||||||
j.writeNumberField("upper", ud.getSupportUpperBound());
|
|
||||||
} else if (c == ZipfDistribution.class) {
|
|
||||||
ZipfDistribution zd = (ZipfDistribution) d;
|
|
||||||
j.writeNumberField("numElements", zd.getNumberOfElements());
|
|
||||||
j.writeNumberField("exponent", zd.getExponent());
|
|
||||||
} else {
|
|
||||||
throw new UnsupportedOperationException("Unknown or not supported IntegerDistribution: " + c);
|
|
||||||
}
|
|
||||||
|
|
||||||
j.writeEndObject();
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,74 +0,0 @@
|
||||||
/*******************************************************************************
|
|
||||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
|
||||||
*
|
|
||||||
* This program and the accompanying materials are made available under the
|
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* License for the specific language governing permissions and limitations
|
|
||||||
* under the License.
|
|
||||||
*
|
|
||||||
* SPDX-License-Identifier: Apache-2.0
|
|
||||||
******************************************************************************/
|
|
||||||
|
|
||||||
package org.deeplearning4j.arbiter.optimize.serde.jackson;
|
|
||||||
|
|
||||||
import org.nd4j.shade.jackson.annotation.JsonAutoDetect;
|
|
||||||
import org.nd4j.shade.jackson.annotation.PropertyAccessor;
|
|
||||||
import org.nd4j.shade.jackson.databind.DeserializationFeature;
|
|
||||||
import org.nd4j.shade.jackson.databind.ObjectMapper;
|
|
||||||
import org.nd4j.shade.jackson.databind.SerializationFeature;
|
|
||||||
import org.nd4j.shade.jackson.dataformat.yaml.YAMLFactory;
|
|
||||||
import org.nd4j.shade.jackson.datatype.joda.JodaModule;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Created by Alex on 16/11/2016.
|
|
||||||
*/
|
|
||||||
public class JsonMapper {
|
|
||||||
|
|
||||||
private static ObjectMapper mapper;
|
|
||||||
private static ObjectMapper yamlMapper;
|
|
||||||
|
|
||||||
static {
|
|
||||||
mapper = new ObjectMapper();
|
|
||||||
mapper.registerModule(new JodaModule());
|
|
||||||
mapper.configure(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, false);
|
|
||||||
mapper.configure(SerializationFeature.FAIL_ON_EMPTY_BEANS, false);
|
|
||||||
mapper.enable(SerializationFeature.INDENT_OUTPUT);
|
|
||||||
mapper.setVisibility(PropertyAccessor.ALL, JsonAutoDetect.Visibility.NONE);
|
|
||||||
mapper.setVisibility(PropertyAccessor.FIELD, JsonAutoDetect.Visibility.ANY);
|
|
||||||
mapper.setVisibility(PropertyAccessor.CREATOR, JsonAutoDetect.Visibility.ANY);
|
|
||||||
mapper.setVisibility(PropertyAccessor.SETTER, JsonAutoDetect.Visibility.ANY);
|
|
||||||
yamlMapper = new ObjectMapper(new YAMLFactory());
|
|
||||||
yamlMapper.registerModule(new JodaModule());
|
|
||||||
yamlMapper.configure(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, false);
|
|
||||||
yamlMapper.configure(SerializationFeature.FAIL_ON_EMPTY_BEANS, false);
|
|
||||||
yamlMapper.enable(SerializationFeature.INDENT_OUTPUT);
|
|
||||||
yamlMapper.setVisibility(PropertyAccessor.ALL, JsonAutoDetect.Visibility.NONE);
|
|
||||||
yamlMapper.setVisibility(PropertyAccessor.FIELD, JsonAutoDetect.Visibility.ANY);
|
|
||||||
yamlMapper.setVisibility(PropertyAccessor.CREATOR, JsonAutoDetect.Visibility.ANY);
|
|
||||||
}
|
|
||||||
|
|
||||||
private JsonMapper() {}
|
|
||||||
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Return the yaml mapper
|
|
||||||
* @return
|
|
||||||
*/
|
|
||||||
public static ObjectMapper getYamlMapper() {
|
|
||||||
return yamlMapper;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Return a json mapper
|
|
||||||
* @return
|
|
||||||
*/
|
|
||||||
public static ObjectMapper getMapper() {
|
|
||||||
return mapper;
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
|
@ -1,78 +0,0 @@
|
||||||
/*******************************************************************************
|
|
||||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
|
||||||
*
|
|
||||||
* This program and the accompanying materials are made available under the
|
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* License for the specific language governing permissions and limitations
|
|
||||||
* under the License.
|
|
||||||
*
|
|
||||||
* SPDX-License-Identifier: Apache-2.0
|
|
||||||
******************************************************************************/
|
|
||||||
|
|
||||||
package org.deeplearning4j.arbiter.optimize.serde.jackson;
|
|
||||||
|
|
||||||
import org.apache.commons.math3.distribution.*;
|
|
||||||
import org.deeplearning4j.arbiter.optimize.distribution.LogUniformDistribution;
|
|
||||||
import org.nd4j.shade.jackson.core.JsonParser;
|
|
||||||
import org.nd4j.shade.jackson.core.JsonProcessingException;
|
|
||||||
import org.nd4j.shade.jackson.databind.DeserializationContext;
|
|
||||||
import org.nd4j.shade.jackson.databind.JsonDeserializer;
|
|
||||||
import org.nd4j.shade.jackson.databind.JsonNode;
|
|
||||||
|
|
||||||
import java.io.IOException;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Created by Alex on 14/02/2017.
|
|
||||||
*/
|
|
||||||
public class RealDistributionDeserializer extends JsonDeserializer<RealDistribution> {
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public RealDistribution deserialize(JsonParser p, DeserializationContext ctxt)
|
|
||||||
throws IOException, JsonProcessingException {
|
|
||||||
JsonNode node = p.getCodec().readTree(p);
|
|
||||||
String simpleName = node.get("distribution").asText();
|
|
||||||
|
|
||||||
switch (simpleName) {
|
|
||||||
case "BetaDistribution":
|
|
||||||
return new BetaDistribution(node.get("alpha").asDouble(), node.get("beta").asDouble());
|
|
||||||
case "CauchyDistribution":
|
|
||||||
return new CauchyDistribution(node.get("median").asDouble(), node.get("scale").asDouble());
|
|
||||||
case "ChiSquaredDistribution":
|
|
||||||
return new ChiSquaredDistribution(node.get("dof").asDouble());
|
|
||||||
case "ExponentialDistribution":
|
|
||||||
return new ExponentialDistribution(node.get("mean").asDouble());
|
|
||||||
case "FDistribution":
|
|
||||||
return new FDistribution(node.get("numeratorDof").asDouble(), node.get("denominatorDof").asDouble());
|
|
||||||
case "GammaDistribution":
|
|
||||||
return new GammaDistribution(node.get("shape").asDouble(), node.get("scale").asDouble());
|
|
||||||
case "LevyDistribution":
|
|
||||||
return new LevyDistribution(node.get("mu").asDouble(), node.get("c").asDouble());
|
|
||||||
case "LogNormalDistribution":
|
|
||||||
return new LogNormalDistribution(node.get("scale").asDouble(), node.get("shape").asDouble());
|
|
||||||
case "NormalDistribution":
|
|
||||||
return new NormalDistribution(node.get("mean").asDouble(), node.get("stdev").asDouble());
|
|
||||||
case "ParetoDistribution":
|
|
||||||
return new ParetoDistribution(node.get("scale").asDouble(), node.get("shape").asDouble());
|
|
||||||
case "TDistribution":
|
|
||||||
return new TDistribution(node.get("dof").asDouble());
|
|
||||||
case "TriangularDistribution":
|
|
||||||
return new TriangularDistribution(node.get("a").asDouble(), node.get("b").asDouble(),
|
|
||||||
node.get("c").asDouble());
|
|
||||||
case "UniformRealDistribution":
|
|
||||||
return new UniformRealDistribution(node.get("lower").asDouble(), node.get("upper").asDouble());
|
|
||||||
case "WeibullDistribution":
|
|
||||||
return new WeibullDistribution(node.get("alpha").asDouble(), node.get("beta").asDouble());
|
|
||||||
case "LogUniformDistribution":
|
|
||||||
return new LogUniformDistribution(node.get("min").asDouble(), node.get("max").asDouble());
|
|
||||||
default:
|
|
||||||
throw new RuntimeException("Unknown or not supported distribution: " + simpleName);
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,107 +0,0 @@
|
||||||
/*******************************************************************************
|
|
||||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
|
||||||
*
|
|
||||||
* This program and the accompanying materials are made available under the
|
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* License for the specific language governing permissions and limitations
|
|
||||||
* under the License.
|
|
||||||
*
|
|
||||||
* SPDX-License-Identifier: Apache-2.0
|
|
||||||
******************************************************************************/
|
|
||||||
|
|
||||||
package org.deeplearning4j.arbiter.optimize.serde.jackson;
|
|
||||||
|
|
||||||
import org.apache.commons.math3.distribution.*;
|
|
||||||
import org.deeplearning4j.arbiter.optimize.distribution.LogUniformDistribution;
|
|
||||||
import org.nd4j.shade.jackson.core.JsonGenerator;
|
|
||||||
import org.nd4j.shade.jackson.databind.JsonSerializer;
|
|
||||||
import org.nd4j.shade.jackson.databind.SerializerProvider;
|
|
||||||
|
|
||||||
import java.io.IOException;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Custom JSON serializer for Apache commons RealDistribution instances.
|
|
||||||
* The custom serializer is set up to use the built-in c
|
|
||||||
*/
|
|
||||||
public class RealDistributionSerializer extends JsonSerializer<RealDistribution> {
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void serialize(RealDistribution d, JsonGenerator j, SerializerProvider serializerProvider)
|
|
||||||
throws IOException {
|
|
||||||
Class<?> c = d.getClass();
|
|
||||||
String s = c.getSimpleName();
|
|
||||||
|
|
||||||
j.writeStartObject();
|
|
||||||
j.writeStringField("distribution", s);
|
|
||||||
|
|
||||||
|
|
||||||
if (c == BetaDistribution.class) {
|
|
||||||
BetaDistribution bd = (BetaDistribution) d;
|
|
||||||
j.writeNumberField("alpha", bd.getAlpha());
|
|
||||||
j.writeNumberField("beta", bd.getBeta());
|
|
||||||
} else if (c == CauchyDistribution.class) {
|
|
||||||
CauchyDistribution cd = (CauchyDistribution) d;
|
|
||||||
j.writeNumberField("median", cd.getMedian());
|
|
||||||
j.writeNumberField("scale", cd.getScale());
|
|
||||||
} else if (c == ChiSquaredDistribution.class) {
|
|
||||||
ChiSquaredDistribution cd = (ChiSquaredDistribution) d;
|
|
||||||
j.writeNumberField("dof", cd.getDegreesOfFreedom());
|
|
||||||
} else if (c == ExponentialDistribution.class) {
|
|
||||||
ExponentialDistribution ed = (ExponentialDistribution) d;
|
|
||||||
j.writeNumberField("mean", ed.getMean());
|
|
||||||
} else if (c == FDistribution.class) {
|
|
||||||
FDistribution fd = (FDistribution) d;
|
|
||||||
j.writeNumberField("numeratorDof", fd.getNumeratorDegreesOfFreedom());
|
|
||||||
j.writeNumberField("denominatorDof", fd.getDenominatorDegreesOfFreedom());
|
|
||||||
} else if (c == GammaDistribution.class) {
|
|
||||||
GammaDistribution gd = (GammaDistribution) d;
|
|
||||||
j.writeNumberField("shape", gd.getShape());
|
|
||||||
j.writeNumberField("scale", gd.getScale());
|
|
||||||
} else if (c == LevyDistribution.class) {
|
|
||||||
LevyDistribution ld = (LevyDistribution) d;
|
|
||||||
j.writeNumberField("mu", ld.getLocation());
|
|
||||||
j.writeNumberField("c", ld.getScale());
|
|
||||||
} else if (c == LogNormalDistribution.class) {
|
|
||||||
LogNormalDistribution ln = (LogNormalDistribution) d;
|
|
||||||
j.writeNumberField("scale", ln.getScale());
|
|
||||||
j.writeNumberField("shape", ln.getShape());
|
|
||||||
} else if (c == NormalDistribution.class) {
|
|
||||||
NormalDistribution nd = (NormalDistribution) d;
|
|
||||||
j.writeNumberField("mean", nd.getMean());
|
|
||||||
j.writeNumberField("stdev", nd.getStandardDeviation());
|
|
||||||
} else if (c == ParetoDistribution.class) {
|
|
||||||
ParetoDistribution pd = (ParetoDistribution) d;
|
|
||||||
j.writeNumberField("scale", pd.getScale());
|
|
||||||
j.writeNumberField("shape", pd.getShape());
|
|
||||||
} else if (c == TDistribution.class) {
|
|
||||||
TDistribution td = (TDistribution) d;
|
|
||||||
j.writeNumberField("dof", td.getDegreesOfFreedom());
|
|
||||||
} else if (c == TriangularDistribution.class) {
|
|
||||||
TriangularDistribution td = (TriangularDistribution) d;
|
|
||||||
j.writeNumberField("a", td.getSupportLowerBound());
|
|
||||||
j.writeNumberField("b", td.getMode());
|
|
||||||
j.writeNumberField("c", td.getSupportUpperBound());
|
|
||||||
} else if (c == UniformRealDistribution.class) {
|
|
||||||
UniformRealDistribution u = (UniformRealDistribution) d;
|
|
||||||
j.writeNumberField("lower", u.getSupportLowerBound());
|
|
||||||
j.writeNumberField("upper", u.getSupportUpperBound());
|
|
||||||
} else if (c == WeibullDistribution.class) {
|
|
||||||
WeibullDistribution wb = (WeibullDistribution) d;
|
|
||||||
j.writeNumberField("alpha", wb.getShape());
|
|
||||||
j.writeNumberField("beta", wb.getScale());
|
|
||||||
} else if (c == LogUniformDistribution.class){
|
|
||||||
LogUniformDistribution lud = (LogUniformDistribution) d;
|
|
||||||
j.writeNumberField("min", lud.getMin());
|
|
||||||
j.writeNumberField("max", lud.getMax());
|
|
||||||
} else {
|
|
||||||
throw new UnsupportedOperationException("Unknown or not supported RealDistribution: " + d.getClass());
|
|
||||||
}
|
|
||||||
|
|
||||||
j.writeEndObject();
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,52 +0,0 @@
|
||||||
/*******************************************************************************
|
|
||||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
|
||||||
*
|
|
||||||
* This program and the accompanying materials are made available under the
|
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* License for the specific language governing permissions and limitations
|
|
||||||
* under the License.
|
|
||||||
*
|
|
||||||
* SPDX-License-Identifier: Apache-2.0
|
|
||||||
******************************************************************************/
|
|
||||||
|
|
||||||
package org.deeplearning4j.arbiter.optimize.serde.jackson;
|
|
||||||
|
|
||||||
import org.nd4j.shade.jackson.annotation.JsonAutoDetect;
|
|
||||||
import org.nd4j.shade.jackson.annotation.PropertyAccessor;
|
|
||||||
import org.nd4j.shade.jackson.databind.DeserializationFeature;
|
|
||||||
import org.nd4j.shade.jackson.databind.ObjectMapper;
|
|
||||||
import org.nd4j.shade.jackson.databind.SerializationFeature;
|
|
||||||
import org.nd4j.shade.jackson.dataformat.yaml.YAMLFactory;
|
|
||||||
import org.nd4j.shade.jackson.datatype.joda.JodaModule;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Created by Alex on 16/11/2016.
|
|
||||||
*/
|
|
||||||
public class YamlMapper {
|
|
||||||
|
|
||||||
private static final ObjectMapper mapper;
|
|
||||||
|
|
||||||
static {
|
|
||||||
mapper = new ObjectMapper(new YAMLFactory());
|
|
||||||
mapper.registerModule(new JodaModule());
|
|
||||||
mapper.configure(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, false);
|
|
||||||
mapper.configure(SerializationFeature.FAIL_ON_EMPTY_BEANS, false);
|
|
||||||
mapper.enable(SerializationFeature.INDENT_OUTPUT);
|
|
||||||
mapper.setVisibility(PropertyAccessor.ALL, JsonAutoDetect.Visibility.NONE);
|
|
||||||
mapper.setVisibility(PropertyAccessor.FIELD, JsonAutoDetect.Visibility.ANY);
|
|
||||||
mapper.setVisibility(PropertyAccessor.CREATOR, JsonAutoDetect.Visibility.ANY);
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
private YamlMapper() {}
|
|
||||||
|
|
||||||
public static ObjectMapper getMapper() {
|
|
||||||
return mapper;
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
|
@ -1,233 +0,0 @@
|
||||||
/*******************************************************************************
|
|
||||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
|
||||||
*
|
|
||||||
* This program and the accompanying materials are made available under the
|
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* License for the specific language governing permissions and limitations
|
|
||||||
* under the License.
|
|
||||||
*
|
|
||||||
* SPDX-License-Identifier: Apache-2.0
|
|
||||||
******************************************************************************/
|
|
||||||
|
|
||||||
package org.deeplearning4j.arbiter.util;
|
|
||||||
|
|
||||||
import org.slf4j.Logger;
|
|
||||||
import org.slf4j.LoggerFactory;
|
|
||||||
|
|
||||||
import java.io.*;
|
|
||||||
import java.net.MalformedURLException;
|
|
||||||
import java.net.URI;
|
|
||||||
import java.net.URISyntaxException;
|
|
||||||
import java.net.URL;
|
|
||||||
import java.util.zip.ZipEntry;
|
|
||||||
import java.util.zip.ZipFile;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Simple utility class used to get access to files at the classpath, or packed into jar.
|
|
||||||
* Based on Spring ClassPathResource implementation + jar internals access implemented.
|
|
||||||
*
|
|
||||||
*
|
|
||||||
* @author raver119@gmail.com
|
|
||||||
*/
|
|
||||||
public class ClassPathResource {
|
|
||||||
|
|
||||||
private String resourceName;
|
|
||||||
|
|
||||||
private static Logger log = LoggerFactory.getLogger(ClassPathResource.class);
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Builds new ClassPathResource object
|
|
||||||
*
|
|
||||||
* @param resourceName String name of resource, to be retrieved
|
|
||||||
*/
|
|
||||||
public ClassPathResource(String resourceName) {
|
|
||||||
if (resourceName == null)
|
|
||||||
throw new IllegalStateException("Resource name can't be null");
|
|
||||||
this.resourceName = resourceName;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Returns URL of the requested resource
|
|
||||||
*
|
|
||||||
* @return URL of the resource, if it's available in current Jar
|
|
||||||
*/
|
|
||||||
private URL getUrl() {
|
|
||||||
ClassLoader loader = null;
|
|
||||||
try {
|
|
||||||
loader = Thread.currentThread().getContextClassLoader();
|
|
||||||
} catch (Exception e) {
|
|
||||||
// do nothing
|
|
||||||
}
|
|
||||||
|
|
||||||
if (loader == null) {
|
|
||||||
loader = ClassPathResource.class.getClassLoader();
|
|
||||||
}
|
|
||||||
|
|
||||||
URL url = loader.getResource(this.resourceName);
|
|
||||||
if (url == null) {
|
|
||||||
// try to check for mis-used starting slash
|
|
||||||
// TODO: see TODO below
|
|
||||||
if (this.resourceName.startsWith("/")) {
|
|
||||||
url = loader.getResource(this.resourceName.replaceFirst("[\\\\/]", ""));
|
|
||||||
if (url != null)
|
|
||||||
return url;
|
|
||||||
} else {
|
|
||||||
// try to add slash, to make clear it's not an issue
|
|
||||||
// TODO: change this mechanic to actual path purifier
|
|
||||||
url = loader.getResource("/" + this.resourceName);
|
|
||||||
if (url != null)
|
|
||||||
return url;
|
|
||||||
}
|
|
||||||
throw new IllegalStateException("Resource '" + this.resourceName + "' cannot be found.");
|
|
||||||
}
|
|
||||||
return url;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Returns requested ClassPathResource as File object
|
|
||||||
*
|
|
||||||
* Please note: if this method called from compiled jar, temporary file will be created to provide File access
|
|
||||||
*
|
|
||||||
* @return File requested at constructor call
|
|
||||||
* @throws FileNotFoundException
|
|
||||||
*/
|
|
||||||
public File getFile() throws FileNotFoundException {
|
|
||||||
URL url = this.getUrl();
|
|
||||||
|
|
||||||
if (isJarURL(url)) {
|
|
||||||
/*
|
|
||||||
This is actually request for file, that's packed into jar. Probably the current one, but that doesn't matters.
|
|
||||||
*/
|
|
||||||
try {
|
|
||||||
url = extractActualUrl(url);
|
|
||||||
File file = File.createTempFile("canova_temp", "file");
|
|
||||||
file.deleteOnExit();
|
|
||||||
|
|
||||||
ZipFile zipFile = new ZipFile(url.getFile());
|
|
||||||
ZipEntry entry = zipFile.getEntry(this.resourceName);
|
|
||||||
if (entry == null) {
|
|
||||||
if (this.resourceName.startsWith("/")) {
|
|
||||||
entry = zipFile.getEntry(this.resourceName.replaceFirst("/", ""));
|
|
||||||
if (entry == null) {
|
|
||||||
throw new FileNotFoundException("Resource " + this.resourceName + " not found");
|
|
||||||
}
|
|
||||||
} else
|
|
||||||
throw new FileNotFoundException("Resource " + this.resourceName + " not found");
|
|
||||||
}
|
|
||||||
|
|
||||||
long size = entry.getSize();
|
|
||||||
|
|
||||||
InputStream stream = zipFile.getInputStream(entry);
|
|
||||||
FileOutputStream outputStream = new FileOutputStream(file);
|
|
||||||
byte[] array = new byte[1024];
|
|
||||||
int rd = 0;
|
|
||||||
long bytesRead = 0;
|
|
||||||
do {
|
|
||||||
rd = stream.read(array);
|
|
||||||
outputStream.write(array, 0, rd);
|
|
||||||
bytesRead += rd;
|
|
||||||
} while (bytesRead < size);
|
|
||||||
|
|
||||||
outputStream.flush();
|
|
||||||
outputStream.close();
|
|
||||||
|
|
||||||
stream.close();
|
|
||||||
zipFile.close();
|
|
||||||
|
|
||||||
return file;
|
|
||||||
} catch (Exception e) {
|
|
||||||
throw new RuntimeException(e);
|
|
||||||
}
|
|
||||||
|
|
||||||
} else {
|
|
||||||
/*
|
|
||||||
It's something in the actual underlying filesystem, so we can just go for it
|
|
||||||
*/
|
|
||||||
|
|
||||||
try {
|
|
||||||
URI uri = new URI(url.toString().replaceAll(" ", "%20"));
|
|
||||||
return new File(uri.getSchemeSpecificPart());
|
|
||||||
} catch (URISyntaxException e) {
|
|
||||||
return new File(url.getFile());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Checks, if proposed URL is packed into archive.
|
|
||||||
*
|
|
||||||
* @param url URL to be checked
|
|
||||||
* @return True, if URL is archive entry, False otherwise
|
|
||||||
*/
|
|
||||||
private boolean isJarURL(URL url) {
|
|
||||||
String protocol = url.getProtocol();
|
|
||||||
return "jar".equals(protocol) || "zip".equals(protocol) || "wsjar".equals(protocol)
|
|
||||||
|| "code-source".equals(protocol) && url.getPath().contains("!/");
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Extracts parent Jar URL from original ClassPath entry URL.
|
|
||||||
*
|
|
||||||
* @param jarUrl Original URL of the resource
|
|
||||||
* @return URL of the Jar file, containing requested resource
|
|
||||||
* @throws MalformedURLException
|
|
||||||
*/
|
|
||||||
private URL extractActualUrl(URL jarUrl) throws MalformedURLException {
|
|
||||||
String urlFile = jarUrl.getFile();
|
|
||||||
int separatorIndex = urlFile.indexOf("!/");
|
|
||||||
if (separatorIndex != -1) {
|
|
||||||
String jarFile = urlFile.substring(0, separatorIndex);
|
|
||||||
|
|
||||||
try {
|
|
||||||
return new URL(jarFile);
|
|
||||||
} catch (MalformedURLException var5) {
|
|
||||||
if (!jarFile.startsWith("/")) {
|
|
||||||
jarFile = "/" + jarFile;
|
|
||||||
}
|
|
||||||
|
|
||||||
return new URL("file:" + jarFile);
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
return jarUrl;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Returns requested ClassPathResource as InputStream object
|
|
||||||
*
|
|
||||||
* @return File requested at constructor call
|
|
||||||
* @throws FileNotFoundException
|
|
||||||
*/
|
|
||||||
public InputStream getInputStream() throws FileNotFoundException {
|
|
||||||
URL url = this.getUrl();
|
|
||||||
if (isJarURL(url)) {
|
|
||||||
try {
|
|
||||||
url = extractActualUrl(url);
|
|
||||||
ZipFile zipFile = new ZipFile(url.getFile());
|
|
||||||
ZipEntry entry = zipFile.getEntry(this.resourceName);
|
|
||||||
|
|
||||||
if (entry == null) {
|
|
||||||
if (this.resourceName.startsWith("/")) {
|
|
||||||
entry = zipFile.getEntry(this.resourceName.replaceFirst("/", ""));
|
|
||||||
if (entry == null) {
|
|
||||||
throw new FileNotFoundException("Resource " + this.resourceName + " not found");
|
|
||||||
}
|
|
||||||
} else
|
|
||||||
throw new FileNotFoundException("Resource " + this.resourceName + " not found");
|
|
||||||
}
|
|
||||||
|
|
||||||
return zipFile.getInputStream(entry);
|
|
||||||
} catch (Exception e) {
|
|
||||||
throw new RuntimeException(e);
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
File srcFile = this.getFile();
|
|
||||||
return new FileInputStream(srcFile);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,49 +0,0 @@
|
||||||
/*******************************************************************************
|
|
||||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
|
||||||
*
|
|
||||||
* This program and the accompanying materials are made available under the
|
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* License for the specific language governing permissions and limitations
|
|
||||||
* under the License.
|
|
||||||
*
|
|
||||||
* SPDX-License-Identifier: Apache-2.0
|
|
||||||
******************************************************************************/
|
|
||||||
|
|
||||||
package org.deeplearning4j.arbiter.util;
|
|
||||||
|
|
||||||
import java.util.ArrayList;
|
|
||||||
import java.util.Collection;
|
|
||||||
import java.util.HashSet;
|
|
||||||
import java.util.List;
|
|
||||||
|
|
||||||
public class CollectionUtils {
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Count the number of unique values in a collection
|
|
||||||
*/
|
|
||||||
public static int countUnique(Collection<?> collection) {
|
|
||||||
HashSet<Object> set = new HashSet<>(collection);
|
|
||||||
return set.size();
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Returns a list containing only unique values in a collection
|
|
||||||
*/
|
|
||||||
public static <T> List<T> getUnique(Collection<T> collection) {
|
|
||||||
HashSet<T> set = new HashSet<>();
|
|
||||||
List<T> out = new ArrayList<>();
|
|
||||||
for (T t : collection) {
|
|
||||||
if (!set.contains(t)) {
|
|
||||||
out.add(t);
|
|
||||||
set.add(t);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return out;
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
|
@ -1,74 +0,0 @@
|
||||||
/*******************************************************************************
|
|
||||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
|
||||||
*
|
|
||||||
* This program and the accompanying materials are made available under the
|
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* License for the specific language governing permissions and limitations
|
|
||||||
* under the License.
|
|
||||||
*
|
|
||||||
* SPDX-License-Identifier: Apache-2.0
|
|
||||||
******************************************************************************/
|
|
||||||
|
|
||||||
package org.deeplearning4j.arbiter.util;
|
|
||||||
|
|
||||||
import org.deeplearning4j.arbiter.optimize.api.ParameterSpace;
|
|
||||||
|
|
||||||
import java.util.ArrayList;
|
|
||||||
import java.util.List;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Created by Alex on 29/06/2017.
|
|
||||||
*/
|
|
||||||
public class LeafUtils {
|
|
||||||
|
|
||||||
private LeafUtils() {}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Returns a list of unique objects, not using the .equals() method, but rather using ==
|
|
||||||
*
|
|
||||||
* @param allLeaves Leaf values to process
|
|
||||||
* @return A list of unique parameter space values
|
|
||||||
*/
|
|
||||||
public static List<ParameterSpace> getUniqueObjects(List<ParameterSpace> allLeaves) {
|
|
||||||
List<ParameterSpace> unique = new ArrayList<>();
|
|
||||||
for (ParameterSpace p : allLeaves) {
|
|
||||||
//This isn't especially efficient, but small number of parameters in general means it's fine
|
|
||||||
boolean found = false;
|
|
||||||
for (ParameterSpace q : unique) {
|
|
||||||
if (p == q) {
|
|
||||||
found = true;
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if (!found) {
|
|
||||||
unique.add(p);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return unique;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Count the number of unique parameters in the specified leaf nodes
|
|
||||||
*
|
|
||||||
* @param allLeaves Leaf values to count the parameters fore
|
|
||||||
* @return Number of parameters for all unique objects
|
|
||||||
*/
|
|
||||||
public static int countUniqueParameters(List<ParameterSpace> allLeaves) {
|
|
||||||
List<ParameterSpace> unique = getUniqueObjects(allLeaves);
|
|
||||||
int count = 0;
|
|
||||||
for (ParameterSpace ps : unique) {
|
|
||||||
if (!ps.isLeaf()) {
|
|
||||||
throw new IllegalStateException("Method should only be used with leaf nodes");
|
|
||||||
}
|
|
||||||
count += ps.numParameters();
|
|
||||||
}
|
|
||||||
return count;
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
|
@ -1,61 +0,0 @@
|
||||||
/*******************************************************************************
|
|
||||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
|
||||||
*
|
|
||||||
* This program and the accompanying materials are made available under the
|
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* License for the specific language governing permissions and limitations
|
|
||||||
* under the License.
|
|
||||||
*
|
|
||||||
* SPDX-License-Identifier: Apache-2.0
|
|
||||||
******************************************************************************/
|
|
||||||
|
|
||||||
package org.deeplearning4j.arbiter.util;
|
|
||||||
|
|
||||||
import java.util.Arrays;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @author Alex Black
|
|
||||||
*/
|
|
||||||
public class ObjectUtils {
|
|
||||||
|
|
||||||
private ObjectUtils() {}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Get the string representation of the object. Arrays, including primitive arrays, are printed using
|
|
||||||
* Arrays.toString(...) methods.
|
|
||||||
*
|
|
||||||
* @param v Value to convert to a string
|
|
||||||
* @return String representation
|
|
||||||
*/
|
|
||||||
public static String valueToString(Object v) {
|
|
||||||
if (v.getClass().isArray()) {
|
|
||||||
if (v.getClass().getComponentType().isPrimitive()) {
|
|
||||||
Class<?> c = v.getClass().getComponentType();
|
|
||||||
if (c == int.class) {
|
|
||||||
return Arrays.toString((int[]) v);
|
|
||||||
} else if (c == double.class) {
|
|
||||||
return Arrays.toString((double[]) v);
|
|
||||||
} else if (c == float.class) {
|
|
||||||
return Arrays.toString((float[]) v);
|
|
||||||
} else if (c == long.class) {
|
|
||||||
return Arrays.toString((long[]) v);
|
|
||||||
} else if (c == byte.class) {
|
|
||||||
return Arrays.toString((byte[]) v);
|
|
||||||
} else if (c == short.class) {
|
|
||||||
return Arrays.toString((short[]) v);
|
|
||||||
} else {
|
|
||||||
return v.toString();
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
return Arrays.toString((Object[]) v);
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
return v.toString();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,49 +0,0 @@
|
||||||
/* ******************************************************************************
|
|
||||||
* Copyright (c) 2020 Konduit K.K.
|
|
||||||
*
|
|
||||||
* This program and the accompanying materials are made available under the
|
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* License for the specific language governing permissions and limitations
|
|
||||||
* under the License.
|
|
||||||
*
|
|
||||||
* SPDX-License-Identifier: Apache-2.0
|
|
||||||
******************************************************************************/
|
|
||||||
package org.deeplearning4j.arbiter.optimize;
|
|
||||||
|
|
||||||
import lombok.extern.slf4j.Slf4j;
|
|
||||||
import org.deeplearning4j.BaseDL4JTest;
|
|
||||||
import org.nd4j.common.tests.AbstractAssertTestsClass;
|
|
||||||
import java.util.*;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* This class checks that all test classes (i.e., anything with one or more methods annotated with @Test)
|
|
||||||
* extends BaseDl4jTest - either directly or indirectly.
|
|
||||||
* Other than a small set of exceptions, all tests must extend this
|
|
||||||
*
|
|
||||||
* @author Alex Black
|
|
||||||
*/
|
|
||||||
|
|
||||||
@Slf4j
|
|
||||||
public class AssertTestsExtendBaseClass extends AbstractAssertTestsClass {
|
|
||||||
|
|
||||||
@Override
|
|
||||||
protected Set<Class<?>> getExclusions() {
|
|
||||||
//Set of classes that are exclusions to the rule (either run manually or have their own logging + timeouts)
|
|
||||||
return new HashSet<>();
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
protected String getPackageName() {
|
|
||||||
return "org.deeplearning4j.arbiter.optimize";
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
protected Class<?> getBaseClass() {
|
|
||||||
return BaseDL4JTest.class;
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,156 +0,0 @@
|
||||||
/*******************************************************************************
|
|
||||||
* Copyright (c) 2015-2019 Skymind, Inc.
|
|
||||||
*
|
|
||||||
* This program and the accompanying materials are made available under the
|
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* License for the specific language governing permissions and limitations
|
|
||||||
* under the License.
|
|
||||||
*
|
|
||||||
* SPDX-License-Identifier: Apache-2.0
|
|
||||||
******************************************************************************/
|
|
||||||
|
|
||||||
package org.deeplearning4j.arbiter.optimize;
|
|
||||||
|
|
||||||
import lombok.AllArgsConstructor;
|
|
||||||
import lombok.Data;
|
|
||||||
import org.deeplearning4j.arbiter.optimize.api.*;
|
|
||||||
import org.deeplearning4j.arbiter.optimize.api.data.DataProvider;
|
|
||||||
import org.deeplearning4j.arbiter.optimize.api.data.DataSource;
|
|
||||||
import org.deeplearning4j.arbiter.optimize.api.score.ScoreFunction;
|
|
||||||
import org.deeplearning4j.arbiter.optimize.parameter.continuous.ContinuousParameterSpace;
|
|
||||||
import org.deeplearning4j.arbiter.optimize.runner.CandidateInfo;
|
|
||||||
import org.deeplearning4j.arbiter.optimize.runner.CandidateStatus;
|
|
||||||
import org.deeplearning4j.arbiter.optimize.runner.IOptimizationRunner;
|
|
||||||
import org.deeplearning4j.arbiter.optimize.runner.listener.StatusListener;
|
|
||||||
|
|
||||||
import java.io.Serializable;
|
|
||||||
import java.util.*;
|
|
||||||
import java.util.concurrent.Callable;
|
|
||||||
|
|
||||||
public class BraninFunction {
|
|
||||||
public static class BraninSpace extends AbstractParameterSpace<BraninConfig> {
|
|
||||||
private int[] indices;
|
|
||||||
private ParameterSpace<Double> first = new ContinuousParameterSpace(-5, 10);
|
|
||||||
private ParameterSpace<Double> second = new ContinuousParameterSpace(0, 15);
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public BraninConfig getValue(double[] parameterValues) {
|
|
||||||
double f = first.getValue(parameterValues);
|
|
||||||
double s = second.getValue(parameterValues);
|
|
||||||
return new BraninConfig(f, s); //-5 to +10 and 0 to 15
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public int numParameters() {
|
|
||||||
return 2;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public List<ParameterSpace> collectLeaves() {
|
|
||||||
List<ParameterSpace> list = new ArrayList<>();
|
|
||||||
list.addAll(first.collectLeaves());
|
|
||||||
list.addAll(second.collectLeaves());
|
|
||||||
return list;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public boolean isLeaf() {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void setIndices(int... indices) {
|
|
||||||
throw new UnsupportedOperationException();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
@AllArgsConstructor
|
|
||||||
@Data
|
|
||||||
public static class BraninConfig implements Serializable {
|
|
||||||
private double x1;
|
|
||||||
private double x2;
|
|
||||||
}
|
|
||||||
|
|
||||||
public static class BraninScoreFunction implements ScoreFunction {
|
|
||||||
private static final double a = 1.0;
|
|
||||||
private static final double b = 5.1 / (4.0 * Math.PI * Math.PI);
|
|
||||||
private static final double c = 5.0 / Math.PI;
|
|
||||||
private static final double r = 6.0;
|
|
||||||
private static final double s = 10.0;
|
|
||||||
private static final double t = 1.0 / (8.0 * Math.PI);
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public double score(Object m, DataProvider data, Map<String, Object> dataParameters) {
|
|
||||||
BraninConfig model = (BraninConfig) m;
|
|
||||||
double x1 = model.getX1();
|
|
||||||
double x2 = model.getX2();
|
|
||||||
|
|
||||||
return a * Math.pow(x2 - b * x1 * x1 + c * x1 - r, 2.0) + s * (1 - t) * Math.cos(x1) + s;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public double score(Object model, Class<? extends DataSource> dataSource, Properties dataSourceProperties) {
|
|
||||||
throw new UnsupportedOperationException();
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public boolean minimize() {
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public List<Class<?>> getSupportedModelTypes() {
|
|
||||||
return Collections.<Class<?>>singletonList(BraninConfig.class);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public List<Class<?>> getSupportedDataTypes() {
|
|
||||||
return Collections.<Class<?>>singletonList(Object.class);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
public static class BraninTaskCreator implements TaskCreator {
|
|
||||||
@Override
|
|
||||||
public Callable<OptimizationResult> create(final Candidate c, DataProvider dataProvider,
|
|
||||||
final ScoreFunction scoreFunction, final List<StatusListener> statusListeners,
|
|
||||||
IOptimizationRunner runner) {
|
|
||||||
|
|
||||||
return new Callable<OptimizationResult>() {
|
|
||||||
@Override
|
|
||||||
public OptimizationResult call() throws Exception {
|
|
||||||
|
|
||||||
BraninConfig candidate = (BraninConfig) c.getValue();
|
|
||||||
|
|
||||||
double score = scoreFunction.score(candidate, null, (Map) null);
|
|
||||||
// System.out.println(candidate.getX1() + "\t" + candidate.getX2() + "\t" + score);
|
|
||||||
|
|
||||||
Thread.sleep(20);
|
|
||||||
|
|
||||||
if (statusListeners != null) {
|
|
||||||
for (StatusListener sl : statusListeners) {
|
|
||||||
sl.onCandidateIteration(null, null, 0);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
CandidateInfo ci = new CandidateInfo(-1, CandidateStatus.Complete, score,
|
|
||||||
System.currentTimeMillis(), null, null, null, null);
|
|
||||||
|
|
||||||
return new OptimizationResult(c, score, c.getIndex(), null, ci, null);
|
|
||||||
}
|
|
||||||
};
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public Callable<OptimizationResult> create(Candidate candidate, Class<? extends DataSource> dataSource,
|
|
||||||
Properties dataSourceProperties, ScoreFunction scoreFunction,
|
|
||||||
List<StatusListener> statusListeners, IOptimizationRunner runner) {
|
|
||||||
throw new UnsupportedOperationException();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
|
@ -1,118 +0,0 @@
|
||||||
/*******************************************************************************
|
|
||||||
* Copyright (c) 2015-2019 Skymind, Inc.
|
|
||||||
*
|
|
||||||
* This program and the accompanying materials are made available under the
|
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* License for the specific language governing permissions and limitations
|
|
||||||
* under the License.
|
|
||||||
*
|
|
||||||
* SPDX-License-Identifier: Apache-2.0
|
|
||||||
******************************************************************************/
|
|
||||||
|
|
||||||
package org.deeplearning4j.arbiter.optimize;
|
|
||||||
|
|
||||||
import org.deeplearning4j.BaseDL4JTest;
|
|
||||||
import org.deeplearning4j.arbiter.optimize.api.CandidateGenerator;
|
|
||||||
import org.deeplearning4j.arbiter.optimize.api.score.ScoreFunction;
|
|
||||||
import org.deeplearning4j.arbiter.optimize.api.termination.MaxCandidatesCondition;
|
|
||||||
import org.deeplearning4j.arbiter.optimize.api.termination.TerminationCondition;
|
|
||||||
import org.deeplearning4j.arbiter.optimize.config.OptimizationConfiguration;
|
|
||||||
import org.deeplearning4j.arbiter.optimize.generator.GeneticSearchCandidateGenerator;
|
|
||||||
import org.deeplearning4j.arbiter.optimize.generator.genetic.exceptions.GeneticGenerationException;
|
|
||||||
import org.deeplearning4j.arbiter.optimize.generator.genetic.selection.SelectionOperator;
|
|
||||||
import org.deeplearning4j.arbiter.optimize.runner.CandidateInfo;
|
|
||||||
import org.deeplearning4j.arbiter.optimize.runner.CandidateStatus;
|
|
||||||
import org.deeplearning4j.arbiter.optimize.runner.IOptimizationRunner;
|
|
||||||
import org.deeplearning4j.arbiter.optimize.runner.LocalOptimizationRunner;
|
|
||||||
import org.deeplearning4j.arbiter.optimize.runner.listener.impl.LoggingStatusListener;
|
|
||||||
import org.junit.Assert;
|
|
||||||
import org.junit.Test;
|
|
||||||
|
|
||||||
public class TestGeneticSearch extends BaseDL4JTest {
|
|
||||||
public class TestSelectionOperator extends SelectionOperator {
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public double[] buildNextGenes() {
|
|
||||||
throw new GeneticGenerationException("Forced exception to test exception handling.");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
public class TestTerminationCondition implements TerminationCondition {
|
|
||||||
|
|
||||||
public boolean hasAFailedCandidate = false;
|
|
||||||
public int evalCount = 0;
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void initialize(IOptimizationRunner optimizationRunner) {}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public boolean terminate(IOptimizationRunner optimizationRunner) {
|
|
||||||
if (++evalCount == 50) {
|
|
||||||
// Generator did not handle GeneticGenerationException
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
for (CandidateInfo candidateInfo : optimizationRunner.getCandidateStatus()) {
|
|
||||||
if (candidateInfo.getCandidateStatus() == CandidateStatus.Failed) {
|
|
||||||
hasAFailedCandidate = true;
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
|
||||||
public void GeneticSearchCandidateGenerator_getCandidate_ShouldGenerateCandidates() throws Exception {
|
|
||||||
|
|
||||||
ScoreFunction scoreFunction = new BraninFunction.BraninScoreFunction();
|
|
||||||
|
|
||||||
//Define configuration:
|
|
||||||
CandidateGenerator candidateGenerator =
|
|
||||||
new GeneticSearchCandidateGenerator.Builder(new BraninFunction.BraninSpace(), scoreFunction)
|
|
||||||
.build();
|
|
||||||
|
|
||||||
TestTerminationCondition testTerminationCondition = new TestTerminationCondition();
|
|
||||||
OptimizationConfiguration configuration = new OptimizationConfiguration.Builder()
|
|
||||||
.candidateGenerator(candidateGenerator).scoreFunction(scoreFunction)
|
|
||||||
.terminationConditions(new MaxCandidatesCondition(50), testTerminationCondition).build();
|
|
||||||
|
|
||||||
IOptimizationRunner runner = new LocalOptimizationRunner(configuration, new BraninFunction.BraninTaskCreator());
|
|
||||||
|
|
||||||
runner.addListeners(new LoggingStatusListener());
|
|
||||||
runner.execute();
|
|
||||||
|
|
||||||
Assert.assertFalse(testTerminationCondition.hasAFailedCandidate);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
|
||||||
public void GeneticSearchCandidateGenerator_getCandidate_GeneticExceptionShouldMarkCandidateAsFailed() {
|
|
||||||
|
|
||||||
ScoreFunction scoreFunction = new BraninFunction.BraninScoreFunction();
|
|
||||||
|
|
||||||
//Define configuration:
|
|
||||||
CandidateGenerator candidateGenerator =
|
|
||||||
new GeneticSearchCandidateGenerator.Builder(new BraninFunction.BraninSpace(), scoreFunction)
|
|
||||||
.selectionOperator(new TestSelectionOperator()).build();
|
|
||||||
|
|
||||||
TestTerminationCondition testTerminationCondition = new TestTerminationCondition();
|
|
||||||
|
|
||||||
OptimizationConfiguration configuration = new OptimizationConfiguration.Builder()
|
|
||||||
.candidateGenerator(candidateGenerator).scoreFunction(scoreFunction)
|
|
||||||
.terminationConditions(testTerminationCondition).build();
|
|
||||||
|
|
||||||
IOptimizationRunner runner = new LocalOptimizationRunner(configuration, new BraninFunction.BraninTaskCreator());
|
|
||||||
|
|
||||||
runner.addListeners(new LoggingStatusListener());
|
|
||||||
runner.execute();
|
|
||||||
|
|
||||||
Assert.assertTrue(testTerminationCondition.hasAFailedCandidate);
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
|
@ -1,104 +0,0 @@
|
||||||
/*******************************************************************************
|
|
||||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
|
||||||
*
|
|
||||||
* This program and the accompanying materials are made available under the
|
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* License for the specific language governing permissions and limitations
|
|
||||||
* under the License.
|
|
||||||
*
|
|
||||||
* SPDX-License-Identifier: Apache-2.0
|
|
||||||
******************************************************************************/
|
|
||||||
|
|
||||||
package org.deeplearning4j.arbiter.optimize;
|
|
||||||
|
|
||||||
import org.deeplearning4j.BaseDL4JTest;
|
|
||||||
import org.deeplearning4j.arbiter.optimize.api.CandidateGenerator;
|
|
||||||
import org.deeplearning4j.arbiter.optimize.api.data.DataSetIteratorFactoryProvider;
|
|
||||||
import org.deeplearning4j.arbiter.optimize.generator.GridSearchCandidateGenerator;
|
|
||||||
import org.junit.Test;
|
|
||||||
|
|
||||||
import java.util.HashMap;
|
|
||||||
import java.util.Map;
|
|
||||||
|
|
||||||
import static org.junit.Assert.*;
|
|
||||||
|
|
||||||
public class TestGridSearch extends BaseDL4JTest {
|
|
||||||
|
|
||||||
@Test
|
|
||||||
public void testIndexing() {
|
|
||||||
int[] nValues = {2, 3};
|
|
||||||
int prod = 2 * 3;
|
|
||||||
double[][] expVals = new double[][] {{0.0, 0.0}, {1.0, 0.0}, {0.0, 0.5}, {1.0, 0.5}, {0.0, 1.0}, {1.0, 1.0}};
|
|
||||||
for (int i = 0; i < prod; i++) {
|
|
||||||
double[] out = GridSearchCandidateGenerator.indexToValues(nValues, i, prod);
|
|
||||||
double[] exp = expVals[i];
|
|
||||||
assertArrayEquals(exp, out, 1e-4);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
|
||||||
public void testGeneration() throws Exception {
|
|
||||||
Map<String, Object> commands = new HashMap<>();
|
|
||||||
commands.put(DataSetIteratorFactoryProvider.FACTORY_KEY, new HashMap<>());
|
|
||||||
|
|
||||||
//Define configuration:
|
|
||||||
CandidateGenerator candidateGenerator = new GridSearchCandidateGenerator(new BraninFunction.BraninSpace(), 4,
|
|
||||||
GridSearchCandidateGenerator.Mode.Sequential, commands);
|
|
||||||
|
|
||||||
//Check sequential:
|
|
||||||
double[] expValuesFirst = {-5, 0, 5, 10}; //Range: -5 to +10, with 4 values
|
|
||||||
double[] expValuesSecond = {0, 5, 10, 15}; //Range: 0 to +15, with 4 values
|
|
||||||
for (int i = 0; i < 4 * 4; i++) {
|
|
||||||
BraninFunction.BraninConfig conf = (BraninFunction.BraninConfig) candidateGenerator.getCandidate().getValue();
|
|
||||||
double expF = expValuesFirst[i % 4]; //Changes most rapidly
|
|
||||||
double expS = expValuesSecond[i / 4];
|
|
||||||
|
|
||||||
double actF = conf.getX1();
|
|
||||||
double actS = conf.getX2();
|
|
||||||
|
|
||||||
assertEquals(expF, actF, 1e-4);
|
|
||||||
assertEquals(expS, actS, 1e-4);
|
|
||||||
}
|
|
||||||
|
|
||||||
//Check random order. specifically: check that all values are generated, in some order
|
|
||||||
double[][] orderedOutput = new double[16][2];
|
|
||||||
for (int i = 0; i < expValuesFirst.length; i++) {
|
|
||||||
for (int j = 0; j < expValuesSecond.length; j++) {
|
|
||||||
orderedOutput[4 * j + i][0] = expValuesFirst[i];
|
|
||||||
orderedOutput[4 * j + i][1] = expValuesSecond[j];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
candidateGenerator = new GridSearchCandidateGenerator(new BraninFunction.BraninSpace(), 4,
|
|
||||||
GridSearchCandidateGenerator.Mode.RandomOrder, commands);
|
|
||||||
boolean[] seen = new boolean[16];
|
|
||||||
int seenCount = 0;
|
|
||||||
for (int i = 0; i < 4 * 4; i++) {
|
|
||||||
assertTrue(candidateGenerator.hasMoreCandidates());
|
|
||||||
BraninFunction.BraninConfig config = (BraninFunction.BraninConfig) candidateGenerator.getCandidate().getValue();
|
|
||||||
double x1 = config.getX1();
|
|
||||||
double x2 = config.getX2();
|
|
||||||
//Work out which of the values this is...
|
|
||||||
boolean matched = false;
|
|
||||||
for (int j = 0; j < 16; j++) {
|
|
||||||
if (Math.abs(orderedOutput[j][0] - x1) < 1e-5 && Math.abs(orderedOutput[j][1] - x2) < 1e-5) {
|
|
||||||
matched = true;
|
|
||||||
if (seen[j])
|
|
||||||
fail("Same candidate generated multiple times");
|
|
||||||
seen[j] = true;
|
|
||||||
seenCount++;
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
assertTrue("Candidate " + x1 + ", " + x2 + " not found; invalid?", matched);
|
|
||||||
}
|
|
||||||
assertFalse(candidateGenerator.hasMoreCandidates());
|
|
||||||
assertEquals(16, seenCount);
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,122 +0,0 @@
|
||||||
/*******************************************************************************
|
|
||||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
|
||||||
*
|
|
||||||
* This program and the accompanying materials are made available under the
|
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* License for the specific language governing permissions and limitations
|
|
||||||
* under the License.
|
|
||||||
*
|
|
||||||
* SPDX-License-Identifier: Apache-2.0
|
|
||||||
******************************************************************************/
|
|
||||||
|
|
||||||
package org.deeplearning4j.arbiter.optimize;
|
|
||||||
|
|
||||||
import org.apache.commons.math3.distribution.LogNormalDistribution;
|
|
||||||
import org.apache.commons.math3.distribution.NormalDistribution;
|
|
||||||
import org.apache.commons.math3.distribution.UniformIntegerDistribution;
|
|
||||||
import org.deeplearning4j.BaseDL4JTest;
|
|
||||||
import org.deeplearning4j.arbiter.optimize.api.CandidateGenerator;
|
|
||||||
import org.deeplearning4j.arbiter.optimize.api.ParameterSpace;
|
|
||||||
import org.deeplearning4j.arbiter.optimize.api.data.DataSetIteratorFactoryProvider;
|
|
||||||
import org.deeplearning4j.arbiter.optimize.generator.GridSearchCandidateGenerator;
|
|
||||||
import org.deeplearning4j.arbiter.optimize.generator.RandomSearchGenerator;
|
|
||||||
import org.deeplearning4j.arbiter.optimize.parameter.BooleanSpace;
|
|
||||||
import org.deeplearning4j.arbiter.optimize.parameter.FixedValue;
|
|
||||||
import org.deeplearning4j.arbiter.optimize.parameter.continuous.ContinuousParameterSpace;
|
|
||||||
import org.deeplearning4j.arbiter.optimize.parameter.discrete.DiscreteParameterSpace;
|
|
||||||
import org.deeplearning4j.arbiter.optimize.parameter.integer.IntegerParameterSpace;
|
|
||||||
import org.junit.Test;
|
|
||||||
import org.nd4j.shade.jackson.annotation.JsonAutoDetect;
|
|
||||||
import org.nd4j.shade.jackson.annotation.PropertyAccessor;
|
|
||||||
import org.nd4j.shade.jackson.core.JsonFactory;
|
|
||||||
import org.nd4j.shade.jackson.databind.DeserializationFeature;
|
|
||||||
import org.nd4j.shade.jackson.databind.ObjectMapper;
|
|
||||||
import org.nd4j.shade.jackson.databind.SerializationFeature;
|
|
||||||
import org.nd4j.shade.jackson.dataformat.yaml.YAMLFactory;
|
|
||||||
import org.nd4j.shade.jackson.datatype.joda.JodaModule;
|
|
||||||
|
|
||||||
import java.util.ArrayList;
|
|
||||||
import java.util.HashMap;
|
|
||||||
import java.util.List;
|
|
||||||
import java.util.Map;
|
|
||||||
|
|
||||||
import static org.junit.Assert.assertEquals;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Created by Alex on 02/02/2017.
|
|
||||||
*/
|
|
||||||
public class TestJson extends BaseDL4JTest {
|
|
||||||
|
|
||||||
protected static ObjectMapper getObjectMapper(JsonFactory factory) {
|
|
||||||
ObjectMapper om = new ObjectMapper(factory);
|
|
||||||
om.registerModule(new JodaModule());
|
|
||||||
om.configure(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, false);
|
|
||||||
om.configure(SerializationFeature.FAIL_ON_EMPTY_BEANS, false);
|
|
||||||
om.enable(SerializationFeature.INDENT_OUTPUT);
|
|
||||||
om.setVisibility(PropertyAccessor.ALL, JsonAutoDetect.Visibility.NONE);
|
|
||||||
om.setVisibility(PropertyAccessor.FIELD, JsonAutoDetect.Visibility.ANY);
|
|
||||||
om.setVisibility(PropertyAccessor.CREATOR, JsonAutoDetect.Visibility.ANY);
|
|
||||||
return om;
|
|
||||||
}
|
|
||||||
|
|
||||||
private static ObjectMapper jsonMapper = getObjectMapper(new JsonFactory());
|
|
||||||
private static ObjectMapper yamlMapper = getObjectMapper(new YAMLFactory());
|
|
||||||
|
|
||||||
|
|
||||||
@Test
|
|
||||||
public void testParameterSpaceJson() throws Exception {
|
|
||||||
|
|
||||||
List<ParameterSpace<?>> l = new ArrayList<>();
|
|
||||||
l.add(new FixedValue<>(1.0));
|
|
||||||
l.add(new FixedValue<>(1));
|
|
||||||
l.add(new FixedValue<>("string"));
|
|
||||||
l.add(new ContinuousParameterSpace(-1, 1));
|
|
||||||
l.add(new ContinuousParameterSpace(new LogNormalDistribution(1, 1)));
|
|
||||||
l.add(new ContinuousParameterSpace(new NormalDistribution(2, 0.01)));
|
|
||||||
l.add(new DiscreteParameterSpace<>(1, 5, 7));
|
|
||||||
l.add(new DiscreteParameterSpace<>("first", "second", "third"));
|
|
||||||
l.add(new IntegerParameterSpace(0, 10));
|
|
||||||
l.add(new IntegerParameterSpace(new UniformIntegerDistribution(0, 50)));
|
|
||||||
l.add(new BooleanSpace());
|
|
||||||
|
|
||||||
for (ParameterSpace<?> ps : l) {
|
|
||||||
String strJson = jsonMapper.writeValueAsString(ps);
|
|
||||||
String strYaml = yamlMapper.writeValueAsString(ps);
|
|
||||||
|
|
||||||
ParameterSpace<?> fromJson = jsonMapper.readValue(strJson, ParameterSpace.class);
|
|
||||||
ParameterSpace<?> fromYaml = yamlMapper.readValue(strYaml, ParameterSpace.class);
|
|
||||||
|
|
||||||
assertEquals(ps, fromJson);
|
|
||||||
assertEquals(ps, fromYaml);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
|
||||||
public void testCandidateGeneratorJson() throws Exception {
|
|
||||||
Map<String, Object> commands = new HashMap<>();
|
|
||||||
commands.put(DataSetIteratorFactoryProvider.FACTORY_KEY, new HashMap<>());
|
|
||||||
|
|
||||||
List<CandidateGenerator> l = new ArrayList<>();
|
|
||||||
l.add(new GridSearchCandidateGenerator(new DiscreteParameterSpace<>(0, 1, 2, 3, 4, 5), 10,
|
|
||||||
GridSearchCandidateGenerator.Mode.Sequential, commands));
|
|
||||||
l.add(new GridSearchCandidateGenerator(new DiscreteParameterSpace<>(0, 1, 2, 3, 4, 5), 10,
|
|
||||||
GridSearchCandidateGenerator.Mode.RandomOrder, commands));
|
|
||||||
l.add(new RandomSearchGenerator(new DiscreteParameterSpace<>(0, 1, 2, 3, 4, 5), commands));
|
|
||||||
|
|
||||||
for (CandidateGenerator cg : l) {
|
|
||||||
String strJson = jsonMapper.writeValueAsString(cg);
|
|
||||||
String strYaml = yamlMapper.writeValueAsString(cg);
|
|
||||||
|
|
||||||
CandidateGenerator fromJson = jsonMapper.readValue(strJson, CandidateGenerator.class);
|
|
||||||
CandidateGenerator fromYaml = yamlMapper.readValue(strYaml, CandidateGenerator.class);
|
|
||||||
|
|
||||||
assertEquals(cg, fromJson);
|
|
||||||
assertEquals(cg, fromYaml);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,61 +0,0 @@
|
||||||
/*******************************************************************************
|
|
||||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
|
||||||
*
|
|
||||||
* This program and the accompanying materials are made available under the
|
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* License for the specific language governing permissions and limitations
|
|
||||||
* under the License.
|
|
||||||
*
|
|
||||||
* SPDX-License-Identifier: Apache-2.0
|
|
||||||
******************************************************************************/
|
|
||||||
|
|
||||||
package org.deeplearning4j.arbiter.optimize;
|
|
||||||
|
|
||||||
import org.deeplearning4j.BaseDL4JTest;
|
|
||||||
import org.deeplearning4j.arbiter.optimize.api.CandidateGenerator;
|
|
||||||
import org.deeplearning4j.arbiter.optimize.api.data.DataSetIteratorFactoryProvider;
|
|
||||||
import org.deeplearning4j.arbiter.optimize.api.termination.MaxCandidatesCondition;
|
|
||||||
import org.deeplearning4j.arbiter.optimize.generator.RandomSearchGenerator;
|
|
||||||
import org.deeplearning4j.arbiter.optimize.config.OptimizationConfiguration;
|
|
||||||
import org.deeplearning4j.arbiter.optimize.runner.IOptimizationRunner;
|
|
||||||
import org.deeplearning4j.arbiter.optimize.runner.LocalOptimizationRunner;
|
|
||||||
import org.deeplearning4j.arbiter.optimize.runner.listener.impl.LoggingStatusListener;
|
|
||||||
import org.junit.Test;
|
|
||||||
|
|
||||||
import java.util.HashMap;
|
|
||||||
import java.util.Map;
|
|
||||||
|
|
||||||
/**
|
|
||||||
*
|
|
||||||
* Test random search on the Branin Function:
|
|
||||||
* http://www.sfu.ca/~ssurjano/branin.html
|
|
||||||
*/
|
|
||||||
public class TestRandomSearch extends BaseDL4JTest {
|
|
||||||
|
|
||||||
@Test
|
|
||||||
public void test() throws Exception {
|
|
||||||
Map<String, Object> commands = new HashMap<>();
|
|
||||||
commands.put(DataSetIteratorFactoryProvider.FACTORY_KEY, new HashMap<>());
|
|
||||||
|
|
||||||
//Define configuration:
|
|
||||||
CandidateGenerator candidateGenerator = new RandomSearchGenerator(new BraninFunction.BraninSpace(), commands);
|
|
||||||
OptimizationConfiguration configuration = new OptimizationConfiguration.Builder()
|
|
||||||
.candidateGenerator(candidateGenerator).scoreFunction(new BraninFunction.BraninScoreFunction())
|
|
||||||
.terminationConditions(new MaxCandidatesCondition(50)).build();
|
|
||||||
|
|
||||||
IOptimizationRunner runner = new LocalOptimizationRunner(configuration, new BraninFunction.BraninTaskCreator());
|
|
||||||
|
|
||||||
runner.addListeners(new LoggingStatusListener());
|
|
||||||
runner.execute();
|
|
||||||
|
|
||||||
|
|
||||||
// System.out.println("----- Complete -----");
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
}
|
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue