RL4J: Add AgentLearner (#470)

Signed-off-by: Alexandre Boulanger <aboulang2002@yahoo.com>
master
Alexandre Boulanger 2020-05-27 07:41:02 -04:00 committed by GitHub
parent a18417193d
commit 5568b9d72f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
40 changed files with 1541 additions and 244 deletions

View File

@ -1,3 +1,18 @@
/*******************************************************************************
* Copyright (c) 2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.rl4j.agent;
import lombok.AccessLevel;
@ -14,7 +29,13 @@ import org.nd4j.common.base.Preconditions;
import java.util.Map;
public class Agent<ACTION> {
/**
* An agent implementation. The Agent will use a {@link IPolicy} to interact with an {@link Environment} and receive
* a reward.
*
* @param <ACTION> The type of action
*/
public class Agent<ACTION> implements IAgent<ACTION> {
@Getter
private final String id;
@ -37,19 +58,28 @@ public class Agent<ACTION> {
private ACTION lastAction;
@Getter
private int episodeStepNumber;
private int episodeStepCount;
@Getter
private double reward;
protected boolean canContinue;
private Agent(Builder<ACTION> builder) {
this.environment = builder.environment;
this.transformProcess = builder.transformProcess;
this.policy = builder.policy;
this.maxEpisodeSteps = builder.maxEpisodeSteps;
this.id = builder.id;
/**
* @param environment The {@link Environment} to be used
* @param transformProcess The {@link TransformProcess} to be used to transform the raw observations into usable ones.
* @param policy The {@link IPolicy} to be used
* @param maxEpisodeSteps The maximum number of steps an episode can have before being interrupted. Use null to have no max.
* @param id A user-supplied id to identify the instance.
*/
public Agent(@NonNull Environment<ACTION> environment, @NonNull TransformProcess transformProcess, @NonNull IPolicy<ACTION> policy, Integer maxEpisodeSteps, String id) {
Preconditions.checkArgument(maxEpisodeSteps == null || maxEpisodeSteps > 0, "maxEpisodeSteps must be null (no maximum) or greater than 0, got", maxEpisodeSteps);
this.environment = environment;
this.transformProcess = transformProcess;
this.policy = policy;
this.maxEpisodeSteps = maxEpisodeSteps;
this.id = id;
listeners = buildListenerList();
}
@ -58,10 +88,17 @@ public class Agent<ACTION> {
return new AgentListenerList<ACTION>();
}
/**
* Add a {@link AgentListener} that will be notified when agent events happens
* @param listener
*/
public void addListener(AgentListener listener) {
listeners.add(listener);
}
/**
* This will run a single episode
*/
public void run() {
runEpisode();
}
@ -80,7 +117,7 @@ public class Agent<ACTION> {
canContinue = listeners.notifyBeforeEpisode(this);
while (canContinue && !environment.isEpisodeFinished() && (maxEpisodeSteps == null || episodeStepNumber < maxEpisodeSteps)) {
while (canContinue && !environment.isEpisodeFinished() && (maxEpisodeSteps == null || episodeStepCount < maxEpisodeSteps)) {
performStep();
}
@ -100,9 +137,9 @@ public class Agent<ACTION> {
}
protected void resetEnvironment() {
episodeStepNumber = 0;
episodeStepCount = 0;
Map<String, Object> channelsData = environment.reset();
this.observation = transformProcess.transform(channelsData, episodeStepNumber, false);
this.observation = transformProcess.transform(channelsData, episodeStepCount, false);
}
protected void resetPolicy() {
@ -125,7 +162,6 @@ public class Agent<ACTION> {
}
StepResult stepResult = act(action);
handleStepResult(stepResult);
onAfterStep(stepResult);
@ -134,11 +170,11 @@ public class Agent<ACTION> {
return;
}
incrementEpisodeStepNumber();
incrementEpisodeStepCount();
}
protected void incrementEpisodeStepNumber() {
++episodeStepNumber;
protected void incrementEpisodeStepCount() {
++episodeStepCount;
}
protected ACTION decideAction(Observation observation) {
@ -150,12 +186,15 @@ public class Agent<ACTION> {
}
protected StepResult act(ACTION action) {
return environment.step(action);
}
Observation observationBeforeAction = observation;
protected void handleStepResult(StepResult stepResult) {
observation = convertChannelDataToObservation(stepResult, episodeStepNumber + 1);
reward +=computeReward(stepResult);
StepResult stepResult = environment.step(action);
observation = convertChannelDataToObservation(stepResult, episodeStepCount + 1);
reward += computeReward(stepResult);
onAfterAction(observationBeforeAction, action, stepResult);
return stepResult;
}
protected Observation convertChannelDataToObservation(StepResult stepResult, int episodeStepNumberOfObs) {
@ -166,6 +205,10 @@ public class Agent<ACTION> {
return stepResult.getReward();
}
protected void onAfterAction(Observation observationBeforeAction, ACTION action, StepResult stepResult) {
// Do Nothing
}
protected void onAfterStep(StepResult stepResult) {
// Do Nothing
}
@ -174,16 +217,24 @@ public class Agent<ACTION> {
// Do Nothing
}
public static <ACTION> Builder<ACTION> builder(@NonNull Environment<ACTION> environment, @NonNull TransformProcess transformProcess, @NonNull IPolicy<ACTION> policy) {
/**
*
* @param environment
* @param transformProcess
* @param policy
* @param <ACTION>
* @return
*/
public static <ACTION> Builder<ACTION, Agent> builder(@NonNull Environment<ACTION> environment, @NonNull TransformProcess transformProcess, @NonNull IPolicy<ACTION> policy) {
return new Builder<>(environment, transformProcess, policy);
}
public static class Builder<ACTION> {
private final Environment<ACTION> environment;
private final TransformProcess transformProcess;
private final IPolicy<ACTION> policy;
private Integer maxEpisodeSteps = null; // Default, no max
private String id;
public static class Builder<ACTION, AGENT_TYPE extends Agent> {
protected final Environment<ACTION> environment;
protected final TransformProcess transformProcess;
protected final IPolicy<ACTION> policy;
protected Integer maxEpisodeSteps = null; // Default, no max
protected String id;
public Builder(@NonNull Environment<ACTION> environment, @NonNull TransformProcess transformProcess, @NonNull IPolicy<ACTION> policy) {
this.environment = environment;
@ -191,20 +242,20 @@ public class Agent<ACTION> {
this.policy = policy;
}
public Builder<ACTION> maxEpisodeSteps(int maxEpisodeSteps) {
public Builder<ACTION, AGENT_TYPE> maxEpisodeSteps(int maxEpisodeSteps) {
Preconditions.checkArgument(maxEpisodeSteps > 0, "maxEpisodeSteps must be greater than 0, got", maxEpisodeSteps);
this.maxEpisodeSteps = maxEpisodeSteps;
return this;
}
public Builder<ACTION> id(String id) {
public Builder<ACTION, AGENT_TYPE> id(String id) {
this.id = id;
return this;
}
public Agent build() {
return new Agent(this);
public AGENT_TYPE build() {
return (AGENT_TYPE)new Agent<ACTION>(environment, transformProcess, policy, maxEpisodeSteps, id);
}
}
}

View File

@ -0,0 +1,115 @@
/*******************************************************************************
* Copyright (c) 2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.rl4j.agent;
import lombok.Getter;
import lombok.NonNull;
import org.deeplearning4j.rl4j.agent.learning.ILearningBehavior;
import org.deeplearning4j.rl4j.environment.Environment;
import org.deeplearning4j.rl4j.environment.StepResult;
import org.deeplearning4j.rl4j.observation.Observation;
import org.deeplearning4j.rl4j.observation.transform.TransformProcess;
import org.deeplearning4j.rl4j.policy.IPolicy;
/**
* The ActionLearner is an {@link Agent} that delegate the learning to a {@link ILearningBehavior}.
* @param <ACTION> The type of the action
*/
public class AgentLearner<ACTION> extends Agent<ACTION> implements IAgentLearner<ACTION> {
@Getter
private int totalStepCount = 0;
private final ILearningBehavior<ACTION> learningBehavior;
private double rewardAtLastExperience;
/**
*
* @param environment The {@link Environment} to be used
* @param transformProcess The {@link TransformProcess} to be used to transform the raw observations into usable ones.
* @param policy The {@link IPolicy} to be used
* @param maxEpisodeSteps The maximum number of steps an episode can have before being interrupted. Use null to have no max.
* @param id A user-supplied id to identify the instance.
* @param learningBehavior The {@link ILearningBehavior} that will be used to supervise the learning.
*/
public AgentLearner(Environment<ACTION> environment, TransformProcess transformProcess, IPolicy<ACTION> policy, Integer maxEpisodeSteps, String id, @NonNull ILearningBehavior<ACTION> learningBehavior) {
super(environment, transformProcess, policy, maxEpisodeSteps, id);
this.learningBehavior = learningBehavior;
}
@Override
protected void reset() {
super.reset();
rewardAtLastExperience = 0;
}
@Override
protected void onBeforeEpisode() {
super.onBeforeEpisode();
learningBehavior.handleEpisodeStart();
}
@Override
protected void onAfterAction(Observation observationBeforeAction, ACTION action, StepResult stepResult) {
if(!observationBeforeAction.isSkipped()) {
double rewardSinceLastExperience = getReward() - rewardAtLastExperience;
learningBehavior.handleNewExperience(observationBeforeAction, action, rewardSinceLastExperience, stepResult.isTerminal());
rewardAtLastExperience = getReward();
}
}
@Override
protected void onAfterEpisode() {
learningBehavior.handleEpisodeEnd(getObservation());
}
@Override
protected void incrementEpisodeStepCount() {
super.incrementEpisodeStepCount();
++totalStepCount;
}
// FIXME: parent is still visible
public static <ACTION> AgentLearner.Builder<ACTION, AgentLearner<ACTION>> builder(Environment<ACTION> environment,
TransformProcess transformProcess,
IPolicy<ACTION> policy,
ILearningBehavior<ACTION> learningBehavior) {
return new AgentLearner.Builder<ACTION, AgentLearner<ACTION>>(environment, transformProcess, policy, learningBehavior);
}
public static class Builder<ACTION, AGENT_TYPE extends AgentLearner<ACTION>> extends Agent.Builder<ACTION, AGENT_TYPE> {
private final ILearningBehavior<ACTION> learningBehavior;
public Builder(@NonNull Environment<ACTION> environment,
@NonNull TransformProcess transformProcess,
@NonNull IPolicy<ACTION> policy,
@NonNull ILearningBehavior<ACTION> learningBehavior) {
super(environment, transformProcess, policy);
this.learningBehavior = learningBehavior;
}
@Override
public AGENT_TYPE build() {
return (AGENT_TYPE)new AgentLearner<ACTION>(environment, transformProcess, policy, maxEpisodeSteps, id, learningBehavior);
}
}
}

View File

@ -0,0 +1,55 @@
/*******************************************************************************
* Copyright (c) 2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.rl4j.agent;
import org.deeplearning4j.rl4j.environment.Environment;
import org.deeplearning4j.rl4j.policy.IPolicy;
/**
* The interface of {@link Agent}
* @param <ACTION>
*/
public interface IAgent<ACTION> {
/**
* Will play a single episode
*/
void run();
/**
* @return A user-supplied id to identify the IAgent instance.
*/
String getId();
/**
* @return The {@link Environment} instance being used by the agent.
*/
Environment<ACTION> getEnvironment();
/**
* @return The {@link IPolicy} instance being used by the agent.
*/
IPolicy<ACTION> getPolicy();
/**
* @return The step count taken in the current episode.
*/
int getEpisodeStepCount();
/**
* @return The cumulative reward received in the current episode.
*/
double getReward();
}

View File

@ -0,0 +1,24 @@
/*******************************************************************************
* Copyright (c) 2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.rl4j.agent;
public interface IAgentLearner<ACTION> extends IAgent<ACTION> {
/**
* @return The total count of steps taken by this AgentLearner, for all episodes.
*/
int getTotalStepCount();
}

View File

@ -0,0 +1,49 @@
/*******************************************************************************
* Copyright (c) 2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.rl4j.agent.learning;
import org.deeplearning4j.rl4j.observation.Observation;
/**
* The <code>ILearningBehavior</code> implementations are in charge of the training. Through this interface, they are
* notified as new experience is generated.
*
* @param <ACTION> The type of action
*/
public interface ILearningBehavior<ACTION> {
/**
* This method is called when a new episode has been started.
*/
void handleEpisodeStart();
/**
* This method is called when new experience is generated.
*
* @param observation The observation prior to taking the action
* @param action The action that has been taken
* @param reward The reward received by taking the action
* @param isTerminal True if the episode ended after taking the action
*/
void handleNewExperience(Observation observation, ACTION action, double reward, boolean isTerminal);
/**
* This method is called when the episode ends or the maximum number of episode steps is reached.
*
* @param finalObservation The observation after the last action of the episode has been taken.
*/
void handleEpisodeEnd(Observation finalObservation);
}

View File

@ -0,0 +1,59 @@
/*******************************************************************************
* Copyright (c) 2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.rl4j.agent.learning;
import lombok.Builder;
import org.deeplearning4j.rl4j.agent.update.IUpdateRule;
import org.deeplearning4j.rl4j.experience.ExperienceHandler;
import org.deeplearning4j.rl4j.observation.Observation;
/**
* A generic {@link ILearningBehavior} that delegates the handling of experience to a {@link ExperienceHandler} and
* the update logic to a {@link IUpdateRule}
*
* @param <ACTION> The type of the action
* @param <EXPERIENCE_TYPE> The type of experience the ExperienceHandler needs
*/
@Builder
public class LearningBehavior<ACTION, EXPERIENCE_TYPE> implements ILearningBehavior<ACTION> {
@Builder.Default
private int experienceUpdateSize = 64;
private final ExperienceHandler<ACTION, EXPERIENCE_TYPE> experienceHandler;
private final IUpdateRule<EXPERIENCE_TYPE> updateRule;
@Override
public void handleEpisodeStart() {
experienceHandler.reset();
}
@Override
public void handleNewExperience(Observation observation, ACTION action, double reward, boolean isTerminal) {
experienceHandler.addExperience(observation, action, reward, isTerminal);
if(experienceHandler.isTrainingBatchReady()) {
updateRule.update(experienceHandler.generateTrainingBatch());
}
}
@Override
public void handleEpisodeEnd(Observation finalObservation) {
experienceHandler.setFinalObservation(finalObservation);
if(experienceHandler.isTrainingBatchReady()) {
updateRule.update(experienceHandler.generateTrainingBatch());
}
}
}

View File

@ -1,23 +1,66 @@
/*******************************************************************************
* Copyright (c) 2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.rl4j.agent.listener;
import org.deeplearning4j.rl4j.agent.Agent;
import org.deeplearning4j.rl4j.environment.StepResult;
import org.deeplearning4j.rl4j.observation.Observation;
/**
* The base definition of all {@link Agent} event listeners
*/
public interface AgentListener<ACTION> {
enum ListenerResponse {
/**
* Tell the learning process to continue calling the listeners and the training.
* Tell the {@link Agent} to continue calling the listeners and the processing.
*/
CONTINUE,
/**
* Tell the learning process to stop calling the listeners and terminate the training.
* Tell the {@link Agent} to interrupt calling the listeners and stop the processing.
*/
STOP,
}
/**
* Called when a new episode is about to start.
* @param agent The agent that generated the event
*
* @return A {@link ListenerResponse}.
*/
AgentListener.ListenerResponse onBeforeEpisode(Agent agent);
/**
* Called when a step is about to be taken.
*
* @param agent The agent that generated the event
* @param observation The observation before the action is taken
* @param action The action that will be performed
*
* @return A {@link ListenerResponse}.
*/
AgentListener.ListenerResponse onBeforeStep(Agent agent, Observation observation, ACTION action);
/**
* Called after a step has been taken.
*
* @param agent The agent that generated the event
* @param stepResult The {@link StepResult} result of the step.
*
* @return A {@link ListenerResponse}.
*/
AgentListener.ListenerResponse onAfterStep(Agent agent, StepResult stepResult);
}

View File

@ -1,3 +1,18 @@
/*******************************************************************************
* Copyright (c) 2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.rl4j.agent.listener;
import org.deeplearning4j.rl4j.agent.Agent;
@ -7,6 +22,10 @@ import org.deeplearning4j.rl4j.observation.Observation;
import java.util.ArrayList;
import java.util.List;
/**
* A class that manages a list of {@link AgentListener AgentListeners} listening to an {@link Agent}.
* @param <ACTION>
*/
public class AgentListenerList<ACTION> {
protected final List<AgentListener<ACTION>> listeners = new ArrayList<>();
@ -18,6 +37,13 @@ public class AgentListenerList<ACTION> {
listeners.add(listener);
}
/**
* This method will notify all listeners that an episode is about to start. If a listener returns
* {@link AgentListener.ListenerResponse STOP}, any following listener is skipped.
*
* @param agent The agent that generated the event.
* @return False if the processing should be stopped
*/
public boolean notifyBeforeEpisode(Agent<ACTION> agent) {
for (AgentListener<ACTION> listener : listeners) {
if (listener.onBeforeEpisode(agent) == AgentListener.ListenerResponse.STOP) {
@ -28,6 +54,13 @@ public class AgentListenerList<ACTION> {
return true;
}
/**
*
* @param agent The agent that generated the event.
* @param observation The observation before the action is taken
* @param action The action that will be performed
* @return False if the processing should be stopped
*/
public boolean notifyBeforeStep(Agent<ACTION> agent, Observation observation, ACTION action) {
for (AgentListener<ACTION> listener : listeners) {
if (listener.onBeforeStep(agent, observation, action) == AgentListener.ListenerResponse.STOP) {
@ -38,6 +71,12 @@ public class AgentListenerList<ACTION> {
return true;
}
/**
*
* @param agent The agent that generated the event.
* @param stepResult The {@link StepResult} result of the step.
* @return False if the processing should be stopped
*/
public boolean notifyAfterStep(Agent<ACTION> agent, StepResult stepResult) {
for (AgentListener<ACTION> listener : listeners) {
if (listener.onAfterStep(agent, stepResult) == AgentListener.ListenerResponse.STOP) {

View File

@ -0,0 +1,62 @@
/*******************************************************************************
* Copyright (c) 2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.rl4j.agent.update;
import lombok.Getter;
import org.deeplearning4j.rl4j.learning.sync.Transition;
import org.deeplearning4j.rl4j.learning.sync.qlearning.TargetQNetworkSource;
import org.deeplearning4j.rl4j.learning.sync.qlearning.discrete.TDTargetAlgorithm.DoubleDQN;
import org.deeplearning4j.rl4j.learning.sync.qlearning.discrete.TDTargetAlgorithm.ITDTargetAlgorithm;
import org.deeplearning4j.rl4j.learning.sync.qlearning.discrete.TDTargetAlgorithm.StandardDQN;
import org.deeplearning4j.rl4j.network.dqn.IDQN;
import org.nd4j.linalg.dataset.api.DataSet;
import java.util.List;
// Temporary class that will be replaced with a more generic class that delegates gradient computation
// and network update to sub components.
public class DQNNeuralNetUpdateRule implements IUpdateRule<Transition<Integer>>, TargetQNetworkSource {
@Getter
private final IDQN qNetwork;
@Getter
private IDQN targetQNetwork;
private final int targetUpdateFrequency;
private final ITDTargetAlgorithm<Integer> tdTargetAlgorithm;
@Getter
private int updateCount = 0;
public DQNNeuralNetUpdateRule(IDQN qNetwork, int targetUpdateFrequency, boolean isDoubleDQN, double gamma, double errorClamp) {
this.qNetwork = qNetwork;
this.targetQNetwork = qNetwork.clone();
this.targetUpdateFrequency = targetUpdateFrequency;
tdTargetAlgorithm = isDoubleDQN
? new DoubleDQN(this, gamma, errorClamp)
: new StandardDQN(this, gamma, errorClamp);
}
@Override
public void update(List<Transition<Integer>> trainingBatch) {
DataSet targets = tdTargetAlgorithm.computeTDTargets(trainingBatch);
qNetwork.fit(targets.getFeatures(), targets.getLabels());
if(++updateCount % targetUpdateFrequency == 0) {
targetQNetwork = qNetwork.clone();
}
}
}

View File

@ -0,0 +1,26 @@
/*******************************************************************************
* Copyright (c) 2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.rl4j.agent.update;
import lombok.Value;
import org.deeplearning4j.nn.gradient.Gradient;
// Work in progress
@Value
public class Gradients {
private Gradient[] gradients; // Temporary: we'll need something better than a Gradient[]
private int batchSize;
}

View File

@ -0,0 +1,37 @@
/*******************************************************************************
* Copyright (c) 2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.rl4j.agent.update;
import java.util.List;
/**
* The role of IUpdateRule implementations is to use an experience batch to improve the accuracy of the policy.
* Used by {@link org.deeplearning4j.rl4j.agent.AgentLearner AgentLearner}
* @param <EXPERIENCE_TYPE> The type of the experience
*/
public interface IUpdateRule<EXPERIENCE_TYPE> {
/**
* Perform the update
* @param trainingBatch A batch of experience
*/
void update(List<EXPERIENCE_TYPE> trainingBatch);
/**
* @return The total number of times the policy has been updated. In a multi-agent learning context, this total is
* for all the agents.
*/
int getUpdateCount();
}

View File

@ -1,9 +0,0 @@
package org.deeplearning4j.rl4j.environment;
import lombok.Value;
@Value
public class ActionSchema<ACTION> {
private ACTION noOp;
//FIXME ACTION randomAction();
}

View File

@ -1,11 +1,54 @@
/*******************************************************************************
* Copyright (c) 2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.rl4j.environment;
import java.util.Map;
/**
* An interface for environments used by the {@link org.deeplearning4j.rl4j.agent.Agent Agents}.
* @param <ACTION> The type of actions
*/
public interface Environment<ACTION> {
/**
* @return The {@link Schema} of the environment
*/
Schema<ACTION> getSchema();
/**
* Reset the environment's state to start a new episode.
* @return
*/
Map<String, Object> reset();
/**
* Perform a single step.
*
* @param action The action taken
* @return A {@link StepResult} describing the result of the step.
*/
StepResult step(ACTION action);
/**
* @return True if the episode is finished
*/
boolean isEpisodeFinished();
/**
* Called when the agent is finished using this environment instance.
*/
void close();
}

View File

@ -0,0 +1,26 @@
/*******************************************************************************
* Copyright (c) 2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.rl4j.environment;
import lombok.Value;
// Work in progress
public interface IActionSchema<ACTION> {
ACTION getNoOp();
// Review: A schema should be data-only and not have behavior
ACTION getRandomAction();
}

View File

@ -0,0 +1,47 @@
/*******************************************************************************
* Copyright (c) 2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.rl4j.environment;
import org.nd4j.linalg.api.rng.Random;
import org.nd4j.linalg.factory.Nd4j;
// Work in progress
public class IntegerActionSchema implements IActionSchema<Integer> {
private final int numActions;
private final int noOpAction;
private final Random rnd;
public IntegerActionSchema(int numActions, int noOpAction) {
this(numActions, noOpAction, Nd4j.getRandom());
}
public IntegerActionSchema(int numActions, int noOpAction, Random rnd) {
this.numActions = numActions;
this.noOpAction = noOpAction;
this.rnd = rnd;
}
@Override
public Integer getNoOp() {
return noOpAction;
}
@Override
public Integer getRandomAction() {
return rnd.nextInt(numActions);
}
}

View File

@ -1,8 +1,24 @@
/*******************************************************************************
* Copyright (c) 2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.rl4j.environment;
import lombok.Value;
// Work in progress
@Value
public class Schema<ACTION> {
private ActionSchema<ACTION> actionSchema;
private IActionSchema<ACTION> actionSchema;
}

View File

@ -1,3 +1,18 @@
/*******************************************************************************
* Copyright (c) 2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.rl4j.environment;
import lombok.Value;

View File

@ -41,6 +41,11 @@ public interface ExperienceHandler<A, E> {
*/
int getTrainingBatchSize();
/**
* @return True if a batch is ready for training.
*/
boolean isTrainingBatchReady();
/**
* The elements are returned in the historical order (i.e. in the order they happened)
* @return The list of experience elements

View File

@ -36,6 +36,7 @@ import java.util.List;
public class ReplayMemoryExperienceHandler<A> implements ExperienceHandler<A, Transition<A>> {
private static final int DEFAULT_MAX_REPLAY_MEMORY_SIZE = 150000;
private static final int DEFAULT_BATCH_SIZE = 32;
private final int batchSize;
private IExpReplay<A> expReplay;
@ -43,6 +44,7 @@ public class ReplayMemoryExperienceHandler<A> implements ExperienceHandler<A, Tr
public ReplayMemoryExperienceHandler(IExpReplay<A> expReplay) {
this.expReplay = expReplay;
this.batchSize = expReplay.getDesignatedBatchSize();
}
public ReplayMemoryExperienceHandler(int maxReplayMemorySize, int batchSize, Random random) {
@ -64,6 +66,11 @@ public class ReplayMemoryExperienceHandler<A> implements ExperienceHandler<A, Tr
return expReplay.getBatchSize();
}
@Override
public boolean isTrainingBatchReady() {
return expReplay.getBatchSize() >= batchSize;
}
/**
* @return A batch of experience selected from the replay memory. The replay memory is unchanged after the call.
*/

View File

@ -30,10 +30,18 @@ import java.util.List;
*/
public class StateActionExperienceHandler<A> implements ExperienceHandler<A, StateActionPair<A>> {
private final int batchSize;
private boolean isFinalObservationSet;
public StateActionExperienceHandler(int batchSize) {
this.batchSize = batchSize;
}
private List<StateActionPair<A>> stateActionPairs = new ArrayList<>();
public void setFinalObservation(Observation observation) {
// Do nothing
isFinalObservationSet = true;
}
public void addExperience(Observation observation, A action, double reward, boolean isTerminal) {
@ -45,6 +53,12 @@ public class StateActionExperienceHandler<A> implements ExperienceHandler<A, Sta
return stateActionPairs.size();
}
@Override
public boolean isTrainingBatchReady() {
return stateActionPairs.size() >= batchSize
|| (isFinalObservationSet && stateActionPairs.size() > 0);
}
/**
* The elements are returned in the historical order (i.e. in the order they happened)
* Note: the experience store is cleared after calling this method.
@ -62,6 +76,7 @@ public class StateActionExperienceHandler<A> implements ExperienceHandler<A, Sta
@Override
public void reset() {
stateActionPairs = new ArrayList<>();
isFinalObservationSet = false;
}
}

View File

@ -24,17 +24,38 @@ import org.nd4j.linalg.factory.Nd4j;
* @author Alexandre Boulanger
*/
public class INDArrayHelper {
/**
* MultiLayerNetwork and ComputationGraph expects input data to be in NCHW in the case of pixels and NS in case of other data types.
*
* We must have either shape 2 (NK) or shape 4 (NCHW)
* Force the input source to have the correct shape:
* <p><ul>
* <li>DL4J requires it to be at least 2D</li>
* <li>RL4J has a convention to have the batch size on dimension 0 to all INDArrays</li>
* </ul></p>
* @param source The {@link INDArray} to be corrected.
* @return The corrected INDArray
*/
public static INDArray forceCorrectShape(INDArray source) {
return source.shape()[0] == 1 && source.shape().length > 1
return source.shape()[0] == 1 && source.rank() > 1
? source
: Nd4j.expandDims(source, 0);
}
/**
* This will create a INDArray with <i>batchSize</i> as dimension 0 and <i>shape</i> as other dimensions.
* For example, if <i>batchSize</i> is 10 and shape is { 1, 3, 4 }, the resulting INDArray shape will be { 10, 3, 4}
* @param batchSize The size of the batch to create
* @param shape The shape of individual elements.
* Note: all shapes in RL4J should have a batch size as dimension 0; in this case the batch size should be 1.
* @return A INDArray
*/
public static INDArray createBatchForShape(long batchSize, long... shape) {
long[] batchShape;
batchShape = new long[shape.length];
System.arraycopy(shape, 0, batchShape, 0, shape.length);
batchShape[0] = batchSize;
return Nd4j.create(batchShape);
}
}

View File

@ -25,6 +25,7 @@ import org.deeplearning4j.gym.StepReply;
import org.deeplearning4j.rl4j.experience.ExperienceHandler;
import org.deeplearning4j.rl4j.experience.StateActionExperienceHandler;
import org.deeplearning4j.rl4j.learning.IHistoryProcessor;
import org.deeplearning4j.rl4j.learning.configuration.IAsyncLearningConfiguration;
import org.deeplearning4j.rl4j.learning.listener.TrainingListenerList;
import org.deeplearning4j.rl4j.mdp.MDP;
import org.deeplearning4j.rl4j.network.NeuralNet;
@ -49,7 +50,7 @@ public abstract class AsyncThreadDiscrete<OBSERVATION extends Encodable, NN exte
// TODO: Make it configurable with a builder
@Setter(AccessLevel.PROTECTED) @Getter
private ExperienceHandler experienceHandler = new StateActionExperienceHandler();
private ExperienceHandler experienceHandler;
public AsyncThreadDiscrete(IAsyncGlobal<NN> asyncGlobal,
MDP<OBSERVATION, Integer, DiscreteSpace> mdp,
@ -60,6 +61,17 @@ public abstract class AsyncThreadDiscrete<OBSERVATION extends Encodable, NN exte
synchronized (asyncGlobal) {
current = (NN) asyncGlobal.getTarget().clone();
}
experienceHandler = new StateActionExperienceHandler(getNStep());
}
private int getNStep() {
IAsyncLearningConfiguration configuration = getConfiguration();
if(configuration == null) {
return Integer.MAX_VALUE;
}
return configuration.getNStep();
}
// TODO: Add an actor-learner class and be able to inject the update algorithm

View File

@ -71,7 +71,6 @@ public class AsyncNStepQLearningThreadDiscrete<OBSERVATION extends Encodable> ex
@Override
protected UpdateAlgorithm<IDQN> buildUpdateAlgorithm() {
int[] shape = getHistoryProcessor() == null ? getMdp().getObservationSpace().getShape() : getHistoryProcessor().getConf().getShape();
return new QLearningUpdateAlgorithm(shape, getMdp().getActionSpace().getSize(), configuration.getGamma());
return new QLearningUpdateAlgorithm(getMdp().getActionSpace().getSize(), configuration.getGamma());
}
}

View File

@ -17,7 +17,7 @@ package org.deeplearning4j.rl4j.learning.async.nstep.discrete;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.rl4j.experience.StateActionPair;
import org.deeplearning4j.rl4j.learning.Learning;
import org.deeplearning4j.rl4j.helper.INDArrayHelper;
import org.deeplearning4j.rl4j.learning.async.UpdateAlgorithm;
import org.deeplearning4j.rl4j.network.dqn.IDQN;
import org.nd4j.linalg.api.ndarray.INDArray;
@ -27,15 +27,12 @@ import java.util.List;
public class QLearningUpdateAlgorithm implements UpdateAlgorithm<IDQN> {
private final int[] shape;
private final int actionSpaceSize;
private final double gamma;
public QLearningUpdateAlgorithm(int[] shape,
int actionSpaceSize,
public QLearningUpdateAlgorithm(int actionSpaceSize,
double gamma) {
this.shape = shape;
this.actionSpaceSize = actionSpaceSize;
this.gamma = gamma;
}
@ -44,33 +41,34 @@ public class QLearningUpdateAlgorithm implements UpdateAlgorithm<IDQN> {
public Gradient[] computeGradients(IDQN current, List<StateActionPair<Integer>> experience) {
int size = experience.size();
int[] nshape = Learning.makeShape(size, shape);
INDArray input = Nd4j.create(nshape);
INDArray targets = Nd4j.create(size, actionSpaceSize);
StateActionPair<Integer> stateActionPair = experience.get(size - 1);
INDArray data = stateActionPair.getObservation().getData();
INDArray features = INDArrayHelper.createBatchForShape(size, data.shape());
INDArray targets = Nd4j.create(size, actionSpaceSize);
double r;
if (stateActionPair.isTerminal()) {
r = 0;
} else {
INDArray[] output = null;
output = current.outputAll(stateActionPair.getObservation().getData());
output = current.outputAll(data);
r = Nd4j.max(output[0]).getDouble(0);
}
for (int i = size - 1; i >= 0; i--) {
stateActionPair = experience.get(i);
data = stateActionPair.getObservation().getData();
input.putRow(i, stateActionPair.getObservation().getData());
features.putRow(i, data);
r = stateActionPair.getReward() + gamma * r;
INDArray[] output = current.outputAll(stateActionPair.getObservation().getData());
INDArray[] output = current.outputAll(data);
INDArray row = output[0];
row = row.putScalar(stateActionPair.getAction(), r);
targets.putRow(i, row);
}
return current.gradient(input, targets);
return current.gradient(features, targets);
}
}

View File

@ -80,6 +80,11 @@ public class ExpReplay<A> implements IExpReplay<A> {
//log.info("size: "+storage.size());
}
@Override
public int getDesignatedBatchSize() {
return batchSize;
}
public int getBatchSize() {
int storageSize = storage.size();
return Math.min(storageSize, batchSize);

View File

@ -47,4 +47,9 @@ public interface IExpReplay<A> {
* @param transition a new transition to store
*/
void store(Transition<A> transition);
/**
* @return The desired size of batches
*/
int getDesignatedBatchSize();
}

View File

@ -51,25 +51,16 @@ import java.util.List;
@Slf4j
public abstract class QLearning<O extends Encodable, A, AS extends ActionSpace<A>>
extends SyncLearning<O, A, AS, IDQN>
implements TargetQNetworkSource, IEpochTrainer {
implements IEpochTrainer {
protected abstract LegacyMDPWrapper<O, A, AS> getLegacyMDPWrapper();
protected abstract EpsGreedy<O, A, AS> getEgPolicy();
protected abstract EpsGreedy<A> getEgPolicy();
public abstract MDP<O, A, AS> getMdp();
public abstract IDQN getQNetwork();
public abstract IDQN getTargetQNetwork();
protected abstract void setTargetQNetwork(IDQN dqn);
protected void updateTargetNetwork() {
log.info("Update target network");
setTargetQNetwork(getQNetwork().clone());
}
public IDQN getNeuralNet() {
return getQNetwork();
}
@ -101,11 +92,6 @@ public abstract class QLearning<O extends Encodable, A, AS extends ActionSpace<A
int numQ = 0;
List<Double> scores = new ArrayList<>();
while (currentEpisodeStepCount < getConfiguration().getMaxEpochStep() && !getMdp().isDone()) {
if (this.getStepCount() % getConfiguration().getTargetDqnUpdateFreq() == 0) {
updateTargetNetwork();
}
QLStepReturn<Observation> stepR = trainStep(obs);
if (!stepR.getMaxQ().isNaN()) {
@ -146,7 +132,6 @@ public abstract class QLearning<O extends Encodable, A, AS extends ActionSpace<A
protected void resetNetworks() {
getQNetwork().reset();
getTargetQNetwork().reset();
}
private InitMdp<Observation> refacInitMdp() {

View File

@ -21,6 +21,10 @@ import lombok.AccessLevel;
import lombok.Getter;
import lombok.Setter;
import org.deeplearning4j.gym.StepReply;
import org.deeplearning4j.rl4j.agent.learning.ILearningBehavior;
import org.deeplearning4j.rl4j.agent.learning.LearningBehavior;
import org.deeplearning4j.rl4j.agent.update.DQNNeuralNetUpdateRule;
import org.deeplearning4j.rl4j.agent.update.IUpdateRule;
import org.deeplearning4j.rl4j.experience.ExperienceHandler;
import org.deeplearning4j.rl4j.experience.ReplayMemoryExperienceHandler;
import org.deeplearning4j.rl4j.learning.IHistoryProcessor;
@ -28,9 +32,6 @@ import org.deeplearning4j.rl4j.learning.Learning;
import org.deeplearning4j.rl4j.learning.configuration.QLearningConfiguration;
import org.deeplearning4j.rl4j.learning.sync.Transition;
import org.deeplearning4j.rl4j.learning.sync.qlearning.QLearning;
import org.deeplearning4j.rl4j.learning.sync.qlearning.discrete.TDTargetAlgorithm.DoubleDQN;
import org.deeplearning4j.rl4j.learning.sync.qlearning.discrete.TDTargetAlgorithm.ITDTargetAlgorithm;
import org.deeplearning4j.rl4j.learning.sync.qlearning.discrete.TDTargetAlgorithm.StandardDQN;
import org.deeplearning4j.rl4j.mdp.MDP;
import org.deeplearning4j.rl4j.network.dqn.IDQN;
import org.deeplearning4j.rl4j.space.Encodable;
@ -41,12 +42,8 @@ import org.deeplearning4j.rl4j.space.DiscreteSpace;
import org.deeplearning4j.rl4j.util.LegacyMDPWrapper;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.rng.Random;
import org.nd4j.linalg.dataset.api.DataSet;
import org.nd4j.linalg.factory.Nd4j;
import java.util.List;
/**
* @author rubenfiszel (ruben.fiszel@epfl.ch) 7/18/16.
@ -63,22 +60,15 @@ public abstract class QLearningDiscrete<O extends Encodable> extends QLearning<O
@Getter
private DQNPolicy<O> policy;
@Getter
private EpsGreedy<O, Integer, DiscreteSpace> egPolicy;
private EpsGreedy<Integer> egPolicy;
@Getter
final private IDQN qNetwork;
@Getter
@Setter(AccessLevel.PROTECTED)
private IDQN targetQNetwork;
private int lastAction;
private double accuReward = 0;
ITDTargetAlgorithm tdTargetAlgorithm;
// TODO: User a builder and remove the setter
@Getter(AccessLevel.PROTECTED) @Setter
private ExperienceHandler<Integer, Transition<Integer>> experienceHandler;
private final ILearningBehavior<Integer> learningBehavior;
protected LegacyMDPWrapper<O, Integer, DiscreteSpace> getLegacyMDPWrapper() {
return mdp;
@ -88,21 +78,31 @@ public abstract class QLearningDiscrete<O extends Encodable> extends QLearning<O
this(mdp, dqn, conf, epsilonNbStep, Nd4j.getRandomFactory().getNewRandomInstance(conf.getSeed()));
}
public QLearningDiscrete(MDP<O, Integer, DiscreteSpace> mdp, IDQN dqn, QLearningConfiguration conf, int epsilonNbStep, Random random) {
this(mdp, dqn, conf, epsilonNbStep, buildLearningBehavior(dqn, conf, random), random);
}
public QLearningDiscrete(MDP<O, Integer, DiscreteSpace> mdp, IDQN dqn, QLearningConfiguration conf,
int epsilonNbStep, Random random) {
int epsilonNbStep, ILearningBehavior<Integer> learningBehavior, Random random) {
this.configuration = conf;
this.mdp = new LegacyMDPWrapper<>(mdp, null);
qNetwork = dqn;
targetQNetwork = dqn.clone();
policy = new DQNPolicy(getQNetwork());
egPolicy = new EpsGreedy(policy, mdp, conf.getUpdateStart(), epsilonNbStep, random, conf.getMinEpsilon(),
this);
tdTargetAlgorithm = conf.isDoubleDQN()
? new DoubleDQN(this, conf.getGamma(), conf.getErrorClamp())
: new StandardDQN(this, conf.getGamma(), conf.getErrorClamp());
this.learningBehavior = learningBehavior;
}
private static ILearningBehavior<Integer> buildLearningBehavior(IDQN qNetwork, QLearningConfiguration conf, Random random) {
IUpdateRule<Transition<Integer>> updateRule = new DQNNeuralNetUpdateRule(qNetwork, conf.getTargetDqnUpdateFreq(), conf.isDoubleDQN(), conf.getGamma(), conf.getErrorClamp());
ExperienceHandler<Integer, Transition<Integer>> experienceHandler = new ReplayMemoryExperienceHandler(conf.getExpRepMaxSize(), conf.getBatchSize(), random);
return LearningBehavior.<Integer, Transition<Integer>>builder()
.experienceHandler(experienceHandler)
.updateRule(updateRule)
.experienceUpdateSize(conf.getBatchSize())
.build();
experienceHandler = new ReplayMemoryExperienceHandler(conf.getExpRepMaxSize(), conf.getBatchSize(), random);
}
public MDP<O, Integer, DiscreteSpace> getMdp() {
@ -119,7 +119,7 @@ public abstract class QLearningDiscrete<O extends Encodable> extends QLearning<O
public void preEpoch() {
lastAction = mdp.getActionSpace().noOp();
accuReward = 0;
experienceHandler.reset();
learningBehavior.handleEpisodeStart();
}
@Override
@ -136,12 +136,6 @@ public abstract class QLearningDiscrete<O extends Encodable> extends QLearning<O
*/
protected QLStepReturn<Observation> trainStep(Observation obs) {
boolean isHistoryProcessor = getHistoryProcessor() != null;
int skipFrame = isHistoryProcessor ? getHistoryProcessor().getConf().getSkipFrame() : 1;
int historyLength = isHistoryProcessor ? getHistoryProcessor().getConf().getHistoryLength() : 1;
int updateStart = this.getConfiguration().getUpdateStart()
+ ((this.getConfiguration().getBatchSize() + historyLength) * skipFrame);
Double maxQ = Double.NaN; //ignore if Nan for stats
//if step of training, just repeat lastAction
@ -160,29 +154,15 @@ public abstract class QLearningDiscrete<O extends Encodable> extends QLearning<O
if (!obs.isSkipped()) {
// Add experience
experienceHandler.addExperience(obs, lastAction, accuReward, stepReply.isDone());
learningBehavior.handleNewExperience(obs, lastAction, accuReward, stepReply.isDone());
accuReward = 0;
// Update NN
// FIXME: maybe start updating when experience replay has reached a certain size instead of using "updateStart"?
if (this.getStepCount() > updateStart) {
DataSet targets = setTarget(experienceHandler.generateTrainingBatch());
getQNetwork().fit(targets.getFeatures(), targets.getLabels());
}
}
return new QLStepReturn<>(maxQ, getQNetwork().getLatestScore(), stepReply);
}
protected DataSet setTarget(List<Transition<Integer>> transitions) {
if (transitions.size() == 0)
throw new IllegalArgumentException("too few transitions");
return tdTargetAlgorithm.computeTDTargets(transitions);
}
@Override
protected void finishEpoch(Observation observation) {
experienceHandler.setFinalObservation(observation);
learningBehavior.handleEpisodeEnd(observation);
}
}

View File

@ -2,21 +2,19 @@ package org.deeplearning4j.rl4j.mdp;
import lombok.Getter;
import lombok.Setter;
import org.deeplearning4j.rl4j.environment.ActionSchema;
import org.deeplearning4j.rl4j.environment.Environment;
import org.deeplearning4j.rl4j.environment.Schema;
import org.deeplearning4j.rl4j.environment.StepResult;
import org.deeplearning4j.rl4j.environment.*;
import org.nd4j.linalg.api.rng.Random;
import org.nd4j.linalg.factory.Nd4j;
import java.util.HashMap;
import java.util.Map;
import java.util.Random;
public class CartpoleEnvironment implements Environment<Integer> {
private static final int NUM_ACTIONS = 2;
private static final int ACTION_LEFT = 0;
private static final int ACTION_RIGHT = 1;
private static final Schema<Integer> schema = new Schema<>(new ActionSchema<>(ACTION_LEFT));
private final Schema<Integer> schema;
public enum KinematicsIntegrators { Euler, SemiImplicitEuler };
@ -48,11 +46,12 @@ public class CartpoleEnvironment implements Environment<Integer> {
private Integer stepsBeyondDone;
public CartpoleEnvironment() {
rnd = new Random();
this(Nd4j.getRandom());
}
public CartpoleEnvironment(int seed) {
rnd = new Random(seed);
public CartpoleEnvironment(Random rnd) {
this.rnd = rnd;
this.schema = new Schema<Integer>(new IntegerActionSchema(NUM_ACTIONS, ACTION_LEFT, rnd));
}
@Override

View File

@ -17,16 +17,19 @@
package org.deeplearning4j.rl4j.policy;
import lombok.AllArgsConstructor;
import lombok.Builder;
import lombok.NonNull;
import lombok.extern.slf4j.Slf4j;
import org.deeplearning4j.rl4j.environment.IActionSchema;
import org.deeplearning4j.rl4j.learning.IEpochTrainer;
import org.deeplearning4j.rl4j.mdp.MDP;
import org.deeplearning4j.rl4j.network.NeuralNet;
import org.deeplearning4j.rl4j.space.Encodable;
import org.deeplearning4j.rl4j.observation.Observation;
import org.deeplearning4j.rl4j.space.ActionSpace;
import org.deeplearning4j.rl4j.space.Encodable;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.rng.Random;
import org.nd4j.linalg.factory.Nd4j;
/**
* @author rubenfiszel (ruben.fiszel@epfl.ch) 7/24/16.
@ -38,18 +41,60 @@ import org.nd4j.linalg.api.rng.Random;
* epislon is annealed to minEpsilon over epsilonNbStep steps
*
*/
@AllArgsConstructor
@Slf4j
public class EpsGreedy<OBSERVATION extends Encodable, A, AS extends ActionSpace<A>> extends Policy<A> {
public class EpsGreedy<A> extends Policy<A> {
final private Policy<A> policy;
final private MDP<OBSERVATION, A, AS> mdp;
final private INeuralNetPolicy<A> policy;
final private int updateStart;
final private int epsilonNbStep;
final private Random rnd;
final private double minEpsilon;
private final IActionSchema<A> actionSchema;
final private MDP<Encodable, A, ActionSpace<A>> mdp;
final private IEpochTrainer learning;
// Using agent's (learning's) step count is incorrect; frame skipping makes epsilon's value decrease too quickly
private int annealingStep = 0;
@Deprecated
public <OBSERVATION extends Encodable, AS extends ActionSpace<A>> EpsGreedy(Policy<A> policy,
MDP<Encodable, A, ActionSpace<A>> mdp,
int updateStart,
int epsilonNbStep,
Random rnd,
double minEpsilon,
IEpochTrainer learning) {
this.policy = policy;
this.mdp = mdp;
this.updateStart = updateStart;
this.epsilonNbStep = epsilonNbStep;
this.rnd = rnd;
this.minEpsilon = minEpsilon;
this.learning = learning;
this.actionSchema = null;
}
public EpsGreedy(@NonNull Policy<A> policy, @NonNull IActionSchema<A> actionSchema, double minEpsilon, int updateStart, int epsilonNbStep) {
this(policy, actionSchema, minEpsilon, updateStart, epsilonNbStep, null);
}
@Builder
public EpsGreedy(@NonNull INeuralNetPolicy<A> policy, @NonNull IActionSchema<A> actionSchema, double minEpsilon, int updateStart, int epsilonNbStep, Random rnd) {
this.policy = policy;
this.rnd = rnd == null ? Nd4j.getRandom() : rnd;
this.minEpsilon = minEpsilon;
this.updateStart = updateStart;
this.epsilonNbStep = epsilonNbStep;
this.actionSchema = actionSchema;
this.mdp = null;
this.learning = null;
}
public NeuralNet getNeuralNet() {
return policy.getNeuralNet();
}
@ -57,6 +102,11 @@ public class EpsGreedy<OBSERVATION extends Encodable, A, AS extends ActionSpace<
public A nextAction(INDArray input) {
double ep = getEpsilon();
if(actionSchema != null) {
// Only legacy classes should pass here.
throw new RuntimeException("nextAction(Observation observation) should be called when using a AgentLearner");
}
if (learning.getStepCount() % 500 == 1)
log.info("EP: " + ep + " " + learning.getStepCount());
if (rnd.nextDouble() > ep)
@ -66,10 +116,31 @@ public class EpsGreedy<OBSERVATION extends Encodable, A, AS extends ActionSpace<
}
public A nextAction(Observation observation) {
if(actionSchema == null) {
return this.nextAction(observation.getData());
}
A result;
double ep = getEpsilon();
if (annealingStep % 500 == 1) {
log.info("EP: " + ep + " " + annealingStep);
}
if (rnd.nextDouble() > ep) {
result = policy.nextAction(observation);
}
else {
result = actionSchema.getRandomAction();
}
++annealingStep;
return result;
}
public double getEpsilon() {
return Math.min(1.0, Math.max(minEpsilon, 1.0 - (learning.getStepCount() - updateStart) * 1.0 / epsilonNbStep));
int step = actionSchema != null ? annealingStep : learning.getStepCount();
return Math.min(1.0, Math.max(minEpsilon, 1.0 - (step - updateStart) * 1.0 / epsilonNbStep));
}
}

View File

@ -0,0 +1,7 @@
package org.deeplearning4j.rl4j.policy;
import org.deeplearning4j.rl4j.network.NeuralNet;
public interface INeuralNetPolicy<ACTION> extends IPolicy<ACTION> {
NeuralNet getNeuralNet();
}

View File

@ -34,7 +34,7 @@ import org.deeplearning4j.rl4j.util.LegacyMDPWrapper;
*
* A Policy responsability is to choose the next action given a state
*/
public abstract class Policy<A> implements IPolicy<A> {
public abstract class Policy<A> implements INeuralNetPolicy<A> {
public abstract NeuralNet getNeuralNet();

View File

@ -0,0 +1,211 @@
package org.deeplearning4j.rl4j.agent;
import org.deeplearning4j.rl4j.agent.learning.LearningBehavior;
import org.deeplearning4j.rl4j.environment.Environment;
import org.deeplearning4j.rl4j.environment.IntegerActionSchema;
import org.deeplearning4j.rl4j.environment.Schema;
import org.deeplearning4j.rl4j.environment.StepResult;
import org.deeplearning4j.rl4j.observation.Observation;
import org.deeplearning4j.rl4j.observation.transform.TransformProcess;
import org.deeplearning4j.rl4j.policy.IPolicy;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.mockito.ArgumentCaptor;
import org.mockito.Mock;
import org.mockito.invocation.InvocationOnMock;
import org.mockito.junit.MockitoJUnitRunner;
import org.mockito.stubbing.Answer;
import org.nd4j.linalg.factory.Nd4j;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import static org.mockito.ArgumentMatchers.*;
import static org.mockito.Mockito.*;
import static org.junit.Assert.*;
@RunWith(MockitoJUnitRunner.class)
public class AgentLearnerTest {
@Mock
Environment<Integer> environmentMock;
@Mock
TransformProcess transformProcessMock;
@Mock
IPolicy<Integer> policyMock;
@Mock
LearningBehavior<Integer, Object> learningBehaviorMock;
@Test
public void when_episodeIsStarted_expect_learningBehaviorHandleEpisodeStartCalled() {
// Arrange
AgentLearner<Integer> sut = AgentLearner.builder(environmentMock, transformProcessMock, policyMock, learningBehaviorMock)
.maxEpisodeSteps(3)
.build();
Schema schema = new Schema(new IntegerActionSchema(0, -1));
when(environmentMock.reset()).thenReturn(new HashMap<>());
when(environmentMock.getSchema()).thenReturn(schema);
StepResult stepResult = new StepResult(new HashMap<>(), 234.0, false);
when(environmentMock.step(any(Integer.class))).thenReturn(stepResult);
when(transformProcessMock.transform(any(Map.class), anyInt(), anyBoolean())).thenReturn(new Observation(Nd4j.create(new double[] { 123.0 })));
when(policyMock.nextAction(any(Observation.class))).thenReturn(123);
// Act
sut.run();
// Assert
verify(learningBehaviorMock, times(1)).handleEpisodeStart();
}
@Test
public void when_runIsCalled_expect_experienceHandledWithLearningBehavior() {
// Arrange
AgentLearner<Integer> sut = AgentLearner.builder(environmentMock, transformProcessMock, policyMock, learningBehaviorMock)
.maxEpisodeSteps(4)
.build();
Schema schema = new Schema(new IntegerActionSchema(0, -1));
when(environmentMock.getSchema()).thenReturn(schema);
when(environmentMock.reset()).thenReturn(new HashMap<>());
double[] reward = new double[] { 0.0 };
when(environmentMock.step(any(Integer.class)))
.thenAnswer(a -> new StepResult(new HashMap<>(), ++reward[0], reward[0] == 4.0));
when(environmentMock.isEpisodeFinished()).thenAnswer(x -> reward[0] == 4.0);
when(transformProcessMock.transform(any(Map.class), anyInt(), anyBoolean()))
.thenAnswer(new Answer<Observation>() {
public Observation answer(InvocationOnMock invocation) throws Throwable {
int step = (int)invocation.getArgument(1);
boolean isTerminal = (boolean)invocation.getArgument(2);
return (step % 2 == 0 || isTerminal)
? new Observation(Nd4j.create(new double[] { step * 1.1 }))
: Observation.SkippedObservation;
}
});
when(policyMock.nextAction(any(Observation.class))).thenAnswer(x -> (int)reward[0]);
// Act
sut.run();
// Assert
ArgumentCaptor<Observation> observationCaptor = ArgumentCaptor.forClass(Observation.class);
ArgumentCaptor<Integer> actionCaptor = ArgumentCaptor.forClass(Integer.class);
ArgumentCaptor<Double> rewardCaptor = ArgumentCaptor.forClass(Double.class);
ArgumentCaptor<Boolean> isTerminalCaptor = ArgumentCaptor.forClass(Boolean.class);
verify(learningBehaviorMock, times(2)).handleNewExperience(observationCaptor.capture(), actionCaptor.capture(), rewardCaptor.capture(), isTerminalCaptor.capture());
List<Observation> observations = observationCaptor.getAllValues();
List<Integer> actions = actionCaptor.getAllValues();
List<Double> rewards = rewardCaptor.getAllValues();
List<Boolean> isTerminalList = isTerminalCaptor.getAllValues();
assertEquals(0.0, observations.get(0).getData().getDouble(0), 0.00001);
assertEquals(0, (int)actions.get(0));
assertEquals(0.0 + 1.0, rewards.get(0), 0.00001);
assertFalse(isTerminalList.get(0));
assertEquals(2.2, observations.get(1).getData().getDouble(0), 0.00001);
assertEquals(2, (int)actions.get(1));
assertEquals(2.0 + 3.0, rewards.get(1), 0.00001);
assertFalse(isTerminalList.get(1));
ArgumentCaptor<Observation> finalObservationCaptor = ArgumentCaptor.forClass(Observation.class);
verify(learningBehaviorMock, times(1)).handleEpisodeEnd(finalObservationCaptor.capture());
assertEquals(4.4, finalObservationCaptor.getValue().getData().getDouble(0), 0.00001);
}
@Test
public void when_runIsCalledMultipleTimes_expect_totalStepCountCorrect() {
// Arrange
AgentLearner<Integer> sut = AgentLearner.builder(environmentMock, transformProcessMock, policyMock, learningBehaviorMock)
.maxEpisodeSteps(4)
.build();
Schema schema = new Schema(new IntegerActionSchema(0, -1));
when(environmentMock.getSchema()).thenReturn(schema);
when(environmentMock.reset()).thenReturn(new HashMap<>());
double[] reward = new double[] { 0.0 };
when(environmentMock.step(any(Integer.class)))
.thenAnswer(a -> new StepResult(new HashMap<>(), ++reward[0], reward[0] == 4.0));
when(environmentMock.isEpisodeFinished()).thenAnswer(x -> reward[0] == 4.0);
when(transformProcessMock.transform(any(Map.class), anyInt(), anyBoolean()))
.thenAnswer(new Answer<Observation>() {
public Observation answer(InvocationOnMock invocation) throws Throwable {
int step = (int)invocation.getArgument(1);
boolean isTerminal = (boolean)invocation.getArgument(2);
return (step % 2 == 0 || isTerminal)
? new Observation(Nd4j.create(new double[] { step * 1.1 }))
: Observation.SkippedObservation;
}
});
when(policyMock.nextAction(any(Observation.class))).thenAnswer(x -> (int)reward[0]);
// Act
sut.run();
reward[0] = 0.0;
sut.run();
// Assert
assertEquals(8, sut.getTotalStepCount());
}
@Test
public void when_runIsCalledMultipleTimes_expect_rewardSentToLearningBehaviorToBeCorrect() {
// Arrange
AgentLearner<Integer> sut = AgentLearner.builder(environmentMock, transformProcessMock, policyMock, learningBehaviorMock)
.maxEpisodeSteps(4)
.build();
Schema schema = new Schema(new IntegerActionSchema(0, -1));
when(environmentMock.getSchema()).thenReturn(schema);
when(environmentMock.reset()).thenReturn(new HashMap<>());
double[] reward = new double[] { 0.0 };
when(environmentMock.step(any(Integer.class)))
.thenAnswer(a -> new StepResult(new HashMap<>(), ++reward[0], reward[0] == 4.0));
when(environmentMock.isEpisodeFinished()).thenAnswer(x -> reward[0] == 4.0);
when(transformProcessMock.transform(any(Map.class), anyInt(), anyBoolean()))
.thenAnswer(new Answer<Observation>() {
public Observation answer(InvocationOnMock invocation) throws Throwable {
int step = (int)invocation.getArgument(1);
boolean isTerminal = (boolean)invocation.getArgument(2);
return (step % 2 == 0 || isTerminal)
? new Observation(Nd4j.create(new double[] { step * 1.1 }))
: Observation.SkippedObservation;
}
});
when(policyMock.nextAction(any(Observation.class))).thenAnswer(x -> (int)reward[0]);
// Act
sut.run();
reward[0] = 0.0;
sut.run();
// Assert
ArgumentCaptor<Double> rewardCaptor = ArgumentCaptor.forClass(Double.class);
verify(learningBehaviorMock, times(4)).handleNewExperience(any(Observation.class), any(Integer.class), rewardCaptor.capture(), any(Boolean.class));
List<Double> rewards = rewardCaptor.getAllValues();
// rewardAtLastExperience at the end of 1st call to .run() should not leak into 2nd call.
assertEquals(0.0 + 1.0, rewards.get(2), 0.00001);
assertEquals(2.0 + 3.0, rewards.get(3), 0.00001);
}
}

View File

@ -1,10 +1,7 @@
package org.deeplearning4j.rl4j.agent;
import org.deeplearning4j.rl4j.agent.listener.AgentListener;
import org.deeplearning4j.rl4j.environment.ActionSchema;
import org.deeplearning4j.rl4j.environment.Environment;
import org.deeplearning4j.rl4j.environment.Schema;
import org.deeplearning4j.rl4j.environment.StepResult;
import org.deeplearning4j.rl4j.environment.*;
import org.deeplearning4j.rl4j.observation.Observation;
import org.deeplearning4j.rl4j.observation.transform.TransformProcess;
import org.deeplearning4j.rl4j.policy.IPolicy;
@ -12,6 +9,7 @@ import org.junit.Rule;
import org.junit.Test;
import static org.junit.Assert.*;
import org.junit.runner.RunWith;
import org.mockito.*;
import org.mockito.junit.*;
import org.nd4j.linalg.factory.Nd4j;
@ -23,8 +21,8 @@ import java.util.Map;
import static org.mockito.ArgumentMatchers.*;
import static org.mockito.Mockito.*;
@RunWith(MockitoJUnitRunner.class)
public class AgentTest {
@Mock Environment environmentMock;
@Mock TransformProcess transformProcessMock;
@Mock IPolicy policyMock;
@ -102,7 +100,7 @@ public class AgentTest {
public void when_runIsCalled_expect_agentIsReset() {
// Arrange
Map<String, Object> envResetResult = new HashMap<>();
Schema schema = new Schema(new ActionSchema<>(-1));
Schema schema = new Schema(new IntegerActionSchema(0, -1));
when(environmentMock.reset()).thenReturn(envResetResult);
when(environmentMock.getSchema()).thenReturn(schema);
@ -119,7 +117,7 @@ public class AgentTest {
sut.run();
// Assert
assertEquals(0, sut.getEpisodeStepNumber());
assertEquals(0, sut.getEpisodeStepCount());
verify(transformProcessMock).transform(envResetResult, 0, false);
verify(policyMock, times(1)).reset();
assertEquals(0.0, sut.getReward(), 0.00001);
@ -130,7 +128,7 @@ public class AgentTest {
public void when_runIsCalled_expect_onBeforeAndAfterEpisodeCalled() {
// Arrange
Map<String, Object> envResetResult = new HashMap<>();
Schema schema = new Schema(new ActionSchema<>(-1));
Schema schema = new Schema(new IntegerActionSchema(0, -1));
when(environmentMock.reset()).thenReturn(envResetResult);
when(environmentMock.getSchema()).thenReturn(schema);
@ -152,7 +150,7 @@ public class AgentTest {
public void when_onBeforeEpisodeReturnsStop_expect_performStepAndOnAfterEpisodeNotCalled() {
// Arrange
Map<String, Object> envResetResult = new HashMap<>();
Schema schema = new Schema(new ActionSchema<>(-1));
Schema schema = new Schema(new IntegerActionSchema(0, -1));
when(environmentMock.reset()).thenReturn(envResetResult);
when(environmentMock.getSchema()).thenReturn(schema);
@ -179,7 +177,7 @@ public class AgentTest {
public void when_runIsCalledWithoutMaxStep_expect_agentRunUntilEpisodeIsFinished() {
// Arrange
Map<String, Object> envResetResult = new HashMap<>();
Schema schema = new Schema(new ActionSchema<>(-1));
Schema schema = new Schema(new IntegerActionSchema(0, -1));
when(environmentMock.reset()).thenReturn(envResetResult);
when(environmentMock.getSchema()).thenReturn(schema);
@ -191,10 +189,10 @@ public class AgentTest {
final Agent spy = Mockito.spy(sut);
doAnswer(invocation -> {
((Agent)invocation.getMock()).incrementEpisodeStepNumber();
((Agent)invocation.getMock()).incrementEpisodeStepCount();
return null;
}).when(spy).performStep();
when(environmentMock.isEpisodeFinished()).thenAnswer(invocation -> spy.getEpisodeStepNumber() >= 5 );
when(environmentMock.isEpisodeFinished()).thenAnswer(invocation -> spy.getEpisodeStepCount() >= 5 );
// Act
spy.run();
@ -209,7 +207,7 @@ public class AgentTest {
public void when_maxStepsIsReachedBeforeEposideEnds_expect_runTerminated() {
// Arrange
Map<String, Object> envResetResult = new HashMap<>();
Schema schema = new Schema(new ActionSchema<>(-1));
Schema schema = new Schema(new IntegerActionSchema(0, -1));
when(environmentMock.reset()).thenReturn(envResetResult);
when(environmentMock.getSchema()).thenReturn(schema);
@ -222,7 +220,7 @@ public class AgentTest {
final Agent spy = Mockito.spy(sut);
doAnswer(invocation -> {
((Agent)invocation.getMock()).incrementEpisodeStepNumber();
((Agent)invocation.getMock()).incrementEpisodeStepCount();
return null;
}).when(spy).performStep();
@ -239,7 +237,7 @@ public class AgentTest {
public void when_initialObservationsAreSkipped_expect_performNoOpAction() {
// Arrange
Map<String, Object> envResetResult = new HashMap<>();
Schema schema = new Schema(new ActionSchema<>(-1));
Schema schema = new Schema(new IntegerActionSchema(0, -1));
when(environmentMock.reset()).thenReturn(envResetResult);
when(environmentMock.getSchema()).thenReturn(schema);
@ -264,7 +262,7 @@ public class AgentTest {
public void when_initialObservationsAreSkipped_expect_performNoOpActionAnd() {
// Arrange
Map<String, Object> envResetResult = new HashMap<>();
Schema schema = new Schema(new ActionSchema<>(-1));
Schema schema = new Schema(new IntegerActionSchema(0, -1));
when(environmentMock.reset()).thenReturn(envResetResult);
when(environmentMock.getSchema()).thenReturn(schema);
@ -289,7 +287,7 @@ public class AgentTest {
public void when_observationsIsSkipped_expect_performLastAction() {
// Arrange
Map<String, Object> envResetResult = new HashMap<>();
Schema schema = new Schema(new ActionSchema<>(-1));
Schema schema = new Schema(new IntegerActionSchema(0, -1));
when(environmentMock.reset()).thenReturn(envResetResult);
when(environmentMock.step(any(Integer.class))).thenReturn(new StepResult(envResetResult, 0.0, false));
when(environmentMock.getSchema()).thenReturn(schema);
@ -331,7 +329,7 @@ public class AgentTest {
@Test
public void when_onBeforeStepReturnsStop_expect_performStepAndOnAfterEpisodeNotCalled() {
// Arrange
Schema schema = new Schema(new ActionSchema<>(-1));
Schema schema = new Schema(new IntegerActionSchema(0, -1));
when(environmentMock.reset()).thenReturn(new HashMap<>());
when(environmentMock.getSchema()).thenReturn(schema);
@ -358,7 +356,7 @@ public class AgentTest {
@Test
public void when_observationIsNotSkipped_expect_policyActionIsSentToEnvironment() {
// Arrange
Schema schema = new Schema(new ActionSchema<>(-1));
Schema schema = new Schema(new IntegerActionSchema(0, -1));
when(environmentMock.reset()).thenReturn(new HashMap<>());
when(environmentMock.getSchema()).thenReturn(schema);
when(environmentMock.step(any(Integer.class))).thenReturn(new StepResult(new HashMap<>(), 0.0, false));
@ -381,7 +379,7 @@ public class AgentTest {
@Test
public void when_stepResultIsReceived_expect_observationAndRewardUpdated() {
// Arrange
Schema schema = new Schema(new ActionSchema<>(-1));
Schema schema = new Schema(new IntegerActionSchema(0, -1));
when(environmentMock.reset()).thenReturn(new HashMap<>());
when(environmentMock.getSchema()).thenReturn(schema);
when(environmentMock.step(any(Integer.class))).thenReturn(new StepResult(new HashMap<>(), 234.0, false));
@ -405,7 +403,7 @@ public class AgentTest {
@Test
public void when_stepIsDone_expect_onAfterStepAndWithStepResult() {
// Arrange
Schema schema = new Schema(new ActionSchema<>(-1));
Schema schema = new Schema(new IntegerActionSchema(0, -1));
when(environmentMock.reset()).thenReturn(new HashMap<>());
when(environmentMock.getSchema()).thenReturn(schema);
StepResult stepResult = new StepResult(new HashMap<>(), 234.0, false);
@ -430,7 +428,7 @@ public class AgentTest {
@Test
public void when_onAfterStepReturnsStop_expect_onAfterEpisodeNotCalled() {
// Arrange
Schema schema = new Schema(new ActionSchema<>(-1));
Schema schema = new Schema(new IntegerActionSchema(0, -1));
when(environmentMock.reset()).thenReturn(new HashMap<>());
when(environmentMock.getSchema()).thenReturn(schema);
StepResult stepResult = new StepResult(new HashMap<>(), 234.0, false);
@ -458,7 +456,7 @@ public class AgentTest {
@Test
public void when_runIsCalled_expect_onAfterEpisodeIsCalled() {
// Arrange
Schema schema = new Schema(new ActionSchema<>(-1));
Schema schema = new Schema(new IntegerActionSchema(0, -1));
when(environmentMock.reset()).thenReturn(new HashMap<>());
when(environmentMock.getSchema()).thenReturn(schema);
StepResult stepResult = new StepResult(new HashMap<>(), 234.0, false);

View File

@ -0,0 +1,133 @@
package org.deeplearning4j.rl4j.agent.learning;
import org.deeplearning4j.rl4j.agent.update.IUpdateRule;
import org.deeplearning4j.rl4j.experience.ExperienceHandler;
import org.deeplearning4j.rl4j.observation.Observation;
import org.junit.Before;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.mockito.ArgumentCaptor;
import org.mockito.Mock;
import org.mockito.junit.MockitoJUnitRunner;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import java.util.ArrayList;
import java.util.List;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.*;
@RunWith(MockitoJUnitRunner.class)
public class LearningBehaviorTest {
@Mock
ExperienceHandler<Integer, Object> experienceHandlerMock;
@Mock
IUpdateRule<Object> updateRuleMock;
LearningBehavior<Integer, Object> sut;
@Before
public void setup() {
sut = LearningBehavior.<Integer, Object>builder()
.experienceHandler(experienceHandlerMock)
.updateRule(updateRuleMock)
.build();
}
@Test
public void when_callingHandleEpisodeStart_expect_experienceHandlerResetCalled() {
// Arrange
LearningBehavior<Integer, Object> sut = LearningBehavior.<Integer, Object>builder()
.experienceHandler(experienceHandlerMock)
.updateRule(updateRuleMock)
.build();
// Act
sut.handleEpisodeStart();
// Assert
verify(experienceHandlerMock, times(1)).reset();
}
@Test
public void when_callingHandleNewExperience_expect_experienceHandlerAddExperienceCalled() {
// Arrange
INDArray observationData = Nd4j.rand(1, 1);
when(experienceHandlerMock.isTrainingBatchReady()).thenReturn(false);
// Act
sut.handleNewExperience(new Observation(observationData), 1, 2.0, false);
// Assert
ArgumentCaptor<Observation> observationCaptor = ArgumentCaptor.forClass(Observation.class);
ArgumentCaptor<Integer> actionCaptor = ArgumentCaptor.forClass(Integer.class);
ArgumentCaptor<Double> rewardCaptor = ArgumentCaptor.forClass(Double.class);
ArgumentCaptor<Boolean> isTerminatedCaptor = ArgumentCaptor.forClass(Boolean.class);
verify(experienceHandlerMock, times(1)).addExperience(observationCaptor.capture(), actionCaptor.capture(), rewardCaptor.capture(), isTerminatedCaptor.capture());
assertEquals(observationData.getDouble(0, 0), observationCaptor.getValue().getData().getDouble(0, 0), 0.00001);
assertEquals(1, (int)actionCaptor.getValue());
assertEquals(2.0, (double)rewardCaptor.getValue(), 0.00001);
assertFalse(isTerminatedCaptor.getValue());
verify(updateRuleMock, never()).update(any(List.class));
}
@Test
public void when_callingHandleNewExperienceAndTrainingBatchIsReady_expect_updateRuleUpdateWithTrainingBatch() {
// Arrange
INDArray observationData = Nd4j.rand(1, 1);
when(experienceHandlerMock.isTrainingBatchReady()).thenReturn(true);
List<Object> trainingBatch = new ArrayList<Object>();
when(experienceHandlerMock.generateTrainingBatch()).thenReturn(trainingBatch);
// Act
sut.handleNewExperience(new Observation(observationData), 1, 2.0, false);
// Assert
verify(updateRuleMock, times(1)).update(trainingBatch);
}
@Test
public void when_callingHandleEpisodeEnd_expect_experienceHandlerSetFinalObservationCalled() {
// Arrange
INDArray observationData = Nd4j.rand(1, 1);
when(experienceHandlerMock.isTrainingBatchReady()).thenReturn(false);
// Act
sut.handleEpisodeEnd(new Observation(observationData));
// Assert
ArgumentCaptor<Observation> observationCaptor = ArgumentCaptor.forClass(Observation.class);
verify(experienceHandlerMock, times(1)).setFinalObservation(observationCaptor.capture());
assertEquals(observationData.getDouble(0, 0), observationCaptor.getValue().getData().getDouble(0, 0), 0.00001);
verify(updateRuleMock, never()).update(any(List.class));
}
@Test
public void when_callingHandleEpisodeEndAndTrainingBatchIsNotEmpty_expect_updateRuleUpdateWithTrainingBatch() {
// Arrange
INDArray observationData = Nd4j.rand(1, 1);
when(experienceHandlerMock.isTrainingBatchReady()).thenReturn(true);
List<Object> trainingBatch = new ArrayList<Object>();
when(experienceHandlerMock.generateTrainingBatch()).thenReturn(trainingBatch);
// Act
sut.handleEpisodeEnd(new Observation(observationData));
// Assert
ArgumentCaptor<Observation> observationCaptor = ArgumentCaptor.forClass(Observation.class);
verify(experienceHandlerMock, times(1)).setFinalObservation(observationCaptor.capture());
assertEquals(observationData.getDouble(0, 0), observationCaptor.getValue().getData().getDouble(0, 0), 0.00001);
verify(updateRuleMock, times(1)).update(trainingBatch);
}
}

View File

@ -4,34 +4,44 @@ import org.deeplearning4j.rl4j.learning.sync.IExpReplay;
import org.deeplearning4j.rl4j.learning.sync.Transition;
import org.deeplearning4j.rl4j.observation.Observation;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.mockito.ArgumentCaptor;
import org.mockito.Mock;
import org.mockito.junit.MockitoJUnitRunner;
import org.nd4j.linalg.factory.Nd4j;
import java.util.ArrayList;
import java.util.List;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.*;
import static org.mockito.Mockito.*;
@RunWith(MockitoJUnitRunner.class)
public class ReplayMemoryExperienceHandlerTest {
@Mock
IExpReplay<Integer> expReplayMock;
@Test
public void when_addingFirstExperience_expect_notAddedToStoreBeforeNextObservationIsAdded() {
// Arrange
TestExpReplay expReplayMock = new TestExpReplay();
when(expReplayMock.getDesignatedBatchSize()).thenReturn(10);
ReplayMemoryExperienceHandler sut = new ReplayMemoryExperienceHandler(expReplayMock);
// Act
sut.addExperience(new Observation(Nd4j.create(new double[] { 1.0 })), 1, 1.0, false);
int numStoredTransitions = expReplayMock.addedTransitions.size();
boolean isStoreCalledAfterFirstAdd = mockingDetails(expReplayMock).getInvocations().stream().anyMatch(x -> x.getMethod().getName() == "store");
sut.addExperience(new Observation(Nd4j.create(new double[] { 2.0 })), 2, 2.0, false);
boolean isStoreCalledAfterSecondAdd = mockingDetails(expReplayMock).getInvocations().stream().anyMatch(x -> x.getMethod().getName() == "store");
// Assert
assertEquals(0, numStoredTransitions);
assertEquals(1, expReplayMock.addedTransitions.size());
assertFalse(isStoreCalledAfterFirstAdd);
assertTrue(isStoreCalledAfterSecondAdd);
}
@Test
public void when_addingExperience_expect_transitionsAreCorrect() {
// Arrange
TestExpReplay expReplayMock = new TestExpReplay();
ReplayMemoryExperienceHandler sut = new ReplayMemoryExperienceHandler(expReplayMock);
// Act
@ -40,24 +50,25 @@ public class ReplayMemoryExperienceHandlerTest {
sut.setFinalObservation(new Observation(Nd4j.create(new double[] { 3.0 })));
// Assert
assertEquals(2, expReplayMock.addedTransitions.size());
ArgumentCaptor<Transition<Integer>> argument = ArgumentCaptor.forClass(Transition.class);
verify(expReplayMock, times(2)).store(argument.capture());
List<Transition<Integer>> transitions = argument.getAllValues();
assertEquals(1.0, expReplayMock.addedTransitions.get(0).getObservation().getData().getDouble(0), 0.00001);
assertEquals(1, (int)expReplayMock.addedTransitions.get(0).getAction());
assertEquals(1.0, expReplayMock.addedTransitions.get(0).getReward(), 0.00001);
assertEquals(2.0, expReplayMock.addedTransitions.get(0).getNextObservation().getDouble(0), 0.00001);
assertEquals(1.0, transitions.get(0).getObservation().getData().getDouble(0), 0.00001);
assertEquals(1, (int)transitions.get(0).getAction());
assertEquals(1.0, transitions.get(0).getReward(), 0.00001);
assertEquals(2.0, transitions.get(0).getNextObservation().getDouble(0), 0.00001);
assertEquals(2.0, expReplayMock.addedTransitions.get(1).getObservation().getData().getDouble(0), 0.00001);
assertEquals(2, (int)expReplayMock.addedTransitions.get(1).getAction());
assertEquals(2.0, expReplayMock.addedTransitions.get(1).getReward(), 0.00001);
assertEquals(3.0, expReplayMock.addedTransitions.get(1).getNextObservation().getDouble(0), 0.00001);
assertEquals(2.0, transitions.get(1).getObservation().getData().getDouble(0), 0.00001);
assertEquals(2, (int)transitions.get(1).getAction());
assertEquals(2.0, transitions.get(1).getReward(), 0.00001);
assertEquals(3.0, transitions.get(1).getNextObservation().getDouble(0), 0.00001);
}
@Test
public void when_settingFinalObservation_expect_nextAddedExperienceDoNotUsePreviousObservation() {
// Arrange
TestExpReplay expReplayMock = new TestExpReplay();
ReplayMemoryExperienceHandler sut = new ReplayMemoryExperienceHandler(expReplayMock);
// Act
@ -66,42 +77,57 @@ public class ReplayMemoryExperienceHandlerTest {
sut.addExperience(new Observation(Nd4j.create(new double[] { 3.0 })), 3, 3.0, false);
// Assert
assertEquals(1, expReplayMock.addedTransitions.size());
assertEquals(1, (int)expReplayMock.addedTransitions.get(0).getAction());
ArgumentCaptor<Transition<Integer>> argument = ArgumentCaptor.forClass(Transition.class);
verify(expReplayMock, times(1)).store(argument.capture());
Transition<Integer> transition = argument.getValue();
assertEquals(1, (int)transition.getAction());
}
@Test
public void when_addingExperience_expect_getTrainingBatchSizeReturnSize() {
// Arrange
TestExpReplay expReplayMock = new TestExpReplay();
ReplayMemoryExperienceHandler sut = new ReplayMemoryExperienceHandler(expReplayMock);
ReplayMemoryExperienceHandler sut = new ReplayMemoryExperienceHandler(10, 5, Nd4j.getRandom());
sut.addExperience(new Observation(Nd4j.create(new double[] { 1.0 })), 1, 1.0, false);
sut.addExperience(new Observation(Nd4j.create(new double[] { 2.0 })), 2, 2.0, false);
sut.setFinalObservation(new Observation(Nd4j.create(new double[] { 3.0 })));
// Act
int size = sut.getTrainingBatchSize();
// Assert
assertEquals(2, size);
}
private static class TestExpReplay implements IExpReplay<Integer> {
@Test
public void when_experienceSizeIsSmallerThanBatchSize_expect_TrainingBatchIsNotReady() {
// Arrange
ReplayMemoryExperienceHandler sut = new ReplayMemoryExperienceHandler(10, 5, Nd4j.getRandom());
sut.addExperience(new Observation(Nd4j.create(new double[] { 1.0 })), 1, 1.0, false);
sut.addExperience(new Observation(Nd4j.create(new double[] { 2.0 })), 2, 2.0, false);
sut.setFinalObservation(new Observation(Nd4j.create(new double[] { 3.0 })));
public final List<Transition<Integer>> addedTransitions = new ArrayList<>();
// Act
@Override
public ArrayList<Transition<Integer>> getBatch() {
return null;
// Assert
assertFalse(sut.isTrainingBatchReady());
}
@Override
public void store(Transition<Integer> transition) {
addedTransitions.add(transition);
@Test
public void when_experienceSizeIsGreaterOrEqualToBatchSize_expect_TrainingBatchIsReady() {
// Arrange
ReplayMemoryExperienceHandler sut = new ReplayMemoryExperienceHandler(10, 5, Nd4j.getRandom());
sut.addExperience(new Observation(Nd4j.create(new double[] { 1.0 })), 1, 1.0, false);
sut.addExperience(new Observation(Nd4j.create(new double[] { 2.0 })), 2, 2.0, false);
sut.addExperience(new Observation(Nd4j.create(new double[] { 3.0 })), 3, 3.0, false);
sut.addExperience(new Observation(Nd4j.create(new double[] { 4.0 })), 4, 4.0, false);
sut.addExperience(new Observation(Nd4j.create(new double[] { 5.0 })), 5, 5.0, false);
sut.setFinalObservation(new Observation(Nd4j.create(new double[] { 6.0 })));
// Act
// Assert
assertTrue(sut.isTrainingBatchReady());
}
@Override
public int getBatchSize() {
return addedTransitions.size();
}
}
}

View File

@ -13,7 +13,7 @@ public class StateActionExperienceHandlerTest {
@Test
public void when_addingExperience_expect_generateTrainingBatchReturnsIt() {
// Arrange
StateActionExperienceHandler sut = new StateActionExperienceHandler();
StateActionExperienceHandler sut = new StateActionExperienceHandler(Integer.MAX_VALUE);
sut.reset();
Observation observation = new Observation(Nd4j.zeros(1));
sut.addExperience(observation, 123, 234.0, true);
@ -32,7 +32,7 @@ public class StateActionExperienceHandlerTest {
@Test
public void when_addingMultipleExperiences_expect_generateTrainingBatchReturnsItInSameOrder() {
// Arrange
StateActionExperienceHandler sut = new StateActionExperienceHandler();
StateActionExperienceHandler sut = new StateActionExperienceHandler(Integer.MAX_VALUE);
sut.reset();
sut.addExperience(null, 1, 1.0, false);
sut.addExperience(null, 2, 2.0, false);
@ -51,7 +51,7 @@ public class StateActionExperienceHandlerTest {
@Test
public void when_gettingExperience_expect_experienceStoreIsCleared() {
// Arrange
StateActionExperienceHandler sut = new StateActionExperienceHandler();
StateActionExperienceHandler sut = new StateActionExperienceHandler(Integer.MAX_VALUE);
sut.reset();
sut.addExperience(null, 1, 1.0, false);
@ -67,7 +67,7 @@ public class StateActionExperienceHandlerTest {
@Test
public void when_addingExperience_expect_getTrainingBatchSizeReturnSize() {
// Arrange
StateActionExperienceHandler sut = new StateActionExperienceHandler();
StateActionExperienceHandler sut = new StateActionExperienceHandler(Integer.MAX_VALUE);
sut.reset();
sut.addExperience(null, 1, 1.0, false);
sut.addExperience(null, 2, 2.0, false);
@ -79,4 +79,66 @@ public class StateActionExperienceHandlerTest {
// Assert
assertEquals(3, size);
}
@Test
public void when_experienceIsEmpty_expect_TrainingBatchNotReady() {
// Arrange
StateActionExperienceHandler sut = new StateActionExperienceHandler(5);
sut.reset();
// Act
boolean isTrainingBatchReady = sut.isTrainingBatchReady();
// Assert
assertFalse(isTrainingBatchReady);
}
@Test
public void when_experienceSizeIsGreaterOrEqualToThanBatchSize_expect_TrainingBatchIsReady() {
// Arrange
StateActionExperienceHandler sut = new StateActionExperienceHandler(5);
sut.reset();
sut.addExperience(null, 1, 1.0, false);
sut.addExperience(null, 2, 2.0, false);
sut.addExperience(null, 3, 3.0, false);
sut.addExperience(null, 4, 4.0, false);
sut.addExperience(null, 5, 5.0, false);
// Act
boolean isTrainingBatchReady = sut.isTrainingBatchReady();
// Assert
assertTrue(isTrainingBatchReady);
}
@Test
public void when_experienceSizeIsSmallerThanBatchSizeButFinalObservationIsSet_expect_TrainingBatchIsReady() {
// Arrange
StateActionExperienceHandler sut = new StateActionExperienceHandler(5);
sut.reset();
sut.addExperience(null, 1, 1.0, false);
sut.addExperience(null, 2, 2.0, false);
sut.setFinalObservation(null);
// Act
boolean isTrainingBatchReady = sut.isTrainingBatchReady();
// Assert
assertTrue(isTrainingBatchReady);
}
@Test
public void when_experienceSizeIsZeroAndFinalObservationIsSet_expect_TrainingBatchIsNotReady() {
// Arrange
StateActionExperienceHandler sut = new StateActionExperienceHandler(5);
sut.reset();
sut.setFinalObservation(null);
// Act
boolean isTrainingBatchReady = sut.isTrainingBatchReady();
// Assert
assertFalse(isTrainingBatchReady);
}
}

View File

@ -49,4 +49,25 @@ public class INDArrayHelperTest {
assertEquals(1, output.shape()[1]);
}
@Test
public void when_callingCreateBatchForShape_expect_INDArrayWithCorrectShapeAndOriginalShapeUnchanged() {
// Arrange
long[] shape = new long[] { 1, 3, 4};
// Act
INDArray output = INDArrayHelper.createBatchForShape(2, shape);
// Assert
// Output shape
assertEquals(3, output.shape().length);
assertEquals(2, output.shape()[0]);
assertEquals(3, output.shape()[1]);
assertEquals(4, output.shape()[2]);
// Input should remain unchanged
assertEquals(1, shape[0]);
assertEquals(3, shape[1]);
assertEquals(4, shape[2]);
}
}

View File

@ -19,10 +19,11 @@ package org.deeplearning4j.rl4j.learning.async.nstep.discrete;
import org.deeplearning4j.rl4j.experience.StateActionPair;
import org.deeplearning4j.rl4j.learning.async.AsyncGlobal;
import org.deeplearning4j.rl4j.learning.async.UpdateAlgorithm;
import org.deeplearning4j.rl4j.network.dqn.IDQN;
import org.deeplearning4j.rl4j.observation.Observation;
import org.deeplearning4j.rl4j.support.MockDQN;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.mockito.ArgumentCaptor;
import org.mockito.Mock;
import org.mockito.junit.MockitoJUnitRunner;
import org.nd4j.linalg.api.ndarray.INDArray;
@ -32,6 +33,9 @@ import java.util.ArrayList;
import java.util.List;
import static org.junit.Assert.assertEquals;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.argThat;
import static org.mockito.Mockito.*;
@RunWith(MockitoJUnitRunner.class)
public class QLearningUpdateAlgorithmTest {
@ -39,12 +43,24 @@ public class QLearningUpdateAlgorithmTest {
@Mock
AsyncGlobal mockAsyncGlobal;
@Mock
IDQN dqnMock;
private UpdateAlgorithm sut;
private void setup(double gamma) {
// mock a neural net output -- just invert the sign of the input
when(dqnMock.outputAll(any(INDArray.class))).thenAnswer(invocation -> new INDArray[] { invocation.getArgument(0, INDArray.class).mul(-1.0) });
sut = new QLearningUpdateAlgorithm(2, gamma);
}
@Test
public void when_isTerminal_expect_initRewardIs0() {
// Arrange
MockDQN dqnMock = new MockDQN();
UpdateAlgorithm sut = new QLearningUpdateAlgorithm(new int[] { 1 }, 1, 1.0);
final Observation observation = new Observation(Nd4j.zeros(1));
setup(1.0);
final Observation observation = new Observation(Nd4j.zeros(1, 2));
List<StateActionPair<Integer>> experience = new ArrayList<StateActionPair<Integer>>() {
{
add(new StateActionPair<Integer>(observation, 0, 0.0, true));
@ -55,59 +71,68 @@ public class QLearningUpdateAlgorithmTest {
sut.computeGradients(dqnMock, experience);
// Assert
assertEquals(0.0, dqnMock.gradientParams.get(0).getRight().getDouble(0), 0.00001);
verify(dqnMock, times(1)).gradient(any(INDArray.class), argThat((INDArray x) -> x.getDouble(0) == 0.0));
}
@Test
public void when_terminalAndNoTargetUpdate_expect_initRewardWithMaxQFromCurrent() {
// Arrange
UpdateAlgorithm sut = new QLearningUpdateAlgorithm(new int[] { 2 }, 2, 1.0);
final Observation observation = new Observation(Nd4j.create(new double[] { -123.0, -234.0 }));
setup(1.0);
final Observation observation = new Observation(Nd4j.create(new double[] { -123.0, -234.0 }).reshape(1, 2));
List<StateActionPair<Integer>> experience = new ArrayList<StateActionPair<Integer>>() {
{
add(new StateActionPair<Integer>(observation, 0, 0.0, false));
}
};
MockDQN dqnMock = new MockDQN();
// Act
sut.computeGradients(dqnMock, experience);
// Assert
assertEquals(2, dqnMock.outputAllParams.size());
assertEquals(-123.0, dqnMock.outputAllParams.get(0).getDouble(0, 0), 0.00001);
assertEquals(234.0, dqnMock.gradientParams.get(0).getRight().getDouble(0), 0.00001);
ArgumentCaptor<INDArray> argument = ArgumentCaptor.forClass(INDArray.class);
verify(dqnMock, times(2)).outputAll(argument.capture());
List<INDArray> values = argument.getAllValues();
assertEquals(-123.0, values.get(0).getDouble(0, 0), 0.00001);
assertEquals(-123.0, values.get(1).getDouble(0, 0), 0.00001);
verify(dqnMock, times(1)).gradient(any(INDArray.class), argThat((INDArray x) -> x.getDouble(0) == 234.0));
}
@Test
public void when_callingWithMultipleExperiences_expect_gradientsAreValid() {
// Arrange
double gamma = 0.9;
UpdateAlgorithm sut = new QLearningUpdateAlgorithm(new int[] { 2 }, 2, gamma);
setup(gamma);
List<StateActionPair<Integer>> experience = new ArrayList<StateActionPair<Integer>>() {
{
add(new StateActionPair<Integer>(new Observation(Nd4j.create(new double[] { -1.1, -1.2 })), 0, 1.0, false));
add(new StateActionPair<Integer>(new Observation(Nd4j.create(new double[] { -2.1, -2.2 })), 1, 2.0, true));
add(new StateActionPair<Integer>(new Observation(Nd4j.create(new double[] { -1.1, -1.2 }).reshape(1, 2)), 0, 1.0, false));
add(new StateActionPair<Integer>(new Observation(Nd4j.create(new double[] { -2.1, -2.2 }).reshape(1, 2)), 1, 2.0, true));
}
};
MockDQN dqnMock = new MockDQN();
// Act
sut.computeGradients(dqnMock, experience);
// Assert
ArgumentCaptor<INDArray> features = ArgumentCaptor.forClass(INDArray.class);
ArgumentCaptor<INDArray> targets = ArgumentCaptor.forClass(INDArray.class);
verify(dqnMock, times(1)).gradient(features.capture(), targets.capture());
// input side -- should be a stack of observations
INDArray input = dqnMock.gradientParams.get(0).getLeft();
assertEquals(-1.1, input.getDouble(0, 0), 0.00001);
assertEquals(-1.2, input.getDouble(0, 1), 0.00001);
assertEquals(-2.1, input.getDouble(1, 0), 0.00001);
assertEquals(-2.2, input.getDouble(1, 1), 0.00001);
INDArray featuresValues = features.getValue();
assertEquals(-1.1, featuresValues.getDouble(0, 0), 0.00001);
assertEquals(-1.2, featuresValues.getDouble(0, 1), 0.00001);
assertEquals(-2.1, featuresValues.getDouble(1, 0), 0.00001);
assertEquals(-2.2, featuresValues.getDouble(1, 1), 0.00001);
// target side
INDArray target = dqnMock.gradientParams.get(0).getRight();
assertEquals(1.0 + gamma * 2.0, target.getDouble(0, 0), 0.00001);
assertEquals(1.2, target.getDouble(0, 1), 0.00001);
assertEquals(2.1, target.getDouble(1, 0), 0.00001);
assertEquals(2.0, target.getDouble(1, 1), 0.00001);
INDArray targetsValues = targets.getValue();
assertEquals(1.0 + gamma * 2.0, targetsValues.getDouble(0, 0), 0.00001);
assertEquals(1.2, targetsValues.getDouble(0, 1), 0.00001);
assertEquals(2.1, targetsValues.getDouble(1, 0), 0.00001);
assertEquals(2.0, targetsValues.getDouble(1, 1), 0.00001);
}
}

View File

@ -18,6 +18,7 @@
package org.deeplearning4j.rl4j.learning.sync.qlearning.discrete;
import org.deeplearning4j.gym.StepReply;
import org.deeplearning4j.rl4j.agent.learning.ILearningBehavior;
import org.deeplearning4j.rl4j.learning.IHistoryProcessor;
import org.deeplearning4j.rl4j.learning.configuration.QLearningConfiguration;
import org.deeplearning4j.rl4j.learning.sync.qlearning.QLearning;
@ -74,6 +75,9 @@ public class QLearningDiscreteTest {
@Mock
QLearningConfiguration mockQlearningConfiguration;
@Mock
ILearningBehavior<Integer> learningBehavior;
// HWC
int[] observationShape = new int[]{3, 10, 10};
int totalObservationSize = 1;
@ -92,12 +96,21 @@ public class QLearningDiscreteTest {
}
private void mockTestContext(int maxSteps, int updateStart, int batchSize, double rewardFactor, int maxExperienceReplay) {
private void mockTestContext(int maxSteps, int updateStart, int batchSize, double rewardFactor, int maxExperienceReplay, ILearningBehavior<Integer> learningBehavior) {
when(mockQlearningConfiguration.getBatchSize()).thenReturn(batchSize);
when(mockQlearningConfiguration.getRewardFactor()).thenReturn(rewardFactor);
when(mockQlearningConfiguration.getExpRepMaxSize()).thenReturn(maxExperienceReplay);
when(mockQlearningConfiguration.getSeed()).thenReturn(123L);
if(learningBehavior != null) {
qLearningDiscrete = mock(
QLearningDiscrete.class,
Mockito.withSettings()
.useConstructor(mockMDP, mockDQN, mockQlearningConfiguration, 0, learningBehavior, Nd4j.getRandom())
.defaultAnswer(Mockito.CALLS_REAL_METHODS)
);
}
else {
qLearningDiscrete = mock(
QLearningDiscrete.class,
Mockito.withSettings()
@ -105,6 +118,7 @@ public class QLearningDiscreteTest {
.defaultAnswer(Mockito.CALLS_REAL_METHODS)
);
}
}
private void mockHistoryProcessor(int skipFrames) {
when(mockHistoryConfiguration.getRescaledHeight()).thenReturn(observationShape[1]);
@ -136,7 +150,7 @@ public class QLearningDiscreteTest {
public void when_singleTrainStep_expect_correctValues() {
// Arrange
mockTestContext(100,0,2,1.0, 10);
mockTestContext(100,0,2,1.0, 10, null);
// An example observation and 2 Q values output (2 actions)
Observation observation = new Observation(Nd4j.zeros(observationShape));
@ -162,7 +176,7 @@ public class QLearningDiscreteTest {
@Test
public void when_singleTrainStepSkippedFrames_expect_correctValues() {
// Arrange
mockTestContext(100,0,2,1.0, 10);
mockTestContext(100,0,2,1.0, 10, learningBehavior);
Observation skippedObservation = Observation.SkippedObservation;
Observation nextObservation = new Observation(Nd4j.zeros(observationShape));
@ -180,8 +194,8 @@ public class QLearningDiscreteTest {
assertEquals(0, stepReply.getReward(), 1e-5);
assertFalse(stepReply.isDone());
assertFalse(stepReply.getObservation().isSkipped());
assertEquals(0, qLearningDiscrete.getExperienceHandler().getTrainingBatchSize());
verify(learningBehavior, never()).handleNewExperience(any(Observation.class), any(Integer.class), any(Double.class), any(Boolean.class));
verify(mockDQN, never()).output(any(INDArray.class));
}