RL4J: Fix QLearningDiscrete.setTarget() and add CartpoleNative (#8250)
* Fixed QLearningDiscrete.setTarget() Signed-off-by: Alexandre Boulanger <aboulang2002@yahoo.com> * Added native java version of Cartpole Signed-off-by: unknown <aboulang2002@yahoo.com>master
parent
d5e98afcef
commit
5959ff4795
|
@ -218,15 +218,14 @@ public abstract class QLearningDiscrete<O extends Encodable> extends QLearning<O
|
|||
INDArray dqnOutputAr = dqnOutput(obs);
|
||||
|
||||
INDArray dqnOutputNext = dqnOutput(nextObs);
|
||||
INDArray targetDqnOutputNext = null;
|
||||
INDArray targetDqnOutputNext = targetDqnOutput(nextObs);
|
||||
|
||||
INDArray tempQ = null;
|
||||
INDArray getMaxAction = null;
|
||||
if (getConfiguration().isDoubleDQN()) {
|
||||
targetDqnOutputNext = targetDqnOutput(nextObs);
|
||||
getMaxAction = Nd4j.argMax(dqnOutputNext, 1);
|
||||
} else {
|
||||
tempQ = Nd4j.max(dqnOutputNext, 1);
|
||||
tempQ = Nd4j.max(targetDqnOutputNext, 1);
|
||||
}
|
||||
|
||||
|
||||
|
@ -243,8 +242,6 @@ public abstract class QLearningDiscrete<O extends Encodable> extends QLearning<O
|
|||
|
||||
}
|
||||
|
||||
|
||||
|
||||
double previousV = dqnOutputAr.getDouble(i, actions[i]);
|
||||
double lowB = previousV - getConfiguration().getErrorClamp();
|
||||
double highB = previousV + getConfiguration().getErrorClamp();
|
||||
|
@ -255,5 +252,4 @@ public abstract class QLearningDiscrete<O extends Encodable> extends QLearning<O
|
|||
|
||||
return new Pair(obs, dqnOutputAr);
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
@ -0,0 +1,159 @@
|
|||
package org.deeplearning4j.rl4j.mdp;
|
||||
|
||||
import lombok.Getter;
|
||||
import lombok.Setter;
|
||||
import org.deeplearning4j.gym.StepReply;
|
||||
import org.deeplearning4j.rl4j.space.ArrayObservationSpace;
|
||||
import org.deeplearning4j.rl4j.space.DiscreteSpace;
|
||||
import org.deeplearning4j.rl4j.space.Encodable;
|
||||
import org.deeplearning4j.rl4j.space.ObservationSpace;
|
||||
|
||||
import java.util.Random;
|
||||
|
||||
/*
|
||||
With the setup below, it should hit max score (200) after 4000-5000 iterations
|
||||
|
||||
public static QLearning.QLConfiguration CARTPOLE_QL =
|
||||
new QLearning.QLConfiguration(
|
||||
123, //Random seed
|
||||
200, //Max step By epoch
|
||||
10000, //Max step
|
||||
10000, //Max size of experience replay
|
||||
64, //size of batches
|
||||
50, //target update (hard)
|
||||
0, //num step noop warmup
|
||||
1.0, //reward scaling
|
||||
0.99, //gamma
|
||||
Double.MAX_VALUE, //td-error clipping
|
||||
0.1f, //min epsilon
|
||||
3000, //num step for eps greedy anneal
|
||||
true //double DQN
|
||||
);
|
||||
|
||||
public static DQNFactoryStdDense.Configuration CARTPOLE_NET =
|
||||
DQNFactoryStdDense.Configuration.builder()
|
||||
.l2(0.001).updater(new Adam(0.0005)).numHiddenNodes(16).numLayer(3).build();
|
||||
|
||||
*/
|
||||
|
||||
public class CartpoleNative implements MDP<CartpoleNative.State, Integer, DiscreteSpace> {
|
||||
public enum KinematicsIntegrators { Euler, SemiImplicitEuler };
|
||||
|
||||
private static final int NUM_ACTIONS = 2;
|
||||
private static final int ACTION_LEFT = 0;
|
||||
private static final int ACTION_RIGHT = 1;
|
||||
private static final int OBSERVATION_NUM_FEATURES = 4;
|
||||
|
||||
private static final double gravity = 9.8;
|
||||
private static final double massCart = 1.0;
|
||||
private static final double massPole = 0.1;
|
||||
private static final double totalMass = massPole + massCart;
|
||||
private static final double length = 0.5; // actually half the pole's length
|
||||
private static final double polemassLength = massPole * length;
|
||||
private static final double forceMag = 10.0;
|
||||
private static final double tau = 0.02; // seconds between state updates
|
||||
|
||||
// Angle at which to fail the episode
|
||||
private static final double thetaThresholdRadians = 12.0 * 2.0 * Math.PI / 360.0;
|
||||
private static final double xThreshold = 2.4;
|
||||
|
||||
private final Random rnd = new Random();
|
||||
|
||||
@Getter @Setter
|
||||
private KinematicsIntegrators kinematicsIntegrator = KinematicsIntegrators.Euler;
|
||||
|
||||
@Getter
|
||||
private boolean done = false;
|
||||
|
||||
private double x;
|
||||
private double xDot;
|
||||
private double theta;
|
||||
private double thetaDot;
|
||||
private Integer stepsBeyondDone;
|
||||
|
||||
@Getter
|
||||
private DiscreteSpace actionSpace = new DiscreteSpace(NUM_ACTIONS);
|
||||
@Getter
|
||||
private ObservationSpace<CartpoleNative.State> observationSpace = new ArrayObservationSpace(new int[] { OBSERVATION_NUM_FEATURES });
|
||||
|
||||
@Override
|
||||
public State reset() {
|
||||
|
||||
x = 0.1 * rnd.nextDouble() - 0.05;
|
||||
xDot = 0.1 * rnd.nextDouble() - 0.05;
|
||||
theta = 0.1 * rnd.nextDouble() - 0.05;
|
||||
thetaDot = 0.1 * rnd.nextDouble() - 0.05;
|
||||
stepsBeyondDone = null;
|
||||
|
||||
return new State(new double[] { x, xDot, theta, thetaDot });
|
||||
}
|
||||
|
||||
@Override
|
||||
public void close() {
|
||||
|
||||
}
|
||||
|
||||
@Override
|
||||
public StepReply<State> step(Integer action) {
|
||||
double force = action == ACTION_RIGHT ? forceMag : -forceMag;
|
||||
double cosTheta = Math.cos(theta);
|
||||
double sinTheta = Math.sin(theta);
|
||||
double temp = (force + polemassLength * thetaDot * thetaDot * sinTheta) / totalMass;
|
||||
double thetaAcc = (gravity * sinTheta - cosTheta* temp) / (length * (4.0/3.0 - massPole * cosTheta * cosTheta / totalMass));
|
||||
double xAcc = temp - polemassLength * thetaAcc * cosTheta / totalMass;
|
||||
|
||||
switch(kinematicsIntegrator) {
|
||||
case Euler:
|
||||
x += tau * xDot;
|
||||
xDot += tau * xAcc;
|
||||
theta += tau * thetaDot;
|
||||
thetaDot += tau * thetaAcc;
|
||||
break;
|
||||
|
||||
case SemiImplicitEuler:
|
||||
xDot += tau * xAcc;
|
||||
x += tau * xDot;
|
||||
thetaDot += tau * thetaAcc;
|
||||
theta += tau * thetaDot;
|
||||
break;
|
||||
}
|
||||
|
||||
boolean done = x < -xThreshold || x > xThreshold
|
||||
|| theta < -thetaThresholdRadians || theta > thetaThresholdRadians;
|
||||
|
||||
double reward;
|
||||
if(!done) {
|
||||
reward = 1.0;
|
||||
}
|
||||
else if(stepsBeyondDone == null) {
|
||||
stepsBeyondDone = 0;
|
||||
reward = 1.0;
|
||||
}
|
||||
else {
|
||||
++stepsBeyondDone;
|
||||
reward = 0;
|
||||
}
|
||||
|
||||
return new StepReply<>(new State(new double[] { x, xDot, theta, thetaDot }), reward, done, null);
|
||||
}
|
||||
|
||||
@Override
|
||||
public MDP<State, Integer, DiscreteSpace> newInstance() {
|
||||
return new CartpoleNative();
|
||||
}
|
||||
|
||||
public static class State implements Encodable {
|
||||
|
||||
private final double[] state;
|
||||
|
||||
State(double[] state) {
|
||||
|
||||
this.state = state;
|
||||
}
|
||||
|
||||
@Override
|
||||
public double[] toArray() {
|
||||
return state;
|
||||
}
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue