Eclipse -> Konduit update (#188)

* Update Japanese translation for Deeplearning4J UI (#8525)

Signed-off-by: k-tamura <ktamura.biz.80@gmail.com>

* RL4J: Remove processing done on observations in Policy & Async (#8471)

* Removed processing from Policy.play() and fixed missing resets

Signed-off-by: unknown <aboulang2002@yahoo.com>

* Adjusted unit test to check if DQNs have been reset

Signed-off-by: unknown <aboulang2002@yahoo.com>

* Fixed a couple of problems, added and updated unit tests

Signed-off-by: unknown <aboulang2002@yahoo.com>

* Removed processing from AsyncThreadDiscrete

Signed-off-by: unknown <aboulang2002@yahoo.com>

* Fixed a few problems

Signed-off-by: unknown <aboulang2002@yahoo.com>

* python version bump

* increase

* RL4J: Replace gym-java-client with JavaCPP (#8595)

* RL4J: Replace gym-java-client with JavaCPP

Signed-off-by: Samuel Audet <samuel.audet@gmail.com>

Co-authored-by: Kohei Tamura <ktamura.biz.80@gmail.com>
Co-authored-by: Alexandre Boulanger <44292157+aboulang2002@users.noreply.github.com>
Co-authored-by: Max Pumperla <max.pumperla@googlemail.com>
Co-authored-by: Samuel Audet <samuel.audet@gmail.com>
master
Alex Black 2020-01-27 16:03:00 +11:00 committed by GitHub
parent 458d141d8e
commit 95db34e389
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
67 changed files with 1167 additions and 2054 deletions

View File

@ -7,7 +7,6 @@ Welcome to the new monorepo of Deeplearning4j that contains the source code for
* https://github.com/eclipse/deeplearning4j/tree/master/datavec
* https://github.com/eclipse/deeplearning4j/tree/master/arbiter
* https://github.com/eclipse/deeplearning4j/tree/master/nd4s
* https://github.com/eclipse/deeplearning4j/tree/master/gym-java-client
* https://github.com/eclipse/deeplearning4j/tree/master/rl4j
* https://github.com/eclipse/deeplearning4j/tree/master/scalnet
* https://github.com/eclipse/deeplearning4j/tree/master/pydl4j

View File

@ -4,5 +4,5 @@ train.nav.model=モデル
train.nav.system=システム
train.nav.userguide=ユーザーガイド
train.nav.language=言語
train.session.label=Session
train.session.worker.label=Worker
train.session.label=セッション
train.session.worker.label=ワーカー

View File

@ -5,9 +5,9 @@ train.model.lrChart.title=パラメータ学習率
train.model.lrChart.titleShort=学習率
train.model.paramHistChart.title=レイヤーパラメータヒストグラム
train.model.updateHistChart.title=レイヤー更新ヒストグラム
train.model.meanmag.btn.ratio=Ratio
train.model.meanmag.btn.param=Param
train.model.meanmag.btn.update=Updates
train.model.meanmag.btn.ratio=比率
train.model.meanmag.btn.param=パラメータ
train.model.meanmag.btn.update=更新
train.model.layerinfotable.layerName=レイヤー名
train.model.layerinfotable.layerType=レイヤータイプ
train.model.layerinfotable.layerNIn=入力サイズ
@ -19,4 +19,4 @@ train.model.layerinfotable.layerUpdater=更新の方法
train.model.layerinfotable.layerSubsamplingPoolingType=プーリングタイプ
train.model.layerinfotable.layerCnnKernel=カーネルサイズ
train.model.layerinfotable.layerCnnStride=ストライド
train.model.layerinfotable.layerCnnPadding=パディング
train.model.layerinfotable.layerCnnPadding=パディング

View File

@ -1,10 +1,10 @@
train.system.title=システム詳細
train.system.selectMachine=Select Machine
train.system.selectMachine=マシンを選択
train.system.chart.memoryShort=メモリ
train.system.chart.systemMemoryTitle=JVM and Off-Heap Memory Utilization
train.system.chart.gpuMemoryTitle=GPU Memory Utilization
train.system.chart.key.jvm=JVM Memory
train.system.chart.key.offHeap=Off Heap Memory
train.system.chart.systemMemoryTitle=JVMとオフヒープのメモリ使用率
train.system.chart.gpuMemoryTitle=GPUメモリ使用率
train.system.chart.key.jvm=JVMメモリ
train.system.chart.key.offHeap=オフヒープメモリ
train.system.hwTable.title=ハードウェアの情報
train.system.hwTable.jvmCurrent=JVM現在メモリ
train.system.hwTable.jvmMax=JVM最大メモリ
@ -13,7 +13,7 @@ train.system.hwTable.offHeapMax=オフヒープ最大メモリ
train.system.hwTable.jvmProcs=JVM使用可能プロセッサ
train.system.hwTable.computeDevices=計算デバイス数
train.system.hwTable.deviceMemory=デバイスメモリ
train.system.hwTable.deviceName=Device Name
train.system.hwTable.deviceName=デバイス名
train.system.swTable.title=ソフトウェアの情報
train.system.swTable.hostname=ホスト名
train.system.swTable.os=OSの種類
@ -21,4 +21,4 @@ train.system.swTable.osArch=OSのアーキテクチャ
train.system.swTable.jvmName=JVM名
train.system.swTable.jvmVersion=JVMバージョン
train.system.swTable.nd4jBackend=ND4Jバックエンド
train.system.swTable.nd4jDataType=ND4Jデータ型
train.system.swTable.nd4jDataType=ND4Jデータ型

View File

@ -1,4 +0,0 @@
target/
.idea/
*.iml
*-git.properties

View File

@ -1,201 +0,0 @@
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "{}"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright {yyyy} {name of copyright owner}
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.

View File

@ -1,108 +0,0 @@
# gym-java-client
A java http client for [gym-http-api](https://github.com/openai/gym-http-api).
Note: If you are encountering errors as reported in [issue #13](https://github.com/deeplearning4j/gym-java-client/issues/13), please execute the following command before launching `python gym_http_server.py`:
```bash
$ sudo sysctl -w net.ipv4.tcp_tw_recycle=1
```
# Quickstart
To create a new Client, use the ClientFactory. If the url is not localhost:5000, provide it as a second argument
```java
Client<Box, Integer, DiscreteSpace> client = ClientFactory.build("CartPole-v0");
```
"CartPole-v0" is the name of the gym environment.
The type parameters of a client are the Observation type, the Action type, the Observation Space type and the ActionSpace type.
It is a bit cumbersome to both declare an ActionSpace and an Action since an ActionSpace knows what type is an Action but unfortunately java does't support type member and path dependant types.
Here we use Box and BoxSpace for the environment and Integer and Discrete Space because it is how [CartPole-v0](https://gym.openai.com/envs/CartPole-v0) is specified.
The methods nomenclature follows closely the api interface of [gym-http-api](https://github.com/openai/gym-http-api#api-specification), O is Observation an A is Action:
```java
//Static methods
/**
* @param url url of the server
* @return set of all environments running on the server at the url
*/
public static Set<String> listAll(String url);
/**
* Shutdown the server at the url
*
* @param url url of the server
*/
public static void serverShutdown(String url);
//Methods accessible from a Client
/**
* @return set of all environments running on the same server than this client
*/
public Set<String> listAll();
/**
* Step the environment by one action
*
* @param action action to step the environment with
* @return the StepReply containing the next observation, the reward, if it is a terminal state and optional information.
*/
public StepReply<O> step(A action);
/**
* Reset the state of the environment and return an initial observation.
*
* @return initial observation
*/
public O reset();
/**
* Start monitoring.
*
* @param directory path to directory in which store the monitoring file
* @param force clear out existing training data from this directory (by deleting every file prefixed with "openaigym.")
* @param resume retain the training data already in this directory, which will be merged with our new data
*/
public void monitorStart(String directory, boolean force, boolean resume);
/**
* Flush all monitor data to disk
*/
public void monitorClose();
/**
* Upload monitoring data to OpenAI servers.
*
* @param trainingDir directory that contains the monitoring data
* @param apiKey personal OpenAI API key
* @param algorithmId an arbitrary string indicating the paricular version of the algorithm (including choices of parameters) you are running.
**/
public void upload(String trainingDir, String apiKey, String algorithmId);
/**
* Upload monitoring data to OpenAI servers.
*
* @param trainingDir directory that contains the monitoring data
* @param apiKey personal OpenAI API key
*/
public void upload(String trainingDir, String apiKey);
/**
* Shutdown the server at the same url than this client
*/
public void serverShutdown()
```
## TODO
* Add all ObservationSpace and ActionSpace when they will be available.

View File

@ -1,354 +0,0 @@
<?xml version="1.0" encoding="UTF-8" standalone="no"?>
<!--~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
~ Copyright (c) 2015-2018 Skymind, Inc.
~
~ This program and the accompanying materials are made available under the
~ terms of the Apache License, Version 2.0 which is available at
~ https://www.apache.org/licenses/LICENSE-2.0.
~
~ Unless required by applicable law or agreed to in writing, software
~ distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
~ WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
~ License for the specific language governing permissions and limitations
~ under the License.
~
~ SPDX-License-Identifier: Apache-2.0
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~-->
<profiles version="13">
<profile kind="CodeFormatterProfile" name="GoogleStyle" version="13">
<setting id="org.eclipse.jdt.core.formatter.comment.insert_new_line_before_root_tags" value="insert"/>
<setting id="org.eclipse.jdt.core.formatter.disabling_tag" value="@formatter:off"/>
<setting id="org.eclipse.jdt.core.formatter.insert_space_after_comma_in_annotation" value="insert"/>
<setting id="org.eclipse.jdt.core.formatter.insert_space_before_comma_in_type_parameters" value="do not insert"/>
<setting id="org.eclipse.jdt.core.formatter.insert_space_before_opening_brace_in_type_declaration" value="insert"/>
<setting id="org.eclipse.jdt.core.formatter.insert_space_after_comma_in_type_arguments" value="insert"/>
<setting id="org.eclipse.jdt.core.formatter.brace_position_for_anonymous_type_declaration" value="end_of_line"/>
<setting id="org.eclipse.jdt.core.formatter.insert_space_before_colon_in_case" value="do not insert"/>
<setting id="org.eclipse.jdt.core.formatter.insert_space_after_opening_brace_in_array_initializer" value="do not insert"/>
<setting id="org.eclipse.jdt.core.formatter.comment.new_lines_at_block_boundaries" value="true"/>
<setting id="org.eclipse.jdt.core.formatter.alignment_for_cascading_method_invocation_with_arguments.count_dependent" value="16|-1|16"/>
<setting id="org.eclipse.jdt.core.formatter.insert_new_line_in_empty_annotation_declaration" value="insert"/>
<setting id="org.eclipse.jdt.core.formatter.insert_space_after_opening_paren_in_annotation" value="do not insert"/>
<setting id="org.eclipse.jdt.core.formatter.blank_lines_before_field" value="0"/>
<setting id="org.eclipse.jdt.core.formatter.insert_space_after_opening_paren_in_while" value="do not insert"/>
<setting id="org.eclipse.jdt.core.formatter.use_on_off_tags" value="false"/>
<setting id="org.eclipse.jdt.core.formatter.wrap_prefer_two_fragments" value="false"/>
<setting id="org.eclipse.jdt.core.formatter.insert_space_between_empty_parens_in_annotation_type_member_declaration" value="do not insert"/>
<setting id="org.eclipse.jdt.core.formatter.insert_new_line_before_else_in_if_statement" value="do not insert"/>
<setting id="org.eclipse.jdt.core.formatter.insert_space_after_prefix_operator" value="do not insert"/>
<setting id="org.eclipse.jdt.core.formatter.keep_else_statement_on_same_line" value="false"/>
<setting id="org.eclipse.jdt.core.formatter.insert_space_after_ellipsis" value="insert"/>
<setting id="org.eclipse.jdt.core.formatter.comment.insert_new_line_for_parameter" value="do not insert"/>
<setting id="org.eclipse.jdt.core.formatter.wrap_comment_inline_tags" value="false"/>
<setting id="org.eclipse.jdt.core.formatter.insert_space_before_opening_brace_in_annotation_type_declaration" value="insert"/>
<setting id="org.eclipse.jdt.core.formatter.indent_breaks_compare_to_cases" value="true"/>
<setting id="org.eclipse.jdt.core.formatter.insert_space_after_at_in_annotation" value="do not insert"/>
<setting id="org.eclipse.jdt.core.formatter.alignment_for_local_variable_declaration" value="16"/>
<setting id="org.eclipse.jdt.core.formatter.alignment_for_multiple_fields" value="16"/>
<setting id="org.eclipse.jdt.core.formatter.alignment_for_annotations_on_parameter" value="1040"/>
<setting id="org.eclipse.jdt.core.formatter.alignment_for_expressions_in_array_initializer" value="16"/>
<setting id="org.eclipse.jdt.core.formatter.alignment_for_annotations_on_type.count_dependent" value="1585|-1|1585"/>
<setting id="org.eclipse.jdt.core.formatter.alignment_for_conditional_expression" value="80"/>
<setting id="org.eclipse.jdt.core.formatter.insert_space_before_opening_paren_in_for" value="insert"/>
<setting id="org.eclipse.jdt.core.formatter.alignment_for_multiple_fields.count_dependent" value="16|-1|16"/>
<setting id="org.eclipse.jdt.core.formatter.insert_space_after_binary_operator" value="insert"/>
<setting id="org.eclipse.jdt.core.formatter.insert_space_before_question_in_wildcard" value="do not insert"/>
<setting id="org.eclipse.jdt.core.formatter.brace_position_for_array_initializer" value="end_of_line"/>
<setting id="org.eclipse.jdt.core.formatter.insert_space_between_empty_parens_in_enum_constant" value="do not insert"/>
<setting id="org.eclipse.jdt.core.formatter.insert_new_line_before_finally_in_try_statement" value="do not insert"/>
<setting id="org.eclipse.jdt.core.formatter.insert_new_line_after_annotation_on_local_variable" value="insert"/>
<setting id="org.eclipse.jdt.core.formatter.insert_new_line_before_catch_in_try_statement" value="do not insert"/>
<setting id="org.eclipse.jdt.core.formatter.insert_space_before_opening_paren_in_while" value="insert"/>
<setting id="org.eclipse.jdt.core.formatter.blank_lines_after_package" value="1"/>
<setting id="org.eclipse.jdt.core.formatter.alignment_for_arguments_in_qualified_allocation_expression.count_dependent" value="16|4|80"/>
<setting id="org.eclipse.jdt.core.formatter.alignment_for_throws_clause_in_method_declaration.count_dependent" value="16|4|48"/>
<setting id="org.eclipse.jdt.core.formatter.insert_space_after_comma_in_type_parameters" value="insert"/>
<setting id="org.eclipse.jdt.core.formatter.continuation_indentation" value="4"/>
<setting id="org.eclipse.jdt.core.formatter.alignment_for_superinterfaces_in_enum_declaration.count_dependent" value="16|4|49"/>
<setting id="org.eclipse.jdt.core.formatter.insert_space_after_postfix_operator" value="do not insert"/>
<setting id="org.eclipse.jdt.core.formatter.alignment_for_arguments_in_method_invocation" value="16"/>
<setting id="org.eclipse.jdt.core.formatter.insert_space_before_closing_angle_bracket_in_type_arguments" value="do not insert"/>
<setting id="org.eclipse.jdt.core.formatter.insert_space_before_comma_in_superinterfaces" value="do not insert"/>
<setting id="org.eclipse.jdt.core.formatter.blank_lines_before_new_chunk" value="1"/>
<setting id="org.eclipse.jdt.core.formatter.insert_space_before_binary_operator" value="insert"/>
<setting id="org.eclipse.jdt.core.formatter.blank_lines_before_package" value="0"/>
<setting id="org.eclipse.jdt.core.formatter.alignment_for_cascading_method_invocation_with_arguments" value="16"/>
<setting id="org.eclipse.jdt.core.compiler.source" value="1.7"/>
<setting id="org.eclipse.jdt.core.formatter.alignment_for_throws_clause_in_constructor_declaration.count_dependent" value="16|4|48"/>
<setting id="org.eclipse.jdt.core.formatter.insert_space_after_comma_in_enum_constant_arguments" value="insert"/>
<setting id="org.eclipse.jdt.core.formatter.insert_space_after_opening_paren_in_constructor_declaration" value="do not insert"/>
<setting id="org.eclipse.jdt.core.formatter.comment.format_line_comments" value="true"/>
<setting id="org.eclipse.jdt.core.formatter.insert_space_after_closing_angle_bracket_in_type_arguments" value="do not insert"/>
<setting id="org.eclipse.jdt.core.formatter.insert_space_after_comma_in_enum_declarations" value="insert"/>
<setting id="org.eclipse.jdt.core.formatter.join_wrapped_lines" value="true"/>
<setting id="org.eclipse.jdt.core.formatter.insert_space_before_opening_brace_in_block" value="insert"/>
<setting id="org.eclipse.jdt.core.formatter.alignment_for_arguments_in_explicit_constructor_call" value="16"/>
<setting id="org.eclipse.jdt.core.formatter.wrap_non_simple_local_variable_annotation" value="true"/>
<setting id="org.eclipse.jdt.core.formatter.insert_space_before_comma_in_method_invocation_arguments" value="do not insert"/>
<setting id="org.eclipse.jdt.core.formatter.align_type_members_on_columns" value="false"/>
<setting id="org.eclipse.jdt.core.formatter.blank_lines_before_member_type" value="0"/>
<setting id="org.eclipse.jdt.core.formatter.insert_space_after_opening_paren_in_enum_constant" value="do not insert"/>
<setting id="org.eclipse.jdt.core.formatter.alignment_for_enum_constants.count_dependent" value="16|5|48"/>
<setting id="org.eclipse.jdt.core.formatter.insert_space_after_opening_paren_in_for" value="do not insert"/>
<setting id="org.eclipse.jdt.core.formatter.insert_space_before_opening_brace_in_method_declaration" value="insert"/>
<setting id="org.eclipse.jdt.core.formatter.alignment_for_selector_in_method_invocation" value="16"/>
<setting id="org.eclipse.jdt.core.formatter.insert_space_after_opening_paren_in_switch" value="do not insert"/>
<setting id="org.eclipse.jdt.core.formatter.insert_space_after_unary_operator" value="do not insert"/>
<setting id="org.eclipse.jdt.core.formatter.insert_space_after_colon_in_case" value="insert"/>
<setting id="org.eclipse.jdt.core.formatter.comment.indent_parameter_description" value="false"/>
<setting id="org.eclipse.jdt.core.formatter.insert_space_before_closing_paren_in_method_declaration" value="do not insert"/>
<setting id="org.eclipse.jdt.core.formatter.insert_space_before_closing_paren_in_switch" value="do not insert"/>
<setting id="org.eclipse.jdt.core.formatter.insert_space_before_opening_brace_in_enum_declaration" value="insert"/>
<setting id="org.eclipse.jdt.core.formatter.insert_space_before_opening_angle_bracket_in_type_parameters" value="do not insert"/>
<setting id="org.eclipse.jdt.core.formatter.comment.clear_blank_lines_in_block_comment" value="false"/>
<setting id="org.eclipse.jdt.core.formatter.insert_new_line_in_empty_type_declaration" value="insert"/>
<setting id="org.eclipse.jdt.core.formatter.lineSplit" value="120"/>
<setting id="org.eclipse.jdt.core.formatter.insert_space_before_opening_paren_in_if" value="insert"/>
<setting id="org.eclipse.jdt.core.formatter.alignment_for_selector_in_method_invocation.count_dependent" value="16|4|48"/>
<setting id="org.eclipse.jdt.core.formatter.insert_space_between_brackets_in_array_type_reference" value="do not insert"/>
<setting id="org.eclipse.jdt.core.formatter.insert_space_after_opening_paren_in_parenthesized_expression" value="do not insert"/>
<setting id="org.eclipse.jdt.core.formatter.insert_space_before_comma_in_explicitconstructorcall_arguments" value="do not insert"/>
<setting id="org.eclipse.jdt.core.formatter.insert_space_before_opening_brace_in_constructor_declaration" value="insert"/>
<setting id="org.eclipse.jdt.core.formatter.blank_lines_before_first_class_body_declaration" value="0"/>
<setting id="org.eclipse.jdt.core.formatter.insert_new_line_after_annotation_on_method" value="insert"/>
<setting id="org.eclipse.jdt.core.formatter.indentation.size" value="4"/>
<setting id="org.eclipse.jdt.core.formatter.insert_space_between_empty_parens_in_method_declaration" value="do not insert"/>
<setting id="org.eclipse.jdt.core.formatter.enabling_tag" value="@formatter:on"/>
<setting id="org.eclipse.jdt.core.formatter.insert_space_before_opening_paren_in_enum_constant" value="do not insert"/>
<setting id="org.eclipse.jdt.core.formatter.alignment_for_annotations_on_package" value="1585"/>
<setting id="org.eclipse.jdt.core.formatter.alignment_for_superclass_in_type_declaration" value="16"/>
<setting id="org.eclipse.jdt.core.formatter.alignment_for_assignment" value="16"/>
<setting id="org.eclipse.jdt.core.compiler.problem.assertIdentifier" value="error"/>
<setting id="org.eclipse.jdt.core.formatter.tabulation.char" value="space"/>
<setting id="org.eclipse.jdt.core.formatter.insert_space_after_semicolon_in_try_resources" value="insert"/>
<setting id="org.eclipse.jdt.core.formatter.insert_space_after_comma_in_constructor_declaration_parameters" value="insert"/>
<setting id="org.eclipse.jdt.core.formatter.insert_space_before_prefix_operator" value="do not insert"/>
<setting id="org.eclipse.jdt.core.formatter.indent_statements_compare_to_body" value="true"/>
<setting id="org.eclipse.jdt.core.formatter.blank_lines_before_method" value="1"/>
<setting id="org.eclipse.jdt.core.formatter.wrap_outer_expressions_when_nested" value="true"/>
<setting id="org.eclipse.jdt.core.formatter.wrap_non_simple_type_annotation" value="true"/>
<setting id="org.eclipse.jdt.core.formatter.format_guardian_clause_on_one_line" value="false"/>
<setting id="org.eclipse.jdt.core.formatter.insert_space_before_colon_in_for" value="insert"/>
<setting id="org.eclipse.jdt.core.formatter.alignment_for_field_declaration" value="16"/>
<setting id="org.eclipse.jdt.core.formatter.insert_space_before_closing_paren_in_cast" value="do not insert"/>
<setting id="org.eclipse.jdt.core.formatter.alignment_for_parameters_in_constructor_declaration" value="16"/>
<setting id="org.eclipse.jdt.core.formatter.insert_space_after_colon_in_labeled_statement" value="insert"/>
<setting id="org.eclipse.jdt.core.formatter.brace_position_for_annotation_type_declaration" value="end_of_line"/>
<setting id="org.eclipse.jdt.core.formatter.insert_new_line_in_empty_method_body" value="do not insert"/>
<setting id="org.eclipse.jdt.core.formatter.alignment_for_method_declaration" value="0"/>
<setting id="org.eclipse.jdt.core.formatter.insert_space_before_closing_paren_in_try" value="do not insert"/>
<setting id="org.eclipse.jdt.core.formatter.insert_space_before_closing_paren_in_method_invocation" value="do not insert"/>
<setting id="org.eclipse.jdt.core.formatter.insert_space_after_opening_bracket_in_array_allocation_expression" value="do not insert"/>
<setting id="org.eclipse.jdt.core.formatter.insert_space_before_opening_brace_in_enum_constant" value="insert"/>
<setting id="org.eclipse.jdt.core.formatter.insert_space_before_comma_in_annotation" value="do not insert"/>
<setting id="org.eclipse.jdt.core.formatter.insert_space_after_at_in_annotation_type_declaration" value="do not insert"/>
<setting id="org.eclipse.jdt.core.formatter.insert_space_before_comma_in_method_declaration_throws" value="do not insert"/>
<setting id="org.eclipse.jdt.core.formatter.insert_space_before_closing_paren_in_if" value="do not insert"/>
<setting id="org.eclipse.jdt.core.formatter.brace_position_for_switch" value="end_of_line"/>
<setting id="org.eclipse.jdt.core.formatter.insert_space_after_comma_in_method_declaration_throws" value="insert"/>
<setting id="org.eclipse.jdt.core.formatter.insert_space_before_parenthesized_expression_in_return" value="insert"/>
<setting id="org.eclipse.jdt.core.formatter.insert_space_before_opening_paren_in_annotation" value="do not insert"/>
<setting id="org.eclipse.jdt.core.formatter.insert_space_after_question_in_conditional" value="insert"/>
<setting id="org.eclipse.jdt.core.formatter.insert_space_after_question_in_wildcard" value="do not insert"/>
<setting id="org.eclipse.jdt.core.formatter.insert_space_after_opening_paren_in_try" value="do not insert"/>
<setting id="org.eclipse.jdt.core.formatter.insert_space_before_closing_bracket_in_array_allocation_expression" value="do not insert"/>
<setting id="org.eclipse.jdt.core.formatter.comment.preserve_white_space_between_code_and_line_comments" value="false"/>
<setting id="org.eclipse.jdt.core.formatter.insert_space_before_parenthesized_expression_in_throw" value="insert"/>
<setting id="org.eclipse.jdt.core.formatter.insert_space_before_comma_in_type_arguments" value="do not insert"/>
<setting id="org.eclipse.jdt.core.compiler.problem.enumIdentifier" value="error"/>
<setting id="org.eclipse.jdt.core.formatter.alignment_for_generic_type_arguments" value="16"/>
<setting id="org.eclipse.jdt.core.formatter.indent_switchstatements_compare_to_switch" value="true"/>
<setting id="org.eclipse.jdt.core.formatter.comment_new_line_at_start_of_html_paragraph" value="true"/>
<setting id="org.eclipse.jdt.core.formatter.insert_space_before_ellipsis" value="do not insert"/>
<setting id="org.eclipse.jdt.core.formatter.brace_position_for_block" value="end_of_line"/>
<setting id="org.eclipse.jdt.core.formatter.insert_space_after_comment_prefix" value="do not insert"/>
<setting id="org.eclipse.jdt.core.formatter.insert_space_before_comma_in_for_inits" value="do not insert"/>
<setting id="org.eclipse.jdt.core.formatter.brace_position_for_method_declaration" value="end_of_line"/>
<setting id="org.eclipse.jdt.core.formatter.compact_else_if" value="true"/>
<setting id="org.eclipse.jdt.core.formatter.wrap_non_simple_parameter_annotation" value="false"/>
<setting id="org.eclipse.jdt.core.formatter.wrap_before_or_operator_multicatch" value="true"/>
<setting id="org.eclipse.jdt.core.formatter.insert_space_before_comma_in_array_initializer" value="do not insert"/>
<setting id="org.eclipse.jdt.core.formatter.insert_space_after_comma_in_for_increments" value="insert"/>
<setting id="org.eclipse.jdt.core.formatter.alignment_for_annotations_on_method" value="1585"/>
<setting id="org.eclipse.jdt.core.formatter.format_line_comment_starting_on_first_column" value="true"/>
<setting id="org.eclipse.jdt.core.formatter.insert_space_before_closing_bracket_in_array_reference" value="do not insert"/>
<setting id="org.eclipse.jdt.core.formatter.insert_new_line_after_annotation_on_field" value="insert"/>
<setting id="org.eclipse.jdt.core.formatter.comment.indent_root_tags" value="true"/>
<setting id="org.eclipse.jdt.core.formatter.brace_position_for_enum_constant" value="end_of_line"/>
<setting id="org.eclipse.jdt.core.formatter.insert_space_before_comma_in_enum_declarations" value="do not insert"/>
<setting id="org.eclipse.jdt.core.formatter.alignment_for_union_type_in_multicatch" value="16"/>
<setting id="org.eclipse.jdt.core.formatter.insert_space_after_comma_in_explicitconstructorcall_arguments" value="insert"/>
<setting id="org.eclipse.jdt.core.formatter.insert_space_before_opening_brace_in_switch" value="insert"/>
<setting id="org.eclipse.jdt.core.formatter.insert_space_before_comma_in_method_declaration_parameters" value="do not insert"/>
<setting id="org.eclipse.jdt.core.formatter.insert_space_after_comma_in_superinterfaces" value="insert"/>
<setting id="org.eclipse.jdt.core.formatter.insert_space_before_comma_in_allocation_expression" value="do not insert"/>
<setting id="org.eclipse.jdt.core.formatter.tabulation.size" value="4"/>
<setting id="org.eclipse.jdt.core.formatter.insert_space_before_opening_bracket_in_array_type_reference" value="do not insert"/>
<setting id="org.eclipse.jdt.core.formatter.insert_new_line_after_opening_brace_in_array_initializer" value="do not insert"/>
<setting id="org.eclipse.jdt.core.formatter.insert_space_after_closing_brace_in_block" value="insert"/>
<setting id="org.eclipse.jdt.core.formatter.insert_space_before_opening_bracket_in_array_reference" value="do not insert"/>
<setting id="org.eclipse.jdt.core.formatter.insert_new_line_in_empty_enum_constant" value="do not insert"/>
<setting id="org.eclipse.jdt.core.formatter.insert_space_after_opening_angle_bracket_in_type_arguments" value="do not insert"/>
<setting id="org.eclipse.jdt.core.formatter.insert_space_before_opening_paren_in_constructor_declaration" value="do not insert"/>
<setting id="org.eclipse.jdt.core.formatter.insert_space_before_comma_in_constructor_declaration_throws" value="do not insert"/>
<setting id="org.eclipse.jdt.core.formatter.insert_space_after_opening_paren_in_if" value="do not insert"/>
<setting id="org.eclipse.jdt.core.formatter.alignment_for_arguments_in_method_invocation.count_dependent" value="16|5|80"/>
<setting id="org.eclipse.jdt.core.formatter.comment.clear_blank_lines_in_javadoc_comment" value="false"/>
<setting id="org.eclipse.jdt.core.formatter.alignment_for_annotations_on_parameter.count_dependent" value="1040|-1|1040"/>
<setting id="org.eclipse.jdt.core.formatter.alignment_for_throws_clause_in_constructor_declaration" value="16"/>
<setting id="org.eclipse.jdt.core.formatter.insert_space_before_assignment_operator" value="insert"/>
<setting id="org.eclipse.jdt.core.formatter.insert_space_after_assignment_operator" value="insert"/>
<setting id="org.eclipse.jdt.core.formatter.alignment_for_annotations_on_package.count_dependent" value="1585|-1|1585"/>
<setting id="org.eclipse.jdt.core.formatter.indent_empty_lines" value="false"/>
<setting id="org.eclipse.jdt.core.formatter.insert_space_after_opening_paren_in_synchronized" value="do not insert"/>
<setting id="org.eclipse.jdt.core.formatter.insert_space_after_closing_paren_in_cast" value="insert"/>
<setting id="org.eclipse.jdt.core.formatter.insert_space_after_comma_in_method_declaration_parameters" value="insert"/>
<setting id="org.eclipse.jdt.core.formatter.force_if_else_statement_brace" value="true"/>
<setting id="org.eclipse.jdt.core.formatter.brace_position_for_block_in_case" value="end_of_line"/>
<setting id="org.eclipse.jdt.core.formatter.number_of_empty_lines_to_preserve" value="3"/>
<setting id="org.eclipse.jdt.core.formatter.wrap_non_simple_package_annotation" value="true"/>
<setting id="org.eclipse.jdt.core.formatter.insert_space_before_opening_paren_in_method_declaration" value="do not insert"/>
<setting id="org.eclipse.jdt.core.formatter.insert_space_after_opening_paren_in_catch" value="do not insert"/>
<setting id="org.eclipse.jdt.core.formatter.insert_space_before_closing_paren_in_constructor_declaration" value="do not insert"/>
<setting id="org.eclipse.jdt.core.formatter.insert_space_before_opening_paren_in_method_invocation" value="do not insert"/>
<setting id="org.eclipse.jdt.core.formatter.insert_space_after_opening_bracket_in_array_reference" value="do not insert"/>
<setting id="org.eclipse.jdt.core.formatter.alignment_for_arguments_in_qualified_allocation_expression" value="16"/>
<setting id="org.eclipse.jdt.core.formatter.alignment_for_arguments_in_annotation.count_dependent" value="16|-1|16"/>
<setting id="org.eclipse.jdt.core.formatter.insert_space_after_and_in_type_parameter" value="insert"/>
<setting id="org.eclipse.jdt.core.formatter.alignment_for_annotations_on_type" value="1585"/>
<setting id="org.eclipse.jdt.core.compiler.compliance" value="1.7"/>
<setting id="org.eclipse.jdt.core.formatter.continuation_indentation_for_array_initializer" value="4"/>
<setting id="org.eclipse.jdt.core.formatter.insert_space_between_empty_brackets_in_array_allocation_expression" value="do not insert"/>
<setting id="org.eclipse.jdt.core.formatter.insert_space_before_at_in_annotation_type_declaration" value="insert"/>
<setting id="org.eclipse.jdt.core.formatter.alignment_for_arguments_in_allocation_expression" value="16"/>
<setting id="org.eclipse.jdt.core.formatter.insert_space_after_opening_paren_in_cast" value="do not insert"/>
<setting id="org.eclipse.jdt.core.formatter.insert_space_before_unary_operator" value="do not insert"/>
<setting id="org.eclipse.jdt.core.formatter.alignment_for_new_anonymous_class" value="20"/>
<setting id="org.eclipse.jdt.core.formatter.insert_space_before_closing_angle_bracket_in_parameterized_type_reference" value="do not insert"/>
<setting id="org.eclipse.jdt.core.formatter.alignment_for_annotations_on_local_variable.count_dependent" value="1585|-1|1585"/>
<setting id="org.eclipse.jdt.core.formatter.insert_space_before_opening_brace_in_anonymous_type_declaration" value="insert"/>
<setting id="org.eclipse.jdt.core.formatter.keep_empty_array_initializer_on_one_line" value="false"/>
<setting id="org.eclipse.jdt.core.formatter.insert_new_line_in_empty_enum_declaration" value="insert"/>
<setting id="org.eclipse.jdt.core.formatter.alignment_for_annotations_on_field.count_dependent" value="1585|-1|1585"/>
<setting id="org.eclipse.jdt.core.formatter.keep_imple_if_on_one_line" value="false"/>
<setting id="org.eclipse.jdt.core.formatter.insert_space_before_comma_in_constructor_declaration_parameters" value="do not insert"/>
<setting id="org.eclipse.jdt.core.formatter.insert_space_after_closing_angle_bracket_in_type_parameters" value="insert"/>
<setting id="org.eclipse.jdt.core.formatter.insert_space_before_colon_in_labeled_statement" value="do not insert"/>
<setting id="org.eclipse.jdt.core.formatter.insert_new_line_at_end_of_file_if_missing" value="insert"/>
<setting id="org.eclipse.jdt.core.formatter.insert_space_after_colon_in_for" value="insert"/>
<setting id="org.eclipse.jdt.core.formatter.insert_space_before_comma_in_parameterized_type_reference" value="do not insert"/>
<setting id="org.eclipse.jdt.core.formatter.alignment_for_parameters_in_constructor_declaration.count_dependent" value="16|5|80"/>
<setting id="org.eclipse.jdt.core.formatter.alignment_for_superinterfaces_in_type_declaration" value="16"/>
<setting id="org.eclipse.jdt.core.formatter.brace_position_for_enum_declaration" value="end_of_line"/>
<setting id="org.eclipse.jdt.core.formatter.alignment_for_binary_expression" value="16"/>
<setting id="org.eclipse.jdt.core.formatter.insert_space_before_closing_paren_in_while" value="do not insert"/>
<setting id="org.eclipse.jdt.core.formatter.insert_new_line_after_annotation_on_type" value="insert"/>
<setting id="org.eclipse.jdt.core.compiler.codegen.inlineJsrBytecode" value="enabled"/>
<setting id="org.eclipse.jdt.core.formatter.insert_space_before_opening_paren_in_try" value="insert"/>
<setting id="org.eclipse.jdt.core.formatter.put_empty_statement_on_new_line" value="false"/>
<setting id="org.eclipse.jdt.core.formatter.insert_new_line_after_label" value="do not insert"/>
<setting id="org.eclipse.jdt.core.formatter.insert_new_line_after_annotation_on_parameter" value="do not insert"/>
<setting id="org.eclipse.jdt.core.formatter.insert_space_after_opening_angle_bracket_in_type_parameters" value="do not insert"/>
<setting id="org.eclipse.jdt.core.formatter.insert_space_between_empty_parens_in_method_invocation" value="do not insert"/>
<setting id="org.eclipse.jdt.core.formatter.comment.format_javadoc_comments" value="true"/>
<setting id="org.eclipse.jdt.core.formatter.alignment_for_arguments_in_enum_constant" value="16"/>
<setting id="org.eclipse.jdt.core.formatter.insert_new_line_before_while_in_do_statement" value="do not insert"/>
<setting id="org.eclipse.jdt.core.formatter.alignment_for_arguments_in_enum_constant.count_dependent" value="16|-1|16"/>
<setting id="org.eclipse.jdt.core.formatter.comment.line_length" value="120"/>
<setting id="org.eclipse.jdt.core.formatter.insert_new_line_after_annotation_on_package" value="insert"/>
<setting id="org.eclipse.jdt.core.formatter.blank_lines_between_import_groups" value="1"/>
<setting id="org.eclipse.jdt.core.formatter.insert_space_before_comma_in_enum_constant_arguments" value="do not insert"/>
<setting id="org.eclipse.jdt.core.formatter.insert_space_before_semicolon" value="do not insert"/>
<setting id="org.eclipse.jdt.core.formatter.brace_position_for_constructor_declaration" value="end_of_line"/>
<setting id="org.eclipse.jdt.core.formatter.number_of_blank_lines_at_beginning_of_method_body" value="0"/>
<setting id="org.eclipse.jdt.core.formatter.insert_space_before_colon_in_conditional" value="insert"/>
<setting id="org.eclipse.jdt.core.formatter.indent_body_declarations_compare_to_type_header" value="true"/>
<setting id="org.eclipse.jdt.core.formatter.insert_space_before_opening_paren_in_annotation_type_member_declaration" value="do not insert"/>
<setting id="org.eclipse.jdt.core.formatter.wrap_before_binary_operator" value="true"/>
<setting id="org.eclipse.jdt.core.formatter.blank_lines_between_type_declarations" value="2"/>
<setting id="org.eclipse.jdt.core.formatter.indent_body_declarations_compare_to_enum_declaration_header" value="true"/>
<setting id="org.eclipse.jdt.core.formatter.insert_space_before_closing_paren_in_synchronized" value="do not insert"/>
<setting id="org.eclipse.jdt.core.formatter.alignment_for_superinterfaces_in_enum_declaration" value="16"/>
<setting id="org.eclipse.jdt.core.formatter.indent_statements_compare_to_block" value="true"/>
<setting id="org.eclipse.jdt.core.formatter.join_lines_in_comments" value="true"/>
<setting id="org.eclipse.jdt.core.formatter.insert_space_before_question_in_conditional" value="insert"/>
<setting id="org.eclipse.jdt.core.formatter.insert_space_before_comma_in_multiple_field_declarations" value="do not insert"/>
<setting id="org.eclipse.jdt.core.formatter.alignment_for_compact_if" value="16"/>
<setting id="org.eclipse.jdt.core.formatter.insert_space_after_comma_in_for_inits" value="insert"/>
<setting id="org.eclipse.jdt.core.formatter.indent_switchstatements_compare_to_cases" value="true"/>
<setting id="org.eclipse.jdt.core.formatter.insert_space_after_comma_in_array_initializer" value="insert"/>
<setting id="org.eclipse.jdt.core.formatter.insert_space_before_colon_in_default" value="do not insert"/>
<setting id="org.eclipse.jdt.core.formatter.insert_space_before_and_in_type_parameter" value="insert"/>
<setting id="org.eclipse.jdt.core.formatter.insert_space_between_empty_parens_in_constructor_declaration" value="do not insert"/>
<setting id="org.eclipse.jdt.core.formatter.blank_lines_before_imports" value="0"/>
<setting id="org.eclipse.jdt.core.formatter.insert_space_after_colon_in_assert" value="insert"/>
<setting id="org.eclipse.jdt.core.formatter.alignment_for_annotations_on_field" value="1585"/>
<setting id="org.eclipse.jdt.core.formatter.comment.format_html" value="true"/>
<setting id="org.eclipse.jdt.core.formatter.alignment_for_throws_clause_in_method_declaration" value="16"/>
<setting id="org.eclipse.jdt.core.formatter.insert_space_before_closing_angle_bracket_in_type_parameters" value="do not insert"/>
<setting id="org.eclipse.jdt.core.formatter.insert_space_before_opening_bracket_in_array_allocation_expression" value="do not insert"/>
<setting id="org.eclipse.jdt.core.formatter.insert_new_line_in_empty_anonymous_type_declaration" value="do not insert"/>
<setting id="org.eclipse.jdt.core.formatter.insert_space_after_colon_in_conditional" value="insert"/>
<setting id="org.eclipse.jdt.core.formatter.insert_space_after_opening_angle_bracket_in_parameterized_type_reference" value="do not insert"/>
<setting id="org.eclipse.jdt.core.formatter.insert_space_before_closing_paren_in_for" value="do not insert"/>
<setting id="org.eclipse.jdt.core.formatter.alignment_for_expressions_in_array_initializer.count_dependent" value="16|5|80"/>
<setting id="org.eclipse.jdt.core.formatter.insert_space_before_postfix_operator" value="do not insert"/>
<setting id="org.eclipse.jdt.core.formatter.comment.format_source_code" value="true"/>
<setting id="org.eclipse.jdt.core.formatter.insert_space_before_opening_paren_in_synchronized" value="insert"/>
<setting id="org.eclipse.jdt.core.formatter.insert_space_after_comma_in_allocation_expression" value="insert"/>
<setting id="org.eclipse.jdt.core.formatter.insert_space_after_comma_in_constructor_declaration_throws" value="insert"/>
<setting id="org.eclipse.jdt.core.formatter.alignment_for_parameters_in_method_declaration" value="16"/>
<setting id="org.eclipse.jdt.core.formatter.insert_space_before_closing_brace_in_array_initializer" value="do not insert"/>
<setting id="org.eclipse.jdt.core.compiler.codegen.targetPlatform" value="1.7"/>
<setting id="org.eclipse.jdt.core.formatter.alignment_for_resources_in_try" value="80"/>
<setting id="org.eclipse.jdt.core.formatter.use_tabs_only_for_leading_indentations" value="false"/>
<setting id="org.eclipse.jdt.core.formatter.alignment_for_arguments_in_annotation" value="16"/>
<setting id="org.eclipse.jdt.core.formatter.comment.format_header" value="true"/>
<setting id="org.eclipse.jdt.core.formatter.comment.format_block_comments" value="true"/>
<setting id="org.eclipse.jdt.core.formatter.insert_space_before_closing_paren_in_enum_constant" value="do not insert"/>
<setting id="org.eclipse.jdt.core.formatter.alignment_for_enum_constants" value="0"/>
<setting id="org.eclipse.jdt.core.formatter.insert_space_before_closing_paren_in_parenthesized_expression" value="do not insert"/>
<setting id="org.eclipse.jdt.core.formatter.indent_body_declarations_compare_to_annotation_declaration_header" value="true"/>
<setting id="org.eclipse.jdt.core.formatter.insert_new_line_in_empty_block" value="insert"/>
<setting id="org.eclipse.jdt.core.formatter.insert_space_before_opening_paren_in_parenthesized_expression" value="do not insert"/>
<setting id="org.eclipse.jdt.core.formatter.insert_space_before_closing_paren_in_catch" value="do not insert"/>
<setting id="org.eclipse.jdt.core.formatter.insert_space_before_comma_in_multiple_local_declarations" value="do not insert"/>
<setting id="org.eclipse.jdt.core.formatter.alignment_for_superinterfaces_in_type_declaration.count_dependent" value="16|4|48"/>
<setting id="org.eclipse.jdt.core.formatter.insert_space_before_opening_paren_in_switch" value="insert"/>
<setting id="org.eclipse.jdt.core.formatter.insert_space_before_comma_in_for_increments" value="do not insert"/>
<setting id="org.eclipse.jdt.core.formatter.alignment_for_annotations_on_method.count_dependent" value="1585|-1|1585"/>
<setting id="org.eclipse.jdt.core.formatter.insert_space_after_opening_paren_in_method_invocation" value="do not insert"/>
<setting id="org.eclipse.jdt.core.formatter.insert_space_before_colon_in_assert" value="insert"/>
<setting id="org.eclipse.jdt.core.formatter.brace_position_for_type_declaration" value="end_of_line"/>
<setting id="org.eclipse.jdt.core.formatter.insert_space_before_opening_brace_in_array_initializer" value="insert"/>
<setting id="org.eclipse.jdt.core.formatter.insert_space_between_empty_braces_in_array_initializer" value="do not insert"/>
<setting id="org.eclipse.jdt.core.formatter.alignment_for_binary_expression.count_dependent" value="16|-1|16"/>
<setting id="org.eclipse.jdt.core.formatter.wrap_non_simple_member_annotation" value="true"/>
<setting id="org.eclipse.jdt.core.formatter.alignment_for_annotations_on_local_variable" value="1585"/>
<setting id="org.eclipse.jdt.core.formatter.insert_space_after_opening_paren_in_method_declaration" value="do not insert"/>
<setting id="org.eclipse.jdt.core.formatter.insert_space_before_semicolon_in_for" value="do not insert"/>
<setting id="org.eclipse.jdt.core.formatter.alignment_for_arguments_in_explicit_constructor_call.count_dependent" value="16|5|80"/>
<setting id="org.eclipse.jdt.core.formatter.insert_space_before_opening_paren_in_catch" value="insert"/>
<setting id="org.eclipse.jdt.core.formatter.insert_space_before_opening_angle_bracket_in_parameterized_type_reference" value="do not insert"/>
<setting id="org.eclipse.jdt.core.formatter.insert_space_after_comma_in_multiple_field_declarations" value="insert"/>
<setting id="org.eclipse.jdt.core.formatter.insert_space_before_closing_paren_in_annotation" value="do not insert"/>
<setting id="org.eclipse.jdt.core.formatter.alignment_for_generic_type_arguments.count_dependent" value="16|-1|16"/>
<setting id="org.eclipse.jdt.core.formatter.insert_space_after_comma_in_parameterized_type_reference" value="insert"/>
<setting id="org.eclipse.jdt.core.formatter.alignment_for_arguments_in_allocation_expression.count_dependent" value="16|5|80"/>
<setting id="org.eclipse.jdt.core.formatter.insert_space_after_comma_in_method_invocation_arguments" value="insert"/>
<setting id="org.eclipse.jdt.core.formatter.alignment_for_parameters_in_method_declaration.count_dependent" value="16|5|80"/>
<setting id="org.eclipse.jdt.core.formatter.comment.new_lines_at_javadoc_boundaries" value="true"/>
<setting id="org.eclipse.jdt.core.formatter.blank_lines_after_imports" value="1"/>
<setting id="org.eclipse.jdt.core.formatter.insert_space_after_comma_in_multiple_local_declarations" value="insert"/>
<setting id="org.eclipse.jdt.core.formatter.indent_body_declarations_compare_to_enum_constant_header" value="true"/>
<setting id="org.eclipse.jdt.core.formatter.insert_space_after_semicolon_in_for" value="insert"/>
<setting id="org.eclipse.jdt.core.formatter.never_indent_line_comments_on_first_column" value="false"/>
<setting id="org.eclipse.jdt.core.formatter.insert_space_before_semicolon_in_try_resources" value="do not insert"/>
<setting id="org.eclipse.jdt.core.formatter.alignment_for_for_statement" value="16"/>
<setting id="org.eclipse.jdt.core.formatter.insert_space_before_opening_angle_bracket_in_type_arguments" value="do not insert"/>
<setting id="org.eclipse.jdt.core.formatter.never_indent_block_comments_on_first_column" value="false"/>
<setting id="org.eclipse.jdt.core.formatter.keep_then_statement_on_same_line" value="false"/>
</profile>
</profiles>

View File

@ -1,337 +0,0 @@
<?xml version="1.0" encoding="UTF-8"?>
<!--~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
~ Copyright (c) 2015-2018 Skymind, Inc.
~
~ This program and the accompanying materials are made available under the
~ terms of the Apache License, Version 2.0 which is available at
~ https://www.apache.org/licenses/LICENSE-2.0.
~
~ Unless required by applicable law or agreed to in writing, software
~ distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
~ WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
~ License for the specific language governing permissions and limitations
~ under the License.
~
~ SPDX-License-Identifier: Apache-2.0
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~-->
<project xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xmlns="http://maven.apache.org/POM/4.0.0"
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
<parent>
<groupId>org.deeplearning4j</groupId>
<artifactId>deeplearning4j</artifactId>
<version>1.0.0-SNAPSHOT</version>
</parent>
<modelVersion>4.0.0</modelVersion>
<groupId>org.deeplearning4j</groupId>
<artifactId>gym-java-client</artifactId>
<name>gym-java-client</name>
<description>A Java client for Open AI's Reinforcement Learning Gym</description>
<licenses>
<license>
<name>Apache License, Version 2.0</name>
<url>http://www.apache.org/licenses/LICENSE-2.0.txt</url>
<distribution>repo</distribution>
</license>
</licenses>
<developers>
<developer>
<id>rubenfiszel</id>
<name>Ruben Fiszel</name>
<email>ruben.fiszel@epfl.ch</email>
</developer>
</developers>
<properties>
<nd4j.backend>nd4j-native</nd4j.backend>
</properties>
<dependencies>
<dependency>
<groupId>org.nd4j</groupId>
<artifactId>${nd4j.backend}</artifactId>
<version>${nd4j.version}</version>
</dependency>
<dependency>
<groupId>commons-codec</groupId>
<artifactId>commons-codec</artifactId>
<version>${commons-codec.version}</version>
</dependency>
<dependency>
<groupId>org.apache.httpcomponents</groupId>
<artifactId>httpclient</artifactId>
<version>${httpclient.version}</version>
</dependency>
<dependency>
<groupId>org.apache.httpcomponents</groupId>
<artifactId>httpcore</artifactId>
<version>${httpcore.version}</version>
</dependency>
<dependency>
<groupId>org.apache.httpcomponents</groupId>
<artifactId>httpmime</artifactId>
<version>${httpmime.version}</version>
</dependency>
<dependency>
<groupId>com.mashape.unirest</groupId>
<artifactId>unirest-java</artifactId>
<version>${unirest.version}</version>
</dependency>
<dependency>
<groupId>org.objenesis</groupId>
<artifactId>objenesis</artifactId>
<version>${objenesis.version}</version>
</dependency>
<dependency>
<groupId>org.mockito</groupId>
<artifactId>mockito-core</artifactId>
<version>${mockito.version}</version>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.ow2.asm</groupId>
<artifactId>asm</artifactId>
<version>${asm.version}</version>
</dependency>
<dependency>
<groupId>cglib</groupId>
<artifactId>cglib</artifactId>
<version>3.1</version>
<scope>test</scope>
</dependency>
<dependency>
<groupId>junit</groupId>
<artifactId>junit</artifactId>
<version>${junit.version}</version>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.powermock</groupId>
<artifactId>powermock-api-mockito2</artifactId>
<version>1.7.3</version>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.powermock</groupId>
<artifactId>powermock-module-junit4</artifactId>
<version>1.7.3</version>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.slf4j</groupId>
<artifactId>slf4j-api</artifactId>
<version>${slf4j.version}</version>
</dependency>
<dependency>
<groupId>ch.qos.logback</groupId>
<artifactId>logback-classic</artifactId>
<version>${logback.version}</version>
<scope>test</scope>
</dependency>
<dependency>
<groupId>ch.qos.logback</groupId>
<artifactId>logback-core</artifactId>
<version>${logback.version}</version>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.projectlombok</groupId>
<artifactId>lombok</artifactId>
<version>${lombok.version}</version>
<scope>provided</scope>
</dependency>
</dependencies>
<build>
<plugins>
<plugin>
<artifactId>maven-source-plugin</artifactId>
<version>${maven-source-plugin.version}</version>
<executions>
<execution>
<id>attach-sources</id>
<goals>
<goal>jar</goal>
</goals>
</execution>
</executions>
</plugin>
<plugin>
<artifactId>maven-surefire-plugin</artifactId>
<version>${maven-surefire-plugin.version}</version>
<configuration>
<!--
By default: Surefire will set the classpath based on the manifest. Because tests are not included
in the JAR, any tests that rely on class path scanning for resources in the tests directory will not
function correctly without this configuratino.
For example, tests for custom layers (where the custom layer is defined in the test directory)
will fail due to the custom layer not being found on the classpath.
http://maven.apache.org/surefire/maven-surefire-plugin/examples/class-loading.html
-->
<useSystemClassLoader>true</useSystemClassLoader>
<useManifestOnlyJar>false</useManifestOnlyJar>
</configuration>
</plugin>
<plugin>
<artifactId>maven-javadoc-plugin</artifactId>
<version>${maven-javadoc-plugin.version}</version>
<configuration>
<additionalparam>-Xdoclint:none</additionalparam>
</configuration>
<executions>
<execution>
<id>attach-javadocs</id>
<goals>
<goal>jar</goal>
</goals>
</execution>
</executions>
</plugin>
<plugin>
<groupId>com.lewisd</groupId>
<artifactId>lint-maven-plugin</artifactId>
<version>0.0.11</version>
<configuration>
<failOnViolation>true</failOnViolation>
<onlyRunRules>
<rule>DuplicateDep</rule>
<rule>RedundantDepVersion</rule>
<rule>RedundantPluginVersion</rule>
<!-- Rules incompatible with Java 9
<rule>VersionProp</rule>
<rule>DotVersionProperty</rule> -->
</onlyRunRules>
</configuration>
<executions>
<execution>
<id>pom-lint</id>
<phase>validate</phase>
<goals>
<goal>check</goal>
</goals>
</execution>
</executions>
</plugin>
<plugin>
<groupId>net.revelc.code.formatter</groupId>
<artifactId>formatter-maven-plugin</artifactId>
<version>2.0.0</version>
<configuration>
<configFile>${session.executionRootDirectory}/contrib/formatter.xml</configFile>
</configuration>
</plugin>
<!-- Configuration for git-commit-id plugin - used with ND4J version check functionality -->
<plugin>
<groupId>pl.project13.maven</groupId>
<artifactId>git-commit-id-plugin</artifactId>
<version>${maven-git-commit-plugin.version}</version>
<executions>
<execution>
<goals>
<goal>revision</goal>
</goals>
<phase>generate-resources</phase>
</execution>
</executions>
<configuration>
<generateGitPropertiesFile>true</generateGitPropertiesFile>
<generateGitPropertiesFilename>
${project.basedir}/target/generated-sources/src/main/resources/ai/skymind/${project.groupId}-${project.artifactId}-git.properties
</generateGitPropertiesFilename>
<gitDescribe>
<skip>true</skip>
</gitDescribe>
</configuration>
</plugin>
<!-- Add generated git.properties files resource directory, for output of git-commit-id plugin -->
<plugin>
<groupId>org.codehaus.mojo</groupId>
<artifactId>build-helper-maven-plugin</artifactId>
<version>${maven-build-helper-plugin.version}</version>
<executions>
<execution>
<id>add-resource</id>
<phase>generate-resources</phase>
<goals>
<goal>add-resource</goal>
</goals>
<configuration>
<resources>
<resource>
<directory>
${project.basedir}/target/generated-sources/src/main/resources
</directory>
</resource>
</resources>
</configuration>
</execution>
</executions>
</plugin>
</plugins>
<pluginManagement>
<plugins>
<plugin>
<groupId>org.eclipse.m2e</groupId>
<artifactId>lifecycle-mapping</artifactId>
<version>1.0.0</version>
<configuration>
<lifecycleMappingMetadata>
<pluginExecutions>
<pluginExecution>
<pluginExecutionFilter>
<groupId>com.lewisd</groupId>
<artifactId>lint-maven-plugin</artifactId>
<versionRange>[0.0.11,)</versionRange>
<goals>
<goal>check</goal>
</goals>
</pluginExecutionFilter>
<action>
<ignore/>
</action>
</pluginExecution>
</pluginExecutions>
</lifecycleMappingMetadata>
</configuration>
</plugin>
</plugins>
</pluginManagement>
</build>
<reporting>
<plugins>
<plugin>
<artifactId>maven-surefire-report-plugin</artifactId>
<version>2.19.1</version>
</plugin>
<!-- Test coverage -->
<plugin>
<groupId>org.codehaus.mojo</groupId>
<artifactId>cobertura-maven-plugin</artifactId>
<version>2.7</version>
</plugin>
</plugins>
</reporting>
<profiles>
<profile>
<id>nd4j-backend</id>
<activation>
<property>
<name>libnd4j.cuda</name>
</property>
</activation>
<properties>
<nd4j.backend>nd4j-cuda-${libnd4j.cuda}</nd4j.backend>
</properties>
</profile>
</profiles>
</project>

View File

@ -1,200 +0,0 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.gym;
import com.mashape.unirest.http.JsonNode;
import lombok.Value;
import lombok.extern.slf4j.Slf4j;
import org.deeplearning4j.rl4j.space.GymObservationSpace;
import org.deeplearning4j.rl4j.space.ActionSpace;
import org.json.JSONObject;
import java.util.Set;
/**
* @author rubenfiszel (ruben.fiszel@epfl.ch) on 7/6/16.
*
* A client represent an active connection to a specific instance of an environment on a rl4j-http-api server.
* for API specification
*
* @param <O> Observation type
* @param <A> Action type
* @param <AS> Action Space type
* @see <a href="https://github.com/openai/gym-http-api#api-specification">https://github.com/openai/gym-http-api#api-specification</a>
*/
@Slf4j
@Value
public class Client<O, A, AS extends ActionSpace<A>> {
public static String V1_ROOT = "/v1";
public static String ENVS_ROOT = V1_ROOT + "/envs/";
public static String MONITOR_START = "/monitor/start/";
public static String MONITOR_CLOSE = "/monitor/close/";
public static String CLOSE = "/close/";
public static String RESET = "/reset/";
public static String SHUTDOWN = "/shutdown/";
public static String UPLOAD = "/upload/";
public static String STEP = "/step/";
public static String OBSERVATION_SPACE = "/observation_space/";
public static String ACTION_SPACE = "/action_space/";
String url;
String envId;
String instanceId;
GymObservationSpace<O> observationSpace;
AS actionSpace;
boolean render;
/**
* @param url url of the server
* @return set of all environments running on the server at the url
*/
public static Set<String> listAll(String url) {
JSONObject reply = ClientUtils.get(url + ENVS_ROOT);
return reply.getJSONObject("envs").keySet();
}
/**
* Shutdown the server at the url
*
* @param url url of the server
*/
public static void serverShutdown(String url) {
ClientUtils.post(url + ENVS_ROOT + SHUTDOWN, new JSONObject());
}
/**
* @return set of all environments running on the same server than this client
*/
public Set<String> listAll() {
return listAll(url);
}
/**
* Step the environment by one action
*
* @param action action to step the environment with
* @return the StepReply containing the next observation, the reward, if it is a terminal state and optional information.
*/
public StepReply<O> step(A action) {
JSONObject body = new JSONObject().put("action", getActionSpace().encode(action)).put("render", render);
JSONObject reply = ClientUtils.post(url + ENVS_ROOT + instanceId + STEP, body).getObject();
O observation = observationSpace.getValue(reply, "observation");
double reward = reply.getDouble("reward");
boolean done = reply.getBoolean("done");
JSONObject info = reply.getJSONObject("info");
return new StepReply<O>(observation, reward, done, info);
}
/**
* Reset the state of the environment and return an initial observation.
*
* @return initial observation
*/
public O reset() {
JsonNode resetRep = ClientUtils.post(url + ENVS_ROOT + instanceId + RESET, new JSONObject());
return observationSpace.getValue(resetRep.getObject(), "observation");
}
/*
Present in the doc but not working currently server-side
public void monitorStart(String directory) {
JSONObject json = new JSONObject()
.put("directory", directory);
monitorStartPost(json);
}
*/
/**
* Start monitoring.
*
* @param directory path to directory in which store the monitoring file
* @param force clear out existing training data from this directory (by deleting every file prefixed with "openaigym.")
* @param resume retain the training data already in this directory, which will be merged with our new data
*/
public void monitorStart(String directory, boolean force, boolean resume) {
JSONObject json = new JSONObject().put("directory", directory).put("force", force).put("resume", resume);
monitorStartPost(json);
}
private void monitorStartPost(JSONObject json) {
ClientUtils.post(url + ENVS_ROOT + instanceId + MONITOR_START, json);
}
/**
* Flush all monitor data to disk
*/
public void monitorClose() {
ClientUtils.post(url + ENVS_ROOT + instanceId + MONITOR_CLOSE, new JSONObject());
}
/**
* Upload monitoring data to OpenAI servers.
*
* @param trainingDir directory that contains the monitoring data
* @param apiKey personal OpenAI API key
* @param algorithmId an arbitrary string indicating the paricular version of the algorithm (including choices of parameters) you are running.
**/
public void upload(String trainingDir, String apiKey, String algorithmId) {
JSONObject json = new JSONObject().put("training_dir", trainingDir).put("api_key", apiKey).put("algorithm_id",
algorithmId);
uploadPost(json);
}
/**
* Upload monitoring data to OpenAI servers.
*
* @param trainingDir directory that contains the monitoring data
* @param apiKey personal OpenAI API key
*/
public void upload(String trainingDir, String apiKey) {
JSONObject json = new JSONObject().put("training_dir", trainingDir).put("api_key", apiKey);
uploadPost(json);
}
private void uploadPost(JSONObject json) {
try {
ClientUtils.post(url + V1_ROOT + UPLOAD, json);
} catch (RuntimeException e) {
log.error("Impossible to upload: Wrong API key?");
}
}
/**
* Shutdown the server at the same url than this client
*/
public void serverShutdown() {
serverShutdown(url);
}
public ActionSpace getActionSpace(){
return actionSpace;
}
}

View File

@ -1,90 +0,0 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.gym;
import org.deeplearning4j.rl4j.space.ActionSpace;
import org.deeplearning4j.rl4j.space.DiscreteSpace;
import org.deeplearning4j.rl4j.space.GymObservationSpace;
import org.deeplearning4j.rl4j.space.HighLowDiscrete;
import org.json.JSONArray;
import org.json.JSONException;
import org.json.JSONObject;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
/**
* @author rubenfiszel (ruben.fiszel@epfl.ch) on 7/8/16.
*
* ClientFactory contains builder method to create a new {@link Client}
*/
public class ClientFactory {
public static <O, A, AS extends ActionSpace<A>> Client<O, A, AS> build(String url, String envId, boolean render) {
JSONObject body = new JSONObject().put("env_id", envId);
JSONObject reply = ClientUtils.post(url + Client.ENVS_ROOT, body).getObject();
String instanceId;
try {
instanceId = reply.getString("instance_id");
} catch (JSONException e) {
throw new RuntimeException("Environment id not found", e);
}
GymObservationSpace<O> observationSpace = fetchObservationSpace(url, instanceId);
AS actionSpace = fetchActionSpace(url, instanceId);
return new Client(url, envId, instanceId, observationSpace, actionSpace, render);
}
public static <O, A, AS extends ActionSpace<A>> Client<O, A, AS> build(String envId, boolean render) {
return build("http://127.0.0.1:5000", envId, render);
}
public static <AS extends ActionSpace> AS fetchActionSpace(String url, String instanceId) {
JSONObject reply = ClientUtils.get(url + Client.ENVS_ROOT + instanceId + Client.ACTION_SPACE);
JSONObject info = reply.getJSONObject("info");
String infoName = info.getString("name");
switch (infoName) {
case "Discrete":
return (AS) new DiscreteSpace(info.getInt("n"));
case "HighLow":
int numRows = info.getInt("num_rows");
int size = 3 * numRows;
JSONArray matrixJson = info.getJSONArray("matrix");
INDArray matrix = Nd4j.create(numRows, 3);
for (int i = 0; i < size; i++) {
matrix.putScalar(i, matrixJson.getDouble(i));
}
matrix.reshape(numRows, 3);
return (AS) new HighLowDiscrete(matrix);
default:
throw new RuntimeException("Unknown action space " + infoName);
}
}
public static <O> GymObservationSpace<O> fetchObservationSpace(String url, String instanceId) {
JSONObject reply = ClientUtils.get(url + Client.ENVS_ROOT + instanceId + Client.OBSERVATION_SPACE);
JSONObject info = reply.getJSONObject("info");
return new GymObservationSpace<O>(info);
}
}

View File

@ -1,75 +0,0 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.gym;
import com.mashape.unirest.http.HttpResponse;
import com.mashape.unirest.http.JsonNode;
import com.mashape.unirest.http.Unirest;
import com.mashape.unirest.http.exceptions.UnirestException;
import org.json.JSONException;
import org.json.JSONObject;
/**
* @author rubenfiszel (ruben.fiszel@epfl.ch) on 7/8/16.
*
* ClientUtils contain the utility methods to post and get data from the server REST API through the library unirest.
*/
public class ClientUtils {
static public JsonNode post(String url, JSONObject json) {
HttpResponse<JsonNode> jsonResponse = null;
try {
jsonResponse = Unirest.post(url).header("content-type", "application/json").body(json).asJson();
} catch (UnirestException e) {
unirestCrash(e);
}
return jsonResponse.getBody();
}
static public JSONObject get(String url) {
HttpResponse<JsonNode> jsonResponse = null;
try {
jsonResponse = Unirest.get(url).header("content-type", "application/json").asJson();
} catch (UnirestException e) {
unirestCrash(e);
}
checkReply(jsonResponse, url);
return jsonResponse.getBody().getObject();
}
static public void checkReply(HttpResponse<JsonNode> res, String url) {
if (res.getBody() == null)
throw new RuntimeException("Invalid reply at: " + url);
}
static public void unirestCrash(UnirestException e) {
//if couldn't parse json
if (e.getCause().getCause().getCause() instanceof JSONException)
throw new RuntimeException("Couldn't parse json reply.", e);
else
throw new RuntimeException("Connection error", e);
}
}

View File

@ -1,82 +0,0 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.rl4j.space;
import lombok.Value;
import org.json.JSONArray;
import org.json.JSONObject;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
/**
* @author rubenfiszel (ruben.fiszel@epfl.ch) on 7/8/16.
*
* Contain contextual information about the environment from which Observations are observed and must know how to build an Observation from json.
*
* @param <O> the type of Observation
*/
@Value
public class GymObservationSpace<O> implements ObservationSpace<O> {
String name;
int[] shape;
INDArray low;
INDArray high;
public GymObservationSpace(JSONObject jsonObject) {
name = jsonObject.getString("name");
JSONArray arr = jsonObject.getJSONArray("shape");
int lg = arr.length();
shape = new int[lg];
for (int i = 0; i < lg; i++) {
this.shape[i] = arr.getInt(i);
}
low = Nd4j.create(shape);
high = Nd4j.create(shape);
JSONArray lowJson = jsonObject.getJSONArray("low");
JSONArray highJson = jsonObject.getJSONArray("high");
int size = shape[0];
for (int i = 1; i < shape.length; i++) {
size *= shape[i];
}
for (int i = 0; i < size; i++) {
low.putScalar(i, lowJson.getDouble(i));
high.putScalar(i, highJson.getDouble(i));
}
}
public O getValue(JSONObject o, String key) {
switch (name) {
case "Box":
JSONArray arr = o.getJSONArray(key);
return (O) new Box(arr);
default:
throw new RuntimeException("Invalid environment name: " + name);
}
}
}

View File

@ -1,140 +0,0 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.gym.test;
import com.mashape.unirest.http.JsonNode;
import org.deeplearning4j.gym.Client;
import org.deeplearning4j.gym.ClientFactory;
import org.deeplearning4j.gym.ClientUtils;
import org.deeplearning4j.gym.StepReply;
import org.deeplearning4j.rl4j.space.ActionSpace;
import org.deeplearning4j.rl4j.space.Box;
import org.deeplearning4j.rl4j.space.DiscreteSpace;
import org.json.JSONObject;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.powermock.core.classloader.annotations.PrepareForTest;
import org.powermock.modules.junit4.PowerMockRunner;
import static org.mockito.Matchers.eq;
import static org.powermock.api.mockito.PowerMockito.mockStatic;
import static org.powermock.api.mockito.PowerMockito.when;
/**
* @author rubenfiszel (ruben.fiszel@epfl.ch) on 7/11/16.
*/
@RunWith(PowerMockRunner.class)
@PrepareForTest(ClientUtils.class)
public class ClientTest {
@Test
public void completeClientTest() {
String url = "http://127.0.0.1:5000";
String env = "Powermock-v0";
String instanceID = "e15739cf";
String testDir = "/tmp/testDir";
boolean render = true;
String renderStr = render ? "True" : "False";
mockStatic(ClientUtils.class);
//post mock
JSONObject buildReq = new JSONObject("{\"env_id\":\"" + env + "\"}");
JsonNode buildRep = new JsonNode("{\"instance_id\":\"" + instanceID + "\"}");
when(ClientUtils.post(eq(url + Client.ENVS_ROOT), JSONObjectMatcher.jsonEq(buildReq))).thenReturn(buildRep);
JSONObject monStartReq = new JSONObject("{\"resume\":false,\"directory\":\"" + testDir + "\",\"force\":true}");
when(ClientUtils.post(eq(url + Client.ENVS_ROOT + instanceID + Client.MONITOR_START),
JSONObjectMatcher.jsonEq(monStartReq))).thenReturn(null);
JSONObject monStopReq = new JSONObject("{}");
when(ClientUtils.post(eq(url + Client.ENVS_ROOT + instanceID + Client.MONITOR_CLOSE),
JSONObjectMatcher.jsonEq(monStopReq))).thenReturn(null);
JSONObject resetReq = new JSONObject("{}");
JsonNode resetRep = new JsonNode(
"{\"observation\":[0.021729452941849317,-0.04764548144956857,-0.024914502756611293,-0.04074903379512588]}");
when(ClientUtils.post(eq(url + Client.ENVS_ROOT + instanceID + Client.RESET),
JSONObjectMatcher.jsonEq(resetReq))).thenReturn(resetRep);
JSONObject stepReq = new JSONObject("{\"action\":0, \"render\":" + renderStr + "}");
JsonNode stepRep = new JsonNode(
"{\"observation\":[0.020776543312857946,-0.24240146656155923,-0.02572948343251381,0.24397017400615437],\"reward\":1,\"done\":false,\"info\":{}}");
when(ClientUtils.post(eq(url + Client.ENVS_ROOT + instanceID + Client.STEP), JSONObjectMatcher.jsonEq(stepReq)))
.thenReturn(stepRep);
JSONObject stepReq2 = new JSONObject("{\"action\":1, \"render\":" + renderStr + "}");
JsonNode stepRep2 = new JsonNode(
"{\"observation\":[0.020776543312857946,-0.24240146656155923,-0.02572948343251381,0.24397017400615437],\"reward\":1,\"done\":false,\"info\":{}}");
when(ClientUtils.post(eq(url + Client.ENVS_ROOT + instanceID + Client.STEP),
JSONObjectMatcher.jsonEq(stepReq2))).thenReturn(stepRep2);
//get mock
JSONObject obsSpace = new JSONObject(
"{\"info\":{\"name\":\"Box\",\"shape\":[4],\"high\":[4.8,3.4028234663852886E38,0.41887902047863906,3.4028234663852886E38],\"low\":[-4.8,-3.4028234663852886E38,-0.41887902047863906,-3.4028234663852886E38]}}");
when(ClientUtils.get(eq(url + Client.ENVS_ROOT + instanceID + Client.OBSERVATION_SPACE))).thenReturn(obsSpace);
JSONObject actionSpace = new JSONObject("{\"info\":{\"name\":\"Discrete\",\"n\":2}}");
when(ClientUtils.get(eq(url + Client.ENVS_ROOT + instanceID + Client.ACTION_SPACE))).thenReturn(actionSpace);
//test
Client<Box, Integer, DiscreteSpace> client = ClientFactory.build(url, env, render);
client.monitorStart(testDir, true, false);
int episodeCount = 1;
int maxSteps = 200;
int reward = 0;
for (int i = 0; i < episodeCount; i++) {
client.reset();
for (int j = 0; j < maxSteps; j++) {
Integer action = ((ActionSpace<Integer>)client.getActionSpace()).randomAction();
StepReply<Box> step = client.step(action);
reward += step.getReward();
//return a isDone true before i == maxSteps
if (j == maxSteps - 5) {
JSONObject stepReqLoc = new JSONObject("{\"action\":0}");
JsonNode stepRepLoc = new JsonNode(
"{\"observation\":[0.020776543312857946,-0.24240146656155923,-0.02572948343251381,0.24397017400615437],\"reward\":1,\"done\":true,\"info\":{}}");
when(ClientUtils.post(eq(url + Client.ENVS_ROOT + instanceID + Client.STEP),
JSONObjectMatcher.jsonEq(stepReqLoc))).thenReturn(stepRepLoc);
}
if (step.isDone()) {
// System.out.println("break");
break;
}
}
}
client.monitorClose();
client.upload(testDir, "YOUR_OPENAI_GYM_API_KEY");
}
}

View File

@ -1,46 +0,0 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.gym.test;
import org.json.JSONObject;
import org.mockito.ArgumentMatcher;
import static org.mockito.Matchers.argThat;
/**
* @author rubenfiszel (ruben.fiszel@epfl.ch) on 7/11/16.
*/
public class JSONObjectMatcher implements ArgumentMatcher<JSONObject> {
private final JSONObject expected;
public JSONObjectMatcher(JSONObject expected) {
this.expected = expected;
}
public static JSONObject jsonEq(JSONObject expected) {
return argThat(new JSONObjectMatcher(expected));
}
@Override
public boolean matches(JSONObject argument) {
if (expected == null)
return argument == null;
return expected.toString().equals(argument.toString()); }
}

22
pom.xml
View File

@ -132,7 +132,6 @@
<module>deeplearning4j</module>
<module>arbiter</module>
<module>nd4s</module>
<module>gym-java-client</module>
<module>rl4j</module>
<module>scalnet</module>
<module>jumpy</module>
@ -288,22 +287,23 @@
<javacpp.platform.extension/> <!-- -Djavacpp.platform.extension=-avx512 -->
<javacpp.platform.properties>${javacpp.platform}</javacpp.platform.properties>
<javacpp.version>1.5.2</javacpp.version>
<javacpp-presets.version>1.5.2</javacpp-presets.version>
<javacv.version>1.5.2</javacv.version>
<javacpp.version>1.5.3-SNAPSHOT</javacpp.version>
<javacpp-presets.version>1.5.3-SNAPSHOT</javacpp-presets.version>
<javacv.version>1.5.3-SNAPSHOT</javacv.version>
<python.version>3.7.5</python.version>
<python.version>3.7.6</python.version>
<cpython-platform.version>${python.version}-${javacpp-presets.version}</cpython-platform.version>
<numpy.version>1.17.3</numpy.version>
<numpy.version>1.18.0</numpy.version>
<numpy.javacpp.version>${numpy.version}-${javacpp-presets.version}</numpy.javacpp.version>
<openblas.version>0.3.7</openblas.version>
<mkl.version>2019.5</mkl.version>
<opencv.version>4.1.2</opencv.version>
<ffmpeg.version>4.2.1</ffmpeg.version>
<leptonica.version>1.78.0</leptonica.version>
<hdf5.version>1.10.5</hdf5.version>
<ale.version>0.6.0</ale.version>
<opencv.version>4.2.0</opencv.version>
<ffmpeg.version>4.2.2</ffmpeg.version>
<leptonica.version>1.79.0</leptonica.version>
<hdf5.version>1.10.6</hdf5.version>
<ale.version>0.6.1</ale.version>
<gym.version>0.15.4</gym.version>
<tensorflow.version>1.15.0</tensorflow.version>
<tensorflow.javacpp.version>${tensorflow.version}-${javacpp-presets.version}</tensorflow.javacpp.version>

View File

@ -1,27 +1,27 @@
################################################################################
# Copyright (c) 2015-2019 Skymind, Inc.
#
# This program and the accompanying materials are made available under the
# terms of the Apache License, Version 2.0 which is available at
# https://www.apache.org/licenses/LICENSE-2.0.
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
#
# SPDX-License-Identifier: Apache-2.0
################################################################################
################################################################################
# Copyright (c) 2015-2019 Skymind, Inc.
#
# This program and the accompanying materials are made available under the
# terms of the Apache License, Version 2.0 which is available at
# https://www.apache.org/licenses/LICENSE-2.0.
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
#
# SPDX-License-Identifier: Apache-2.0
################################################################################
from setuptools import setup
from setuptools import find_packages
setup(
name='pydl4j',
version='0.1.3',
version='0.1.5',
packages=find_packages(),
install_requires=['Cython', 'jnius', 'requests',
install_requires=['Cython', 'pyjnius', 'requests',
'click', 'argcomplete', 'python-dateutil'],
extras_require={
'tests': ['pytest', 'pytest-pep8', 'pytest-cov']

View File

@ -32,10 +32,6 @@ Comments are welcome on our gitter channel:
# Quickstart
** INSTALL rl4j-api before installing all (see below)!**
* mvn install -pl rl4j-api
* [if you want rl4j-gym too] Download and mvn install: [gym-java-client](https://github.com/eclipse/deeplearning4j/tree/master/gym-java-client)
* mvn install
# Visualisation
@ -44,9 +40,7 @@ Comments are welcome on our gitter channel:
# Quicktry cartpole:
* Install [gym-http-api](https://github.com/openai/gym-http-api).
* launch http api server.
* run with this [main](https://github.com/rubenfiszel/rl4j-examples/blob/master/src/main/java/org/deeplearning4j/rl4j/Cartpole.java)
* run with this [main](https://github.com/eclipse/deeplearning4j-examples/blob/master/rl4j-examples/src/main/java/org/deeplearning4j/examples/rl4j/Cartpole.java)
# Doom
@ -83,4 +77,4 @@ Doom is not ready yet but you can make it work if you feel adventurous with some
* Continuous control
* Policy Gradient
* Update gym-java-client when gym-http-api gets compatible with pixels environments to play with Pong, Doom, etc ..
* Update rl4j-gym to make it compatible with pixels environments to play with Pong, Doom, etc ..

View File

@ -97,6 +97,30 @@
<build>
<plugins>
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-enforcer-plugin</artifactId>
<version>${maven-enforcer-plugin.version}</version>
<executions>
<execution>
<phase>test</phase>
<id>enforce-choice-of-nd4j-test-backend</id>
<goals>
<goal>enforce</goal>
</goals>
<configuration>
<skip>${skipBackendChoice}</skip>
<rules>
<requireActiveProfile>
<profiles>test-nd4j-native,test-nd4j-cuda-10.2</profiles>
<all>false</all>
</requireActiveProfile>
</rules>
<fail>true</fail>
</configuration>
</execution>
</executions>
</plugin>
<plugin>
<artifactId>maven-source-plugin</artifactId>
<version>${maven-source-plugin.version}</version>
@ -265,6 +289,32 @@
</pluginManagement>
</build>
<profiles>
<profile>
<id>test-nd4j-native</id>
<dependencies>
<dependency>
<groupId>org.nd4j</groupId>
<artifactId>nd4j-native</artifactId>
<version>${nd4j.version}</version>
<scope>test</scope>
</dependency>
</dependencies>
</profile>
<profile>
<id>test-nd4j-cuda-10.2</id>
<dependencies>
<dependency>
<groupId>org.nd4j</groupId>
<artifactId>nd4j-cuda-10.2</artifactId>
<version>${nd4j.version}</version>
<scope>test</scope>
</dependency>
</dependencies>
</profile>
</profiles>
<reporting>
<plugins>
<plugin>

View File

@ -44,4 +44,13 @@
<version>${ale.version}-${javacpp-presets.version}</version>
</dependency>
</dependencies>
<profiles>
<profile>
<id>test-nd4j-native</id>
</profile>
<profile>
<id>test-nd4j-cuda-10.2</id>
</profile>
</profiles>
</project>

View File

@ -33,15 +33,19 @@
</properties>
<dependencies>
<dependency>
<groupId>org.deeplearning4j</groupId>
<artifactId>gym-java-client</artifactId>
<version>${dl4j.version}</version>
</dependency>
<dependency>
<groupId>org.nd4j</groupId>
<artifactId>nd4j-api</artifactId>
<version>${nd4j.version}</version>
</dependency>
</dependencies>
<profiles>
<profile>
<id>test-nd4j-native</id>
</profile>
<profile>
<id>test-nd4j-cuda-10.2</id>
</profile>
</profiles>
</project>

View File

@ -17,7 +17,6 @@
package org.deeplearning4j.gym;
import lombok.Value;
import org.json.JSONObject;
/**
* @param <T> type of observation
@ -31,6 +30,6 @@ public class StepReply<T> {
T observation;
double reward;
boolean done;
JSONObject info;
Object info;
}

View File

@ -16,8 +16,6 @@
package org.deeplearning4j.rl4j.space;
import org.json.JSONArray;
/**
* @author rubenfiszel (ruben.fiszel@epfl.ch) on 7/8/16.
*
@ -29,14 +27,8 @@ public class Box implements Encodable {
private final double[] array;
public Box(JSONArray arr) {
int lg = arr.length();
this.array = new double[lg];
for (int i = 0; i < lg; i++) {
this.array[i] = arr.getDouble(i);
}
public Box(double[] arr) {
this.array = arr;
}
public double[] toArray() {

View File

@ -17,10 +17,8 @@
package org.deeplearning4j.rl4j.space;
import lombok.Value;
import org.json.JSONArray;
import org.nd4j.linalg.api.ndarray.INDArray;
/**
* @author rubenfiszel (ruben.fiszel@epfl.ch) on 7/26/16.
*/
@ -37,12 +35,8 @@ public class HighLowDiscrete extends DiscreteSpace {
@Override
public Object encode(Integer a) {
JSONArray jsonArray = new JSONArray();
for (int i = 0; i < size; i++) {
jsonArray.put(matrix.getDouble(i, 0));
}
jsonArray.put(a - 1, matrix.getDouble(a - 1, 1));
return jsonArray;
INDArray m = matrix.dup();
m.put(a - 1, 0, matrix.getDouble(a - 1, 1));
return m;
}
}

View File

@ -44,11 +44,6 @@
<groupId>ch.qos.logback</groupId>
<artifactId>logback-classic</artifactId>
</dependency>
<dependency>
<groupId>org.deeplearning4j</groupId>
<artifactId>gym-java-client</artifactId>
<version>${dl4j.version}</version>
</dependency>
<dependency>
<groupId>org.bytedeco</groupId>
@ -111,27 +106,10 @@
<profiles>
<profile>
<id>nd4j-tests-cpu</id>
<dependencies>
<dependency>
<groupId>org.nd4j</groupId>
<artifactId>nd4j-native</artifactId>
<version>${project.version}</version>
<scope>test</scope>
</dependency>
</dependencies>
<id>test-nd4j-native</id>
</profile>
<profile>
<id>nd4j-tests-cuda</id>
<dependencies>
<dependency>
<groupId>org.nd4j</groupId>
<artifactId>nd4j-cuda-10.2</artifactId>
<version>${project.version}</version>
<scope>test</scope>
</dependency>
</dependencies>
<id>test-nd4j-cuda-10.2</id>
</profile>
</profiles>
</project>

View File

@ -57,6 +57,7 @@ public class AsyncGlobal<NN extends NeuralNet> extends Thread implements IAsyncG
final private NN current;
final private ConcurrentLinkedQueue<Pair<Gradient[], Integer>> queue;
final private AsyncConfiguration a3cc;
private final IAsyncLearning learning;
@Getter
private AtomicInteger T = new AtomicInteger(0);
@Getter
@ -64,10 +65,11 @@ public class AsyncGlobal<NN extends NeuralNet> extends Thread implements IAsyncG
@Getter
private boolean running = true;
public AsyncGlobal(NN initial, AsyncConfiguration a3cc) {
public AsyncGlobal(NN initial, AsyncConfiguration a3cc, IAsyncLearning learning) {
this.current = initial;
target = (NN) initial.clone();
this.a3cc = a3cc;
this.learning = learning;
queue = new ConcurrentLinkedQueue<>();
}
@ -106,11 +108,14 @@ public class AsyncGlobal<NN extends NeuralNet> extends Thread implements IAsyncG
}
/**
* Force the immediate termination of the AsyncGlobal instance. Queued work items will be discarded.
* Force the immediate termination of the AsyncGlobal instance. Queued work items will be discarded and the AsyncLearning instance will be forced to terminate too.
*/
public void terminate() {
running = false;
queue.clear();
if(running) {
running = false;
queue.clear();
learning.terminate();
}
}
}

View File

@ -37,7 +37,10 @@ import org.nd4j.linalg.factory.Nd4j;
*/
@Slf4j
public abstract class AsyncLearning<O extends Encodable, A, AS extends ActionSpace<A>, NN extends NeuralNet>
extends Learning<O, A, AS, NN> {
extends Learning<O, A, AS, NN>
implements IAsyncLearning {
private Thread monitorThread = null;
@Getter(AccessLevel.PROTECTED)
private final TrainingListenerList listeners = new TrainingListenerList();
@ -126,6 +129,7 @@ public abstract class AsyncLearning<O extends Encodable, A, AS extends ActionSpa
protected void monitorTraining() {
try {
monitorThread = Thread.currentThread();
while (canContinue && !isTrainingComplete() && getAsyncGlobal().isRunning()) {
canContinue = listeners.notifyTrainingProgress(this);
if(!canContinue) {
@ -139,10 +143,25 @@ public abstract class AsyncLearning<O extends Encodable, A, AS extends ActionSpa
} catch (InterruptedException e) {
log.error("Training interrupted.", e);
}
monitorThread = null;
}
protected void cleanupPostTraining() {
// Worker threads stops automatically when the global thread stops
getAsyncGlobal().terminate();
}
/**
* Force the immediate termination of the learning. All learning threads, the AsyncGlobal thread and the monitor thread will be terminated.
*/
public void terminate() {
if(canContinue) {
canContinue = false;
Thread safeMonitorThread = monitorThread;
if(safeMonitorThread != null) {
safeMonitorThread.interrupt();
}
}
}
}

View File

@ -21,15 +21,19 @@ import lombok.Getter;
import lombok.Setter;
import lombok.Value;
import lombok.extern.slf4j.Slf4j;
import org.deeplearning4j.gym.StepReply;
import org.deeplearning4j.rl4j.learning.*;
import org.deeplearning4j.rl4j.learning.listener.TrainingListener;
import org.deeplearning4j.rl4j.learning.listener.TrainingListenerList;
import org.deeplearning4j.rl4j.mdp.MDP;
import org.deeplearning4j.rl4j.network.NeuralNet;
import org.deeplearning4j.rl4j.observation.Observation;
import org.deeplearning4j.rl4j.policy.IPolicy;
import org.deeplearning4j.rl4j.space.ActionSpace;
import org.deeplearning4j.rl4j.space.DiscreteSpace;
import org.deeplearning4j.rl4j.space.Encodable;
import org.deeplearning4j.rl4j.util.IDataManager;
import org.deeplearning4j.rl4j.util.LegacyMDPWrapper;
import org.nd4j.linalg.factory.Nd4j;
/**
@ -43,7 +47,7 @@ import org.nd4j.linalg.factory.Nd4j;
* @author Alexandre Boulanger
*/
@Slf4j
public abstract class AsyncThread<O extends Encodable, A, AS extends ActionSpace<A>, NN extends NeuralNet>
public abstract class AsyncThread<O, A, AS extends ActionSpace<A>, NN extends NeuralNet>
extends Thread implements StepCountable, IEpochTrainer {
@Getter
@ -54,26 +58,35 @@ public abstract class AsyncThread<O extends Encodable, A, AS extends ActionSpace
private int stepCounter = 0;
@Getter @Setter
private int epochCounter = 0;
@Getter
private MDP<O, A, AS> mdp;
@Getter @Setter
private IHistoryProcessor historyProcessor;
private boolean isEpochStarted = false;
private final LegacyMDPWrapper<O, A, AS> mdp;
private final TrainingListenerList listeners;
public AsyncThread(IAsyncGlobal<NN> asyncGlobal, MDP<O, A, AS> mdp, TrainingListenerList listeners, int threadNumber, int deviceNum) {
this.mdp = mdp;
this.mdp = new LegacyMDPWrapper<O, A, AS>(mdp, null, this);
this.listeners = listeners;
this.threadNumber = threadNumber;
this.deviceNum = deviceNum;
}
public MDP<O, A, AS> getMdp() {
return mdp.getWrappedMDP();
}
protected LegacyMDPWrapper<O, A, AS> getLegacyMDPWrapper() {
return mdp;
}
public void setHistoryProcessor(IHistoryProcessor.Configuration conf) {
historyProcessor = new HistoryProcessor(conf);
setHistoryProcessor(new HistoryProcessor(conf));
}
public void setHistoryProcessor(IHistoryProcessor historyProcessor) {
this.historyProcessor = historyProcessor;
mdp.setHistoryProcessor(historyProcessor);
}
protected void postEpoch() {
@ -109,61 +122,63 @@ public abstract class AsyncThread<O extends Encodable, A, AS extends ActionSpace
*/
@Override
public void run() {
RunContext<O> context = new RunContext<>();
Nd4j.getAffinityManager().unsafeSetDevice(deviceNum);
try {
RunContext context = new RunContext();
Nd4j.getAffinityManager().unsafeSetDevice(deviceNum);
log.info("ThreadNum-" + threadNumber + " Started!");
boolean canContinue = initWork(context);
if (canContinue) {
log.info("ThreadNum-" + threadNumber + " Started!");
while (!getAsyncGlobal().isTrainingComplete() && getAsyncGlobal().isRunning()) {
handleTraining(context);
if (context.epochElapsedSteps >= getConf().getMaxEpochStep() || getMdp().isDone()) {
canContinue = finishEpoch(context) && startNewEpoch(context);
if (!isEpochStarted) {
boolean canContinue = startNewEpoch(context);
if (!canContinue) {
break;
}
}
handleTraining(context);
if (context.epochElapsedSteps >= getConf().getMaxEpochStep() || getMdp().isDone()) {
boolean canContinue = finishEpoch(context);
if (!canContinue) {
break;
}
++epochCounter;
}
}
}
terminateWork();
finally {
terminateWork();
}
}
private void initNewEpoch(RunContext context) {
getCurrent().reset();
Learning.InitMdp<O> initMdp = Learning.initMdp(getMdp(), historyProcessor);
context.obs = initMdp.getLastObs();
context.rewards = initMdp.getReward();
context.epochElapsedSteps = initMdp.getSteps();
}
private void handleTraining(RunContext<O> context) {
private void handleTraining(RunContext context) {
int maxSteps = Math.min(getConf().getNstep(), getConf().getMaxEpochStep() - context.epochElapsedSteps);
SubEpochReturn<O> subEpochReturn = trainSubEpoch(context.obs, maxSteps);
SubEpochReturn subEpochReturn = trainSubEpoch(context.obs, maxSteps);
context.obs = subEpochReturn.getLastObs();
stepCounter += subEpochReturn.getSteps();
context.epochElapsedSteps += subEpochReturn.getSteps();
context.rewards += subEpochReturn.getReward();
context.score = subEpochReturn.getScore();
}
private boolean initWork(RunContext context) {
initNewEpoch(context);
preEpoch();
return listeners.notifyNewEpoch(this);
}
private boolean startNewEpoch(RunContext context) {
initNewEpoch(context);
epochCounter++;
getCurrent().reset();
Learning.InitMdp<Observation> initMdp = refacInitMdp();
context.obs = initMdp.getLastObs();
context.rewards = initMdp.getReward();
context.epochElapsedSteps = initMdp.getSteps();
isEpochStarted = true;
preEpoch();
return listeners.notifyNewEpoch(this);
}
private boolean finishEpoch(RunContext context) {
isEpochStarted = false;
postEpoch();
IDataManager.StatEntry statEntry = new AsyncStatEntry(getStepCounter(), epochCounter, context.rewards, context.epochElapsedSteps, context.score);
@ -173,8 +188,10 @@ public abstract class AsyncThread<O extends Encodable, A, AS extends ActionSpace
}
private void terminateWork() {
postEpoch();
getAsyncGlobal().terminate();
if(isEpochStarted) {
postEpoch();
}
}
protected abstract NN getCurrent();
@ -185,13 +202,47 @@ public abstract class AsyncThread<O extends Encodable, A, AS extends ActionSpace
protected abstract IPolicy<O, A> getPolicy(NN net);
protected abstract SubEpochReturn<O> trainSubEpoch(O obs, int nstep);
protected abstract SubEpochReturn trainSubEpoch(Observation obs, int nstep);
private Learning.InitMdp<Observation> refacInitMdp() {
LegacyMDPWrapper<O, A, AS> mdp = getLegacyMDPWrapper();
IHistoryProcessor hp = getHistoryProcessor();
Observation observation = mdp.reset();
int step = 0;
double reward = 0;
boolean isHistoryProcessor = hp != null;
int skipFrame = isHistoryProcessor ? hp.getConf().getSkipFrame() : 1;
int requiredFrame = isHistoryProcessor ? skipFrame * (hp.getConf().getHistoryLength() - 1) : 0;
while (step < requiredFrame && !mdp.isDone()) {
A action = mdp.getActionSpace().noOp(); //by convention should be the NO_OP
StepReply<Observation> stepReply = mdp.step(action);
reward += stepReply.getReward();
observation = stepReply.getObservation();
step++;
}
return new Learning.InitMdp(step, observation, reward);
}
public void incrementStep() {
++stepCounter;
}
@AllArgsConstructor
@Value
public static class SubEpochReturn<O> {
public static class SubEpochReturn {
int steps;
O lastObs;
Observation lastObs;
double reward;
double score;
}
@ -206,8 +257,8 @@ public abstract class AsyncThread<O extends Encodable, A, AS extends ActionSpace
double score;
}
private static class RunContext<O extends Encodable> {
private O obs;
private static class RunContext {
private Observation obs;
private double rewards;
private int epochElapsedSteps;
private double score;

View File

@ -25,9 +25,11 @@ import org.deeplearning4j.rl4j.learning.listener.TrainingListenerList;
import org.deeplearning4j.rl4j.learning.sync.Transition;
import org.deeplearning4j.rl4j.mdp.MDP;
import org.deeplearning4j.rl4j.network.NeuralNet;
import org.deeplearning4j.rl4j.observation.Observation;
import org.deeplearning4j.rl4j.policy.IPolicy;
import org.deeplearning4j.rl4j.space.DiscreteSpace;
import org.deeplearning4j.rl4j.space.Encodable;
import org.deeplearning4j.rl4j.util.LegacyMDPWrapper;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.util.ArrayUtil;
@ -40,7 +42,7 @@ import java.util.Stack;
* Async Learning specialized for the Discrete Domain
*
*/
public abstract class AsyncThreadDiscrete<O extends Encodable, NN extends NeuralNet>
public abstract class AsyncThreadDiscrete<O, NN extends NeuralNet>
extends AsyncThread<O, Integer, DiscreteSpace, NN> {
@Getter
@ -61,14 +63,14 @@ public abstract class AsyncThreadDiscrete<O extends Encodable, NN extends Neural
* @param nstep the number of max nstep (step until t_max or state is terminal)
* @return subepoch training informations
*/
public SubEpochReturn<O> trainSubEpoch(O sObs, int nstep) {
public SubEpochReturn trainSubEpoch(Observation sObs, int nstep) {
synchronized (getAsyncGlobal()) {
current.copy(getAsyncGlobal().getCurrent());
}
Stack<MiniTrans<Integer>> rewards = new Stack<>();
O obs = sObs;
Observation obs = sObs;
IPolicy<O, Integer> policy = getPolicy(current);
Integer action;
@ -81,93 +83,53 @@ public abstract class AsyncThreadDiscrete<O extends Encodable, NN extends Neural
int i = 0;
while (!getMdp().isDone() && i < nstep * skipFrame) {
INDArray input = Learning.getInput(getMdp(), obs);
INDArray hstack = null;
if (hp != null) {
hp.record(input);
}
//if step of training, just repeat lastAction
if (i % skipFrame != 0 && lastAction != null) {
action = lastAction;
} else {
hstack = processHistory(input);
action = policy.nextAction(hstack);
action = policy.nextAction(obs);
}
StepReply<O> stepReply = getMdp().step(action);
StepReply<Observation> stepReply = getLegacyMDPWrapper().step(action);
accuReward += stepReply.getReward() * getConf().getRewardFactor();
//if it's not a skipped frame, you can do a step of training
if (i % skipFrame == 0 || lastAction == null || stepReply.isDone()) {
obs = stepReply.getObservation();
if (hstack == null) {
hstack = processHistory(input);
}
INDArray[] output = current.outputAll(hstack);
rewards.add(new MiniTrans(hstack, action, output, accuReward));
INDArray[] output = current.outputAll(obs.getData());
rewards.add(new MiniTrans(obs.getData(), action, output, accuReward));
accuReward = 0;
}
obs = stepReply.getObservation();
reward += stepReply.getReward();
i++;
incrementStep();
lastAction = action;
}
//a bit of a trick usable because of how the stack is treated to init R
INDArray input = Learning.getInput(getMdp(), obs);
INDArray hstack = processHistory(input);
if (hp != null) {
hp.record(input);
}
// FIXME: The last element of minitrans is only used to seed the reward in calcGradient; observation, action and output are ignored.
if (getMdp().isDone() && i < nstep * skipFrame)
rewards.add(new MiniTrans(hstack, null, null, 0));
rewards.add(new MiniTrans(obs.getData(), null, null, 0));
else {
INDArray[] output = null;
if (getConf().getTargetDqnUpdateFreq() == -1)
output = current.outputAll(hstack);
output = current.outputAll(obs.getData());
else synchronized (getAsyncGlobal()) {
output = getAsyncGlobal().getTarget().outputAll(hstack);
output = getAsyncGlobal().getTarget().outputAll(obs.getData());
}
double maxQ = Nd4j.max(output[0]).getDouble(0);
rewards.add(new MiniTrans(hstack, null, output, maxQ));
rewards.add(new MiniTrans(obs.getData(), null, output, maxQ));
}
getAsyncGlobal().enqueue(calcGradient(current, rewards), i);
return new SubEpochReturn<O>(i, obs, reward, current.getLatestScore());
}
protected INDArray processHistory(INDArray input) {
IHistoryProcessor hp = getHistoryProcessor();
INDArray[] history;
if (hp != null) {
hp.add(input);
history = hp.getHistory();
} else
history = new INDArray[] {input};
//concat the history into a single INDArray input
INDArray hstack = Transition.concat(history);
if (hp != null) {
hstack.muli(1.0 / hp.getScale());
}
if (getCurrent().isRecurrent()) {
//flatten everything for the RNN
hstack = hstack.reshape(Learning.makeShape(1, ArrayUtil.toInts(hstack.shape()), 1));
} else {
//if input is not 2d, you have to append that the batch is 1 length high
if (hstack.shape().length > 2)
hstack = hstack.reshape(Learning.makeShape(1, ArrayUtil.toInts(hstack.shape())));
}
return hstack;
return new SubEpochReturn(i, obs, reward, current.getLatestScore());
}
public abstract Gradient[] calcGradient(NN nn, Stack<MiniTrans<Integer>> rewards);

View File

@ -0,0 +1,21 @@
/*******************************************************************************
* Copyright (c) 2015-2019 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.rl4j.learning.async;
public interface IAsyncLearning {
void terminate();
}

View File

@ -53,7 +53,7 @@ public abstract class A3CDiscrete<O extends Encodable> extends AsyncLearning<O,
this.iActorCritic = iActorCritic;
this.mdp = mdp;
this.configuration = conf;
asyncGlobal = new AsyncGlobal<>(iActorCritic, conf);
asyncGlobal = new AsyncGlobal<>(iActorCritic, conf, this);
Integer seed = conf.getSeed();
Random rnd = Nd4j.getRandom();

View File

@ -21,6 +21,7 @@ import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.rl4j.learning.Learning;
import org.deeplearning4j.rl4j.learning.async.AsyncGlobal;
import org.deeplearning4j.rl4j.learning.async.AsyncThreadDiscrete;
import org.deeplearning4j.rl4j.learning.async.IAsyncGlobal;
import org.deeplearning4j.rl4j.learning.async.MiniTrans;
import org.deeplearning4j.rl4j.learning.listener.TrainingListenerList;
import org.deeplearning4j.rl4j.mdp.MDP;
@ -46,13 +47,13 @@ public class A3CThreadDiscrete<O extends Encodable> extends AsyncThreadDiscrete<
@Getter
final protected A3CDiscrete.A3CConfiguration conf;
@Getter
final protected AsyncGlobal<IActorCritic> asyncGlobal;
final protected IAsyncGlobal<IActorCritic> asyncGlobal;
@Getter
final protected int threadNumber;
final private Random rnd;
public A3CThreadDiscrete(MDP<O, Integer, DiscreteSpace> mdp, AsyncGlobal<IActorCritic> asyncGlobal,
public A3CThreadDiscrete(MDP<O, Integer, DiscreteSpace> mdp, IAsyncGlobal<IActorCritic> asyncGlobal,
A3CDiscrete.A3CConfiguration a3cc, int deviceNum, TrainingListenerList listeners,
int threadNumber) {
super(asyncGlobal, mdp, listeners, threadNumber, deviceNum);

View File

@ -46,7 +46,7 @@ public abstract class AsyncNStepQLearningDiscrete<O extends Encodable>
public AsyncNStepQLearningDiscrete(MDP<O, Integer, DiscreteSpace> mdp, IDQN dqn, AsyncNStepQLConfiguration conf) {
this.mdp = mdp;
this.configuration = conf;
this.asyncGlobal = new AsyncGlobal<>(dqn, conf);
this.asyncGlobal = new AsyncGlobal<>(dqn, conf, this);
}
@Override

View File

@ -150,6 +150,9 @@ public abstract class QLearning<O extends Encodable, A, AS extends ActionSpace<A
}
private InitMdp<Observation> refacInitMdp() {
getQNetwork().reset();
getTargetQNetwork().reset();
LegacyMDPWrapper<O, A, AS> mdp = getLegacyMDPWrapper();
IHistoryProcessor hp = getHistoryProcessor();

View File

@ -46,7 +46,7 @@ import java.util.ArrayList;
*
* DQN or Deep Q-Learning in the Discrete domain
*
* https://arxiv.org/abs/1312.5602
* http://arxiv.org/abs/1312.5602
*
*/
public abstract class QLearningDiscrete<O extends Encodable> extends QLearning<O, Integer, DiscreteSpace> {

View File

@ -23,7 +23,6 @@ import org.deeplearning4j.rl4j.network.dqn.IDQN;
import org.deeplearning4j.rl4j.space.ArrayObservationSpace;
import org.deeplearning4j.rl4j.space.DiscreteSpace;
import org.deeplearning4j.rl4j.space.ObservationSpace;
import org.json.JSONObject;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
@ -104,7 +103,7 @@ public class HardDeteministicToy implements MDP<HardToyState, Integer, DiscreteS
if (a == maxIndex(hardToyState.getValues()))
reward += 1;
hardToyState = states[hardToyState.getStep() + 1];
return new StepReply(hardToyState, reward, isDone(), new JSONObject("{}"));
return new StepReply(hardToyState, reward, isDone(), null);
}
public HardDeteministicToy newInstance() {

View File

@ -26,7 +26,6 @@ import org.deeplearning4j.rl4j.network.dqn.IDQN;
import org.deeplearning4j.rl4j.space.ArrayObservationSpace;
import org.deeplearning4j.rl4j.space.DiscreteSpace;
import org.deeplearning4j.rl4j.space.ObservationSpace;
import org.json.JSONObject;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
@ -80,7 +79,7 @@ public class SimpleToy implements MDP<SimpleToyState, Integer, DiscreteSpace> {
public StepReply<SimpleToyState> step(Integer a) {
double reward = (simpleToyState.getStep() % 2 == 0) ? 1 - a : a;
simpleToyState = new SimpleToyState(simpleToyState.getI() + 1, simpleToyState.getStep() + 1);
return new StepReply<>(simpleToyState, reward, isDone(), new JSONObject("{}"));
return new StepReply<>(simpleToyState, reward, isDone(), null);
}
public SimpleToy newInstance() {

View File

@ -29,7 +29,15 @@ public class Observation {
private final DataSet data;
public Observation(INDArray[] data) {
this(new org.nd4j.linalg.dataset.DataSet(Nd4j.concat(0, data), null));
this(data, false);
}
public Observation(INDArray[] data, boolean shouldReshape) {
INDArray features = Nd4j.concat(0, data);
if(shouldReshape) {
features = reshape(features);
}
this.data = new org.nd4j.linalg.dataset.DataSet(features, null);
}
// FIXME: Remove -- only used in unit tests
@ -37,6 +45,15 @@ public class Observation {
this.data = new org.nd4j.linalg.dataset.DataSet(data, null);
}
private INDArray reshape(INDArray source) {
long[] shape = source.shape();
long[] nshape = new long[shape.length + 1];
nshape[0] = 1;
System.arraycopy(shape, 0, nshape, 1, shape.length);
return source.reshape(nshape);
}
private Observation(DataSet data) {
this.data = data;
}

View File

@ -20,6 +20,7 @@ import org.deeplearning4j.rl4j.learning.Learning;
import org.deeplearning4j.rl4j.network.ac.ActorCriticCompGraph;
import org.deeplearning4j.rl4j.network.ac.ActorCriticSeparate;
import org.deeplearning4j.rl4j.network.ac.IActorCritic;
import org.deeplearning4j.rl4j.observation.Observation;
import org.deeplearning4j.rl4j.space.Encodable;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.rng.Random;
@ -65,6 +66,11 @@ public class ACPolicy<O extends Encodable> extends Policy<O, Integer> {
return actorCritic;
}
@Override
public Integer nextAction(Observation obs) {
return nextAction(obs.getData());
}
public Integer nextAction(INDArray input) {
INDArray output = actorCritic.outputAll(input)[1];
if (rnd == null) {

View File

@ -17,6 +17,7 @@
package org.deeplearning4j.rl4j.policy;
import org.deeplearning4j.rl4j.network.dqn.IDQN;
import org.deeplearning4j.rl4j.observation.Observation;
import org.deeplearning4j.rl4j.space.Encodable;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.rng.Random;
@ -43,6 +44,11 @@ public class BoltzmannQ<O extends Encodable> extends Policy<O, Integer> {
return dqn;
}
@Override
public Integer nextAction(Observation obs) {
return nextAction(obs.getData());
}
public Integer nextAction(INDArray input) {
INDArray output = dqn.output(input);

View File

@ -20,6 +20,7 @@ import lombok.AllArgsConstructor;
import org.deeplearning4j.rl4j.learning.Learning;
import org.deeplearning4j.rl4j.network.dqn.DQN;
import org.deeplearning4j.rl4j.network.dqn.IDQN;
import org.deeplearning4j.rl4j.observation.Observation;
import org.deeplearning4j.rl4j.space.Encodable;
import org.nd4j.linalg.api.ndarray.INDArray;
@ -44,6 +45,11 @@ public class DQNPolicy<O extends Encodable> extends Policy<O, Integer> {
return dqn;
}
@Override
public Integer nextAction(Observation obs) {
return nextAction(obs.getData());
}
public Integer nextAction(INDArray input) {
INDArray output = dqn.output(input);
return Learning.getMaxAction(output);

View File

@ -2,6 +2,7 @@ package org.deeplearning4j.rl4j.policy;
import org.deeplearning4j.rl4j.learning.IHistoryProcessor;
import org.deeplearning4j.rl4j.mdp.MDP;
import org.deeplearning4j.rl4j.observation.Observation;
import org.deeplearning4j.rl4j.space.ActionSpace;
import org.deeplearning4j.rl4j.space.Encodable;
import org.nd4j.linalg.api.ndarray.INDArray;
@ -9,4 +10,5 @@ import org.nd4j.linalg.api.ndarray.INDArray;
public interface IPolicy<O, A> {
<AS extends ActionSpace<A>> double play(MDP<O, A, AS> mdp, IHistoryProcessor hp);
A nextAction(INDArray input);
A nextAction(Observation observation);
}

View File

@ -16,15 +16,21 @@
package org.deeplearning4j.rl4j.policy;
import lombok.Getter;
import lombok.Setter;
import org.deeplearning4j.gym.StepReply;
import org.deeplearning4j.rl4j.learning.HistoryProcessor;
import org.deeplearning4j.rl4j.learning.IHistoryProcessor;
import org.deeplearning4j.rl4j.learning.Learning;
import org.deeplearning4j.rl4j.learning.StepCountable;
import org.deeplearning4j.rl4j.learning.sync.Transition;
import org.deeplearning4j.rl4j.mdp.MDP;
import org.deeplearning4j.rl4j.network.NeuralNet;
import org.deeplearning4j.rl4j.observation.Observation;
import org.deeplearning4j.rl4j.space.ActionSpace;
import org.deeplearning4j.rl4j.space.DiscreteSpace;
import org.deeplearning4j.rl4j.space.Encodable;
import org.deeplearning4j.rl4j.util.LegacyMDPWrapper;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.util.ArrayUtil;
@ -39,7 +45,7 @@ public abstract class Policy<O, A> implements IPolicy<O, A> {
public abstract NeuralNet getNeuralNet();
public abstract A nextAction(INDArray input);
public abstract A nextAction(Observation obs);
public <AS extends ActionSpace<A>> double play(MDP<O, A, AS> mdp) {
return play(mdp, (IHistoryProcessor)null);
@ -51,66 +57,81 @@ public abstract class Policy<O, A> implements IPolicy<O, A> {
@Override
public <AS extends ActionSpace<A>> double play(MDP<O, A, AS> mdp, IHistoryProcessor hp) {
RefacStepCountable stepCountable = new RefacStepCountable();
LegacyMDPWrapper<O, A, AS> mdpWrapper = new LegacyMDPWrapper<O, A, AS>(mdp, hp, stepCountable);
boolean isHistoryProcessor = hp != null;
int skipFrame = isHistoryProcessor ? hp.getConf().getSkipFrame() : 1;
getNeuralNet().reset();
Learning.InitMdp<O> initMdp = Learning.initMdp(mdp, hp);
O obs = initMdp.getLastObs();
Learning.InitMdp<Observation> initMdp = refacInitMdp(mdpWrapper, hp);
Observation obs = initMdp.getLastObs();
double reward = initMdp.getReward();
A lastAction = mdp.getActionSpace().noOp();
A lastAction = mdpWrapper.getActionSpace().noOp();
A action;
int step = initMdp.getSteps();
INDArray[] history = null;
stepCountable.setStepCounter(initMdp.getSteps());
INDArray input = Learning.getInput(mdp, obs);
while (!mdpWrapper.isDone()) {
while (!mdp.isDone()) {
if (step % skipFrame != 0) {
if (stepCountable.getStepCounter() % skipFrame != 0) {
action = lastAction;
} else {
if (history == null) {
if (isHistoryProcessor) {
hp.add(input);
history = hp.getHistory();
} else
history = new INDArray[] {input};
}
INDArray hstack = Transition.concat(history);
if (isHistoryProcessor) {
hstack.muli(1.0 / hp.getScale());
}
if (getNeuralNet().isRecurrent()) {
//flatten everything for the RNN
hstack = hstack.reshape(Learning.makeShape(1, ArrayUtil.toInts(hstack.shape()), 1));
} else {
if (hstack.shape().length > 2)
hstack = hstack.reshape(Learning.makeShape(1, ArrayUtil.toInts(hstack.shape())));
}
action = nextAction(hstack);
action = nextAction(obs);
}
lastAction = action;
StepReply<O> stepReply = mdp.step(action);
StepReply<Observation> stepReply = mdpWrapper.step(action);
reward += stepReply.getReward();
input = Learning.getInput(mdp, stepReply.getObservation());
if (isHistoryProcessor) {
hp.record(input);
hp.add(input);
}
history = isHistoryProcessor ? hp.getHistory()
: new INDArray[] {Learning.getInput(mdp, stepReply.getObservation())};
step++;
obs = stepReply.getObservation();
stepCountable.increment();
}
return reward;
}
private <AS extends ActionSpace<A>> Learning.InitMdp<Observation> refacInitMdp(LegacyMDPWrapper<O, A, AS> mdpWrapper, IHistoryProcessor hp) {
getNeuralNet().reset();
Observation observation = mdpWrapper.reset();
int step = 0;
double reward = 0;
boolean isHistoryProcessor = hp != null;
int skipFrame = isHistoryProcessor ? hp.getConf().getSkipFrame() : 1;
int requiredFrame = isHistoryProcessor ? skipFrame * (hp.getConf().getHistoryLength() - 1) : 0;
while (step < requiredFrame && !mdpWrapper.isDone()) {
A action = mdpWrapper.getActionSpace().noOp(); //by convention should be the NO_OP
StepReply<Observation> stepReply = mdpWrapper.step(action);
reward += stepReply.getReward();
observation = stepReply.getObservation();
step++;
}
return new Learning.InitMdp(step, observation, reward);
}
private class RefacStepCountable implements StepCountable {
@Getter
@Setter
private int stepCounter = 0;
public void increment() {
++stepCounter;
}
@Override
public int getStepCounter() {
return 0;
}
}
}

View File

@ -26,6 +26,7 @@ import org.deeplearning4j.rl4j.learning.NeuralNetFetchable;
import org.nd4j.linalg.primitives.Pair;
import org.deeplearning4j.rl4j.learning.ILearning;
import org.deeplearning4j.rl4j.learning.Learning;
import org.deeplearning4j.rl4j.network.ac.IActorCritic;
import org.deeplearning4j.rl4j.network.dqn.DQN;
import org.deeplearning4j.rl4j.network.dqn.IDQN;
import org.deeplearning4j.util.ModelSerializer;
@ -88,19 +89,38 @@ public class DataManager implements IDataManager {
String json = new ObjectMapper().writeValueAsString(learning.getConfiguration());
writeEntry(new ByteArrayInputStream(json.getBytes()), zipfile);
ZipEntry dqn = new ZipEntry("dqn.bin");
zipfile.putNextEntry(dqn);
try {
ZipEntry dqn = new ZipEntry("dqn.bin");
zipfile.putNextEntry(dqn);
ByteArrayOutputStream bos = new ByteArrayOutputStream();
if(learning instanceof NeuralNetFetchable) {
((NeuralNetFetchable)learning).getNeuralNet().save(bos);
ByteArrayOutputStream bos = new ByteArrayOutputStream();
if(learning instanceof NeuralNetFetchable) {
((NeuralNetFetchable)learning).getNeuralNet().save(bos);
}
bos.flush();
bos.close();
InputStream inputStream = new ByteArrayInputStream(bos.toByteArray());
writeEntry(inputStream, zipfile);
} catch (UnsupportedOperationException e) {
ByteArrayOutputStream bos = new ByteArrayOutputStream();
ByteArrayOutputStream bos2 = new ByteArrayOutputStream();
((IActorCritic)((NeuralNetFetchable)learning).getNeuralNet()).save(bos, bos2);
bos.flush();
bos.close();
InputStream inputStream = new ByteArrayInputStream(bos.toByteArray());
ZipEntry value = new ZipEntry("value.bin");
zipfile.putNextEntry(value);
writeEntry(inputStream, zipfile);
bos2.flush();
bos2.close();
InputStream inputStream2 = new ByteArrayInputStream(bos2.toByteArray());
ZipEntry policy = new ZipEntry("policy.bin");
zipfile.putNextEntry(policy);
writeEntry(inputStream2, zipfile);
}
bos.flush();
bos.close();
InputStream inputStream = new ByteArrayInputStream(bos.toByteArray());
writeEntry(inputStream, zipfile);
if (learning.getHistoryProcessor() != null) {
ZipEntry hpconf = new ZipEntry("hpconf.bin");
@ -268,7 +288,12 @@ public class DataManager implements IDataManager {
save(getModelDir() + "/" + learning.getStepCounter() + ".training", learning);
if(learning instanceof NeuralNetFetchable) {
((NeuralNetFetchable)learning).getNeuralNet().save(getModelDir() + "/" + learning.getStepCounter() + ".model");
try {
((NeuralNetFetchable)learning).getNeuralNet().save(getModelDir() + "/" + learning.getStepCounter() + ".model");
} catch (UnsupportedOperationException e) {
String path = getModelDir() + "/" + learning.getStepCounter();
((IActorCritic)((NeuralNetFetchable)learning).getNeuralNet()).save(path + "_value.model", path + "_policy.model");
}
}
}

View File

@ -4,6 +4,7 @@ import lombok.Getter;
import org.deeplearning4j.gym.StepReply;
import org.deeplearning4j.rl4j.learning.IHistoryProcessor;
import org.deeplearning4j.rl4j.learning.ILearning;
import org.deeplearning4j.rl4j.learning.StepCountable;
import org.deeplearning4j.rl4j.mdp.MDP;
import org.deeplearning4j.rl4j.observation.Observation;
import org.deeplearning4j.rl4j.space.ActionSpace;
@ -12,21 +13,56 @@ import org.deeplearning4j.rl4j.space.ObservationSpace;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
public class LegacyMDPWrapper<O extends Encodable, A, AS extends ActionSpace<A>> implements MDP<Observation, A, AS> {
public class LegacyMDPWrapper<O, A, AS extends ActionSpace<A>> implements MDP<Observation, A, AS> {
@Getter
private final MDP<O, A, AS> wrappedMDP;
@Getter
private final WrapperObservationSpace observationSpace;
private final ILearning learning;
private IHistoryProcessor historyProcessor;
private final StepCountable stepCountable;
private int skipFrame;
private int step = 0;
public LegacyMDPWrapper(MDP<O, A, AS> wrappedMDP, ILearning learning) {
this(wrappedMDP, learning, null, null);
}
public LegacyMDPWrapper(MDP<O, A, AS> wrappedMDP, IHistoryProcessor historyProcessor, StepCountable stepCountable) {
this(wrappedMDP, null, historyProcessor, stepCountable);
}
private LegacyMDPWrapper(MDP<O, A, AS> wrappedMDP, ILearning learning, IHistoryProcessor historyProcessor, StepCountable stepCountable) {
this.wrappedMDP = wrappedMDP;
this.observationSpace = new WrapperObservationSpace(wrappedMDP.getObservationSpace().getShape());
this.learning = learning;
this.historyProcessor = historyProcessor;
this.stepCountable = stepCountable;
}
private IHistoryProcessor getHistoryProcessor() {
if(historyProcessor != null) {
return historyProcessor;
}
if (learning != null) {
return learning.getHistoryProcessor();
}
return null;
}
public void setHistoryProcessor(IHistoryProcessor historyProcessor) {
this.historyProcessor = historyProcessor;
}
private int getStep() {
if(stepCountable != null) {
return stepCountable.getStepCounter();
}
return learning.getStepCounter();
}
@Override
@ -38,13 +74,12 @@ public class LegacyMDPWrapper<O extends Encodable, A, AS extends ActionSpace<A>>
public Observation reset() {
INDArray rawObservation = getInput(wrappedMDP.reset());
IHistoryProcessor historyProcessor = learning.getHistoryProcessor();
IHistoryProcessor historyProcessor = getHistoryProcessor();
if(historyProcessor != null) {
historyProcessor.record(rawObservation.dup());
rawObservation.muli(1.0 / historyProcessor.getScale());
historyProcessor.record(rawObservation);
}
Observation observation = new Observation(new INDArray[] { rawObservation });
Observation observation = new Observation(new INDArray[] { rawObservation }, false);
if(historyProcessor != null) {
skipFrame = historyProcessor.getConf().getSkipFrame();
@ -55,14 +90,9 @@ public class LegacyMDPWrapper<O extends Encodable, A, AS extends ActionSpace<A>>
return observation;
}
@Override
public void close() {
wrappedMDP.close();
}
@Override
public StepReply<Observation> step(A a) {
IHistoryProcessor historyProcessor = learning.getHistoryProcessor();
IHistoryProcessor historyProcessor = getHistoryProcessor();
StepReply<O> rawStepReply = wrappedMDP.step(a);
INDArray rawObservation = getInput(rawStepReply.getObservation());
@ -71,11 +101,10 @@ public class LegacyMDPWrapper<O extends Encodable, A, AS extends ActionSpace<A>>
int requiredFrame = 0;
if(historyProcessor != null) {
historyProcessor.record(rawObservation.dup());
rawObservation.muli(1.0 / historyProcessor.getScale());
historyProcessor.record(rawObservation);
requiredFrame = skipFrame * (historyProcessor.getConf().getHistoryLength() - 1);
if ((learning.getStepCounter() % skipFrame == 0 && step >= requiredFrame)
if ((getStep() % skipFrame == 0 && step >= requiredFrame)
|| (step % skipFrame == 0 && step < requiredFrame )){
historyProcessor.add(rawObservation);
}
@ -83,15 +112,21 @@ public class LegacyMDPWrapper<O extends Encodable, A, AS extends ActionSpace<A>>
Observation observation;
if(historyProcessor != null && step >= requiredFrame) {
observation = new Observation(historyProcessor.getHistory());
observation = new Observation(historyProcessor.getHistory(), true);
observation.getData().muli(1.0 / historyProcessor.getScale());
}
else {
observation = new Observation(new INDArray[] { rawObservation });
observation = new Observation(new INDArray[] { rawObservation }, false);
}
return new StepReply<Observation>(observation, rawStepReply.getReward(), rawStepReply.isDone(), rawStepReply.getInfo());
}
@Override
public void close() {
wrappedMDP.close();
}
@Override
public boolean isDone() {
return wrappedMDP.isDone();
@ -103,7 +138,7 @@ public class LegacyMDPWrapper<O extends Encodable, A, AS extends ActionSpace<A>>
}
private INDArray getInput(O obs) {
INDArray arr = Nd4j.create(obs.toArray());
INDArray arr = Nd4j.create(((Encodable)obs).toArray());
int[] shape = observationSpace.getShape();
if (shape.length == 1)
return arr.reshape(new long[] {1, arr.length()});

View File

@ -72,7 +72,7 @@ public class AsyncLearningTest {
public final MockAsyncGlobal asyncGlobal = new MockAsyncGlobal();
public final MockPolicy policy = new MockPolicy();
public final TestAsyncLearning sut = new TestAsyncLearning(config, asyncGlobal, policy);
public final MockTrainingListener listener = new MockTrainingListener();
public final MockTrainingListener listener = new MockTrainingListener(asyncGlobal);
public TestContext() {
sut.addListener(listener);

View File

@ -2,16 +2,17 @@ package org.deeplearning4j.rl4j.learning.async;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.rl4j.learning.IHistoryProcessor;
import org.deeplearning4j.rl4j.learning.Learning;
import org.deeplearning4j.rl4j.learning.listener.TrainingListenerList;
import org.deeplearning4j.rl4j.mdp.MDP;
import org.deeplearning4j.rl4j.observation.Observation;
import org.deeplearning4j.rl4j.policy.IPolicy;
import org.deeplearning4j.rl4j.space.DiscreteSpace;
import org.deeplearning4j.rl4j.support.*;
import org.junit.Test;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import java.util.ArrayList;
import java.util.List;
import java.util.Stack;
import static org.junit.Assert.assertEquals;
@ -21,37 +22,51 @@ public class AsyncThreadDiscreteTest {
@Test
public void refac_AsyncThreadDiscrete_trainSubEpoch() {
// Arrange
int numEpochs = 1;
MockNeuralNet nnMock = new MockNeuralNet();
IHistoryProcessor.Configuration hpConf = new IHistoryProcessor.Configuration(5, 4, 4, 4, 4, 0, 0, 2);
MockHistoryProcessor hpMock = new MockHistoryProcessor(hpConf);
MockAsyncGlobal asyncGlobalMock = new MockAsyncGlobal(nnMock);
asyncGlobalMock.setMaxLoops(hpConf.getSkipFrame() * numEpochs);
MockObservationSpace observationSpace = new MockObservationSpace();
MockMDP mdpMock = new MockMDP(observationSpace);
TrainingListenerList listeners = new TrainingListenerList();
MockPolicy policyMock = new MockPolicy();
MockAsyncConfiguration config = new MockAsyncConfiguration(5, 10, 0, 0, 0, 5,0, 0, 0, 0);
IHistoryProcessor.Configuration hpConf = new IHistoryProcessor.Configuration(5, 4, 4, 4, 4, 0, 0, 2);
MockHistoryProcessor hpMock = new MockHistoryProcessor(hpConf);
MockAsyncConfiguration config = new MockAsyncConfiguration(5, 16, 0, 0, 2, 5,0, 0, 0, 0);
TestAsyncThreadDiscrete sut = new TestAsyncThreadDiscrete(asyncGlobalMock, mdpMock, listeners, 0, 0, policyMock, config, hpMock);
MockEncodable obs = new MockEncodable(123);
hpMock.add(Learning.getInput(mdpMock, new MockEncodable(1)));
hpMock.add(Learning.getInput(mdpMock, new MockEncodable(2)));
hpMock.add(Learning.getInput(mdpMock, new MockEncodable(3)));
hpMock.add(Learning.getInput(mdpMock, new MockEncodable(4)));
hpMock.add(Learning.getInput(mdpMock, new MockEncodable(5)));
// Act
AsyncThread.SubEpochReturn<MockEncodable> result = sut.trainSubEpoch(obs, 2);
sut.run();
// Assert
assertEquals(4, result.getSteps());
assertEquals(6.0, result.getReward(), 0.00001);
assertEquals(0.0, result.getScore(), 0.00001);
assertEquals(3.0, result.getLastObs().toArray()[0], 0.00001);
assertEquals(1, asyncGlobalMock.enqueueCallCount);
assertEquals(2, sut.trainSubEpochResults.size());
double[][] expectedLastObservations = new double[][] {
new double[] { 4.0, 6.0, 8.0, 9.0, 11.0 },
new double[] { 8.0, 9.0, 11.0, 13.0, 15.0 },
};
double[] expectedSubEpochReturnRewards = new double[] { 42.0, 58.0 };
for(int i = 0; i < 2; ++i) {
AsyncThread.SubEpochReturn result = sut.trainSubEpochResults.get(i);
assertEquals(4, result.getSteps());
assertEquals(expectedSubEpochReturnRewards[i], result.getReward(), 0.00001);
assertEquals(0.0, result.getScore(), 0.00001);
double[] expectedLastObservation = expectedLastObservations[i];
assertEquals(expectedLastObservation.length, result.getLastObs().getData().shape()[1]);
for(int j = 0; j < expectedLastObservation.length; ++j) {
assertEquals(expectedLastObservation[j], 255.0 * result.getLastObs().getData().getDouble(j), 0.00001);
}
}
assertEquals(2, asyncGlobalMock.enqueueCallCount);
// HistoryProcessor
assertEquals(10, hpMock.addCalls.size());
double[] expectedRecordValues = new double[] { 123.0, 0.0, 1.0, 2.0, 3.0 };
double[] expectedAddValues = new double[] { 0.0, 2.0, 4.0, 6.0, 8.0, 9.0, 11.0, 13.0, 15.0 };
assertEquals(expectedAddValues.length, hpMock.addCalls.size());
for(int i = 0; i < expectedAddValues.length; ++i) {
assertEquals(expectedAddValues[i], hpMock.addCalls.get(i).getDouble(0), 0.00001);
}
double[] expectedRecordValues = new double[] { 0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0, };
assertEquals(expectedRecordValues.length, hpMock.recordCalls.size());
for(int i = 0; i < expectedRecordValues.length; ++i) {
assertEquals(expectedRecordValues[i], hpMock.recordCalls.get(i).getDouble(0), 0.00001);
@ -59,49 +74,89 @@ public class AsyncThreadDiscreteTest {
// Policy
double[][] expectedPolicyInputs = new double[][] {
new double[] { 2.0, 3.0, 4.0, 5.0, 123.0 },
new double[] { 3.0, 4.0, 5.0, 123.0, 0.0 },
new double[] { 4.0, 5.0, 123.0, 0.0, 1.0 },
new double[] { 5.0, 123.0, 0.0, 1.0, 2.0 },
new double[] { 0.0, 2.0, 4.0, 6.0, 8.0 },
new double[] { 2.0, 4.0, 6.0, 8.0, 9.0 },
new double[] { 4.0, 6.0, 8.0, 9.0, 11.0 },
new double[] { 6.0, 8.0, 9.0, 11.0, 13.0 },
};
assertEquals(expectedPolicyInputs.length, policyMock.actionInputs.size());
for(int i = 0; i < expectedPolicyInputs.length; ++i) {
double[] expectedRow = expectedPolicyInputs[i];
INDArray input = policyMock.actionInputs.get(i);
assertEquals(expectedRow.length, input.shape()[0]);
assertEquals(expectedRow.length, input.shape()[1]);
for(int j = 0; j < expectedRow.length; ++j) {
assertEquals(expectedRow[j], 255.0 * input.getDouble(j), 0.00001);
}
}
// NeuralNetwork
assertEquals(1, nnMock.copyCallCount);
assertEquals(2, nnMock.copyCallCount);
double[][] expectedNNInputs = new double[][] {
new double[] { 2.0, 3.0, 4.0, 5.0, 123.0 },
new double[] { 3.0, 4.0, 5.0, 123.0, 0.0 },
new double[] { 4.0, 5.0, 123.0, 0.0, 1.0 },
new double[] { 5.0, 123.0, 0.0, 1.0, 2.0 },
new double[] { 123.0, 0.0, 1.0, 2.0, 3.0 },
new double[] { 0.0, 2.0, 4.0, 6.0, 8.0 },
new double[] { 2.0, 4.0, 6.0, 8.0, 9.0 },
new double[] { 4.0, 6.0, 8.0, 9.0, 11.0 }, // FIXME: This one comes from the computation of output of the last minitrans
new double[] { 4.0, 6.0, 8.0, 9.0, 11.0 },
new double[] { 6.0, 8.0, 9.0, 11.0, 13.0 },
new double[] { 8.0, 9.0, 11.0, 13.0, 15.0 }, // FIXME: This one comes from the computation of output of the last minitrans
};
assertEquals(expectedNNInputs.length, nnMock.outputAllInputs.size());
for(int i = 0; i < expectedNNInputs.length; ++i) {
double[] expectedRow = expectedNNInputs[i];
INDArray input = nnMock.outputAllInputs.get(i);
assertEquals(expectedRow.length, input.shape()[0]);
assertEquals(expectedRow.length, input.shape()[1]);
for(int j = 0; j < expectedRow.length; ++j) {
assertEquals(expectedRow[j], 255.0 * input.getDouble(j), 0.00001);
}
}
int arrayIdx = 0;
double[][][] expectedMinitransObs = new double[][][] {
new double[][] {
new double[] { 0.0, 2.0, 4.0, 6.0, 8.0 },
new double[] { 2.0, 4.0, 6.0, 8.0, 9.0 },
new double[] { 4.0, 6.0, 8.0, 9.0, 11.0 }, // FIXME: The last minitrans contains the next observation
},
new double[][] {
new double[] { 4.0, 6.0, 8.0, 9.0, 11.0 },
new double[] { 6.0, 8.0, 9.0, 11.0, 13.0 },
new double[] { 8.0, 9.0, 11.0, 13.0, 15 }, // FIXME: The last minitrans contains the next observation
}
};
double[] expectedOutputs = new double[] { 1.0, 2.0, 3.0, 4.0, 5.0, 6.0 };
double[] expectedRewards = new double[] { 0.0, 0.0, 3.0, 0.0, 0.0, 6.0 };
assertEquals(2, sut.rewards.size());
for(int rewardIdx = 0; rewardIdx < 2; ++rewardIdx) {
Stack<MiniTrans<Integer>> miniTransStack = sut.rewards.get(rewardIdx);
for (int i = 0; i < expectedMinitransObs[rewardIdx].length; ++i) {
MiniTrans minitrans = miniTransStack.get(i);
// Observation
double[] expectedRow = expectedMinitransObs[rewardIdx][i];
INDArray realRewards = minitrans.getObs();
assertEquals(expectedRow.length, realRewards.shape()[1]);
for (int j = 0; j < expectedRow.length; ++j) {
assertEquals("row: "+ i + " col: " + j, expectedRow[j], 255.0 * realRewards.getDouble(j), 0.00001);
}
assertEquals(expectedOutputs[arrayIdx], minitrans.getOutput()[0].getDouble(0), 0.00001);
assertEquals(expectedRewards[arrayIdx], minitrans.getReward(), 0.00001);
++arrayIdx;
}
}
}
public static class TestAsyncThreadDiscrete extends AsyncThreadDiscrete<MockEncodable, MockNeuralNet> {
private final IAsyncGlobal<MockNeuralNet> asyncGlobal;
private final MockAsyncGlobal asyncGlobal;
private final MockPolicy policy;
private final MockAsyncConfiguration config;
public TestAsyncThreadDiscrete(IAsyncGlobal<MockNeuralNet> asyncGlobal, MDP<MockEncodable, Integer, DiscreteSpace> mdp,
public final List<SubEpochReturn> trainSubEpochResults = new ArrayList<SubEpochReturn>();
public final List<Stack<MiniTrans<Integer>>> rewards = new ArrayList<Stack<MiniTrans<Integer>>>();
public TestAsyncThreadDiscrete(MockAsyncGlobal asyncGlobal, MDP<MockEncodable, Integer, DiscreteSpace> mdp,
TrainingListenerList listeners, int threadNumber, int deviceNum, MockPolicy policy,
MockAsyncConfiguration config, IHistoryProcessor hp) {
super(asyncGlobal, mdp, listeners, threadNumber, deviceNum);
@ -113,6 +168,7 @@ public class AsyncThreadDiscreteTest {
@Override
public Gradient[] calcGradient(MockNeuralNet mockNeuralNet, Stack<MiniTrans<Integer>> rewards) {
this.rewards.add(rewards);
return new Gradient[0];
}
@ -130,5 +186,13 @@ public class AsyncThreadDiscreteTest {
protected IPolicy<MockEncodable, Integer> getPolicy(MockNeuralNet net) {
return policy;
}
@Override
public SubEpochReturn trainSubEpoch(Observation sObs, int nstep) {
asyncGlobal.increaseCurrentLoop();
SubEpochReturn result = super.trainSubEpoch(sObs, nstep);
trainSubEpochResults.add(result);
return result;
}
}
}

View File

@ -6,11 +6,14 @@ import org.deeplearning4j.rl4j.learning.IHistoryProcessor;
import org.deeplearning4j.rl4j.learning.listener.TrainingListenerList;
import org.deeplearning4j.rl4j.mdp.MDP;
import org.deeplearning4j.rl4j.network.NeuralNet;
import org.deeplearning4j.rl4j.observation.Observation;
import org.deeplearning4j.rl4j.policy.Policy;
import org.deeplearning4j.rl4j.space.DiscreteSpace;
import org.deeplearning4j.rl4j.space.Encodable;
import org.deeplearning4j.rl4j.support.*;
import org.deeplearning4j.rl4j.util.IDataManager;
import org.junit.Test;
import org.nd4j.linalg.api.ndarray.INDArray;
import java.util.ArrayList;
import java.util.List;
@ -23,107 +26,100 @@ public class AsyncThreadTest {
@Test
public void when_newEpochStarted_expect_neuralNetworkReset() {
// Arrange
TestContext context = new TestContext();
context.listener.setRemainingOnNewEpochCallCount(5);
int numberOfEpochs = 5;
TestContext context = new TestContext(numberOfEpochs);
// Act
context.sut.run();
// Assert
assertEquals(6, context.neuralNet.resetCallCount);
assertEquals(numberOfEpochs, context.neuralNet.resetCallCount);
}
@Test
public void when_onNewEpochReturnsStop_expect_threadStopped() {
// Arrange
TestContext context = new TestContext();
context.listener.setRemainingOnNewEpochCallCount(1);
int stopAfterNumCalls = 1;
TestContext context = new TestContext(100000);
context.listener.setRemainingOnNewEpochCallCount(stopAfterNumCalls);
// Act
context.sut.run();
// Assert
assertEquals(2, context.listener.onNewEpochCallCount);
assertEquals(1, context.listener.onEpochTrainingResultCallCount);
assertEquals(stopAfterNumCalls + 1, context.listener.onNewEpochCallCount); // +1: The call that returns stop is counted
assertEquals(stopAfterNumCalls, context.listener.onEpochTrainingResultCallCount);
}
@Test
public void when_epochTrainingResultReturnsStop_expect_threadStopped() {
// Arrange
TestContext context = new TestContext();
context.listener.setRemainingOnEpochTrainingResult(1);
int stopAfterNumCalls = 1;
TestContext context = new TestContext(100000);
context.listener.setRemainingOnEpochTrainingResult(stopAfterNumCalls);
// Act
context.sut.run();
// Assert
assertEquals(2, context.listener.onNewEpochCallCount);
assertEquals(2, context.listener.onEpochTrainingResultCallCount);
assertEquals(stopAfterNumCalls + 1, context.listener.onEpochTrainingResultCallCount); // +1: The call that returns stop is counted
assertEquals(stopAfterNumCalls + 1, context.listener.onNewEpochCallCount); // +1: onNewEpoch is called on the epoch that onEpochTrainingResult() will stop
}
@Test
public void when_run_expect_preAndPostEpochCalled() {
// Arrange
TestContext context = new TestContext();
int numberOfEpochs = 5;
TestContext context = new TestContext(numberOfEpochs);
// Act
context.sut.run();
// Assert
assertEquals(6, context.sut.preEpochCallCount);
assertEquals(6, context.sut.postEpochCallCount);
assertEquals(numberOfEpochs, context.sut.preEpochCallCount);
assertEquals(numberOfEpochs, context.sut.postEpochCallCount);
}
@Test
public void when_run_expect_trainSubEpochCalledAndResultPassedToListeners() {
// Arrange
TestContext context = new TestContext();
int numberOfEpochs = 5;
TestContext context = new TestContext(numberOfEpochs);
// Act
context.sut.run();
// Assert
assertEquals(5, context.listener.statEntries.size());
assertEquals(numberOfEpochs, context.listener.statEntries.size());
int[] expectedStepCounter = new int[] { 2, 4, 6, 8, 10 };
for(int i = 0; i < 5; ++i) {
double expectedReward = (1.0 + 2.0 + 3.0 + 4.0 + 5.0 + 6.0 + 7.0 + 8.0) // reward from init
+ 1.0; // Reward from trainSubEpoch()
for(int i = 0; i < numberOfEpochs; ++i) {
IDataManager.StatEntry statEntry = context.listener.statEntries.get(i);
assertEquals(expectedStepCounter[i], statEntry.getStepCounter());
assertEquals(i, statEntry.getEpochCounter());
assertEquals(38.0, statEntry.getReward(), 0.0001);
assertEquals(expectedReward, statEntry.getReward(), 0.0001);
}
}
@Test
public void when_run_expect_NeuralNetIsResetAtInitAndEveryEpoch() {
// Arrange
TestContext context = new TestContext();
// Act
context.sut.run();
// Assert
assertEquals(6, context.neuralNet.resetCallCount);
}
@Test
public void when_run_expect_trainSubEpochCalled() {
// Arrange
TestContext context = new TestContext();
int numberOfEpochs = 5;
TestContext context = new TestContext(numberOfEpochs);
// Act
context.sut.run();
// Assert
assertEquals(10, context.sut.trainSubEpochParams.size());
for(int i = 0; i < 10; ++i) {
assertEquals(numberOfEpochs, context.sut.trainSubEpochParams.size());
double[] expectedObservation = new double[] { 0.0, 2.0, 4.0, 6.0, 8.0 };
for(int i = 0; i < context.sut.getEpochCounter(); ++i) {
MockAsyncThread.TrainSubEpochParams params = context.sut.trainSubEpochParams.get(i);
if(i % 2 == 0) {
assertEquals(2, params.nstep);
assertEquals(8.0, params.obs.toArray()[0], 0.00001);
}
else {
assertEquals(1, params.nstep);
assertNull(params.obs);
assertEquals(2, params.nstep);
assertEquals(expectedObservation.length, params.obs.getData().shape()[1]);
for(int j = 0; j < expectedObservation.length; ++j){
assertEquals(expectedObservation[j], 255.0 * params.obs.getData().getDouble(j), 0.00001);
}
}
}
@ -136,30 +132,30 @@ public class AsyncThreadTest {
public final MockAsyncConfiguration config = new MockAsyncConfiguration(5, 10, 0, 0, 10, 0, 0, 0, 0, 0);
public final TrainingListenerList listeners = new TrainingListenerList();
public final MockTrainingListener listener = new MockTrainingListener();
private final IHistoryProcessor.Configuration hpConf = new IHistoryProcessor.Configuration(5, 4, 4, 4, 4, 0, 0, 2);
public final IHistoryProcessor.Configuration hpConf = new IHistoryProcessor.Configuration(5, 4, 4, 4, 4, 0, 0, 2);
public final MockHistoryProcessor historyProcessor = new MockHistoryProcessor(hpConf);
public final MockAsyncThread sut = new MockAsyncThread(asyncGlobal, 0, neuralNet, mdp, config, listeners);
public TestContext() {
asyncGlobal.setMaxLoops(10);
public TestContext(int numEpochs) {
asyncGlobal.setMaxLoops(numEpochs);
listeners.add(listener);
sut.setHistoryProcessor(historyProcessor);
}
}
public static class MockAsyncThread extends AsyncThread {
public static class MockAsyncThread extends AsyncThread<MockEncodable, Integer, DiscreteSpace, MockNeuralNet> {
public int preEpochCallCount = 0;
public int postEpochCallCount = 0;
private final IAsyncGlobal asyncGlobal;
private final MockAsyncGlobal asyncGlobal;
private final MockNeuralNet neuralNet;
private final AsyncConfiguration conf;
private final List<TrainSubEpochParams> trainSubEpochParams = new ArrayList<TrainSubEpochParams>();
public MockAsyncThread(IAsyncGlobal asyncGlobal, int threadNumber, MockNeuralNet neuralNet, MDP mdp, AsyncConfiguration conf, TrainingListenerList listeners) {
public MockAsyncThread(MockAsyncGlobal asyncGlobal, int threadNumber, MockNeuralNet neuralNet, MDP mdp, AsyncConfiguration conf, TrainingListenerList listeners) {
super(asyncGlobal, mdp, listeners, threadNumber, 0);
this.asyncGlobal = asyncGlobal;
@ -180,7 +176,7 @@ public class AsyncThreadTest {
}
@Override
protected NeuralNet getCurrent() {
protected MockNeuralNet getCurrent() {
return neuralNet;
}
@ -195,20 +191,22 @@ public class AsyncThreadTest {
}
@Override
protected Policy getPolicy(NeuralNet net) {
protected Policy getPolicy(MockNeuralNet net) {
return null;
}
@Override
protected SubEpochReturn trainSubEpoch(Encodable obs, int nstep) {
protected SubEpochReturn trainSubEpoch(Observation obs, int nstep) {
asyncGlobal.increaseCurrentLoop();
trainSubEpochParams.add(new TrainSubEpochParams(obs, nstep));
return new SubEpochReturn(1, null, 1.0, 1.0);
setStepCounter(getStepCounter() + nstep);
return new SubEpochReturn(nstep, null, 1.0, 1.0);
}
@AllArgsConstructor
@Getter
public static class TrainSubEpochParams {
Encodable obs;
Observation obs;
int nstep;
}
}

View File

@ -0,0 +1,181 @@
package org.deeplearning4j.rl4j.learning.async.a3c.discrete;
import org.deeplearning4j.nn.api.NeuralNetwork;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.rl4j.learning.IHistoryProcessor;
import org.deeplearning4j.rl4j.learning.async.MiniTrans;
import org.deeplearning4j.rl4j.learning.async.nstep.discrete.AsyncNStepQLearningDiscrete;
import org.deeplearning4j.rl4j.learning.async.nstep.discrete.AsyncNStepQLearningThreadDiscrete;
import org.deeplearning4j.rl4j.network.NeuralNet;
import org.deeplearning4j.rl4j.network.ac.IActorCritic;
import org.deeplearning4j.rl4j.support.*;
import org.junit.Test;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.primitives.Pair;
import java.io.IOException;
import java.io.OutputStream;
import java.util.ArrayList;
import java.util.List;
import java.util.Stack;
import static org.junit.Assert.assertArrayEquals;
import static org.junit.Assert.assertEquals;
public class A3CThreadDiscreteTest {
@Test
public void refac_calcGradient() {
// Arrange
double gamma = 0.9;
MockObservationSpace observationSpace = new MockObservationSpace();
MockMDP mdpMock = new MockMDP(observationSpace);
A3CDiscrete.A3CConfiguration config = new A3CDiscrete.A3CConfiguration(0, 0, 0, 0, 0, 0, 0, gamma, 0);
MockActorCritic actorCriticMock = new MockActorCritic();
IHistoryProcessor.Configuration hpConf = new IHistoryProcessor.Configuration(5, 1, 1, 1, 1, 0, 0, 2);
MockAsyncGlobal<IActorCritic> asyncGlobalMock = new MockAsyncGlobal<IActorCritic>(actorCriticMock);
A3CThreadDiscrete sut = new A3CThreadDiscrete<MockEncodable>(mdpMock, asyncGlobalMock, config, 0, null, 0);
MockHistoryProcessor hpMock = new MockHistoryProcessor(hpConf);
sut.setHistoryProcessor(hpMock);
double[][] minitransObs = new double[][] {
new double[] { 0.0, 2.0, 4.0, 6.0, 8.0 },
new double[] { 2.0, 4.0, 6.0, 8.0, 10.0 },
new double[] { 4.0, 6.0, 8.0, 10.0, 12.0 },
};
double[] outputs = new double[] { 1.0, 2.0, 3.0 };
double[] rewards = new double[] { 0.0, 0.0, 3.0 };
Stack<MiniTrans<Integer>> minitransList = new Stack<MiniTrans<Integer>>();
for(int i = 0; i < 3; ++i) {
INDArray obs = Nd4j.create(minitransObs[i]).reshape(5, 1, 1);
INDArray[] output = new INDArray[] {
Nd4j.zeros(5)
};
output[0].putScalar(i, outputs[i]);
minitransList.push(new MiniTrans<Integer>(obs, i, output, rewards[i]));
}
minitransList.push(new MiniTrans<Integer>(null, 0, null, 4.0)); // The special batch-ending MiniTrans
// Act
sut.calcGradient(actorCriticMock, minitransList);
// Assert
assertEquals(1, actorCriticMock.gradientParams.size());
INDArray input = actorCriticMock.gradientParams.get(0).getFirst();
INDArray[] labels = actorCriticMock.gradientParams.get(0).getSecond();
assertEquals(minitransObs.length, input.shape()[0]);
for(int i = 0; i < minitransObs.length; ++i) {
double[] expectedRow = minitransObs[i];
assertEquals(expectedRow.length, input.shape()[1]);
for(int j = 0; j < expectedRow.length; ++j) {
assertEquals(expectedRow[j], input.getDouble(i, j, 1, 1), 0.00001);
}
}
double latestReward = (gamma * 4.0) + 3.0;
double[] expectedLabels0 = new double[] { gamma * gamma * latestReward, gamma * latestReward, latestReward };
for(int i = 0; i < expectedLabels0.length; ++i) {
assertEquals(expectedLabels0[i], labels[0].getDouble(i), 0.00001);
}
double[][] expectedLabels1 = new double[][] {
new double[] { 4.346, 0.0, 0.0, 0.0, 0.0 },
new double[] { 0.0, gamma * latestReward, 0.0, 0.0, 0.0 },
new double[] { 0.0, 0.0, latestReward, 0.0, 0.0 },
};
assertArrayEquals(new long[] { expectedLabels0.length, 1 }, labels[0].shape());
for(int i = 0; i < expectedLabels1.length; ++i) {
double[] expectedRow = expectedLabels1[i];
assertEquals(expectedRow.length, labels[1].shape()[1]);
for(int j = 0; j < expectedRow.length; ++j) {
assertEquals(expectedRow[j], labels[1].getDouble(i, j), 0.00001);
}
}
}
public class MockActorCritic implements IActorCritic {
public final List<Pair<INDArray, INDArray[]>> gradientParams = new ArrayList<>();
@Override
public NeuralNetwork[] getNeuralNetworks() {
return new NeuralNetwork[0];
}
@Override
public boolean isRecurrent() {
return false;
}
@Override
public void reset() {
}
@Override
public void fit(INDArray input, INDArray[] labels) {
}
@Override
public INDArray[] outputAll(INDArray batch) {
return new INDArray[0];
}
@Override
public IActorCritic clone() {
return this;
}
@Override
public void copy(NeuralNet from) {
}
@Override
public void copy(IActorCritic from) {
}
@Override
public Gradient[] gradient(INDArray input, INDArray[] labels) {
gradientParams.add(new Pair<INDArray, INDArray[]>(input, labels));
return new Gradient[0];
}
@Override
public void applyGradient(Gradient[] gradient, int batchSize) {
}
@Override
public void save(OutputStream streamValue, OutputStream streamPolicy) throws IOException {
}
@Override
public void save(String pathValue, String pathPolicy) throws IOException {
}
@Override
public double getLatestScore() {
return 0;
}
@Override
public void save(OutputStream os) throws IOException {
}
@Override
public void save(String filename) throws IOException {
}
}
}

View File

@ -0,0 +1,81 @@
package org.deeplearning4j.rl4j.learning.async.nstep.discrete;
import org.deeplearning4j.rl4j.learning.IHistoryProcessor;
import org.deeplearning4j.rl4j.learning.async.MiniTrans;
import org.deeplearning4j.rl4j.support.*;
import org.junit.Test;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import java.util.Stack;
import static org.junit.Assert.assertEquals;
public class AsyncNStepQLearningThreadDiscreteTest {
@Test
public void refac_calcGradient() {
// Arrange
double gamma = 0.9;
MockObservationSpace observationSpace = new MockObservationSpace();
MockMDP mdpMock = new MockMDP(observationSpace);
AsyncNStepQLearningDiscrete.AsyncNStepQLConfiguration config = new AsyncNStepQLearningDiscrete.AsyncNStepQLConfiguration(0, 0, 0, 0, 0, 0, 0, 0, gamma, 0, 0, 0);
MockDQN dqnMock = new MockDQN();
IHistoryProcessor.Configuration hpConf = new IHistoryProcessor.Configuration(5, 1, 1, 1, 1, 0, 0, 2);
MockAsyncGlobal asyncGlobalMock = new MockAsyncGlobal(dqnMock);
AsyncNStepQLearningThreadDiscrete sut = new AsyncNStepQLearningThreadDiscrete<MockEncodable>(mdpMock, asyncGlobalMock, config, null, 0, 0);
MockHistoryProcessor hpMock = new MockHistoryProcessor(hpConf);
sut.setHistoryProcessor(hpMock);
double[][] minitransObs = new double[][] {
new double[] { 0.0, 2.0, 4.0, 6.0, 8.0 },
new double[] { 2.0, 4.0, 6.0, 8.0, 10.0 },
new double[] { 4.0, 6.0, 8.0, 10.0, 12.0 },
};
double[] outputs = new double[] { 1.0, 2.0, 3.0 };
double[] rewards = new double[] { 0.0, 0.0, 3.0 };
Stack<MiniTrans<Integer>> minitransList = new Stack<MiniTrans<Integer>>();
for(int i = 0; i < 3; ++i) {
INDArray obs = Nd4j.create(minitransObs[i]).reshape(5, 1, 1);
INDArray[] output = new INDArray[] {
Nd4j.zeros(5)
};
output[0].putScalar(i, outputs[i]);
minitransList.push(new MiniTrans<Integer>(obs, i, output, rewards[i]));
}
minitransList.push(new MiniTrans<Integer>(null, 0, null, 4.0)); // The special batch-ending MiniTrans
// Act
sut.calcGradient(dqnMock, minitransList);
// Assert
assertEquals(1, dqnMock.gradientParams.size());
INDArray input = dqnMock.gradientParams.get(0).getFirst();
INDArray labels = dqnMock.gradientParams.get(0).getSecond();
assertEquals(minitransObs.length, input.shape()[0]);
for(int i = 0; i < minitransObs.length; ++i) {
double[] expectedRow = minitransObs[i];
assertEquals(expectedRow.length, input.shape()[1]);
for(int j = 0; j < expectedRow.length; ++j) {
assertEquals(expectedRow[j], input.getDouble(i, j, 1, 1), 0.00001);
}
}
double latestReward = (gamma * 4.0) + 3.0;
double[][] expectedLabels = new double[][] {
new double[] { gamma * gamma * latestReward, 0.0, 0.0, 0.0, 0.0 },
new double[] { 0.0, gamma * latestReward, 0.0, 0.0, 0.0 },
new double[] { 0.0, 0.0, latestReward, 0.0, 0.0 },
};
assertEquals(minitransObs.length, labels.shape()[0]);
for(int i = 0; i < minitransObs.length; ++i) {
double[] expectedRow = expectedLabels[i];
assertEquals(expectedRow.length, labels.shape()[1]);
for(int j = 0; j < expectedRow.length; ++j) {
assertEquals(expectedRow[j], labels.getDouble(i, j), 0.00001);
}
}
}
}

View File

@ -63,7 +63,7 @@ public class QLearningDiscreteTest {
double[] expectedAdds = new double[] { 0.0, 2.0, 4.0, 6.0, 8.0, 9.0, 11.0, 13.0, 15.0, 17.0, 19.0, 21.0, 23.0 };
assertEquals(expectedAdds.length, hp.addCalls.size());
for(int i = 0; i < expectedAdds.length; ++i) {
assertEquals(expectedAdds[i], 255.0 * hp.addCalls.get(i).getDouble(0), 0.0001);
assertEquals(expectedAdds[i], hp.addCalls.get(i).getDouble(0), 0.0001);
}
assertEquals(0, hp.startMonitorCallCount);
assertEquals(0, hp.stopMonitorCallCount);
@ -92,8 +92,8 @@ public class QLearningDiscreteTest {
for(int i = 0; i < expectedDQNOutput.length; ++i) {
INDArray outputParam = dqn.outputParams.get(i);
assertEquals(5, outputParam.shape()[0]);
assertEquals(1, outputParam.shape()[1]);
assertEquals(5, outputParam.shape()[1]);
assertEquals(1, outputParam.shape()[2]);
double[] expectedRow = expectedDQNOutput[i];
for(int j = 0; j < expectedRow.length; ++j) {
@ -124,13 +124,15 @@ public class QLearningDiscreteTest {
assertEquals(expectedTrActions[i], tr.getAction());
assertEquals(expectedTrNextObservation[i], 255.0 * tr.getNextObservation().getDouble(0), 0.0001);
for(int j = 0; j < expectedTrObservations[i].length; ++j) {
assertEquals("row: "+ i + " col: " + j, expectedTrObservations[i][j], 255.0 * tr.getObservation().getData().getDouble(j, 0), 0.0001);
assertEquals("row: "+ i + " col: " + j, expectedTrObservations[i][j], 255.0 * tr.getObservation().getData().getDouble(0, j, 0), 0.0001);
}
}
// trainEpoch result
assertEquals(16, result.getStepCounter());
assertEquals(300.0, result.getReward(), 0.00001);
assertTrue(dqn.hasBeenReset);
assertTrue(((MockDQN)sut.getTargetQNetwork()).hasBeenReset);
}
public static class TestQLearningDiscrete extends QLearningDiscrete<MockEncodable> {

View File

@ -28,6 +28,7 @@ import org.deeplearning4j.rl4j.learning.sync.qlearning.discrete.QLearningDiscret
import org.deeplearning4j.rl4j.mdp.MDP;
import org.deeplearning4j.rl4j.network.NeuralNet;
import org.deeplearning4j.rl4j.network.ac.IActorCritic;
import org.deeplearning4j.rl4j.observation.Observation;
import org.deeplearning4j.rl4j.space.DiscreteSpace;
import org.deeplearning4j.rl4j.space.Encodable;
import org.deeplearning4j.rl4j.support.*;
@ -229,6 +230,11 @@ public class PolicyTest {
return neuralNet;
}
@Override
public Integer nextAction(Observation obs) {
return nextAction(obs.getData());
}
@Override
public Integer nextAction(INDArray input) {
return (int)input.getDouble(0);

View File

@ -8,9 +8,10 @@ import org.deeplearning4j.rl4j.network.NeuralNet;
import java.util.concurrent.atomic.AtomicInteger;
public class MockAsyncGlobal implements IAsyncGlobal {
public class MockAsyncGlobal<NN extends NeuralNet> implements IAsyncGlobal<NN> {
private final NeuralNet current;
@Getter
private final NN current;
public boolean hasBeenStarted = false;
public boolean hasBeenTerminated = false;
@ -27,7 +28,7 @@ public class MockAsyncGlobal implements IAsyncGlobal {
this(null);
}
public MockAsyncGlobal(NeuralNet current) {
public MockAsyncGlobal(NN current) {
maxLoops = Integer.MAX_VALUE;
numLoopsStopRunning = Integer.MAX_VALUE;
this.current = current;
@ -45,7 +46,7 @@ public class MockAsyncGlobal implements IAsyncGlobal {
@Override
public boolean isTrainingComplete() {
return ++currentLoop > maxLoops;
return currentLoop >= maxLoops;
}
@Override
@ -59,12 +60,7 @@ public class MockAsyncGlobal implements IAsyncGlobal {
}
@Override
public NeuralNet getCurrent() {
return current;
}
@Override
public NeuralNet getTarget() {
public NN getTarget() {
return current;
}
@ -72,4 +68,8 @@ public class MockAsyncGlobal implements IAsyncGlobal {
public void enqueue(Gradient[] gradient, Integer nstep) {
++enqueueCallCount;
}
public void increaseCurrentLoop() {
++currentLoop;
}
}

View File

@ -15,8 +15,10 @@ import java.util.List;
public class MockDQN implements IDQN {
public boolean hasBeenReset = false;
public final List<INDArray> outputParams = new ArrayList<>();
public final List<Pair<INDArray, INDArray>> fitParams = new ArrayList<>();
public final List<Pair<INDArray, INDArray>> gradientParams = new ArrayList<>();
@Override
public NeuralNetwork[] getNeuralNetworks() {
@ -30,7 +32,7 @@ public class MockDQN implements IDQN {
@Override
public void reset() {
hasBeenReset = true;
}
@Override
@ -61,7 +63,10 @@ public class MockDQN implements IDQN {
@Override
public IDQN clone() {
return null;
MockDQN clone = new MockDQN();
clone.hasBeenReset = hasBeenReset;
return clone;
}
@Override
@ -76,6 +81,7 @@ public class MockDQN implements IDQN {
@Override
public Gradient[] gradient(INDArray input, INDArray label) {
gradientParams.add(new Pair<INDArray, INDArray>(input, label));
return new Gradient[0];
}

View File

@ -35,7 +35,7 @@ public class MockNeuralNet implements NeuralNet {
@Override
public INDArray[] outputAll(INDArray batch) {
outputAllInputs.add(batch);
return new INDArray[] { Nd4j.create(new double[] { 1.0 }) };
return new INDArray[] { Nd4j.create(new double[] { outputAllInputs.size() }) };
}
@Override

View File

@ -2,6 +2,7 @@ package org.deeplearning4j.rl4j.support;
import org.deeplearning4j.rl4j.learning.IHistoryProcessor;
import org.deeplearning4j.rl4j.mdp.MDP;
import org.deeplearning4j.rl4j.observation.Observation;
import org.deeplearning4j.rl4j.policy.IPolicy;
import org.deeplearning4j.rl4j.space.ActionSpace;
import org.nd4j.linalg.api.ndarray.INDArray;
@ -23,6 +24,11 @@ public class MockPolicy implements IPolicy<MockEncodable, Integer> {
@Override
public Integer nextAction(INDArray input) {
actionInputs.add(input);
return null;
return input.getInt(0, 0, 0);
}
@Override
public Integer nextAction(Observation observation) {
return nextAction(observation.getData());
}
}

View File

@ -11,6 +11,7 @@ import java.util.List;
public class MockTrainingListener implements TrainingListener {
private final MockAsyncGlobal asyncGlobal;
public int onTrainingStartCallCount = 0;
public int onTrainingEndCallCount = 0;
public int onNewEpochCallCount = 0;
@ -28,6 +29,14 @@ public class MockTrainingListener implements TrainingListener {
public final List<IDataManager.StatEntry> statEntries = new ArrayList<>();
public MockTrainingListener() {
this(null);
}
public MockTrainingListener(MockAsyncGlobal asyncGlobal) {
this.asyncGlobal = asyncGlobal;
}
@Override
public ListenerResponse onTrainingStart() {
@ -55,6 +64,9 @@ public class MockTrainingListener implements TrainingListener {
public ListenerResponse onTrainingProgress(ILearning learning) {
++onTrainingProgressCallCount;
--remainingonTrainingProgressCallCount;
if(asyncGlobal != null) {
asyncGlobal.increaseCurrentLoop();
}
return remainingonTrainingProgressCallCount < 0 ? ListenerResponse.STOP : ListenerResponse.CONTINUE;
}

View File

@ -39,4 +39,13 @@
<version>${project.version}</version>
</dependency>
</dependencies>
<profiles>
<profile>
<id>test-nd4j-native</id>
</profile>
<profile>
<id>test-nd4j-cuda-10.2</id>
</profile>
</profiles>
</project>

View File

@ -39,9 +39,18 @@
<version>${project.version}</version>
</dependency>
<dependency>
<groupId>org.deeplearning4j</groupId>
<artifactId>gym-java-client</artifactId>
<version>${project.version}</version>
<groupId>org.bytedeco</groupId>
<artifactId>gym-platform</artifactId>
<version>${gym.version}-${javacpp-presets.version}</version>
</dependency>
</dependencies>
<profiles>
<profile>
<id>test-nd4j-native</id>
</profile>
<profile>
<id>test-nd4j-cuda-10.2</id>
</profile>
</profiles>
</project>

View File

@ -1,5 +1,6 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
* Copyright (c) 2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
@ -16,39 +17,111 @@
package org.deeplearning4j.rl4j.mdp.gym;
import org.deeplearning4j.gym.Client;
import org.deeplearning4j.gym.ClientFactory;
import java.io.IOException;
import lombok.Getter;
import lombok.Setter;
import lombok.Value;
import lombok.extern.slf4j.Slf4j;
import org.bytedeco.javacpp.DoublePointer;
import org.bytedeco.javacpp.Pointer;
import org.bytedeco.javacpp.SizeTPointer;
import org.deeplearning4j.gym.StepReply;
import org.deeplearning4j.rl4j.mdp.MDP;
import org.deeplearning4j.rl4j.space.ArrayObservationSpace;
import org.deeplearning4j.rl4j.space.ActionSpace;
import org.deeplearning4j.rl4j.space.Box;
import org.deeplearning4j.rl4j.space.DiscreteSpace;
import org.deeplearning4j.rl4j.space.Encodable;
import org.deeplearning4j.rl4j.space.HighLowDiscrete;
import org.deeplearning4j.rl4j.space.ObservationSpace;
import org.bytedeco.cpython.*;
import org.bytedeco.numpy.*;
import static org.bytedeco.cpython.global.python.*;
import static org.bytedeco.numpy.global.numpy.*;
/**
* An MDP for OpenAI Gym: https://gym.openai.com/
*
* @author rubenfiszel (ruben.fiszel@epfl.ch) 7/12/16.
*
* Wrapper over the client of gym-java-client
*
* @author saudet
*/
@Slf4j
public class GymEnv<O, A, AS extends ActionSpace<A>> implements MDP<O, A, AS> {
final public static String GYM_MONITOR_DIR = "/tmp/gym-dqn";
public static final String GYM_MONITOR_DIR = "/tmp/gym-dqn";
final private Client<O, A, AS> client;
private static void checkPythonError() {
if (PyErr_Occurred() != null) {
PyErr_Print();
throw new RuntimeException("Python error occurred");
}
}
private static Pointer program;
private static PyObject globals;
static {
try {
Py_SetPath(org.bytedeco.gym.presets.gym.cachePackages());
program = Py_DecodeLocale(GymEnv.class.getSimpleName(), null);
Py_SetProgramName(program);
Py_Initialize();
PyEval_InitThreads();
PySys_SetArgvEx(1, program, 0);
if (_import_array() < 0) {
PyErr_Print();
throw new RuntimeException("numpy.core.multiarray failed to import");
}
globals = PyModule_GetDict(PyImport_AddModule("__main__"));
PyEval_SaveThread(); // just to release the GIL
} catch (IOException e) {
PyMem_RawFree(program);
throw new RuntimeException(e);
}
}
private PyObject locals;
final protected DiscreteSpace actionSpace;
final protected ObservationSpace<O> observationSpace;
@Getter
final private String envId;
@Getter
final private boolean render;
@Getter
final private boolean monitor;
private ActionTransformer actionTransformer = null;
private boolean done = false;
public GymEnv(String envId, boolean render, boolean monitor) {
this.client = ClientFactory.build(envId, render);
this.envId = envId;
this.render = render;
this.monitor = monitor;
if (monitor)
client.monitorStart(GYM_MONITOR_DIR, true, false);
int gstate = PyGILState_Ensure();
try {
locals = PyDict_New();
Py_DecRef(PyRun_StringFlags("import gym; env = gym.make('" + envId + "')", Py_single_input, globals, locals, null));
checkPythonError();
if (monitor) {
Py_DecRef(PyRun_StringFlags("env = gym.wrappers.Monitor(env, '" + GYM_MONITOR_DIR + "')", Py_single_input, globals, locals, null));
checkPythonError();
}
PyObject shapeTuple = PyRun_StringFlags("env.observation_space.shape", Py_eval_input, globals, locals, null);
int[] shape = new int[(int)PyTuple_Size(shapeTuple)];
for (int i = 0; i < shape.length; i++) {
shape[i] = (int)PyLong_AsLong(PyTuple_GetItem(shapeTuple, i));
}
observationSpace = (ObservationSpace<O>) new ArrayObservationSpace<Box>(shape);
Py_DecRef(shapeTuple);
PyObject n = PyRun_StringFlags("env.action_space.n", Py_eval_input, globals, locals, null);
actionSpace = new DiscreteSpace((int)PyLong_AsLong(n));
Py_DecRef(n);
checkPythonError();
} finally {
PyGILState_Release(gstate);
}
}
public GymEnv(String envId, boolean render, boolean monitor, int[] actions) {
@ -56,43 +129,87 @@ public class GymEnv<O, A, AS extends ActionSpace<A>> implements MDP<O, A, AS> {
actionTransformer = new ActionTransformer((HighLowDiscrete) getActionSpace(), actions);
}
@Override
public ObservationSpace<O> getObservationSpace() {
return client.getObservationSpace();
return observationSpace;
}
@Override
public AS getActionSpace() {
if (actionTransformer == null)
return (AS) client.getActionSpace();
return (AS) actionSpace;
else
return (AS) actionTransformer;
}
@Override
public StepReply<O> step(A action) {
StepReply<O> stepRep = client.step(action);
done = stepRep.isDone();
return stepRep;
int gstate = PyGILState_Ensure();
try {
if (render) {
Py_DecRef(PyRun_StringFlags("env.render()", Py_single_input, globals, locals, null));
checkPythonError();
}
Py_DecRef(PyRun_StringFlags("state, reward, done, info = env.step(" + (Integer)action +")", Py_single_input, globals, locals, null));
checkPythonError();
PyArrayObject state = new PyArrayObject(PyDict_GetItemString(locals, "state"));
DoublePointer stateData = new DoublePointer(PyArray_BYTES(state)).capacity(PyArray_Size(state));
SizeTPointer stateDims = PyArray_DIMS(state).capacity(PyArray_NDIM(state));
double reward = PyFloat_AsDouble(PyDict_GetItemString(locals, "reward"));
done = PyLong_AsLong(PyDict_GetItemString(locals, "done")) != 0;
checkPythonError();
double[] data = new double[(int)stateData.capacity()];
stateData.get(data);
return new StepReply(new Box(data), reward, done, null);
} finally {
PyGILState_Release(gstate);
}
}
@Override
public boolean isDone() {
return done;
}
@Override
public O reset() {
done = false;
return client.reset();
}
public void upload(String apiKey) {
client.upload(GYM_MONITOR_DIR, apiKey);
int gstate = PyGILState_Ensure();
try {
Py_DecRef(PyRun_StringFlags("state = env.reset()", Py_single_input, globals, locals, null));
checkPythonError();
PyArrayObject state = new PyArrayObject(PyDict_GetItemString(locals, "state"));
DoublePointer stateData = new DoublePointer(PyArray_BYTES(state)).capacity(PyArray_Size(state));
SizeTPointer stateDims = PyArray_DIMS(state).capacity(PyArray_NDIM(state));
checkPythonError();
done = false;
double[] data = new double[(int)stateData.capacity()];
stateData.get(data);
return (O) new Box(data);
} finally {
PyGILState_Release(gstate);
}
}
@Override
public void close() {
if (monitor)
client.monitorClose();
int gstate = PyGILState_Ensure();
try {
Py_DecRef(PyRun_StringFlags("env.close()", Py_single_input, globals, locals, null));
checkPythonError();
Py_DecRef(locals);
} finally {
PyGILState_Release(gstate);
}
}
@Override
public GymEnv<O, A, AS> newInstance() {
return new GymEnv<O, A, AS>(envId, render, monitor);
}

View File

@ -0,0 +1,48 @@
/*******************************************************************************
* Copyright (c) 2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.rl4j.mdp.gym;
import org.deeplearning4j.gym.StepReply;
import org.deeplearning4j.rl4j.space.ArrayObservationSpace;
import org.deeplearning4j.rl4j.space.Box;
import org.deeplearning4j.rl4j.space.DiscreteSpace;
import org.junit.Test;
import static org.junit.Assert.assertArrayEquals;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotEquals;
/**
*
* @author saudet
*/
public class GymEnvTest {
@Test
public void testCartpole() {
GymEnv mdp = new GymEnv("CartPole-v0", false, false);
assertArrayEquals(new int[] {4}, ((ArrayObservationSpace)mdp.getObservationSpace()).getShape());
assertEquals(2, ((DiscreteSpace)mdp.getActionSpace()).getSize());
assertEquals(false, mdp.isDone());
Box o = (Box)mdp.reset();
StepReply r = mdp.step(0);
assertEquals(4, o.toArray().length);
assertEquals(4, ((Box)r.getObservation()).toArray().length);
assertNotEquals(null, mdp.newInstance());
mdp.close();
}
}

View File

@ -33,6 +33,11 @@
</properties>
<dependencies>
<dependency>
<groupId>org.json</groupId>
<artifactId>json</artifactId>
<version>20190722</version>
</dependency>
<dependency>
<groupId>org.deeplearning4j</groupId>
<artifactId>rl4j-api</artifactId>
@ -44,4 +49,13 @@
<version>0.30.0</version>
</dependency>
</dependencies>
<profiles>
<profile>
<id>test-nd4j-native</id>
</profile>
<profile>
<id>test-nd4j-cuda-10.2</id>
</profile>
</profiles>
</project>