RL4J: Add AgentLearner (#470)

Signed-off-by: Alexandre Boulanger <aboulang2002@yahoo.com>
master
Alexandre Boulanger 2020-05-27 07:41:02 -04:00 committed by GitHub
parent a18417193d
commit 5568b9d72f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
40 changed files with 1541 additions and 244 deletions

View File

@ -1,3 +1,18 @@
/*******************************************************************************
* Copyright (c) 2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.rl4j.agent; package org.deeplearning4j.rl4j.agent;
import lombok.AccessLevel; import lombok.AccessLevel;
@ -14,7 +29,13 @@ import org.nd4j.common.base.Preconditions;
import java.util.Map; import java.util.Map;
public class Agent<ACTION> { /**
* An agent implementation. The Agent will use a {@link IPolicy} to interact with an {@link Environment} and receive
* a reward.
*
* @param <ACTION> The type of action
*/
public class Agent<ACTION> implements IAgent<ACTION> {
@Getter @Getter
private final String id; private final String id;
@ -37,19 +58,28 @@ public class Agent<ACTION> {
private ACTION lastAction; private ACTION lastAction;
@Getter @Getter
private int episodeStepNumber; private int episodeStepCount;
@Getter @Getter
private double reward; private double reward;
protected boolean canContinue; protected boolean canContinue;
private Agent(Builder<ACTION> builder) { /**
this.environment = builder.environment; * @param environment The {@link Environment} to be used
this.transformProcess = builder.transformProcess; * @param transformProcess The {@link TransformProcess} to be used to transform the raw observations into usable ones.
this.policy = builder.policy; * @param policy The {@link IPolicy} to be used
this.maxEpisodeSteps = builder.maxEpisodeSteps; * @param maxEpisodeSteps The maximum number of steps an episode can have before being interrupted. Use null to have no max.
this.id = builder.id; * @param id A user-supplied id to identify the instance.
*/
public Agent(@NonNull Environment<ACTION> environment, @NonNull TransformProcess transformProcess, @NonNull IPolicy<ACTION> policy, Integer maxEpisodeSteps, String id) {
Preconditions.checkArgument(maxEpisodeSteps == null || maxEpisodeSteps > 0, "maxEpisodeSteps must be null (no maximum) or greater than 0, got", maxEpisodeSteps);
this.environment = environment;
this.transformProcess = transformProcess;
this.policy = policy;
this.maxEpisodeSteps = maxEpisodeSteps;
this.id = id;
listeners = buildListenerList(); listeners = buildListenerList();
} }
@ -58,10 +88,17 @@ public class Agent<ACTION> {
return new AgentListenerList<ACTION>(); return new AgentListenerList<ACTION>();
} }
/**
* Add a {@link AgentListener} that will be notified when agent events happens
* @param listener
*/
public void addListener(AgentListener listener) { public void addListener(AgentListener listener) {
listeners.add(listener); listeners.add(listener);
} }
/**
* This will run a single episode
*/
public void run() { public void run() {
runEpisode(); runEpisode();
} }
@ -80,7 +117,7 @@ public class Agent<ACTION> {
canContinue = listeners.notifyBeforeEpisode(this); canContinue = listeners.notifyBeforeEpisode(this);
while (canContinue && !environment.isEpisodeFinished() && (maxEpisodeSteps == null || episodeStepNumber < maxEpisodeSteps)) { while (canContinue && !environment.isEpisodeFinished() && (maxEpisodeSteps == null || episodeStepCount < maxEpisodeSteps)) {
performStep(); performStep();
} }
@ -100,9 +137,9 @@ public class Agent<ACTION> {
} }
protected void resetEnvironment() { protected void resetEnvironment() {
episodeStepNumber = 0; episodeStepCount = 0;
Map<String, Object> channelsData = environment.reset(); Map<String, Object> channelsData = environment.reset();
this.observation = transformProcess.transform(channelsData, episodeStepNumber, false); this.observation = transformProcess.transform(channelsData, episodeStepCount, false);
} }
protected void resetPolicy() { protected void resetPolicy() {
@ -125,7 +162,6 @@ public class Agent<ACTION> {
} }
StepResult stepResult = act(action); StepResult stepResult = act(action);
handleStepResult(stepResult);
onAfterStep(stepResult); onAfterStep(stepResult);
@ -134,11 +170,11 @@ public class Agent<ACTION> {
return; return;
} }
incrementEpisodeStepNumber(); incrementEpisodeStepCount();
} }
protected void incrementEpisodeStepNumber() { protected void incrementEpisodeStepCount() {
++episodeStepNumber; ++episodeStepCount;
} }
protected ACTION decideAction(Observation observation) { protected ACTION decideAction(Observation observation) {
@ -150,12 +186,15 @@ public class Agent<ACTION> {
} }
protected StepResult act(ACTION action) { protected StepResult act(ACTION action) {
return environment.step(action); Observation observationBeforeAction = observation;
}
protected void handleStepResult(StepResult stepResult) { StepResult stepResult = environment.step(action);
observation = convertChannelDataToObservation(stepResult, episodeStepNumber + 1); observation = convertChannelDataToObservation(stepResult, episodeStepCount + 1);
reward +=computeReward(stepResult); reward += computeReward(stepResult);
onAfterAction(observationBeforeAction, action, stepResult);
return stepResult;
} }
protected Observation convertChannelDataToObservation(StepResult stepResult, int episodeStepNumberOfObs) { protected Observation convertChannelDataToObservation(StepResult stepResult, int episodeStepNumberOfObs) {
@ -166,6 +205,10 @@ public class Agent<ACTION> {
return stepResult.getReward(); return stepResult.getReward();
} }
protected void onAfterAction(Observation observationBeforeAction, ACTION action, StepResult stepResult) {
// Do Nothing
}
protected void onAfterStep(StepResult stepResult) { protected void onAfterStep(StepResult stepResult) {
// Do Nothing // Do Nothing
} }
@ -174,16 +217,24 @@ public class Agent<ACTION> {
// Do Nothing // Do Nothing
} }
public static <ACTION> Builder<ACTION> builder(@NonNull Environment<ACTION> environment, @NonNull TransformProcess transformProcess, @NonNull IPolicy<ACTION> policy) { /**
*
* @param environment
* @param transformProcess
* @param policy
* @param <ACTION>
* @return
*/
public static <ACTION> Builder<ACTION, Agent> builder(@NonNull Environment<ACTION> environment, @NonNull TransformProcess transformProcess, @NonNull IPolicy<ACTION> policy) {
return new Builder<>(environment, transformProcess, policy); return new Builder<>(environment, transformProcess, policy);
} }
public static class Builder<ACTION> { public static class Builder<ACTION, AGENT_TYPE extends Agent> {
private final Environment<ACTION> environment; protected final Environment<ACTION> environment;
private final TransformProcess transformProcess; protected final TransformProcess transformProcess;
private final IPolicy<ACTION> policy; protected final IPolicy<ACTION> policy;
private Integer maxEpisodeSteps = null; // Default, no max protected Integer maxEpisodeSteps = null; // Default, no max
private String id; protected String id;
public Builder(@NonNull Environment<ACTION> environment, @NonNull TransformProcess transformProcess, @NonNull IPolicy<ACTION> policy) { public Builder(@NonNull Environment<ACTION> environment, @NonNull TransformProcess transformProcess, @NonNull IPolicy<ACTION> policy) {
this.environment = environment; this.environment = environment;
@ -191,20 +242,20 @@ public class Agent<ACTION> {
this.policy = policy; this.policy = policy;
} }
public Builder<ACTION> maxEpisodeSteps(int maxEpisodeSteps) { public Builder<ACTION, AGENT_TYPE> maxEpisodeSteps(int maxEpisodeSteps) {
Preconditions.checkArgument(maxEpisodeSteps > 0, "maxEpisodeSteps must be greater than 0, got", maxEpisodeSteps); Preconditions.checkArgument(maxEpisodeSteps > 0, "maxEpisodeSteps must be greater than 0, got", maxEpisodeSteps);
this.maxEpisodeSteps = maxEpisodeSteps; this.maxEpisodeSteps = maxEpisodeSteps;
return this; return this;
} }
public Builder<ACTION> id(String id) { public Builder<ACTION, AGENT_TYPE> id(String id) {
this.id = id; this.id = id;
return this; return this;
} }
public Agent build() { public AGENT_TYPE build() {
return new Agent(this); return (AGENT_TYPE)new Agent<ACTION>(environment, transformProcess, policy, maxEpisodeSteps, id);
} }
} }
} }

View File

@ -0,0 +1,115 @@
/*******************************************************************************
* Copyright (c) 2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.rl4j.agent;
import lombok.Getter;
import lombok.NonNull;
import org.deeplearning4j.rl4j.agent.learning.ILearningBehavior;
import org.deeplearning4j.rl4j.environment.Environment;
import org.deeplearning4j.rl4j.environment.StepResult;
import org.deeplearning4j.rl4j.observation.Observation;
import org.deeplearning4j.rl4j.observation.transform.TransformProcess;
import org.deeplearning4j.rl4j.policy.IPolicy;
/**
* The ActionLearner is an {@link Agent} that delegate the learning to a {@link ILearningBehavior}.
* @param <ACTION> The type of the action
*/
public class AgentLearner<ACTION> extends Agent<ACTION> implements IAgentLearner<ACTION> {
@Getter
private int totalStepCount = 0;
private final ILearningBehavior<ACTION> learningBehavior;
private double rewardAtLastExperience;
/**
*
* @param environment The {@link Environment} to be used
* @param transformProcess The {@link TransformProcess} to be used to transform the raw observations into usable ones.
* @param policy The {@link IPolicy} to be used
* @param maxEpisodeSteps The maximum number of steps an episode can have before being interrupted. Use null to have no max.
* @param id A user-supplied id to identify the instance.
* @param learningBehavior The {@link ILearningBehavior} that will be used to supervise the learning.
*/
public AgentLearner(Environment<ACTION> environment, TransformProcess transformProcess, IPolicy<ACTION> policy, Integer maxEpisodeSteps, String id, @NonNull ILearningBehavior<ACTION> learningBehavior) {
super(environment, transformProcess, policy, maxEpisodeSteps, id);
this.learningBehavior = learningBehavior;
}
@Override
protected void reset() {
super.reset();
rewardAtLastExperience = 0;
}
@Override
protected void onBeforeEpisode() {
super.onBeforeEpisode();
learningBehavior.handleEpisodeStart();
}
@Override
protected void onAfterAction(Observation observationBeforeAction, ACTION action, StepResult stepResult) {
if(!observationBeforeAction.isSkipped()) {
double rewardSinceLastExperience = getReward() - rewardAtLastExperience;
learningBehavior.handleNewExperience(observationBeforeAction, action, rewardSinceLastExperience, stepResult.isTerminal());
rewardAtLastExperience = getReward();
}
}
@Override
protected void onAfterEpisode() {
learningBehavior.handleEpisodeEnd(getObservation());
}
@Override
protected void incrementEpisodeStepCount() {
super.incrementEpisodeStepCount();
++totalStepCount;
}
// FIXME: parent is still visible
public static <ACTION> AgentLearner.Builder<ACTION, AgentLearner<ACTION>> builder(Environment<ACTION> environment,
TransformProcess transformProcess,
IPolicy<ACTION> policy,
ILearningBehavior<ACTION> learningBehavior) {
return new AgentLearner.Builder<ACTION, AgentLearner<ACTION>>(environment, transformProcess, policy, learningBehavior);
}
public static class Builder<ACTION, AGENT_TYPE extends AgentLearner<ACTION>> extends Agent.Builder<ACTION, AGENT_TYPE> {
private final ILearningBehavior<ACTION> learningBehavior;
public Builder(@NonNull Environment<ACTION> environment,
@NonNull TransformProcess transformProcess,
@NonNull IPolicy<ACTION> policy,
@NonNull ILearningBehavior<ACTION> learningBehavior) {
super(environment, transformProcess, policy);
this.learningBehavior = learningBehavior;
}
@Override
public AGENT_TYPE build() {
return (AGENT_TYPE)new AgentLearner<ACTION>(environment, transformProcess, policy, maxEpisodeSteps, id, learningBehavior);
}
}
}

View File

@ -0,0 +1,55 @@
/*******************************************************************************
* Copyright (c) 2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.rl4j.agent;
import org.deeplearning4j.rl4j.environment.Environment;
import org.deeplearning4j.rl4j.policy.IPolicy;
/**
* The interface of {@link Agent}
* @param <ACTION>
*/
public interface IAgent<ACTION> {
/**
* Will play a single episode
*/
void run();
/**
* @return A user-supplied id to identify the IAgent instance.
*/
String getId();
/**
* @return The {@link Environment} instance being used by the agent.
*/
Environment<ACTION> getEnvironment();
/**
* @return The {@link IPolicy} instance being used by the agent.
*/
IPolicy<ACTION> getPolicy();
/**
* @return The step count taken in the current episode.
*/
int getEpisodeStepCount();
/**
* @return The cumulative reward received in the current episode.
*/
double getReward();
}

View File

@ -0,0 +1,24 @@
/*******************************************************************************
* Copyright (c) 2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.rl4j.agent;
public interface IAgentLearner<ACTION> extends IAgent<ACTION> {
/**
* @return The total count of steps taken by this AgentLearner, for all episodes.
*/
int getTotalStepCount();
}

View File

@ -0,0 +1,49 @@
/*******************************************************************************
* Copyright (c) 2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.rl4j.agent.learning;
import org.deeplearning4j.rl4j.observation.Observation;
/**
* The <code>ILearningBehavior</code> implementations are in charge of the training. Through this interface, they are
* notified as new experience is generated.
*
* @param <ACTION> The type of action
*/
public interface ILearningBehavior<ACTION> {
/**
* This method is called when a new episode has been started.
*/
void handleEpisodeStart();
/**
* This method is called when new experience is generated.
*
* @param observation The observation prior to taking the action
* @param action The action that has been taken
* @param reward The reward received by taking the action
* @param isTerminal True if the episode ended after taking the action
*/
void handleNewExperience(Observation observation, ACTION action, double reward, boolean isTerminal);
/**
* This method is called when the episode ends or the maximum number of episode steps is reached.
*
* @param finalObservation The observation after the last action of the episode has been taken.
*/
void handleEpisodeEnd(Observation finalObservation);
}

View File

@ -0,0 +1,59 @@
/*******************************************************************************
* Copyright (c) 2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.rl4j.agent.learning;
import lombok.Builder;
import org.deeplearning4j.rl4j.agent.update.IUpdateRule;
import org.deeplearning4j.rl4j.experience.ExperienceHandler;
import org.deeplearning4j.rl4j.observation.Observation;
/**
* A generic {@link ILearningBehavior} that delegates the handling of experience to a {@link ExperienceHandler} and
* the update logic to a {@link IUpdateRule}
*
* @param <ACTION> The type of the action
* @param <EXPERIENCE_TYPE> The type of experience the ExperienceHandler needs
*/
@Builder
public class LearningBehavior<ACTION, EXPERIENCE_TYPE> implements ILearningBehavior<ACTION> {
@Builder.Default
private int experienceUpdateSize = 64;
private final ExperienceHandler<ACTION, EXPERIENCE_TYPE> experienceHandler;
private final IUpdateRule<EXPERIENCE_TYPE> updateRule;
@Override
public void handleEpisodeStart() {
experienceHandler.reset();
}
@Override
public void handleNewExperience(Observation observation, ACTION action, double reward, boolean isTerminal) {
experienceHandler.addExperience(observation, action, reward, isTerminal);
if(experienceHandler.isTrainingBatchReady()) {
updateRule.update(experienceHandler.generateTrainingBatch());
}
}
@Override
public void handleEpisodeEnd(Observation finalObservation) {
experienceHandler.setFinalObservation(finalObservation);
if(experienceHandler.isTrainingBatchReady()) {
updateRule.update(experienceHandler.generateTrainingBatch());
}
}
}

View File

@ -1,23 +1,66 @@
/*******************************************************************************
* Copyright (c) 2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.rl4j.agent.listener; package org.deeplearning4j.rl4j.agent.listener;
import org.deeplearning4j.rl4j.agent.Agent; import org.deeplearning4j.rl4j.agent.Agent;
import org.deeplearning4j.rl4j.environment.StepResult; import org.deeplearning4j.rl4j.environment.StepResult;
import org.deeplearning4j.rl4j.observation.Observation; import org.deeplearning4j.rl4j.observation.Observation;
/**
* The base definition of all {@link Agent} event listeners
*/
public interface AgentListener<ACTION> { public interface AgentListener<ACTION> {
enum ListenerResponse { enum ListenerResponse {
/** /**
* Tell the learning process to continue calling the listeners and the training. * Tell the {@link Agent} to continue calling the listeners and the processing.
*/ */
CONTINUE, CONTINUE,
/** /**
* Tell the learning process to stop calling the listeners and terminate the training. * Tell the {@link Agent} to interrupt calling the listeners and stop the processing.
*/ */
STOP, STOP,
} }
/**
* Called when a new episode is about to start.
* @param agent The agent that generated the event
*
* @return A {@link ListenerResponse}.
*/
AgentListener.ListenerResponse onBeforeEpisode(Agent agent); AgentListener.ListenerResponse onBeforeEpisode(Agent agent);
/**
* Called when a step is about to be taken.
*
* @param agent The agent that generated the event
* @param observation The observation before the action is taken
* @param action The action that will be performed
*
* @return A {@link ListenerResponse}.
*/
AgentListener.ListenerResponse onBeforeStep(Agent agent, Observation observation, ACTION action); AgentListener.ListenerResponse onBeforeStep(Agent agent, Observation observation, ACTION action);
/**
* Called after a step has been taken.
*
* @param agent The agent that generated the event
* @param stepResult The {@link StepResult} result of the step.
*
* @return A {@link ListenerResponse}.
*/
AgentListener.ListenerResponse onAfterStep(Agent agent, StepResult stepResult); AgentListener.ListenerResponse onAfterStep(Agent agent, StepResult stepResult);
} }

View File

@ -1,3 +1,18 @@
/*******************************************************************************
* Copyright (c) 2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.rl4j.agent.listener; package org.deeplearning4j.rl4j.agent.listener;
import org.deeplearning4j.rl4j.agent.Agent; import org.deeplearning4j.rl4j.agent.Agent;
@ -7,6 +22,10 @@ import org.deeplearning4j.rl4j.observation.Observation;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.List; import java.util.List;
/**
* A class that manages a list of {@link AgentListener AgentListeners} listening to an {@link Agent}.
* @param <ACTION>
*/
public class AgentListenerList<ACTION> { public class AgentListenerList<ACTION> {
protected final List<AgentListener<ACTION>> listeners = new ArrayList<>(); protected final List<AgentListener<ACTION>> listeners = new ArrayList<>();
@ -18,6 +37,13 @@ public class AgentListenerList<ACTION> {
listeners.add(listener); listeners.add(listener);
} }
/**
* This method will notify all listeners that an episode is about to start. If a listener returns
* {@link AgentListener.ListenerResponse STOP}, any following listener is skipped.
*
* @param agent The agent that generated the event.
* @return False if the processing should be stopped
*/
public boolean notifyBeforeEpisode(Agent<ACTION> agent) { public boolean notifyBeforeEpisode(Agent<ACTION> agent) {
for (AgentListener<ACTION> listener : listeners) { for (AgentListener<ACTION> listener : listeners) {
if (listener.onBeforeEpisode(agent) == AgentListener.ListenerResponse.STOP) { if (listener.onBeforeEpisode(agent) == AgentListener.ListenerResponse.STOP) {
@ -28,6 +54,13 @@ public class AgentListenerList<ACTION> {
return true; return true;
} }
/**
*
* @param agent The agent that generated the event.
* @param observation The observation before the action is taken
* @param action The action that will be performed
* @return False if the processing should be stopped
*/
public boolean notifyBeforeStep(Agent<ACTION> agent, Observation observation, ACTION action) { public boolean notifyBeforeStep(Agent<ACTION> agent, Observation observation, ACTION action) {
for (AgentListener<ACTION> listener : listeners) { for (AgentListener<ACTION> listener : listeners) {
if (listener.onBeforeStep(agent, observation, action) == AgentListener.ListenerResponse.STOP) { if (listener.onBeforeStep(agent, observation, action) == AgentListener.ListenerResponse.STOP) {
@ -38,6 +71,12 @@ public class AgentListenerList<ACTION> {
return true; return true;
} }
/**
*
* @param agent The agent that generated the event.
* @param stepResult The {@link StepResult} result of the step.
* @return False if the processing should be stopped
*/
public boolean notifyAfterStep(Agent<ACTION> agent, StepResult stepResult) { public boolean notifyAfterStep(Agent<ACTION> agent, StepResult stepResult) {
for (AgentListener<ACTION> listener : listeners) { for (AgentListener<ACTION> listener : listeners) {
if (listener.onAfterStep(agent, stepResult) == AgentListener.ListenerResponse.STOP) { if (listener.onAfterStep(agent, stepResult) == AgentListener.ListenerResponse.STOP) {

View File

@ -0,0 +1,62 @@
/*******************************************************************************
* Copyright (c) 2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.rl4j.agent.update;
import lombok.Getter;
import org.deeplearning4j.rl4j.learning.sync.Transition;
import org.deeplearning4j.rl4j.learning.sync.qlearning.TargetQNetworkSource;
import org.deeplearning4j.rl4j.learning.sync.qlearning.discrete.TDTargetAlgorithm.DoubleDQN;
import org.deeplearning4j.rl4j.learning.sync.qlearning.discrete.TDTargetAlgorithm.ITDTargetAlgorithm;
import org.deeplearning4j.rl4j.learning.sync.qlearning.discrete.TDTargetAlgorithm.StandardDQN;
import org.deeplearning4j.rl4j.network.dqn.IDQN;
import org.nd4j.linalg.dataset.api.DataSet;
import java.util.List;
// Temporary class that will be replaced with a more generic class that delegates gradient computation
// and network update to sub components.
public class DQNNeuralNetUpdateRule implements IUpdateRule<Transition<Integer>>, TargetQNetworkSource {
@Getter
private final IDQN qNetwork;
@Getter
private IDQN targetQNetwork;
private final int targetUpdateFrequency;
private final ITDTargetAlgorithm<Integer> tdTargetAlgorithm;
@Getter
private int updateCount = 0;
public DQNNeuralNetUpdateRule(IDQN qNetwork, int targetUpdateFrequency, boolean isDoubleDQN, double gamma, double errorClamp) {
this.qNetwork = qNetwork;
this.targetQNetwork = qNetwork.clone();
this.targetUpdateFrequency = targetUpdateFrequency;
tdTargetAlgorithm = isDoubleDQN
? new DoubleDQN(this, gamma, errorClamp)
: new StandardDQN(this, gamma, errorClamp);
}
@Override
public void update(List<Transition<Integer>> trainingBatch) {
DataSet targets = tdTargetAlgorithm.computeTDTargets(trainingBatch);
qNetwork.fit(targets.getFeatures(), targets.getLabels());
if(++updateCount % targetUpdateFrequency == 0) {
targetQNetwork = qNetwork.clone();
}
}
}

View File

@ -0,0 +1,26 @@
/*******************************************************************************
* Copyright (c) 2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.rl4j.agent.update;
import lombok.Value;
import org.deeplearning4j.nn.gradient.Gradient;
// Work in progress
@Value
public class Gradients {
private Gradient[] gradients; // Temporary: we'll need something better than a Gradient[]
private int batchSize;
}

View File

@ -0,0 +1,37 @@
/*******************************************************************************
* Copyright (c) 2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.rl4j.agent.update;
import java.util.List;
/**
* The role of IUpdateRule implementations is to use an experience batch to improve the accuracy of the policy.
* Used by {@link org.deeplearning4j.rl4j.agent.AgentLearner AgentLearner}
* @param <EXPERIENCE_TYPE> The type of the experience
*/
public interface IUpdateRule<EXPERIENCE_TYPE> {
/**
* Perform the update
* @param trainingBatch A batch of experience
*/
void update(List<EXPERIENCE_TYPE> trainingBatch);
/**
* @return The total number of times the policy has been updated. In a multi-agent learning context, this total is
* for all the agents.
*/
int getUpdateCount();
}

View File

@ -1,9 +0,0 @@
package org.deeplearning4j.rl4j.environment;
import lombok.Value;
@Value
public class ActionSchema<ACTION> {
private ACTION noOp;
//FIXME ACTION randomAction();
}

View File

@ -1,11 +1,54 @@
/*******************************************************************************
* Copyright (c) 2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.rl4j.environment; package org.deeplearning4j.rl4j.environment;
import java.util.Map; import java.util.Map;
/**
* An interface for environments used by the {@link org.deeplearning4j.rl4j.agent.Agent Agents}.
* @param <ACTION> The type of actions
*/
public interface Environment<ACTION> { public interface Environment<ACTION> {
/**
* @return The {@link Schema} of the environment
*/
Schema<ACTION> getSchema(); Schema<ACTION> getSchema();
/**
* Reset the environment's state to start a new episode.
* @return
*/
Map<String, Object> reset(); Map<String, Object> reset();
/**
* Perform a single step.
*
* @param action The action taken
* @return A {@link StepResult} describing the result of the step.
*/
StepResult step(ACTION action); StepResult step(ACTION action);
/**
* @return True if the episode is finished
*/
boolean isEpisodeFinished(); boolean isEpisodeFinished();
/**
* Called when the agent is finished using this environment instance.
*/
void close(); void close();
} }

View File

@ -0,0 +1,26 @@
/*******************************************************************************
* Copyright (c) 2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.rl4j.environment;
import lombok.Value;
// Work in progress
public interface IActionSchema<ACTION> {
ACTION getNoOp();
// Review: A schema should be data-only and not have behavior
ACTION getRandomAction();
}

View File

@ -0,0 +1,47 @@
/*******************************************************************************
* Copyright (c) 2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.rl4j.environment;
import org.nd4j.linalg.api.rng.Random;
import org.nd4j.linalg.factory.Nd4j;
// Work in progress
public class IntegerActionSchema implements IActionSchema<Integer> {
private final int numActions;
private final int noOpAction;
private final Random rnd;
public IntegerActionSchema(int numActions, int noOpAction) {
this(numActions, noOpAction, Nd4j.getRandom());
}
public IntegerActionSchema(int numActions, int noOpAction, Random rnd) {
this.numActions = numActions;
this.noOpAction = noOpAction;
this.rnd = rnd;
}
@Override
public Integer getNoOp() {
return noOpAction;
}
@Override
public Integer getRandomAction() {
return rnd.nextInt(numActions);
}
}

View File

@ -1,8 +1,24 @@
/*******************************************************************************
* Copyright (c) 2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.rl4j.environment; package org.deeplearning4j.rl4j.environment;
import lombok.Value; import lombok.Value;
// Work in progress
@Value @Value
public class Schema<ACTION> { public class Schema<ACTION> {
private ActionSchema<ACTION> actionSchema; private IActionSchema<ACTION> actionSchema;
} }

View File

@ -1,3 +1,18 @@
/*******************************************************************************
* Copyright (c) 2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.rl4j.environment; package org.deeplearning4j.rl4j.environment;
import lombok.Value; import lombok.Value;

View File

@ -41,6 +41,11 @@ public interface ExperienceHandler<A, E> {
*/ */
int getTrainingBatchSize(); int getTrainingBatchSize();
/**
* @return True if a batch is ready for training.
*/
boolean isTrainingBatchReady();
/** /**
* The elements are returned in the historical order (i.e. in the order they happened) * The elements are returned in the historical order (i.e. in the order they happened)
* @return The list of experience elements * @return The list of experience elements

View File

@ -36,6 +36,7 @@ import java.util.List;
public class ReplayMemoryExperienceHandler<A> implements ExperienceHandler<A, Transition<A>> { public class ReplayMemoryExperienceHandler<A> implements ExperienceHandler<A, Transition<A>> {
private static final int DEFAULT_MAX_REPLAY_MEMORY_SIZE = 150000; private static final int DEFAULT_MAX_REPLAY_MEMORY_SIZE = 150000;
private static final int DEFAULT_BATCH_SIZE = 32; private static final int DEFAULT_BATCH_SIZE = 32;
private final int batchSize;
private IExpReplay<A> expReplay; private IExpReplay<A> expReplay;
@ -43,6 +44,7 @@ public class ReplayMemoryExperienceHandler<A> implements ExperienceHandler<A, Tr
public ReplayMemoryExperienceHandler(IExpReplay<A> expReplay) { public ReplayMemoryExperienceHandler(IExpReplay<A> expReplay) {
this.expReplay = expReplay; this.expReplay = expReplay;
this.batchSize = expReplay.getDesignatedBatchSize();
} }
public ReplayMemoryExperienceHandler(int maxReplayMemorySize, int batchSize, Random random) { public ReplayMemoryExperienceHandler(int maxReplayMemorySize, int batchSize, Random random) {
@ -64,6 +66,11 @@ public class ReplayMemoryExperienceHandler<A> implements ExperienceHandler<A, Tr
return expReplay.getBatchSize(); return expReplay.getBatchSize();
} }
@Override
public boolean isTrainingBatchReady() {
return expReplay.getBatchSize() >= batchSize;
}
/** /**
* @return A batch of experience selected from the replay memory. The replay memory is unchanged after the call. * @return A batch of experience selected from the replay memory. The replay memory is unchanged after the call.
*/ */

View File

@ -30,10 +30,18 @@ import java.util.List;
*/ */
public class StateActionExperienceHandler<A> implements ExperienceHandler<A, StateActionPair<A>> { public class StateActionExperienceHandler<A> implements ExperienceHandler<A, StateActionPair<A>> {
private final int batchSize;
private boolean isFinalObservationSet;
public StateActionExperienceHandler(int batchSize) {
this.batchSize = batchSize;
}
private List<StateActionPair<A>> stateActionPairs = new ArrayList<>(); private List<StateActionPair<A>> stateActionPairs = new ArrayList<>();
public void setFinalObservation(Observation observation) { public void setFinalObservation(Observation observation) {
// Do nothing isFinalObservationSet = true;
} }
public void addExperience(Observation observation, A action, double reward, boolean isTerminal) { public void addExperience(Observation observation, A action, double reward, boolean isTerminal) {
@ -45,6 +53,12 @@ public class StateActionExperienceHandler<A> implements ExperienceHandler<A, Sta
return stateActionPairs.size(); return stateActionPairs.size();
} }
@Override
public boolean isTrainingBatchReady() {
return stateActionPairs.size() >= batchSize
|| (isFinalObservationSet && stateActionPairs.size() > 0);
}
/** /**
* The elements are returned in the historical order (i.e. in the order they happened) * The elements are returned in the historical order (i.e. in the order they happened)
* Note: the experience store is cleared after calling this method. * Note: the experience store is cleared after calling this method.
@ -62,6 +76,7 @@ public class StateActionExperienceHandler<A> implements ExperienceHandler<A, Sta
@Override @Override
public void reset() { public void reset() {
stateActionPairs = new ArrayList<>(); stateActionPairs = new ArrayList<>();
isFinalObservationSet = false;
} }
} }

View File

@ -24,17 +24,38 @@ import org.nd4j.linalg.factory.Nd4j;
* @author Alexandre Boulanger * @author Alexandre Boulanger
*/ */
public class INDArrayHelper { public class INDArrayHelper {
/** /**
* MultiLayerNetwork and ComputationGraph expects input data to be in NCHW in the case of pixels and NS in case of other data types. * Force the input source to have the correct shape:
* * <p><ul>
* We must have either shape 2 (NK) or shape 4 (NCHW) * <li>DL4J requires it to be at least 2D</li>
* <li>RL4J has a convention to have the batch size on dimension 0 to all INDArrays</li>
* </ul></p>
* @param source The {@link INDArray} to be corrected.
* @return The corrected INDArray
*/ */
public static INDArray forceCorrectShape(INDArray source) { public static INDArray forceCorrectShape(INDArray source) {
return source.shape()[0] == 1 && source.shape().length > 1 return source.shape()[0] == 1 && source.rank() > 1
? source ? source
: Nd4j.expandDims(source, 0); : Nd4j.expandDims(source, 0);
} }
/**
* This will create a INDArray with <i>batchSize</i> as dimension 0 and <i>shape</i> as other dimensions.
* For example, if <i>batchSize</i> is 10 and shape is { 1, 3, 4 }, the resulting INDArray shape will be { 10, 3, 4}
* @param batchSize The size of the batch to create
* @param shape The shape of individual elements.
* Note: all shapes in RL4J should have a batch size as dimension 0; in this case the batch size should be 1.
* @return A INDArray
*/
public static INDArray createBatchForShape(long batchSize, long... shape) {
long[] batchShape;
batchShape = new long[shape.length];
System.arraycopy(shape, 0, batchShape, 0, shape.length);
batchShape[0] = batchSize;
return Nd4j.create(batchShape);
}
} }

View File

@ -25,6 +25,7 @@ import org.deeplearning4j.gym.StepReply;
import org.deeplearning4j.rl4j.experience.ExperienceHandler; import org.deeplearning4j.rl4j.experience.ExperienceHandler;
import org.deeplearning4j.rl4j.experience.StateActionExperienceHandler; import org.deeplearning4j.rl4j.experience.StateActionExperienceHandler;
import org.deeplearning4j.rl4j.learning.IHistoryProcessor; import org.deeplearning4j.rl4j.learning.IHistoryProcessor;
import org.deeplearning4j.rl4j.learning.configuration.IAsyncLearningConfiguration;
import org.deeplearning4j.rl4j.learning.listener.TrainingListenerList; import org.deeplearning4j.rl4j.learning.listener.TrainingListenerList;
import org.deeplearning4j.rl4j.mdp.MDP; import org.deeplearning4j.rl4j.mdp.MDP;
import org.deeplearning4j.rl4j.network.NeuralNet; import org.deeplearning4j.rl4j.network.NeuralNet;
@ -49,7 +50,7 @@ public abstract class AsyncThreadDiscrete<OBSERVATION extends Encodable, NN exte
// TODO: Make it configurable with a builder // TODO: Make it configurable with a builder
@Setter(AccessLevel.PROTECTED) @Getter @Setter(AccessLevel.PROTECTED) @Getter
private ExperienceHandler experienceHandler = new StateActionExperienceHandler(); private ExperienceHandler experienceHandler;
public AsyncThreadDiscrete(IAsyncGlobal<NN> asyncGlobal, public AsyncThreadDiscrete(IAsyncGlobal<NN> asyncGlobal,
MDP<OBSERVATION, Integer, DiscreteSpace> mdp, MDP<OBSERVATION, Integer, DiscreteSpace> mdp,
@ -60,6 +61,17 @@ public abstract class AsyncThreadDiscrete<OBSERVATION extends Encodable, NN exte
synchronized (asyncGlobal) { synchronized (asyncGlobal) {
current = (NN) asyncGlobal.getTarget().clone(); current = (NN) asyncGlobal.getTarget().clone();
} }
experienceHandler = new StateActionExperienceHandler(getNStep());
}
private int getNStep() {
IAsyncLearningConfiguration configuration = getConfiguration();
if(configuration == null) {
return Integer.MAX_VALUE;
}
return configuration.getNStep();
} }
// TODO: Add an actor-learner class and be able to inject the update algorithm // TODO: Add an actor-learner class and be able to inject the update algorithm

View File

@ -71,7 +71,6 @@ public class AsyncNStepQLearningThreadDiscrete<OBSERVATION extends Encodable> ex
@Override @Override
protected UpdateAlgorithm<IDQN> buildUpdateAlgorithm() { protected UpdateAlgorithm<IDQN> buildUpdateAlgorithm() {
int[] shape = getHistoryProcessor() == null ? getMdp().getObservationSpace().getShape() : getHistoryProcessor().getConf().getShape(); return new QLearningUpdateAlgorithm(getMdp().getActionSpace().getSize(), configuration.getGamma());
return new QLearningUpdateAlgorithm(shape, getMdp().getActionSpace().getSize(), configuration.getGamma());
} }
} }

View File

@ -17,7 +17,7 @@ package org.deeplearning4j.rl4j.learning.async.nstep.discrete;
import org.deeplearning4j.nn.gradient.Gradient; import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.rl4j.experience.StateActionPair; import org.deeplearning4j.rl4j.experience.StateActionPair;
import org.deeplearning4j.rl4j.learning.Learning; import org.deeplearning4j.rl4j.helper.INDArrayHelper;
import org.deeplearning4j.rl4j.learning.async.UpdateAlgorithm; import org.deeplearning4j.rl4j.learning.async.UpdateAlgorithm;
import org.deeplearning4j.rl4j.network.dqn.IDQN; import org.deeplearning4j.rl4j.network.dqn.IDQN;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
@ -27,15 +27,12 @@ import java.util.List;
public class QLearningUpdateAlgorithm implements UpdateAlgorithm<IDQN> { public class QLearningUpdateAlgorithm implements UpdateAlgorithm<IDQN> {
private final int[] shape;
private final int actionSpaceSize; private final int actionSpaceSize;
private final double gamma; private final double gamma;
public QLearningUpdateAlgorithm(int[] shape, public QLearningUpdateAlgorithm(int actionSpaceSize,
int actionSpaceSize,
double gamma) { double gamma) {
this.shape = shape;
this.actionSpaceSize = actionSpaceSize; this.actionSpaceSize = actionSpaceSize;
this.gamma = gamma; this.gamma = gamma;
} }
@ -44,33 +41,34 @@ public class QLearningUpdateAlgorithm implements UpdateAlgorithm<IDQN> {
public Gradient[] computeGradients(IDQN current, List<StateActionPair<Integer>> experience) { public Gradient[] computeGradients(IDQN current, List<StateActionPair<Integer>> experience) {
int size = experience.size(); int size = experience.size();
int[] nshape = Learning.makeShape(size, shape);
INDArray input = Nd4j.create(nshape);
INDArray targets = Nd4j.create(size, actionSpaceSize);
StateActionPair<Integer> stateActionPair = experience.get(size - 1); StateActionPair<Integer> stateActionPair = experience.get(size - 1);
INDArray data = stateActionPair.getObservation().getData();
INDArray features = INDArrayHelper.createBatchForShape(size, data.shape());
INDArray targets = Nd4j.create(size, actionSpaceSize);
double r; double r;
if (stateActionPair.isTerminal()) { if (stateActionPair.isTerminal()) {
r = 0; r = 0;
} else { } else {
INDArray[] output = null; INDArray[] output = null;
output = current.outputAll(stateActionPair.getObservation().getData()); output = current.outputAll(data);
r = Nd4j.max(output[0]).getDouble(0); r = Nd4j.max(output[0]).getDouble(0);
} }
for (int i = size - 1; i >= 0; i--) { for (int i = size - 1; i >= 0; i--) {
stateActionPair = experience.get(i); stateActionPair = experience.get(i);
data = stateActionPair.getObservation().getData();
input.putRow(i, stateActionPair.getObservation().getData()); features.putRow(i, data);
r = stateActionPair.getReward() + gamma * r; r = stateActionPair.getReward() + gamma * r;
INDArray[] output = current.outputAll(stateActionPair.getObservation().getData()); INDArray[] output = current.outputAll(data);
INDArray row = output[0]; INDArray row = output[0];
row = row.putScalar(stateActionPair.getAction(), r); row = row.putScalar(stateActionPair.getAction(), r);
targets.putRow(i, row); targets.putRow(i, row);
} }
return current.gradient(input, targets); return current.gradient(features, targets);
} }
} }

View File

@ -80,6 +80,11 @@ public class ExpReplay<A> implements IExpReplay<A> {
//log.info("size: "+storage.size()); //log.info("size: "+storage.size());
} }
@Override
public int getDesignatedBatchSize() {
return batchSize;
}
public int getBatchSize() { public int getBatchSize() {
int storageSize = storage.size(); int storageSize = storage.size();
return Math.min(storageSize, batchSize); return Math.min(storageSize, batchSize);

View File

@ -47,4 +47,9 @@ public interface IExpReplay<A> {
* @param transition a new transition to store * @param transition a new transition to store
*/ */
void store(Transition<A> transition); void store(Transition<A> transition);
/**
* @return The desired size of batches
*/
int getDesignatedBatchSize();
} }

View File

@ -51,25 +51,16 @@ import java.util.List;
@Slf4j @Slf4j
public abstract class QLearning<O extends Encodable, A, AS extends ActionSpace<A>> public abstract class QLearning<O extends Encodable, A, AS extends ActionSpace<A>>
extends SyncLearning<O, A, AS, IDQN> extends SyncLearning<O, A, AS, IDQN>
implements TargetQNetworkSource, IEpochTrainer { implements IEpochTrainer {
protected abstract LegacyMDPWrapper<O, A, AS> getLegacyMDPWrapper(); protected abstract LegacyMDPWrapper<O, A, AS> getLegacyMDPWrapper();
protected abstract EpsGreedy<O, A, AS> getEgPolicy(); protected abstract EpsGreedy<A> getEgPolicy();
public abstract MDP<O, A, AS> getMdp(); public abstract MDP<O, A, AS> getMdp();
public abstract IDQN getQNetwork(); public abstract IDQN getQNetwork();
public abstract IDQN getTargetQNetwork();
protected abstract void setTargetQNetwork(IDQN dqn);
protected void updateTargetNetwork() {
log.info("Update target network");
setTargetQNetwork(getQNetwork().clone());
}
public IDQN getNeuralNet() { public IDQN getNeuralNet() {
return getQNetwork(); return getQNetwork();
} }
@ -101,11 +92,6 @@ public abstract class QLearning<O extends Encodable, A, AS extends ActionSpace<A
int numQ = 0; int numQ = 0;
List<Double> scores = new ArrayList<>(); List<Double> scores = new ArrayList<>();
while (currentEpisodeStepCount < getConfiguration().getMaxEpochStep() && !getMdp().isDone()) { while (currentEpisodeStepCount < getConfiguration().getMaxEpochStep() && !getMdp().isDone()) {
if (this.getStepCount() % getConfiguration().getTargetDqnUpdateFreq() == 0) {
updateTargetNetwork();
}
QLStepReturn<Observation> stepR = trainStep(obs); QLStepReturn<Observation> stepR = trainStep(obs);
if (!stepR.getMaxQ().isNaN()) { if (!stepR.getMaxQ().isNaN()) {
@ -146,7 +132,6 @@ public abstract class QLearning<O extends Encodable, A, AS extends ActionSpace<A
protected void resetNetworks() { protected void resetNetworks() {
getQNetwork().reset(); getQNetwork().reset();
getTargetQNetwork().reset();
} }
private InitMdp<Observation> refacInitMdp() { private InitMdp<Observation> refacInitMdp() {

View File

@ -21,6 +21,10 @@ import lombok.AccessLevel;
import lombok.Getter; import lombok.Getter;
import lombok.Setter; import lombok.Setter;
import org.deeplearning4j.gym.StepReply; import org.deeplearning4j.gym.StepReply;
import org.deeplearning4j.rl4j.agent.learning.ILearningBehavior;
import org.deeplearning4j.rl4j.agent.learning.LearningBehavior;
import org.deeplearning4j.rl4j.agent.update.DQNNeuralNetUpdateRule;
import org.deeplearning4j.rl4j.agent.update.IUpdateRule;
import org.deeplearning4j.rl4j.experience.ExperienceHandler; import org.deeplearning4j.rl4j.experience.ExperienceHandler;
import org.deeplearning4j.rl4j.experience.ReplayMemoryExperienceHandler; import org.deeplearning4j.rl4j.experience.ReplayMemoryExperienceHandler;
import org.deeplearning4j.rl4j.learning.IHistoryProcessor; import org.deeplearning4j.rl4j.learning.IHistoryProcessor;
@ -28,9 +32,6 @@ import org.deeplearning4j.rl4j.learning.Learning;
import org.deeplearning4j.rl4j.learning.configuration.QLearningConfiguration; import org.deeplearning4j.rl4j.learning.configuration.QLearningConfiguration;
import org.deeplearning4j.rl4j.learning.sync.Transition; import org.deeplearning4j.rl4j.learning.sync.Transition;
import org.deeplearning4j.rl4j.learning.sync.qlearning.QLearning; import org.deeplearning4j.rl4j.learning.sync.qlearning.QLearning;
import org.deeplearning4j.rl4j.learning.sync.qlearning.discrete.TDTargetAlgorithm.DoubleDQN;
import org.deeplearning4j.rl4j.learning.sync.qlearning.discrete.TDTargetAlgorithm.ITDTargetAlgorithm;
import org.deeplearning4j.rl4j.learning.sync.qlearning.discrete.TDTargetAlgorithm.StandardDQN;
import org.deeplearning4j.rl4j.mdp.MDP; import org.deeplearning4j.rl4j.mdp.MDP;
import org.deeplearning4j.rl4j.network.dqn.IDQN; import org.deeplearning4j.rl4j.network.dqn.IDQN;
import org.deeplearning4j.rl4j.space.Encodable; import org.deeplearning4j.rl4j.space.Encodable;
@ -41,12 +42,8 @@ import org.deeplearning4j.rl4j.space.DiscreteSpace;
import org.deeplearning4j.rl4j.util.LegacyMDPWrapper; import org.deeplearning4j.rl4j.util.LegacyMDPWrapper;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.rng.Random; import org.nd4j.linalg.api.rng.Random;
import org.nd4j.linalg.dataset.api.DataSet;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
import java.util.List;
/** /**
* @author rubenfiszel (ruben.fiszel@epfl.ch) 7/18/16. * @author rubenfiszel (ruben.fiszel@epfl.ch) 7/18/16.
@ -63,22 +60,15 @@ public abstract class QLearningDiscrete<O extends Encodable> extends QLearning<O
@Getter @Getter
private DQNPolicy<O> policy; private DQNPolicy<O> policy;
@Getter @Getter
private EpsGreedy<O, Integer, DiscreteSpace> egPolicy; private EpsGreedy<Integer> egPolicy;
@Getter @Getter
final private IDQN qNetwork; final private IDQN qNetwork;
@Getter
@Setter(AccessLevel.PROTECTED)
private IDQN targetQNetwork;
private int lastAction; private int lastAction;
private double accuReward = 0; private double accuReward = 0;
ITDTargetAlgorithm tdTargetAlgorithm; private final ILearningBehavior<Integer> learningBehavior;
// TODO: User a builder and remove the setter
@Getter(AccessLevel.PROTECTED) @Setter
private ExperienceHandler<Integer, Transition<Integer>> experienceHandler;
protected LegacyMDPWrapper<O, Integer, DiscreteSpace> getLegacyMDPWrapper() { protected LegacyMDPWrapper<O, Integer, DiscreteSpace> getLegacyMDPWrapper() {
return mdp; return mdp;
@ -88,21 +78,31 @@ public abstract class QLearningDiscrete<O extends Encodable> extends QLearning<O
this(mdp, dqn, conf, epsilonNbStep, Nd4j.getRandomFactory().getNewRandomInstance(conf.getSeed())); this(mdp, dqn, conf, epsilonNbStep, Nd4j.getRandomFactory().getNewRandomInstance(conf.getSeed()));
} }
public QLearningDiscrete(MDP<O, Integer, DiscreteSpace> mdp, IDQN dqn, QLearningConfiguration conf, int epsilonNbStep, Random random) {
this(mdp, dqn, conf, epsilonNbStep, buildLearningBehavior(dqn, conf, random), random);
}
public QLearningDiscrete(MDP<O, Integer, DiscreteSpace> mdp, IDQN dqn, QLearningConfiguration conf, public QLearningDiscrete(MDP<O, Integer, DiscreteSpace> mdp, IDQN dqn, QLearningConfiguration conf,
int epsilonNbStep, Random random) { int epsilonNbStep, ILearningBehavior<Integer> learningBehavior, Random random) {
this.configuration = conf; this.configuration = conf;
this.mdp = new LegacyMDPWrapper<>(mdp, null); this.mdp = new LegacyMDPWrapper<>(mdp, null);
qNetwork = dqn; qNetwork = dqn;
targetQNetwork = dqn.clone();
policy = new DQNPolicy(getQNetwork()); policy = new DQNPolicy(getQNetwork());
egPolicy = new EpsGreedy(policy, mdp, conf.getUpdateStart(), epsilonNbStep, random, conf.getMinEpsilon(), egPolicy = new EpsGreedy(policy, mdp, conf.getUpdateStart(), epsilonNbStep, random, conf.getMinEpsilon(),
this); this);
tdTargetAlgorithm = conf.isDoubleDQN() this.learningBehavior = learningBehavior;
? new DoubleDQN(this, conf.getGamma(), conf.getErrorClamp()) }
: new StandardDQN(this, conf.getGamma(), conf.getErrorClamp());
private static ILearningBehavior<Integer> buildLearningBehavior(IDQN qNetwork, QLearningConfiguration conf, Random random) {
IUpdateRule<Transition<Integer>> updateRule = new DQNNeuralNetUpdateRule(qNetwork, conf.getTargetDqnUpdateFreq(), conf.isDoubleDQN(), conf.getGamma(), conf.getErrorClamp());
ExperienceHandler<Integer, Transition<Integer>> experienceHandler = new ReplayMemoryExperienceHandler(conf.getExpRepMaxSize(), conf.getBatchSize(), random);
return LearningBehavior.<Integer, Transition<Integer>>builder()
.experienceHandler(experienceHandler)
.updateRule(updateRule)
.experienceUpdateSize(conf.getBatchSize())
.build();
experienceHandler = new ReplayMemoryExperienceHandler(conf.getExpRepMaxSize(), conf.getBatchSize(), random);
} }
public MDP<O, Integer, DiscreteSpace> getMdp() { public MDP<O, Integer, DiscreteSpace> getMdp() {
@ -119,7 +119,7 @@ public abstract class QLearningDiscrete<O extends Encodable> extends QLearning<O
public void preEpoch() { public void preEpoch() {
lastAction = mdp.getActionSpace().noOp(); lastAction = mdp.getActionSpace().noOp();
accuReward = 0; accuReward = 0;
experienceHandler.reset(); learningBehavior.handleEpisodeStart();
} }
@Override @Override
@ -136,12 +136,6 @@ public abstract class QLearningDiscrete<O extends Encodable> extends QLearning<O
*/ */
protected QLStepReturn<Observation> trainStep(Observation obs) { protected QLStepReturn<Observation> trainStep(Observation obs) {
boolean isHistoryProcessor = getHistoryProcessor() != null;
int skipFrame = isHistoryProcessor ? getHistoryProcessor().getConf().getSkipFrame() : 1;
int historyLength = isHistoryProcessor ? getHistoryProcessor().getConf().getHistoryLength() : 1;
int updateStart = this.getConfiguration().getUpdateStart()
+ ((this.getConfiguration().getBatchSize() + historyLength) * skipFrame);
Double maxQ = Double.NaN; //ignore if Nan for stats Double maxQ = Double.NaN; //ignore if Nan for stats
//if step of training, just repeat lastAction //if step of training, just repeat lastAction
@ -160,29 +154,15 @@ public abstract class QLearningDiscrete<O extends Encodable> extends QLearning<O
if (!obs.isSkipped()) { if (!obs.isSkipped()) {
// Add experience // Add experience
experienceHandler.addExperience(obs, lastAction, accuReward, stepReply.isDone()); learningBehavior.handleNewExperience(obs, lastAction, accuReward, stepReply.isDone());
accuReward = 0; accuReward = 0;
// Update NN
// FIXME: maybe start updating when experience replay has reached a certain size instead of using "updateStart"?
if (this.getStepCount() > updateStart) {
DataSet targets = setTarget(experienceHandler.generateTrainingBatch());
getQNetwork().fit(targets.getFeatures(), targets.getLabels());
}
} }
return new QLStepReturn<>(maxQ, getQNetwork().getLatestScore(), stepReply); return new QLStepReturn<>(maxQ, getQNetwork().getLatestScore(), stepReply);
} }
protected DataSet setTarget(List<Transition<Integer>> transitions) {
if (transitions.size() == 0)
throw new IllegalArgumentException("too few transitions");
return tdTargetAlgorithm.computeTDTargets(transitions);
}
@Override @Override
protected void finishEpoch(Observation observation) { protected void finishEpoch(Observation observation) {
experienceHandler.setFinalObservation(observation); learningBehavior.handleEpisodeEnd(observation);
} }
} }

View File

@ -2,21 +2,19 @@ package org.deeplearning4j.rl4j.mdp;
import lombok.Getter; import lombok.Getter;
import lombok.Setter; import lombok.Setter;
import org.deeplearning4j.rl4j.environment.ActionSchema; import org.deeplearning4j.rl4j.environment.*;
import org.deeplearning4j.rl4j.environment.Environment; import org.nd4j.linalg.api.rng.Random;
import org.deeplearning4j.rl4j.environment.Schema; import org.nd4j.linalg.factory.Nd4j;
import org.deeplearning4j.rl4j.environment.StepResult;
import java.util.HashMap; import java.util.HashMap;
import java.util.Map; import java.util.Map;
import java.util.Random;
public class CartpoleEnvironment implements Environment<Integer> { public class CartpoleEnvironment implements Environment<Integer> {
private static final int NUM_ACTIONS = 2; private static final int NUM_ACTIONS = 2;
private static final int ACTION_LEFT = 0; private static final int ACTION_LEFT = 0;
private static final int ACTION_RIGHT = 1; private static final int ACTION_RIGHT = 1;
private static final Schema<Integer> schema = new Schema<>(new ActionSchema<>(ACTION_LEFT)); private final Schema<Integer> schema;
public enum KinematicsIntegrators { Euler, SemiImplicitEuler }; public enum KinematicsIntegrators { Euler, SemiImplicitEuler };
@ -48,11 +46,12 @@ public class CartpoleEnvironment implements Environment<Integer> {
private Integer stepsBeyondDone; private Integer stepsBeyondDone;
public CartpoleEnvironment() { public CartpoleEnvironment() {
rnd = new Random(); this(Nd4j.getRandom());
} }
public CartpoleEnvironment(int seed) { public CartpoleEnvironment(Random rnd) {
rnd = new Random(seed); this.rnd = rnd;
this.schema = new Schema<Integer>(new IntegerActionSchema(NUM_ACTIONS, ACTION_LEFT, rnd));
} }
@Override @Override

View File

@ -17,16 +17,19 @@
package org.deeplearning4j.rl4j.policy; package org.deeplearning4j.rl4j.policy;
import lombok.AllArgsConstructor; import lombok.Builder;
import lombok.NonNull;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.deeplearning4j.rl4j.environment.IActionSchema;
import org.deeplearning4j.rl4j.learning.IEpochTrainer; import org.deeplearning4j.rl4j.learning.IEpochTrainer;
import org.deeplearning4j.rl4j.mdp.MDP; import org.deeplearning4j.rl4j.mdp.MDP;
import org.deeplearning4j.rl4j.network.NeuralNet; import org.deeplearning4j.rl4j.network.NeuralNet;
import org.deeplearning4j.rl4j.space.Encodable;
import org.deeplearning4j.rl4j.observation.Observation; import org.deeplearning4j.rl4j.observation.Observation;
import org.deeplearning4j.rl4j.space.ActionSpace; import org.deeplearning4j.rl4j.space.ActionSpace;
import org.deeplearning4j.rl4j.space.Encodable;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.rng.Random; import org.nd4j.linalg.api.rng.Random;
import org.nd4j.linalg.factory.Nd4j;
/** /**
* @author rubenfiszel (ruben.fiszel@epfl.ch) 7/24/16. * @author rubenfiszel (ruben.fiszel@epfl.ch) 7/24/16.
@ -38,18 +41,60 @@ import org.nd4j.linalg.api.rng.Random;
* epislon is annealed to minEpsilon over epsilonNbStep steps * epislon is annealed to minEpsilon over epsilonNbStep steps
* *
*/ */
@AllArgsConstructor
@Slf4j @Slf4j
public class EpsGreedy<OBSERVATION extends Encodable, A, AS extends ActionSpace<A>> extends Policy<A> { public class EpsGreedy<A> extends Policy<A> {
final private Policy<A> policy; final private INeuralNetPolicy<A> policy;
final private MDP<OBSERVATION, A, AS> mdp;
final private int updateStart; final private int updateStart;
final private int epsilonNbStep; final private int epsilonNbStep;
final private Random rnd; final private Random rnd;
final private double minEpsilon; final private double minEpsilon;
private final IActionSchema<A> actionSchema;
final private MDP<Encodable, A, ActionSpace<A>> mdp;
final private IEpochTrainer learning; final private IEpochTrainer learning;
// Using agent's (learning's) step count is incorrect; frame skipping makes epsilon's value decrease too quickly
private int annealingStep = 0;
@Deprecated
public <OBSERVATION extends Encodable, AS extends ActionSpace<A>> EpsGreedy(Policy<A> policy,
MDP<Encodable, A, ActionSpace<A>> mdp,
int updateStart,
int epsilonNbStep,
Random rnd,
double minEpsilon,
IEpochTrainer learning) {
this.policy = policy;
this.mdp = mdp;
this.updateStart = updateStart;
this.epsilonNbStep = epsilonNbStep;
this.rnd = rnd;
this.minEpsilon = minEpsilon;
this.learning = learning;
this.actionSchema = null;
}
public EpsGreedy(@NonNull Policy<A> policy, @NonNull IActionSchema<A> actionSchema, double minEpsilon, int updateStart, int epsilonNbStep) {
this(policy, actionSchema, minEpsilon, updateStart, epsilonNbStep, null);
}
@Builder
public EpsGreedy(@NonNull INeuralNetPolicy<A> policy, @NonNull IActionSchema<A> actionSchema, double minEpsilon, int updateStart, int epsilonNbStep, Random rnd) {
this.policy = policy;
this.rnd = rnd == null ? Nd4j.getRandom() : rnd;
this.minEpsilon = minEpsilon;
this.updateStart = updateStart;
this.epsilonNbStep = epsilonNbStep;
this.actionSchema = actionSchema;
this.mdp = null;
this.learning = null;
}
public NeuralNet getNeuralNet() { public NeuralNet getNeuralNet() {
return policy.getNeuralNet(); return policy.getNeuralNet();
} }
@ -57,6 +102,11 @@ public class EpsGreedy<OBSERVATION extends Encodable, A, AS extends ActionSpace<
public A nextAction(INDArray input) { public A nextAction(INDArray input) {
double ep = getEpsilon(); double ep = getEpsilon();
if(actionSchema != null) {
// Only legacy classes should pass here.
throw new RuntimeException("nextAction(Observation observation) should be called when using a AgentLearner");
}
if (learning.getStepCount() % 500 == 1) if (learning.getStepCount() % 500 == 1)
log.info("EP: " + ep + " " + learning.getStepCount()); log.info("EP: " + ep + " " + learning.getStepCount());
if (rnd.nextDouble() > ep) if (rnd.nextDouble() > ep)
@ -66,10 +116,31 @@ public class EpsGreedy<OBSERVATION extends Encodable, A, AS extends ActionSpace<
} }
public A nextAction(Observation observation) { public A nextAction(Observation observation) {
return this.nextAction(observation.getData()); if(actionSchema == null) {
return this.nextAction(observation.getData());
}
A result;
double ep = getEpsilon();
if (annealingStep % 500 == 1) {
log.info("EP: " + ep + " " + annealingStep);
}
if (rnd.nextDouble() > ep) {
result = policy.nextAction(observation);
}
else {
result = actionSchema.getRandomAction();
}
++annealingStep;
return result;
} }
public double getEpsilon() { public double getEpsilon() {
return Math.min(1.0, Math.max(minEpsilon, 1.0 - (learning.getStepCount() - updateStart) * 1.0 / epsilonNbStep)); int step = actionSchema != null ? annealingStep : learning.getStepCount();
return Math.min(1.0, Math.max(minEpsilon, 1.0 - (step - updateStart) * 1.0 / epsilonNbStep));
} }
} }

View File

@ -0,0 +1,7 @@
package org.deeplearning4j.rl4j.policy;
import org.deeplearning4j.rl4j.network.NeuralNet;
public interface INeuralNetPolicy<ACTION> extends IPolicy<ACTION> {
NeuralNet getNeuralNet();
}

View File

@ -34,7 +34,7 @@ import org.deeplearning4j.rl4j.util.LegacyMDPWrapper;
* *
* A Policy responsability is to choose the next action given a state * A Policy responsability is to choose the next action given a state
*/ */
public abstract class Policy<A> implements IPolicy<A> { public abstract class Policy<A> implements INeuralNetPolicy<A> {
public abstract NeuralNet getNeuralNet(); public abstract NeuralNet getNeuralNet();

View File

@ -0,0 +1,211 @@
package org.deeplearning4j.rl4j.agent;
import org.deeplearning4j.rl4j.agent.learning.LearningBehavior;
import org.deeplearning4j.rl4j.environment.Environment;
import org.deeplearning4j.rl4j.environment.IntegerActionSchema;
import org.deeplearning4j.rl4j.environment.Schema;
import org.deeplearning4j.rl4j.environment.StepResult;
import org.deeplearning4j.rl4j.observation.Observation;
import org.deeplearning4j.rl4j.observation.transform.TransformProcess;
import org.deeplearning4j.rl4j.policy.IPolicy;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.mockito.ArgumentCaptor;
import org.mockito.Mock;
import org.mockito.invocation.InvocationOnMock;
import org.mockito.junit.MockitoJUnitRunner;
import org.mockito.stubbing.Answer;
import org.nd4j.linalg.factory.Nd4j;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import static org.mockito.ArgumentMatchers.*;
import static org.mockito.Mockito.*;
import static org.junit.Assert.*;
@RunWith(MockitoJUnitRunner.class)
public class AgentLearnerTest {
@Mock
Environment<Integer> environmentMock;
@Mock
TransformProcess transformProcessMock;
@Mock
IPolicy<Integer> policyMock;
@Mock
LearningBehavior<Integer, Object> learningBehaviorMock;
@Test
public void when_episodeIsStarted_expect_learningBehaviorHandleEpisodeStartCalled() {
// Arrange
AgentLearner<Integer> sut = AgentLearner.builder(environmentMock, transformProcessMock, policyMock, learningBehaviorMock)
.maxEpisodeSteps(3)
.build();
Schema schema = new Schema(new IntegerActionSchema(0, -1));
when(environmentMock.reset()).thenReturn(new HashMap<>());
when(environmentMock.getSchema()).thenReturn(schema);
StepResult stepResult = new StepResult(new HashMap<>(), 234.0, false);
when(environmentMock.step(any(Integer.class))).thenReturn(stepResult);
when(transformProcessMock.transform(any(Map.class), anyInt(), anyBoolean())).thenReturn(new Observation(Nd4j.create(new double[] { 123.0 })));
when(policyMock.nextAction(any(Observation.class))).thenReturn(123);
// Act
sut.run();
// Assert
verify(learningBehaviorMock, times(1)).handleEpisodeStart();
}
@Test
public void when_runIsCalled_expect_experienceHandledWithLearningBehavior() {
// Arrange
AgentLearner<Integer> sut = AgentLearner.builder(environmentMock, transformProcessMock, policyMock, learningBehaviorMock)
.maxEpisodeSteps(4)
.build();
Schema schema = new Schema(new IntegerActionSchema(0, -1));
when(environmentMock.getSchema()).thenReturn(schema);
when(environmentMock.reset()).thenReturn(new HashMap<>());
double[] reward = new double[] { 0.0 };
when(environmentMock.step(any(Integer.class)))
.thenAnswer(a -> new StepResult(new HashMap<>(), ++reward[0], reward[0] == 4.0));
when(environmentMock.isEpisodeFinished()).thenAnswer(x -> reward[0] == 4.0);
when(transformProcessMock.transform(any(Map.class), anyInt(), anyBoolean()))
.thenAnswer(new Answer<Observation>() {
public Observation answer(InvocationOnMock invocation) throws Throwable {
int step = (int)invocation.getArgument(1);
boolean isTerminal = (boolean)invocation.getArgument(2);
return (step % 2 == 0 || isTerminal)
? new Observation(Nd4j.create(new double[] { step * 1.1 }))
: Observation.SkippedObservation;
}
});
when(policyMock.nextAction(any(Observation.class))).thenAnswer(x -> (int)reward[0]);
// Act
sut.run();
// Assert
ArgumentCaptor<Observation> observationCaptor = ArgumentCaptor.forClass(Observation.class);
ArgumentCaptor<Integer> actionCaptor = ArgumentCaptor.forClass(Integer.class);
ArgumentCaptor<Double> rewardCaptor = ArgumentCaptor.forClass(Double.class);
ArgumentCaptor<Boolean> isTerminalCaptor = ArgumentCaptor.forClass(Boolean.class);
verify(learningBehaviorMock, times(2)).handleNewExperience(observationCaptor.capture(), actionCaptor.capture(), rewardCaptor.capture(), isTerminalCaptor.capture());
List<Observation> observations = observationCaptor.getAllValues();
List<Integer> actions = actionCaptor.getAllValues();
List<Double> rewards = rewardCaptor.getAllValues();
List<Boolean> isTerminalList = isTerminalCaptor.getAllValues();
assertEquals(0.0, observations.get(0).getData().getDouble(0), 0.00001);
assertEquals(0, (int)actions.get(0));
assertEquals(0.0 + 1.0, rewards.get(0), 0.00001);
assertFalse(isTerminalList.get(0));
assertEquals(2.2, observations.get(1).getData().getDouble(0), 0.00001);
assertEquals(2, (int)actions.get(1));
assertEquals(2.0 + 3.0, rewards.get(1), 0.00001);
assertFalse(isTerminalList.get(1));
ArgumentCaptor<Observation> finalObservationCaptor = ArgumentCaptor.forClass(Observation.class);
verify(learningBehaviorMock, times(1)).handleEpisodeEnd(finalObservationCaptor.capture());
assertEquals(4.4, finalObservationCaptor.getValue().getData().getDouble(0), 0.00001);
}
@Test
public void when_runIsCalledMultipleTimes_expect_totalStepCountCorrect() {
// Arrange
AgentLearner<Integer> sut = AgentLearner.builder(environmentMock, transformProcessMock, policyMock, learningBehaviorMock)
.maxEpisodeSteps(4)
.build();
Schema schema = new Schema(new IntegerActionSchema(0, -1));
when(environmentMock.getSchema()).thenReturn(schema);
when(environmentMock.reset()).thenReturn(new HashMap<>());
double[] reward = new double[] { 0.0 };
when(environmentMock.step(any(Integer.class)))
.thenAnswer(a -> new StepResult(new HashMap<>(), ++reward[0], reward[0] == 4.0));
when(environmentMock.isEpisodeFinished()).thenAnswer(x -> reward[0] == 4.0);
when(transformProcessMock.transform(any(Map.class), anyInt(), anyBoolean()))
.thenAnswer(new Answer<Observation>() {
public Observation answer(InvocationOnMock invocation) throws Throwable {
int step = (int)invocation.getArgument(1);
boolean isTerminal = (boolean)invocation.getArgument(2);
return (step % 2 == 0 || isTerminal)
? new Observation(Nd4j.create(new double[] { step * 1.1 }))
: Observation.SkippedObservation;
}
});
when(policyMock.nextAction(any(Observation.class))).thenAnswer(x -> (int)reward[0]);
// Act
sut.run();
reward[0] = 0.0;
sut.run();
// Assert
assertEquals(8, sut.getTotalStepCount());
}
@Test
public void when_runIsCalledMultipleTimes_expect_rewardSentToLearningBehaviorToBeCorrect() {
// Arrange
AgentLearner<Integer> sut = AgentLearner.builder(environmentMock, transformProcessMock, policyMock, learningBehaviorMock)
.maxEpisodeSteps(4)
.build();
Schema schema = new Schema(new IntegerActionSchema(0, -1));
when(environmentMock.getSchema()).thenReturn(schema);
when(environmentMock.reset()).thenReturn(new HashMap<>());
double[] reward = new double[] { 0.0 };
when(environmentMock.step(any(Integer.class)))
.thenAnswer(a -> new StepResult(new HashMap<>(), ++reward[0], reward[0] == 4.0));
when(environmentMock.isEpisodeFinished()).thenAnswer(x -> reward[0] == 4.0);
when(transformProcessMock.transform(any(Map.class), anyInt(), anyBoolean()))
.thenAnswer(new Answer<Observation>() {
public Observation answer(InvocationOnMock invocation) throws Throwable {
int step = (int)invocation.getArgument(1);
boolean isTerminal = (boolean)invocation.getArgument(2);
return (step % 2 == 0 || isTerminal)
? new Observation(Nd4j.create(new double[] { step * 1.1 }))
: Observation.SkippedObservation;
}
});
when(policyMock.nextAction(any(Observation.class))).thenAnswer(x -> (int)reward[0]);
// Act
sut.run();
reward[0] = 0.0;
sut.run();
// Assert
ArgumentCaptor<Double> rewardCaptor = ArgumentCaptor.forClass(Double.class);
verify(learningBehaviorMock, times(4)).handleNewExperience(any(Observation.class), any(Integer.class), rewardCaptor.capture(), any(Boolean.class));
List<Double> rewards = rewardCaptor.getAllValues();
// rewardAtLastExperience at the end of 1st call to .run() should not leak into 2nd call.
assertEquals(0.0 + 1.0, rewards.get(2), 0.00001);
assertEquals(2.0 + 3.0, rewards.get(3), 0.00001);
}
}

View File

@ -1,10 +1,7 @@
package org.deeplearning4j.rl4j.agent; package org.deeplearning4j.rl4j.agent;
import org.deeplearning4j.rl4j.agent.listener.AgentListener; import org.deeplearning4j.rl4j.agent.listener.AgentListener;
import org.deeplearning4j.rl4j.environment.ActionSchema; import org.deeplearning4j.rl4j.environment.*;
import org.deeplearning4j.rl4j.environment.Environment;
import org.deeplearning4j.rl4j.environment.Schema;
import org.deeplearning4j.rl4j.environment.StepResult;
import org.deeplearning4j.rl4j.observation.Observation; import org.deeplearning4j.rl4j.observation.Observation;
import org.deeplearning4j.rl4j.observation.transform.TransformProcess; import org.deeplearning4j.rl4j.observation.transform.TransformProcess;
import org.deeplearning4j.rl4j.policy.IPolicy; import org.deeplearning4j.rl4j.policy.IPolicy;
@ -12,6 +9,7 @@ import org.junit.Rule;
import org.junit.Test; import org.junit.Test;
import static org.junit.Assert.*; import static org.junit.Assert.*;
import org.junit.runner.RunWith;
import org.mockito.*; import org.mockito.*;
import org.mockito.junit.*; import org.mockito.junit.*;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
@ -23,8 +21,8 @@ import java.util.Map;
import static org.mockito.ArgumentMatchers.*; import static org.mockito.ArgumentMatchers.*;
import static org.mockito.Mockito.*; import static org.mockito.Mockito.*;
@RunWith(MockitoJUnitRunner.class)
public class AgentTest { public class AgentTest {
@Mock Environment environmentMock; @Mock Environment environmentMock;
@Mock TransformProcess transformProcessMock; @Mock TransformProcess transformProcessMock;
@Mock IPolicy policyMock; @Mock IPolicy policyMock;
@ -102,7 +100,7 @@ public class AgentTest {
public void when_runIsCalled_expect_agentIsReset() { public void when_runIsCalled_expect_agentIsReset() {
// Arrange // Arrange
Map<String, Object> envResetResult = new HashMap<>(); Map<String, Object> envResetResult = new HashMap<>();
Schema schema = new Schema(new ActionSchema<>(-1)); Schema schema = new Schema(new IntegerActionSchema(0, -1));
when(environmentMock.reset()).thenReturn(envResetResult); when(environmentMock.reset()).thenReturn(envResetResult);
when(environmentMock.getSchema()).thenReturn(schema); when(environmentMock.getSchema()).thenReturn(schema);
@ -119,7 +117,7 @@ public class AgentTest {
sut.run(); sut.run();
// Assert // Assert
assertEquals(0, sut.getEpisodeStepNumber()); assertEquals(0, sut.getEpisodeStepCount());
verify(transformProcessMock).transform(envResetResult, 0, false); verify(transformProcessMock).transform(envResetResult, 0, false);
verify(policyMock, times(1)).reset(); verify(policyMock, times(1)).reset();
assertEquals(0.0, sut.getReward(), 0.00001); assertEquals(0.0, sut.getReward(), 0.00001);
@ -130,7 +128,7 @@ public class AgentTest {
public void when_runIsCalled_expect_onBeforeAndAfterEpisodeCalled() { public void when_runIsCalled_expect_onBeforeAndAfterEpisodeCalled() {
// Arrange // Arrange
Map<String, Object> envResetResult = new HashMap<>(); Map<String, Object> envResetResult = new HashMap<>();
Schema schema = new Schema(new ActionSchema<>(-1)); Schema schema = new Schema(new IntegerActionSchema(0, -1));
when(environmentMock.reset()).thenReturn(envResetResult); when(environmentMock.reset()).thenReturn(envResetResult);
when(environmentMock.getSchema()).thenReturn(schema); when(environmentMock.getSchema()).thenReturn(schema);
@ -152,7 +150,7 @@ public class AgentTest {
public void when_onBeforeEpisodeReturnsStop_expect_performStepAndOnAfterEpisodeNotCalled() { public void when_onBeforeEpisodeReturnsStop_expect_performStepAndOnAfterEpisodeNotCalled() {
// Arrange // Arrange
Map<String, Object> envResetResult = new HashMap<>(); Map<String, Object> envResetResult = new HashMap<>();
Schema schema = new Schema(new ActionSchema<>(-1)); Schema schema = new Schema(new IntegerActionSchema(0, -1));
when(environmentMock.reset()).thenReturn(envResetResult); when(environmentMock.reset()).thenReturn(envResetResult);
when(environmentMock.getSchema()).thenReturn(schema); when(environmentMock.getSchema()).thenReturn(schema);
@ -179,7 +177,7 @@ public class AgentTest {
public void when_runIsCalledWithoutMaxStep_expect_agentRunUntilEpisodeIsFinished() { public void when_runIsCalledWithoutMaxStep_expect_agentRunUntilEpisodeIsFinished() {
// Arrange // Arrange
Map<String, Object> envResetResult = new HashMap<>(); Map<String, Object> envResetResult = new HashMap<>();
Schema schema = new Schema(new ActionSchema<>(-1)); Schema schema = new Schema(new IntegerActionSchema(0, -1));
when(environmentMock.reset()).thenReturn(envResetResult); when(environmentMock.reset()).thenReturn(envResetResult);
when(environmentMock.getSchema()).thenReturn(schema); when(environmentMock.getSchema()).thenReturn(schema);
@ -191,10 +189,10 @@ public class AgentTest {
final Agent spy = Mockito.spy(sut); final Agent spy = Mockito.spy(sut);
doAnswer(invocation -> { doAnswer(invocation -> {
((Agent)invocation.getMock()).incrementEpisodeStepNumber(); ((Agent)invocation.getMock()).incrementEpisodeStepCount();
return null; return null;
}).when(spy).performStep(); }).when(spy).performStep();
when(environmentMock.isEpisodeFinished()).thenAnswer(invocation -> spy.getEpisodeStepNumber() >= 5 ); when(environmentMock.isEpisodeFinished()).thenAnswer(invocation -> spy.getEpisodeStepCount() >= 5 );
// Act // Act
spy.run(); spy.run();
@ -209,7 +207,7 @@ public class AgentTest {
public void when_maxStepsIsReachedBeforeEposideEnds_expect_runTerminated() { public void when_maxStepsIsReachedBeforeEposideEnds_expect_runTerminated() {
// Arrange // Arrange
Map<String, Object> envResetResult = new HashMap<>(); Map<String, Object> envResetResult = new HashMap<>();
Schema schema = new Schema(new ActionSchema<>(-1)); Schema schema = new Schema(new IntegerActionSchema(0, -1));
when(environmentMock.reset()).thenReturn(envResetResult); when(environmentMock.reset()).thenReturn(envResetResult);
when(environmentMock.getSchema()).thenReturn(schema); when(environmentMock.getSchema()).thenReturn(schema);
@ -222,7 +220,7 @@ public class AgentTest {
final Agent spy = Mockito.spy(sut); final Agent spy = Mockito.spy(sut);
doAnswer(invocation -> { doAnswer(invocation -> {
((Agent)invocation.getMock()).incrementEpisodeStepNumber(); ((Agent)invocation.getMock()).incrementEpisodeStepCount();
return null; return null;
}).when(spy).performStep(); }).when(spy).performStep();
@ -239,7 +237,7 @@ public class AgentTest {
public void when_initialObservationsAreSkipped_expect_performNoOpAction() { public void when_initialObservationsAreSkipped_expect_performNoOpAction() {
// Arrange // Arrange
Map<String, Object> envResetResult = new HashMap<>(); Map<String, Object> envResetResult = new HashMap<>();
Schema schema = new Schema(new ActionSchema<>(-1)); Schema schema = new Schema(new IntegerActionSchema(0, -1));
when(environmentMock.reset()).thenReturn(envResetResult); when(environmentMock.reset()).thenReturn(envResetResult);
when(environmentMock.getSchema()).thenReturn(schema); when(environmentMock.getSchema()).thenReturn(schema);
@ -264,7 +262,7 @@ public class AgentTest {
public void when_initialObservationsAreSkipped_expect_performNoOpActionAnd() { public void when_initialObservationsAreSkipped_expect_performNoOpActionAnd() {
// Arrange // Arrange
Map<String, Object> envResetResult = new HashMap<>(); Map<String, Object> envResetResult = new HashMap<>();
Schema schema = new Schema(new ActionSchema<>(-1)); Schema schema = new Schema(new IntegerActionSchema(0, -1));
when(environmentMock.reset()).thenReturn(envResetResult); when(environmentMock.reset()).thenReturn(envResetResult);
when(environmentMock.getSchema()).thenReturn(schema); when(environmentMock.getSchema()).thenReturn(schema);
@ -289,7 +287,7 @@ public class AgentTest {
public void when_observationsIsSkipped_expect_performLastAction() { public void when_observationsIsSkipped_expect_performLastAction() {
// Arrange // Arrange
Map<String, Object> envResetResult = new HashMap<>(); Map<String, Object> envResetResult = new HashMap<>();
Schema schema = new Schema(new ActionSchema<>(-1)); Schema schema = new Schema(new IntegerActionSchema(0, -1));
when(environmentMock.reset()).thenReturn(envResetResult); when(environmentMock.reset()).thenReturn(envResetResult);
when(environmentMock.step(any(Integer.class))).thenReturn(new StepResult(envResetResult, 0.0, false)); when(environmentMock.step(any(Integer.class))).thenReturn(new StepResult(envResetResult, 0.0, false));
when(environmentMock.getSchema()).thenReturn(schema); when(environmentMock.getSchema()).thenReturn(schema);
@ -331,7 +329,7 @@ public class AgentTest {
@Test @Test
public void when_onBeforeStepReturnsStop_expect_performStepAndOnAfterEpisodeNotCalled() { public void when_onBeforeStepReturnsStop_expect_performStepAndOnAfterEpisodeNotCalled() {
// Arrange // Arrange
Schema schema = new Schema(new ActionSchema<>(-1)); Schema schema = new Schema(new IntegerActionSchema(0, -1));
when(environmentMock.reset()).thenReturn(new HashMap<>()); when(environmentMock.reset()).thenReturn(new HashMap<>());
when(environmentMock.getSchema()).thenReturn(schema); when(environmentMock.getSchema()).thenReturn(schema);
@ -358,7 +356,7 @@ public class AgentTest {
@Test @Test
public void when_observationIsNotSkipped_expect_policyActionIsSentToEnvironment() { public void when_observationIsNotSkipped_expect_policyActionIsSentToEnvironment() {
// Arrange // Arrange
Schema schema = new Schema(new ActionSchema<>(-1)); Schema schema = new Schema(new IntegerActionSchema(0, -1));
when(environmentMock.reset()).thenReturn(new HashMap<>()); when(environmentMock.reset()).thenReturn(new HashMap<>());
when(environmentMock.getSchema()).thenReturn(schema); when(environmentMock.getSchema()).thenReturn(schema);
when(environmentMock.step(any(Integer.class))).thenReturn(new StepResult(new HashMap<>(), 0.0, false)); when(environmentMock.step(any(Integer.class))).thenReturn(new StepResult(new HashMap<>(), 0.0, false));
@ -381,7 +379,7 @@ public class AgentTest {
@Test @Test
public void when_stepResultIsReceived_expect_observationAndRewardUpdated() { public void when_stepResultIsReceived_expect_observationAndRewardUpdated() {
// Arrange // Arrange
Schema schema = new Schema(new ActionSchema<>(-1)); Schema schema = new Schema(new IntegerActionSchema(0, -1));
when(environmentMock.reset()).thenReturn(new HashMap<>()); when(environmentMock.reset()).thenReturn(new HashMap<>());
when(environmentMock.getSchema()).thenReturn(schema); when(environmentMock.getSchema()).thenReturn(schema);
when(environmentMock.step(any(Integer.class))).thenReturn(new StepResult(new HashMap<>(), 234.0, false)); when(environmentMock.step(any(Integer.class))).thenReturn(new StepResult(new HashMap<>(), 234.0, false));
@ -405,7 +403,7 @@ public class AgentTest {
@Test @Test
public void when_stepIsDone_expect_onAfterStepAndWithStepResult() { public void when_stepIsDone_expect_onAfterStepAndWithStepResult() {
// Arrange // Arrange
Schema schema = new Schema(new ActionSchema<>(-1)); Schema schema = new Schema(new IntegerActionSchema(0, -1));
when(environmentMock.reset()).thenReturn(new HashMap<>()); when(environmentMock.reset()).thenReturn(new HashMap<>());
when(environmentMock.getSchema()).thenReturn(schema); when(environmentMock.getSchema()).thenReturn(schema);
StepResult stepResult = new StepResult(new HashMap<>(), 234.0, false); StepResult stepResult = new StepResult(new HashMap<>(), 234.0, false);
@ -430,7 +428,7 @@ public class AgentTest {
@Test @Test
public void when_onAfterStepReturnsStop_expect_onAfterEpisodeNotCalled() { public void when_onAfterStepReturnsStop_expect_onAfterEpisodeNotCalled() {
// Arrange // Arrange
Schema schema = new Schema(new ActionSchema<>(-1)); Schema schema = new Schema(new IntegerActionSchema(0, -1));
when(environmentMock.reset()).thenReturn(new HashMap<>()); when(environmentMock.reset()).thenReturn(new HashMap<>());
when(environmentMock.getSchema()).thenReturn(schema); when(environmentMock.getSchema()).thenReturn(schema);
StepResult stepResult = new StepResult(new HashMap<>(), 234.0, false); StepResult stepResult = new StepResult(new HashMap<>(), 234.0, false);
@ -458,7 +456,7 @@ public class AgentTest {
@Test @Test
public void when_runIsCalled_expect_onAfterEpisodeIsCalled() { public void when_runIsCalled_expect_onAfterEpisodeIsCalled() {
// Arrange // Arrange
Schema schema = new Schema(new ActionSchema<>(-1)); Schema schema = new Schema(new IntegerActionSchema(0, -1));
when(environmentMock.reset()).thenReturn(new HashMap<>()); when(environmentMock.reset()).thenReturn(new HashMap<>());
when(environmentMock.getSchema()).thenReturn(schema); when(environmentMock.getSchema()).thenReturn(schema);
StepResult stepResult = new StepResult(new HashMap<>(), 234.0, false); StepResult stepResult = new StepResult(new HashMap<>(), 234.0, false);

View File

@ -0,0 +1,133 @@
package org.deeplearning4j.rl4j.agent.learning;
import org.deeplearning4j.rl4j.agent.update.IUpdateRule;
import org.deeplearning4j.rl4j.experience.ExperienceHandler;
import org.deeplearning4j.rl4j.observation.Observation;
import org.junit.Before;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.mockito.ArgumentCaptor;
import org.mockito.Mock;
import org.mockito.junit.MockitoJUnitRunner;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import java.util.ArrayList;
import java.util.List;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.*;
@RunWith(MockitoJUnitRunner.class)
public class LearningBehaviorTest {
@Mock
ExperienceHandler<Integer, Object> experienceHandlerMock;
@Mock
IUpdateRule<Object> updateRuleMock;
LearningBehavior<Integer, Object> sut;
@Before
public void setup() {
sut = LearningBehavior.<Integer, Object>builder()
.experienceHandler(experienceHandlerMock)
.updateRule(updateRuleMock)
.build();
}
@Test
public void when_callingHandleEpisodeStart_expect_experienceHandlerResetCalled() {
// Arrange
LearningBehavior<Integer, Object> sut = LearningBehavior.<Integer, Object>builder()
.experienceHandler(experienceHandlerMock)
.updateRule(updateRuleMock)
.build();
// Act
sut.handleEpisodeStart();
// Assert
verify(experienceHandlerMock, times(1)).reset();
}
@Test
public void when_callingHandleNewExperience_expect_experienceHandlerAddExperienceCalled() {
// Arrange
INDArray observationData = Nd4j.rand(1, 1);
when(experienceHandlerMock.isTrainingBatchReady()).thenReturn(false);
// Act
sut.handleNewExperience(new Observation(observationData), 1, 2.0, false);
// Assert
ArgumentCaptor<Observation> observationCaptor = ArgumentCaptor.forClass(Observation.class);
ArgumentCaptor<Integer> actionCaptor = ArgumentCaptor.forClass(Integer.class);
ArgumentCaptor<Double> rewardCaptor = ArgumentCaptor.forClass(Double.class);
ArgumentCaptor<Boolean> isTerminatedCaptor = ArgumentCaptor.forClass(Boolean.class);
verify(experienceHandlerMock, times(1)).addExperience(observationCaptor.capture(), actionCaptor.capture(), rewardCaptor.capture(), isTerminatedCaptor.capture());
assertEquals(observationData.getDouble(0, 0), observationCaptor.getValue().getData().getDouble(0, 0), 0.00001);
assertEquals(1, (int)actionCaptor.getValue());
assertEquals(2.0, (double)rewardCaptor.getValue(), 0.00001);
assertFalse(isTerminatedCaptor.getValue());
verify(updateRuleMock, never()).update(any(List.class));
}
@Test
public void when_callingHandleNewExperienceAndTrainingBatchIsReady_expect_updateRuleUpdateWithTrainingBatch() {
// Arrange
INDArray observationData = Nd4j.rand(1, 1);
when(experienceHandlerMock.isTrainingBatchReady()).thenReturn(true);
List<Object> trainingBatch = new ArrayList<Object>();
when(experienceHandlerMock.generateTrainingBatch()).thenReturn(trainingBatch);
// Act
sut.handleNewExperience(new Observation(observationData), 1, 2.0, false);
// Assert
verify(updateRuleMock, times(1)).update(trainingBatch);
}
@Test
public void when_callingHandleEpisodeEnd_expect_experienceHandlerSetFinalObservationCalled() {
// Arrange
INDArray observationData = Nd4j.rand(1, 1);
when(experienceHandlerMock.isTrainingBatchReady()).thenReturn(false);
// Act
sut.handleEpisodeEnd(new Observation(observationData));
// Assert
ArgumentCaptor<Observation> observationCaptor = ArgumentCaptor.forClass(Observation.class);
verify(experienceHandlerMock, times(1)).setFinalObservation(observationCaptor.capture());
assertEquals(observationData.getDouble(0, 0), observationCaptor.getValue().getData().getDouble(0, 0), 0.00001);
verify(updateRuleMock, never()).update(any(List.class));
}
@Test
public void when_callingHandleEpisodeEndAndTrainingBatchIsNotEmpty_expect_updateRuleUpdateWithTrainingBatch() {
// Arrange
INDArray observationData = Nd4j.rand(1, 1);
when(experienceHandlerMock.isTrainingBatchReady()).thenReturn(true);
List<Object> trainingBatch = new ArrayList<Object>();
when(experienceHandlerMock.generateTrainingBatch()).thenReturn(trainingBatch);
// Act
sut.handleEpisodeEnd(new Observation(observationData));
// Assert
ArgumentCaptor<Observation> observationCaptor = ArgumentCaptor.forClass(Observation.class);
verify(experienceHandlerMock, times(1)).setFinalObservation(observationCaptor.capture());
assertEquals(observationData.getDouble(0, 0), observationCaptor.getValue().getData().getDouble(0, 0), 0.00001);
verify(updateRuleMock, times(1)).update(trainingBatch);
}
}

View File

@ -4,34 +4,44 @@ import org.deeplearning4j.rl4j.learning.sync.IExpReplay;
import org.deeplearning4j.rl4j.learning.sync.Transition; import org.deeplearning4j.rl4j.learning.sync.Transition;
import org.deeplearning4j.rl4j.observation.Observation; import org.deeplearning4j.rl4j.observation.Observation;
import org.junit.Test; import org.junit.Test;
import org.junit.runner.RunWith;
import org.mockito.ArgumentCaptor;
import org.mockito.Mock;
import org.mockito.junit.MockitoJUnitRunner;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
import java.util.ArrayList;
import java.util.List; import java.util.List;
import static org.junit.Assert.assertEquals; import static org.junit.Assert.*;
import static org.mockito.Mockito.*;
@RunWith(MockitoJUnitRunner.class)
public class ReplayMemoryExperienceHandlerTest { public class ReplayMemoryExperienceHandlerTest {
@Mock
IExpReplay<Integer> expReplayMock;
@Test @Test
public void when_addingFirstExperience_expect_notAddedToStoreBeforeNextObservationIsAdded() { public void when_addingFirstExperience_expect_notAddedToStoreBeforeNextObservationIsAdded() {
// Arrange // Arrange
TestExpReplay expReplayMock = new TestExpReplay(); when(expReplayMock.getDesignatedBatchSize()).thenReturn(10);
ReplayMemoryExperienceHandler sut = new ReplayMemoryExperienceHandler(expReplayMock); ReplayMemoryExperienceHandler sut = new ReplayMemoryExperienceHandler(expReplayMock);
// Act // Act
sut.addExperience(new Observation(Nd4j.create(new double[] { 1.0 })), 1, 1.0, false); sut.addExperience(new Observation(Nd4j.create(new double[] { 1.0 })), 1, 1.0, false);
int numStoredTransitions = expReplayMock.addedTransitions.size(); boolean isStoreCalledAfterFirstAdd = mockingDetails(expReplayMock).getInvocations().stream().anyMatch(x -> x.getMethod().getName() == "store");
sut.addExperience(new Observation(Nd4j.create(new double[] { 2.0 })), 2, 2.0, false); sut.addExperience(new Observation(Nd4j.create(new double[] { 2.0 })), 2, 2.0, false);
boolean isStoreCalledAfterSecondAdd = mockingDetails(expReplayMock).getInvocations().stream().anyMatch(x -> x.getMethod().getName() == "store");
// Assert // Assert
assertEquals(0, numStoredTransitions); assertFalse(isStoreCalledAfterFirstAdd);
assertEquals(1, expReplayMock.addedTransitions.size()); assertTrue(isStoreCalledAfterSecondAdd);
} }
@Test @Test
public void when_addingExperience_expect_transitionsAreCorrect() { public void when_addingExperience_expect_transitionsAreCorrect() {
// Arrange // Arrange
TestExpReplay expReplayMock = new TestExpReplay();
ReplayMemoryExperienceHandler sut = new ReplayMemoryExperienceHandler(expReplayMock); ReplayMemoryExperienceHandler sut = new ReplayMemoryExperienceHandler(expReplayMock);
// Act // Act
@ -40,24 +50,25 @@ public class ReplayMemoryExperienceHandlerTest {
sut.setFinalObservation(new Observation(Nd4j.create(new double[] { 3.0 }))); sut.setFinalObservation(new Observation(Nd4j.create(new double[] { 3.0 })));
// Assert // Assert
assertEquals(2, expReplayMock.addedTransitions.size()); ArgumentCaptor<Transition<Integer>> argument = ArgumentCaptor.forClass(Transition.class);
verify(expReplayMock, times(2)).store(argument.capture());
List<Transition<Integer>> transitions = argument.getAllValues();
assertEquals(1.0, expReplayMock.addedTransitions.get(0).getObservation().getData().getDouble(0), 0.00001); assertEquals(1.0, transitions.get(0).getObservation().getData().getDouble(0), 0.00001);
assertEquals(1, (int)expReplayMock.addedTransitions.get(0).getAction()); assertEquals(1, (int)transitions.get(0).getAction());
assertEquals(1.0, expReplayMock.addedTransitions.get(0).getReward(), 0.00001); assertEquals(1.0, transitions.get(0).getReward(), 0.00001);
assertEquals(2.0, expReplayMock.addedTransitions.get(0).getNextObservation().getDouble(0), 0.00001); assertEquals(2.0, transitions.get(0).getNextObservation().getDouble(0), 0.00001);
assertEquals(2.0, expReplayMock.addedTransitions.get(1).getObservation().getData().getDouble(0), 0.00001); assertEquals(2.0, transitions.get(1).getObservation().getData().getDouble(0), 0.00001);
assertEquals(2, (int)expReplayMock.addedTransitions.get(1).getAction()); assertEquals(2, (int)transitions.get(1).getAction());
assertEquals(2.0, expReplayMock.addedTransitions.get(1).getReward(), 0.00001); assertEquals(2.0, transitions.get(1).getReward(), 0.00001);
assertEquals(3.0, expReplayMock.addedTransitions.get(1).getNextObservation().getDouble(0), 0.00001); assertEquals(3.0, transitions.get(1).getNextObservation().getDouble(0), 0.00001);
} }
@Test @Test
public void when_settingFinalObservation_expect_nextAddedExperienceDoNotUsePreviousObservation() { public void when_settingFinalObservation_expect_nextAddedExperienceDoNotUsePreviousObservation() {
// Arrange // Arrange
TestExpReplay expReplayMock = new TestExpReplay();
ReplayMemoryExperienceHandler sut = new ReplayMemoryExperienceHandler(expReplayMock); ReplayMemoryExperienceHandler sut = new ReplayMemoryExperienceHandler(expReplayMock);
// Act // Act
@ -66,42 +77,57 @@ public class ReplayMemoryExperienceHandlerTest {
sut.addExperience(new Observation(Nd4j.create(new double[] { 3.0 })), 3, 3.0, false); sut.addExperience(new Observation(Nd4j.create(new double[] { 3.0 })), 3, 3.0, false);
// Assert // Assert
assertEquals(1, expReplayMock.addedTransitions.size()); ArgumentCaptor<Transition<Integer>> argument = ArgumentCaptor.forClass(Transition.class);
assertEquals(1, (int)expReplayMock.addedTransitions.get(0).getAction()); verify(expReplayMock, times(1)).store(argument.capture());
Transition<Integer> transition = argument.getValue();
assertEquals(1, (int)transition.getAction());
} }
@Test @Test
public void when_addingExperience_expect_getTrainingBatchSizeReturnSize() { public void when_addingExperience_expect_getTrainingBatchSizeReturnSize() {
// Arrange // Arrange
TestExpReplay expReplayMock = new TestExpReplay(); ReplayMemoryExperienceHandler sut = new ReplayMemoryExperienceHandler(10, 5, Nd4j.getRandom());
ReplayMemoryExperienceHandler sut = new ReplayMemoryExperienceHandler(expReplayMock);
sut.addExperience(new Observation(Nd4j.create(new double[] { 1.0 })), 1, 1.0, false); sut.addExperience(new Observation(Nd4j.create(new double[] { 1.0 })), 1, 1.0, false);
sut.addExperience(new Observation(Nd4j.create(new double[] { 2.0 })), 2, 2.0, false); sut.addExperience(new Observation(Nd4j.create(new double[] { 2.0 })), 2, 2.0, false);
sut.setFinalObservation(new Observation(Nd4j.create(new double[] { 3.0 }))); sut.setFinalObservation(new Observation(Nd4j.create(new double[] { 3.0 })));
// Act // Act
int size = sut.getTrainingBatchSize(); int size = sut.getTrainingBatchSize();
// Assert // Assert
assertEquals(2, size); assertEquals(2, size);
} }
private static class TestExpReplay implements IExpReplay<Integer> { @Test
public void when_experienceSizeIsSmallerThanBatchSize_expect_TrainingBatchIsNotReady() {
// Arrange
ReplayMemoryExperienceHandler sut = new ReplayMemoryExperienceHandler(10, 5, Nd4j.getRandom());
sut.addExperience(new Observation(Nd4j.create(new double[] { 1.0 })), 1, 1.0, false);
sut.addExperience(new Observation(Nd4j.create(new double[] { 2.0 })), 2, 2.0, false);
sut.setFinalObservation(new Observation(Nd4j.create(new double[] { 3.0 })));
public final List<Transition<Integer>> addedTransitions = new ArrayList<>(); // Act
@Override // Assert
public ArrayList<Transition<Integer>> getBatch() { assertFalse(sut.isTrainingBatchReady());
return null;
}
@Override
public void store(Transition<Integer> transition) {
addedTransitions.add(transition);
}
@Override
public int getBatchSize() {
return addedTransitions.size();
}
} }
@Test
public void when_experienceSizeIsGreaterOrEqualToBatchSize_expect_TrainingBatchIsReady() {
// Arrange
ReplayMemoryExperienceHandler sut = new ReplayMemoryExperienceHandler(10, 5, Nd4j.getRandom());
sut.addExperience(new Observation(Nd4j.create(new double[] { 1.0 })), 1, 1.0, false);
sut.addExperience(new Observation(Nd4j.create(new double[] { 2.0 })), 2, 2.0, false);
sut.addExperience(new Observation(Nd4j.create(new double[] { 3.0 })), 3, 3.0, false);
sut.addExperience(new Observation(Nd4j.create(new double[] { 4.0 })), 4, 4.0, false);
sut.addExperience(new Observation(Nd4j.create(new double[] { 5.0 })), 5, 5.0, false);
sut.setFinalObservation(new Observation(Nd4j.create(new double[] { 6.0 })));
// Act
// Assert
assertTrue(sut.isTrainingBatchReady());
}
} }

View File

@ -13,7 +13,7 @@ public class StateActionExperienceHandlerTest {
@Test @Test
public void when_addingExperience_expect_generateTrainingBatchReturnsIt() { public void when_addingExperience_expect_generateTrainingBatchReturnsIt() {
// Arrange // Arrange
StateActionExperienceHandler sut = new StateActionExperienceHandler(); StateActionExperienceHandler sut = new StateActionExperienceHandler(Integer.MAX_VALUE);
sut.reset(); sut.reset();
Observation observation = new Observation(Nd4j.zeros(1)); Observation observation = new Observation(Nd4j.zeros(1));
sut.addExperience(observation, 123, 234.0, true); sut.addExperience(observation, 123, 234.0, true);
@ -32,7 +32,7 @@ public class StateActionExperienceHandlerTest {
@Test @Test
public void when_addingMultipleExperiences_expect_generateTrainingBatchReturnsItInSameOrder() { public void when_addingMultipleExperiences_expect_generateTrainingBatchReturnsItInSameOrder() {
// Arrange // Arrange
StateActionExperienceHandler sut = new StateActionExperienceHandler(); StateActionExperienceHandler sut = new StateActionExperienceHandler(Integer.MAX_VALUE);
sut.reset(); sut.reset();
sut.addExperience(null, 1, 1.0, false); sut.addExperience(null, 1, 1.0, false);
sut.addExperience(null, 2, 2.0, false); sut.addExperience(null, 2, 2.0, false);
@ -51,7 +51,7 @@ public class StateActionExperienceHandlerTest {
@Test @Test
public void when_gettingExperience_expect_experienceStoreIsCleared() { public void when_gettingExperience_expect_experienceStoreIsCleared() {
// Arrange // Arrange
StateActionExperienceHandler sut = new StateActionExperienceHandler(); StateActionExperienceHandler sut = new StateActionExperienceHandler(Integer.MAX_VALUE);
sut.reset(); sut.reset();
sut.addExperience(null, 1, 1.0, false); sut.addExperience(null, 1, 1.0, false);
@ -67,7 +67,7 @@ public class StateActionExperienceHandlerTest {
@Test @Test
public void when_addingExperience_expect_getTrainingBatchSizeReturnSize() { public void when_addingExperience_expect_getTrainingBatchSizeReturnSize() {
// Arrange // Arrange
StateActionExperienceHandler sut = new StateActionExperienceHandler(); StateActionExperienceHandler sut = new StateActionExperienceHandler(Integer.MAX_VALUE);
sut.reset(); sut.reset();
sut.addExperience(null, 1, 1.0, false); sut.addExperience(null, 1, 1.0, false);
sut.addExperience(null, 2, 2.0, false); sut.addExperience(null, 2, 2.0, false);
@ -79,4 +79,66 @@ public class StateActionExperienceHandlerTest {
// Assert // Assert
assertEquals(3, size); assertEquals(3, size);
} }
@Test
public void when_experienceIsEmpty_expect_TrainingBatchNotReady() {
// Arrange
StateActionExperienceHandler sut = new StateActionExperienceHandler(5);
sut.reset();
// Act
boolean isTrainingBatchReady = sut.isTrainingBatchReady();
// Assert
assertFalse(isTrainingBatchReady);
}
@Test
public void when_experienceSizeIsGreaterOrEqualToThanBatchSize_expect_TrainingBatchIsReady() {
// Arrange
StateActionExperienceHandler sut = new StateActionExperienceHandler(5);
sut.reset();
sut.addExperience(null, 1, 1.0, false);
sut.addExperience(null, 2, 2.0, false);
sut.addExperience(null, 3, 3.0, false);
sut.addExperience(null, 4, 4.0, false);
sut.addExperience(null, 5, 5.0, false);
// Act
boolean isTrainingBatchReady = sut.isTrainingBatchReady();
// Assert
assertTrue(isTrainingBatchReady);
}
@Test
public void when_experienceSizeIsSmallerThanBatchSizeButFinalObservationIsSet_expect_TrainingBatchIsReady() {
// Arrange
StateActionExperienceHandler sut = new StateActionExperienceHandler(5);
sut.reset();
sut.addExperience(null, 1, 1.0, false);
sut.addExperience(null, 2, 2.0, false);
sut.setFinalObservation(null);
// Act
boolean isTrainingBatchReady = sut.isTrainingBatchReady();
// Assert
assertTrue(isTrainingBatchReady);
}
@Test
public void when_experienceSizeIsZeroAndFinalObservationIsSet_expect_TrainingBatchIsNotReady() {
// Arrange
StateActionExperienceHandler sut = new StateActionExperienceHandler(5);
sut.reset();
sut.setFinalObservation(null);
// Act
boolean isTrainingBatchReady = sut.isTrainingBatchReady();
// Assert
assertFalse(isTrainingBatchReady);
}
} }

View File

@ -49,4 +49,25 @@ public class INDArrayHelperTest {
assertEquals(1, output.shape()[1]); assertEquals(1, output.shape()[1]);
} }
@Test
public void when_callingCreateBatchForShape_expect_INDArrayWithCorrectShapeAndOriginalShapeUnchanged() {
// Arrange
long[] shape = new long[] { 1, 3, 4};
// Act
INDArray output = INDArrayHelper.createBatchForShape(2, shape);
// Assert
// Output shape
assertEquals(3, output.shape().length);
assertEquals(2, output.shape()[0]);
assertEquals(3, output.shape()[1]);
assertEquals(4, output.shape()[2]);
// Input should remain unchanged
assertEquals(1, shape[0]);
assertEquals(3, shape[1]);
assertEquals(4, shape[2]);
}
} }

View File

@ -19,10 +19,11 @@ package org.deeplearning4j.rl4j.learning.async.nstep.discrete;
import org.deeplearning4j.rl4j.experience.StateActionPair; import org.deeplearning4j.rl4j.experience.StateActionPair;
import org.deeplearning4j.rl4j.learning.async.AsyncGlobal; import org.deeplearning4j.rl4j.learning.async.AsyncGlobal;
import org.deeplearning4j.rl4j.learning.async.UpdateAlgorithm; import org.deeplearning4j.rl4j.learning.async.UpdateAlgorithm;
import org.deeplearning4j.rl4j.network.dqn.IDQN;
import org.deeplearning4j.rl4j.observation.Observation; import org.deeplearning4j.rl4j.observation.Observation;
import org.deeplearning4j.rl4j.support.MockDQN;
import org.junit.Test; import org.junit.Test;
import org.junit.runner.RunWith; import org.junit.runner.RunWith;
import org.mockito.ArgumentCaptor;
import org.mockito.Mock; import org.mockito.Mock;
import org.mockito.junit.MockitoJUnitRunner; import org.mockito.junit.MockitoJUnitRunner;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
@ -32,6 +33,9 @@ import java.util.ArrayList;
import java.util.List; import java.util.List;
import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertEquals;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.argThat;
import static org.mockito.Mockito.*;
@RunWith(MockitoJUnitRunner.class) @RunWith(MockitoJUnitRunner.class)
public class QLearningUpdateAlgorithmTest { public class QLearningUpdateAlgorithmTest {
@ -39,12 +43,24 @@ public class QLearningUpdateAlgorithmTest {
@Mock @Mock
AsyncGlobal mockAsyncGlobal; AsyncGlobal mockAsyncGlobal;
@Mock
IDQN dqnMock;
private UpdateAlgorithm sut;
private void setup(double gamma) {
// mock a neural net output -- just invert the sign of the input
when(dqnMock.outputAll(any(INDArray.class))).thenAnswer(invocation -> new INDArray[] { invocation.getArgument(0, INDArray.class).mul(-1.0) });
sut = new QLearningUpdateAlgorithm(2, gamma);
}
@Test @Test
public void when_isTerminal_expect_initRewardIs0() { public void when_isTerminal_expect_initRewardIs0() {
// Arrange // Arrange
MockDQN dqnMock = new MockDQN(); setup(1.0);
UpdateAlgorithm sut = new QLearningUpdateAlgorithm(new int[] { 1 }, 1, 1.0);
final Observation observation = new Observation(Nd4j.zeros(1)); final Observation observation = new Observation(Nd4j.zeros(1, 2));
List<StateActionPair<Integer>> experience = new ArrayList<StateActionPair<Integer>>() { List<StateActionPair<Integer>> experience = new ArrayList<StateActionPair<Integer>>() {
{ {
add(new StateActionPair<Integer>(observation, 0, 0.0, true)); add(new StateActionPair<Integer>(observation, 0, 0.0, true));
@ -55,59 +71,68 @@ public class QLearningUpdateAlgorithmTest {
sut.computeGradients(dqnMock, experience); sut.computeGradients(dqnMock, experience);
// Assert // Assert
assertEquals(0.0, dqnMock.gradientParams.get(0).getRight().getDouble(0), 0.00001); verify(dqnMock, times(1)).gradient(any(INDArray.class), argThat((INDArray x) -> x.getDouble(0) == 0.0));
} }
@Test @Test
public void when_terminalAndNoTargetUpdate_expect_initRewardWithMaxQFromCurrent() { public void when_terminalAndNoTargetUpdate_expect_initRewardWithMaxQFromCurrent() {
// Arrange // Arrange
UpdateAlgorithm sut = new QLearningUpdateAlgorithm(new int[] { 2 }, 2, 1.0); setup(1.0);
final Observation observation = new Observation(Nd4j.create(new double[] { -123.0, -234.0 }));
final Observation observation = new Observation(Nd4j.create(new double[] { -123.0, -234.0 }).reshape(1, 2));
List<StateActionPair<Integer>> experience = new ArrayList<StateActionPair<Integer>>() { List<StateActionPair<Integer>> experience = new ArrayList<StateActionPair<Integer>>() {
{ {
add(new StateActionPair<Integer>(observation, 0, 0.0, false)); add(new StateActionPair<Integer>(observation, 0, 0.0, false));
} }
}; };
MockDQN dqnMock = new MockDQN();
// Act // Act
sut.computeGradients(dqnMock, experience); sut.computeGradients(dqnMock, experience);
// Assert // Assert
assertEquals(2, dqnMock.outputAllParams.size()); ArgumentCaptor<INDArray> argument = ArgumentCaptor.forClass(INDArray.class);
assertEquals(-123.0, dqnMock.outputAllParams.get(0).getDouble(0, 0), 0.00001);
assertEquals(234.0, dqnMock.gradientParams.get(0).getRight().getDouble(0), 0.00001); verify(dqnMock, times(2)).outputAll(argument.capture());
List<INDArray> values = argument.getAllValues();
assertEquals(-123.0, values.get(0).getDouble(0, 0), 0.00001);
assertEquals(-123.0, values.get(1).getDouble(0, 0), 0.00001);
verify(dqnMock, times(1)).gradient(any(INDArray.class), argThat((INDArray x) -> x.getDouble(0) == 234.0));
} }
@Test @Test
public void when_callingWithMultipleExperiences_expect_gradientsAreValid() { public void when_callingWithMultipleExperiences_expect_gradientsAreValid() {
// Arrange // Arrange
double gamma = 0.9; double gamma = 0.9;
UpdateAlgorithm sut = new QLearningUpdateAlgorithm(new int[] { 2 }, 2, gamma); setup(gamma);
List<StateActionPair<Integer>> experience = new ArrayList<StateActionPair<Integer>>() { List<StateActionPair<Integer>> experience = new ArrayList<StateActionPair<Integer>>() {
{ {
add(new StateActionPair<Integer>(new Observation(Nd4j.create(new double[] { -1.1, -1.2 })), 0, 1.0, false)); add(new StateActionPair<Integer>(new Observation(Nd4j.create(new double[] { -1.1, -1.2 }).reshape(1, 2)), 0, 1.0, false));
add(new StateActionPair<Integer>(new Observation(Nd4j.create(new double[] { -2.1, -2.2 })), 1, 2.0, true)); add(new StateActionPair<Integer>(new Observation(Nd4j.create(new double[] { -2.1, -2.2 }).reshape(1, 2)), 1, 2.0, true));
} }
}; };
MockDQN dqnMock = new MockDQN();
// Act // Act
sut.computeGradients(dqnMock, experience); sut.computeGradients(dqnMock, experience);
// Assert // Assert
ArgumentCaptor<INDArray> features = ArgumentCaptor.forClass(INDArray.class);
ArgumentCaptor<INDArray> targets = ArgumentCaptor.forClass(INDArray.class);
verify(dqnMock, times(1)).gradient(features.capture(), targets.capture());
// input side -- should be a stack of observations // input side -- should be a stack of observations
INDArray input = dqnMock.gradientParams.get(0).getLeft(); INDArray featuresValues = features.getValue();
assertEquals(-1.1, input.getDouble(0, 0), 0.00001); assertEquals(-1.1, featuresValues.getDouble(0, 0), 0.00001);
assertEquals(-1.2, input.getDouble(0, 1), 0.00001); assertEquals(-1.2, featuresValues.getDouble(0, 1), 0.00001);
assertEquals(-2.1, input.getDouble(1, 0), 0.00001); assertEquals(-2.1, featuresValues.getDouble(1, 0), 0.00001);
assertEquals(-2.2, input.getDouble(1, 1), 0.00001); assertEquals(-2.2, featuresValues.getDouble(1, 1), 0.00001);
// target side // target side
INDArray target = dqnMock.gradientParams.get(0).getRight(); INDArray targetsValues = targets.getValue();
assertEquals(1.0 + gamma * 2.0, target.getDouble(0, 0), 0.00001); assertEquals(1.0 + gamma * 2.0, targetsValues.getDouble(0, 0), 0.00001);
assertEquals(1.2, target.getDouble(0, 1), 0.00001); assertEquals(1.2, targetsValues.getDouble(0, 1), 0.00001);
assertEquals(2.1, target.getDouble(1, 0), 0.00001); assertEquals(2.1, targetsValues.getDouble(1, 0), 0.00001);
assertEquals(2.0, target.getDouble(1, 1), 0.00001); assertEquals(2.0, targetsValues.getDouble(1, 1), 0.00001);
} }
} }

View File

@ -18,6 +18,7 @@
package org.deeplearning4j.rl4j.learning.sync.qlearning.discrete; package org.deeplearning4j.rl4j.learning.sync.qlearning.discrete;
import org.deeplearning4j.gym.StepReply; import org.deeplearning4j.gym.StepReply;
import org.deeplearning4j.rl4j.agent.learning.ILearningBehavior;
import org.deeplearning4j.rl4j.learning.IHistoryProcessor; import org.deeplearning4j.rl4j.learning.IHistoryProcessor;
import org.deeplearning4j.rl4j.learning.configuration.QLearningConfiguration; import org.deeplearning4j.rl4j.learning.configuration.QLearningConfiguration;
import org.deeplearning4j.rl4j.learning.sync.qlearning.QLearning; import org.deeplearning4j.rl4j.learning.sync.qlearning.QLearning;
@ -74,6 +75,9 @@ public class QLearningDiscreteTest {
@Mock @Mock
QLearningConfiguration mockQlearningConfiguration; QLearningConfiguration mockQlearningConfiguration;
@Mock
ILearningBehavior<Integer> learningBehavior;
// HWC // HWC
int[] observationShape = new int[]{3, 10, 10}; int[] observationShape = new int[]{3, 10, 10};
int totalObservationSize = 1; int totalObservationSize = 1;
@ -92,18 +96,28 @@ public class QLearningDiscreteTest {
} }
private void mockTestContext(int maxSteps, int updateStart, int batchSize, double rewardFactor, int maxExperienceReplay) { private void mockTestContext(int maxSteps, int updateStart, int batchSize, double rewardFactor, int maxExperienceReplay, ILearningBehavior<Integer> learningBehavior) {
when(mockQlearningConfiguration.getBatchSize()).thenReturn(batchSize); when(mockQlearningConfiguration.getBatchSize()).thenReturn(batchSize);
when(mockQlearningConfiguration.getRewardFactor()).thenReturn(rewardFactor); when(mockQlearningConfiguration.getRewardFactor()).thenReturn(rewardFactor);
when(mockQlearningConfiguration.getExpRepMaxSize()).thenReturn(maxExperienceReplay); when(mockQlearningConfiguration.getExpRepMaxSize()).thenReturn(maxExperienceReplay);
when(mockQlearningConfiguration.getSeed()).thenReturn(123L); when(mockQlearningConfiguration.getSeed()).thenReturn(123L);
qLearningDiscrete = mock( if(learningBehavior != null) {
QLearningDiscrete.class, qLearningDiscrete = mock(
Mockito.withSettings() QLearningDiscrete.class,
.useConstructor(mockMDP, mockDQN, mockQlearningConfiguration, 0) Mockito.withSettings()
.defaultAnswer(Mockito.CALLS_REAL_METHODS) .useConstructor(mockMDP, mockDQN, mockQlearningConfiguration, 0, learningBehavior, Nd4j.getRandom())
); .defaultAnswer(Mockito.CALLS_REAL_METHODS)
);
}
else {
qLearningDiscrete = mock(
QLearningDiscrete.class,
Mockito.withSettings()
.useConstructor(mockMDP, mockDQN, mockQlearningConfiguration, 0)
.defaultAnswer(Mockito.CALLS_REAL_METHODS)
);
}
} }
private void mockHistoryProcessor(int skipFrames) { private void mockHistoryProcessor(int skipFrames) {
@ -136,7 +150,7 @@ public class QLearningDiscreteTest {
public void when_singleTrainStep_expect_correctValues() { public void when_singleTrainStep_expect_correctValues() {
// Arrange // Arrange
mockTestContext(100,0,2,1.0, 10); mockTestContext(100,0,2,1.0, 10, null);
// An example observation and 2 Q values output (2 actions) // An example observation and 2 Q values output (2 actions)
Observation observation = new Observation(Nd4j.zeros(observationShape)); Observation observation = new Observation(Nd4j.zeros(observationShape));
@ -162,7 +176,7 @@ public class QLearningDiscreteTest {
@Test @Test
public void when_singleTrainStepSkippedFrames_expect_correctValues() { public void when_singleTrainStepSkippedFrames_expect_correctValues() {
// Arrange // Arrange
mockTestContext(100,0,2,1.0, 10); mockTestContext(100,0,2,1.0, 10, learningBehavior);
Observation skippedObservation = Observation.SkippedObservation; Observation skippedObservation = Observation.SkippedObservation;
Observation nextObservation = new Observation(Nd4j.zeros(observationShape)); Observation nextObservation = new Observation(Nd4j.zeros(observationShape));
@ -180,8 +194,8 @@ public class QLearningDiscreteTest {
assertEquals(0, stepReply.getReward(), 1e-5); assertEquals(0, stepReply.getReward(), 1e-5);
assertFalse(stepReply.isDone()); assertFalse(stepReply.isDone());
assertFalse(stepReply.getObservation().isSkipped()); assertFalse(stepReply.getObservation().isSkipped());
assertEquals(0, qLearningDiscrete.getExperienceHandler().getTrainingBatchSize());
verify(learningBehavior, never()).handleNewExperience(any(Observation.class), any(Integer.class), any(Double.class), any(Boolean.class));
verify(mockDQN, never()).output(any(INDArray.class)); verify(mockDQN, never()).output(any(INDArray.class));
} }