RL4J: Add Agent and Environment (#358)
* Added Agent and Environment Signed-off-by: Alexandre Boulanger <aboulang2002@yahoo.com> * Added headers Signed-off-by: Alexandre Boulanger <aboulang2002@yahoo.com> * Fix compilation errors Signed-off-by: Samuel Audet <samuel.audet@gmail.com>master
parent
a10fd4524a
commit
550e84ef43
|
@ -0,0 +1,210 @@
|
|||
package org.deeplearning4j.rl4j.agent;
|
||||
|
||||
import lombok.AccessLevel;
|
||||
import lombok.Getter;
|
||||
import lombok.NonNull;
|
||||
import org.deeplearning4j.rl4j.agent.listener.AgentListener;
|
||||
import org.deeplearning4j.rl4j.agent.listener.AgentListenerList;
|
||||
import org.deeplearning4j.rl4j.environment.Environment;
|
||||
import org.deeplearning4j.rl4j.environment.StepResult;
|
||||
import org.deeplearning4j.rl4j.observation.Observation;
|
||||
import org.deeplearning4j.rl4j.observation.transform.TransformProcess;
|
||||
import org.deeplearning4j.rl4j.policy.IPolicy;
|
||||
import org.nd4j.base.Preconditions;
|
||||
|
||||
import java.util.Map;
|
||||
|
||||
public class Agent<ACTION> {
|
||||
@Getter
|
||||
private final String id;
|
||||
|
||||
@Getter
|
||||
private final Environment<ACTION> environment;
|
||||
|
||||
@Getter
|
||||
private final IPolicy<ACTION> policy;
|
||||
|
||||
private final TransformProcess transformProcess;
|
||||
|
||||
protected final AgentListenerList<ACTION> listeners;
|
||||
|
||||
private final Integer maxEpisodeSteps;
|
||||
|
||||
@Getter(AccessLevel.PROTECTED)
|
||||
private Observation observation;
|
||||
|
||||
@Getter(AccessLevel.PROTECTED)
|
||||
private ACTION lastAction;
|
||||
|
||||
@Getter
|
||||
private int episodeStepNumber;
|
||||
|
||||
@Getter
|
||||
private double reward;
|
||||
|
||||
protected boolean canContinue;
|
||||
|
||||
private Agent(Builder<ACTION> builder) {
|
||||
this.environment = builder.environment;
|
||||
this.transformProcess = builder.transformProcess;
|
||||
this.policy = builder.policy;
|
||||
this.maxEpisodeSteps = builder.maxEpisodeSteps;
|
||||
this.id = builder.id;
|
||||
|
||||
listeners = buildListenerList();
|
||||
}
|
||||
|
||||
protected AgentListenerList<ACTION> buildListenerList() {
|
||||
return new AgentListenerList<ACTION>();
|
||||
}
|
||||
|
||||
public void addListener(AgentListener listener) {
|
||||
listeners.add(listener);
|
||||
}
|
||||
|
||||
public void run() {
|
||||
runEpisode();
|
||||
}
|
||||
|
||||
protected void onBeforeEpisode() {
|
||||
// Do Nothing
|
||||
}
|
||||
|
||||
protected void onAfterEpisode() {
|
||||
// Do Nothing
|
||||
}
|
||||
|
||||
protected void runEpisode() {
|
||||
reset();
|
||||
onBeforeEpisode();
|
||||
|
||||
canContinue = listeners.notifyBeforeEpisode(this);
|
||||
|
||||
while (canContinue && !environment.isEpisodeFinished() && (maxEpisodeSteps == null || episodeStepNumber < maxEpisodeSteps)) {
|
||||
performStep();
|
||||
}
|
||||
|
||||
if(!canContinue) {
|
||||
return;
|
||||
}
|
||||
|
||||
onAfterEpisode();
|
||||
}
|
||||
|
||||
protected void reset() {
|
||||
resetEnvironment();
|
||||
resetPolicy();
|
||||
reward = 0;
|
||||
lastAction = getInitialAction();
|
||||
canContinue = true;
|
||||
}
|
||||
|
||||
protected void resetEnvironment() {
|
||||
episodeStepNumber = 0;
|
||||
Map<String, Object> channelsData = environment.reset();
|
||||
this.observation = transformProcess.transform(channelsData, episodeStepNumber, false);
|
||||
}
|
||||
|
||||
protected void resetPolicy() {
|
||||
policy.reset();
|
||||
}
|
||||
|
||||
protected ACTION getInitialAction() {
|
||||
return environment.getSchema().getActionSchema().getNoOp();
|
||||
}
|
||||
|
||||
protected void performStep() {
|
||||
|
||||
onBeforeStep();
|
||||
|
||||
ACTION action = decideAction(observation);
|
||||
|
||||
canContinue = listeners.notifyBeforeStep(this, observation, action);
|
||||
if(!canContinue) {
|
||||
return;
|
||||
}
|
||||
|
||||
StepResult stepResult = act(action);
|
||||
handleStepResult(stepResult);
|
||||
|
||||
onAfterStep(stepResult);
|
||||
|
||||
canContinue = listeners.notifyAfterStep(this, stepResult);
|
||||
if(!canContinue) {
|
||||
return;
|
||||
}
|
||||
|
||||
incrementEpisodeStepNumber();
|
||||
}
|
||||
|
||||
protected void incrementEpisodeStepNumber() {
|
||||
++episodeStepNumber;
|
||||
}
|
||||
|
||||
protected ACTION decideAction(Observation observation) {
|
||||
if (!observation.isSkipped()) {
|
||||
lastAction = policy.nextAction(observation);
|
||||
}
|
||||
|
||||
return lastAction;
|
||||
}
|
||||
|
||||
protected StepResult act(ACTION action) {
|
||||
return environment.step(action);
|
||||
}
|
||||
|
||||
protected void handleStepResult(StepResult stepResult) {
|
||||
observation = convertChannelDataToObservation(stepResult, episodeStepNumber + 1);
|
||||
reward +=computeReward(stepResult);
|
||||
}
|
||||
|
||||
protected Observation convertChannelDataToObservation(StepResult stepResult, int episodeStepNumberOfObs) {
|
||||
return transformProcess.transform(stepResult.getChannelsData(), episodeStepNumberOfObs, stepResult.isTerminal());
|
||||
}
|
||||
|
||||
protected double computeReward(StepResult stepResult) {
|
||||
return stepResult.getReward();
|
||||
}
|
||||
|
||||
protected void onAfterStep(StepResult stepResult) {
|
||||
// Do Nothing
|
||||
}
|
||||
|
||||
protected void onBeforeStep() {
|
||||
// Do Nothing
|
||||
}
|
||||
|
||||
public static <ACTION> Builder<ACTION> builder(@NonNull Environment<ACTION> environment, @NonNull TransformProcess transformProcess, @NonNull IPolicy<ACTION> policy) {
|
||||
return new Builder<>(environment, transformProcess, policy);
|
||||
}
|
||||
|
||||
public static class Builder<ACTION> {
|
||||
private final Environment<ACTION> environment;
|
||||
private final TransformProcess transformProcess;
|
||||
private final IPolicy<ACTION> policy;
|
||||
private Integer maxEpisodeSteps = null; // Default, no max
|
||||
private String id;
|
||||
|
||||
public Builder(@NonNull Environment<ACTION> environment, @NonNull TransformProcess transformProcess, @NonNull IPolicy<ACTION> policy) {
|
||||
this.environment = environment;
|
||||
this.transformProcess = transformProcess;
|
||||
this.policy = policy;
|
||||
}
|
||||
|
||||
public Builder<ACTION> maxEpisodeSteps(int maxEpisodeSteps) {
|
||||
Preconditions.checkArgument(maxEpisodeSteps > 0, "maxEpisodeSteps must be greater than 0, got", maxEpisodeSteps);
|
||||
this.maxEpisodeSteps = maxEpisodeSteps;
|
||||
|
||||
return this;
|
||||
}
|
||||
|
||||
public Builder<ACTION> id(String id) {
|
||||
this.id = id;
|
||||
return this;
|
||||
}
|
||||
|
||||
public Agent build() {
|
||||
return new Agent(this);
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,23 @@
|
|||
package org.deeplearning4j.rl4j.agent.listener;
|
||||
|
||||
import org.deeplearning4j.rl4j.agent.Agent;
|
||||
import org.deeplearning4j.rl4j.environment.StepResult;
|
||||
import org.deeplearning4j.rl4j.observation.Observation;
|
||||
|
||||
public interface AgentListener<ACTION> {
|
||||
enum ListenerResponse {
|
||||
/**
|
||||
* Tell the learning process to continue calling the listeners and the training.
|
||||
*/
|
||||
CONTINUE,
|
||||
|
||||
/**
|
||||
* Tell the learning process to stop calling the listeners and terminate the training.
|
||||
*/
|
||||
STOP,
|
||||
}
|
||||
|
||||
AgentListener.ListenerResponse onBeforeEpisode(Agent agent);
|
||||
AgentListener.ListenerResponse onBeforeStep(Agent agent, Observation observation, ACTION action);
|
||||
AgentListener.ListenerResponse onAfterStep(Agent agent, StepResult stepResult);
|
||||
}
|
|
@ -0,0 +1,50 @@
|
|||
package org.deeplearning4j.rl4j.agent.listener;
|
||||
|
||||
import org.deeplearning4j.rl4j.agent.Agent;
|
||||
import org.deeplearning4j.rl4j.environment.StepResult;
|
||||
import org.deeplearning4j.rl4j.observation.Observation;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
|
||||
public class AgentListenerList<ACTION> {
|
||||
protected final List<AgentListener<ACTION>> listeners = new ArrayList<>();
|
||||
|
||||
/**
|
||||
* Add a listener at the end of the list
|
||||
* @param listener The listener to be added
|
||||
*/
|
||||
public void add(AgentListener<ACTION> listener) {
|
||||
listeners.add(listener);
|
||||
}
|
||||
|
||||
public boolean notifyBeforeEpisode(Agent<ACTION> agent) {
|
||||
for (AgentListener<ACTION> listener : listeners) {
|
||||
if (listener.onBeforeEpisode(agent) == AgentListener.ListenerResponse.STOP) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
public boolean notifyBeforeStep(Agent<ACTION> agent, Observation observation, ACTION action) {
|
||||
for (AgentListener<ACTION> listener : listeners) {
|
||||
if (listener.onBeforeStep(agent, observation, action) == AgentListener.ListenerResponse.STOP) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
public boolean notifyAfterStep(Agent<ACTION> agent, StepResult stepResult) {
|
||||
for (AgentListener<ACTION> listener : listeners) {
|
||||
if (listener.onAfterStep(agent, stepResult) == AgentListener.ListenerResponse.STOP) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
}
|
|
@ -0,0 +1,9 @@
|
|||
package org.deeplearning4j.rl4j.environment;
|
||||
|
||||
import lombok.Value;
|
||||
|
||||
@Value
|
||||
public class ActionSchema<ACTION> {
|
||||
private ACTION noOp;
|
||||
//FIXME ACTION randomAction();
|
||||
}
|
|
@ -0,0 +1,11 @@
|
|||
package org.deeplearning4j.rl4j.environment;
|
||||
|
||||
import java.util.Map;
|
||||
|
||||
public interface Environment<ACTION> {
|
||||
Schema<ACTION> getSchema();
|
||||
Map<String, Object> reset();
|
||||
StepResult step(ACTION action);
|
||||
boolean isEpisodeFinished();
|
||||
void close();
|
||||
}
|
|
@ -0,0 +1,8 @@
|
|||
package org.deeplearning4j.rl4j.environment;
|
||||
|
||||
import lombok.Value;
|
||||
|
||||
@Value
|
||||
public class Schema<ACTION> {
|
||||
private ActionSchema<ACTION> actionSchema;
|
||||
}
|
|
@ -0,0 +1,12 @@
|
|||
package org.deeplearning4j.rl4j.environment;
|
||||
|
||||
import lombok.Value;
|
||||
|
||||
import java.util.Map;
|
||||
|
||||
@Value
|
||||
public class StepResult {
|
||||
private Map<String, Object> channelsData;
|
||||
private double reward;
|
||||
private boolean terminal;
|
||||
}
|
|
@ -30,7 +30,7 @@ import org.deeplearning4j.rl4j.space.Encodable;
|
|||
*/
|
||||
public interface ILearning<O extends Encodable, A, AS extends ActionSpace<A>> {
|
||||
|
||||
IPolicy<O, A> getPolicy();
|
||||
IPolicy<A> getPolicy();
|
||||
|
||||
void train();
|
||||
|
||||
|
|
|
@ -221,7 +221,7 @@ public abstract class AsyncThread<OBSERVATION extends Encodable, ACTION, ACTION_
|
|||
|
||||
protected abstract IAsyncLearningConfiguration getConf();
|
||||
|
||||
protected abstract IPolicy<OBSERVATION, ACTION> getPolicy(NN net);
|
||||
protected abstract IPolicy<ACTION> getPolicy(NN net);
|
||||
|
||||
protected abstract SubEpochReturn trainSubEpoch(Observation obs, int nstep);
|
||||
|
||||
|
|
|
@ -97,7 +97,7 @@ public abstract class AsyncThreadDiscrete<O extends Encodable, NN extends Neural
|
|||
current.copy(getAsyncGlobal().getTarget());
|
||||
|
||||
Observation obs = sObs;
|
||||
IPolicy<O, Integer> policy = getPolicy(current);
|
||||
IPolicy<Integer> policy = getPolicy(current);
|
||||
|
||||
Integer action = getMdp().getActionSpace().noOp();
|
||||
|
||||
|
|
|
@ -65,7 +65,7 @@ public class A3CThreadDiscrete<O extends Encodable> extends AsyncThreadDiscrete<
|
|||
}
|
||||
|
||||
@Override
|
||||
protected Policy<O, Integer> getPolicy(IActorCritic net) {
|
||||
protected Policy<Integer> getPolicy(IActorCritic net) {
|
||||
return new ACPolicy(net, rnd);
|
||||
}
|
||||
|
||||
|
|
|
@ -62,7 +62,7 @@ public abstract class AsyncNStepQLearningDiscrete<O extends Encodable>
|
|||
return asyncGlobal.getTarget();
|
||||
}
|
||||
|
||||
public IPolicy<O, Integer> getPolicy() {
|
||||
public IPolicy<Integer> getPolicy() {
|
||||
return new DQNPolicy<O>(getNeuralNet());
|
||||
}
|
||||
|
||||
|
|
|
@ -64,7 +64,7 @@ public class AsyncNStepQLearningThreadDiscrete<O extends Encodable> extends Asyn
|
|||
setUpdateAlgorithm(buildUpdateAlgorithm());
|
||||
}
|
||||
|
||||
public Policy<O, Integer> getPolicy(IDQN nn) {
|
||||
public Policy<Integer> getPolicy(IDQN nn) {
|
||||
return new EpsGreedy(new DQNPolicy(nn), getMdp(), conf.getUpdateStart(), conf.getEpsilonNbStep(),
|
||||
rnd, conf.getMinEpsilon(), this);
|
||||
}
|
||||
|
|
|
@ -0,0 +1,129 @@
|
|||
package org.deeplearning4j.rl4j.mdp;
|
||||
|
||||
import lombok.Getter;
|
||||
import lombok.Setter;
|
||||
import org.deeplearning4j.rl4j.environment.ActionSchema;
|
||||
import org.deeplearning4j.rl4j.environment.Environment;
|
||||
import org.deeplearning4j.rl4j.environment.Schema;
|
||||
import org.deeplearning4j.rl4j.environment.StepResult;
|
||||
|
||||
import java.util.HashMap;
|
||||
import java.util.Map;
|
||||
import java.util.Random;
|
||||
|
||||
public class CartpoleEnvironment implements Environment<Integer> {
|
||||
private static final int NUM_ACTIONS = 2;
|
||||
private static final int ACTION_LEFT = 0;
|
||||
private static final int ACTION_RIGHT = 1;
|
||||
|
||||
private static final Schema<Integer> schema = new Schema<>(new ActionSchema<>(ACTION_LEFT));
|
||||
|
||||
public enum KinematicsIntegrators { Euler, SemiImplicitEuler };
|
||||
|
||||
private static final double gravity = 9.8;
|
||||
private static final double massCart = 1.0;
|
||||
private static final double massPole = 0.1;
|
||||
private static final double totalMass = massPole + massCart;
|
||||
private static final double length = 0.5; // actually half the pole's length
|
||||
private static final double polemassLength = massPole * length;
|
||||
private static final double forceMag = 10.0;
|
||||
private static final double tau = 0.02; // seconds between state updates
|
||||
|
||||
// Angle at which to fail the episode
|
||||
private static final double thetaThresholdRadians = 12.0 * 2.0 * Math.PI / 360.0;
|
||||
private static final double xThreshold = 2.4;
|
||||
|
||||
private final Random rnd;
|
||||
|
||||
@Getter @Setter
|
||||
private KinematicsIntegrators kinematicsIntegrator = KinematicsIntegrators.Euler;
|
||||
|
||||
@Getter
|
||||
private boolean episodeFinished = false;
|
||||
|
||||
private double x;
|
||||
private double xDot;
|
||||
private double theta;
|
||||
private double thetaDot;
|
||||
private Integer stepsBeyondDone;
|
||||
|
||||
public CartpoleEnvironment() {
|
||||
rnd = new Random();
|
||||
}
|
||||
|
||||
public CartpoleEnvironment(int seed) {
|
||||
rnd = new Random(seed);
|
||||
}
|
||||
|
||||
@Override
|
||||
public Schema<Integer> getSchema() {
|
||||
return schema;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Map<String, Object> reset() {
|
||||
|
||||
x = 0.1 * rnd.nextDouble() - 0.05;
|
||||
xDot = 0.1 * rnd.nextDouble() - 0.05;
|
||||
theta = 0.1 * rnd.nextDouble() - 0.05;
|
||||
thetaDot = 0.1 * rnd.nextDouble() - 0.05;
|
||||
stepsBeyondDone = null;
|
||||
episodeFinished = false;
|
||||
|
||||
return new HashMap<String, Object>() {{
|
||||
put("data", new double[]{x, xDot, theta, thetaDot});
|
||||
}};
|
||||
}
|
||||
|
||||
@Override
|
||||
public StepResult step(Integer action) {
|
||||
double force = action == ACTION_RIGHT ? forceMag : -forceMag;
|
||||
double cosTheta = Math.cos(theta);
|
||||
double sinTheta = Math.sin(theta);
|
||||
double temp = (force + polemassLength * thetaDot * thetaDot * sinTheta) / totalMass;
|
||||
double thetaAcc = (gravity * sinTheta - cosTheta* temp) / (length * (4.0/3.0 - massPole * cosTheta * cosTheta / totalMass));
|
||||
double xAcc = temp - polemassLength * thetaAcc * cosTheta / totalMass;
|
||||
|
||||
switch(kinematicsIntegrator) {
|
||||
case Euler:
|
||||
x += tau * xDot;
|
||||
xDot += tau * xAcc;
|
||||
theta += tau * thetaDot;
|
||||
thetaDot += tau * thetaAcc;
|
||||
break;
|
||||
|
||||
case SemiImplicitEuler:
|
||||
xDot += tau * xAcc;
|
||||
x += tau * xDot;
|
||||
thetaDot += tau * thetaAcc;
|
||||
theta += tau * thetaDot;
|
||||
break;
|
||||
}
|
||||
|
||||
episodeFinished |= x < -xThreshold || x > xThreshold
|
||||
|| theta < -thetaThresholdRadians || theta > thetaThresholdRadians;
|
||||
|
||||
double reward;
|
||||
if(!episodeFinished) {
|
||||
reward = 1.0;
|
||||
}
|
||||
else if(stepsBeyondDone == null) {
|
||||
stepsBeyondDone = 0;
|
||||
reward = 1.0;
|
||||
}
|
||||
else {
|
||||
++stepsBeyondDone;
|
||||
reward = 0;
|
||||
}
|
||||
|
||||
Map<String, Object> channelsData = new HashMap<String, Object>() {{
|
||||
put("data", new double[]{x, xDot, theta, thetaDot});
|
||||
}};
|
||||
return new StepResult(channelsData, reward, episodeFinished);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void close() {
|
||||
// Do nothing
|
||||
}
|
||||
}
|
|
@ -35,7 +35,7 @@ import java.io.IOException;
|
|||
* the softmax output of the actor critic, but objects constructed
|
||||
* with a {@link Random} argument of null return the max only.
|
||||
*/
|
||||
public class ACPolicy<O extends Encodable> extends Policy<O, Integer> {
|
||||
public class ACPolicy<O extends Encodable> extends Policy<Integer> {
|
||||
|
||||
final private IActorCritic actorCritic;
|
||||
Random rnd;
|
||||
|
|
|
@ -30,7 +30,7 @@ import static org.nd4j.linalg.ops.transforms.Transforms.exp;
|
|||
* Boltzmann exploration is a stochastic policy wrt to the
|
||||
* exponential Q-values as evaluated by the dqn model.
|
||||
*/
|
||||
public class BoltzmannQ<O extends Encodable> extends Policy<O, Integer> {
|
||||
public class BoltzmannQ<O extends Encodable> extends Policy<Integer> {
|
||||
|
||||
final private IDQN dqn;
|
||||
final private Random rnd;
|
||||
|
|
|
@ -32,8 +32,10 @@ import java.io.IOException;
|
|||
* DQN policy returns the action with the maximum Q-value as evaluated
|
||||
* by the dqn model
|
||||
*/
|
||||
|
||||
// FIXME: Should we rename this "GreedyPolicy"?
|
||||
@AllArgsConstructor
|
||||
public class DQNPolicy<O extends Encodable> extends Policy<O, Integer> {
|
||||
public class DQNPolicy<O> extends Policy<Integer> {
|
||||
|
||||
final private IDQN dqn;
|
||||
|
||||
|
|
|
@ -41,9 +41,9 @@ import org.nd4j.linalg.api.rng.Random;
|
|||
*/
|
||||
@AllArgsConstructor
|
||||
@Slf4j
|
||||
public class EpsGreedy<O extends Encodable, A, AS extends ActionSpace<A>> extends Policy<O, A> {
|
||||
public class EpsGreedy<O extends Encodable, A, AS extends ActionSpace<A>> extends Policy<A> {
|
||||
|
||||
final private Policy<O, A> policy;
|
||||
final private Policy<A> policy;
|
||||
final private MDP<O, A, AS> mdp;
|
||||
final private int updateStart;
|
||||
final private int epsilonNbStep;
|
||||
|
|
|
@ -7,8 +7,14 @@ import org.deeplearning4j.rl4j.space.ActionSpace;
|
|||
import org.deeplearning4j.rl4j.space.Encodable;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
|
||||
public interface IPolicy<O extends Encodable, A> {
|
||||
<AS extends ActionSpace<A>> double play(MDP<O, A, AS> mdp, IHistoryProcessor hp);
|
||||
A nextAction(INDArray input);
|
||||
A nextAction(Observation observation);
|
||||
public interface IPolicy<ACTION> {
|
||||
@Deprecated
|
||||
<O extends Encodable, AS extends ActionSpace<ACTION>> double play(MDP<O, ACTION, AS> mdp, IHistoryProcessor hp);
|
||||
|
||||
@Deprecated
|
||||
ACTION nextAction(INDArray input);
|
||||
|
||||
ACTION nextAction(Observation observation);
|
||||
|
||||
void reset();
|
||||
}
|
||||
|
|
|
@ -34,22 +34,22 @@ import org.deeplearning4j.rl4j.util.LegacyMDPWrapper;
|
|||
*
|
||||
* A Policy responsability is to choose the next action given a state
|
||||
*/
|
||||
public abstract class Policy<O extends Encodable, A> implements IPolicy<O, A> {
|
||||
public abstract class Policy<A> implements IPolicy<A> {
|
||||
|
||||
public abstract NeuralNet getNeuralNet();
|
||||
|
||||
public abstract A nextAction(Observation obs);
|
||||
|
||||
public <AS extends ActionSpace<A>> double play(MDP<O, A, AS> mdp) {
|
||||
public <O extends Encodable, AS extends ActionSpace<A>> double play(MDP<O, A, AS> mdp) {
|
||||
return play(mdp, (IHistoryProcessor)null);
|
||||
}
|
||||
|
||||
public <AS extends ActionSpace<A>> double play(MDP<O, A, AS> mdp, HistoryProcessor.Configuration conf) {
|
||||
public <O extends Encodable, AS extends ActionSpace<A>> double play(MDP<O, A, AS> mdp, HistoryProcessor.Configuration conf) {
|
||||
return play(mdp, new HistoryProcessor(conf));
|
||||
}
|
||||
|
||||
@Override
|
||||
public <AS extends ActionSpace<A>> double play(MDP<O, A, AS> mdp, IHistoryProcessor hp) {
|
||||
public <O extends Encodable, AS extends ActionSpace<A>> double play(MDP<O, A, AS> mdp, IHistoryProcessor hp) {
|
||||
resetNetworks();
|
||||
|
||||
LegacyMDPWrapper<O, A, AS> mdpWrapper = new LegacyMDPWrapper<O, A, AS>(mdp, hp);
|
||||
|
@ -84,8 +84,11 @@ public abstract class Policy<O extends Encodable, A> implements IPolicy<O, A> {
|
|||
protected void resetNetworks() {
|
||||
getNeuralNet().reset();
|
||||
}
|
||||
public void reset() {
|
||||
resetNetworks();
|
||||
}
|
||||
|
||||
protected <AS extends ActionSpace<A>> Learning.InitMdp<Observation> refacInitMdp(LegacyMDPWrapper<O, A, AS> mdpWrapper, IHistoryProcessor hp) {
|
||||
protected <O extends Encodable, AS extends ActionSpace<A>> Learning.InitMdp<Observation> refacInitMdp(LegacyMDPWrapper<O, A, AS> mdpWrapper, IHistoryProcessor hp) {
|
||||
|
||||
double reward = 0;
|
||||
|
||||
|
|
|
@ -0,0 +1,483 @@
|
|||
package org.deeplearning4j.rl4j.agent;
|
||||
|
||||
import org.deeplearning4j.rl4j.agent.listener.AgentListener;
|
||||
import org.deeplearning4j.rl4j.environment.ActionSchema;
|
||||
import org.deeplearning4j.rl4j.environment.Environment;
|
||||
import org.deeplearning4j.rl4j.environment.Schema;
|
||||
import org.deeplearning4j.rl4j.environment.StepResult;
|
||||
import org.deeplearning4j.rl4j.observation.Observation;
|
||||
import org.deeplearning4j.rl4j.observation.transform.TransformProcess;
|
||||
import org.deeplearning4j.rl4j.policy.IPolicy;
|
||||
import org.junit.Rule;
|
||||
import org.junit.Test;
|
||||
import static org.junit.Assert.*;
|
||||
|
||||
import org.mockito.*;
|
||||
import org.mockito.junit.*;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
import static org.mockito.ArgumentMatchers.*;
|
||||
import static org.mockito.Mockito.*;
|
||||
|
||||
public class AgentTest {
|
||||
|
||||
@Mock Environment environmentMock;
|
||||
@Mock TransformProcess transformProcessMock;
|
||||
@Mock IPolicy policyMock;
|
||||
@Mock AgentListener listenerMock;
|
||||
|
||||
@Rule
|
||||
public MockitoRule mockitoRule = MockitoJUnit.rule();
|
||||
|
||||
@Test
|
||||
public void when_buildingWithNullEnvironment_expect_exception() {
|
||||
try {
|
||||
Agent.builder(null, null, null).build();
|
||||
fail("NullPointerException should have been thrown");
|
||||
} catch (NullPointerException exception) {
|
||||
String expectedMessage = "environment is marked non-null but is null";
|
||||
String actualMessage = exception.getMessage();
|
||||
|
||||
assertTrue(actualMessage.contains(expectedMessage));
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void when_buildingWithNullTransformProcess_expect_exception() {
|
||||
try {
|
||||
Agent.builder(environmentMock, null, null).build();
|
||||
fail("NullPointerException should have been thrown");
|
||||
} catch (NullPointerException exception) {
|
||||
String expectedMessage = "transformProcess is marked non-null but is null";
|
||||
String actualMessage = exception.getMessage();
|
||||
|
||||
assertTrue(actualMessage.contains(expectedMessage));
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void when_buildingWithNullPolicy_expect_exception() {
|
||||
try {
|
||||
Agent.builder(environmentMock, transformProcessMock, null).build();
|
||||
fail("NullPointerException should have been thrown");
|
||||
} catch (NullPointerException exception) {
|
||||
String expectedMessage = "policy is marked non-null but is null";
|
||||
String actualMessage = exception.getMessage();
|
||||
|
||||
assertTrue(actualMessage.contains(expectedMessage));
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void when_buildingWithInvalidMaxSteps_expect_exception() {
|
||||
try {
|
||||
Agent.builder(environmentMock, transformProcessMock, policyMock)
|
||||
.maxEpisodeSteps(0)
|
||||
.build();
|
||||
fail("IllegalArgumentException should have been thrown");
|
||||
} catch (IllegalArgumentException exception) {
|
||||
String expectedMessage = "maxEpisodeSteps must be greater than 0, got [0]";
|
||||
String actualMessage = exception.getMessage();
|
||||
|
||||
assertTrue(actualMessage.contains(expectedMessage));
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void when_buildingWithId_expect_idSetInAgent() {
|
||||
// Arrange
|
||||
Agent sut = Agent.builder(environmentMock, transformProcessMock, policyMock)
|
||||
.id("TestAgent")
|
||||
.build();
|
||||
|
||||
// Assert
|
||||
assertEquals("TestAgent", sut.getId());
|
||||
}
|
||||
|
||||
@Test
|
||||
public void when_runIsCalled_expect_agentIsReset() {
|
||||
// Arrange
|
||||
Map<String, Object> envResetResult = new HashMap<>();
|
||||
Schema schema = new Schema(new ActionSchema<>(-1));
|
||||
when(environmentMock.reset()).thenReturn(envResetResult);
|
||||
when(environmentMock.getSchema()).thenReturn(schema);
|
||||
|
||||
when(transformProcessMock.transform(any(Map.class), anyInt(), anyBoolean())).thenReturn(new Observation(Nd4j.create(new double[] { 123.0 })));
|
||||
when(policyMock.nextAction(any(Observation.class))).thenReturn(1);
|
||||
|
||||
Agent sut = Agent.builder(environmentMock, transformProcessMock, policyMock)
|
||||
.build();
|
||||
|
||||
when(listenerMock.onBeforeStep(any(Agent.class), any(Observation.class), anyInt())).thenReturn(AgentListener.ListenerResponse.STOP);
|
||||
sut.addListener(listenerMock);
|
||||
|
||||
// Act
|
||||
sut.run();
|
||||
|
||||
// Assert
|
||||
assertEquals(0, sut.getEpisodeStepNumber());
|
||||
verify(transformProcessMock).transform(envResetResult, 0, false);
|
||||
verify(policyMock, times(1)).reset();
|
||||
assertEquals(0.0, sut.getReward(), 0.00001);
|
||||
verify(environmentMock, times(1)).reset();
|
||||
}
|
||||
|
||||
@Test
|
||||
public void when_runIsCalled_expect_onBeforeAndAfterEpisodeCalled() {
|
||||
// Arrange
|
||||
Map<String, Object> envResetResult = new HashMap<>();
|
||||
Schema schema = new Schema(new ActionSchema<>(-1));
|
||||
when(environmentMock.reset()).thenReturn(envResetResult);
|
||||
when(environmentMock.getSchema()).thenReturn(schema);
|
||||
|
||||
when(transformProcessMock.transform(any(Map.class), anyInt(), anyBoolean())).thenReturn(new Observation(Nd4j.create(new double[] { 123.0 })));
|
||||
when(environmentMock.isEpisodeFinished()).thenReturn(true);
|
||||
|
||||
Agent sut = Agent.builder(environmentMock, transformProcessMock, policyMock).build();
|
||||
Agent spy = Mockito.spy(sut);
|
||||
|
||||
// Act
|
||||
spy.run();
|
||||
|
||||
// Assert
|
||||
verify(spy, times(1)).onBeforeEpisode();
|
||||
verify(spy, times(1)).onAfterEpisode();
|
||||
}
|
||||
|
||||
@Test
|
||||
public void when_onBeforeEpisodeReturnsStop_expect_performStepAndOnAfterEpisodeNotCalled() {
|
||||
// Arrange
|
||||
Map<String, Object> envResetResult = new HashMap<>();
|
||||
Schema schema = new Schema(new ActionSchema<>(-1));
|
||||
when(environmentMock.reset()).thenReturn(envResetResult);
|
||||
when(environmentMock.getSchema()).thenReturn(schema);
|
||||
|
||||
when(transformProcessMock.transform(any(Map.class), anyInt(), anyBoolean())).thenReturn(new Observation(Nd4j.create(new double[] { 123.0 })));
|
||||
|
||||
Agent sut = Agent.builder(environmentMock, transformProcessMock, policyMock).build();
|
||||
|
||||
when(listenerMock.onBeforeEpisode(any(Agent.class))).thenReturn(AgentListener.ListenerResponse.STOP);
|
||||
sut.addListener(listenerMock);
|
||||
|
||||
Agent spy = Mockito.spy(sut);
|
||||
|
||||
// Act
|
||||
spy.run();
|
||||
|
||||
// Assert
|
||||
verify(spy, times(1)).onBeforeEpisode();
|
||||
verify(spy, never()).performStep();
|
||||
verify(spy, never()).onAfterStep(any(StepResult.class));
|
||||
verify(spy, never()).onAfterEpisode();
|
||||
}
|
||||
|
||||
@Test
|
||||
public void when_runIsCalledWithoutMaxStep_expect_agentRunUntilEpisodeIsFinished() {
|
||||
// Arrange
|
||||
Map<String, Object> envResetResult = new HashMap<>();
|
||||
Schema schema = new Schema(new ActionSchema<>(-1));
|
||||
when(environmentMock.reset()).thenReturn(envResetResult);
|
||||
when(environmentMock.getSchema()).thenReturn(schema);
|
||||
|
||||
when(transformProcessMock.transform(any(Map.class), anyInt(), anyBoolean())).thenReturn(new Observation(Nd4j.create(new double[] { 123.0 })));
|
||||
|
||||
Agent sut = Agent.builder(environmentMock, transformProcessMock, policyMock)
|
||||
.build();
|
||||
|
||||
final Agent spy = Mockito.spy(sut);
|
||||
|
||||
doAnswer(invocation -> {
|
||||
((Agent)invocation.getMock()).incrementEpisodeStepNumber();
|
||||
return null;
|
||||
}).when(spy).performStep();
|
||||
when(environmentMock.isEpisodeFinished()).thenAnswer(invocation -> spy.getEpisodeStepNumber() >= 5 );
|
||||
|
||||
// Act
|
||||
spy.run();
|
||||
|
||||
// Assert
|
||||
verify(spy, times(1)).onBeforeEpisode();
|
||||
verify(spy, times(5)).performStep();
|
||||
verify(spy, times(1)).onAfterEpisode();
|
||||
}
|
||||
|
||||
@Test
|
||||
public void when_maxStepsIsReachedBeforeEposideEnds_expect_runTerminated() {
|
||||
// Arrange
|
||||
Map<String, Object> envResetResult = new HashMap<>();
|
||||
Schema schema = new Schema(new ActionSchema<>(-1));
|
||||
when(environmentMock.reset()).thenReturn(envResetResult);
|
||||
when(environmentMock.getSchema()).thenReturn(schema);
|
||||
|
||||
when(transformProcessMock.transform(any(Map.class), anyInt(), anyBoolean())).thenReturn(new Observation(Nd4j.create(new double[] { 123.0 })));
|
||||
|
||||
Agent sut = Agent.builder(environmentMock, transformProcessMock, policyMock)
|
||||
.maxEpisodeSteps(3)
|
||||
.build();
|
||||
|
||||
final Agent spy = Mockito.spy(sut);
|
||||
|
||||
doAnswer(invocation -> {
|
||||
((Agent)invocation.getMock()).incrementEpisodeStepNumber();
|
||||
return null;
|
||||
}).when(spy).performStep();
|
||||
|
||||
// Act
|
||||
spy.run();
|
||||
|
||||
// Assert
|
||||
verify(spy, times(1)).onBeforeEpisode();
|
||||
verify(spy, times(3)).performStep();
|
||||
verify(spy, times(1)).onAfterEpisode();
|
||||
}
|
||||
|
||||
@Test
|
||||
public void when_initialObservationsAreSkipped_expect_performNoOpAction() {
|
||||
// Arrange
|
||||
Map<String, Object> envResetResult = new HashMap<>();
|
||||
Schema schema = new Schema(new ActionSchema<>(-1));
|
||||
when(environmentMock.reset()).thenReturn(envResetResult);
|
||||
when(environmentMock.getSchema()).thenReturn(schema);
|
||||
|
||||
when(transformProcessMock.transform(any(Map.class), anyInt(), anyBoolean())).thenReturn(Observation.SkippedObservation);
|
||||
|
||||
Agent sut = Agent.builder(environmentMock, transformProcessMock, policyMock)
|
||||
.build();
|
||||
|
||||
when(listenerMock.onBeforeStep(any(Agent.class), any(Observation.class), any())).thenReturn(AgentListener.ListenerResponse.STOP);
|
||||
sut.addListener(listenerMock);
|
||||
|
||||
Agent spy = Mockito.spy(sut);
|
||||
|
||||
// Act
|
||||
spy.run();
|
||||
|
||||
// Assert
|
||||
verify(listenerMock).onBeforeStep(any(), any(), eq(-1));
|
||||
}
|
||||
|
||||
@Test
|
||||
public void when_initialObservationsAreSkipped_expect_performNoOpActionAnd() {
|
||||
// Arrange
|
||||
Map<String, Object> envResetResult = new HashMap<>();
|
||||
Schema schema = new Schema(new ActionSchema<>(-1));
|
||||
when(environmentMock.reset()).thenReturn(envResetResult);
|
||||
when(environmentMock.getSchema()).thenReturn(schema);
|
||||
|
||||
when(transformProcessMock.transform(any(Map.class), anyInt(), anyBoolean())).thenReturn(Observation.SkippedObservation);
|
||||
|
||||
Agent sut = Agent.builder(environmentMock, transformProcessMock, policyMock)
|
||||
.build();
|
||||
|
||||
when(listenerMock.onBeforeStep(any(Agent.class), any(Observation.class), any())).thenReturn(AgentListener.ListenerResponse.STOP);
|
||||
sut.addListener(listenerMock);
|
||||
|
||||
Agent spy = Mockito.spy(sut);
|
||||
|
||||
// Act
|
||||
spy.run();
|
||||
|
||||
// Assert
|
||||
verify(listenerMock).onBeforeStep(any(), any(), eq(-1));
|
||||
}
|
||||
|
||||
@Test
|
||||
public void when_observationsIsSkipped_expect_performLastAction() {
|
||||
// Arrange
|
||||
Map<String, Object> envResetResult = new HashMap<>();
|
||||
Schema schema = new Schema(new ActionSchema<>(-1));
|
||||
when(environmentMock.reset()).thenReturn(envResetResult);
|
||||
when(environmentMock.step(any(Integer.class))).thenReturn(new StepResult(envResetResult, 0.0, false));
|
||||
when(environmentMock.getSchema()).thenReturn(schema);
|
||||
|
||||
when(policyMock.nextAction(any(Observation.class)))
|
||||
.thenAnswer(invocation -> (int)((Observation)invocation.getArgument(0)).getData().getDouble(0));
|
||||
|
||||
Agent sut = Agent.builder(environmentMock, transformProcessMock, policyMock)
|
||||
.maxEpisodeSteps(3)
|
||||
.build();
|
||||
|
||||
Agent spy = Mockito.spy(sut);
|
||||
|
||||
when(transformProcessMock.transform(any(Map.class), anyInt(), anyBoolean()))
|
||||
.thenAnswer(invocation -> {
|
||||
int stepNumber = (int)invocation.getArgument(1);
|
||||
return stepNumber % 2 == 1 ? Observation.SkippedObservation
|
||||
: new Observation(Nd4j.create(new double[] { stepNumber }));
|
||||
});
|
||||
|
||||
sut.addListener(listenerMock);
|
||||
|
||||
// Act
|
||||
spy.run();
|
||||
|
||||
// Assert
|
||||
verify(policyMock, times(2)).nextAction(any(Observation.class));
|
||||
|
||||
ArgumentCaptor<Agent> agentCaptor = ArgumentCaptor.forClass(Agent.class);
|
||||
ArgumentCaptor<Observation> observationCaptor = ArgumentCaptor.forClass(Observation.class);
|
||||
ArgumentCaptor<Integer> actionCaptor = ArgumentCaptor.forClass(Integer.class);
|
||||
verify(listenerMock, times(3)).onBeforeStep(agentCaptor.capture(), observationCaptor.capture(), actionCaptor.capture());
|
||||
List<Integer> capturedActions = actionCaptor.getAllValues();
|
||||
assertEquals(0, (int)capturedActions.get(0));
|
||||
assertEquals(0, (int)capturedActions.get(1));
|
||||
assertEquals(2, (int)capturedActions.get(2));
|
||||
}
|
||||
|
||||
@Test
|
||||
public void when_onBeforeStepReturnsStop_expect_performStepAndOnAfterEpisodeNotCalled() {
|
||||
// Arrange
|
||||
Schema schema = new Schema(new ActionSchema<>(-1));
|
||||
when(environmentMock.reset()).thenReturn(new HashMap<>());
|
||||
when(environmentMock.getSchema()).thenReturn(schema);
|
||||
|
||||
when(transformProcessMock.transform(any(Map.class), anyInt(), anyBoolean())).thenReturn(new Observation(Nd4j.create(new double[] { 123.0 })));
|
||||
|
||||
Agent sut = Agent.builder(environmentMock, transformProcessMock, policyMock).build();
|
||||
|
||||
when(listenerMock.onBeforeStep(any(Agent.class), any(Observation.class), any())).thenReturn(AgentListener.ListenerResponse.STOP);
|
||||
sut.addListener(listenerMock);
|
||||
|
||||
Agent spy = Mockito.spy(sut);
|
||||
|
||||
// Act
|
||||
spy.run();
|
||||
|
||||
// Assert
|
||||
verify(spy, times(1)).onBeforeEpisode();
|
||||
verify(spy, times(1)).onBeforeStep();
|
||||
verify(spy, never()).act(any());
|
||||
verify(spy, never()).onAfterStep(any(StepResult.class));
|
||||
verify(spy, never()).onAfterEpisode();
|
||||
}
|
||||
|
||||
@Test
|
||||
public void when_observationIsNotSkipped_expect_policyActionIsSentToEnvironment() {
|
||||
// Arrange
|
||||
Schema schema = new Schema(new ActionSchema<>(-1));
|
||||
when(environmentMock.reset()).thenReturn(new HashMap<>());
|
||||
when(environmentMock.getSchema()).thenReturn(schema);
|
||||
when(environmentMock.step(any(Integer.class))).thenReturn(new StepResult(new HashMap<>(), 0.0, false));
|
||||
|
||||
when(transformProcessMock.transform(any(Map.class), anyInt(), anyBoolean())).thenReturn(new Observation(Nd4j.create(new double[] { 123.0 })));
|
||||
|
||||
when(policyMock.nextAction(any(Observation.class))).thenReturn(123);
|
||||
|
||||
Agent sut = Agent.builder(environmentMock, transformProcessMock, policyMock)
|
||||
.maxEpisodeSteps(1)
|
||||
.build();
|
||||
|
||||
// Act
|
||||
sut.run();
|
||||
|
||||
// Assert
|
||||
verify(environmentMock, times(1)).step(123);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void when_stepResultIsReceived_expect_observationAndRewardUpdated() {
|
||||
// Arrange
|
||||
Schema schema = new Schema(new ActionSchema<>(-1));
|
||||
when(environmentMock.reset()).thenReturn(new HashMap<>());
|
||||
when(environmentMock.getSchema()).thenReturn(schema);
|
||||
when(environmentMock.step(any(Integer.class))).thenReturn(new StepResult(new HashMap<>(), 234.0, false));
|
||||
|
||||
when(transformProcessMock.transform(any(Map.class), anyInt(), anyBoolean())).thenReturn(new Observation(Nd4j.create(new double[] { 123.0 })));
|
||||
|
||||
when(policyMock.nextAction(any(Observation.class))).thenReturn(123);
|
||||
|
||||
Agent sut = Agent.builder(environmentMock, transformProcessMock, policyMock)
|
||||
.maxEpisodeSteps(1)
|
||||
.build();
|
||||
|
||||
// Act
|
||||
sut.run();
|
||||
|
||||
// Assert
|
||||
assertEquals(123.0, sut.getObservation().getData().getDouble(0), 0.00001);
|
||||
assertEquals(234.0, sut.getReward(), 0.00001);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void when_stepIsDone_expect_onAfterStepAndWithStepResult() {
|
||||
// Arrange
|
||||
Schema schema = new Schema(new ActionSchema<>(-1));
|
||||
when(environmentMock.reset()).thenReturn(new HashMap<>());
|
||||
when(environmentMock.getSchema()).thenReturn(schema);
|
||||
StepResult stepResult = new StepResult(new HashMap<>(), 234.0, false);
|
||||
when(environmentMock.step(any(Integer.class))).thenReturn(stepResult);
|
||||
|
||||
when(transformProcessMock.transform(any(Map.class), anyInt(), anyBoolean())).thenReturn(new Observation(Nd4j.create(new double[] { 123.0 })));
|
||||
|
||||
when(policyMock.nextAction(any(Observation.class))).thenReturn(123);
|
||||
|
||||
Agent sut = Agent.builder(environmentMock, transformProcessMock, policyMock)
|
||||
.maxEpisodeSteps(1)
|
||||
.build();
|
||||
Agent spy = Mockito.spy(sut);
|
||||
|
||||
// Act
|
||||
spy.run();
|
||||
|
||||
// Assert
|
||||
verify(spy).onAfterStep(stepResult);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void when_onAfterStepReturnsStop_expect_onAfterEpisodeNotCalled() {
|
||||
// Arrange
|
||||
Schema schema = new Schema(new ActionSchema<>(-1));
|
||||
when(environmentMock.reset()).thenReturn(new HashMap<>());
|
||||
when(environmentMock.getSchema()).thenReturn(schema);
|
||||
StepResult stepResult = new StepResult(new HashMap<>(), 234.0, false);
|
||||
when(environmentMock.step(any(Integer.class))).thenReturn(stepResult);
|
||||
|
||||
when(transformProcessMock.transform(any(Map.class), anyInt(), anyBoolean())).thenReturn(new Observation(Nd4j.create(new double[] { 123.0 })));
|
||||
|
||||
when(policyMock.nextAction(any(Observation.class))).thenReturn(123);
|
||||
|
||||
Agent sut = Agent.builder(environmentMock, transformProcessMock, policyMock)
|
||||
.maxEpisodeSteps(1)
|
||||
.build();
|
||||
when(listenerMock.onAfterStep(any(Agent.class), any(StepResult.class))).thenReturn(AgentListener.ListenerResponse.STOP);
|
||||
sut.addListener(listenerMock);
|
||||
|
||||
Agent spy = Mockito.spy(sut);
|
||||
|
||||
// Act
|
||||
spy.run();
|
||||
|
||||
// Assert
|
||||
verify(spy, never()).onAfterEpisode();
|
||||
}
|
||||
|
||||
@Test
|
||||
public void when_runIsCalled_expect_onAfterEpisodeIsCalled() {
|
||||
// Arrange
|
||||
Schema schema = new Schema(new ActionSchema<>(-1));
|
||||
when(environmentMock.reset()).thenReturn(new HashMap<>());
|
||||
when(environmentMock.getSchema()).thenReturn(schema);
|
||||
StepResult stepResult = new StepResult(new HashMap<>(), 234.0, false);
|
||||
when(environmentMock.step(any(Integer.class))).thenReturn(stepResult);
|
||||
|
||||
when(transformProcessMock.transform(any(Map.class), anyInt(), anyBoolean())).thenReturn(new Observation(Nd4j.create(new double[] { 123.0 })));
|
||||
|
||||
when(policyMock.nextAction(any(Observation.class))).thenReturn(123);
|
||||
|
||||
Agent sut = Agent.builder(environmentMock, transformProcessMock, policyMock)
|
||||
.maxEpisodeSteps(1)
|
||||
.build();
|
||||
|
||||
Agent spy = Mockito.spy(sut);
|
||||
|
||||
// Act
|
||||
spy.run();
|
||||
|
||||
// Assert
|
||||
verify(spy, times(1)).onAfterEpisode();
|
||||
}
|
||||
}
|
|
@ -62,7 +62,7 @@ public class AsyncThreadDiscreteTest {
|
|||
IAsyncGlobal<NeuralNet> mockAsyncGlobal;
|
||||
|
||||
@Mock
|
||||
Policy<Encodable, Integer> mockGlobalCurrentPolicy;
|
||||
Policy<Integer> mockGlobalCurrentPolicy;
|
||||
|
||||
@Mock
|
||||
NeuralNet mockGlobalTargetNetwork;
|
||||
|
|
|
@ -30,13 +30,8 @@ import org.deeplearning4j.rl4j.network.NeuralNet;
|
|||
import org.deeplearning4j.rl4j.network.ac.IActorCritic;
|
||||
import org.deeplearning4j.rl4j.observation.Observation;
|
||||
import org.deeplearning4j.rl4j.space.ActionSpace;
|
||||
import org.deeplearning4j.rl4j.support.MockDQN;
|
||||
import org.deeplearning4j.rl4j.support.MockEncodable;
|
||||
import org.deeplearning4j.rl4j.support.MockHistoryProcessor;
|
||||
import org.deeplearning4j.rl4j.support.MockMDP;
|
||||
import org.deeplearning4j.rl4j.support.MockNeuralNet;
|
||||
import org.deeplearning4j.rl4j.support.MockObservationSpace;
|
||||
import org.deeplearning4j.rl4j.support.MockRandom;
|
||||
import org.deeplearning4j.rl4j.space.Encodable;
|
||||
import org.deeplearning4j.rl4j.support.*;
|
||||
import org.deeplearning4j.rl4j.util.LegacyMDPWrapper;
|
||||
import org.junit.Test;
|
||||
import org.nd4j.linalg.activations.Activation;
|
||||
|
@ -227,7 +222,7 @@ public class PolicyTest {
|
|||
assertEquals(0, dqn.outputParams.size());
|
||||
}
|
||||
|
||||
public static class MockRefacPolicy extends Policy<MockEncodable, Integer> {
|
||||
public static class MockRefacPolicy extends Policy<Integer> {
|
||||
|
||||
private NeuralNet neuralNet;
|
||||
private final int[] shape;
|
||||
|
@ -257,7 +252,7 @@ public class PolicyTest {
|
|||
}
|
||||
|
||||
@Override
|
||||
protected <AS extends ActionSpace<Integer>> Learning.InitMdp<Observation> refacInitMdp(LegacyMDPWrapper<MockEncodable, Integer, AS> mdpWrapper, IHistoryProcessor hp) {
|
||||
protected <O extends Encodable, AS extends ActionSpace<Integer>> Learning.InitMdp<Observation> refacInitMdp(LegacyMDPWrapper<O, Integer, AS> mdpWrapper, IHistoryProcessor hp) {
|
||||
mdpWrapper.setTransformProcess(MockMDP.buildTransformProcess(shape, skipFrame, historyLength));
|
||||
return super.refacInitMdp(mdpWrapper, hp);
|
||||
}
|
||||
|
|
|
@ -5,18 +5,19 @@ import org.deeplearning4j.rl4j.mdp.MDP;
|
|||
import org.deeplearning4j.rl4j.observation.Observation;
|
||||
import org.deeplearning4j.rl4j.policy.IPolicy;
|
||||
import org.deeplearning4j.rl4j.space.ActionSpace;
|
||||
import org.deeplearning4j.rl4j.space.Encodable;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
|
||||
public class MockPolicy implements IPolicy<MockEncodable, Integer> {
|
||||
public class MockPolicy implements IPolicy<Integer> {
|
||||
|
||||
public int playCallCount = 0;
|
||||
public List<INDArray> actionInputs = new ArrayList<INDArray>();
|
||||
|
||||
@Override
|
||||
public <AS extends ActionSpace<Integer>> double play(MDP<MockEncodable, Integer, AS> mdp, IHistoryProcessor hp) {
|
||||
public <O extends Encodable, AS extends ActionSpace<Integer>> double play(MDP<O, Integer, AS> mdp, IHistoryProcessor hp) {
|
||||
++playCallCount;
|
||||
return 0;
|
||||
}
|
||||
|
@ -31,4 +32,9 @@ public class MockPolicy implements IPolicy<MockEncodable, Integer> {
|
|||
public Integer nextAction(Observation observation) {
|
||||
return nextAction(observation.getData());
|
||||
}
|
||||
|
||||
@Override
|
||||
public void reset() {
|
||||
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue