RL4J: Add Agent and Environment (#358)

* Added Agent and Environment

Signed-off-by: Alexandre Boulanger <aboulang2002@yahoo.com>

* Added headers

Signed-off-by: Alexandre Boulanger <aboulang2002@yahoo.com>

* Fix compilation errors

Signed-off-by: Samuel Audet <samuel.audet@gmail.com>
master
Alexandre Boulanger 2020-04-21 20:13:08 -04:00 committed by GitHub
parent a10fd4524a
commit 550e84ef43
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
24 changed files with 979 additions and 32 deletions

View File

@ -0,0 +1,210 @@
package org.deeplearning4j.rl4j.agent;
import lombok.AccessLevel;
import lombok.Getter;
import lombok.NonNull;
import org.deeplearning4j.rl4j.agent.listener.AgentListener;
import org.deeplearning4j.rl4j.agent.listener.AgentListenerList;
import org.deeplearning4j.rl4j.environment.Environment;
import org.deeplearning4j.rl4j.environment.StepResult;
import org.deeplearning4j.rl4j.observation.Observation;
import org.deeplearning4j.rl4j.observation.transform.TransformProcess;
import org.deeplearning4j.rl4j.policy.IPolicy;
import org.nd4j.base.Preconditions;
import java.util.Map;
public class Agent<ACTION> {
@Getter
private final String id;
@Getter
private final Environment<ACTION> environment;
@Getter
private final IPolicy<ACTION> policy;
private final TransformProcess transformProcess;
protected final AgentListenerList<ACTION> listeners;
private final Integer maxEpisodeSteps;
@Getter(AccessLevel.PROTECTED)
private Observation observation;
@Getter(AccessLevel.PROTECTED)
private ACTION lastAction;
@Getter
private int episodeStepNumber;
@Getter
private double reward;
protected boolean canContinue;
private Agent(Builder<ACTION> builder) {
this.environment = builder.environment;
this.transformProcess = builder.transformProcess;
this.policy = builder.policy;
this.maxEpisodeSteps = builder.maxEpisodeSteps;
this.id = builder.id;
listeners = buildListenerList();
}
protected AgentListenerList<ACTION> buildListenerList() {
return new AgentListenerList<ACTION>();
}
public void addListener(AgentListener listener) {
listeners.add(listener);
}
public void run() {
runEpisode();
}
protected void onBeforeEpisode() {
// Do Nothing
}
protected void onAfterEpisode() {
// Do Nothing
}
protected void runEpisode() {
reset();
onBeforeEpisode();
canContinue = listeners.notifyBeforeEpisode(this);
while (canContinue && !environment.isEpisodeFinished() && (maxEpisodeSteps == null || episodeStepNumber < maxEpisodeSteps)) {
performStep();
}
if(!canContinue) {
return;
}
onAfterEpisode();
}
protected void reset() {
resetEnvironment();
resetPolicy();
reward = 0;
lastAction = getInitialAction();
canContinue = true;
}
protected void resetEnvironment() {
episodeStepNumber = 0;
Map<String, Object> channelsData = environment.reset();
this.observation = transformProcess.transform(channelsData, episodeStepNumber, false);
}
protected void resetPolicy() {
policy.reset();
}
protected ACTION getInitialAction() {
return environment.getSchema().getActionSchema().getNoOp();
}
protected void performStep() {
onBeforeStep();
ACTION action = decideAction(observation);
canContinue = listeners.notifyBeforeStep(this, observation, action);
if(!canContinue) {
return;
}
StepResult stepResult = act(action);
handleStepResult(stepResult);
onAfterStep(stepResult);
canContinue = listeners.notifyAfterStep(this, stepResult);
if(!canContinue) {
return;
}
incrementEpisodeStepNumber();
}
protected void incrementEpisodeStepNumber() {
++episodeStepNumber;
}
protected ACTION decideAction(Observation observation) {
if (!observation.isSkipped()) {
lastAction = policy.nextAction(observation);
}
return lastAction;
}
protected StepResult act(ACTION action) {
return environment.step(action);
}
protected void handleStepResult(StepResult stepResult) {
observation = convertChannelDataToObservation(stepResult, episodeStepNumber + 1);
reward +=computeReward(stepResult);
}
protected Observation convertChannelDataToObservation(StepResult stepResult, int episodeStepNumberOfObs) {
return transformProcess.transform(stepResult.getChannelsData(), episodeStepNumberOfObs, stepResult.isTerminal());
}
protected double computeReward(StepResult stepResult) {
return stepResult.getReward();
}
protected void onAfterStep(StepResult stepResult) {
// Do Nothing
}
protected void onBeforeStep() {
// Do Nothing
}
public static <ACTION> Builder<ACTION> builder(@NonNull Environment<ACTION> environment, @NonNull TransformProcess transformProcess, @NonNull IPolicy<ACTION> policy) {
return new Builder<>(environment, transformProcess, policy);
}
public static class Builder<ACTION> {
private final Environment<ACTION> environment;
private final TransformProcess transformProcess;
private final IPolicy<ACTION> policy;
private Integer maxEpisodeSteps = null; // Default, no max
private String id;
public Builder(@NonNull Environment<ACTION> environment, @NonNull TransformProcess transformProcess, @NonNull IPolicy<ACTION> policy) {
this.environment = environment;
this.transformProcess = transformProcess;
this.policy = policy;
}
public Builder<ACTION> maxEpisodeSteps(int maxEpisodeSteps) {
Preconditions.checkArgument(maxEpisodeSteps > 0, "maxEpisodeSteps must be greater than 0, got", maxEpisodeSteps);
this.maxEpisodeSteps = maxEpisodeSteps;
return this;
}
public Builder<ACTION> id(String id) {
this.id = id;
return this;
}
public Agent build() {
return new Agent(this);
}
}
}

View File

@ -0,0 +1,23 @@
package org.deeplearning4j.rl4j.agent.listener;
import org.deeplearning4j.rl4j.agent.Agent;
import org.deeplearning4j.rl4j.environment.StepResult;
import org.deeplearning4j.rl4j.observation.Observation;
public interface AgentListener<ACTION> {
enum ListenerResponse {
/**
* Tell the learning process to continue calling the listeners and the training.
*/
CONTINUE,
/**
* Tell the learning process to stop calling the listeners and terminate the training.
*/
STOP,
}
AgentListener.ListenerResponse onBeforeEpisode(Agent agent);
AgentListener.ListenerResponse onBeforeStep(Agent agent, Observation observation, ACTION action);
AgentListener.ListenerResponse onAfterStep(Agent agent, StepResult stepResult);
}

View File

@ -0,0 +1,50 @@
package org.deeplearning4j.rl4j.agent.listener;
import org.deeplearning4j.rl4j.agent.Agent;
import org.deeplearning4j.rl4j.environment.StepResult;
import org.deeplearning4j.rl4j.observation.Observation;
import java.util.ArrayList;
import java.util.List;
public class AgentListenerList<ACTION> {
protected final List<AgentListener<ACTION>> listeners = new ArrayList<>();
/**
* Add a listener at the end of the list
* @param listener The listener to be added
*/
public void add(AgentListener<ACTION> listener) {
listeners.add(listener);
}
public boolean notifyBeforeEpisode(Agent<ACTION> agent) {
for (AgentListener<ACTION> listener : listeners) {
if (listener.onBeforeEpisode(agent) == AgentListener.ListenerResponse.STOP) {
return false;
}
}
return true;
}
public boolean notifyBeforeStep(Agent<ACTION> agent, Observation observation, ACTION action) {
for (AgentListener<ACTION> listener : listeners) {
if (listener.onBeforeStep(agent, observation, action) == AgentListener.ListenerResponse.STOP) {
return false;
}
}
return true;
}
public boolean notifyAfterStep(Agent<ACTION> agent, StepResult stepResult) {
for (AgentListener<ACTION> listener : listeners) {
if (listener.onAfterStep(agent, stepResult) == AgentListener.ListenerResponse.STOP) {
return false;
}
}
return true;
}
}

View File

@ -0,0 +1,9 @@
package org.deeplearning4j.rl4j.environment;
import lombok.Value;
@Value
public class ActionSchema<ACTION> {
private ACTION noOp;
//FIXME ACTION randomAction();
}

View File

@ -0,0 +1,11 @@
package org.deeplearning4j.rl4j.environment;
import java.util.Map;
public interface Environment<ACTION> {
Schema<ACTION> getSchema();
Map<String, Object> reset();
StepResult step(ACTION action);
boolean isEpisodeFinished();
void close();
}

View File

@ -0,0 +1,8 @@
package org.deeplearning4j.rl4j.environment;
import lombok.Value;
@Value
public class Schema<ACTION> {
private ActionSchema<ACTION> actionSchema;
}

View File

@ -0,0 +1,12 @@
package org.deeplearning4j.rl4j.environment;
import lombok.Value;
import java.util.Map;
@Value
public class StepResult {
private Map<String, Object> channelsData;
private double reward;
private boolean terminal;
}

View File

@ -30,7 +30,7 @@ import org.deeplearning4j.rl4j.space.Encodable;
*/ */
public interface ILearning<O extends Encodable, A, AS extends ActionSpace<A>> { public interface ILearning<O extends Encodable, A, AS extends ActionSpace<A>> {
IPolicy<O, A> getPolicy(); IPolicy<A> getPolicy();
void train(); void train();

View File

@ -221,7 +221,7 @@ public abstract class AsyncThread<OBSERVATION extends Encodable, ACTION, ACTION_
protected abstract IAsyncLearningConfiguration getConf(); protected abstract IAsyncLearningConfiguration getConf();
protected abstract IPolicy<OBSERVATION, ACTION> getPolicy(NN net); protected abstract IPolicy<ACTION> getPolicy(NN net);
protected abstract SubEpochReturn trainSubEpoch(Observation obs, int nstep); protected abstract SubEpochReturn trainSubEpoch(Observation obs, int nstep);

View File

@ -97,7 +97,7 @@ public abstract class AsyncThreadDiscrete<O extends Encodable, NN extends Neural
current.copy(getAsyncGlobal().getTarget()); current.copy(getAsyncGlobal().getTarget());
Observation obs = sObs; Observation obs = sObs;
IPolicy<O, Integer> policy = getPolicy(current); IPolicy<Integer> policy = getPolicy(current);
Integer action = getMdp().getActionSpace().noOp(); Integer action = getMdp().getActionSpace().noOp();

View File

@ -65,7 +65,7 @@ public class A3CThreadDiscrete<O extends Encodable> extends AsyncThreadDiscrete<
} }
@Override @Override
protected Policy<O, Integer> getPolicy(IActorCritic net) { protected Policy<Integer> getPolicy(IActorCritic net) {
return new ACPolicy(net, rnd); return new ACPolicy(net, rnd);
} }

View File

@ -62,7 +62,7 @@ public abstract class AsyncNStepQLearningDiscrete<O extends Encodable>
return asyncGlobal.getTarget(); return asyncGlobal.getTarget();
} }
public IPolicy<O, Integer> getPolicy() { public IPolicy<Integer> getPolicy() {
return new DQNPolicy<O>(getNeuralNet()); return new DQNPolicy<O>(getNeuralNet());
} }

View File

@ -64,7 +64,7 @@ public class AsyncNStepQLearningThreadDiscrete<O extends Encodable> extends Asyn
setUpdateAlgorithm(buildUpdateAlgorithm()); setUpdateAlgorithm(buildUpdateAlgorithm());
} }
public Policy<O, Integer> getPolicy(IDQN nn) { public Policy<Integer> getPolicy(IDQN nn) {
return new EpsGreedy(new DQNPolicy(nn), getMdp(), conf.getUpdateStart(), conf.getEpsilonNbStep(), return new EpsGreedy(new DQNPolicy(nn), getMdp(), conf.getUpdateStart(), conf.getEpsilonNbStep(),
rnd, conf.getMinEpsilon(), this); rnd, conf.getMinEpsilon(), this);
} }

View File

@ -0,0 +1,129 @@
package org.deeplearning4j.rl4j.mdp;
import lombok.Getter;
import lombok.Setter;
import org.deeplearning4j.rl4j.environment.ActionSchema;
import org.deeplearning4j.rl4j.environment.Environment;
import org.deeplearning4j.rl4j.environment.Schema;
import org.deeplearning4j.rl4j.environment.StepResult;
import java.util.HashMap;
import java.util.Map;
import java.util.Random;
public class CartpoleEnvironment implements Environment<Integer> {
private static final int NUM_ACTIONS = 2;
private static final int ACTION_LEFT = 0;
private static final int ACTION_RIGHT = 1;
private static final Schema<Integer> schema = new Schema<>(new ActionSchema<>(ACTION_LEFT));
public enum KinematicsIntegrators { Euler, SemiImplicitEuler };
private static final double gravity = 9.8;
private static final double massCart = 1.0;
private static final double massPole = 0.1;
private static final double totalMass = massPole + massCart;
private static final double length = 0.5; // actually half the pole's length
private static final double polemassLength = massPole * length;
private static final double forceMag = 10.0;
private static final double tau = 0.02; // seconds between state updates
// Angle at which to fail the episode
private static final double thetaThresholdRadians = 12.0 * 2.0 * Math.PI / 360.0;
private static final double xThreshold = 2.4;
private final Random rnd;
@Getter @Setter
private KinematicsIntegrators kinematicsIntegrator = KinematicsIntegrators.Euler;
@Getter
private boolean episodeFinished = false;
private double x;
private double xDot;
private double theta;
private double thetaDot;
private Integer stepsBeyondDone;
public CartpoleEnvironment() {
rnd = new Random();
}
public CartpoleEnvironment(int seed) {
rnd = new Random(seed);
}
@Override
public Schema<Integer> getSchema() {
return schema;
}
@Override
public Map<String, Object> reset() {
x = 0.1 * rnd.nextDouble() - 0.05;
xDot = 0.1 * rnd.nextDouble() - 0.05;
theta = 0.1 * rnd.nextDouble() - 0.05;
thetaDot = 0.1 * rnd.nextDouble() - 0.05;
stepsBeyondDone = null;
episodeFinished = false;
return new HashMap<String, Object>() {{
put("data", new double[]{x, xDot, theta, thetaDot});
}};
}
@Override
public StepResult step(Integer action) {
double force = action == ACTION_RIGHT ? forceMag : -forceMag;
double cosTheta = Math.cos(theta);
double sinTheta = Math.sin(theta);
double temp = (force + polemassLength * thetaDot * thetaDot * sinTheta) / totalMass;
double thetaAcc = (gravity * sinTheta - cosTheta* temp) / (length * (4.0/3.0 - massPole * cosTheta * cosTheta / totalMass));
double xAcc = temp - polemassLength * thetaAcc * cosTheta / totalMass;
switch(kinematicsIntegrator) {
case Euler:
x += tau * xDot;
xDot += tau * xAcc;
theta += tau * thetaDot;
thetaDot += tau * thetaAcc;
break;
case SemiImplicitEuler:
xDot += tau * xAcc;
x += tau * xDot;
thetaDot += tau * thetaAcc;
theta += tau * thetaDot;
break;
}
episodeFinished |= x < -xThreshold || x > xThreshold
|| theta < -thetaThresholdRadians || theta > thetaThresholdRadians;
double reward;
if(!episodeFinished) {
reward = 1.0;
}
else if(stepsBeyondDone == null) {
stepsBeyondDone = 0;
reward = 1.0;
}
else {
++stepsBeyondDone;
reward = 0;
}
Map<String, Object> channelsData = new HashMap<String, Object>() {{
put("data", new double[]{x, xDot, theta, thetaDot});
}};
return new StepResult(channelsData, reward, episodeFinished);
}
@Override
public void close() {
// Do nothing
}
}

View File

@ -35,7 +35,7 @@ import java.io.IOException;
* the softmax output of the actor critic, but objects constructed * the softmax output of the actor critic, but objects constructed
* with a {@link Random} argument of null return the max only. * with a {@link Random} argument of null return the max only.
*/ */
public class ACPolicy<O extends Encodable> extends Policy<O, Integer> { public class ACPolicy<O extends Encodable> extends Policy<Integer> {
final private IActorCritic actorCritic; final private IActorCritic actorCritic;
Random rnd; Random rnd;

View File

@ -30,7 +30,7 @@ import static org.nd4j.linalg.ops.transforms.Transforms.exp;
* Boltzmann exploration is a stochastic policy wrt to the * Boltzmann exploration is a stochastic policy wrt to the
* exponential Q-values as evaluated by the dqn model. * exponential Q-values as evaluated by the dqn model.
*/ */
public class BoltzmannQ<O extends Encodable> extends Policy<O, Integer> { public class BoltzmannQ<O extends Encodable> extends Policy<Integer> {
final private IDQN dqn; final private IDQN dqn;
final private Random rnd; final private Random rnd;

View File

@ -32,8 +32,10 @@ import java.io.IOException;
* DQN policy returns the action with the maximum Q-value as evaluated * DQN policy returns the action with the maximum Q-value as evaluated
* by the dqn model * by the dqn model
*/ */
// FIXME: Should we rename this "GreedyPolicy"?
@AllArgsConstructor @AllArgsConstructor
public class DQNPolicy<O extends Encodable> extends Policy<O, Integer> { public class DQNPolicy<O> extends Policy<Integer> {
final private IDQN dqn; final private IDQN dqn;

View File

@ -41,9 +41,9 @@ import org.nd4j.linalg.api.rng.Random;
*/ */
@AllArgsConstructor @AllArgsConstructor
@Slf4j @Slf4j
public class EpsGreedy<O extends Encodable, A, AS extends ActionSpace<A>> extends Policy<O, A> { public class EpsGreedy<O extends Encodable, A, AS extends ActionSpace<A>> extends Policy<A> {
final private Policy<O, A> policy; final private Policy<A> policy;
final private MDP<O, A, AS> mdp; final private MDP<O, A, AS> mdp;
final private int updateStart; final private int updateStart;
final private int epsilonNbStep; final private int epsilonNbStep;

View File

@ -7,8 +7,14 @@ import org.deeplearning4j.rl4j.space.ActionSpace;
import org.deeplearning4j.rl4j.space.Encodable; import org.deeplearning4j.rl4j.space.Encodable;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
public interface IPolicy<O extends Encodable, A> { public interface IPolicy<ACTION> {
<AS extends ActionSpace<A>> double play(MDP<O, A, AS> mdp, IHistoryProcessor hp); @Deprecated
A nextAction(INDArray input); <O extends Encodable, AS extends ActionSpace<ACTION>> double play(MDP<O, ACTION, AS> mdp, IHistoryProcessor hp);
A nextAction(Observation observation);
@Deprecated
ACTION nextAction(INDArray input);
ACTION nextAction(Observation observation);
void reset();
} }

View File

@ -34,22 +34,22 @@ import org.deeplearning4j.rl4j.util.LegacyMDPWrapper;
* *
* A Policy responsability is to choose the next action given a state * A Policy responsability is to choose the next action given a state
*/ */
public abstract class Policy<O extends Encodable, A> implements IPolicy<O, A> { public abstract class Policy<A> implements IPolicy<A> {
public abstract NeuralNet getNeuralNet(); public abstract NeuralNet getNeuralNet();
public abstract A nextAction(Observation obs); public abstract A nextAction(Observation obs);
public <AS extends ActionSpace<A>> double play(MDP<O, A, AS> mdp) { public <O extends Encodable, AS extends ActionSpace<A>> double play(MDP<O, A, AS> mdp) {
return play(mdp, (IHistoryProcessor)null); return play(mdp, (IHistoryProcessor)null);
} }
public <AS extends ActionSpace<A>> double play(MDP<O, A, AS> mdp, HistoryProcessor.Configuration conf) { public <O extends Encodable, AS extends ActionSpace<A>> double play(MDP<O, A, AS> mdp, HistoryProcessor.Configuration conf) {
return play(mdp, new HistoryProcessor(conf)); return play(mdp, new HistoryProcessor(conf));
} }
@Override @Override
public <AS extends ActionSpace<A>> double play(MDP<O, A, AS> mdp, IHistoryProcessor hp) { public <O extends Encodable, AS extends ActionSpace<A>> double play(MDP<O, A, AS> mdp, IHistoryProcessor hp) {
resetNetworks(); resetNetworks();
LegacyMDPWrapper<O, A, AS> mdpWrapper = new LegacyMDPWrapper<O, A, AS>(mdp, hp); LegacyMDPWrapper<O, A, AS> mdpWrapper = new LegacyMDPWrapper<O, A, AS>(mdp, hp);
@ -84,8 +84,11 @@ public abstract class Policy<O extends Encodable, A> implements IPolicy<O, A> {
protected void resetNetworks() { protected void resetNetworks() {
getNeuralNet().reset(); getNeuralNet().reset();
} }
public void reset() {
resetNetworks();
}
protected <AS extends ActionSpace<A>> Learning.InitMdp<Observation> refacInitMdp(LegacyMDPWrapper<O, A, AS> mdpWrapper, IHistoryProcessor hp) { protected <O extends Encodable, AS extends ActionSpace<A>> Learning.InitMdp<Observation> refacInitMdp(LegacyMDPWrapper<O, A, AS> mdpWrapper, IHistoryProcessor hp) {
double reward = 0; double reward = 0;

View File

@ -0,0 +1,483 @@
package org.deeplearning4j.rl4j.agent;
import org.deeplearning4j.rl4j.agent.listener.AgentListener;
import org.deeplearning4j.rl4j.environment.ActionSchema;
import org.deeplearning4j.rl4j.environment.Environment;
import org.deeplearning4j.rl4j.environment.Schema;
import org.deeplearning4j.rl4j.environment.StepResult;
import org.deeplearning4j.rl4j.observation.Observation;
import org.deeplearning4j.rl4j.observation.transform.TransformProcess;
import org.deeplearning4j.rl4j.policy.IPolicy;
import org.junit.Rule;
import org.junit.Test;
import static org.junit.Assert.*;
import org.mockito.*;
import org.mockito.junit.*;
import org.nd4j.linalg.factory.Nd4j;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import static org.mockito.ArgumentMatchers.*;
import static org.mockito.Mockito.*;
public class AgentTest {
@Mock Environment environmentMock;
@Mock TransformProcess transformProcessMock;
@Mock IPolicy policyMock;
@Mock AgentListener listenerMock;
@Rule
public MockitoRule mockitoRule = MockitoJUnit.rule();
@Test
public void when_buildingWithNullEnvironment_expect_exception() {
try {
Agent.builder(null, null, null).build();
fail("NullPointerException should have been thrown");
} catch (NullPointerException exception) {
String expectedMessage = "environment is marked non-null but is null";
String actualMessage = exception.getMessage();
assertTrue(actualMessage.contains(expectedMessage));
}
}
@Test
public void when_buildingWithNullTransformProcess_expect_exception() {
try {
Agent.builder(environmentMock, null, null).build();
fail("NullPointerException should have been thrown");
} catch (NullPointerException exception) {
String expectedMessage = "transformProcess is marked non-null but is null";
String actualMessage = exception.getMessage();
assertTrue(actualMessage.contains(expectedMessage));
}
}
@Test
public void when_buildingWithNullPolicy_expect_exception() {
try {
Agent.builder(environmentMock, transformProcessMock, null).build();
fail("NullPointerException should have been thrown");
} catch (NullPointerException exception) {
String expectedMessage = "policy is marked non-null but is null";
String actualMessage = exception.getMessage();
assertTrue(actualMessage.contains(expectedMessage));
}
}
@Test
public void when_buildingWithInvalidMaxSteps_expect_exception() {
try {
Agent.builder(environmentMock, transformProcessMock, policyMock)
.maxEpisodeSteps(0)
.build();
fail("IllegalArgumentException should have been thrown");
} catch (IllegalArgumentException exception) {
String expectedMessage = "maxEpisodeSteps must be greater than 0, got [0]";
String actualMessage = exception.getMessage();
assertTrue(actualMessage.contains(expectedMessage));
}
}
@Test
public void when_buildingWithId_expect_idSetInAgent() {
// Arrange
Agent sut = Agent.builder(environmentMock, transformProcessMock, policyMock)
.id("TestAgent")
.build();
// Assert
assertEquals("TestAgent", sut.getId());
}
@Test
public void when_runIsCalled_expect_agentIsReset() {
// Arrange
Map<String, Object> envResetResult = new HashMap<>();
Schema schema = new Schema(new ActionSchema<>(-1));
when(environmentMock.reset()).thenReturn(envResetResult);
when(environmentMock.getSchema()).thenReturn(schema);
when(transformProcessMock.transform(any(Map.class), anyInt(), anyBoolean())).thenReturn(new Observation(Nd4j.create(new double[] { 123.0 })));
when(policyMock.nextAction(any(Observation.class))).thenReturn(1);
Agent sut = Agent.builder(environmentMock, transformProcessMock, policyMock)
.build();
when(listenerMock.onBeforeStep(any(Agent.class), any(Observation.class), anyInt())).thenReturn(AgentListener.ListenerResponse.STOP);
sut.addListener(listenerMock);
// Act
sut.run();
// Assert
assertEquals(0, sut.getEpisodeStepNumber());
verify(transformProcessMock).transform(envResetResult, 0, false);
verify(policyMock, times(1)).reset();
assertEquals(0.0, sut.getReward(), 0.00001);
verify(environmentMock, times(1)).reset();
}
@Test
public void when_runIsCalled_expect_onBeforeAndAfterEpisodeCalled() {
// Arrange
Map<String, Object> envResetResult = new HashMap<>();
Schema schema = new Schema(new ActionSchema<>(-1));
when(environmentMock.reset()).thenReturn(envResetResult);
when(environmentMock.getSchema()).thenReturn(schema);
when(transformProcessMock.transform(any(Map.class), anyInt(), anyBoolean())).thenReturn(new Observation(Nd4j.create(new double[] { 123.0 })));
when(environmentMock.isEpisodeFinished()).thenReturn(true);
Agent sut = Agent.builder(environmentMock, transformProcessMock, policyMock).build();
Agent spy = Mockito.spy(sut);
// Act
spy.run();
// Assert
verify(spy, times(1)).onBeforeEpisode();
verify(spy, times(1)).onAfterEpisode();
}
@Test
public void when_onBeforeEpisodeReturnsStop_expect_performStepAndOnAfterEpisodeNotCalled() {
// Arrange
Map<String, Object> envResetResult = new HashMap<>();
Schema schema = new Schema(new ActionSchema<>(-1));
when(environmentMock.reset()).thenReturn(envResetResult);
when(environmentMock.getSchema()).thenReturn(schema);
when(transformProcessMock.transform(any(Map.class), anyInt(), anyBoolean())).thenReturn(new Observation(Nd4j.create(new double[] { 123.0 })));
Agent sut = Agent.builder(environmentMock, transformProcessMock, policyMock).build();
when(listenerMock.onBeforeEpisode(any(Agent.class))).thenReturn(AgentListener.ListenerResponse.STOP);
sut.addListener(listenerMock);
Agent spy = Mockito.spy(sut);
// Act
spy.run();
// Assert
verify(spy, times(1)).onBeforeEpisode();
verify(spy, never()).performStep();
verify(spy, never()).onAfterStep(any(StepResult.class));
verify(spy, never()).onAfterEpisode();
}
@Test
public void when_runIsCalledWithoutMaxStep_expect_agentRunUntilEpisodeIsFinished() {
// Arrange
Map<String, Object> envResetResult = new HashMap<>();
Schema schema = new Schema(new ActionSchema<>(-1));
when(environmentMock.reset()).thenReturn(envResetResult);
when(environmentMock.getSchema()).thenReturn(schema);
when(transformProcessMock.transform(any(Map.class), anyInt(), anyBoolean())).thenReturn(new Observation(Nd4j.create(new double[] { 123.0 })));
Agent sut = Agent.builder(environmentMock, transformProcessMock, policyMock)
.build();
final Agent spy = Mockito.spy(sut);
doAnswer(invocation -> {
((Agent)invocation.getMock()).incrementEpisodeStepNumber();
return null;
}).when(spy).performStep();
when(environmentMock.isEpisodeFinished()).thenAnswer(invocation -> spy.getEpisodeStepNumber() >= 5 );
// Act
spy.run();
// Assert
verify(spy, times(1)).onBeforeEpisode();
verify(spy, times(5)).performStep();
verify(spy, times(1)).onAfterEpisode();
}
@Test
public void when_maxStepsIsReachedBeforeEposideEnds_expect_runTerminated() {
// Arrange
Map<String, Object> envResetResult = new HashMap<>();
Schema schema = new Schema(new ActionSchema<>(-1));
when(environmentMock.reset()).thenReturn(envResetResult);
when(environmentMock.getSchema()).thenReturn(schema);
when(transformProcessMock.transform(any(Map.class), anyInt(), anyBoolean())).thenReturn(new Observation(Nd4j.create(new double[] { 123.0 })));
Agent sut = Agent.builder(environmentMock, transformProcessMock, policyMock)
.maxEpisodeSteps(3)
.build();
final Agent spy = Mockito.spy(sut);
doAnswer(invocation -> {
((Agent)invocation.getMock()).incrementEpisodeStepNumber();
return null;
}).when(spy).performStep();
// Act
spy.run();
// Assert
verify(spy, times(1)).onBeforeEpisode();
verify(spy, times(3)).performStep();
verify(spy, times(1)).onAfterEpisode();
}
@Test
public void when_initialObservationsAreSkipped_expect_performNoOpAction() {
// Arrange
Map<String, Object> envResetResult = new HashMap<>();
Schema schema = new Schema(new ActionSchema<>(-1));
when(environmentMock.reset()).thenReturn(envResetResult);
when(environmentMock.getSchema()).thenReturn(schema);
when(transformProcessMock.transform(any(Map.class), anyInt(), anyBoolean())).thenReturn(Observation.SkippedObservation);
Agent sut = Agent.builder(environmentMock, transformProcessMock, policyMock)
.build();
when(listenerMock.onBeforeStep(any(Agent.class), any(Observation.class), any())).thenReturn(AgentListener.ListenerResponse.STOP);
sut.addListener(listenerMock);
Agent spy = Mockito.spy(sut);
// Act
spy.run();
// Assert
verify(listenerMock).onBeforeStep(any(), any(), eq(-1));
}
@Test
public void when_initialObservationsAreSkipped_expect_performNoOpActionAnd() {
// Arrange
Map<String, Object> envResetResult = new HashMap<>();
Schema schema = new Schema(new ActionSchema<>(-1));
when(environmentMock.reset()).thenReturn(envResetResult);
when(environmentMock.getSchema()).thenReturn(schema);
when(transformProcessMock.transform(any(Map.class), anyInt(), anyBoolean())).thenReturn(Observation.SkippedObservation);
Agent sut = Agent.builder(environmentMock, transformProcessMock, policyMock)
.build();
when(listenerMock.onBeforeStep(any(Agent.class), any(Observation.class), any())).thenReturn(AgentListener.ListenerResponse.STOP);
sut.addListener(listenerMock);
Agent spy = Mockito.spy(sut);
// Act
spy.run();
// Assert
verify(listenerMock).onBeforeStep(any(), any(), eq(-1));
}
@Test
public void when_observationsIsSkipped_expect_performLastAction() {
// Arrange
Map<String, Object> envResetResult = new HashMap<>();
Schema schema = new Schema(new ActionSchema<>(-1));
when(environmentMock.reset()).thenReturn(envResetResult);
when(environmentMock.step(any(Integer.class))).thenReturn(new StepResult(envResetResult, 0.0, false));
when(environmentMock.getSchema()).thenReturn(schema);
when(policyMock.nextAction(any(Observation.class)))
.thenAnswer(invocation -> (int)((Observation)invocation.getArgument(0)).getData().getDouble(0));
Agent sut = Agent.builder(environmentMock, transformProcessMock, policyMock)
.maxEpisodeSteps(3)
.build();
Agent spy = Mockito.spy(sut);
when(transformProcessMock.transform(any(Map.class), anyInt(), anyBoolean()))
.thenAnswer(invocation -> {
int stepNumber = (int)invocation.getArgument(1);
return stepNumber % 2 == 1 ? Observation.SkippedObservation
: new Observation(Nd4j.create(new double[] { stepNumber }));
});
sut.addListener(listenerMock);
// Act
spy.run();
// Assert
verify(policyMock, times(2)).nextAction(any(Observation.class));
ArgumentCaptor<Agent> agentCaptor = ArgumentCaptor.forClass(Agent.class);
ArgumentCaptor<Observation> observationCaptor = ArgumentCaptor.forClass(Observation.class);
ArgumentCaptor<Integer> actionCaptor = ArgumentCaptor.forClass(Integer.class);
verify(listenerMock, times(3)).onBeforeStep(agentCaptor.capture(), observationCaptor.capture(), actionCaptor.capture());
List<Integer> capturedActions = actionCaptor.getAllValues();
assertEquals(0, (int)capturedActions.get(0));
assertEquals(0, (int)capturedActions.get(1));
assertEquals(2, (int)capturedActions.get(2));
}
@Test
public void when_onBeforeStepReturnsStop_expect_performStepAndOnAfterEpisodeNotCalled() {
// Arrange
Schema schema = new Schema(new ActionSchema<>(-1));
when(environmentMock.reset()).thenReturn(new HashMap<>());
when(environmentMock.getSchema()).thenReturn(schema);
when(transformProcessMock.transform(any(Map.class), anyInt(), anyBoolean())).thenReturn(new Observation(Nd4j.create(new double[] { 123.0 })));
Agent sut = Agent.builder(environmentMock, transformProcessMock, policyMock).build();
when(listenerMock.onBeforeStep(any(Agent.class), any(Observation.class), any())).thenReturn(AgentListener.ListenerResponse.STOP);
sut.addListener(listenerMock);
Agent spy = Mockito.spy(sut);
// Act
spy.run();
// Assert
verify(spy, times(1)).onBeforeEpisode();
verify(spy, times(1)).onBeforeStep();
verify(spy, never()).act(any());
verify(spy, never()).onAfterStep(any(StepResult.class));
verify(spy, never()).onAfterEpisode();
}
@Test
public void when_observationIsNotSkipped_expect_policyActionIsSentToEnvironment() {
// Arrange
Schema schema = new Schema(new ActionSchema<>(-1));
when(environmentMock.reset()).thenReturn(new HashMap<>());
when(environmentMock.getSchema()).thenReturn(schema);
when(environmentMock.step(any(Integer.class))).thenReturn(new StepResult(new HashMap<>(), 0.0, false));
when(transformProcessMock.transform(any(Map.class), anyInt(), anyBoolean())).thenReturn(new Observation(Nd4j.create(new double[] { 123.0 })));
when(policyMock.nextAction(any(Observation.class))).thenReturn(123);
Agent sut = Agent.builder(environmentMock, transformProcessMock, policyMock)
.maxEpisodeSteps(1)
.build();
// Act
sut.run();
// Assert
verify(environmentMock, times(1)).step(123);
}
@Test
public void when_stepResultIsReceived_expect_observationAndRewardUpdated() {
// Arrange
Schema schema = new Schema(new ActionSchema<>(-1));
when(environmentMock.reset()).thenReturn(new HashMap<>());
when(environmentMock.getSchema()).thenReturn(schema);
when(environmentMock.step(any(Integer.class))).thenReturn(new StepResult(new HashMap<>(), 234.0, false));
when(transformProcessMock.transform(any(Map.class), anyInt(), anyBoolean())).thenReturn(new Observation(Nd4j.create(new double[] { 123.0 })));
when(policyMock.nextAction(any(Observation.class))).thenReturn(123);
Agent sut = Agent.builder(environmentMock, transformProcessMock, policyMock)
.maxEpisodeSteps(1)
.build();
// Act
sut.run();
// Assert
assertEquals(123.0, sut.getObservation().getData().getDouble(0), 0.00001);
assertEquals(234.0, sut.getReward(), 0.00001);
}
@Test
public void when_stepIsDone_expect_onAfterStepAndWithStepResult() {
// Arrange
Schema schema = new Schema(new ActionSchema<>(-1));
when(environmentMock.reset()).thenReturn(new HashMap<>());
when(environmentMock.getSchema()).thenReturn(schema);
StepResult stepResult = new StepResult(new HashMap<>(), 234.0, false);
when(environmentMock.step(any(Integer.class))).thenReturn(stepResult);
when(transformProcessMock.transform(any(Map.class), anyInt(), anyBoolean())).thenReturn(new Observation(Nd4j.create(new double[] { 123.0 })));
when(policyMock.nextAction(any(Observation.class))).thenReturn(123);
Agent sut = Agent.builder(environmentMock, transformProcessMock, policyMock)
.maxEpisodeSteps(1)
.build();
Agent spy = Mockito.spy(sut);
// Act
spy.run();
// Assert
verify(spy).onAfterStep(stepResult);
}
@Test
public void when_onAfterStepReturnsStop_expect_onAfterEpisodeNotCalled() {
// Arrange
Schema schema = new Schema(new ActionSchema<>(-1));
when(environmentMock.reset()).thenReturn(new HashMap<>());
when(environmentMock.getSchema()).thenReturn(schema);
StepResult stepResult = new StepResult(new HashMap<>(), 234.0, false);
when(environmentMock.step(any(Integer.class))).thenReturn(stepResult);
when(transformProcessMock.transform(any(Map.class), anyInt(), anyBoolean())).thenReturn(new Observation(Nd4j.create(new double[] { 123.0 })));
when(policyMock.nextAction(any(Observation.class))).thenReturn(123);
Agent sut = Agent.builder(environmentMock, transformProcessMock, policyMock)
.maxEpisodeSteps(1)
.build();
when(listenerMock.onAfterStep(any(Agent.class), any(StepResult.class))).thenReturn(AgentListener.ListenerResponse.STOP);
sut.addListener(listenerMock);
Agent spy = Mockito.spy(sut);
// Act
spy.run();
// Assert
verify(spy, never()).onAfterEpisode();
}
@Test
public void when_runIsCalled_expect_onAfterEpisodeIsCalled() {
// Arrange
Schema schema = new Schema(new ActionSchema<>(-1));
when(environmentMock.reset()).thenReturn(new HashMap<>());
when(environmentMock.getSchema()).thenReturn(schema);
StepResult stepResult = new StepResult(new HashMap<>(), 234.0, false);
when(environmentMock.step(any(Integer.class))).thenReturn(stepResult);
when(transformProcessMock.transform(any(Map.class), anyInt(), anyBoolean())).thenReturn(new Observation(Nd4j.create(new double[] { 123.0 })));
when(policyMock.nextAction(any(Observation.class))).thenReturn(123);
Agent sut = Agent.builder(environmentMock, transformProcessMock, policyMock)
.maxEpisodeSteps(1)
.build();
Agent spy = Mockito.spy(sut);
// Act
spy.run();
// Assert
verify(spy, times(1)).onAfterEpisode();
}
}

View File

@ -62,7 +62,7 @@ public class AsyncThreadDiscreteTest {
IAsyncGlobal<NeuralNet> mockAsyncGlobal; IAsyncGlobal<NeuralNet> mockAsyncGlobal;
@Mock @Mock
Policy<Encodable, Integer> mockGlobalCurrentPolicy; Policy<Integer> mockGlobalCurrentPolicy;
@Mock @Mock
NeuralNet mockGlobalTargetNetwork; NeuralNet mockGlobalTargetNetwork;

View File

@ -30,13 +30,8 @@ import org.deeplearning4j.rl4j.network.NeuralNet;
import org.deeplearning4j.rl4j.network.ac.IActorCritic; import org.deeplearning4j.rl4j.network.ac.IActorCritic;
import org.deeplearning4j.rl4j.observation.Observation; import org.deeplearning4j.rl4j.observation.Observation;
import org.deeplearning4j.rl4j.space.ActionSpace; import org.deeplearning4j.rl4j.space.ActionSpace;
import org.deeplearning4j.rl4j.support.MockDQN; import org.deeplearning4j.rl4j.space.Encodable;
import org.deeplearning4j.rl4j.support.MockEncodable; import org.deeplearning4j.rl4j.support.*;
import org.deeplearning4j.rl4j.support.MockHistoryProcessor;
import org.deeplearning4j.rl4j.support.MockMDP;
import org.deeplearning4j.rl4j.support.MockNeuralNet;
import org.deeplearning4j.rl4j.support.MockObservationSpace;
import org.deeplearning4j.rl4j.support.MockRandom;
import org.deeplearning4j.rl4j.util.LegacyMDPWrapper; import org.deeplearning4j.rl4j.util.LegacyMDPWrapper;
import org.junit.Test; import org.junit.Test;
import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.activations.Activation;
@ -227,7 +222,7 @@ public class PolicyTest {
assertEquals(0, dqn.outputParams.size()); assertEquals(0, dqn.outputParams.size());
} }
public static class MockRefacPolicy extends Policy<MockEncodable, Integer> { public static class MockRefacPolicy extends Policy<Integer> {
private NeuralNet neuralNet; private NeuralNet neuralNet;
private final int[] shape; private final int[] shape;
@ -257,7 +252,7 @@ public class PolicyTest {
} }
@Override @Override
protected <AS extends ActionSpace<Integer>> Learning.InitMdp<Observation> refacInitMdp(LegacyMDPWrapper<MockEncodable, Integer, AS> mdpWrapper, IHistoryProcessor hp) { protected <O extends Encodable, AS extends ActionSpace<Integer>> Learning.InitMdp<Observation> refacInitMdp(LegacyMDPWrapper<O, Integer, AS> mdpWrapper, IHistoryProcessor hp) {
mdpWrapper.setTransformProcess(MockMDP.buildTransformProcess(shape, skipFrame, historyLength)); mdpWrapper.setTransformProcess(MockMDP.buildTransformProcess(shape, skipFrame, historyLength));
return super.refacInitMdp(mdpWrapper, hp); return super.refacInitMdp(mdpWrapper, hp);
} }

View File

@ -5,18 +5,19 @@ import org.deeplearning4j.rl4j.mdp.MDP;
import org.deeplearning4j.rl4j.observation.Observation; import org.deeplearning4j.rl4j.observation.Observation;
import org.deeplearning4j.rl4j.policy.IPolicy; import org.deeplearning4j.rl4j.policy.IPolicy;
import org.deeplearning4j.rl4j.space.ActionSpace; import org.deeplearning4j.rl4j.space.ActionSpace;
import org.deeplearning4j.rl4j.space.Encodable;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.List; import java.util.List;
public class MockPolicy implements IPolicy<MockEncodable, Integer> { public class MockPolicy implements IPolicy<Integer> {
public int playCallCount = 0; public int playCallCount = 0;
public List<INDArray> actionInputs = new ArrayList<INDArray>(); public List<INDArray> actionInputs = new ArrayList<INDArray>();
@Override @Override
public <AS extends ActionSpace<Integer>> double play(MDP<MockEncodable, Integer, AS> mdp, IHistoryProcessor hp) { public <O extends Encodable, AS extends ActionSpace<Integer>> double play(MDP<O, Integer, AS> mdp, IHistoryProcessor hp) {
++playCallCount; ++playCallCount;
return 0; return 0;
} }
@ -31,4 +32,9 @@ public class MockPolicy implements IPolicy<MockEncodable, Integer> {
public Integer nextAction(Observation observation) { public Integer nextAction(Observation observation) {
return nextAction(observation.getData()); return nextAction(observation.getData());
} }
@Override
public void reset() {
}
} }