From 9edbefdc67f6174564fd667dfe20f88c2cdfd178 Mon Sep 17 00:00:00 2001 From: Samuel Audet Date: Mon, 20 Jan 2020 17:13:57 +0900 Subject: [PATCH] RL4J: Replace gym-java-client with JavaCPP (#8595) * RL4J: Replace gym-java-client with JavaCPP Signed-off-by: Samuel Audet --- README.md | 1 - gym-java-client/.gitignore | 4 - gym-java-client/LICENSE.txt | 201 ---------- gym-java-client/README.md | 108 ------ gym-java-client/contrib/formatter.xml | 354 ------------------ gym-java-client/pom.xml | 337 ----------------- .../java/org/deeplearning4j/gym/Client.java | 200 ---------- .../org/deeplearning4j/gym/ClientFactory.java | 90 ----- .../org/deeplearning4j/gym/ClientUtils.java | 75 ---- .../rl4j/space/GymObservationSpace.java | 82 ---- .../deeplearning4j/gym/test/ClientTest.java | 140 ------- .../gym/test/JSONObjectMatcher.java | 46 --- pom.xml | 22 +- rl4j/README.md | 10 +- rl4j/pom.xml | 50 +++ rl4j/rl4j-ale/pom.xml | 9 + rl4j/rl4j-api/pom.xml | 14 +- .../org/deeplearning4j/gym/StepReply.java | 3 +- .../rl4j/space/ActionSpace.java | 0 .../rl4j/space/ArrayObservationSpace.java | 0 .../org/deeplearning4j/rl4j/space/Box.java | 12 +- .../rl4j/space/DiscreteSpace.java | 0 .../deeplearning4j/rl4j/space/Encodable.java | 0 .../rl4j/space/HighLowDiscrete.java | 12 +- .../rl4j/space/ObservationSpace.java | 0 rl4j/rl4j-core/pom.xml | 26 +- .../rl4j/mdp/toy/HardDeteministicToy.java | 3 +- .../rl4j/mdp/toy/SimpleToy.java | 3 +- .../deeplearning4j/rl4j/util/DataManager.java | 49 ++- .../rl4j/util/LegacyMDPWrapper.java | 5 +- rl4j/rl4j-doom/pom.xml | 9 + rl4j/rl4j-gym/pom.xml | 15 +- .../deeplearning4j/rl4j/mdp/gym/GymEnv.java | 169 +++++++-- .../rl4j/mdp/gym/GymEnvTest.java | 48 +++ rl4j/rl4j-malmo/pom.xml | 14 + 35 files changed, 358 insertions(+), 1753 deletions(-) delete mode 100644 gym-java-client/.gitignore delete mode 100644 gym-java-client/LICENSE.txt delete mode 100644 gym-java-client/README.md delete mode 100644 gym-java-client/contrib/formatter.xml delete mode 100644 gym-java-client/pom.xml delete mode 100644 gym-java-client/src/main/java/org/deeplearning4j/gym/Client.java delete mode 100644 gym-java-client/src/main/java/org/deeplearning4j/gym/ClientFactory.java delete mode 100644 gym-java-client/src/main/java/org/deeplearning4j/gym/ClientUtils.java delete mode 100644 gym-java-client/src/main/java/org/deeplearning4j/rl4j/space/GymObservationSpace.java delete mode 100644 gym-java-client/src/test/java/org/deeplearning4j/gym/test/ClientTest.java delete mode 100644 gym-java-client/src/test/java/org/deeplearning4j/gym/test/JSONObjectMatcher.java rename {gym-java-client => rl4j/rl4j-api}/src/main/java/org/deeplearning4j/gym/StepReply.java (95%) rename {gym-java-client => rl4j/rl4j-api}/src/main/java/org/deeplearning4j/rl4j/space/ActionSpace.java (100%) rename {gym-java-client => rl4j/rl4j-api}/src/main/java/org/deeplearning4j/rl4j/space/ArrayObservationSpace.java (100%) rename {gym-java-client => rl4j/rl4j-api}/src/main/java/org/deeplearning4j/rl4j/space/Box.java (83%) rename {gym-java-client => rl4j/rl4j-api}/src/main/java/org/deeplearning4j/rl4j/space/DiscreteSpace.java (100%) rename {gym-java-client => rl4j/rl4j-api}/src/main/java/org/deeplearning4j/rl4j/space/Encodable.java (100%) rename {gym-java-client => rl4j/rl4j-api}/src/main/java/org/deeplearning4j/rl4j/space/HighLowDiscrete.java (82%) rename {gym-java-client => rl4j/rl4j-api}/src/main/java/org/deeplearning4j/rl4j/space/ObservationSpace.java (100%) create mode 100644 rl4j/rl4j-gym/src/test/java/org/deeplearning4j/rl4j/mdp/gym/GymEnvTest.java diff --git a/README.md b/README.md index 2b69f7981..6a3d206c7 100644 --- a/README.md +++ b/README.md @@ -7,7 +7,6 @@ Welcome to the new monorepo of Deeplearning4j that contains the source code for * https://github.com/eclipse/deeplearning4j/tree/master/datavec * https://github.com/eclipse/deeplearning4j/tree/master/arbiter * https://github.com/eclipse/deeplearning4j/tree/master/nd4s - * https://github.com/eclipse/deeplearning4j/tree/master/gym-java-client * https://github.com/eclipse/deeplearning4j/tree/master/rl4j * https://github.com/eclipse/deeplearning4j/tree/master/scalnet * https://github.com/eclipse/deeplearning4j/tree/master/pydl4j diff --git a/gym-java-client/.gitignore b/gym-java-client/.gitignore deleted file mode 100644 index 089ca5bb0..000000000 --- a/gym-java-client/.gitignore +++ /dev/null @@ -1,4 +0,0 @@ -target/ -.idea/ -*.iml -*-git.properties diff --git a/gym-java-client/LICENSE.txt b/gym-java-client/LICENSE.txt deleted file mode 100644 index 5c304d1a4..000000000 --- a/gym-java-client/LICENSE.txt +++ /dev/null @@ -1,201 +0,0 @@ -Apache License - Version 2.0, January 2004 - http://www.apache.org/licenses/ - - TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION - - 1. Definitions. - - "License" shall mean the terms and conditions for use, reproduction, - and distribution as defined by Sections 1 through 9 of this document. - - "Licensor" shall mean the copyright owner or entity authorized by - the copyright owner that is granting the License. - - "Legal Entity" shall mean the union of the acting entity and all - other entities that control, are controlled by, or are under common - control with that entity. For the purposes of this definition, - "control" means (i) the power, direct or indirect, to cause the - direction or management of such entity, whether by contract or - otherwise, or (ii) ownership of fifty percent (50%) or more of the - outstanding shares, or (iii) beneficial ownership of such entity. - - "You" (or "Your") shall mean an individual or Legal Entity - exercising permissions granted by this License. - - "Source" form shall mean the preferred form for making modifications, - including but not limited to software source code, documentation - source, and configuration files. - - "Object" form shall mean any form resulting from mechanical - transformation or translation of a Source form, including but - not limited to compiled object code, generated documentation, - and conversions to other media types. - - "Work" shall mean the work of authorship, whether in Source or - Object form, made available under the License, as indicated by a - copyright notice that is included in or attached to the work - (an example is provided in the Appendix below). - - "Derivative Works" shall mean any work, whether in Source or Object - form, that is based on (or derived from) the Work and for which the - editorial revisions, annotations, elaborations, or other modifications - represent, as a whole, an original work of authorship. For the purposes - of this License, Derivative Works shall not include works that remain - separable from, or merely link (or bind by name) to the interfaces of, - the Work and Derivative Works thereof. - - "Contribution" shall mean any work of authorship, including - the original version of the Work and any modifications or additions - to that Work or Derivative Works thereof, that is intentionally - submitted to Licensor for inclusion in the Work by the copyright owner - or by an individual or Legal Entity authorized to submit on behalf of - the copyright owner. For the purposes of this definition, "submitted" - means any form of electronic, verbal, or written communication sent - to the Licensor or its representatives, including but not limited to - communication on electronic mailing lists, source code control systems, - and issue tracking systems that are managed by, or on behalf of, the - Licensor for the purpose of discussing and improving the Work, but - excluding communication that is conspicuously marked or otherwise - designated in writing by the copyright owner as "Not a Contribution." - - "Contributor" shall mean Licensor and any individual or Legal Entity - on behalf of whom a Contribution has been received by Licensor and - subsequently incorporated within the Work. - - 2. Grant of Copyright License. Subject to the terms and conditions of - this License, each Contributor hereby grants to You a perpetual, - worldwide, non-exclusive, no-charge, royalty-free, irrevocable - copyright license to reproduce, prepare Derivative Works of, - publicly display, publicly perform, sublicense, and distribute the - Work and such Derivative Works in Source or Object form. - - 3. Grant of Patent License. Subject to the terms and conditions of - this License, each Contributor hereby grants to You a perpetual, - worldwide, non-exclusive, no-charge, royalty-free, irrevocable - (except as stated in this section) patent license to make, have made, - use, offer to sell, sell, import, and otherwise transfer the Work, - where such license applies only to those patent claims licensable - by such Contributor that are necessarily infringed by their - Contribution(s) alone or by combination of their Contribution(s) - with the Work to which such Contribution(s) was submitted. If You - institute patent litigation against any entity (including a - cross-claim or counterclaim in a lawsuit) alleging that the Work - or a Contribution incorporated within the Work constitutes direct - or contributory patent infringement, then any patent licenses - granted to You under this License for that Work shall terminate - as of the date such litigation is filed. - - 4. Redistribution. You may reproduce and distribute copies of the - Work or Derivative Works thereof in any medium, with or without - modifications, and in Source or Object form, provided that You - meet the following conditions: - - (a) You must give any other recipients of the Work or - Derivative Works a copy of this License; and - - (b) You must cause any modified files to carry prominent notices - stating that You changed the files; and - - (c) You must retain, in the Source form of any Derivative Works - that You distribute, all copyright, patent, trademark, and - attribution notices from the Source form of the Work, - excluding those notices that do not pertain to any part of - the Derivative Works; and - - (d) If the Work includes a "NOTICE" text file as part of its - distribution, then any Derivative Works that You distribute must - include a readable copy of the attribution notices contained - within such NOTICE file, excluding those notices that do not - pertain to any part of the Derivative Works, in at least one - of the following places: within a NOTICE text file distributed - as part of the Derivative Works; within the Source form or - documentation, if provided along with the Derivative Works; or, - within a display generated by the Derivative Works, if and - wherever such third-party notices normally appear. The contents - of the NOTICE file are for informational purposes only and - do not modify the License. You may add Your own attribution - notices within Derivative Works that You distribute, alongside - or as an addendum to the NOTICE text from the Work, provided - that such additional attribution notices cannot be construed - as modifying the License. - - You may add Your own copyright statement to Your modifications and - may provide additional or different license terms and conditions - for use, reproduction, or distribution of Your modifications, or - for any such Derivative Works as a whole, provided Your use, - reproduction, and distribution of the Work otherwise complies with - the conditions stated in this License. - - 5. Submission of Contributions. Unless You explicitly state otherwise, - any Contribution intentionally submitted for inclusion in the Work - by You to the Licensor shall be under the terms and conditions of - this License, without any additional terms or conditions. - Notwithstanding the above, nothing herein shall supersede or modify - the terms of any separate license agreement you may have executed - with Licensor regarding such Contributions. - - 6. Trademarks. This License does not grant permission to use the trade - names, trademarks, service marks, or product names of the Licensor, - except as required for reasonable and customary use in describing the - origin of the Work and reproducing the content of the NOTICE file. - - 7. Disclaimer of Warranty. Unless required by applicable law or - agreed to in writing, Licensor provides the Work (and each - Contributor provides its Contributions) on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or - implied, including, without limitation, any warranties or conditions - of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A - PARTICULAR PURPOSE. You are solely responsible for determining the - appropriateness of using or redistributing the Work and assume any - risks associated with Your exercise of permissions under this License. - - 8. Limitation of Liability. In no event and under no legal theory, - whether in tort (including negligence), contract, or otherwise, - unless required by applicable law (such as deliberate and grossly - negligent acts) or agreed to in writing, shall any Contributor be - liable to You for damages, including any direct, indirect, special, - incidental, or consequential damages of any character arising as a - result of this License or out of the use or inability to use the - Work (including but not limited to damages for loss of goodwill, - work stoppage, computer failure or malfunction, or any and all - other commercial damages or losses), even if such Contributor - has been advised of the possibility of such damages. - - 9. Accepting Warranty or Additional Liability. While redistributing - the Work or Derivative Works thereof, You may choose to offer, - and charge a fee for, acceptance of support, warranty, indemnity, - or other liability obligations and/or rights consistent with this - License. However, in accepting such obligations, You may act only - on Your own behalf and on Your sole responsibility, not on behalf - of any other Contributor, and only if You agree to indemnify, - defend, and hold each Contributor harmless for any liability - incurred by, or claims asserted against, such Contributor by reason - of your accepting any such warranty or additional liability. - - END OF TERMS AND CONDITIONS - - APPENDIX: How to apply the Apache License to your work. - - To apply the Apache License to your work, attach the following - boilerplate notice, with the fields enclosed by brackets "{}" - replaced with your own identifying information. (Don't include - the brackets!) The text should be enclosed in the appropriate - comment syntax for the file format. We also recommend that a - file or class name and description of purpose be included on the - same "printed page" as the copyright notice for easier - identification within third-party archives. - - Copyright {yyyy} {name of copyright owner} - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. diff --git a/gym-java-client/README.md b/gym-java-client/README.md deleted file mode 100644 index 3ae9aa408..000000000 --- a/gym-java-client/README.md +++ /dev/null @@ -1,108 +0,0 @@ -# gym-java-client - -A java http client for [gym-http-api](https://github.com/openai/gym-http-api). - -Note: If you are encountering errors as reported in [issue #13](https://github.com/deeplearning4j/gym-java-client/issues/13), please execute the following command before launching `python gym_http_server.py`: - -```bash -$ sudo sysctl -w net.ipv4.tcp_tw_recycle=1 -``` - -# Quickstart - -To create a new Client, use the ClientFactory. If the url is not localhost:5000, provide it as a second argument - -```java -Client client = ClientFactory.build("CartPole-v0"); -``` - -"CartPole-v0" is the name of the gym environment. - -The type parameters of a client are the Observation type, the Action type, the Observation Space type and the ActionSpace type. - -It is a bit cumbersome to both declare an ActionSpace and an Action since an ActionSpace knows what type is an Action but unfortunately java does't support type member and path dependant types. - -Here we use Box and BoxSpace for the environment and Integer and Discrete Space because it is how [CartPole-v0](https://gym.openai.com/envs/CartPole-v0) is specified. - -The methods nomenclature follows closely the api interface of [gym-http-api](https://github.com/openai/gym-http-api#api-specification), O is Observation an A is Action: - -```java -//Static methods - -/** - * @param url url of the server - * @return set of all environments running on the server at the url - */ -public static Set listAll(String url); - -/** - * Shutdown the server at the url - * - * @param url url of the server - */ -public static void serverShutdown(String url); - - - -//Methods accessible from a Client -/** - * @return set of all environments running on the same server than this client - */ -public Set listAll(); - -/** - * Step the environment by one action - * - * @param action action to step the environment with - * @return the StepReply containing the next observation, the reward, if it is a terminal state and optional information. - */ -public StepReply step(A action); -/** - * Reset the state of the environment and return an initial observation. - * - * @return initial observation - */ -public O reset(); - -/** - * Start monitoring. - * - * @param directory path to directory in which store the monitoring file - * @param force clear out existing training data from this directory (by deleting every file prefixed with "openaigym.") - * @param resume retain the training data already in this directory, which will be merged with our new data - */ -public void monitorStart(String directory, boolean force, boolean resume); - -/** - * Flush all monitor data to disk - */ -public void monitorClose(); - -/** - * Upload monitoring data to OpenAI servers. - * - * @param trainingDir directory that contains the monitoring data - * @param apiKey personal OpenAI API key - * @param algorithmId an arbitrary string indicating the paricular version of the algorithm (including choices of parameters) you are running. - **/ -public void upload(String trainingDir, String apiKey, String algorithmId); - -/** - * Upload monitoring data to OpenAI servers. - * - * @param trainingDir directory that contains the monitoring data - * @param apiKey personal OpenAI API key - */ -public void upload(String trainingDir, String apiKey); - - -/** - * Shutdown the server at the same url than this client - */ -public void serverShutdown() - -``` - -## TODO - -* Add all ObservationSpace and ActionSpace when they will be available. diff --git a/gym-java-client/contrib/formatter.xml b/gym-java-client/contrib/formatter.xml deleted file mode 100644 index 1d0adbbe1..000000000 --- a/gym-java-client/contrib/formatter.xml +++ /dev/null @@ -1,354 +0,0 @@ - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - diff --git a/gym-java-client/pom.xml b/gym-java-client/pom.xml deleted file mode 100644 index fc82d01b2..000000000 --- a/gym-java-client/pom.xml +++ /dev/null @@ -1,337 +0,0 @@ - - - - - - - - org.deeplearning4j - deeplearning4j - 1.0.0-SNAPSHOT - - - 4.0.0 - - org.deeplearning4j - gym-java-client - - gym-java-client - A Java client for Open AI's Reinforcement Learning Gym - - - - Apache License, Version 2.0 - http://www.apache.org/licenses/LICENSE-2.0.txt - repo - - - - - - rubenfiszel - Ruben Fiszel - ruben.fiszel@epfl.ch - - - - - nd4j-native - - - - - org.nd4j - ${nd4j.backend} - ${nd4j.version} - - - commons-codec - commons-codec - ${commons-codec.version} - - - org.apache.httpcomponents - httpclient - ${httpclient.version} - - - org.apache.httpcomponents - httpcore - ${httpcore.version} - - - org.apache.httpcomponents - httpmime - ${httpmime.version} - - - com.mashape.unirest - unirest-java - ${unirest.version} - - - org.objenesis - objenesis - ${objenesis.version} - - - org.mockito - mockito-core - ${mockito.version} - test - - - org.ow2.asm - asm - ${asm.version} - - - cglib - cglib - 3.1 - test - - - junit - junit - ${junit.version} - test - - - org.powermock - powermock-api-mockito2 - 1.7.3 - test - - - org.powermock - powermock-module-junit4 - 1.7.3 - test - - - org.slf4j - slf4j-api - ${slf4j.version} - - - ch.qos.logback - logback-classic - ${logback.version} - test - - - ch.qos.logback - logback-core - ${logback.version} - test - - - org.projectlombok - lombok - ${lombok.version} - provided - - - - - - - maven-source-plugin - ${maven-source-plugin.version} - - - attach-sources - - jar - - - - - - maven-surefire-plugin - ${maven-surefire-plugin.version} - - - true - false - - - - maven-javadoc-plugin - ${maven-javadoc-plugin.version} - - -Xdoclint:none - - - - attach-javadocs - - jar - - - - - - com.lewisd - lint-maven-plugin - 0.0.11 - - true - - DuplicateDep - RedundantDepVersion - RedundantPluginVersion - - - - - - pom-lint - validate - - check - - - - - - net.revelc.code.formatter - formatter-maven-plugin - 2.0.0 - - ${session.executionRootDirectory}/contrib/formatter.xml - - - - - pl.project13.maven - git-commit-id-plugin - ${maven-git-commit-plugin.version} - - - - revision - - generate-resources - - - - true - - ${project.basedir}/target/generated-sources/src/main/resources/ai/skymind/${project.groupId}-${project.artifactId}-git.properties - - - true - - - - - - org.codehaus.mojo - build-helper-maven-plugin - ${maven-build-helper-plugin.version} - - - add-resource - generate-resources - - add-resource - - - - - - ${project.basedir}/target/generated-sources/src/main/resources - - - - - - - - - - - - org.eclipse.m2e - lifecycle-mapping - 1.0.0 - - - - - - com.lewisd - lint-maven-plugin - [0.0.11,) - - check - - - - - - - - - - - - - - - - - - maven-surefire-report-plugin - 2.19.1 - - - - org.codehaus.mojo - cobertura-maven-plugin - 2.7 - - - - - - - nd4j-backend - - - libnd4j.cuda - - - - nd4j-cuda-${libnd4j.cuda} - - - - diff --git a/gym-java-client/src/main/java/org/deeplearning4j/gym/Client.java b/gym-java-client/src/main/java/org/deeplearning4j/gym/Client.java deleted file mode 100644 index c3f221c3f..000000000 --- a/gym-java-client/src/main/java/org/deeplearning4j/gym/Client.java +++ /dev/null @@ -1,200 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.deeplearning4j.gym; - - -import com.mashape.unirest.http.JsonNode; -import lombok.Value; -import lombok.extern.slf4j.Slf4j; -import org.deeplearning4j.rl4j.space.GymObservationSpace; -import org.deeplearning4j.rl4j.space.ActionSpace; -import org.json.JSONObject; - -import java.util.Set; - -/** - * @author rubenfiszel (ruben.fiszel@epfl.ch) on 7/6/16. - * - * A client represent an active connection to a specific instance of an environment on a rl4j-http-api server. - * for API specification - * - * @param Observation type - * @param Action type - * @param Action Space type - * @see https://github.com/openai/gym-http-api#api-specification - */ -@Slf4j -@Value -public class Client> { - - - public static String V1_ROOT = "/v1"; - public static String ENVS_ROOT = V1_ROOT + "/envs/"; - - public static String MONITOR_START = "/monitor/start/"; - public static String MONITOR_CLOSE = "/monitor/close/"; - public static String CLOSE = "/close/"; - public static String RESET = "/reset/"; - public static String SHUTDOWN = "/shutdown/"; - public static String UPLOAD = "/upload/"; - public static String STEP = "/step/"; - public static String OBSERVATION_SPACE = "/observation_space/"; - public static String ACTION_SPACE = "/action_space/"; - - - String url; - String envId; - String instanceId; - GymObservationSpace observationSpace; - AS actionSpace; - boolean render; - - - /** - * @param url url of the server - * @return set of all environments running on the server at the url - */ - public static Set listAll(String url) { - JSONObject reply = ClientUtils.get(url + ENVS_ROOT); - return reply.getJSONObject("envs").keySet(); - } - - /** - * Shutdown the server at the url - * - * @param url url of the server - */ - public static void serverShutdown(String url) { - ClientUtils.post(url + ENVS_ROOT + SHUTDOWN, new JSONObject()); - } - - /** - * @return set of all environments running on the same server than this client - */ - public Set listAll() { - return listAll(url); - } - - /** - * Step the environment by one action - * - * @param action action to step the environment with - * @return the StepReply containing the next observation, the reward, if it is a terminal state and optional information. - */ - public StepReply step(A action) { - JSONObject body = new JSONObject().put("action", getActionSpace().encode(action)).put("render", render); - - JSONObject reply = ClientUtils.post(url + ENVS_ROOT + instanceId + STEP, body).getObject(); - - O observation = observationSpace.getValue(reply, "observation"); - double reward = reply.getDouble("reward"); - boolean done = reply.getBoolean("done"); - JSONObject info = reply.getJSONObject("info"); - - return new StepReply(observation, reward, done, info); - } - - /** - * Reset the state of the environment and return an initial observation. - * - * @return initial observation - */ - public O reset() { - JsonNode resetRep = ClientUtils.post(url + ENVS_ROOT + instanceId + RESET, new JSONObject()); - return observationSpace.getValue(resetRep.getObject(), "observation"); - } - - /* - Present in the doc but not working currently server-side - public void monitorStart(String directory) { - - JSONObject json = new JSONObject() - .put("directory", directory); - - monitorStartPost(json); - } - */ - - /** - * Start monitoring. - * - * @param directory path to directory in which store the monitoring file - * @param force clear out existing training data from this directory (by deleting every file prefixed with "openaigym.") - * @param resume retain the training data already in this directory, which will be merged with our new data - */ - public void monitorStart(String directory, boolean force, boolean resume) { - JSONObject json = new JSONObject().put("directory", directory).put("force", force).put("resume", resume); - - monitorStartPost(json); - } - - private void monitorStartPost(JSONObject json) { - ClientUtils.post(url + ENVS_ROOT + instanceId + MONITOR_START, json); - } - - /** - * Flush all monitor data to disk - */ - public void monitorClose() { - ClientUtils.post(url + ENVS_ROOT + instanceId + MONITOR_CLOSE, new JSONObject()); - } - - /** - * Upload monitoring data to OpenAI servers. - * - * @param trainingDir directory that contains the monitoring data - * @param apiKey personal OpenAI API key - * @param algorithmId an arbitrary string indicating the paricular version of the algorithm (including choices of parameters) you are running. - **/ - public void upload(String trainingDir, String apiKey, String algorithmId) { - JSONObject json = new JSONObject().put("training_dir", trainingDir).put("api_key", apiKey).put("algorithm_id", - algorithmId); - - uploadPost(json); - } - - /** - * Upload monitoring data to OpenAI servers. - * - * @param trainingDir directory that contains the monitoring data - * @param apiKey personal OpenAI API key - */ - public void upload(String trainingDir, String apiKey) { - JSONObject json = new JSONObject().put("training_dir", trainingDir).put("api_key", apiKey); - - uploadPost(json); - } - - private void uploadPost(JSONObject json) { - try { - ClientUtils.post(url + V1_ROOT + UPLOAD, json); - } catch (RuntimeException e) { - log.error("Impossible to upload: Wrong API key?"); - } - } - - /** - * Shutdown the server at the same url than this client - */ - public void serverShutdown() { - serverShutdown(url); - } - - public ActionSpace getActionSpace(){ - return actionSpace; - } -} diff --git a/gym-java-client/src/main/java/org/deeplearning4j/gym/ClientFactory.java b/gym-java-client/src/main/java/org/deeplearning4j/gym/ClientFactory.java deleted file mode 100644 index 57002fb77..000000000 --- a/gym-java-client/src/main/java/org/deeplearning4j/gym/ClientFactory.java +++ /dev/null @@ -1,90 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.deeplearning4j.gym; - -import org.deeplearning4j.rl4j.space.ActionSpace; -import org.deeplearning4j.rl4j.space.DiscreteSpace; -import org.deeplearning4j.rl4j.space.GymObservationSpace; -import org.deeplearning4j.rl4j.space.HighLowDiscrete; -import org.json.JSONArray; -import org.json.JSONException; -import org.json.JSONObject; -import org.nd4j.linalg.api.ndarray.INDArray; -import org.nd4j.linalg.factory.Nd4j; - - -/** - * @author rubenfiszel (ruben.fiszel@epfl.ch) on 7/8/16. - * - * ClientFactory contains builder method to create a new {@link Client} - */ -public class ClientFactory { - - public static > Client build(String url, String envId, boolean render) { - - JSONObject body = new JSONObject().put("env_id", envId); - JSONObject reply = ClientUtils.post(url + Client.ENVS_ROOT, body).getObject(); - - String instanceId; - - try { - instanceId = reply.getString("instance_id"); - } catch (JSONException e) { - throw new RuntimeException("Environment id not found", e); - } - - GymObservationSpace observationSpace = fetchObservationSpace(url, instanceId); - AS actionSpace = fetchActionSpace(url, instanceId); - - return new Client(url, envId, instanceId, observationSpace, actionSpace, render); - - } - - public static > Client build(String envId, boolean render) { - return build("http://127.0.0.1:5000", envId, render); - } - - public static AS fetchActionSpace(String url, String instanceId) { - - JSONObject reply = ClientUtils.get(url + Client.ENVS_ROOT + instanceId + Client.ACTION_SPACE); - JSONObject info = reply.getJSONObject("info"); - String infoName = info.getString("name"); - - switch (infoName) { - case "Discrete": - return (AS) new DiscreteSpace(info.getInt("n")); - case "HighLow": - int numRows = info.getInt("num_rows"); - int size = 3 * numRows; - JSONArray matrixJson = info.getJSONArray("matrix"); - INDArray matrix = Nd4j.create(numRows, 3); - for (int i = 0; i < size; i++) { - matrix.putScalar(i, matrixJson.getDouble(i)); - } - matrix.reshape(numRows, 3); - return (AS) new HighLowDiscrete(matrix); - default: - throw new RuntimeException("Unknown action space " + infoName); - } - } - - public static GymObservationSpace fetchObservationSpace(String url, String instanceId) { - JSONObject reply = ClientUtils.get(url + Client.ENVS_ROOT + instanceId + Client.OBSERVATION_SPACE); - JSONObject info = reply.getJSONObject("info"); - return new GymObservationSpace(info); - } -} diff --git a/gym-java-client/src/main/java/org/deeplearning4j/gym/ClientUtils.java b/gym-java-client/src/main/java/org/deeplearning4j/gym/ClientUtils.java deleted file mode 100644 index ff61026ae..000000000 --- a/gym-java-client/src/main/java/org/deeplearning4j/gym/ClientUtils.java +++ /dev/null @@ -1,75 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.deeplearning4j.gym; - -import com.mashape.unirest.http.HttpResponse; -import com.mashape.unirest.http.JsonNode; -import com.mashape.unirest.http.Unirest; -import com.mashape.unirest.http.exceptions.UnirestException; -import org.json.JSONException; -import org.json.JSONObject; - -/** - * @author rubenfiszel (ruben.fiszel@epfl.ch) on 7/8/16. - * - * ClientUtils contain the utility methods to post and get data from the server REST API through the library unirest. - */ -public class ClientUtils { - - static public JsonNode post(String url, JSONObject json) { - HttpResponse jsonResponse = null; - - try { - jsonResponse = Unirest.post(url).header("content-type", "application/json").body(json).asJson(); - } catch (UnirestException e) { - unirestCrash(e); - } - - return jsonResponse.getBody(); - } - - - static public JSONObject get(String url) { - HttpResponse jsonResponse = null; - - try { - jsonResponse = Unirest.get(url).header("content-type", "application/json").asJson(); - } catch (UnirestException e) { - unirestCrash(e); - } - - checkReply(jsonResponse, url); - - return jsonResponse.getBody().getObject(); - } - - - static public void checkReply(HttpResponse res, String url) { - if (res.getBody() == null) - throw new RuntimeException("Invalid reply at: " + url); - } - - static public void unirestCrash(UnirestException e) { - //if couldn't parse json - if (e.getCause().getCause().getCause() instanceof JSONException) - throw new RuntimeException("Couldn't parse json reply.", e); - else - throw new RuntimeException("Connection error", e); - } - - -} diff --git a/gym-java-client/src/main/java/org/deeplearning4j/rl4j/space/GymObservationSpace.java b/gym-java-client/src/main/java/org/deeplearning4j/rl4j/space/GymObservationSpace.java deleted file mode 100644 index e8f9b2bbd..000000000 --- a/gym-java-client/src/main/java/org/deeplearning4j/rl4j/space/GymObservationSpace.java +++ /dev/null @@ -1,82 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.deeplearning4j.rl4j.space; - -import lombok.Value; -import org.json.JSONArray; -import org.json.JSONObject; -import org.nd4j.linalg.api.ndarray.INDArray; -import org.nd4j.linalg.factory.Nd4j; - -/** - * @author rubenfiszel (ruben.fiszel@epfl.ch) on 7/8/16. - * - * Contain contextual information about the environment from which Observations are observed and must know how to build an Observation from json. - * - * @param the type of Observation - */ - -@Value -public class GymObservationSpace implements ObservationSpace { - - String name; - int[] shape; - INDArray low; - INDArray high; - - - public GymObservationSpace(JSONObject jsonObject) { - - name = jsonObject.getString("name"); - - JSONArray arr = jsonObject.getJSONArray("shape"); - int lg = arr.length(); - - shape = new int[lg]; - for (int i = 0; i < lg; i++) { - this.shape[i] = arr.getInt(i); - } - - low = Nd4j.create(shape); - high = Nd4j.create(shape); - - JSONArray lowJson = jsonObject.getJSONArray("low"); - JSONArray highJson = jsonObject.getJSONArray("high"); - - int size = shape[0]; - for (int i = 1; i < shape.length; i++) { - size *= shape[i]; - } - - for (int i = 0; i < size; i++) { - low.putScalar(i, lowJson.getDouble(i)); - high.putScalar(i, highJson.getDouble(i)); - } - - } - - public O getValue(JSONObject o, String key) { - switch (name) { - case "Box": - JSONArray arr = o.getJSONArray(key); - return (O) new Box(arr); - default: - throw new RuntimeException("Invalid environment name: " + name); - } - } - -} diff --git a/gym-java-client/src/test/java/org/deeplearning4j/gym/test/ClientTest.java b/gym-java-client/src/test/java/org/deeplearning4j/gym/test/ClientTest.java deleted file mode 100644 index ab31a488b..000000000 --- a/gym-java-client/src/test/java/org/deeplearning4j/gym/test/ClientTest.java +++ /dev/null @@ -1,140 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.deeplearning4j.gym.test; - -import com.mashape.unirest.http.JsonNode; -import org.deeplearning4j.gym.Client; -import org.deeplearning4j.gym.ClientFactory; -import org.deeplearning4j.gym.ClientUtils; -import org.deeplearning4j.gym.StepReply; -import org.deeplearning4j.rl4j.space.ActionSpace; -import org.deeplearning4j.rl4j.space.Box; -import org.deeplearning4j.rl4j.space.DiscreteSpace; -import org.json.JSONObject; -import org.junit.Test; -import org.junit.runner.RunWith; -import org.powermock.core.classloader.annotations.PrepareForTest; -import org.powermock.modules.junit4.PowerMockRunner; - -import static org.mockito.Matchers.eq; -import static org.powermock.api.mockito.PowerMockito.mockStatic; -import static org.powermock.api.mockito.PowerMockito.when; - -/** - * @author rubenfiszel (ruben.fiszel@epfl.ch) on 7/11/16. - */ - -@RunWith(PowerMockRunner.class) -@PrepareForTest(ClientUtils.class) -public class ClientTest { - - @Test - public void completeClientTest() { - - String url = "http://127.0.0.1:5000"; - String env = "Powermock-v0"; - String instanceID = "e15739cf"; - String testDir = "/tmp/testDir"; - boolean render = true; - String renderStr = render ? "True" : "False"; - - mockStatic(ClientUtils.class); - - //post mock - - JSONObject buildReq = new JSONObject("{\"env_id\":\"" + env + "\"}"); - JsonNode buildRep = new JsonNode("{\"instance_id\":\"" + instanceID + "\"}"); - when(ClientUtils.post(eq(url + Client.ENVS_ROOT), JSONObjectMatcher.jsonEq(buildReq))).thenReturn(buildRep); - - JSONObject monStartReq = new JSONObject("{\"resume\":false,\"directory\":\"" + testDir + "\",\"force\":true}"); - when(ClientUtils.post(eq(url + Client.ENVS_ROOT + instanceID + Client.MONITOR_START), - JSONObjectMatcher.jsonEq(monStartReq))).thenReturn(null); - - JSONObject monStopReq = new JSONObject("{}"); - when(ClientUtils.post(eq(url + Client.ENVS_ROOT + instanceID + Client.MONITOR_CLOSE), - JSONObjectMatcher.jsonEq(monStopReq))).thenReturn(null); - - JSONObject resetReq = new JSONObject("{}"); - JsonNode resetRep = new JsonNode( - "{\"observation\":[0.021729452941849317,-0.04764548144956857,-0.024914502756611293,-0.04074903379512588]}"); - when(ClientUtils.post(eq(url + Client.ENVS_ROOT + instanceID + Client.RESET), - JSONObjectMatcher.jsonEq(resetReq))).thenReturn(resetRep); - - JSONObject stepReq = new JSONObject("{\"action\":0, \"render\":" + renderStr + "}"); - JsonNode stepRep = new JsonNode( - "{\"observation\":[0.020776543312857946,-0.24240146656155923,-0.02572948343251381,0.24397017400615437],\"reward\":1,\"done\":false,\"info\":{}}"); - when(ClientUtils.post(eq(url + Client.ENVS_ROOT + instanceID + Client.STEP), JSONObjectMatcher.jsonEq(stepReq))) - .thenReturn(stepRep); - - JSONObject stepReq2 = new JSONObject("{\"action\":1, \"render\":" + renderStr + "}"); - JsonNode stepRep2 = new JsonNode( - "{\"observation\":[0.020776543312857946,-0.24240146656155923,-0.02572948343251381,0.24397017400615437],\"reward\":1,\"done\":false,\"info\":{}}"); - when(ClientUtils.post(eq(url + Client.ENVS_ROOT + instanceID + Client.STEP), - JSONObjectMatcher.jsonEq(stepReq2))).thenReturn(stepRep2); - - //get mock - JSONObject obsSpace = new JSONObject( - "{\"info\":{\"name\":\"Box\",\"shape\":[4],\"high\":[4.8,3.4028234663852886E38,0.41887902047863906,3.4028234663852886E38],\"low\":[-4.8,-3.4028234663852886E38,-0.41887902047863906,-3.4028234663852886E38]}}"); - when(ClientUtils.get(eq(url + Client.ENVS_ROOT + instanceID + Client.OBSERVATION_SPACE))).thenReturn(obsSpace); - - JSONObject actionSpace = new JSONObject("{\"info\":{\"name\":\"Discrete\",\"n\":2}}"); - when(ClientUtils.get(eq(url + Client.ENVS_ROOT + instanceID + Client.ACTION_SPACE))).thenReturn(actionSpace); - - - //test - - Client client = ClientFactory.build(url, env, render); - client.monitorStart(testDir, true, false); - - int episodeCount = 1; - int maxSteps = 200; - int reward = 0; - - for (int i = 0; i < episodeCount; i++) { - client.reset(); - - for (int j = 0; j < maxSteps; j++) { - - Integer action = ((ActionSpace)client.getActionSpace()).randomAction(); - StepReply step = client.step(action); - reward += step.getReward(); - - //return a isDone true before i == maxSteps - if (j == maxSteps - 5) { - JSONObject stepReqLoc = new JSONObject("{\"action\":0}"); - JsonNode stepRepLoc = new JsonNode( - "{\"observation\":[0.020776543312857946,-0.24240146656155923,-0.02572948343251381,0.24397017400615437],\"reward\":1,\"done\":true,\"info\":{}}"); - - when(ClientUtils.post(eq(url + Client.ENVS_ROOT + instanceID + Client.STEP), - JSONObjectMatcher.jsonEq(stepReqLoc))).thenReturn(stepRepLoc); - } - - if (step.isDone()) { - // System.out.println("break"); - break; - } - } - - } - - client.monitorClose(); - client.upload(testDir, "YOUR_OPENAI_GYM_API_KEY"); - - - } - -} diff --git a/gym-java-client/src/test/java/org/deeplearning4j/gym/test/JSONObjectMatcher.java b/gym-java-client/src/test/java/org/deeplearning4j/gym/test/JSONObjectMatcher.java deleted file mode 100644 index ff1430786..000000000 --- a/gym-java-client/src/test/java/org/deeplearning4j/gym/test/JSONObjectMatcher.java +++ /dev/null @@ -1,46 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.deeplearning4j.gym.test; - -import org.json.JSONObject; -import org.mockito.ArgumentMatcher; - -import static org.mockito.Matchers.argThat; - -/** - * @author rubenfiszel (ruben.fiszel@epfl.ch) on 7/11/16. - */ - - -public class JSONObjectMatcher implements ArgumentMatcher { - private final JSONObject expected; - - public JSONObjectMatcher(JSONObject expected) { - this.expected = expected; - } - - public static JSONObject jsonEq(JSONObject expected) { - return argThat(new JSONObjectMatcher(expected)); - } - - - @Override - public boolean matches(JSONObject argument) { - if (expected == null) - return argument == null; - return expected.toString().equals(argument.toString()); } -} diff --git a/pom.xml b/pom.xml index ada833f12..3d7082524 100644 --- a/pom.xml +++ b/pom.xml @@ -132,7 +132,6 @@ deeplearning4j arbiter nd4s - gym-java-client rl4j scalnet jumpy @@ -288,22 +287,23 @@ ${javacpp.platform} - 1.5.2 - 1.5.2 - 1.5.2 + 1.5.3-SNAPSHOT + 1.5.3-SNAPSHOT + 1.5.3-SNAPSHOT - 3.7.5 + 3.7.6 ${python.version}-${javacpp-presets.version} - 1.17.3 + 1.18.0 ${numpy.version}-${javacpp-presets.version} 0.3.7 2019.5 - 4.1.2 - 4.2.1 - 1.78.0 - 1.10.5 - 0.6.0 + 4.2.0 + 4.2.2 + 1.79.0 + 1.10.6 + 0.6.1 + 0.15.4 1.15.0 ${tensorflow.version}-${javacpp-presets.version} diff --git a/rl4j/README.md b/rl4j/README.md index ef5797701..68d5b755e 100644 --- a/rl4j/README.md +++ b/rl4j/README.md @@ -32,10 +32,6 @@ Comments are welcome on our gitter channel: # Quickstart -** INSTALL rl4j-api before installing all (see below)!** - -* mvn install -pl rl4j-api -* [if you want rl4j-gym too] Download and mvn install: [gym-java-client](https://github.com/eclipse/deeplearning4j/tree/master/gym-java-client) * mvn install # Visualisation @@ -44,9 +40,7 @@ Comments are welcome on our gitter channel: # Quicktry cartpole: -* Install [gym-http-api](https://github.com/openai/gym-http-api). -* launch http api server. -* run with this [main](https://github.com/rubenfiszel/rl4j-examples/blob/master/src/main/java/org/deeplearning4j/rl4j/Cartpole.java) +* run with this [main](https://github.com/eclipse/deeplearning4j-examples/blob/master/rl4j-examples/src/main/java/org/deeplearning4j/examples/rl4j/Cartpole.java) # Doom @@ -83,4 +77,4 @@ Doom is not ready yet but you can make it work if you feel adventurous with some * Continuous control * Policy Gradient -* Update gym-java-client when gym-http-api gets compatible with pixels environments to play with Pong, Doom, etc .. +* Update rl4j-gym to make it compatible with pixels environments to play with Pong, Doom, etc .. diff --git a/rl4j/pom.xml b/rl4j/pom.xml index 318446629..c91dd7aa2 100644 --- a/rl4j/pom.xml +++ b/rl4j/pom.xml @@ -97,6 +97,30 @@ + + org.apache.maven.plugins + maven-enforcer-plugin + ${maven-enforcer-plugin.version} + + + test + enforce-choice-of-nd4j-test-backend + + enforce + + + ${skipBackendChoice} + + + test-nd4j-native,test-nd4j-cuda-10.2 + false + + + true + + + + maven-source-plugin ${maven-source-plugin.version} @@ -265,6 +289,32 @@ + + + test-nd4j-native + + + org.nd4j + nd4j-native + ${nd4j.version} + test + + + + + + test-nd4j-cuda-10.2 + + + org.nd4j + nd4j-cuda-10.2 + ${nd4j.version} + test + + + + + diff --git a/rl4j/rl4j-ale/pom.xml b/rl4j/rl4j-ale/pom.xml index 33211f480..360c3db1d 100644 --- a/rl4j/rl4j-ale/pom.xml +++ b/rl4j/rl4j-ale/pom.xml @@ -44,4 +44,13 @@ ${ale.version}-${javacpp-presets.version} + + + + test-nd4j-native + + + test-nd4j-cuda-10.2 + + diff --git a/rl4j/rl4j-api/pom.xml b/rl4j/rl4j-api/pom.xml index b89c875cc..629783e15 100644 --- a/rl4j/rl4j-api/pom.xml +++ b/rl4j/rl4j-api/pom.xml @@ -33,15 +33,19 @@ - - org.deeplearning4j - gym-java-client - ${dl4j.version} - org.nd4j nd4j-api ${nd4j.version} + + + + test-nd4j-native + + + test-nd4j-cuda-10.2 + + diff --git a/gym-java-client/src/main/java/org/deeplearning4j/gym/StepReply.java b/rl4j/rl4j-api/src/main/java/org/deeplearning4j/gym/StepReply.java similarity index 95% rename from gym-java-client/src/main/java/org/deeplearning4j/gym/StepReply.java rename to rl4j/rl4j-api/src/main/java/org/deeplearning4j/gym/StepReply.java index 0da2971fe..ab054689a 100644 --- a/gym-java-client/src/main/java/org/deeplearning4j/gym/StepReply.java +++ b/rl4j/rl4j-api/src/main/java/org/deeplearning4j/gym/StepReply.java @@ -17,7 +17,6 @@ package org.deeplearning4j.gym; import lombok.Value; -import org.json.JSONObject; /** * @param type of observation @@ -31,6 +30,6 @@ public class StepReply { T observation; double reward; boolean done; - JSONObject info; + Object info; } diff --git a/gym-java-client/src/main/java/org/deeplearning4j/rl4j/space/ActionSpace.java b/rl4j/rl4j-api/src/main/java/org/deeplearning4j/rl4j/space/ActionSpace.java similarity index 100% rename from gym-java-client/src/main/java/org/deeplearning4j/rl4j/space/ActionSpace.java rename to rl4j/rl4j-api/src/main/java/org/deeplearning4j/rl4j/space/ActionSpace.java diff --git a/gym-java-client/src/main/java/org/deeplearning4j/rl4j/space/ArrayObservationSpace.java b/rl4j/rl4j-api/src/main/java/org/deeplearning4j/rl4j/space/ArrayObservationSpace.java similarity index 100% rename from gym-java-client/src/main/java/org/deeplearning4j/rl4j/space/ArrayObservationSpace.java rename to rl4j/rl4j-api/src/main/java/org/deeplearning4j/rl4j/space/ArrayObservationSpace.java diff --git a/gym-java-client/src/main/java/org/deeplearning4j/rl4j/space/Box.java b/rl4j/rl4j-api/src/main/java/org/deeplearning4j/rl4j/space/Box.java similarity index 83% rename from gym-java-client/src/main/java/org/deeplearning4j/rl4j/space/Box.java rename to rl4j/rl4j-api/src/main/java/org/deeplearning4j/rl4j/space/Box.java index c938747d5..e90601fda 100644 --- a/gym-java-client/src/main/java/org/deeplearning4j/rl4j/space/Box.java +++ b/rl4j/rl4j-api/src/main/java/org/deeplearning4j/rl4j/space/Box.java @@ -16,8 +16,6 @@ package org.deeplearning4j.rl4j.space; -import org.json.JSONArray; - /** * @author rubenfiszel (ruben.fiszel@epfl.ch) on 7/8/16. * @@ -29,14 +27,8 @@ public class Box implements Encodable { private final double[] array; - public Box(JSONArray arr) { - - int lg = arr.length(); - this.array = new double[lg]; - - for (int i = 0; i < lg; i++) { - this.array[i] = arr.getDouble(i); - } + public Box(double[] arr) { + this.array = arr; } public double[] toArray() { diff --git a/gym-java-client/src/main/java/org/deeplearning4j/rl4j/space/DiscreteSpace.java b/rl4j/rl4j-api/src/main/java/org/deeplearning4j/rl4j/space/DiscreteSpace.java similarity index 100% rename from gym-java-client/src/main/java/org/deeplearning4j/rl4j/space/DiscreteSpace.java rename to rl4j/rl4j-api/src/main/java/org/deeplearning4j/rl4j/space/DiscreteSpace.java diff --git a/gym-java-client/src/main/java/org/deeplearning4j/rl4j/space/Encodable.java b/rl4j/rl4j-api/src/main/java/org/deeplearning4j/rl4j/space/Encodable.java similarity index 100% rename from gym-java-client/src/main/java/org/deeplearning4j/rl4j/space/Encodable.java rename to rl4j/rl4j-api/src/main/java/org/deeplearning4j/rl4j/space/Encodable.java diff --git a/gym-java-client/src/main/java/org/deeplearning4j/rl4j/space/HighLowDiscrete.java b/rl4j/rl4j-api/src/main/java/org/deeplearning4j/rl4j/space/HighLowDiscrete.java similarity index 82% rename from gym-java-client/src/main/java/org/deeplearning4j/rl4j/space/HighLowDiscrete.java rename to rl4j/rl4j-api/src/main/java/org/deeplearning4j/rl4j/space/HighLowDiscrete.java index d9c77f285..491f6aca1 100644 --- a/gym-java-client/src/main/java/org/deeplearning4j/rl4j/space/HighLowDiscrete.java +++ b/rl4j/rl4j-api/src/main/java/org/deeplearning4j/rl4j/space/HighLowDiscrete.java @@ -17,10 +17,8 @@ package org.deeplearning4j.rl4j.space; import lombok.Value; -import org.json.JSONArray; import org.nd4j.linalg.api.ndarray.INDArray; - /** * @author rubenfiszel (ruben.fiszel@epfl.ch) on 7/26/16. */ @@ -37,12 +35,8 @@ public class HighLowDiscrete extends DiscreteSpace { @Override public Object encode(Integer a) { - JSONArray jsonArray = new JSONArray(); - for (int i = 0; i < size; i++) { - jsonArray.put(matrix.getDouble(i, 0)); - } - jsonArray.put(a - 1, matrix.getDouble(a - 1, 1)); - return jsonArray; + INDArray m = matrix.dup(); + m.put(a - 1, 0, matrix.getDouble(a - 1, 1)); + return m; } - } diff --git a/gym-java-client/src/main/java/org/deeplearning4j/rl4j/space/ObservationSpace.java b/rl4j/rl4j-api/src/main/java/org/deeplearning4j/rl4j/space/ObservationSpace.java similarity index 100% rename from gym-java-client/src/main/java/org/deeplearning4j/rl4j/space/ObservationSpace.java rename to rl4j/rl4j-api/src/main/java/org/deeplearning4j/rl4j/space/ObservationSpace.java diff --git a/rl4j/rl4j-core/pom.xml b/rl4j/rl4j-core/pom.xml index ebdfec1a6..a78157603 100644 --- a/rl4j/rl4j-core/pom.xml +++ b/rl4j/rl4j-core/pom.xml @@ -44,11 +44,6 @@ ch.qos.logback logback-classic - - org.deeplearning4j - gym-java-client - ${dl4j.version} - org.bytedeco @@ -111,27 +106,10 @@ - nd4j-tests-cpu - - - org.nd4j - nd4j-native - ${project.version} - test - - + test-nd4j-native - - nd4j-tests-cuda - - - org.nd4j - nd4j-cuda-10.2 - ${project.version} - test - - + test-nd4j-cuda-10.2 diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/mdp/toy/HardDeteministicToy.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/mdp/toy/HardDeteministicToy.java index a9708f550..2d8bc0402 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/mdp/toy/HardDeteministicToy.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/mdp/toy/HardDeteministicToy.java @@ -23,7 +23,6 @@ import org.deeplearning4j.rl4j.network.dqn.IDQN; import org.deeplearning4j.rl4j.space.ArrayObservationSpace; import org.deeplearning4j.rl4j.space.DiscreteSpace; import org.deeplearning4j.rl4j.space.ObservationSpace; -import org.json.JSONObject; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; @@ -104,7 +103,7 @@ public class HardDeteministicToy implements MDP { public StepReply step(Integer a) { double reward = (simpleToyState.getStep() % 2 == 0) ? 1 - a : a; simpleToyState = new SimpleToyState(simpleToyState.getI() + 1, simpleToyState.getStep() + 1); - return new StepReply<>(simpleToyState, reward, isDone(), new JSONObject("{}")); + return new StepReply<>(simpleToyState, reward, isDone(), null); } public SimpleToy newInstance() { diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/util/DataManager.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/util/DataManager.java index 4f7d6244e..b639efdaa 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/util/DataManager.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/util/DataManager.java @@ -26,6 +26,7 @@ import org.deeplearning4j.rl4j.learning.NeuralNetFetchable; import org.nd4j.linalg.primitives.Pair; import org.deeplearning4j.rl4j.learning.ILearning; import org.deeplearning4j.rl4j.learning.Learning; +import org.deeplearning4j.rl4j.network.ac.IActorCritic; import org.deeplearning4j.rl4j.network.dqn.DQN; import org.deeplearning4j.rl4j.network.dqn.IDQN; import org.deeplearning4j.util.ModelSerializer; @@ -88,19 +89,38 @@ public class DataManager implements IDataManager { String json = new ObjectMapper().writeValueAsString(learning.getConfiguration()); writeEntry(new ByteArrayInputStream(json.getBytes()), zipfile); - ZipEntry dqn = new ZipEntry("dqn.bin"); - zipfile.putNextEntry(dqn); + try { + ZipEntry dqn = new ZipEntry("dqn.bin"); + zipfile.putNextEntry(dqn); - ByteArrayOutputStream bos = new ByteArrayOutputStream(); - if(learning instanceof NeuralNetFetchable) { - ((NeuralNetFetchable)learning).getNeuralNet().save(bos); + ByteArrayOutputStream bos = new ByteArrayOutputStream(); + if(learning instanceof NeuralNetFetchable) { + ((NeuralNetFetchable)learning).getNeuralNet().save(bos); + } + bos.flush(); + bos.close(); + + InputStream inputStream = new ByteArrayInputStream(bos.toByteArray()); + writeEntry(inputStream, zipfile); + } catch (UnsupportedOperationException e) { + ByteArrayOutputStream bos = new ByteArrayOutputStream(); + ByteArrayOutputStream bos2 = new ByteArrayOutputStream(); + ((IActorCritic)((NeuralNetFetchable)learning).getNeuralNet()).save(bos, bos2); + + bos.flush(); + bos.close(); + InputStream inputStream = new ByteArrayInputStream(bos.toByteArray()); + ZipEntry value = new ZipEntry("value.bin"); + zipfile.putNextEntry(value); + writeEntry(inputStream, zipfile); + + bos2.flush(); + bos2.close(); + InputStream inputStream2 = new ByteArrayInputStream(bos2.toByteArray()); + ZipEntry policy = new ZipEntry("policy.bin"); + zipfile.putNextEntry(policy); + writeEntry(inputStream2, zipfile); } - bos.flush(); - bos.close(); - - InputStream inputStream = new ByteArrayInputStream(bos.toByteArray()); - writeEntry(inputStream, zipfile); - if (learning.getHistoryProcessor() != null) { ZipEntry hpconf = new ZipEntry("hpconf.bin"); @@ -268,7 +288,12 @@ public class DataManager implements IDataManager { save(getModelDir() + "/" + learning.getStepCounter() + ".training", learning); if(learning instanceof NeuralNetFetchable) { - ((NeuralNetFetchable)learning).getNeuralNet().save(getModelDir() + "/" + learning.getStepCounter() + ".model"); + try { + ((NeuralNetFetchable)learning).getNeuralNet().save(getModelDir() + "/" + learning.getStepCounter() + ".model"); + } catch (UnsupportedOperationException e) { + String path = getModelDir() + "/" + learning.getStepCounter(); + ((IActorCritic)((NeuralNetFetchable)learning).getNeuralNet()).save(path + "_value.model", path + "_policy.model"); + } } } diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/util/LegacyMDPWrapper.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/util/LegacyMDPWrapper.java index 221409040..dbcd38ddc 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/util/LegacyMDPWrapper.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/util/LegacyMDPWrapper.java @@ -47,7 +47,10 @@ public class LegacyMDPWrapper> implements MDP${project.version} + + + + test-nd4j-native + + + test-nd4j-cuda-10.2 + + diff --git a/rl4j/rl4j-gym/pom.xml b/rl4j/rl4j-gym/pom.xml index d5be31032..76e2e39ff 100644 --- a/rl4j/rl4j-gym/pom.xml +++ b/rl4j/rl4j-gym/pom.xml @@ -39,9 +39,18 @@ ${project.version} - org.deeplearning4j - gym-java-client - ${project.version} + org.bytedeco + gym-platform + ${gym.version}-${javacpp-presets.version} + + + + test-nd4j-native + + + test-nd4j-cuda-10.2 + + diff --git a/rl4j/rl4j-gym/src/main/java/org/deeplearning4j/rl4j/mdp/gym/GymEnv.java b/rl4j/rl4j-gym/src/main/java/org/deeplearning4j/rl4j/mdp/gym/GymEnv.java index 8f18cb29e..66be7698f 100644 --- a/rl4j/rl4j-gym/src/main/java/org/deeplearning4j/rl4j/mdp/gym/GymEnv.java +++ b/rl4j/rl4j-gym/src/main/java/org/deeplearning4j/rl4j/mdp/gym/GymEnv.java @@ -1,5 +1,6 @@ /******************************************************************************* * Copyright (c) 2015-2018 Skymind, Inc. + * Copyright (c) 2020 Konduit K.K. * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at @@ -16,39 +17,111 @@ package org.deeplearning4j.rl4j.mdp.gym; - -import org.deeplearning4j.gym.Client; -import org.deeplearning4j.gym.ClientFactory; +import java.io.IOException; +import lombok.Getter; +import lombok.Setter; +import lombok.Value; +import lombok.extern.slf4j.Slf4j; +import org.bytedeco.javacpp.DoublePointer; +import org.bytedeco.javacpp.Pointer; +import org.bytedeco.javacpp.SizeTPointer; import org.deeplearning4j.gym.StepReply; import org.deeplearning4j.rl4j.mdp.MDP; +import org.deeplearning4j.rl4j.space.ArrayObservationSpace; import org.deeplearning4j.rl4j.space.ActionSpace; +import org.deeplearning4j.rl4j.space.Box; +import org.deeplearning4j.rl4j.space.DiscreteSpace; +import org.deeplearning4j.rl4j.space.Encodable; import org.deeplearning4j.rl4j.space.HighLowDiscrete; import org.deeplearning4j.rl4j.space.ObservationSpace; +import org.bytedeco.cpython.*; +import org.bytedeco.numpy.*; +import static org.bytedeco.cpython.global.python.*; +import static org.bytedeco.numpy.global.numpy.*; + /** + * An MDP for OpenAI Gym: https://gym.openai.com/ + * * @author rubenfiszel (ruben.fiszel@epfl.ch) 7/12/16. - * - * Wrapper over the client of gym-java-client - * + * @author saudet */ +@Slf4j public class GymEnv> implements MDP { - final public static String GYM_MONITOR_DIR = "/tmp/gym-dqn"; + public static final String GYM_MONITOR_DIR = "/tmp/gym-dqn"; - final private Client client; + private static void checkPythonError() { + if (PyErr_Occurred() != null) { + PyErr_Print(); + throw new RuntimeException("Python error occurred"); + } + } + + private static Pointer program; + private static PyObject globals; + static { + try { + Py_SetPath(org.bytedeco.gym.presets.gym.cachePackages()); + program = Py_DecodeLocale(GymEnv.class.getSimpleName(), null); + Py_SetProgramName(program); + Py_Initialize(); + PyEval_InitThreads(); + PySys_SetArgvEx(1, program, 0); + if (_import_array() < 0) { + PyErr_Print(); + throw new RuntimeException("numpy.core.multiarray failed to import"); + } + globals = PyModule_GetDict(PyImport_AddModule("__main__")); + PyEval_SaveThread(); // just to release the GIL + } catch (IOException e) { + PyMem_RawFree(program); + throw new RuntimeException(e); + } + } + private PyObject locals; + + final protected DiscreteSpace actionSpace; + final protected ObservationSpace observationSpace; + @Getter final private String envId; + @Getter final private boolean render; + @Getter final private boolean monitor; private ActionTransformer actionTransformer = null; private boolean done = false; public GymEnv(String envId, boolean render, boolean monitor) { - this.client = ClientFactory.build(envId, render); this.envId = envId; this.render = render; this.monitor = monitor; - if (monitor) - client.monitorStart(GYM_MONITOR_DIR, true, false); + + int gstate = PyGILState_Ensure(); + try { + locals = PyDict_New(); + + Py_DecRef(PyRun_StringFlags("import gym; env = gym.make('" + envId + "')", Py_single_input, globals, locals, null)); + checkPythonError(); + if (monitor) { + Py_DecRef(PyRun_StringFlags("env = gym.wrappers.Monitor(env, '" + GYM_MONITOR_DIR + "')", Py_single_input, globals, locals, null)); + checkPythonError(); + } + PyObject shapeTuple = PyRun_StringFlags("env.observation_space.shape", Py_eval_input, globals, locals, null); + int[] shape = new int[(int)PyTuple_Size(shapeTuple)]; + for (int i = 0; i < shape.length; i++) { + shape[i] = (int)PyLong_AsLong(PyTuple_GetItem(shapeTuple, i)); + } + observationSpace = (ObservationSpace) new ArrayObservationSpace(shape); + Py_DecRef(shapeTuple); + + PyObject n = PyRun_StringFlags("env.action_space.n", Py_eval_input, globals, locals, null); + actionSpace = new DiscreteSpace((int)PyLong_AsLong(n)); + Py_DecRef(n); + checkPythonError(); + } finally { + PyGILState_Release(gstate); + } } public GymEnv(String envId, boolean render, boolean monitor, int[] actions) { @@ -56,43 +129,87 @@ public class GymEnv> implements MDP { actionTransformer = new ActionTransformer((HighLowDiscrete) getActionSpace(), actions); } - + @Override public ObservationSpace getObservationSpace() { - return client.getObservationSpace(); + return observationSpace; } + @Override public AS getActionSpace() { if (actionTransformer == null) - return (AS) client.getActionSpace(); + return (AS) actionSpace; else return (AS) actionTransformer; } + @Override public StepReply step(A action) { - StepReply stepRep = client.step(action); - done = stepRep.isDone(); - return stepRep; + int gstate = PyGILState_Ensure(); + try { + if (render) { + Py_DecRef(PyRun_StringFlags("env.render()", Py_single_input, globals, locals, null)); + checkPythonError(); + } + Py_DecRef(PyRun_StringFlags("state, reward, done, info = env.step(" + (Integer)action +")", Py_single_input, globals, locals, null)); + checkPythonError(); + + PyArrayObject state = new PyArrayObject(PyDict_GetItemString(locals, "state")); + DoublePointer stateData = new DoublePointer(PyArray_BYTES(state)).capacity(PyArray_Size(state)); + SizeTPointer stateDims = PyArray_DIMS(state).capacity(PyArray_NDIM(state)); + + double reward = PyFloat_AsDouble(PyDict_GetItemString(locals, "reward")); + done = PyLong_AsLong(PyDict_GetItemString(locals, "done")) != 0; + checkPythonError(); + + double[] data = new double[(int)stateData.capacity()]; + stateData.get(data); + + return new StepReply(new Box(data), reward, done, null); + } finally { + PyGILState_Release(gstate); + } } + @Override public boolean isDone() { return done; } + @Override public O reset() { - done = false; - return client.reset(); - } - - - public void upload(String apiKey) { - client.upload(GYM_MONITOR_DIR, apiKey); + int gstate = PyGILState_Ensure(); + try { + Py_DecRef(PyRun_StringFlags("state = env.reset()", Py_single_input, globals, locals, null)); + checkPythonError(); + + PyArrayObject state = new PyArrayObject(PyDict_GetItemString(locals, "state")); + DoublePointer stateData = new DoublePointer(PyArray_BYTES(state)).capacity(PyArray_Size(state)); + SizeTPointer stateDims = PyArray_DIMS(state).capacity(PyArray_NDIM(state)); + checkPythonError(); + + done = false; + + double[] data = new double[(int)stateData.capacity()]; + stateData.get(data); + return (O) new Box(data); + } finally { + PyGILState_Release(gstate); + } } + @Override public void close() { - if (monitor) - client.monitorClose(); + int gstate = PyGILState_Ensure(); + try { + Py_DecRef(PyRun_StringFlags("env.close()", Py_single_input, globals, locals, null)); + checkPythonError(); + Py_DecRef(locals); + } finally { + PyGILState_Release(gstate); + } } + @Override public GymEnv newInstance() { return new GymEnv(envId, render, monitor); } diff --git a/rl4j/rl4j-gym/src/test/java/org/deeplearning4j/rl4j/mdp/gym/GymEnvTest.java b/rl4j/rl4j-gym/src/test/java/org/deeplearning4j/rl4j/mdp/gym/GymEnvTest.java new file mode 100644 index 000000000..2196d7b31 --- /dev/null +++ b/rl4j/rl4j-gym/src/test/java/org/deeplearning4j/rl4j/mdp/gym/GymEnvTest.java @@ -0,0 +1,48 @@ +/******************************************************************************* + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +package org.deeplearning4j.rl4j.mdp.gym; + +import org.deeplearning4j.gym.StepReply; +import org.deeplearning4j.rl4j.space.ArrayObservationSpace; +import org.deeplearning4j.rl4j.space.Box; +import org.deeplearning4j.rl4j.space.DiscreteSpace; +import org.junit.Test; + +import static org.junit.Assert.assertArrayEquals; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotEquals; + +/** + * + * @author saudet + */ +public class GymEnvTest { + + @Test + public void testCartpole() { + GymEnv mdp = new GymEnv("CartPole-v0", false, false); + assertArrayEquals(new int[] {4}, ((ArrayObservationSpace)mdp.getObservationSpace()).getShape()); + assertEquals(2, ((DiscreteSpace)mdp.getActionSpace()).getSize()); + assertEquals(false, mdp.isDone()); + Box o = (Box)mdp.reset(); + StepReply r = mdp.step(0); + assertEquals(4, o.toArray().length); + assertEquals(4, ((Box)r.getObservation()).toArray().length); + assertNotEquals(null, mdp.newInstance()); + mdp.close(); + } +} diff --git a/rl4j/rl4j-malmo/pom.xml b/rl4j/rl4j-malmo/pom.xml index 53ef7be13..3575da7a8 100644 --- a/rl4j/rl4j-malmo/pom.xml +++ b/rl4j/rl4j-malmo/pom.xml @@ -33,6 +33,11 @@ + + org.json + json + 20190722 + org.deeplearning4j rl4j-api @@ -44,4 +49,13 @@ 0.30.0 + + + + test-nd4j-native + + + test-nd4j-cuda-10.2 + +