From 5959ff47959864e58ac5e46956bfd1c3a8b8aeb8 Mon Sep 17 00:00:00 2001 From: Alexandre Boulanger <44292157+aboulang2002@users.noreply.github.com> Date: Mon, 30 Sep 2019 20:27:51 -0400 Subject: [PATCH] RL4J: Fix QLearningDiscrete.setTarget() and add CartpoleNative (#8250) * Fixed QLearningDiscrete.setTarget() Signed-off-by: Alexandre Boulanger * Added native java version of Cartpole Signed-off-by: unknown --- .../qlearning/discrete/QLearningDiscrete.java | 8 +- .../rl4j/mdp/CartpoleNative.java | 159 ++++++++++++++++++ 2 files changed, 161 insertions(+), 6 deletions(-) create mode 100644 rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/mdp/CartpoleNative.java diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/QLearningDiscrete.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/QLearningDiscrete.java index 7ea47eba8..bc8fb37d2 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/QLearningDiscrete.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/QLearningDiscrete.java @@ -218,15 +218,14 @@ public abstract class QLearningDiscrete extends QLearning extends QLearning extends QLearning { + public enum KinematicsIntegrators { Euler, SemiImplicitEuler }; + + private static final int NUM_ACTIONS = 2; + private static final int ACTION_LEFT = 0; + private static final int ACTION_RIGHT = 1; + private static final int OBSERVATION_NUM_FEATURES = 4; + + private static final double gravity = 9.8; + private static final double massCart = 1.0; + private static final double massPole = 0.1; + private static final double totalMass = massPole + massCart; + private static final double length = 0.5; // actually half the pole's length + private static final double polemassLength = massPole * length; + private static final double forceMag = 10.0; + private static final double tau = 0.02; // seconds between state updates + + // Angle at which to fail the episode + private static final double thetaThresholdRadians = 12.0 * 2.0 * Math.PI / 360.0; + private static final double xThreshold = 2.4; + + private final Random rnd = new Random(); + + @Getter @Setter + private KinematicsIntegrators kinematicsIntegrator = KinematicsIntegrators.Euler; + + @Getter + private boolean done = false; + + private double x; + private double xDot; + private double theta; + private double thetaDot; + private Integer stepsBeyondDone; + + @Getter + private DiscreteSpace actionSpace = new DiscreteSpace(NUM_ACTIONS); + @Getter + private ObservationSpace observationSpace = new ArrayObservationSpace(new int[] { OBSERVATION_NUM_FEATURES }); + + @Override + public State reset() { + + x = 0.1 * rnd.nextDouble() - 0.05; + xDot = 0.1 * rnd.nextDouble() - 0.05; + theta = 0.1 * rnd.nextDouble() - 0.05; + thetaDot = 0.1 * rnd.nextDouble() - 0.05; + stepsBeyondDone = null; + + return new State(new double[] { x, xDot, theta, thetaDot }); + } + + @Override + public void close() { + + } + + @Override + public StepReply step(Integer action) { + double force = action == ACTION_RIGHT ? forceMag : -forceMag; + double cosTheta = Math.cos(theta); + double sinTheta = Math.sin(theta); + double temp = (force + polemassLength * thetaDot * thetaDot * sinTheta) / totalMass; + double thetaAcc = (gravity * sinTheta - cosTheta* temp) / (length * (4.0/3.0 - massPole * cosTheta * cosTheta / totalMass)); + double xAcc = temp - polemassLength * thetaAcc * cosTheta / totalMass; + + switch(kinematicsIntegrator) { + case Euler: + x += tau * xDot; + xDot += tau * xAcc; + theta += tau * thetaDot; + thetaDot += tau * thetaAcc; + break; + + case SemiImplicitEuler: + xDot += tau * xAcc; + x += tau * xDot; + thetaDot += tau * thetaAcc; + theta += tau * thetaDot; + break; + } + + boolean done = x < -xThreshold || x > xThreshold + || theta < -thetaThresholdRadians || theta > thetaThresholdRadians; + + double reward; + if(!done) { + reward = 1.0; + } + else if(stepsBeyondDone == null) { + stepsBeyondDone = 0; + reward = 1.0; + } + else { + ++stepsBeyondDone; + reward = 0; + } + + return new StepReply<>(new State(new double[] { x, xDot, theta, thetaDot }), reward, done, null); + } + + @Override + public MDP newInstance() { + return new CartpoleNative(); + } + + public static class State implements Encodable { + + private final double[] state; + + State(double[] state) { + + this.state = state; + } + + @Override + public double[] toArray() { + return state; + } + } +}