RL4J: Add Agent and Environment (#358)
* Added Agent and Environment Signed-off-by: Alexandre Boulanger <aboulang2002@yahoo.com> * Added headers Signed-off-by: Alexandre Boulanger <aboulang2002@yahoo.com> * Fix compilation errors Signed-off-by: Samuel Audet <samuel.audet@gmail.com>master
parent
a10fd4524a
commit
550e84ef43
|
@ -0,0 +1,210 @@
|
||||||
|
package org.deeplearning4j.rl4j.agent;
|
||||||
|
|
||||||
|
import lombok.AccessLevel;
|
||||||
|
import lombok.Getter;
|
||||||
|
import lombok.NonNull;
|
||||||
|
import org.deeplearning4j.rl4j.agent.listener.AgentListener;
|
||||||
|
import org.deeplearning4j.rl4j.agent.listener.AgentListenerList;
|
||||||
|
import org.deeplearning4j.rl4j.environment.Environment;
|
||||||
|
import org.deeplearning4j.rl4j.environment.StepResult;
|
||||||
|
import org.deeplearning4j.rl4j.observation.Observation;
|
||||||
|
import org.deeplearning4j.rl4j.observation.transform.TransformProcess;
|
||||||
|
import org.deeplearning4j.rl4j.policy.IPolicy;
|
||||||
|
import org.nd4j.base.Preconditions;
|
||||||
|
|
||||||
|
import java.util.Map;
|
||||||
|
|
||||||
|
public class Agent<ACTION> {
|
||||||
|
@Getter
|
||||||
|
private final String id;
|
||||||
|
|
||||||
|
@Getter
|
||||||
|
private final Environment<ACTION> environment;
|
||||||
|
|
||||||
|
@Getter
|
||||||
|
private final IPolicy<ACTION> policy;
|
||||||
|
|
||||||
|
private final TransformProcess transformProcess;
|
||||||
|
|
||||||
|
protected final AgentListenerList<ACTION> listeners;
|
||||||
|
|
||||||
|
private final Integer maxEpisodeSteps;
|
||||||
|
|
||||||
|
@Getter(AccessLevel.PROTECTED)
|
||||||
|
private Observation observation;
|
||||||
|
|
||||||
|
@Getter(AccessLevel.PROTECTED)
|
||||||
|
private ACTION lastAction;
|
||||||
|
|
||||||
|
@Getter
|
||||||
|
private int episodeStepNumber;
|
||||||
|
|
||||||
|
@Getter
|
||||||
|
private double reward;
|
||||||
|
|
||||||
|
protected boolean canContinue;
|
||||||
|
|
||||||
|
private Agent(Builder<ACTION> builder) {
|
||||||
|
this.environment = builder.environment;
|
||||||
|
this.transformProcess = builder.transformProcess;
|
||||||
|
this.policy = builder.policy;
|
||||||
|
this.maxEpisodeSteps = builder.maxEpisodeSteps;
|
||||||
|
this.id = builder.id;
|
||||||
|
|
||||||
|
listeners = buildListenerList();
|
||||||
|
}
|
||||||
|
|
||||||
|
protected AgentListenerList<ACTION> buildListenerList() {
|
||||||
|
return new AgentListenerList<ACTION>();
|
||||||
|
}
|
||||||
|
|
||||||
|
public void addListener(AgentListener listener) {
|
||||||
|
listeners.add(listener);
|
||||||
|
}
|
||||||
|
|
||||||
|
public void run() {
|
||||||
|
runEpisode();
|
||||||
|
}
|
||||||
|
|
||||||
|
protected void onBeforeEpisode() {
|
||||||
|
// Do Nothing
|
||||||
|
}
|
||||||
|
|
||||||
|
protected void onAfterEpisode() {
|
||||||
|
// Do Nothing
|
||||||
|
}
|
||||||
|
|
||||||
|
protected void runEpisode() {
|
||||||
|
reset();
|
||||||
|
onBeforeEpisode();
|
||||||
|
|
||||||
|
canContinue = listeners.notifyBeforeEpisode(this);
|
||||||
|
|
||||||
|
while (canContinue && !environment.isEpisodeFinished() && (maxEpisodeSteps == null || episodeStepNumber < maxEpisodeSteps)) {
|
||||||
|
performStep();
|
||||||
|
}
|
||||||
|
|
||||||
|
if(!canContinue) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
onAfterEpisode();
|
||||||
|
}
|
||||||
|
|
||||||
|
protected void reset() {
|
||||||
|
resetEnvironment();
|
||||||
|
resetPolicy();
|
||||||
|
reward = 0;
|
||||||
|
lastAction = getInitialAction();
|
||||||
|
canContinue = true;
|
||||||
|
}
|
||||||
|
|
||||||
|
protected void resetEnvironment() {
|
||||||
|
episodeStepNumber = 0;
|
||||||
|
Map<String, Object> channelsData = environment.reset();
|
||||||
|
this.observation = transformProcess.transform(channelsData, episodeStepNumber, false);
|
||||||
|
}
|
||||||
|
|
||||||
|
protected void resetPolicy() {
|
||||||
|
policy.reset();
|
||||||
|
}
|
||||||
|
|
||||||
|
protected ACTION getInitialAction() {
|
||||||
|
return environment.getSchema().getActionSchema().getNoOp();
|
||||||
|
}
|
||||||
|
|
||||||
|
protected void performStep() {
|
||||||
|
|
||||||
|
onBeforeStep();
|
||||||
|
|
||||||
|
ACTION action = decideAction(observation);
|
||||||
|
|
||||||
|
canContinue = listeners.notifyBeforeStep(this, observation, action);
|
||||||
|
if(!canContinue) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
StepResult stepResult = act(action);
|
||||||
|
handleStepResult(stepResult);
|
||||||
|
|
||||||
|
onAfterStep(stepResult);
|
||||||
|
|
||||||
|
canContinue = listeners.notifyAfterStep(this, stepResult);
|
||||||
|
if(!canContinue) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
incrementEpisodeStepNumber();
|
||||||
|
}
|
||||||
|
|
||||||
|
protected void incrementEpisodeStepNumber() {
|
||||||
|
++episodeStepNumber;
|
||||||
|
}
|
||||||
|
|
||||||
|
protected ACTION decideAction(Observation observation) {
|
||||||
|
if (!observation.isSkipped()) {
|
||||||
|
lastAction = policy.nextAction(observation);
|
||||||
|
}
|
||||||
|
|
||||||
|
return lastAction;
|
||||||
|
}
|
||||||
|
|
||||||
|
protected StepResult act(ACTION action) {
|
||||||
|
return environment.step(action);
|
||||||
|
}
|
||||||
|
|
||||||
|
protected void handleStepResult(StepResult stepResult) {
|
||||||
|
observation = convertChannelDataToObservation(stepResult, episodeStepNumber + 1);
|
||||||
|
reward +=computeReward(stepResult);
|
||||||
|
}
|
||||||
|
|
||||||
|
protected Observation convertChannelDataToObservation(StepResult stepResult, int episodeStepNumberOfObs) {
|
||||||
|
return transformProcess.transform(stepResult.getChannelsData(), episodeStepNumberOfObs, stepResult.isTerminal());
|
||||||
|
}
|
||||||
|
|
||||||
|
protected double computeReward(StepResult stepResult) {
|
||||||
|
return stepResult.getReward();
|
||||||
|
}
|
||||||
|
|
||||||
|
protected void onAfterStep(StepResult stepResult) {
|
||||||
|
// Do Nothing
|
||||||
|
}
|
||||||
|
|
||||||
|
protected void onBeforeStep() {
|
||||||
|
// Do Nothing
|
||||||
|
}
|
||||||
|
|
||||||
|
public static <ACTION> Builder<ACTION> builder(@NonNull Environment<ACTION> environment, @NonNull TransformProcess transformProcess, @NonNull IPolicy<ACTION> policy) {
|
||||||
|
return new Builder<>(environment, transformProcess, policy);
|
||||||
|
}
|
||||||
|
|
||||||
|
public static class Builder<ACTION> {
|
||||||
|
private final Environment<ACTION> environment;
|
||||||
|
private final TransformProcess transformProcess;
|
||||||
|
private final IPolicy<ACTION> policy;
|
||||||
|
private Integer maxEpisodeSteps = null; // Default, no max
|
||||||
|
private String id;
|
||||||
|
|
||||||
|
public Builder(@NonNull Environment<ACTION> environment, @NonNull TransformProcess transformProcess, @NonNull IPolicy<ACTION> policy) {
|
||||||
|
this.environment = environment;
|
||||||
|
this.transformProcess = transformProcess;
|
||||||
|
this.policy = policy;
|
||||||
|
}
|
||||||
|
|
||||||
|
public Builder<ACTION> maxEpisodeSteps(int maxEpisodeSteps) {
|
||||||
|
Preconditions.checkArgument(maxEpisodeSteps > 0, "maxEpisodeSteps must be greater than 0, got", maxEpisodeSteps);
|
||||||
|
this.maxEpisodeSteps = maxEpisodeSteps;
|
||||||
|
|
||||||
|
return this;
|
||||||
|
}
|
||||||
|
|
||||||
|
public Builder<ACTION> id(String id) {
|
||||||
|
this.id = id;
|
||||||
|
return this;
|
||||||
|
}
|
||||||
|
|
||||||
|
public Agent build() {
|
||||||
|
return new Agent(this);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,23 @@
|
||||||
|
package org.deeplearning4j.rl4j.agent.listener;
|
||||||
|
|
||||||
|
import org.deeplearning4j.rl4j.agent.Agent;
|
||||||
|
import org.deeplearning4j.rl4j.environment.StepResult;
|
||||||
|
import org.deeplearning4j.rl4j.observation.Observation;
|
||||||
|
|
||||||
|
public interface AgentListener<ACTION> {
|
||||||
|
enum ListenerResponse {
|
||||||
|
/**
|
||||||
|
* Tell the learning process to continue calling the listeners and the training.
|
||||||
|
*/
|
||||||
|
CONTINUE,
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Tell the learning process to stop calling the listeners and terminate the training.
|
||||||
|
*/
|
||||||
|
STOP,
|
||||||
|
}
|
||||||
|
|
||||||
|
AgentListener.ListenerResponse onBeforeEpisode(Agent agent);
|
||||||
|
AgentListener.ListenerResponse onBeforeStep(Agent agent, Observation observation, ACTION action);
|
||||||
|
AgentListener.ListenerResponse onAfterStep(Agent agent, StepResult stepResult);
|
||||||
|
}
|
|
@ -0,0 +1,50 @@
|
||||||
|
package org.deeplearning4j.rl4j.agent.listener;
|
||||||
|
|
||||||
|
import org.deeplearning4j.rl4j.agent.Agent;
|
||||||
|
import org.deeplearning4j.rl4j.environment.StepResult;
|
||||||
|
import org.deeplearning4j.rl4j.observation.Observation;
|
||||||
|
|
||||||
|
import java.util.ArrayList;
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
|
public class AgentListenerList<ACTION> {
|
||||||
|
protected final List<AgentListener<ACTION>> listeners = new ArrayList<>();
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Add a listener at the end of the list
|
||||||
|
* @param listener The listener to be added
|
||||||
|
*/
|
||||||
|
public void add(AgentListener<ACTION> listener) {
|
||||||
|
listeners.add(listener);
|
||||||
|
}
|
||||||
|
|
||||||
|
public boolean notifyBeforeEpisode(Agent<ACTION> agent) {
|
||||||
|
for (AgentListener<ACTION> listener : listeners) {
|
||||||
|
if (listener.onBeforeEpisode(agent) == AgentListener.ListenerResponse.STOP) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
public boolean notifyBeforeStep(Agent<ACTION> agent, Observation observation, ACTION action) {
|
||||||
|
for (AgentListener<ACTION> listener : listeners) {
|
||||||
|
if (listener.onBeforeStep(agent, observation, action) == AgentListener.ListenerResponse.STOP) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
public boolean notifyAfterStep(Agent<ACTION> agent, StepResult stepResult) {
|
||||||
|
for (AgentListener<ACTION> listener : listeners) {
|
||||||
|
if (listener.onAfterStep(agent, stepResult) == AgentListener.ListenerResponse.STOP) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,9 @@
|
||||||
|
package org.deeplearning4j.rl4j.environment;
|
||||||
|
|
||||||
|
import lombok.Value;
|
||||||
|
|
||||||
|
@Value
|
||||||
|
public class ActionSchema<ACTION> {
|
||||||
|
private ACTION noOp;
|
||||||
|
//FIXME ACTION randomAction();
|
||||||
|
}
|
|
@ -0,0 +1,11 @@
|
||||||
|
package org.deeplearning4j.rl4j.environment;
|
||||||
|
|
||||||
|
import java.util.Map;
|
||||||
|
|
||||||
|
public interface Environment<ACTION> {
|
||||||
|
Schema<ACTION> getSchema();
|
||||||
|
Map<String, Object> reset();
|
||||||
|
StepResult step(ACTION action);
|
||||||
|
boolean isEpisodeFinished();
|
||||||
|
void close();
|
||||||
|
}
|
|
@ -0,0 +1,8 @@
|
||||||
|
package org.deeplearning4j.rl4j.environment;
|
||||||
|
|
||||||
|
import lombok.Value;
|
||||||
|
|
||||||
|
@Value
|
||||||
|
public class Schema<ACTION> {
|
||||||
|
private ActionSchema<ACTION> actionSchema;
|
||||||
|
}
|
|
@ -0,0 +1,12 @@
|
||||||
|
package org.deeplearning4j.rl4j.environment;
|
||||||
|
|
||||||
|
import lombok.Value;
|
||||||
|
|
||||||
|
import java.util.Map;
|
||||||
|
|
||||||
|
@Value
|
||||||
|
public class StepResult {
|
||||||
|
private Map<String, Object> channelsData;
|
||||||
|
private double reward;
|
||||||
|
private boolean terminal;
|
||||||
|
}
|
|
@ -30,7 +30,7 @@ import org.deeplearning4j.rl4j.space.Encodable;
|
||||||
*/
|
*/
|
||||||
public interface ILearning<O extends Encodable, A, AS extends ActionSpace<A>> {
|
public interface ILearning<O extends Encodable, A, AS extends ActionSpace<A>> {
|
||||||
|
|
||||||
IPolicy<O, A> getPolicy();
|
IPolicy<A> getPolicy();
|
||||||
|
|
||||||
void train();
|
void train();
|
||||||
|
|
||||||
|
|
|
@ -221,7 +221,7 @@ public abstract class AsyncThread<OBSERVATION extends Encodable, ACTION, ACTION_
|
||||||
|
|
||||||
protected abstract IAsyncLearningConfiguration getConf();
|
protected abstract IAsyncLearningConfiguration getConf();
|
||||||
|
|
||||||
protected abstract IPolicy<OBSERVATION, ACTION> getPolicy(NN net);
|
protected abstract IPolicy<ACTION> getPolicy(NN net);
|
||||||
|
|
||||||
protected abstract SubEpochReturn trainSubEpoch(Observation obs, int nstep);
|
protected abstract SubEpochReturn trainSubEpoch(Observation obs, int nstep);
|
||||||
|
|
||||||
|
|
|
@ -97,7 +97,7 @@ public abstract class AsyncThreadDiscrete<O extends Encodable, NN extends Neural
|
||||||
current.copy(getAsyncGlobal().getTarget());
|
current.copy(getAsyncGlobal().getTarget());
|
||||||
|
|
||||||
Observation obs = sObs;
|
Observation obs = sObs;
|
||||||
IPolicy<O, Integer> policy = getPolicy(current);
|
IPolicy<Integer> policy = getPolicy(current);
|
||||||
|
|
||||||
Integer action = getMdp().getActionSpace().noOp();
|
Integer action = getMdp().getActionSpace().noOp();
|
||||||
|
|
||||||
|
|
|
@ -65,7 +65,7 @@ public class A3CThreadDiscrete<O extends Encodable> extends AsyncThreadDiscrete<
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
protected Policy<O, Integer> getPolicy(IActorCritic net) {
|
protected Policy<Integer> getPolicy(IActorCritic net) {
|
||||||
return new ACPolicy(net, rnd);
|
return new ACPolicy(net, rnd);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -62,7 +62,7 @@ public abstract class AsyncNStepQLearningDiscrete<O extends Encodable>
|
||||||
return asyncGlobal.getTarget();
|
return asyncGlobal.getTarget();
|
||||||
}
|
}
|
||||||
|
|
||||||
public IPolicy<O, Integer> getPolicy() {
|
public IPolicy<Integer> getPolicy() {
|
||||||
return new DQNPolicy<O>(getNeuralNet());
|
return new DQNPolicy<O>(getNeuralNet());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -64,7 +64,7 @@ public class AsyncNStepQLearningThreadDiscrete<O extends Encodable> extends Asyn
|
||||||
setUpdateAlgorithm(buildUpdateAlgorithm());
|
setUpdateAlgorithm(buildUpdateAlgorithm());
|
||||||
}
|
}
|
||||||
|
|
||||||
public Policy<O, Integer> getPolicy(IDQN nn) {
|
public Policy<Integer> getPolicy(IDQN nn) {
|
||||||
return new EpsGreedy(new DQNPolicy(nn), getMdp(), conf.getUpdateStart(), conf.getEpsilonNbStep(),
|
return new EpsGreedy(new DQNPolicy(nn), getMdp(), conf.getUpdateStart(), conf.getEpsilonNbStep(),
|
||||||
rnd, conf.getMinEpsilon(), this);
|
rnd, conf.getMinEpsilon(), this);
|
||||||
}
|
}
|
||||||
|
|
|
@ -0,0 +1,129 @@
|
||||||
|
package org.deeplearning4j.rl4j.mdp;
|
||||||
|
|
||||||
|
import lombok.Getter;
|
||||||
|
import lombok.Setter;
|
||||||
|
import org.deeplearning4j.rl4j.environment.ActionSchema;
|
||||||
|
import org.deeplearning4j.rl4j.environment.Environment;
|
||||||
|
import org.deeplearning4j.rl4j.environment.Schema;
|
||||||
|
import org.deeplearning4j.rl4j.environment.StepResult;
|
||||||
|
|
||||||
|
import java.util.HashMap;
|
||||||
|
import java.util.Map;
|
||||||
|
import java.util.Random;
|
||||||
|
|
||||||
|
public class CartpoleEnvironment implements Environment<Integer> {
|
||||||
|
private static final int NUM_ACTIONS = 2;
|
||||||
|
private static final int ACTION_LEFT = 0;
|
||||||
|
private static final int ACTION_RIGHT = 1;
|
||||||
|
|
||||||
|
private static final Schema<Integer> schema = new Schema<>(new ActionSchema<>(ACTION_LEFT));
|
||||||
|
|
||||||
|
public enum KinematicsIntegrators { Euler, SemiImplicitEuler };
|
||||||
|
|
||||||
|
private static final double gravity = 9.8;
|
||||||
|
private static final double massCart = 1.0;
|
||||||
|
private static final double massPole = 0.1;
|
||||||
|
private static final double totalMass = massPole + massCart;
|
||||||
|
private static final double length = 0.5; // actually half the pole's length
|
||||||
|
private static final double polemassLength = massPole * length;
|
||||||
|
private static final double forceMag = 10.0;
|
||||||
|
private static final double tau = 0.02; // seconds between state updates
|
||||||
|
|
||||||
|
// Angle at which to fail the episode
|
||||||
|
private static final double thetaThresholdRadians = 12.0 * 2.0 * Math.PI / 360.0;
|
||||||
|
private static final double xThreshold = 2.4;
|
||||||
|
|
||||||
|
private final Random rnd;
|
||||||
|
|
||||||
|
@Getter @Setter
|
||||||
|
private KinematicsIntegrators kinematicsIntegrator = KinematicsIntegrators.Euler;
|
||||||
|
|
||||||
|
@Getter
|
||||||
|
private boolean episodeFinished = false;
|
||||||
|
|
||||||
|
private double x;
|
||||||
|
private double xDot;
|
||||||
|
private double theta;
|
||||||
|
private double thetaDot;
|
||||||
|
private Integer stepsBeyondDone;
|
||||||
|
|
||||||
|
public CartpoleEnvironment() {
|
||||||
|
rnd = new Random();
|
||||||
|
}
|
||||||
|
|
||||||
|
public CartpoleEnvironment(int seed) {
|
||||||
|
rnd = new Random(seed);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public Schema<Integer> getSchema() {
|
||||||
|
return schema;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public Map<String, Object> reset() {
|
||||||
|
|
||||||
|
x = 0.1 * rnd.nextDouble() - 0.05;
|
||||||
|
xDot = 0.1 * rnd.nextDouble() - 0.05;
|
||||||
|
theta = 0.1 * rnd.nextDouble() - 0.05;
|
||||||
|
thetaDot = 0.1 * rnd.nextDouble() - 0.05;
|
||||||
|
stepsBeyondDone = null;
|
||||||
|
episodeFinished = false;
|
||||||
|
|
||||||
|
return new HashMap<String, Object>() {{
|
||||||
|
put("data", new double[]{x, xDot, theta, thetaDot});
|
||||||
|
}};
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public StepResult step(Integer action) {
|
||||||
|
double force = action == ACTION_RIGHT ? forceMag : -forceMag;
|
||||||
|
double cosTheta = Math.cos(theta);
|
||||||
|
double sinTheta = Math.sin(theta);
|
||||||
|
double temp = (force + polemassLength * thetaDot * thetaDot * sinTheta) / totalMass;
|
||||||
|
double thetaAcc = (gravity * sinTheta - cosTheta* temp) / (length * (4.0/3.0 - massPole * cosTheta * cosTheta / totalMass));
|
||||||
|
double xAcc = temp - polemassLength * thetaAcc * cosTheta / totalMass;
|
||||||
|
|
||||||
|
switch(kinematicsIntegrator) {
|
||||||
|
case Euler:
|
||||||
|
x += tau * xDot;
|
||||||
|
xDot += tau * xAcc;
|
||||||
|
theta += tau * thetaDot;
|
||||||
|
thetaDot += tau * thetaAcc;
|
||||||
|
break;
|
||||||
|
|
||||||
|
case SemiImplicitEuler:
|
||||||
|
xDot += tau * xAcc;
|
||||||
|
x += tau * xDot;
|
||||||
|
thetaDot += tau * thetaAcc;
|
||||||
|
theta += tau * thetaDot;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
episodeFinished |= x < -xThreshold || x > xThreshold
|
||||||
|
|| theta < -thetaThresholdRadians || theta > thetaThresholdRadians;
|
||||||
|
|
||||||
|
double reward;
|
||||||
|
if(!episodeFinished) {
|
||||||
|
reward = 1.0;
|
||||||
|
}
|
||||||
|
else if(stepsBeyondDone == null) {
|
||||||
|
stepsBeyondDone = 0;
|
||||||
|
reward = 1.0;
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
++stepsBeyondDone;
|
||||||
|
reward = 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
Map<String, Object> channelsData = new HashMap<String, Object>() {{
|
||||||
|
put("data", new double[]{x, xDot, theta, thetaDot});
|
||||||
|
}};
|
||||||
|
return new StepResult(channelsData, reward, episodeFinished);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void close() {
|
||||||
|
// Do nothing
|
||||||
|
}
|
||||||
|
}
|
|
@ -35,7 +35,7 @@ import java.io.IOException;
|
||||||
* the softmax output of the actor critic, but objects constructed
|
* the softmax output of the actor critic, but objects constructed
|
||||||
* with a {@link Random} argument of null return the max only.
|
* with a {@link Random} argument of null return the max only.
|
||||||
*/
|
*/
|
||||||
public class ACPolicy<O extends Encodable> extends Policy<O, Integer> {
|
public class ACPolicy<O extends Encodable> extends Policy<Integer> {
|
||||||
|
|
||||||
final private IActorCritic actorCritic;
|
final private IActorCritic actorCritic;
|
||||||
Random rnd;
|
Random rnd;
|
||||||
|
|
|
@ -30,7 +30,7 @@ import static org.nd4j.linalg.ops.transforms.Transforms.exp;
|
||||||
* Boltzmann exploration is a stochastic policy wrt to the
|
* Boltzmann exploration is a stochastic policy wrt to the
|
||||||
* exponential Q-values as evaluated by the dqn model.
|
* exponential Q-values as evaluated by the dqn model.
|
||||||
*/
|
*/
|
||||||
public class BoltzmannQ<O extends Encodable> extends Policy<O, Integer> {
|
public class BoltzmannQ<O extends Encodable> extends Policy<Integer> {
|
||||||
|
|
||||||
final private IDQN dqn;
|
final private IDQN dqn;
|
||||||
final private Random rnd;
|
final private Random rnd;
|
||||||
|
|
|
@ -32,8 +32,10 @@ import java.io.IOException;
|
||||||
* DQN policy returns the action with the maximum Q-value as evaluated
|
* DQN policy returns the action with the maximum Q-value as evaluated
|
||||||
* by the dqn model
|
* by the dqn model
|
||||||
*/
|
*/
|
||||||
|
|
||||||
|
// FIXME: Should we rename this "GreedyPolicy"?
|
||||||
@AllArgsConstructor
|
@AllArgsConstructor
|
||||||
public class DQNPolicy<O extends Encodable> extends Policy<O, Integer> {
|
public class DQNPolicy<O> extends Policy<Integer> {
|
||||||
|
|
||||||
final private IDQN dqn;
|
final private IDQN dqn;
|
||||||
|
|
||||||
|
|
|
@ -41,9 +41,9 @@ import org.nd4j.linalg.api.rng.Random;
|
||||||
*/
|
*/
|
||||||
@AllArgsConstructor
|
@AllArgsConstructor
|
||||||
@Slf4j
|
@Slf4j
|
||||||
public class EpsGreedy<O extends Encodable, A, AS extends ActionSpace<A>> extends Policy<O, A> {
|
public class EpsGreedy<O extends Encodable, A, AS extends ActionSpace<A>> extends Policy<A> {
|
||||||
|
|
||||||
final private Policy<O, A> policy;
|
final private Policy<A> policy;
|
||||||
final private MDP<O, A, AS> mdp;
|
final private MDP<O, A, AS> mdp;
|
||||||
final private int updateStart;
|
final private int updateStart;
|
||||||
final private int epsilonNbStep;
|
final private int epsilonNbStep;
|
||||||
|
|
|
@ -7,8 +7,14 @@ import org.deeplearning4j.rl4j.space.ActionSpace;
|
||||||
import org.deeplearning4j.rl4j.space.Encodable;
|
import org.deeplearning4j.rl4j.space.Encodable;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
|
||||||
public interface IPolicy<O extends Encodable, A> {
|
public interface IPolicy<ACTION> {
|
||||||
<AS extends ActionSpace<A>> double play(MDP<O, A, AS> mdp, IHistoryProcessor hp);
|
@Deprecated
|
||||||
A nextAction(INDArray input);
|
<O extends Encodable, AS extends ActionSpace<ACTION>> double play(MDP<O, ACTION, AS> mdp, IHistoryProcessor hp);
|
||||||
A nextAction(Observation observation);
|
|
||||||
|
@Deprecated
|
||||||
|
ACTION nextAction(INDArray input);
|
||||||
|
|
||||||
|
ACTION nextAction(Observation observation);
|
||||||
|
|
||||||
|
void reset();
|
||||||
}
|
}
|
||||||
|
|
|
@ -34,22 +34,22 @@ import org.deeplearning4j.rl4j.util.LegacyMDPWrapper;
|
||||||
*
|
*
|
||||||
* A Policy responsability is to choose the next action given a state
|
* A Policy responsability is to choose the next action given a state
|
||||||
*/
|
*/
|
||||||
public abstract class Policy<O extends Encodable, A> implements IPolicy<O, A> {
|
public abstract class Policy<A> implements IPolicy<A> {
|
||||||
|
|
||||||
public abstract NeuralNet getNeuralNet();
|
public abstract NeuralNet getNeuralNet();
|
||||||
|
|
||||||
public abstract A nextAction(Observation obs);
|
public abstract A nextAction(Observation obs);
|
||||||
|
|
||||||
public <AS extends ActionSpace<A>> double play(MDP<O, A, AS> mdp) {
|
public <O extends Encodable, AS extends ActionSpace<A>> double play(MDP<O, A, AS> mdp) {
|
||||||
return play(mdp, (IHistoryProcessor)null);
|
return play(mdp, (IHistoryProcessor)null);
|
||||||
}
|
}
|
||||||
|
|
||||||
public <AS extends ActionSpace<A>> double play(MDP<O, A, AS> mdp, HistoryProcessor.Configuration conf) {
|
public <O extends Encodable, AS extends ActionSpace<A>> double play(MDP<O, A, AS> mdp, HistoryProcessor.Configuration conf) {
|
||||||
return play(mdp, new HistoryProcessor(conf));
|
return play(mdp, new HistoryProcessor(conf));
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public <AS extends ActionSpace<A>> double play(MDP<O, A, AS> mdp, IHistoryProcessor hp) {
|
public <O extends Encodable, AS extends ActionSpace<A>> double play(MDP<O, A, AS> mdp, IHistoryProcessor hp) {
|
||||||
resetNetworks();
|
resetNetworks();
|
||||||
|
|
||||||
LegacyMDPWrapper<O, A, AS> mdpWrapper = new LegacyMDPWrapper<O, A, AS>(mdp, hp);
|
LegacyMDPWrapper<O, A, AS> mdpWrapper = new LegacyMDPWrapper<O, A, AS>(mdp, hp);
|
||||||
|
@ -84,8 +84,11 @@ public abstract class Policy<O extends Encodable, A> implements IPolicy<O, A> {
|
||||||
protected void resetNetworks() {
|
protected void resetNetworks() {
|
||||||
getNeuralNet().reset();
|
getNeuralNet().reset();
|
||||||
}
|
}
|
||||||
|
public void reset() {
|
||||||
|
resetNetworks();
|
||||||
|
}
|
||||||
|
|
||||||
protected <AS extends ActionSpace<A>> Learning.InitMdp<Observation> refacInitMdp(LegacyMDPWrapper<O, A, AS> mdpWrapper, IHistoryProcessor hp) {
|
protected <O extends Encodable, AS extends ActionSpace<A>> Learning.InitMdp<Observation> refacInitMdp(LegacyMDPWrapper<O, A, AS> mdpWrapper, IHistoryProcessor hp) {
|
||||||
|
|
||||||
double reward = 0;
|
double reward = 0;
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,483 @@
|
||||||
|
package org.deeplearning4j.rl4j.agent;
|
||||||
|
|
||||||
|
import org.deeplearning4j.rl4j.agent.listener.AgentListener;
|
||||||
|
import org.deeplearning4j.rl4j.environment.ActionSchema;
|
||||||
|
import org.deeplearning4j.rl4j.environment.Environment;
|
||||||
|
import org.deeplearning4j.rl4j.environment.Schema;
|
||||||
|
import org.deeplearning4j.rl4j.environment.StepResult;
|
||||||
|
import org.deeplearning4j.rl4j.observation.Observation;
|
||||||
|
import org.deeplearning4j.rl4j.observation.transform.TransformProcess;
|
||||||
|
import org.deeplearning4j.rl4j.policy.IPolicy;
|
||||||
|
import org.junit.Rule;
|
||||||
|
import org.junit.Test;
|
||||||
|
import static org.junit.Assert.*;
|
||||||
|
|
||||||
|
import org.mockito.*;
|
||||||
|
import org.mockito.junit.*;
|
||||||
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
|
|
||||||
|
import java.util.HashMap;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.Map;
|
||||||
|
|
||||||
|
import static org.mockito.ArgumentMatchers.*;
|
||||||
|
import static org.mockito.Mockito.*;
|
||||||
|
|
||||||
|
public class AgentTest {
|
||||||
|
|
||||||
|
@Mock Environment environmentMock;
|
||||||
|
@Mock TransformProcess transformProcessMock;
|
||||||
|
@Mock IPolicy policyMock;
|
||||||
|
@Mock AgentListener listenerMock;
|
||||||
|
|
||||||
|
@Rule
|
||||||
|
public MockitoRule mockitoRule = MockitoJUnit.rule();
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void when_buildingWithNullEnvironment_expect_exception() {
|
||||||
|
try {
|
||||||
|
Agent.builder(null, null, null).build();
|
||||||
|
fail("NullPointerException should have been thrown");
|
||||||
|
} catch (NullPointerException exception) {
|
||||||
|
String expectedMessage = "environment is marked non-null but is null";
|
||||||
|
String actualMessage = exception.getMessage();
|
||||||
|
|
||||||
|
assertTrue(actualMessage.contains(expectedMessage));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void when_buildingWithNullTransformProcess_expect_exception() {
|
||||||
|
try {
|
||||||
|
Agent.builder(environmentMock, null, null).build();
|
||||||
|
fail("NullPointerException should have been thrown");
|
||||||
|
} catch (NullPointerException exception) {
|
||||||
|
String expectedMessage = "transformProcess is marked non-null but is null";
|
||||||
|
String actualMessage = exception.getMessage();
|
||||||
|
|
||||||
|
assertTrue(actualMessage.contains(expectedMessage));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void when_buildingWithNullPolicy_expect_exception() {
|
||||||
|
try {
|
||||||
|
Agent.builder(environmentMock, transformProcessMock, null).build();
|
||||||
|
fail("NullPointerException should have been thrown");
|
||||||
|
} catch (NullPointerException exception) {
|
||||||
|
String expectedMessage = "policy is marked non-null but is null";
|
||||||
|
String actualMessage = exception.getMessage();
|
||||||
|
|
||||||
|
assertTrue(actualMessage.contains(expectedMessage));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void when_buildingWithInvalidMaxSteps_expect_exception() {
|
||||||
|
try {
|
||||||
|
Agent.builder(environmentMock, transformProcessMock, policyMock)
|
||||||
|
.maxEpisodeSteps(0)
|
||||||
|
.build();
|
||||||
|
fail("IllegalArgumentException should have been thrown");
|
||||||
|
} catch (IllegalArgumentException exception) {
|
||||||
|
String expectedMessage = "maxEpisodeSteps must be greater than 0, got [0]";
|
||||||
|
String actualMessage = exception.getMessage();
|
||||||
|
|
||||||
|
assertTrue(actualMessage.contains(expectedMessage));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void when_buildingWithId_expect_idSetInAgent() {
|
||||||
|
// Arrange
|
||||||
|
Agent sut = Agent.builder(environmentMock, transformProcessMock, policyMock)
|
||||||
|
.id("TestAgent")
|
||||||
|
.build();
|
||||||
|
|
||||||
|
// Assert
|
||||||
|
assertEquals("TestAgent", sut.getId());
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void when_runIsCalled_expect_agentIsReset() {
|
||||||
|
// Arrange
|
||||||
|
Map<String, Object> envResetResult = new HashMap<>();
|
||||||
|
Schema schema = new Schema(new ActionSchema<>(-1));
|
||||||
|
when(environmentMock.reset()).thenReturn(envResetResult);
|
||||||
|
when(environmentMock.getSchema()).thenReturn(schema);
|
||||||
|
|
||||||
|
when(transformProcessMock.transform(any(Map.class), anyInt(), anyBoolean())).thenReturn(new Observation(Nd4j.create(new double[] { 123.0 })));
|
||||||
|
when(policyMock.nextAction(any(Observation.class))).thenReturn(1);
|
||||||
|
|
||||||
|
Agent sut = Agent.builder(environmentMock, transformProcessMock, policyMock)
|
||||||
|
.build();
|
||||||
|
|
||||||
|
when(listenerMock.onBeforeStep(any(Agent.class), any(Observation.class), anyInt())).thenReturn(AgentListener.ListenerResponse.STOP);
|
||||||
|
sut.addListener(listenerMock);
|
||||||
|
|
||||||
|
// Act
|
||||||
|
sut.run();
|
||||||
|
|
||||||
|
// Assert
|
||||||
|
assertEquals(0, sut.getEpisodeStepNumber());
|
||||||
|
verify(transformProcessMock).transform(envResetResult, 0, false);
|
||||||
|
verify(policyMock, times(1)).reset();
|
||||||
|
assertEquals(0.0, sut.getReward(), 0.00001);
|
||||||
|
verify(environmentMock, times(1)).reset();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void when_runIsCalled_expect_onBeforeAndAfterEpisodeCalled() {
|
||||||
|
// Arrange
|
||||||
|
Map<String, Object> envResetResult = new HashMap<>();
|
||||||
|
Schema schema = new Schema(new ActionSchema<>(-1));
|
||||||
|
when(environmentMock.reset()).thenReturn(envResetResult);
|
||||||
|
when(environmentMock.getSchema()).thenReturn(schema);
|
||||||
|
|
||||||
|
when(transformProcessMock.transform(any(Map.class), anyInt(), anyBoolean())).thenReturn(new Observation(Nd4j.create(new double[] { 123.0 })));
|
||||||
|
when(environmentMock.isEpisodeFinished()).thenReturn(true);
|
||||||
|
|
||||||
|
Agent sut = Agent.builder(environmentMock, transformProcessMock, policyMock).build();
|
||||||
|
Agent spy = Mockito.spy(sut);
|
||||||
|
|
||||||
|
// Act
|
||||||
|
spy.run();
|
||||||
|
|
||||||
|
// Assert
|
||||||
|
verify(spy, times(1)).onBeforeEpisode();
|
||||||
|
verify(spy, times(1)).onAfterEpisode();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void when_onBeforeEpisodeReturnsStop_expect_performStepAndOnAfterEpisodeNotCalled() {
|
||||||
|
// Arrange
|
||||||
|
Map<String, Object> envResetResult = new HashMap<>();
|
||||||
|
Schema schema = new Schema(new ActionSchema<>(-1));
|
||||||
|
when(environmentMock.reset()).thenReturn(envResetResult);
|
||||||
|
when(environmentMock.getSchema()).thenReturn(schema);
|
||||||
|
|
||||||
|
when(transformProcessMock.transform(any(Map.class), anyInt(), anyBoolean())).thenReturn(new Observation(Nd4j.create(new double[] { 123.0 })));
|
||||||
|
|
||||||
|
Agent sut = Agent.builder(environmentMock, transformProcessMock, policyMock).build();
|
||||||
|
|
||||||
|
when(listenerMock.onBeforeEpisode(any(Agent.class))).thenReturn(AgentListener.ListenerResponse.STOP);
|
||||||
|
sut.addListener(listenerMock);
|
||||||
|
|
||||||
|
Agent spy = Mockito.spy(sut);
|
||||||
|
|
||||||
|
// Act
|
||||||
|
spy.run();
|
||||||
|
|
||||||
|
// Assert
|
||||||
|
verify(spy, times(1)).onBeforeEpisode();
|
||||||
|
verify(spy, never()).performStep();
|
||||||
|
verify(spy, never()).onAfterStep(any(StepResult.class));
|
||||||
|
verify(spy, never()).onAfterEpisode();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void when_runIsCalledWithoutMaxStep_expect_agentRunUntilEpisodeIsFinished() {
|
||||||
|
// Arrange
|
||||||
|
Map<String, Object> envResetResult = new HashMap<>();
|
||||||
|
Schema schema = new Schema(new ActionSchema<>(-1));
|
||||||
|
when(environmentMock.reset()).thenReturn(envResetResult);
|
||||||
|
when(environmentMock.getSchema()).thenReturn(schema);
|
||||||
|
|
||||||
|
when(transformProcessMock.transform(any(Map.class), anyInt(), anyBoolean())).thenReturn(new Observation(Nd4j.create(new double[] { 123.0 })));
|
||||||
|
|
||||||
|
Agent sut = Agent.builder(environmentMock, transformProcessMock, policyMock)
|
||||||
|
.build();
|
||||||
|
|
||||||
|
final Agent spy = Mockito.spy(sut);
|
||||||
|
|
||||||
|
doAnswer(invocation -> {
|
||||||
|
((Agent)invocation.getMock()).incrementEpisodeStepNumber();
|
||||||
|
return null;
|
||||||
|
}).when(spy).performStep();
|
||||||
|
when(environmentMock.isEpisodeFinished()).thenAnswer(invocation -> spy.getEpisodeStepNumber() >= 5 );
|
||||||
|
|
||||||
|
// Act
|
||||||
|
spy.run();
|
||||||
|
|
||||||
|
// Assert
|
||||||
|
verify(spy, times(1)).onBeforeEpisode();
|
||||||
|
verify(spy, times(5)).performStep();
|
||||||
|
verify(spy, times(1)).onAfterEpisode();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void when_maxStepsIsReachedBeforeEposideEnds_expect_runTerminated() {
|
||||||
|
// Arrange
|
||||||
|
Map<String, Object> envResetResult = new HashMap<>();
|
||||||
|
Schema schema = new Schema(new ActionSchema<>(-1));
|
||||||
|
when(environmentMock.reset()).thenReturn(envResetResult);
|
||||||
|
when(environmentMock.getSchema()).thenReturn(schema);
|
||||||
|
|
||||||
|
when(transformProcessMock.transform(any(Map.class), anyInt(), anyBoolean())).thenReturn(new Observation(Nd4j.create(new double[] { 123.0 })));
|
||||||
|
|
||||||
|
Agent sut = Agent.builder(environmentMock, transformProcessMock, policyMock)
|
||||||
|
.maxEpisodeSteps(3)
|
||||||
|
.build();
|
||||||
|
|
||||||
|
final Agent spy = Mockito.spy(sut);
|
||||||
|
|
||||||
|
doAnswer(invocation -> {
|
||||||
|
((Agent)invocation.getMock()).incrementEpisodeStepNumber();
|
||||||
|
return null;
|
||||||
|
}).when(spy).performStep();
|
||||||
|
|
||||||
|
// Act
|
||||||
|
spy.run();
|
||||||
|
|
||||||
|
// Assert
|
||||||
|
verify(spy, times(1)).onBeforeEpisode();
|
||||||
|
verify(spy, times(3)).performStep();
|
||||||
|
verify(spy, times(1)).onAfterEpisode();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void when_initialObservationsAreSkipped_expect_performNoOpAction() {
|
||||||
|
// Arrange
|
||||||
|
Map<String, Object> envResetResult = new HashMap<>();
|
||||||
|
Schema schema = new Schema(new ActionSchema<>(-1));
|
||||||
|
when(environmentMock.reset()).thenReturn(envResetResult);
|
||||||
|
when(environmentMock.getSchema()).thenReturn(schema);
|
||||||
|
|
||||||
|
when(transformProcessMock.transform(any(Map.class), anyInt(), anyBoolean())).thenReturn(Observation.SkippedObservation);
|
||||||
|
|
||||||
|
Agent sut = Agent.builder(environmentMock, transformProcessMock, policyMock)
|
||||||
|
.build();
|
||||||
|
|
||||||
|
when(listenerMock.onBeforeStep(any(Agent.class), any(Observation.class), any())).thenReturn(AgentListener.ListenerResponse.STOP);
|
||||||
|
sut.addListener(listenerMock);
|
||||||
|
|
||||||
|
Agent spy = Mockito.spy(sut);
|
||||||
|
|
||||||
|
// Act
|
||||||
|
spy.run();
|
||||||
|
|
||||||
|
// Assert
|
||||||
|
verify(listenerMock).onBeforeStep(any(), any(), eq(-1));
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void when_initialObservationsAreSkipped_expect_performNoOpActionAnd() {
|
||||||
|
// Arrange
|
||||||
|
Map<String, Object> envResetResult = new HashMap<>();
|
||||||
|
Schema schema = new Schema(new ActionSchema<>(-1));
|
||||||
|
when(environmentMock.reset()).thenReturn(envResetResult);
|
||||||
|
when(environmentMock.getSchema()).thenReturn(schema);
|
||||||
|
|
||||||
|
when(transformProcessMock.transform(any(Map.class), anyInt(), anyBoolean())).thenReturn(Observation.SkippedObservation);
|
||||||
|
|
||||||
|
Agent sut = Agent.builder(environmentMock, transformProcessMock, policyMock)
|
||||||
|
.build();
|
||||||
|
|
||||||
|
when(listenerMock.onBeforeStep(any(Agent.class), any(Observation.class), any())).thenReturn(AgentListener.ListenerResponse.STOP);
|
||||||
|
sut.addListener(listenerMock);
|
||||||
|
|
||||||
|
Agent spy = Mockito.spy(sut);
|
||||||
|
|
||||||
|
// Act
|
||||||
|
spy.run();
|
||||||
|
|
||||||
|
// Assert
|
||||||
|
verify(listenerMock).onBeforeStep(any(), any(), eq(-1));
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void when_observationsIsSkipped_expect_performLastAction() {
|
||||||
|
// Arrange
|
||||||
|
Map<String, Object> envResetResult = new HashMap<>();
|
||||||
|
Schema schema = new Schema(new ActionSchema<>(-1));
|
||||||
|
when(environmentMock.reset()).thenReturn(envResetResult);
|
||||||
|
when(environmentMock.step(any(Integer.class))).thenReturn(new StepResult(envResetResult, 0.0, false));
|
||||||
|
when(environmentMock.getSchema()).thenReturn(schema);
|
||||||
|
|
||||||
|
when(policyMock.nextAction(any(Observation.class)))
|
||||||
|
.thenAnswer(invocation -> (int)((Observation)invocation.getArgument(0)).getData().getDouble(0));
|
||||||
|
|
||||||
|
Agent sut = Agent.builder(environmentMock, transformProcessMock, policyMock)
|
||||||
|
.maxEpisodeSteps(3)
|
||||||
|
.build();
|
||||||
|
|
||||||
|
Agent spy = Mockito.spy(sut);
|
||||||
|
|
||||||
|
when(transformProcessMock.transform(any(Map.class), anyInt(), anyBoolean()))
|
||||||
|
.thenAnswer(invocation -> {
|
||||||
|
int stepNumber = (int)invocation.getArgument(1);
|
||||||
|
return stepNumber % 2 == 1 ? Observation.SkippedObservation
|
||||||
|
: new Observation(Nd4j.create(new double[] { stepNumber }));
|
||||||
|
});
|
||||||
|
|
||||||
|
sut.addListener(listenerMock);
|
||||||
|
|
||||||
|
// Act
|
||||||
|
spy.run();
|
||||||
|
|
||||||
|
// Assert
|
||||||
|
verify(policyMock, times(2)).nextAction(any(Observation.class));
|
||||||
|
|
||||||
|
ArgumentCaptor<Agent> agentCaptor = ArgumentCaptor.forClass(Agent.class);
|
||||||
|
ArgumentCaptor<Observation> observationCaptor = ArgumentCaptor.forClass(Observation.class);
|
||||||
|
ArgumentCaptor<Integer> actionCaptor = ArgumentCaptor.forClass(Integer.class);
|
||||||
|
verify(listenerMock, times(3)).onBeforeStep(agentCaptor.capture(), observationCaptor.capture(), actionCaptor.capture());
|
||||||
|
List<Integer> capturedActions = actionCaptor.getAllValues();
|
||||||
|
assertEquals(0, (int)capturedActions.get(0));
|
||||||
|
assertEquals(0, (int)capturedActions.get(1));
|
||||||
|
assertEquals(2, (int)capturedActions.get(2));
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void when_onBeforeStepReturnsStop_expect_performStepAndOnAfterEpisodeNotCalled() {
|
||||||
|
// Arrange
|
||||||
|
Schema schema = new Schema(new ActionSchema<>(-1));
|
||||||
|
when(environmentMock.reset()).thenReturn(new HashMap<>());
|
||||||
|
when(environmentMock.getSchema()).thenReturn(schema);
|
||||||
|
|
||||||
|
when(transformProcessMock.transform(any(Map.class), anyInt(), anyBoolean())).thenReturn(new Observation(Nd4j.create(new double[] { 123.0 })));
|
||||||
|
|
||||||
|
Agent sut = Agent.builder(environmentMock, transformProcessMock, policyMock).build();
|
||||||
|
|
||||||
|
when(listenerMock.onBeforeStep(any(Agent.class), any(Observation.class), any())).thenReturn(AgentListener.ListenerResponse.STOP);
|
||||||
|
sut.addListener(listenerMock);
|
||||||
|
|
||||||
|
Agent spy = Mockito.spy(sut);
|
||||||
|
|
||||||
|
// Act
|
||||||
|
spy.run();
|
||||||
|
|
||||||
|
// Assert
|
||||||
|
verify(spy, times(1)).onBeforeEpisode();
|
||||||
|
verify(spy, times(1)).onBeforeStep();
|
||||||
|
verify(spy, never()).act(any());
|
||||||
|
verify(spy, never()).onAfterStep(any(StepResult.class));
|
||||||
|
verify(spy, never()).onAfterEpisode();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void when_observationIsNotSkipped_expect_policyActionIsSentToEnvironment() {
|
||||||
|
// Arrange
|
||||||
|
Schema schema = new Schema(new ActionSchema<>(-1));
|
||||||
|
when(environmentMock.reset()).thenReturn(new HashMap<>());
|
||||||
|
when(environmentMock.getSchema()).thenReturn(schema);
|
||||||
|
when(environmentMock.step(any(Integer.class))).thenReturn(new StepResult(new HashMap<>(), 0.0, false));
|
||||||
|
|
||||||
|
when(transformProcessMock.transform(any(Map.class), anyInt(), anyBoolean())).thenReturn(new Observation(Nd4j.create(new double[] { 123.0 })));
|
||||||
|
|
||||||
|
when(policyMock.nextAction(any(Observation.class))).thenReturn(123);
|
||||||
|
|
||||||
|
Agent sut = Agent.builder(environmentMock, transformProcessMock, policyMock)
|
||||||
|
.maxEpisodeSteps(1)
|
||||||
|
.build();
|
||||||
|
|
||||||
|
// Act
|
||||||
|
sut.run();
|
||||||
|
|
||||||
|
// Assert
|
||||||
|
verify(environmentMock, times(1)).step(123);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void when_stepResultIsReceived_expect_observationAndRewardUpdated() {
|
||||||
|
// Arrange
|
||||||
|
Schema schema = new Schema(new ActionSchema<>(-1));
|
||||||
|
when(environmentMock.reset()).thenReturn(new HashMap<>());
|
||||||
|
when(environmentMock.getSchema()).thenReturn(schema);
|
||||||
|
when(environmentMock.step(any(Integer.class))).thenReturn(new StepResult(new HashMap<>(), 234.0, false));
|
||||||
|
|
||||||
|
when(transformProcessMock.transform(any(Map.class), anyInt(), anyBoolean())).thenReturn(new Observation(Nd4j.create(new double[] { 123.0 })));
|
||||||
|
|
||||||
|
when(policyMock.nextAction(any(Observation.class))).thenReturn(123);
|
||||||
|
|
||||||
|
Agent sut = Agent.builder(environmentMock, transformProcessMock, policyMock)
|
||||||
|
.maxEpisodeSteps(1)
|
||||||
|
.build();
|
||||||
|
|
||||||
|
// Act
|
||||||
|
sut.run();
|
||||||
|
|
||||||
|
// Assert
|
||||||
|
assertEquals(123.0, sut.getObservation().getData().getDouble(0), 0.00001);
|
||||||
|
assertEquals(234.0, sut.getReward(), 0.00001);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void when_stepIsDone_expect_onAfterStepAndWithStepResult() {
|
||||||
|
// Arrange
|
||||||
|
Schema schema = new Schema(new ActionSchema<>(-1));
|
||||||
|
when(environmentMock.reset()).thenReturn(new HashMap<>());
|
||||||
|
when(environmentMock.getSchema()).thenReturn(schema);
|
||||||
|
StepResult stepResult = new StepResult(new HashMap<>(), 234.0, false);
|
||||||
|
when(environmentMock.step(any(Integer.class))).thenReturn(stepResult);
|
||||||
|
|
||||||
|
when(transformProcessMock.transform(any(Map.class), anyInt(), anyBoolean())).thenReturn(new Observation(Nd4j.create(new double[] { 123.0 })));
|
||||||
|
|
||||||
|
when(policyMock.nextAction(any(Observation.class))).thenReturn(123);
|
||||||
|
|
||||||
|
Agent sut = Agent.builder(environmentMock, transformProcessMock, policyMock)
|
||||||
|
.maxEpisodeSteps(1)
|
||||||
|
.build();
|
||||||
|
Agent spy = Mockito.spy(sut);
|
||||||
|
|
||||||
|
// Act
|
||||||
|
spy.run();
|
||||||
|
|
||||||
|
// Assert
|
||||||
|
verify(spy).onAfterStep(stepResult);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void when_onAfterStepReturnsStop_expect_onAfterEpisodeNotCalled() {
|
||||||
|
// Arrange
|
||||||
|
Schema schema = new Schema(new ActionSchema<>(-1));
|
||||||
|
when(environmentMock.reset()).thenReturn(new HashMap<>());
|
||||||
|
when(environmentMock.getSchema()).thenReturn(schema);
|
||||||
|
StepResult stepResult = new StepResult(new HashMap<>(), 234.0, false);
|
||||||
|
when(environmentMock.step(any(Integer.class))).thenReturn(stepResult);
|
||||||
|
|
||||||
|
when(transformProcessMock.transform(any(Map.class), anyInt(), anyBoolean())).thenReturn(new Observation(Nd4j.create(new double[] { 123.0 })));
|
||||||
|
|
||||||
|
when(policyMock.nextAction(any(Observation.class))).thenReturn(123);
|
||||||
|
|
||||||
|
Agent sut = Agent.builder(environmentMock, transformProcessMock, policyMock)
|
||||||
|
.maxEpisodeSteps(1)
|
||||||
|
.build();
|
||||||
|
when(listenerMock.onAfterStep(any(Agent.class), any(StepResult.class))).thenReturn(AgentListener.ListenerResponse.STOP);
|
||||||
|
sut.addListener(listenerMock);
|
||||||
|
|
||||||
|
Agent spy = Mockito.spy(sut);
|
||||||
|
|
||||||
|
// Act
|
||||||
|
spy.run();
|
||||||
|
|
||||||
|
// Assert
|
||||||
|
verify(spy, never()).onAfterEpisode();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void when_runIsCalled_expect_onAfterEpisodeIsCalled() {
|
||||||
|
// Arrange
|
||||||
|
Schema schema = new Schema(new ActionSchema<>(-1));
|
||||||
|
when(environmentMock.reset()).thenReturn(new HashMap<>());
|
||||||
|
when(environmentMock.getSchema()).thenReturn(schema);
|
||||||
|
StepResult stepResult = new StepResult(new HashMap<>(), 234.0, false);
|
||||||
|
when(environmentMock.step(any(Integer.class))).thenReturn(stepResult);
|
||||||
|
|
||||||
|
when(transformProcessMock.transform(any(Map.class), anyInt(), anyBoolean())).thenReturn(new Observation(Nd4j.create(new double[] { 123.0 })));
|
||||||
|
|
||||||
|
when(policyMock.nextAction(any(Observation.class))).thenReturn(123);
|
||||||
|
|
||||||
|
Agent sut = Agent.builder(environmentMock, transformProcessMock, policyMock)
|
||||||
|
.maxEpisodeSteps(1)
|
||||||
|
.build();
|
||||||
|
|
||||||
|
Agent spy = Mockito.spy(sut);
|
||||||
|
|
||||||
|
// Act
|
||||||
|
spy.run();
|
||||||
|
|
||||||
|
// Assert
|
||||||
|
verify(spy, times(1)).onAfterEpisode();
|
||||||
|
}
|
||||||
|
}
|
|
@ -62,7 +62,7 @@ public class AsyncThreadDiscreteTest {
|
||||||
IAsyncGlobal<NeuralNet> mockAsyncGlobal;
|
IAsyncGlobal<NeuralNet> mockAsyncGlobal;
|
||||||
|
|
||||||
@Mock
|
@Mock
|
||||||
Policy<Encodable, Integer> mockGlobalCurrentPolicy;
|
Policy<Integer> mockGlobalCurrentPolicy;
|
||||||
|
|
||||||
@Mock
|
@Mock
|
||||||
NeuralNet mockGlobalTargetNetwork;
|
NeuralNet mockGlobalTargetNetwork;
|
||||||
|
|
|
@ -30,13 +30,8 @@ import org.deeplearning4j.rl4j.network.NeuralNet;
|
||||||
import org.deeplearning4j.rl4j.network.ac.IActorCritic;
|
import org.deeplearning4j.rl4j.network.ac.IActorCritic;
|
||||||
import org.deeplearning4j.rl4j.observation.Observation;
|
import org.deeplearning4j.rl4j.observation.Observation;
|
||||||
import org.deeplearning4j.rl4j.space.ActionSpace;
|
import org.deeplearning4j.rl4j.space.ActionSpace;
|
||||||
import org.deeplearning4j.rl4j.support.MockDQN;
|
import org.deeplearning4j.rl4j.space.Encodable;
|
||||||
import org.deeplearning4j.rl4j.support.MockEncodable;
|
import org.deeplearning4j.rl4j.support.*;
|
||||||
import org.deeplearning4j.rl4j.support.MockHistoryProcessor;
|
|
||||||
import org.deeplearning4j.rl4j.support.MockMDP;
|
|
||||||
import org.deeplearning4j.rl4j.support.MockNeuralNet;
|
|
||||||
import org.deeplearning4j.rl4j.support.MockObservationSpace;
|
|
||||||
import org.deeplearning4j.rl4j.support.MockRandom;
|
|
||||||
import org.deeplearning4j.rl4j.util.LegacyMDPWrapper;
|
import org.deeplearning4j.rl4j.util.LegacyMDPWrapper;
|
||||||
import org.junit.Test;
|
import org.junit.Test;
|
||||||
import org.nd4j.linalg.activations.Activation;
|
import org.nd4j.linalg.activations.Activation;
|
||||||
|
@ -227,7 +222,7 @@ public class PolicyTest {
|
||||||
assertEquals(0, dqn.outputParams.size());
|
assertEquals(0, dqn.outputParams.size());
|
||||||
}
|
}
|
||||||
|
|
||||||
public static class MockRefacPolicy extends Policy<MockEncodable, Integer> {
|
public static class MockRefacPolicy extends Policy<Integer> {
|
||||||
|
|
||||||
private NeuralNet neuralNet;
|
private NeuralNet neuralNet;
|
||||||
private final int[] shape;
|
private final int[] shape;
|
||||||
|
@ -257,7 +252,7 @@ public class PolicyTest {
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
protected <AS extends ActionSpace<Integer>> Learning.InitMdp<Observation> refacInitMdp(LegacyMDPWrapper<MockEncodable, Integer, AS> mdpWrapper, IHistoryProcessor hp) {
|
protected <O extends Encodable, AS extends ActionSpace<Integer>> Learning.InitMdp<Observation> refacInitMdp(LegacyMDPWrapper<O, Integer, AS> mdpWrapper, IHistoryProcessor hp) {
|
||||||
mdpWrapper.setTransformProcess(MockMDP.buildTransformProcess(shape, skipFrame, historyLength));
|
mdpWrapper.setTransformProcess(MockMDP.buildTransformProcess(shape, skipFrame, historyLength));
|
||||||
return super.refacInitMdp(mdpWrapper, hp);
|
return super.refacInitMdp(mdpWrapper, hp);
|
||||||
}
|
}
|
||||||
|
|
|
@ -5,18 +5,19 @@ import org.deeplearning4j.rl4j.mdp.MDP;
|
||||||
import org.deeplearning4j.rl4j.observation.Observation;
|
import org.deeplearning4j.rl4j.observation.Observation;
|
||||||
import org.deeplearning4j.rl4j.policy.IPolicy;
|
import org.deeplearning4j.rl4j.policy.IPolicy;
|
||||||
import org.deeplearning4j.rl4j.space.ActionSpace;
|
import org.deeplearning4j.rl4j.space.ActionSpace;
|
||||||
|
import org.deeplearning4j.rl4j.space.Encodable;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
|
||||||
public class MockPolicy implements IPolicy<MockEncodable, Integer> {
|
public class MockPolicy implements IPolicy<Integer> {
|
||||||
|
|
||||||
public int playCallCount = 0;
|
public int playCallCount = 0;
|
||||||
public List<INDArray> actionInputs = new ArrayList<INDArray>();
|
public List<INDArray> actionInputs = new ArrayList<INDArray>();
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public <AS extends ActionSpace<Integer>> double play(MDP<MockEncodable, Integer, AS> mdp, IHistoryProcessor hp) {
|
public <O extends Encodable, AS extends ActionSpace<Integer>> double play(MDP<O, Integer, AS> mdp, IHistoryProcessor hp) {
|
||||||
++playCallCount;
|
++playCallCount;
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
@ -31,4 +32,9 @@ public class MockPolicy implements IPolicy<MockEncodable, Integer> {
|
||||||
public Integer nextAction(Observation observation) {
|
public Integer nextAction(Observation observation) {
|
||||||
return nextAction(observation.getData());
|
return nextAction(observation.getData());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void reset() {
|
||||||
|
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue