parent
a18417193d
commit
5568b9d72f
|
@ -1,3 +1,18 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2020 Konduit K.K.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
package org.deeplearning4j.rl4j.agent;
|
||||
|
||||
import lombok.AccessLevel;
|
||||
|
@ -14,7 +29,13 @@ import org.nd4j.common.base.Preconditions;
|
|||
|
||||
import java.util.Map;
|
||||
|
||||
public class Agent<ACTION> {
|
||||
/**
|
||||
* An agent implementation. The Agent will use a {@link IPolicy} to interact with an {@link Environment} and receive
|
||||
* a reward.
|
||||
*
|
||||
* @param <ACTION> The type of action
|
||||
*/
|
||||
public class Agent<ACTION> implements IAgent<ACTION> {
|
||||
@Getter
|
||||
private final String id;
|
||||
|
||||
|
@ -37,19 +58,28 @@ public class Agent<ACTION> {
|
|||
private ACTION lastAction;
|
||||
|
||||
@Getter
|
||||
private int episodeStepNumber;
|
||||
private int episodeStepCount;
|
||||
|
||||
@Getter
|
||||
private double reward;
|
||||
|
||||
protected boolean canContinue;
|
||||
|
||||
private Agent(Builder<ACTION> builder) {
|
||||
this.environment = builder.environment;
|
||||
this.transformProcess = builder.transformProcess;
|
||||
this.policy = builder.policy;
|
||||
this.maxEpisodeSteps = builder.maxEpisodeSteps;
|
||||
this.id = builder.id;
|
||||
/**
|
||||
* @param environment The {@link Environment} to be used
|
||||
* @param transformProcess The {@link TransformProcess} to be used to transform the raw observations into usable ones.
|
||||
* @param policy The {@link IPolicy} to be used
|
||||
* @param maxEpisodeSteps The maximum number of steps an episode can have before being interrupted. Use null to have no max.
|
||||
* @param id A user-supplied id to identify the instance.
|
||||
*/
|
||||
public Agent(@NonNull Environment<ACTION> environment, @NonNull TransformProcess transformProcess, @NonNull IPolicy<ACTION> policy, Integer maxEpisodeSteps, String id) {
|
||||
Preconditions.checkArgument(maxEpisodeSteps == null || maxEpisodeSteps > 0, "maxEpisodeSteps must be null (no maximum) or greater than 0, got", maxEpisodeSteps);
|
||||
|
||||
this.environment = environment;
|
||||
this.transformProcess = transformProcess;
|
||||
this.policy = policy;
|
||||
this.maxEpisodeSteps = maxEpisodeSteps;
|
||||
this.id = id;
|
||||
|
||||
listeners = buildListenerList();
|
||||
}
|
||||
|
@ -58,10 +88,17 @@ public class Agent<ACTION> {
|
|||
return new AgentListenerList<ACTION>();
|
||||
}
|
||||
|
||||
/**
|
||||
* Add a {@link AgentListener} that will be notified when agent events happens
|
||||
* @param listener
|
||||
*/
|
||||
public void addListener(AgentListener listener) {
|
||||
listeners.add(listener);
|
||||
}
|
||||
|
||||
/**
|
||||
* This will run a single episode
|
||||
*/
|
||||
public void run() {
|
||||
runEpisode();
|
||||
}
|
||||
|
@ -80,7 +117,7 @@ public class Agent<ACTION> {
|
|||
|
||||
canContinue = listeners.notifyBeforeEpisode(this);
|
||||
|
||||
while (canContinue && !environment.isEpisodeFinished() && (maxEpisodeSteps == null || episodeStepNumber < maxEpisodeSteps)) {
|
||||
while (canContinue && !environment.isEpisodeFinished() && (maxEpisodeSteps == null || episodeStepCount < maxEpisodeSteps)) {
|
||||
performStep();
|
||||
}
|
||||
|
||||
|
@ -100,9 +137,9 @@ public class Agent<ACTION> {
|
|||
}
|
||||
|
||||
protected void resetEnvironment() {
|
||||
episodeStepNumber = 0;
|
||||
episodeStepCount = 0;
|
||||
Map<String, Object> channelsData = environment.reset();
|
||||
this.observation = transformProcess.transform(channelsData, episodeStepNumber, false);
|
||||
this.observation = transformProcess.transform(channelsData, episodeStepCount, false);
|
||||
}
|
||||
|
||||
protected void resetPolicy() {
|
||||
|
@ -125,7 +162,6 @@ public class Agent<ACTION> {
|
|||
}
|
||||
|
||||
StepResult stepResult = act(action);
|
||||
handleStepResult(stepResult);
|
||||
|
||||
onAfterStep(stepResult);
|
||||
|
||||
|
@ -134,11 +170,11 @@ public class Agent<ACTION> {
|
|||
return;
|
||||
}
|
||||
|
||||
incrementEpisodeStepNumber();
|
||||
incrementEpisodeStepCount();
|
||||
}
|
||||
|
||||
protected void incrementEpisodeStepNumber() {
|
||||
++episodeStepNumber;
|
||||
protected void incrementEpisodeStepCount() {
|
||||
++episodeStepCount;
|
||||
}
|
||||
|
||||
protected ACTION decideAction(Observation observation) {
|
||||
|
@ -150,12 +186,15 @@ public class Agent<ACTION> {
|
|||
}
|
||||
|
||||
protected StepResult act(ACTION action) {
|
||||
return environment.step(action);
|
||||
}
|
||||
Observation observationBeforeAction = observation;
|
||||
|
||||
protected void handleStepResult(StepResult stepResult) {
|
||||
observation = convertChannelDataToObservation(stepResult, episodeStepNumber + 1);
|
||||
reward +=computeReward(stepResult);
|
||||
StepResult stepResult = environment.step(action);
|
||||
observation = convertChannelDataToObservation(stepResult, episodeStepCount + 1);
|
||||
reward += computeReward(stepResult);
|
||||
|
||||
onAfterAction(observationBeforeAction, action, stepResult);
|
||||
|
||||
return stepResult;
|
||||
}
|
||||
|
||||
protected Observation convertChannelDataToObservation(StepResult stepResult, int episodeStepNumberOfObs) {
|
||||
|
@ -166,6 +205,10 @@ public class Agent<ACTION> {
|
|||
return stepResult.getReward();
|
||||
}
|
||||
|
||||
protected void onAfterAction(Observation observationBeforeAction, ACTION action, StepResult stepResult) {
|
||||
// Do Nothing
|
||||
}
|
||||
|
||||
protected void onAfterStep(StepResult stepResult) {
|
||||
// Do Nothing
|
||||
}
|
||||
|
@ -174,16 +217,24 @@ public class Agent<ACTION> {
|
|||
// Do Nothing
|
||||
}
|
||||
|
||||
public static <ACTION> Builder<ACTION> builder(@NonNull Environment<ACTION> environment, @NonNull TransformProcess transformProcess, @NonNull IPolicy<ACTION> policy) {
|
||||
/**
|
||||
*
|
||||
* @param environment
|
||||
* @param transformProcess
|
||||
* @param policy
|
||||
* @param <ACTION>
|
||||
* @return
|
||||
*/
|
||||
public static <ACTION> Builder<ACTION, Agent> builder(@NonNull Environment<ACTION> environment, @NonNull TransformProcess transformProcess, @NonNull IPolicy<ACTION> policy) {
|
||||
return new Builder<>(environment, transformProcess, policy);
|
||||
}
|
||||
|
||||
public static class Builder<ACTION> {
|
||||
private final Environment<ACTION> environment;
|
||||
private final TransformProcess transformProcess;
|
||||
private final IPolicy<ACTION> policy;
|
||||
private Integer maxEpisodeSteps = null; // Default, no max
|
||||
private String id;
|
||||
public static class Builder<ACTION, AGENT_TYPE extends Agent> {
|
||||
protected final Environment<ACTION> environment;
|
||||
protected final TransformProcess transformProcess;
|
||||
protected final IPolicy<ACTION> policy;
|
||||
protected Integer maxEpisodeSteps = null; // Default, no max
|
||||
protected String id;
|
||||
|
||||
public Builder(@NonNull Environment<ACTION> environment, @NonNull TransformProcess transformProcess, @NonNull IPolicy<ACTION> policy) {
|
||||
this.environment = environment;
|
||||
|
@ -191,20 +242,20 @@ public class Agent<ACTION> {
|
|||
this.policy = policy;
|
||||
}
|
||||
|
||||
public Builder<ACTION> maxEpisodeSteps(int maxEpisodeSteps) {
|
||||
public Builder<ACTION, AGENT_TYPE> maxEpisodeSteps(int maxEpisodeSteps) {
|
||||
Preconditions.checkArgument(maxEpisodeSteps > 0, "maxEpisodeSteps must be greater than 0, got", maxEpisodeSteps);
|
||||
this.maxEpisodeSteps = maxEpisodeSteps;
|
||||
|
||||
return this;
|
||||
}
|
||||
|
||||
public Builder<ACTION> id(String id) {
|
||||
public Builder<ACTION, AGENT_TYPE> id(String id) {
|
||||
this.id = id;
|
||||
return this;
|
||||
}
|
||||
|
||||
public Agent build() {
|
||||
return new Agent(this);
|
||||
public AGENT_TYPE build() {
|
||||
return (AGENT_TYPE)new Agent<ACTION>(environment, transformProcess, policy, maxEpisodeSteps, id);
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,115 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2020 Konduit K.K.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
package org.deeplearning4j.rl4j.agent;
|
||||
|
||||
import lombok.Getter;
|
||||
import lombok.NonNull;
|
||||
import org.deeplearning4j.rl4j.agent.learning.ILearningBehavior;
|
||||
import org.deeplearning4j.rl4j.environment.Environment;
|
||||
import org.deeplearning4j.rl4j.environment.StepResult;
|
||||
import org.deeplearning4j.rl4j.observation.Observation;
|
||||
import org.deeplearning4j.rl4j.observation.transform.TransformProcess;
|
||||
import org.deeplearning4j.rl4j.policy.IPolicy;
|
||||
|
||||
/**
|
||||
* The ActionLearner is an {@link Agent} that delegate the learning to a {@link ILearningBehavior}.
|
||||
* @param <ACTION> The type of the action
|
||||
*/
|
||||
public class AgentLearner<ACTION> extends Agent<ACTION> implements IAgentLearner<ACTION> {
|
||||
|
||||
@Getter
|
||||
private int totalStepCount = 0;
|
||||
|
||||
private final ILearningBehavior<ACTION> learningBehavior;
|
||||
private double rewardAtLastExperience;
|
||||
|
||||
/**
|
||||
*
|
||||
* @param environment The {@link Environment} to be used
|
||||
* @param transformProcess The {@link TransformProcess} to be used to transform the raw observations into usable ones.
|
||||
* @param policy The {@link IPolicy} to be used
|
||||
* @param maxEpisodeSteps The maximum number of steps an episode can have before being interrupted. Use null to have no max.
|
||||
* @param id A user-supplied id to identify the instance.
|
||||
* @param learningBehavior The {@link ILearningBehavior} that will be used to supervise the learning.
|
||||
*/
|
||||
public AgentLearner(Environment<ACTION> environment, TransformProcess transformProcess, IPolicy<ACTION> policy, Integer maxEpisodeSteps, String id, @NonNull ILearningBehavior<ACTION> learningBehavior) {
|
||||
super(environment, transformProcess, policy, maxEpisodeSteps, id);
|
||||
|
||||
this.learningBehavior = learningBehavior;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected void reset() {
|
||||
super.reset();
|
||||
|
||||
rewardAtLastExperience = 0;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected void onBeforeEpisode() {
|
||||
super.onBeforeEpisode();
|
||||
|
||||
learningBehavior.handleEpisodeStart();
|
||||
}
|
||||
|
||||
@Override
|
||||
protected void onAfterAction(Observation observationBeforeAction, ACTION action, StepResult stepResult) {
|
||||
if(!observationBeforeAction.isSkipped()) {
|
||||
double rewardSinceLastExperience = getReward() - rewardAtLastExperience;
|
||||
learningBehavior.handleNewExperience(observationBeforeAction, action, rewardSinceLastExperience, stepResult.isTerminal());
|
||||
|
||||
rewardAtLastExperience = getReward();
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
protected void onAfterEpisode() {
|
||||
learningBehavior.handleEpisodeEnd(getObservation());
|
||||
}
|
||||
|
||||
@Override
|
||||
protected void incrementEpisodeStepCount() {
|
||||
super.incrementEpisodeStepCount();
|
||||
++totalStepCount;
|
||||
}
|
||||
|
||||
// FIXME: parent is still visible
|
||||
public static <ACTION> AgentLearner.Builder<ACTION, AgentLearner<ACTION>> builder(Environment<ACTION> environment,
|
||||
TransformProcess transformProcess,
|
||||
IPolicy<ACTION> policy,
|
||||
ILearningBehavior<ACTION> learningBehavior) {
|
||||
return new AgentLearner.Builder<ACTION, AgentLearner<ACTION>>(environment, transformProcess, policy, learningBehavior);
|
||||
}
|
||||
|
||||
public static class Builder<ACTION, AGENT_TYPE extends AgentLearner<ACTION>> extends Agent.Builder<ACTION, AGENT_TYPE> {
|
||||
|
||||
private final ILearningBehavior<ACTION> learningBehavior;
|
||||
|
||||
public Builder(@NonNull Environment<ACTION> environment,
|
||||
@NonNull TransformProcess transformProcess,
|
||||
@NonNull IPolicy<ACTION> policy,
|
||||
@NonNull ILearningBehavior<ACTION> learningBehavior) {
|
||||
super(environment, transformProcess, policy);
|
||||
|
||||
this.learningBehavior = learningBehavior;
|
||||
}
|
||||
|
||||
@Override
|
||||
public AGENT_TYPE build() {
|
||||
return (AGENT_TYPE)new AgentLearner<ACTION>(environment, transformProcess, policy, maxEpisodeSteps, id, learningBehavior);
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,55 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2020 Konduit K.K.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
package org.deeplearning4j.rl4j.agent;
|
||||
|
||||
import org.deeplearning4j.rl4j.environment.Environment;
|
||||
import org.deeplearning4j.rl4j.policy.IPolicy;
|
||||
|
||||
/**
|
||||
* The interface of {@link Agent}
|
||||
* @param <ACTION>
|
||||
*/
|
||||
public interface IAgent<ACTION> {
|
||||
/**
|
||||
* Will play a single episode
|
||||
*/
|
||||
void run();
|
||||
|
||||
/**
|
||||
* @return A user-supplied id to identify the IAgent instance.
|
||||
*/
|
||||
String getId();
|
||||
|
||||
/**
|
||||
* @return The {@link Environment} instance being used by the agent.
|
||||
*/
|
||||
Environment<ACTION> getEnvironment();
|
||||
|
||||
/**
|
||||
* @return The {@link IPolicy} instance being used by the agent.
|
||||
*/
|
||||
IPolicy<ACTION> getPolicy();
|
||||
|
||||
/**
|
||||
* @return The step count taken in the current episode.
|
||||
*/
|
||||
int getEpisodeStepCount();
|
||||
|
||||
/**
|
||||
* @return The cumulative reward received in the current episode.
|
||||
*/
|
||||
double getReward();
|
||||
}
|
|
@ -0,0 +1,24 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2020 Konduit K.K.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
package org.deeplearning4j.rl4j.agent;
|
||||
|
||||
public interface IAgentLearner<ACTION> extends IAgent<ACTION> {
|
||||
|
||||
/**
|
||||
* @return The total count of steps taken by this AgentLearner, for all episodes.
|
||||
*/
|
||||
int getTotalStepCount();
|
||||
}
|
|
@ -0,0 +1,49 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2020 Konduit K.K.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
package org.deeplearning4j.rl4j.agent.learning;
|
||||
|
||||
import org.deeplearning4j.rl4j.observation.Observation;
|
||||
|
||||
/**
|
||||
* The <code>ILearningBehavior</code> implementations are in charge of the training. Through this interface, they are
|
||||
* notified as new experience is generated.
|
||||
*
|
||||
* @param <ACTION> The type of action
|
||||
*/
|
||||
public interface ILearningBehavior<ACTION> {
|
||||
|
||||
/**
|
||||
* This method is called when a new episode has been started.
|
||||
*/
|
||||
void handleEpisodeStart();
|
||||
|
||||
/**
|
||||
* This method is called when new experience is generated.
|
||||
*
|
||||
* @param observation The observation prior to taking the action
|
||||
* @param action The action that has been taken
|
||||
* @param reward The reward received by taking the action
|
||||
* @param isTerminal True if the episode ended after taking the action
|
||||
*/
|
||||
void handleNewExperience(Observation observation, ACTION action, double reward, boolean isTerminal);
|
||||
|
||||
/**
|
||||
* This method is called when the episode ends or the maximum number of episode steps is reached.
|
||||
*
|
||||
* @param finalObservation The observation after the last action of the episode has been taken.
|
||||
*/
|
||||
void handleEpisodeEnd(Observation finalObservation);
|
||||
}
|
|
@ -0,0 +1,59 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2020 Konduit K.K.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
package org.deeplearning4j.rl4j.agent.learning;
|
||||
|
||||
import lombok.Builder;
|
||||
import org.deeplearning4j.rl4j.agent.update.IUpdateRule;
|
||||
import org.deeplearning4j.rl4j.experience.ExperienceHandler;
|
||||
import org.deeplearning4j.rl4j.observation.Observation;
|
||||
|
||||
/**
|
||||
* A generic {@link ILearningBehavior} that delegates the handling of experience to a {@link ExperienceHandler} and
|
||||
* the update logic to a {@link IUpdateRule}
|
||||
*
|
||||
* @param <ACTION> The type of the action
|
||||
* @param <EXPERIENCE_TYPE> The type of experience the ExperienceHandler needs
|
||||
*/
|
||||
@Builder
|
||||
public class LearningBehavior<ACTION, EXPERIENCE_TYPE> implements ILearningBehavior<ACTION> {
|
||||
|
||||
@Builder.Default
|
||||
private int experienceUpdateSize = 64;
|
||||
|
||||
private final ExperienceHandler<ACTION, EXPERIENCE_TYPE> experienceHandler;
|
||||
private final IUpdateRule<EXPERIENCE_TYPE> updateRule;
|
||||
|
||||
@Override
|
||||
public void handleEpisodeStart() {
|
||||
experienceHandler.reset();
|
||||
}
|
||||
|
||||
@Override
|
||||
public void handleNewExperience(Observation observation, ACTION action, double reward, boolean isTerminal) {
|
||||
experienceHandler.addExperience(observation, action, reward, isTerminal);
|
||||
if(experienceHandler.isTrainingBatchReady()) {
|
||||
updateRule.update(experienceHandler.generateTrainingBatch());
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public void handleEpisodeEnd(Observation finalObservation) {
|
||||
experienceHandler.setFinalObservation(finalObservation);
|
||||
if(experienceHandler.isTrainingBatchReady()) {
|
||||
updateRule.update(experienceHandler.generateTrainingBatch());
|
||||
}
|
||||
}
|
||||
}
|
|
@ -1,23 +1,66 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2020 Konduit K.K.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
package org.deeplearning4j.rl4j.agent.listener;
|
||||
|
||||
import org.deeplearning4j.rl4j.agent.Agent;
|
||||
import org.deeplearning4j.rl4j.environment.StepResult;
|
||||
import org.deeplearning4j.rl4j.observation.Observation;
|
||||
|
||||
/**
|
||||
* The base definition of all {@link Agent} event listeners
|
||||
*/
|
||||
public interface AgentListener<ACTION> {
|
||||
enum ListenerResponse {
|
||||
/**
|
||||
* Tell the learning process to continue calling the listeners and the training.
|
||||
* Tell the {@link Agent} to continue calling the listeners and the processing.
|
||||
*/
|
||||
CONTINUE,
|
||||
|
||||
/**
|
||||
* Tell the learning process to stop calling the listeners and terminate the training.
|
||||
* Tell the {@link Agent} to interrupt calling the listeners and stop the processing.
|
||||
*/
|
||||
STOP,
|
||||
}
|
||||
|
||||
/**
|
||||
* Called when a new episode is about to start.
|
||||
* @param agent The agent that generated the event
|
||||
*
|
||||
* @return A {@link ListenerResponse}.
|
||||
*/
|
||||
AgentListener.ListenerResponse onBeforeEpisode(Agent agent);
|
||||
|
||||
/**
|
||||
* Called when a step is about to be taken.
|
||||
*
|
||||
* @param agent The agent that generated the event
|
||||
* @param observation The observation before the action is taken
|
||||
* @param action The action that will be performed
|
||||
*
|
||||
* @return A {@link ListenerResponse}.
|
||||
*/
|
||||
AgentListener.ListenerResponse onBeforeStep(Agent agent, Observation observation, ACTION action);
|
||||
|
||||
/**
|
||||
* Called after a step has been taken.
|
||||
*
|
||||
* @param agent The agent that generated the event
|
||||
* @param stepResult The {@link StepResult} result of the step.
|
||||
*
|
||||
* @return A {@link ListenerResponse}.
|
||||
*/
|
||||
AgentListener.ListenerResponse onAfterStep(Agent agent, StepResult stepResult);
|
||||
}
|
||||
|
|
|
@ -1,3 +1,18 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2020 Konduit K.K.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
package org.deeplearning4j.rl4j.agent.listener;
|
||||
|
||||
import org.deeplearning4j.rl4j.agent.Agent;
|
||||
|
@ -7,6 +22,10 @@ import org.deeplearning4j.rl4j.observation.Observation;
|
|||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
|
||||
/**
|
||||
* A class that manages a list of {@link AgentListener AgentListeners} listening to an {@link Agent}.
|
||||
* @param <ACTION>
|
||||
*/
|
||||
public class AgentListenerList<ACTION> {
|
||||
protected final List<AgentListener<ACTION>> listeners = new ArrayList<>();
|
||||
|
||||
|
@ -18,6 +37,13 @@ public class AgentListenerList<ACTION> {
|
|||
listeners.add(listener);
|
||||
}
|
||||
|
||||
/**
|
||||
* This method will notify all listeners that an episode is about to start. If a listener returns
|
||||
* {@link AgentListener.ListenerResponse STOP}, any following listener is skipped.
|
||||
*
|
||||
* @param agent The agent that generated the event.
|
||||
* @return False if the processing should be stopped
|
||||
*/
|
||||
public boolean notifyBeforeEpisode(Agent<ACTION> agent) {
|
||||
for (AgentListener<ACTION> listener : listeners) {
|
||||
if (listener.onBeforeEpisode(agent) == AgentListener.ListenerResponse.STOP) {
|
||||
|
@ -28,6 +54,13 @@ public class AgentListenerList<ACTION> {
|
|||
return true;
|
||||
}
|
||||
|
||||
/**
|
||||
*
|
||||
* @param agent The agent that generated the event.
|
||||
* @param observation The observation before the action is taken
|
||||
* @param action The action that will be performed
|
||||
* @return False if the processing should be stopped
|
||||
*/
|
||||
public boolean notifyBeforeStep(Agent<ACTION> agent, Observation observation, ACTION action) {
|
||||
for (AgentListener<ACTION> listener : listeners) {
|
||||
if (listener.onBeforeStep(agent, observation, action) == AgentListener.ListenerResponse.STOP) {
|
||||
|
@ -38,6 +71,12 @@ public class AgentListenerList<ACTION> {
|
|||
return true;
|
||||
}
|
||||
|
||||
/**
|
||||
*
|
||||
* @param agent The agent that generated the event.
|
||||
* @param stepResult The {@link StepResult} result of the step.
|
||||
* @return False if the processing should be stopped
|
||||
*/
|
||||
public boolean notifyAfterStep(Agent<ACTION> agent, StepResult stepResult) {
|
||||
for (AgentListener<ACTION> listener : listeners) {
|
||||
if (listener.onAfterStep(agent, stepResult) == AgentListener.ListenerResponse.STOP) {
|
||||
|
|
|
@ -0,0 +1,62 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2020 Konduit K.K.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
package org.deeplearning4j.rl4j.agent.update;
|
||||
|
||||
import lombok.Getter;
|
||||
import org.deeplearning4j.rl4j.learning.sync.Transition;
|
||||
import org.deeplearning4j.rl4j.learning.sync.qlearning.TargetQNetworkSource;
|
||||
import org.deeplearning4j.rl4j.learning.sync.qlearning.discrete.TDTargetAlgorithm.DoubleDQN;
|
||||
import org.deeplearning4j.rl4j.learning.sync.qlearning.discrete.TDTargetAlgorithm.ITDTargetAlgorithm;
|
||||
import org.deeplearning4j.rl4j.learning.sync.qlearning.discrete.TDTargetAlgorithm.StandardDQN;
|
||||
import org.deeplearning4j.rl4j.network.dqn.IDQN;
|
||||
import org.nd4j.linalg.dataset.api.DataSet;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
// Temporary class that will be replaced with a more generic class that delegates gradient computation
|
||||
// and network update to sub components.
|
||||
public class DQNNeuralNetUpdateRule implements IUpdateRule<Transition<Integer>>, TargetQNetworkSource {
|
||||
|
||||
@Getter
|
||||
private final IDQN qNetwork;
|
||||
|
||||
@Getter
|
||||
private IDQN targetQNetwork;
|
||||
private final int targetUpdateFrequency;
|
||||
|
||||
private final ITDTargetAlgorithm<Integer> tdTargetAlgorithm;
|
||||
|
||||
@Getter
|
||||
private int updateCount = 0;
|
||||
|
||||
public DQNNeuralNetUpdateRule(IDQN qNetwork, int targetUpdateFrequency, boolean isDoubleDQN, double gamma, double errorClamp) {
|
||||
this.qNetwork = qNetwork;
|
||||
this.targetQNetwork = qNetwork.clone();
|
||||
this.targetUpdateFrequency = targetUpdateFrequency;
|
||||
tdTargetAlgorithm = isDoubleDQN
|
||||
? new DoubleDQN(this, gamma, errorClamp)
|
||||
: new StandardDQN(this, gamma, errorClamp);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void update(List<Transition<Integer>> trainingBatch) {
|
||||
DataSet targets = tdTargetAlgorithm.computeTDTargets(trainingBatch);
|
||||
qNetwork.fit(targets.getFeatures(), targets.getLabels());
|
||||
if(++updateCount % targetUpdateFrequency == 0) {
|
||||
targetQNetwork = qNetwork.clone();
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,26 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2020 Konduit K.K.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
package org.deeplearning4j.rl4j.agent.update;
|
||||
|
||||
import lombok.Value;
|
||||
import org.deeplearning4j.nn.gradient.Gradient;
|
||||
|
||||
// Work in progress
|
||||
@Value
|
||||
public class Gradients {
|
||||
private Gradient[] gradients; // Temporary: we'll need something better than a Gradient[]
|
||||
private int batchSize;
|
||||
}
|
|
@ -0,0 +1,37 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2020 Konduit K.K.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
package org.deeplearning4j.rl4j.agent.update;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
/**
|
||||
* The role of IUpdateRule implementations is to use an experience batch to improve the accuracy of the policy.
|
||||
* Used by {@link org.deeplearning4j.rl4j.agent.AgentLearner AgentLearner}
|
||||
* @param <EXPERIENCE_TYPE> The type of the experience
|
||||
*/
|
||||
public interface IUpdateRule<EXPERIENCE_TYPE> {
|
||||
/**
|
||||
* Perform the update
|
||||
* @param trainingBatch A batch of experience
|
||||
*/
|
||||
void update(List<EXPERIENCE_TYPE> trainingBatch);
|
||||
|
||||
/**
|
||||
* @return The total number of times the policy has been updated. In a multi-agent learning context, this total is
|
||||
* for all the agents.
|
||||
*/
|
||||
int getUpdateCount();
|
||||
}
|
|
@ -1,9 +0,0 @@
|
|||
package org.deeplearning4j.rl4j.environment;
|
||||
|
||||
import lombok.Value;
|
||||
|
||||
@Value
|
||||
public class ActionSchema<ACTION> {
|
||||
private ACTION noOp;
|
||||
//FIXME ACTION randomAction();
|
||||
}
|
|
@ -1,11 +1,54 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2020 Konduit K.K.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
package org.deeplearning4j.rl4j.environment;
|
||||
|
||||
import java.util.Map;
|
||||
|
||||
/**
|
||||
* An interface for environments used by the {@link org.deeplearning4j.rl4j.agent.Agent Agents}.
|
||||
* @param <ACTION> The type of actions
|
||||
*/
|
||||
public interface Environment<ACTION> {
|
||||
|
||||
/**
|
||||
* @return The {@link Schema} of the environment
|
||||
*/
|
||||
Schema<ACTION> getSchema();
|
||||
|
||||
/**
|
||||
* Reset the environment's state to start a new episode.
|
||||
* @return
|
||||
*/
|
||||
Map<String, Object> reset();
|
||||
|
||||
/**
|
||||
* Perform a single step.
|
||||
*
|
||||
* @param action The action taken
|
||||
* @return A {@link StepResult} describing the result of the step.
|
||||
*/
|
||||
StepResult step(ACTION action);
|
||||
|
||||
/**
|
||||
* @return True if the episode is finished
|
||||
*/
|
||||
boolean isEpisodeFinished();
|
||||
|
||||
/**
|
||||
* Called when the agent is finished using this environment instance.
|
||||
*/
|
||||
void close();
|
||||
}
|
||||
|
|
|
@ -0,0 +1,26 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2020 Konduit K.K.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
package org.deeplearning4j.rl4j.environment;
|
||||
|
||||
import lombok.Value;
|
||||
|
||||
// Work in progress
|
||||
public interface IActionSchema<ACTION> {
|
||||
ACTION getNoOp();
|
||||
|
||||
// Review: A schema should be data-only and not have behavior
|
||||
ACTION getRandomAction();
|
||||
}
|
|
@ -0,0 +1,47 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2020 Konduit K.K.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
package org.deeplearning4j.rl4j.environment;
|
||||
|
||||
import org.nd4j.linalg.api.rng.Random;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
|
||||
// Work in progress
|
||||
public class IntegerActionSchema implements IActionSchema<Integer> {
|
||||
|
||||
private final int numActions;
|
||||
private final int noOpAction;
|
||||
private final Random rnd;
|
||||
|
||||
public IntegerActionSchema(int numActions, int noOpAction) {
|
||||
this(numActions, noOpAction, Nd4j.getRandom());
|
||||
}
|
||||
|
||||
public IntegerActionSchema(int numActions, int noOpAction, Random rnd) {
|
||||
this.numActions = numActions;
|
||||
this.noOpAction = noOpAction;
|
||||
this.rnd = rnd;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Integer getNoOp() {
|
||||
return noOpAction;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Integer getRandomAction() {
|
||||
return rnd.nextInt(numActions);
|
||||
}
|
||||
}
|
|
@ -1,8 +1,24 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2020 Konduit K.K.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
package org.deeplearning4j.rl4j.environment;
|
||||
|
||||
import lombok.Value;
|
||||
|
||||
// Work in progress
|
||||
@Value
|
||||
public class Schema<ACTION> {
|
||||
private ActionSchema<ACTION> actionSchema;
|
||||
private IActionSchema<ACTION> actionSchema;
|
||||
}
|
||||
|
|
|
@ -1,3 +1,18 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2020 Konduit K.K.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
package org.deeplearning4j.rl4j.environment;
|
||||
|
||||
import lombok.Value;
|
||||
|
|
|
@ -41,6 +41,11 @@ public interface ExperienceHandler<A, E> {
|
|||
*/
|
||||
int getTrainingBatchSize();
|
||||
|
||||
/**
|
||||
* @return True if a batch is ready for training.
|
||||
*/
|
||||
boolean isTrainingBatchReady();
|
||||
|
||||
/**
|
||||
* The elements are returned in the historical order (i.e. in the order they happened)
|
||||
* @return The list of experience elements
|
||||
|
|
|
@ -36,6 +36,7 @@ import java.util.List;
|
|||
public class ReplayMemoryExperienceHandler<A> implements ExperienceHandler<A, Transition<A>> {
|
||||
private static final int DEFAULT_MAX_REPLAY_MEMORY_SIZE = 150000;
|
||||
private static final int DEFAULT_BATCH_SIZE = 32;
|
||||
private final int batchSize;
|
||||
|
||||
private IExpReplay<A> expReplay;
|
||||
|
||||
|
@ -43,6 +44,7 @@ public class ReplayMemoryExperienceHandler<A> implements ExperienceHandler<A, Tr
|
|||
|
||||
public ReplayMemoryExperienceHandler(IExpReplay<A> expReplay) {
|
||||
this.expReplay = expReplay;
|
||||
this.batchSize = expReplay.getDesignatedBatchSize();
|
||||
}
|
||||
|
||||
public ReplayMemoryExperienceHandler(int maxReplayMemorySize, int batchSize, Random random) {
|
||||
|
@ -64,6 +66,11 @@ public class ReplayMemoryExperienceHandler<A> implements ExperienceHandler<A, Tr
|
|||
return expReplay.getBatchSize();
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean isTrainingBatchReady() {
|
||||
return expReplay.getBatchSize() >= batchSize;
|
||||
}
|
||||
|
||||
/**
|
||||
* @return A batch of experience selected from the replay memory. The replay memory is unchanged after the call.
|
||||
*/
|
||||
|
|
|
@ -30,10 +30,18 @@ import java.util.List;
|
|||
*/
|
||||
public class StateActionExperienceHandler<A> implements ExperienceHandler<A, StateActionPair<A>> {
|
||||
|
||||
private final int batchSize;
|
||||
|
||||
private boolean isFinalObservationSet;
|
||||
|
||||
public StateActionExperienceHandler(int batchSize) {
|
||||
this.batchSize = batchSize;
|
||||
}
|
||||
|
||||
private List<StateActionPair<A>> stateActionPairs = new ArrayList<>();
|
||||
|
||||
public void setFinalObservation(Observation observation) {
|
||||
// Do nothing
|
||||
isFinalObservationSet = true;
|
||||
}
|
||||
|
||||
public void addExperience(Observation observation, A action, double reward, boolean isTerminal) {
|
||||
|
@ -45,6 +53,12 @@ public class StateActionExperienceHandler<A> implements ExperienceHandler<A, Sta
|
|||
return stateActionPairs.size();
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean isTrainingBatchReady() {
|
||||
return stateActionPairs.size() >= batchSize
|
||||
|| (isFinalObservationSet && stateActionPairs.size() > 0);
|
||||
}
|
||||
|
||||
/**
|
||||
* The elements are returned in the historical order (i.e. in the order they happened)
|
||||
* Note: the experience store is cleared after calling this method.
|
||||
|
@ -62,6 +76,7 @@ public class StateActionExperienceHandler<A> implements ExperienceHandler<A, Sta
|
|||
@Override
|
||||
public void reset() {
|
||||
stateActionPairs = new ArrayList<>();
|
||||
isFinalObservationSet = false;
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
@ -24,17 +24,38 @@ import org.nd4j.linalg.factory.Nd4j;
|
|||
* @author Alexandre Boulanger
|
||||
*/
|
||||
public class INDArrayHelper {
|
||||
|
||||
/**
|
||||
* MultiLayerNetwork and ComputationGraph expects input data to be in NCHW in the case of pixels and NS in case of other data types.
|
||||
*
|
||||
* We must have either shape 2 (NK) or shape 4 (NCHW)
|
||||
* Force the input source to have the correct shape:
|
||||
* <p><ul>
|
||||
* <li>DL4J requires it to be at least 2D</li>
|
||||
* <li>RL4J has a convention to have the batch size on dimension 0 to all INDArrays</li>
|
||||
* </ul></p>
|
||||
* @param source The {@link INDArray} to be corrected.
|
||||
* @return The corrected INDArray
|
||||
*/
|
||||
public static INDArray forceCorrectShape(INDArray source) {
|
||||
|
||||
return source.shape()[0] == 1 && source.shape().length > 1
|
||||
return source.shape()[0] == 1 && source.rank() > 1
|
||||
? source
|
||||
: Nd4j.expandDims(source, 0);
|
||||
|
||||
}
|
||||
|
||||
/**
|
||||
* This will create a INDArray with <i>batchSize</i> as dimension 0 and <i>shape</i> as other dimensions.
|
||||
* For example, if <i>batchSize</i> is 10 and shape is { 1, 3, 4 }, the resulting INDArray shape will be { 10, 3, 4}
|
||||
* @param batchSize The size of the batch to create
|
||||
* @param shape The shape of individual elements.
|
||||
* Note: all shapes in RL4J should have a batch size as dimension 0; in this case the batch size should be 1.
|
||||
* @return A INDArray
|
||||
*/
|
||||
public static INDArray createBatchForShape(long batchSize, long... shape) {
|
||||
long[] batchShape;
|
||||
|
||||
batchShape = new long[shape.length];
|
||||
System.arraycopy(shape, 0, batchShape, 0, shape.length);
|
||||
|
||||
batchShape[0] = batchSize;
|
||||
return Nd4j.create(batchShape);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -25,6 +25,7 @@ import org.deeplearning4j.gym.StepReply;
|
|||
import org.deeplearning4j.rl4j.experience.ExperienceHandler;
|
||||
import org.deeplearning4j.rl4j.experience.StateActionExperienceHandler;
|
||||
import org.deeplearning4j.rl4j.learning.IHistoryProcessor;
|
||||
import org.deeplearning4j.rl4j.learning.configuration.IAsyncLearningConfiguration;
|
||||
import org.deeplearning4j.rl4j.learning.listener.TrainingListenerList;
|
||||
import org.deeplearning4j.rl4j.mdp.MDP;
|
||||
import org.deeplearning4j.rl4j.network.NeuralNet;
|
||||
|
@ -49,7 +50,7 @@ public abstract class AsyncThreadDiscrete<OBSERVATION extends Encodable, NN exte
|
|||
|
||||
// TODO: Make it configurable with a builder
|
||||
@Setter(AccessLevel.PROTECTED) @Getter
|
||||
private ExperienceHandler experienceHandler = new StateActionExperienceHandler();
|
||||
private ExperienceHandler experienceHandler;
|
||||
|
||||
public AsyncThreadDiscrete(IAsyncGlobal<NN> asyncGlobal,
|
||||
MDP<OBSERVATION, Integer, DiscreteSpace> mdp,
|
||||
|
@ -60,6 +61,17 @@ public abstract class AsyncThreadDiscrete<OBSERVATION extends Encodable, NN exte
|
|||
synchronized (asyncGlobal) {
|
||||
current = (NN) asyncGlobal.getTarget().clone();
|
||||
}
|
||||
|
||||
experienceHandler = new StateActionExperienceHandler(getNStep());
|
||||
}
|
||||
|
||||
private int getNStep() {
|
||||
IAsyncLearningConfiguration configuration = getConfiguration();
|
||||
if(configuration == null) {
|
||||
return Integer.MAX_VALUE;
|
||||
}
|
||||
|
||||
return configuration.getNStep();
|
||||
}
|
||||
|
||||
// TODO: Add an actor-learner class and be able to inject the update algorithm
|
||||
|
|
|
@ -71,7 +71,6 @@ public class AsyncNStepQLearningThreadDiscrete<OBSERVATION extends Encodable> ex
|
|||
|
||||
@Override
|
||||
protected UpdateAlgorithm<IDQN> buildUpdateAlgorithm() {
|
||||
int[] shape = getHistoryProcessor() == null ? getMdp().getObservationSpace().getShape() : getHistoryProcessor().getConf().getShape();
|
||||
return new QLearningUpdateAlgorithm(shape, getMdp().getActionSpace().getSize(), configuration.getGamma());
|
||||
return new QLearningUpdateAlgorithm(getMdp().getActionSpace().getSize(), configuration.getGamma());
|
||||
}
|
||||
}
|
||||
|
|
|
@ -17,7 +17,7 @@ package org.deeplearning4j.rl4j.learning.async.nstep.discrete;
|
|||
|
||||
import org.deeplearning4j.nn.gradient.Gradient;
|
||||
import org.deeplearning4j.rl4j.experience.StateActionPair;
|
||||
import org.deeplearning4j.rl4j.learning.Learning;
|
||||
import org.deeplearning4j.rl4j.helper.INDArrayHelper;
|
||||
import org.deeplearning4j.rl4j.learning.async.UpdateAlgorithm;
|
||||
import org.deeplearning4j.rl4j.network.dqn.IDQN;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
|
@ -27,15 +27,12 @@ import java.util.List;
|
|||
|
||||
public class QLearningUpdateAlgorithm implements UpdateAlgorithm<IDQN> {
|
||||
|
||||
private final int[] shape;
|
||||
private final int actionSpaceSize;
|
||||
private final double gamma;
|
||||
|
||||
public QLearningUpdateAlgorithm(int[] shape,
|
||||
int actionSpaceSize,
|
||||
public QLearningUpdateAlgorithm(int actionSpaceSize,
|
||||
double gamma) {
|
||||
|
||||
this.shape = shape;
|
||||
this.actionSpaceSize = actionSpaceSize;
|
||||
this.gamma = gamma;
|
||||
}
|
||||
|
@ -44,33 +41,34 @@ public class QLearningUpdateAlgorithm implements UpdateAlgorithm<IDQN> {
|
|||
public Gradient[] computeGradients(IDQN current, List<StateActionPair<Integer>> experience) {
|
||||
int size = experience.size();
|
||||
|
||||
int[] nshape = Learning.makeShape(size, shape);
|
||||
INDArray input = Nd4j.create(nshape);
|
||||
INDArray targets = Nd4j.create(size, actionSpaceSize);
|
||||
|
||||
StateActionPair<Integer> stateActionPair = experience.get(size - 1);
|
||||
|
||||
INDArray data = stateActionPair.getObservation().getData();
|
||||
INDArray features = INDArrayHelper.createBatchForShape(size, data.shape());
|
||||
INDArray targets = Nd4j.create(size, actionSpaceSize);
|
||||
|
||||
double r;
|
||||
if (stateActionPair.isTerminal()) {
|
||||
r = 0;
|
||||
} else {
|
||||
INDArray[] output = null;
|
||||
output = current.outputAll(stateActionPair.getObservation().getData());
|
||||
output = current.outputAll(data);
|
||||
r = Nd4j.max(output[0]).getDouble(0);
|
||||
}
|
||||
|
||||
for (int i = size - 1; i >= 0; i--) {
|
||||
stateActionPair = experience.get(i);
|
||||
data = stateActionPair.getObservation().getData();
|
||||
|
||||
input.putRow(i, stateActionPair.getObservation().getData());
|
||||
features.putRow(i, data);
|
||||
|
||||
r = stateActionPair.getReward() + gamma * r;
|
||||
INDArray[] output = current.outputAll(stateActionPair.getObservation().getData());
|
||||
INDArray[] output = current.outputAll(data);
|
||||
INDArray row = output[0];
|
||||
row = row.putScalar(stateActionPair.getAction(), r);
|
||||
targets.putRow(i, row);
|
||||
}
|
||||
|
||||
return current.gradient(input, targets);
|
||||
return current.gradient(features, targets);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -80,6 +80,11 @@ public class ExpReplay<A> implements IExpReplay<A> {
|
|||
//log.info("size: "+storage.size());
|
||||
}
|
||||
|
||||
@Override
|
||||
public int getDesignatedBatchSize() {
|
||||
return batchSize;
|
||||
}
|
||||
|
||||
public int getBatchSize() {
|
||||
int storageSize = storage.size();
|
||||
return Math.min(storageSize, batchSize);
|
||||
|
|
|
@ -47,4 +47,9 @@ public interface IExpReplay<A> {
|
|||
* @param transition a new transition to store
|
||||
*/
|
||||
void store(Transition<A> transition);
|
||||
|
||||
/**
|
||||
* @return The desired size of batches
|
||||
*/
|
||||
int getDesignatedBatchSize();
|
||||
}
|
||||
|
|
|
@ -51,25 +51,16 @@ import java.util.List;
|
|||
@Slf4j
|
||||
public abstract class QLearning<O extends Encodable, A, AS extends ActionSpace<A>>
|
||||
extends SyncLearning<O, A, AS, IDQN>
|
||||
implements TargetQNetworkSource, IEpochTrainer {
|
||||
implements IEpochTrainer {
|
||||
|
||||
protected abstract LegacyMDPWrapper<O, A, AS> getLegacyMDPWrapper();
|
||||
|
||||
protected abstract EpsGreedy<O, A, AS> getEgPolicy();
|
||||
protected abstract EpsGreedy<A> getEgPolicy();
|
||||
|
||||
public abstract MDP<O, A, AS> getMdp();
|
||||
|
||||
public abstract IDQN getQNetwork();
|
||||
|
||||
public abstract IDQN getTargetQNetwork();
|
||||
|
||||
protected abstract void setTargetQNetwork(IDQN dqn);
|
||||
|
||||
protected void updateTargetNetwork() {
|
||||
log.info("Update target network");
|
||||
setTargetQNetwork(getQNetwork().clone());
|
||||
}
|
||||
|
||||
public IDQN getNeuralNet() {
|
||||
return getQNetwork();
|
||||
}
|
||||
|
@ -101,11 +92,6 @@ public abstract class QLearning<O extends Encodable, A, AS extends ActionSpace<A
|
|||
int numQ = 0;
|
||||
List<Double> scores = new ArrayList<>();
|
||||
while (currentEpisodeStepCount < getConfiguration().getMaxEpochStep() && !getMdp().isDone()) {
|
||||
|
||||
if (this.getStepCount() % getConfiguration().getTargetDqnUpdateFreq() == 0) {
|
||||
updateTargetNetwork();
|
||||
}
|
||||
|
||||
QLStepReturn<Observation> stepR = trainStep(obs);
|
||||
|
||||
if (!stepR.getMaxQ().isNaN()) {
|
||||
|
@ -146,7 +132,6 @@ public abstract class QLearning<O extends Encodable, A, AS extends ActionSpace<A
|
|||
|
||||
protected void resetNetworks() {
|
||||
getQNetwork().reset();
|
||||
getTargetQNetwork().reset();
|
||||
}
|
||||
|
||||
private InitMdp<Observation> refacInitMdp() {
|
||||
|
|
|
@ -21,6 +21,10 @@ import lombok.AccessLevel;
|
|||
import lombok.Getter;
|
||||
import lombok.Setter;
|
||||
import org.deeplearning4j.gym.StepReply;
|
||||
import org.deeplearning4j.rl4j.agent.learning.ILearningBehavior;
|
||||
import org.deeplearning4j.rl4j.agent.learning.LearningBehavior;
|
||||
import org.deeplearning4j.rl4j.agent.update.DQNNeuralNetUpdateRule;
|
||||
import org.deeplearning4j.rl4j.agent.update.IUpdateRule;
|
||||
import org.deeplearning4j.rl4j.experience.ExperienceHandler;
|
||||
import org.deeplearning4j.rl4j.experience.ReplayMemoryExperienceHandler;
|
||||
import org.deeplearning4j.rl4j.learning.IHistoryProcessor;
|
||||
|
@ -28,9 +32,6 @@ import org.deeplearning4j.rl4j.learning.Learning;
|
|||
import org.deeplearning4j.rl4j.learning.configuration.QLearningConfiguration;
|
||||
import org.deeplearning4j.rl4j.learning.sync.Transition;
|
||||
import org.deeplearning4j.rl4j.learning.sync.qlearning.QLearning;
|
||||
import org.deeplearning4j.rl4j.learning.sync.qlearning.discrete.TDTargetAlgorithm.DoubleDQN;
|
||||
import org.deeplearning4j.rl4j.learning.sync.qlearning.discrete.TDTargetAlgorithm.ITDTargetAlgorithm;
|
||||
import org.deeplearning4j.rl4j.learning.sync.qlearning.discrete.TDTargetAlgorithm.StandardDQN;
|
||||
import org.deeplearning4j.rl4j.mdp.MDP;
|
||||
import org.deeplearning4j.rl4j.network.dqn.IDQN;
|
||||
import org.deeplearning4j.rl4j.space.Encodable;
|
||||
|
@ -41,12 +42,8 @@ import org.deeplearning4j.rl4j.space.DiscreteSpace;
|
|||
import org.deeplearning4j.rl4j.util.LegacyMDPWrapper;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.api.rng.Random;
|
||||
import org.nd4j.linalg.dataset.api.DataSet;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
|
||||
|
||||
/**
|
||||
* @author rubenfiszel (ruben.fiszel@epfl.ch) 7/18/16.
|
||||
|
@ -63,22 +60,15 @@ public abstract class QLearningDiscrete<O extends Encodable> extends QLearning<O
|
|||
@Getter
|
||||
private DQNPolicy<O> policy;
|
||||
@Getter
|
||||
private EpsGreedy<O, Integer, DiscreteSpace> egPolicy;
|
||||
private EpsGreedy<Integer> egPolicy;
|
||||
|
||||
@Getter
|
||||
final private IDQN qNetwork;
|
||||
@Getter
|
||||
@Setter(AccessLevel.PROTECTED)
|
||||
private IDQN targetQNetwork;
|
||||
|
||||
private int lastAction;
|
||||
private double accuReward = 0;
|
||||
|
||||
ITDTargetAlgorithm tdTargetAlgorithm;
|
||||
|
||||
// TODO: User a builder and remove the setter
|
||||
@Getter(AccessLevel.PROTECTED) @Setter
|
||||
private ExperienceHandler<Integer, Transition<Integer>> experienceHandler;
|
||||
private final ILearningBehavior<Integer> learningBehavior;
|
||||
|
||||
protected LegacyMDPWrapper<O, Integer, DiscreteSpace> getLegacyMDPWrapper() {
|
||||
return mdp;
|
||||
|
@ -88,21 +78,31 @@ public abstract class QLearningDiscrete<O extends Encodable> extends QLearning<O
|
|||
this(mdp, dqn, conf, epsilonNbStep, Nd4j.getRandomFactory().getNewRandomInstance(conf.getSeed()));
|
||||
}
|
||||
|
||||
public QLearningDiscrete(MDP<O, Integer, DiscreteSpace> mdp, IDQN dqn, QLearningConfiguration conf, int epsilonNbStep, Random random) {
|
||||
this(mdp, dqn, conf, epsilonNbStep, buildLearningBehavior(dqn, conf, random), random);
|
||||
}
|
||||
|
||||
public QLearningDiscrete(MDP<O, Integer, DiscreteSpace> mdp, IDQN dqn, QLearningConfiguration conf,
|
||||
int epsilonNbStep, Random random) {
|
||||
int epsilonNbStep, ILearningBehavior<Integer> learningBehavior, Random random) {
|
||||
this.configuration = conf;
|
||||
this.mdp = new LegacyMDPWrapper<>(mdp, null);
|
||||
qNetwork = dqn;
|
||||
targetQNetwork = dqn.clone();
|
||||
policy = new DQNPolicy(getQNetwork());
|
||||
egPolicy = new EpsGreedy(policy, mdp, conf.getUpdateStart(), epsilonNbStep, random, conf.getMinEpsilon(),
|
||||
this);
|
||||
|
||||
tdTargetAlgorithm = conf.isDoubleDQN()
|
||||
? new DoubleDQN(this, conf.getGamma(), conf.getErrorClamp())
|
||||
: new StandardDQN(this, conf.getGamma(), conf.getErrorClamp());
|
||||
this.learningBehavior = learningBehavior;
|
||||
}
|
||||
|
||||
private static ILearningBehavior<Integer> buildLearningBehavior(IDQN qNetwork, QLearningConfiguration conf, Random random) {
|
||||
IUpdateRule<Transition<Integer>> updateRule = new DQNNeuralNetUpdateRule(qNetwork, conf.getTargetDqnUpdateFreq(), conf.isDoubleDQN(), conf.getGamma(), conf.getErrorClamp());
|
||||
ExperienceHandler<Integer, Transition<Integer>> experienceHandler = new ReplayMemoryExperienceHandler(conf.getExpRepMaxSize(), conf.getBatchSize(), random);
|
||||
return LearningBehavior.<Integer, Transition<Integer>>builder()
|
||||
.experienceHandler(experienceHandler)
|
||||
.updateRule(updateRule)
|
||||
.experienceUpdateSize(conf.getBatchSize())
|
||||
.build();
|
||||
|
||||
experienceHandler = new ReplayMemoryExperienceHandler(conf.getExpRepMaxSize(), conf.getBatchSize(), random);
|
||||
}
|
||||
|
||||
public MDP<O, Integer, DiscreteSpace> getMdp() {
|
||||
|
@ -119,7 +119,7 @@ public abstract class QLearningDiscrete<O extends Encodable> extends QLearning<O
|
|||
public void preEpoch() {
|
||||
lastAction = mdp.getActionSpace().noOp();
|
||||
accuReward = 0;
|
||||
experienceHandler.reset();
|
||||
learningBehavior.handleEpisodeStart();
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -136,12 +136,6 @@ public abstract class QLearningDiscrete<O extends Encodable> extends QLearning<O
|
|||
*/
|
||||
protected QLStepReturn<Observation> trainStep(Observation obs) {
|
||||
|
||||
boolean isHistoryProcessor = getHistoryProcessor() != null;
|
||||
int skipFrame = isHistoryProcessor ? getHistoryProcessor().getConf().getSkipFrame() : 1;
|
||||
int historyLength = isHistoryProcessor ? getHistoryProcessor().getConf().getHistoryLength() : 1;
|
||||
int updateStart = this.getConfiguration().getUpdateStart()
|
||||
+ ((this.getConfiguration().getBatchSize() + historyLength) * skipFrame);
|
||||
|
||||
Double maxQ = Double.NaN; //ignore if Nan for stats
|
||||
|
||||
//if step of training, just repeat lastAction
|
||||
|
@ -160,29 +154,15 @@ public abstract class QLearningDiscrete<O extends Encodable> extends QLearning<O
|
|||
if (!obs.isSkipped()) {
|
||||
|
||||
// Add experience
|
||||
experienceHandler.addExperience(obs, lastAction, accuReward, stepReply.isDone());
|
||||
learningBehavior.handleNewExperience(obs, lastAction, accuReward, stepReply.isDone());
|
||||
accuReward = 0;
|
||||
|
||||
// Update NN
|
||||
// FIXME: maybe start updating when experience replay has reached a certain size instead of using "updateStart"?
|
||||
if (this.getStepCount() > updateStart) {
|
||||
DataSet targets = setTarget(experienceHandler.generateTrainingBatch());
|
||||
getQNetwork().fit(targets.getFeatures(), targets.getLabels());
|
||||
}
|
||||
}
|
||||
|
||||
return new QLStepReturn<>(maxQ, getQNetwork().getLatestScore(), stepReply);
|
||||
}
|
||||
|
||||
protected DataSet setTarget(List<Transition<Integer>> transitions) {
|
||||
if (transitions.size() == 0)
|
||||
throw new IllegalArgumentException("too few transitions");
|
||||
|
||||
return tdTargetAlgorithm.computeTDTargets(transitions);
|
||||
}
|
||||
|
||||
@Override
|
||||
protected void finishEpoch(Observation observation) {
|
||||
experienceHandler.setFinalObservation(observation);
|
||||
learningBehavior.handleEpisodeEnd(observation);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -2,21 +2,19 @@ package org.deeplearning4j.rl4j.mdp;
|
|||
|
||||
import lombok.Getter;
|
||||
import lombok.Setter;
|
||||
import org.deeplearning4j.rl4j.environment.ActionSchema;
|
||||
import org.deeplearning4j.rl4j.environment.Environment;
|
||||
import org.deeplearning4j.rl4j.environment.Schema;
|
||||
import org.deeplearning4j.rl4j.environment.StepResult;
|
||||
import org.deeplearning4j.rl4j.environment.*;
|
||||
import org.nd4j.linalg.api.rng.Random;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
|
||||
import java.util.HashMap;
|
||||
import java.util.Map;
|
||||
import java.util.Random;
|
||||
|
||||
public class CartpoleEnvironment implements Environment<Integer> {
|
||||
private static final int NUM_ACTIONS = 2;
|
||||
private static final int ACTION_LEFT = 0;
|
||||
private static final int ACTION_RIGHT = 1;
|
||||
|
||||
private static final Schema<Integer> schema = new Schema<>(new ActionSchema<>(ACTION_LEFT));
|
||||
private final Schema<Integer> schema;
|
||||
|
||||
public enum KinematicsIntegrators { Euler, SemiImplicitEuler };
|
||||
|
||||
|
@ -48,11 +46,12 @@ public class CartpoleEnvironment implements Environment<Integer> {
|
|||
private Integer stepsBeyondDone;
|
||||
|
||||
public CartpoleEnvironment() {
|
||||
rnd = new Random();
|
||||
this(Nd4j.getRandom());
|
||||
}
|
||||
|
||||
public CartpoleEnvironment(int seed) {
|
||||
rnd = new Random(seed);
|
||||
public CartpoleEnvironment(Random rnd) {
|
||||
this.rnd = rnd;
|
||||
this.schema = new Schema<Integer>(new IntegerActionSchema(NUM_ACTIONS, ACTION_LEFT, rnd));
|
||||
}
|
||||
|
||||
@Override
|
||||
|
|
|
@ -17,16 +17,19 @@
|
|||
|
||||
package org.deeplearning4j.rl4j.policy;
|
||||
|
||||
import lombok.AllArgsConstructor;
|
||||
import lombok.Builder;
|
||||
import lombok.NonNull;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.deeplearning4j.rl4j.environment.IActionSchema;
|
||||
import org.deeplearning4j.rl4j.learning.IEpochTrainer;
|
||||
import org.deeplearning4j.rl4j.mdp.MDP;
|
||||
import org.deeplearning4j.rl4j.network.NeuralNet;
|
||||
import org.deeplearning4j.rl4j.space.Encodable;
|
||||
import org.deeplearning4j.rl4j.observation.Observation;
|
||||
import org.deeplearning4j.rl4j.space.ActionSpace;
|
||||
import org.deeplearning4j.rl4j.space.Encodable;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.api.rng.Random;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
|
||||
/**
|
||||
* @author rubenfiszel (ruben.fiszel@epfl.ch) 7/24/16.
|
||||
|
@ -38,18 +41,60 @@ import org.nd4j.linalg.api.rng.Random;
|
|||
* epislon is annealed to minEpsilon over epsilonNbStep steps
|
||||
*
|
||||
*/
|
||||
@AllArgsConstructor
|
||||
@Slf4j
|
||||
public class EpsGreedy<OBSERVATION extends Encodable, A, AS extends ActionSpace<A>> extends Policy<A> {
|
||||
public class EpsGreedy<A> extends Policy<A> {
|
||||
|
||||
final private Policy<A> policy;
|
||||
final private MDP<OBSERVATION, A, AS> mdp;
|
||||
final private INeuralNetPolicy<A> policy;
|
||||
final private int updateStart;
|
||||
final private int epsilonNbStep;
|
||||
final private Random rnd;
|
||||
final private double minEpsilon;
|
||||
|
||||
private final IActionSchema<A> actionSchema;
|
||||
|
||||
final private MDP<Encodable, A, ActionSpace<A>> mdp;
|
||||
final private IEpochTrainer learning;
|
||||
|
||||
// Using agent's (learning's) step count is incorrect; frame skipping makes epsilon's value decrease too quickly
|
||||
private int annealingStep = 0;
|
||||
|
||||
@Deprecated
|
||||
public <OBSERVATION extends Encodable, AS extends ActionSpace<A>> EpsGreedy(Policy<A> policy,
|
||||
MDP<Encodable, A, ActionSpace<A>> mdp,
|
||||
int updateStart,
|
||||
int epsilonNbStep,
|
||||
Random rnd,
|
||||
double minEpsilon,
|
||||
IEpochTrainer learning) {
|
||||
this.policy = policy;
|
||||
this.mdp = mdp;
|
||||
this.updateStart = updateStart;
|
||||
this.epsilonNbStep = epsilonNbStep;
|
||||
this.rnd = rnd;
|
||||
this.minEpsilon = minEpsilon;
|
||||
this.learning = learning;
|
||||
|
||||
this.actionSchema = null;
|
||||
}
|
||||
|
||||
public EpsGreedy(@NonNull Policy<A> policy, @NonNull IActionSchema<A> actionSchema, double minEpsilon, int updateStart, int epsilonNbStep) {
|
||||
this(policy, actionSchema, minEpsilon, updateStart, epsilonNbStep, null);
|
||||
}
|
||||
|
||||
@Builder
|
||||
public EpsGreedy(@NonNull INeuralNetPolicy<A> policy, @NonNull IActionSchema<A> actionSchema, double minEpsilon, int updateStart, int epsilonNbStep, Random rnd) {
|
||||
this.policy = policy;
|
||||
|
||||
this.rnd = rnd == null ? Nd4j.getRandom() : rnd;
|
||||
this.minEpsilon = minEpsilon;
|
||||
this.updateStart = updateStart;
|
||||
this.epsilonNbStep = epsilonNbStep;
|
||||
this.actionSchema = actionSchema;
|
||||
|
||||
this.mdp = null;
|
||||
this.learning = null;
|
||||
}
|
||||
|
||||
public NeuralNet getNeuralNet() {
|
||||
return policy.getNeuralNet();
|
||||
}
|
||||
|
@ -57,6 +102,11 @@ public class EpsGreedy<OBSERVATION extends Encodable, A, AS extends ActionSpace<
|
|||
public A nextAction(INDArray input) {
|
||||
|
||||
double ep = getEpsilon();
|
||||
if(actionSchema != null) {
|
||||
// Only legacy classes should pass here.
|
||||
throw new RuntimeException("nextAction(Observation observation) should be called when using a AgentLearner");
|
||||
}
|
||||
|
||||
if (learning.getStepCount() % 500 == 1)
|
||||
log.info("EP: " + ep + " " + learning.getStepCount());
|
||||
if (rnd.nextDouble() > ep)
|
||||
|
@ -66,10 +116,31 @@ public class EpsGreedy<OBSERVATION extends Encodable, A, AS extends ActionSpace<
|
|||
}
|
||||
|
||||
public A nextAction(Observation observation) {
|
||||
return this.nextAction(observation.getData());
|
||||
if(actionSchema == null) {
|
||||
return this.nextAction(observation.getData());
|
||||
}
|
||||
|
||||
A result;
|
||||
|
||||
double ep = getEpsilon();
|
||||
if (annealingStep % 500 == 1) {
|
||||
log.info("EP: " + ep + " " + annealingStep);
|
||||
}
|
||||
|
||||
if (rnd.nextDouble() > ep) {
|
||||
result = policy.nextAction(observation);
|
||||
}
|
||||
else {
|
||||
result = actionSchema.getRandomAction();
|
||||
}
|
||||
|
||||
++annealingStep;
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
public double getEpsilon() {
|
||||
return Math.min(1.0, Math.max(minEpsilon, 1.0 - (learning.getStepCount() - updateStart) * 1.0 / epsilonNbStep));
|
||||
int step = actionSchema != null ? annealingStep : learning.getStepCount();
|
||||
return Math.min(1.0, Math.max(minEpsilon, 1.0 - (step - updateStart) * 1.0 / epsilonNbStep));
|
||||
}
|
||||
}
|
||||
|
|
|
@ -0,0 +1,7 @@
|
|||
package org.deeplearning4j.rl4j.policy;
|
||||
|
||||
import org.deeplearning4j.rl4j.network.NeuralNet;
|
||||
|
||||
public interface INeuralNetPolicy<ACTION> extends IPolicy<ACTION> {
|
||||
NeuralNet getNeuralNet();
|
||||
}
|
|
@ -34,7 +34,7 @@ import org.deeplearning4j.rl4j.util.LegacyMDPWrapper;
|
|||
*
|
||||
* A Policy responsability is to choose the next action given a state
|
||||
*/
|
||||
public abstract class Policy<A> implements IPolicy<A> {
|
||||
public abstract class Policy<A> implements INeuralNetPolicy<A> {
|
||||
|
||||
public abstract NeuralNet getNeuralNet();
|
||||
|
||||
|
|
|
@ -0,0 +1,211 @@
|
|||
package org.deeplearning4j.rl4j.agent;
|
||||
|
||||
import org.deeplearning4j.rl4j.agent.learning.LearningBehavior;
|
||||
import org.deeplearning4j.rl4j.environment.Environment;
|
||||
import org.deeplearning4j.rl4j.environment.IntegerActionSchema;
|
||||
import org.deeplearning4j.rl4j.environment.Schema;
|
||||
import org.deeplearning4j.rl4j.environment.StepResult;
|
||||
import org.deeplearning4j.rl4j.observation.Observation;
|
||||
import org.deeplearning4j.rl4j.observation.transform.TransformProcess;
|
||||
import org.deeplearning4j.rl4j.policy.IPolicy;
|
||||
import org.junit.Test;
|
||||
import org.junit.runner.RunWith;
|
||||
import org.mockito.ArgumentCaptor;
|
||||
import org.mockito.Mock;
|
||||
import org.mockito.invocation.InvocationOnMock;
|
||||
import org.mockito.junit.MockitoJUnitRunner;
|
||||
import org.mockito.stubbing.Answer;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
import static org.mockito.ArgumentMatchers.*;
|
||||
import static org.mockito.Mockito.*;
|
||||
import static org.junit.Assert.*;
|
||||
|
||||
@RunWith(MockitoJUnitRunner.class)
|
||||
public class AgentLearnerTest {
|
||||
|
||||
@Mock
|
||||
Environment<Integer> environmentMock;
|
||||
|
||||
@Mock
|
||||
TransformProcess transformProcessMock;
|
||||
|
||||
@Mock
|
||||
IPolicy<Integer> policyMock;
|
||||
|
||||
@Mock
|
||||
LearningBehavior<Integer, Object> learningBehaviorMock;
|
||||
|
||||
@Test
|
||||
public void when_episodeIsStarted_expect_learningBehaviorHandleEpisodeStartCalled() {
|
||||
// Arrange
|
||||
AgentLearner<Integer> sut = AgentLearner.builder(environmentMock, transformProcessMock, policyMock, learningBehaviorMock)
|
||||
.maxEpisodeSteps(3)
|
||||
.build();
|
||||
|
||||
Schema schema = new Schema(new IntegerActionSchema(0, -1));
|
||||
when(environmentMock.reset()).thenReturn(new HashMap<>());
|
||||
when(environmentMock.getSchema()).thenReturn(schema);
|
||||
StepResult stepResult = new StepResult(new HashMap<>(), 234.0, false);
|
||||
when(environmentMock.step(any(Integer.class))).thenReturn(stepResult);
|
||||
|
||||
when(transformProcessMock.transform(any(Map.class), anyInt(), anyBoolean())).thenReturn(new Observation(Nd4j.create(new double[] { 123.0 })));
|
||||
|
||||
when(policyMock.nextAction(any(Observation.class))).thenReturn(123);
|
||||
|
||||
// Act
|
||||
sut.run();
|
||||
|
||||
// Assert
|
||||
verify(learningBehaviorMock, times(1)).handleEpisodeStart();
|
||||
}
|
||||
|
||||
@Test
|
||||
public void when_runIsCalled_expect_experienceHandledWithLearningBehavior() {
|
||||
// Arrange
|
||||
AgentLearner<Integer> sut = AgentLearner.builder(environmentMock, transformProcessMock, policyMock, learningBehaviorMock)
|
||||
.maxEpisodeSteps(4)
|
||||
.build();
|
||||
|
||||
Schema schema = new Schema(new IntegerActionSchema(0, -1));
|
||||
when(environmentMock.getSchema()).thenReturn(schema);
|
||||
when(environmentMock.reset()).thenReturn(new HashMap<>());
|
||||
|
||||
double[] reward = new double[] { 0.0 };
|
||||
when(environmentMock.step(any(Integer.class)))
|
||||
.thenAnswer(a -> new StepResult(new HashMap<>(), ++reward[0], reward[0] == 4.0));
|
||||
|
||||
when(environmentMock.isEpisodeFinished()).thenAnswer(x -> reward[0] == 4.0);
|
||||
|
||||
when(transformProcessMock.transform(any(Map.class), anyInt(), anyBoolean()))
|
||||
.thenAnswer(new Answer<Observation>() {
|
||||
public Observation answer(InvocationOnMock invocation) throws Throwable {
|
||||
int step = (int)invocation.getArgument(1);
|
||||
boolean isTerminal = (boolean)invocation.getArgument(2);
|
||||
return (step % 2 == 0 || isTerminal)
|
||||
? new Observation(Nd4j.create(new double[] { step * 1.1 }))
|
||||
: Observation.SkippedObservation;
|
||||
}
|
||||
});
|
||||
|
||||
when(policyMock.nextAction(any(Observation.class))).thenAnswer(x -> (int)reward[0]);
|
||||
|
||||
// Act
|
||||
sut.run();
|
||||
|
||||
// Assert
|
||||
ArgumentCaptor<Observation> observationCaptor = ArgumentCaptor.forClass(Observation.class);
|
||||
ArgumentCaptor<Integer> actionCaptor = ArgumentCaptor.forClass(Integer.class);
|
||||
ArgumentCaptor<Double> rewardCaptor = ArgumentCaptor.forClass(Double.class);
|
||||
ArgumentCaptor<Boolean> isTerminalCaptor = ArgumentCaptor.forClass(Boolean.class);
|
||||
|
||||
verify(learningBehaviorMock, times(2)).handleNewExperience(observationCaptor.capture(), actionCaptor.capture(), rewardCaptor.capture(), isTerminalCaptor.capture());
|
||||
List<Observation> observations = observationCaptor.getAllValues();
|
||||
List<Integer> actions = actionCaptor.getAllValues();
|
||||
List<Double> rewards = rewardCaptor.getAllValues();
|
||||
List<Boolean> isTerminalList = isTerminalCaptor.getAllValues();
|
||||
|
||||
assertEquals(0.0, observations.get(0).getData().getDouble(0), 0.00001);
|
||||
assertEquals(0, (int)actions.get(0));
|
||||
assertEquals(0.0 + 1.0, rewards.get(0), 0.00001);
|
||||
assertFalse(isTerminalList.get(0));
|
||||
|
||||
assertEquals(2.2, observations.get(1).getData().getDouble(0), 0.00001);
|
||||
assertEquals(2, (int)actions.get(1));
|
||||
assertEquals(2.0 + 3.0, rewards.get(1), 0.00001);
|
||||
assertFalse(isTerminalList.get(1));
|
||||
|
||||
ArgumentCaptor<Observation> finalObservationCaptor = ArgumentCaptor.forClass(Observation.class);
|
||||
verify(learningBehaviorMock, times(1)).handleEpisodeEnd(finalObservationCaptor.capture());
|
||||
assertEquals(4.4, finalObservationCaptor.getValue().getData().getDouble(0), 0.00001);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void when_runIsCalledMultipleTimes_expect_totalStepCountCorrect() {
|
||||
// Arrange
|
||||
AgentLearner<Integer> sut = AgentLearner.builder(environmentMock, transformProcessMock, policyMock, learningBehaviorMock)
|
||||
.maxEpisodeSteps(4)
|
||||
.build();
|
||||
|
||||
Schema schema = new Schema(new IntegerActionSchema(0, -1));
|
||||
when(environmentMock.getSchema()).thenReturn(schema);
|
||||
when(environmentMock.reset()).thenReturn(new HashMap<>());
|
||||
|
||||
double[] reward = new double[] { 0.0 };
|
||||
when(environmentMock.step(any(Integer.class)))
|
||||
.thenAnswer(a -> new StepResult(new HashMap<>(), ++reward[0], reward[0] == 4.0));
|
||||
|
||||
when(environmentMock.isEpisodeFinished()).thenAnswer(x -> reward[0] == 4.0);
|
||||
|
||||
when(transformProcessMock.transform(any(Map.class), anyInt(), anyBoolean()))
|
||||
.thenAnswer(new Answer<Observation>() {
|
||||
public Observation answer(InvocationOnMock invocation) throws Throwable {
|
||||
int step = (int)invocation.getArgument(1);
|
||||
boolean isTerminal = (boolean)invocation.getArgument(2);
|
||||
return (step % 2 == 0 || isTerminal)
|
||||
? new Observation(Nd4j.create(new double[] { step * 1.1 }))
|
||||
: Observation.SkippedObservation;
|
||||
}
|
||||
});
|
||||
|
||||
when(policyMock.nextAction(any(Observation.class))).thenAnswer(x -> (int)reward[0]);
|
||||
|
||||
// Act
|
||||
sut.run();
|
||||
reward[0] = 0.0;
|
||||
sut.run();
|
||||
|
||||
// Assert
|
||||
assertEquals(8, sut.getTotalStepCount());
|
||||
}
|
||||
|
||||
@Test
|
||||
public void when_runIsCalledMultipleTimes_expect_rewardSentToLearningBehaviorToBeCorrect() {
|
||||
// Arrange
|
||||
AgentLearner<Integer> sut = AgentLearner.builder(environmentMock, transformProcessMock, policyMock, learningBehaviorMock)
|
||||
.maxEpisodeSteps(4)
|
||||
.build();
|
||||
|
||||
Schema schema = new Schema(new IntegerActionSchema(0, -1));
|
||||
when(environmentMock.getSchema()).thenReturn(schema);
|
||||
when(environmentMock.reset()).thenReturn(new HashMap<>());
|
||||
|
||||
double[] reward = new double[] { 0.0 };
|
||||
when(environmentMock.step(any(Integer.class)))
|
||||
.thenAnswer(a -> new StepResult(new HashMap<>(), ++reward[0], reward[0] == 4.0));
|
||||
|
||||
when(environmentMock.isEpisodeFinished()).thenAnswer(x -> reward[0] == 4.0);
|
||||
|
||||
when(transformProcessMock.transform(any(Map.class), anyInt(), anyBoolean()))
|
||||
.thenAnswer(new Answer<Observation>() {
|
||||
public Observation answer(InvocationOnMock invocation) throws Throwable {
|
||||
int step = (int)invocation.getArgument(1);
|
||||
boolean isTerminal = (boolean)invocation.getArgument(2);
|
||||
return (step % 2 == 0 || isTerminal)
|
||||
? new Observation(Nd4j.create(new double[] { step * 1.1 }))
|
||||
: Observation.SkippedObservation;
|
||||
}
|
||||
});
|
||||
|
||||
when(policyMock.nextAction(any(Observation.class))).thenAnswer(x -> (int)reward[0]);
|
||||
|
||||
// Act
|
||||
sut.run();
|
||||
reward[0] = 0.0;
|
||||
sut.run();
|
||||
|
||||
// Assert
|
||||
ArgumentCaptor<Double> rewardCaptor = ArgumentCaptor.forClass(Double.class);
|
||||
|
||||
verify(learningBehaviorMock, times(4)).handleNewExperience(any(Observation.class), any(Integer.class), rewardCaptor.capture(), any(Boolean.class));
|
||||
List<Double> rewards = rewardCaptor.getAllValues();
|
||||
|
||||
// rewardAtLastExperience at the end of 1st call to .run() should not leak into 2nd call.
|
||||
assertEquals(0.0 + 1.0, rewards.get(2), 0.00001);
|
||||
assertEquals(2.0 + 3.0, rewards.get(3), 0.00001);
|
||||
}
|
||||
}
|
|
@ -1,10 +1,7 @@
|
|||
package org.deeplearning4j.rl4j.agent;
|
||||
|
||||
import org.deeplearning4j.rl4j.agent.listener.AgentListener;
|
||||
import org.deeplearning4j.rl4j.environment.ActionSchema;
|
||||
import org.deeplearning4j.rl4j.environment.Environment;
|
||||
import org.deeplearning4j.rl4j.environment.Schema;
|
||||
import org.deeplearning4j.rl4j.environment.StepResult;
|
||||
import org.deeplearning4j.rl4j.environment.*;
|
||||
import org.deeplearning4j.rl4j.observation.Observation;
|
||||
import org.deeplearning4j.rl4j.observation.transform.TransformProcess;
|
||||
import org.deeplearning4j.rl4j.policy.IPolicy;
|
||||
|
@ -12,6 +9,7 @@ import org.junit.Rule;
|
|||
import org.junit.Test;
|
||||
import static org.junit.Assert.*;
|
||||
|
||||
import org.junit.runner.RunWith;
|
||||
import org.mockito.*;
|
||||
import org.mockito.junit.*;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
|
@ -23,8 +21,8 @@ import java.util.Map;
|
|||
import static org.mockito.ArgumentMatchers.*;
|
||||
import static org.mockito.Mockito.*;
|
||||
|
||||
@RunWith(MockitoJUnitRunner.class)
|
||||
public class AgentTest {
|
||||
|
||||
@Mock Environment environmentMock;
|
||||
@Mock TransformProcess transformProcessMock;
|
||||
@Mock IPolicy policyMock;
|
||||
|
@ -102,7 +100,7 @@ public class AgentTest {
|
|||
public void when_runIsCalled_expect_agentIsReset() {
|
||||
// Arrange
|
||||
Map<String, Object> envResetResult = new HashMap<>();
|
||||
Schema schema = new Schema(new ActionSchema<>(-1));
|
||||
Schema schema = new Schema(new IntegerActionSchema(0, -1));
|
||||
when(environmentMock.reset()).thenReturn(envResetResult);
|
||||
when(environmentMock.getSchema()).thenReturn(schema);
|
||||
|
||||
|
@ -119,7 +117,7 @@ public class AgentTest {
|
|||
sut.run();
|
||||
|
||||
// Assert
|
||||
assertEquals(0, sut.getEpisodeStepNumber());
|
||||
assertEquals(0, sut.getEpisodeStepCount());
|
||||
verify(transformProcessMock).transform(envResetResult, 0, false);
|
||||
verify(policyMock, times(1)).reset();
|
||||
assertEquals(0.0, sut.getReward(), 0.00001);
|
||||
|
@ -130,7 +128,7 @@ public class AgentTest {
|
|||
public void when_runIsCalled_expect_onBeforeAndAfterEpisodeCalled() {
|
||||
// Arrange
|
||||
Map<String, Object> envResetResult = new HashMap<>();
|
||||
Schema schema = new Schema(new ActionSchema<>(-1));
|
||||
Schema schema = new Schema(new IntegerActionSchema(0, -1));
|
||||
when(environmentMock.reset()).thenReturn(envResetResult);
|
||||
when(environmentMock.getSchema()).thenReturn(schema);
|
||||
|
||||
|
@ -152,7 +150,7 @@ public class AgentTest {
|
|||
public void when_onBeforeEpisodeReturnsStop_expect_performStepAndOnAfterEpisodeNotCalled() {
|
||||
// Arrange
|
||||
Map<String, Object> envResetResult = new HashMap<>();
|
||||
Schema schema = new Schema(new ActionSchema<>(-1));
|
||||
Schema schema = new Schema(new IntegerActionSchema(0, -1));
|
||||
when(environmentMock.reset()).thenReturn(envResetResult);
|
||||
when(environmentMock.getSchema()).thenReturn(schema);
|
||||
|
||||
|
@ -179,7 +177,7 @@ public class AgentTest {
|
|||
public void when_runIsCalledWithoutMaxStep_expect_agentRunUntilEpisodeIsFinished() {
|
||||
// Arrange
|
||||
Map<String, Object> envResetResult = new HashMap<>();
|
||||
Schema schema = new Schema(new ActionSchema<>(-1));
|
||||
Schema schema = new Schema(new IntegerActionSchema(0, -1));
|
||||
when(environmentMock.reset()).thenReturn(envResetResult);
|
||||
when(environmentMock.getSchema()).thenReturn(schema);
|
||||
|
||||
|
@ -191,10 +189,10 @@ public class AgentTest {
|
|||
final Agent spy = Mockito.spy(sut);
|
||||
|
||||
doAnswer(invocation -> {
|
||||
((Agent)invocation.getMock()).incrementEpisodeStepNumber();
|
||||
((Agent)invocation.getMock()).incrementEpisodeStepCount();
|
||||
return null;
|
||||
}).when(spy).performStep();
|
||||
when(environmentMock.isEpisodeFinished()).thenAnswer(invocation -> spy.getEpisodeStepNumber() >= 5 );
|
||||
when(environmentMock.isEpisodeFinished()).thenAnswer(invocation -> spy.getEpisodeStepCount() >= 5 );
|
||||
|
||||
// Act
|
||||
spy.run();
|
||||
|
@ -209,7 +207,7 @@ public class AgentTest {
|
|||
public void when_maxStepsIsReachedBeforeEposideEnds_expect_runTerminated() {
|
||||
// Arrange
|
||||
Map<String, Object> envResetResult = new HashMap<>();
|
||||
Schema schema = new Schema(new ActionSchema<>(-1));
|
||||
Schema schema = new Schema(new IntegerActionSchema(0, -1));
|
||||
when(environmentMock.reset()).thenReturn(envResetResult);
|
||||
when(environmentMock.getSchema()).thenReturn(schema);
|
||||
|
||||
|
@ -222,7 +220,7 @@ public class AgentTest {
|
|||
final Agent spy = Mockito.spy(sut);
|
||||
|
||||
doAnswer(invocation -> {
|
||||
((Agent)invocation.getMock()).incrementEpisodeStepNumber();
|
||||
((Agent)invocation.getMock()).incrementEpisodeStepCount();
|
||||
return null;
|
||||
}).when(spy).performStep();
|
||||
|
||||
|
@ -239,7 +237,7 @@ public class AgentTest {
|
|||
public void when_initialObservationsAreSkipped_expect_performNoOpAction() {
|
||||
// Arrange
|
||||
Map<String, Object> envResetResult = new HashMap<>();
|
||||
Schema schema = new Schema(new ActionSchema<>(-1));
|
||||
Schema schema = new Schema(new IntegerActionSchema(0, -1));
|
||||
when(environmentMock.reset()).thenReturn(envResetResult);
|
||||
when(environmentMock.getSchema()).thenReturn(schema);
|
||||
|
||||
|
@ -264,7 +262,7 @@ public class AgentTest {
|
|||
public void when_initialObservationsAreSkipped_expect_performNoOpActionAnd() {
|
||||
// Arrange
|
||||
Map<String, Object> envResetResult = new HashMap<>();
|
||||
Schema schema = new Schema(new ActionSchema<>(-1));
|
||||
Schema schema = new Schema(new IntegerActionSchema(0, -1));
|
||||
when(environmentMock.reset()).thenReturn(envResetResult);
|
||||
when(environmentMock.getSchema()).thenReturn(schema);
|
||||
|
||||
|
@ -289,7 +287,7 @@ public class AgentTest {
|
|||
public void when_observationsIsSkipped_expect_performLastAction() {
|
||||
// Arrange
|
||||
Map<String, Object> envResetResult = new HashMap<>();
|
||||
Schema schema = new Schema(new ActionSchema<>(-1));
|
||||
Schema schema = new Schema(new IntegerActionSchema(0, -1));
|
||||
when(environmentMock.reset()).thenReturn(envResetResult);
|
||||
when(environmentMock.step(any(Integer.class))).thenReturn(new StepResult(envResetResult, 0.0, false));
|
||||
when(environmentMock.getSchema()).thenReturn(schema);
|
||||
|
@ -331,7 +329,7 @@ public class AgentTest {
|
|||
@Test
|
||||
public void when_onBeforeStepReturnsStop_expect_performStepAndOnAfterEpisodeNotCalled() {
|
||||
// Arrange
|
||||
Schema schema = new Schema(new ActionSchema<>(-1));
|
||||
Schema schema = new Schema(new IntegerActionSchema(0, -1));
|
||||
when(environmentMock.reset()).thenReturn(new HashMap<>());
|
||||
when(environmentMock.getSchema()).thenReturn(schema);
|
||||
|
||||
|
@ -358,7 +356,7 @@ public class AgentTest {
|
|||
@Test
|
||||
public void when_observationIsNotSkipped_expect_policyActionIsSentToEnvironment() {
|
||||
// Arrange
|
||||
Schema schema = new Schema(new ActionSchema<>(-1));
|
||||
Schema schema = new Schema(new IntegerActionSchema(0, -1));
|
||||
when(environmentMock.reset()).thenReturn(new HashMap<>());
|
||||
when(environmentMock.getSchema()).thenReturn(schema);
|
||||
when(environmentMock.step(any(Integer.class))).thenReturn(new StepResult(new HashMap<>(), 0.0, false));
|
||||
|
@ -381,7 +379,7 @@ public class AgentTest {
|
|||
@Test
|
||||
public void when_stepResultIsReceived_expect_observationAndRewardUpdated() {
|
||||
// Arrange
|
||||
Schema schema = new Schema(new ActionSchema<>(-1));
|
||||
Schema schema = new Schema(new IntegerActionSchema(0, -1));
|
||||
when(environmentMock.reset()).thenReturn(new HashMap<>());
|
||||
when(environmentMock.getSchema()).thenReturn(schema);
|
||||
when(environmentMock.step(any(Integer.class))).thenReturn(new StepResult(new HashMap<>(), 234.0, false));
|
||||
|
@ -405,7 +403,7 @@ public class AgentTest {
|
|||
@Test
|
||||
public void when_stepIsDone_expect_onAfterStepAndWithStepResult() {
|
||||
// Arrange
|
||||
Schema schema = new Schema(new ActionSchema<>(-1));
|
||||
Schema schema = new Schema(new IntegerActionSchema(0, -1));
|
||||
when(environmentMock.reset()).thenReturn(new HashMap<>());
|
||||
when(environmentMock.getSchema()).thenReturn(schema);
|
||||
StepResult stepResult = new StepResult(new HashMap<>(), 234.0, false);
|
||||
|
@ -430,7 +428,7 @@ public class AgentTest {
|
|||
@Test
|
||||
public void when_onAfterStepReturnsStop_expect_onAfterEpisodeNotCalled() {
|
||||
// Arrange
|
||||
Schema schema = new Schema(new ActionSchema<>(-1));
|
||||
Schema schema = new Schema(new IntegerActionSchema(0, -1));
|
||||
when(environmentMock.reset()).thenReturn(new HashMap<>());
|
||||
when(environmentMock.getSchema()).thenReturn(schema);
|
||||
StepResult stepResult = new StepResult(new HashMap<>(), 234.0, false);
|
||||
|
@ -458,7 +456,7 @@ public class AgentTest {
|
|||
@Test
|
||||
public void when_runIsCalled_expect_onAfterEpisodeIsCalled() {
|
||||
// Arrange
|
||||
Schema schema = new Schema(new ActionSchema<>(-1));
|
||||
Schema schema = new Schema(new IntegerActionSchema(0, -1));
|
||||
when(environmentMock.reset()).thenReturn(new HashMap<>());
|
||||
when(environmentMock.getSchema()).thenReturn(schema);
|
||||
StepResult stepResult = new StepResult(new HashMap<>(), 234.0, false);
|
||||
|
|
|
@ -0,0 +1,133 @@
|
|||
package org.deeplearning4j.rl4j.agent.learning;
|
||||
|
||||
import org.deeplearning4j.rl4j.agent.update.IUpdateRule;
|
||||
import org.deeplearning4j.rl4j.experience.ExperienceHandler;
|
||||
import org.deeplearning4j.rl4j.observation.Observation;
|
||||
import org.junit.Before;
|
||||
import org.junit.Test;
|
||||
import org.junit.runner.RunWith;
|
||||
import org.mockito.ArgumentCaptor;
|
||||
import org.mockito.Mock;
|
||||
import org.mockito.junit.MockitoJUnitRunner;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
|
||||
import static org.junit.Assert.assertEquals;
|
||||
import static org.junit.Assert.assertFalse;
|
||||
import static org.mockito.ArgumentMatchers.any;
|
||||
import static org.mockito.Mockito.*;
|
||||
|
||||
@RunWith(MockitoJUnitRunner.class)
|
||||
public class LearningBehaviorTest {
|
||||
|
||||
@Mock
|
||||
ExperienceHandler<Integer, Object> experienceHandlerMock;
|
||||
|
||||
@Mock
|
||||
IUpdateRule<Object> updateRuleMock;
|
||||
|
||||
LearningBehavior<Integer, Object> sut;
|
||||
|
||||
@Before
|
||||
public void setup() {
|
||||
sut = LearningBehavior.<Integer, Object>builder()
|
||||
.experienceHandler(experienceHandlerMock)
|
||||
.updateRule(updateRuleMock)
|
||||
.build();
|
||||
}
|
||||
|
||||
@Test
|
||||
public void when_callingHandleEpisodeStart_expect_experienceHandlerResetCalled() {
|
||||
// Arrange
|
||||
LearningBehavior<Integer, Object> sut = LearningBehavior.<Integer, Object>builder()
|
||||
.experienceHandler(experienceHandlerMock)
|
||||
.updateRule(updateRuleMock)
|
||||
.build();
|
||||
|
||||
// Act
|
||||
sut.handleEpisodeStart();
|
||||
|
||||
// Assert
|
||||
verify(experienceHandlerMock, times(1)).reset();
|
||||
}
|
||||
|
||||
@Test
|
||||
public void when_callingHandleNewExperience_expect_experienceHandlerAddExperienceCalled() {
|
||||
// Arrange
|
||||
INDArray observationData = Nd4j.rand(1, 1);
|
||||
when(experienceHandlerMock.isTrainingBatchReady()).thenReturn(false);
|
||||
|
||||
// Act
|
||||
sut.handleNewExperience(new Observation(observationData), 1, 2.0, false);
|
||||
|
||||
// Assert
|
||||
ArgumentCaptor<Observation> observationCaptor = ArgumentCaptor.forClass(Observation.class);
|
||||
ArgumentCaptor<Integer> actionCaptor = ArgumentCaptor.forClass(Integer.class);
|
||||
ArgumentCaptor<Double> rewardCaptor = ArgumentCaptor.forClass(Double.class);
|
||||
ArgumentCaptor<Boolean> isTerminatedCaptor = ArgumentCaptor.forClass(Boolean.class);
|
||||
verify(experienceHandlerMock, times(1)).addExperience(observationCaptor.capture(), actionCaptor.capture(), rewardCaptor.capture(), isTerminatedCaptor.capture());
|
||||
|
||||
assertEquals(observationData.getDouble(0, 0), observationCaptor.getValue().getData().getDouble(0, 0), 0.00001);
|
||||
assertEquals(1, (int)actionCaptor.getValue());
|
||||
assertEquals(2.0, (double)rewardCaptor.getValue(), 0.00001);
|
||||
assertFalse(isTerminatedCaptor.getValue());
|
||||
|
||||
verify(updateRuleMock, never()).update(any(List.class));
|
||||
}
|
||||
|
||||
@Test
|
||||
public void when_callingHandleNewExperienceAndTrainingBatchIsReady_expect_updateRuleUpdateWithTrainingBatch() {
|
||||
// Arrange
|
||||
INDArray observationData = Nd4j.rand(1, 1);
|
||||
when(experienceHandlerMock.isTrainingBatchReady()).thenReturn(true);
|
||||
List<Object> trainingBatch = new ArrayList<Object>();
|
||||
when(experienceHandlerMock.generateTrainingBatch()).thenReturn(trainingBatch);
|
||||
|
||||
// Act
|
||||
sut.handleNewExperience(new Observation(observationData), 1, 2.0, false);
|
||||
|
||||
// Assert
|
||||
verify(updateRuleMock, times(1)).update(trainingBatch);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void when_callingHandleEpisodeEnd_expect_experienceHandlerSetFinalObservationCalled() {
|
||||
// Arrange
|
||||
INDArray observationData = Nd4j.rand(1, 1);
|
||||
when(experienceHandlerMock.isTrainingBatchReady()).thenReturn(false);
|
||||
|
||||
// Act
|
||||
sut.handleEpisodeEnd(new Observation(observationData));
|
||||
|
||||
// Assert
|
||||
ArgumentCaptor<Observation> observationCaptor = ArgumentCaptor.forClass(Observation.class);
|
||||
verify(experienceHandlerMock, times(1)).setFinalObservation(observationCaptor.capture());
|
||||
|
||||
assertEquals(observationData.getDouble(0, 0), observationCaptor.getValue().getData().getDouble(0, 0), 0.00001);
|
||||
|
||||
verify(updateRuleMock, never()).update(any(List.class));
|
||||
}
|
||||
|
||||
@Test
|
||||
public void when_callingHandleEpisodeEndAndTrainingBatchIsNotEmpty_expect_updateRuleUpdateWithTrainingBatch() {
|
||||
// Arrange
|
||||
INDArray observationData = Nd4j.rand(1, 1);
|
||||
when(experienceHandlerMock.isTrainingBatchReady()).thenReturn(true);
|
||||
List<Object> trainingBatch = new ArrayList<Object>();
|
||||
when(experienceHandlerMock.generateTrainingBatch()).thenReturn(trainingBatch);
|
||||
|
||||
// Act
|
||||
sut.handleEpisodeEnd(new Observation(observationData));
|
||||
|
||||
// Assert
|
||||
ArgumentCaptor<Observation> observationCaptor = ArgumentCaptor.forClass(Observation.class);
|
||||
verify(experienceHandlerMock, times(1)).setFinalObservation(observationCaptor.capture());
|
||||
|
||||
assertEquals(observationData.getDouble(0, 0), observationCaptor.getValue().getData().getDouble(0, 0), 0.00001);
|
||||
|
||||
verify(updateRuleMock, times(1)).update(trainingBatch);
|
||||
}
|
||||
}
|
|
@ -4,34 +4,44 @@ import org.deeplearning4j.rl4j.learning.sync.IExpReplay;
|
|||
import org.deeplearning4j.rl4j.learning.sync.Transition;
|
||||
import org.deeplearning4j.rl4j.observation.Observation;
|
||||
import org.junit.Test;
|
||||
import org.junit.runner.RunWith;
|
||||
import org.mockito.ArgumentCaptor;
|
||||
import org.mockito.Mock;
|
||||
import org.mockito.junit.MockitoJUnitRunner;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
|
||||
import static org.junit.Assert.assertEquals;
|
||||
import static org.junit.Assert.*;
|
||||
import static org.mockito.Mockito.*;
|
||||
|
||||
@RunWith(MockitoJUnitRunner.class)
|
||||
public class ReplayMemoryExperienceHandlerTest {
|
||||
|
||||
@Mock
|
||||
IExpReplay<Integer> expReplayMock;
|
||||
|
||||
@Test
|
||||
public void when_addingFirstExperience_expect_notAddedToStoreBeforeNextObservationIsAdded() {
|
||||
// Arrange
|
||||
TestExpReplay expReplayMock = new TestExpReplay();
|
||||
when(expReplayMock.getDesignatedBatchSize()).thenReturn(10);
|
||||
|
||||
ReplayMemoryExperienceHandler sut = new ReplayMemoryExperienceHandler(expReplayMock);
|
||||
|
||||
// Act
|
||||
sut.addExperience(new Observation(Nd4j.create(new double[] { 1.0 })), 1, 1.0, false);
|
||||
int numStoredTransitions = expReplayMock.addedTransitions.size();
|
||||
boolean isStoreCalledAfterFirstAdd = mockingDetails(expReplayMock).getInvocations().stream().anyMatch(x -> x.getMethod().getName() == "store");
|
||||
sut.addExperience(new Observation(Nd4j.create(new double[] { 2.0 })), 2, 2.0, false);
|
||||
boolean isStoreCalledAfterSecondAdd = mockingDetails(expReplayMock).getInvocations().stream().anyMatch(x -> x.getMethod().getName() == "store");
|
||||
|
||||
// Assert
|
||||
assertEquals(0, numStoredTransitions);
|
||||
assertEquals(1, expReplayMock.addedTransitions.size());
|
||||
assertFalse(isStoreCalledAfterFirstAdd);
|
||||
assertTrue(isStoreCalledAfterSecondAdd);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void when_addingExperience_expect_transitionsAreCorrect() {
|
||||
// Arrange
|
||||
TestExpReplay expReplayMock = new TestExpReplay();
|
||||
ReplayMemoryExperienceHandler sut = new ReplayMemoryExperienceHandler(expReplayMock);
|
||||
|
||||
// Act
|
||||
|
@ -40,24 +50,25 @@ public class ReplayMemoryExperienceHandlerTest {
|
|||
sut.setFinalObservation(new Observation(Nd4j.create(new double[] { 3.0 })));
|
||||
|
||||
// Assert
|
||||
assertEquals(2, expReplayMock.addedTransitions.size());
|
||||
ArgumentCaptor<Transition<Integer>> argument = ArgumentCaptor.forClass(Transition.class);
|
||||
verify(expReplayMock, times(2)).store(argument.capture());
|
||||
List<Transition<Integer>> transitions = argument.getAllValues();
|
||||
|
||||
assertEquals(1.0, expReplayMock.addedTransitions.get(0).getObservation().getData().getDouble(0), 0.00001);
|
||||
assertEquals(1, (int)expReplayMock.addedTransitions.get(0).getAction());
|
||||
assertEquals(1.0, expReplayMock.addedTransitions.get(0).getReward(), 0.00001);
|
||||
assertEquals(2.0, expReplayMock.addedTransitions.get(0).getNextObservation().getDouble(0), 0.00001);
|
||||
assertEquals(1.0, transitions.get(0).getObservation().getData().getDouble(0), 0.00001);
|
||||
assertEquals(1, (int)transitions.get(0).getAction());
|
||||
assertEquals(1.0, transitions.get(0).getReward(), 0.00001);
|
||||
assertEquals(2.0, transitions.get(0).getNextObservation().getDouble(0), 0.00001);
|
||||
|
||||
assertEquals(2.0, expReplayMock.addedTransitions.get(1).getObservation().getData().getDouble(0), 0.00001);
|
||||
assertEquals(2, (int)expReplayMock.addedTransitions.get(1).getAction());
|
||||
assertEquals(2.0, expReplayMock.addedTransitions.get(1).getReward(), 0.00001);
|
||||
assertEquals(3.0, expReplayMock.addedTransitions.get(1).getNextObservation().getDouble(0), 0.00001);
|
||||
assertEquals(2.0, transitions.get(1).getObservation().getData().getDouble(0), 0.00001);
|
||||
assertEquals(2, (int)transitions.get(1).getAction());
|
||||
assertEquals(2.0, transitions.get(1).getReward(), 0.00001);
|
||||
assertEquals(3.0, transitions.get(1).getNextObservation().getDouble(0), 0.00001);
|
||||
|
||||
}
|
||||
|
||||
@Test
|
||||
public void when_settingFinalObservation_expect_nextAddedExperienceDoNotUsePreviousObservation() {
|
||||
// Arrange
|
||||
TestExpReplay expReplayMock = new TestExpReplay();
|
||||
ReplayMemoryExperienceHandler sut = new ReplayMemoryExperienceHandler(expReplayMock);
|
||||
|
||||
// Act
|
||||
|
@ -66,42 +77,57 @@ public class ReplayMemoryExperienceHandlerTest {
|
|||
sut.addExperience(new Observation(Nd4j.create(new double[] { 3.0 })), 3, 3.0, false);
|
||||
|
||||
// Assert
|
||||
assertEquals(1, expReplayMock.addedTransitions.size());
|
||||
assertEquals(1, (int)expReplayMock.addedTransitions.get(0).getAction());
|
||||
ArgumentCaptor<Transition<Integer>> argument = ArgumentCaptor.forClass(Transition.class);
|
||||
verify(expReplayMock, times(1)).store(argument.capture());
|
||||
Transition<Integer> transition = argument.getValue();
|
||||
|
||||
assertEquals(1, (int)transition.getAction());
|
||||
}
|
||||
|
||||
@Test
|
||||
public void when_addingExperience_expect_getTrainingBatchSizeReturnSize() {
|
||||
// Arrange
|
||||
TestExpReplay expReplayMock = new TestExpReplay();
|
||||
ReplayMemoryExperienceHandler sut = new ReplayMemoryExperienceHandler(expReplayMock);
|
||||
ReplayMemoryExperienceHandler sut = new ReplayMemoryExperienceHandler(10, 5, Nd4j.getRandom());
|
||||
sut.addExperience(new Observation(Nd4j.create(new double[] { 1.0 })), 1, 1.0, false);
|
||||
sut.addExperience(new Observation(Nd4j.create(new double[] { 2.0 })), 2, 2.0, false);
|
||||
sut.setFinalObservation(new Observation(Nd4j.create(new double[] { 3.0 })));
|
||||
|
||||
// Act
|
||||
int size = sut.getTrainingBatchSize();
|
||||
|
||||
// Assert
|
||||
assertEquals(2, size);
|
||||
}
|
||||
|
||||
private static class TestExpReplay implements IExpReplay<Integer> {
|
||||
@Test
|
||||
public void when_experienceSizeIsSmallerThanBatchSize_expect_TrainingBatchIsNotReady() {
|
||||
// Arrange
|
||||
ReplayMemoryExperienceHandler sut = new ReplayMemoryExperienceHandler(10, 5, Nd4j.getRandom());
|
||||
sut.addExperience(new Observation(Nd4j.create(new double[] { 1.0 })), 1, 1.0, false);
|
||||
sut.addExperience(new Observation(Nd4j.create(new double[] { 2.0 })), 2, 2.0, false);
|
||||
sut.setFinalObservation(new Observation(Nd4j.create(new double[] { 3.0 })));
|
||||
|
||||
public final List<Transition<Integer>> addedTransitions = new ArrayList<>();
|
||||
// Act
|
||||
|
||||
@Override
|
||||
public ArrayList<Transition<Integer>> getBatch() {
|
||||
return null;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void store(Transition<Integer> transition) {
|
||||
addedTransitions.add(transition);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int getBatchSize() {
|
||||
return addedTransitions.size();
|
||||
}
|
||||
// Assert
|
||||
assertFalse(sut.isTrainingBatchReady());
|
||||
}
|
||||
|
||||
@Test
|
||||
public void when_experienceSizeIsGreaterOrEqualToBatchSize_expect_TrainingBatchIsReady() {
|
||||
// Arrange
|
||||
ReplayMemoryExperienceHandler sut = new ReplayMemoryExperienceHandler(10, 5, Nd4j.getRandom());
|
||||
sut.addExperience(new Observation(Nd4j.create(new double[] { 1.0 })), 1, 1.0, false);
|
||||
sut.addExperience(new Observation(Nd4j.create(new double[] { 2.0 })), 2, 2.0, false);
|
||||
sut.addExperience(new Observation(Nd4j.create(new double[] { 3.0 })), 3, 3.0, false);
|
||||
sut.addExperience(new Observation(Nd4j.create(new double[] { 4.0 })), 4, 4.0, false);
|
||||
sut.addExperience(new Observation(Nd4j.create(new double[] { 5.0 })), 5, 5.0, false);
|
||||
sut.setFinalObservation(new Observation(Nd4j.create(new double[] { 6.0 })));
|
||||
|
||||
// Act
|
||||
|
||||
// Assert
|
||||
assertTrue(sut.isTrainingBatchReady());
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
@ -13,7 +13,7 @@ public class StateActionExperienceHandlerTest {
|
|||
@Test
|
||||
public void when_addingExperience_expect_generateTrainingBatchReturnsIt() {
|
||||
// Arrange
|
||||
StateActionExperienceHandler sut = new StateActionExperienceHandler();
|
||||
StateActionExperienceHandler sut = new StateActionExperienceHandler(Integer.MAX_VALUE);
|
||||
sut.reset();
|
||||
Observation observation = new Observation(Nd4j.zeros(1));
|
||||
sut.addExperience(observation, 123, 234.0, true);
|
||||
|
@ -32,7 +32,7 @@ public class StateActionExperienceHandlerTest {
|
|||
@Test
|
||||
public void when_addingMultipleExperiences_expect_generateTrainingBatchReturnsItInSameOrder() {
|
||||
// Arrange
|
||||
StateActionExperienceHandler sut = new StateActionExperienceHandler();
|
||||
StateActionExperienceHandler sut = new StateActionExperienceHandler(Integer.MAX_VALUE);
|
||||
sut.reset();
|
||||
sut.addExperience(null, 1, 1.0, false);
|
||||
sut.addExperience(null, 2, 2.0, false);
|
||||
|
@ -51,7 +51,7 @@ public class StateActionExperienceHandlerTest {
|
|||
@Test
|
||||
public void when_gettingExperience_expect_experienceStoreIsCleared() {
|
||||
// Arrange
|
||||
StateActionExperienceHandler sut = new StateActionExperienceHandler();
|
||||
StateActionExperienceHandler sut = new StateActionExperienceHandler(Integer.MAX_VALUE);
|
||||
sut.reset();
|
||||
sut.addExperience(null, 1, 1.0, false);
|
||||
|
||||
|
@ -67,7 +67,7 @@ public class StateActionExperienceHandlerTest {
|
|||
@Test
|
||||
public void when_addingExperience_expect_getTrainingBatchSizeReturnSize() {
|
||||
// Arrange
|
||||
StateActionExperienceHandler sut = new StateActionExperienceHandler();
|
||||
StateActionExperienceHandler sut = new StateActionExperienceHandler(Integer.MAX_VALUE);
|
||||
sut.reset();
|
||||
sut.addExperience(null, 1, 1.0, false);
|
||||
sut.addExperience(null, 2, 2.0, false);
|
||||
|
@ -79,4 +79,66 @@ public class StateActionExperienceHandlerTest {
|
|||
// Assert
|
||||
assertEquals(3, size);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void when_experienceIsEmpty_expect_TrainingBatchNotReady() {
|
||||
// Arrange
|
||||
StateActionExperienceHandler sut = new StateActionExperienceHandler(5);
|
||||
sut.reset();
|
||||
|
||||
// Act
|
||||
boolean isTrainingBatchReady = sut.isTrainingBatchReady();
|
||||
|
||||
// Assert
|
||||
assertFalse(isTrainingBatchReady);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void when_experienceSizeIsGreaterOrEqualToThanBatchSize_expect_TrainingBatchIsReady() {
|
||||
// Arrange
|
||||
StateActionExperienceHandler sut = new StateActionExperienceHandler(5);
|
||||
sut.reset();
|
||||
sut.addExperience(null, 1, 1.0, false);
|
||||
sut.addExperience(null, 2, 2.0, false);
|
||||
sut.addExperience(null, 3, 3.0, false);
|
||||
sut.addExperience(null, 4, 4.0, false);
|
||||
sut.addExperience(null, 5, 5.0, false);
|
||||
|
||||
// Act
|
||||
boolean isTrainingBatchReady = sut.isTrainingBatchReady();
|
||||
|
||||
// Assert
|
||||
assertTrue(isTrainingBatchReady);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void when_experienceSizeIsSmallerThanBatchSizeButFinalObservationIsSet_expect_TrainingBatchIsReady() {
|
||||
// Arrange
|
||||
StateActionExperienceHandler sut = new StateActionExperienceHandler(5);
|
||||
sut.reset();
|
||||
sut.addExperience(null, 1, 1.0, false);
|
||||
sut.addExperience(null, 2, 2.0, false);
|
||||
sut.setFinalObservation(null);
|
||||
|
||||
// Act
|
||||
boolean isTrainingBatchReady = sut.isTrainingBatchReady();
|
||||
|
||||
// Assert
|
||||
assertTrue(isTrainingBatchReady);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void when_experienceSizeIsZeroAndFinalObservationIsSet_expect_TrainingBatchIsNotReady() {
|
||||
// Arrange
|
||||
StateActionExperienceHandler sut = new StateActionExperienceHandler(5);
|
||||
sut.reset();
|
||||
sut.setFinalObservation(null);
|
||||
|
||||
// Act
|
||||
boolean isTrainingBatchReady = sut.isTrainingBatchReady();
|
||||
|
||||
// Assert
|
||||
assertFalse(isTrainingBatchReady);
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
@ -49,4 +49,25 @@ public class INDArrayHelperTest {
|
|||
assertEquals(1, output.shape()[1]);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void when_callingCreateBatchForShape_expect_INDArrayWithCorrectShapeAndOriginalShapeUnchanged() {
|
||||
// Arrange
|
||||
long[] shape = new long[] { 1, 3, 4};
|
||||
|
||||
// Act
|
||||
INDArray output = INDArrayHelper.createBatchForShape(2, shape);
|
||||
|
||||
// Assert
|
||||
// Output shape
|
||||
assertEquals(3, output.shape().length);
|
||||
assertEquals(2, output.shape()[0]);
|
||||
assertEquals(3, output.shape()[1]);
|
||||
assertEquals(4, output.shape()[2]);
|
||||
|
||||
// Input should remain unchanged
|
||||
assertEquals(1, shape[0]);
|
||||
assertEquals(3, shape[1]);
|
||||
assertEquals(4, shape[2]);
|
||||
|
||||
}
|
||||
}
|
||||
|
|
|
@ -19,10 +19,11 @@ package org.deeplearning4j.rl4j.learning.async.nstep.discrete;
|
|||
import org.deeplearning4j.rl4j.experience.StateActionPair;
|
||||
import org.deeplearning4j.rl4j.learning.async.AsyncGlobal;
|
||||
import org.deeplearning4j.rl4j.learning.async.UpdateAlgorithm;
|
||||
import org.deeplearning4j.rl4j.network.dqn.IDQN;
|
||||
import org.deeplearning4j.rl4j.observation.Observation;
|
||||
import org.deeplearning4j.rl4j.support.MockDQN;
|
||||
import org.junit.Test;
|
||||
import org.junit.runner.RunWith;
|
||||
import org.mockito.ArgumentCaptor;
|
||||
import org.mockito.Mock;
|
||||
import org.mockito.junit.MockitoJUnitRunner;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
|
@ -32,6 +33,9 @@ import java.util.ArrayList;
|
|||
import java.util.List;
|
||||
|
||||
import static org.junit.Assert.assertEquals;
|
||||
import static org.mockito.ArgumentMatchers.any;
|
||||
import static org.mockito.ArgumentMatchers.argThat;
|
||||
import static org.mockito.Mockito.*;
|
||||
|
||||
@RunWith(MockitoJUnitRunner.class)
|
||||
public class QLearningUpdateAlgorithmTest {
|
||||
|
@ -39,12 +43,24 @@ public class QLearningUpdateAlgorithmTest {
|
|||
@Mock
|
||||
AsyncGlobal mockAsyncGlobal;
|
||||
|
||||
@Mock
|
||||
IDQN dqnMock;
|
||||
|
||||
private UpdateAlgorithm sut;
|
||||
|
||||
private void setup(double gamma) {
|
||||
// mock a neural net output -- just invert the sign of the input
|
||||
when(dqnMock.outputAll(any(INDArray.class))).thenAnswer(invocation -> new INDArray[] { invocation.getArgument(0, INDArray.class).mul(-1.0) });
|
||||
|
||||
sut = new QLearningUpdateAlgorithm(2, gamma);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void when_isTerminal_expect_initRewardIs0() {
|
||||
// Arrange
|
||||
MockDQN dqnMock = new MockDQN();
|
||||
UpdateAlgorithm sut = new QLearningUpdateAlgorithm(new int[] { 1 }, 1, 1.0);
|
||||
final Observation observation = new Observation(Nd4j.zeros(1));
|
||||
setup(1.0);
|
||||
|
||||
final Observation observation = new Observation(Nd4j.zeros(1, 2));
|
||||
List<StateActionPair<Integer>> experience = new ArrayList<StateActionPair<Integer>>() {
|
||||
{
|
||||
add(new StateActionPair<Integer>(observation, 0, 0.0, true));
|
||||
|
@ -55,59 +71,68 @@ public class QLearningUpdateAlgorithmTest {
|
|||
sut.computeGradients(dqnMock, experience);
|
||||
|
||||
// Assert
|
||||
assertEquals(0.0, dqnMock.gradientParams.get(0).getRight().getDouble(0), 0.00001);
|
||||
verify(dqnMock, times(1)).gradient(any(INDArray.class), argThat((INDArray x) -> x.getDouble(0) == 0.0));
|
||||
}
|
||||
|
||||
@Test
|
||||
public void when_terminalAndNoTargetUpdate_expect_initRewardWithMaxQFromCurrent() {
|
||||
// Arrange
|
||||
UpdateAlgorithm sut = new QLearningUpdateAlgorithm(new int[] { 2 }, 2, 1.0);
|
||||
final Observation observation = new Observation(Nd4j.create(new double[] { -123.0, -234.0 }));
|
||||
setup(1.0);
|
||||
|
||||
final Observation observation = new Observation(Nd4j.create(new double[] { -123.0, -234.0 }).reshape(1, 2));
|
||||
List<StateActionPair<Integer>> experience = new ArrayList<StateActionPair<Integer>>() {
|
||||
{
|
||||
add(new StateActionPair<Integer>(observation, 0, 0.0, false));
|
||||
}
|
||||
};
|
||||
MockDQN dqnMock = new MockDQN();
|
||||
|
||||
// Act
|
||||
sut.computeGradients(dqnMock, experience);
|
||||
|
||||
// Assert
|
||||
assertEquals(2, dqnMock.outputAllParams.size());
|
||||
assertEquals(-123.0, dqnMock.outputAllParams.get(0).getDouble(0, 0), 0.00001);
|
||||
assertEquals(234.0, dqnMock.gradientParams.get(0).getRight().getDouble(0), 0.00001);
|
||||
ArgumentCaptor<INDArray> argument = ArgumentCaptor.forClass(INDArray.class);
|
||||
|
||||
verify(dqnMock, times(2)).outputAll(argument.capture());
|
||||
List<INDArray> values = argument.getAllValues();
|
||||
assertEquals(-123.0, values.get(0).getDouble(0, 0), 0.00001);
|
||||
assertEquals(-123.0, values.get(1).getDouble(0, 0), 0.00001);
|
||||
|
||||
verify(dqnMock, times(1)).gradient(any(INDArray.class), argThat((INDArray x) -> x.getDouble(0) == 234.0));
|
||||
}
|
||||
|
||||
@Test
|
||||
public void when_callingWithMultipleExperiences_expect_gradientsAreValid() {
|
||||
// Arrange
|
||||
double gamma = 0.9;
|
||||
UpdateAlgorithm sut = new QLearningUpdateAlgorithm(new int[] { 2 }, 2, gamma);
|
||||
setup(gamma);
|
||||
|
||||
List<StateActionPair<Integer>> experience = new ArrayList<StateActionPair<Integer>>() {
|
||||
{
|
||||
add(new StateActionPair<Integer>(new Observation(Nd4j.create(new double[] { -1.1, -1.2 })), 0, 1.0, false));
|
||||
add(new StateActionPair<Integer>(new Observation(Nd4j.create(new double[] { -2.1, -2.2 })), 1, 2.0, true));
|
||||
add(new StateActionPair<Integer>(new Observation(Nd4j.create(new double[] { -1.1, -1.2 }).reshape(1, 2)), 0, 1.0, false));
|
||||
add(new StateActionPair<Integer>(new Observation(Nd4j.create(new double[] { -2.1, -2.2 }).reshape(1, 2)), 1, 2.0, true));
|
||||
}
|
||||
};
|
||||
MockDQN dqnMock = new MockDQN();
|
||||
|
||||
// Act
|
||||
sut.computeGradients(dqnMock, experience);
|
||||
|
||||
// Assert
|
||||
ArgumentCaptor<INDArray> features = ArgumentCaptor.forClass(INDArray.class);
|
||||
ArgumentCaptor<INDArray> targets = ArgumentCaptor.forClass(INDArray.class);
|
||||
verify(dqnMock, times(1)).gradient(features.capture(), targets.capture());
|
||||
|
||||
// input side -- should be a stack of observations
|
||||
INDArray input = dqnMock.gradientParams.get(0).getLeft();
|
||||
assertEquals(-1.1, input.getDouble(0, 0), 0.00001);
|
||||
assertEquals(-1.2, input.getDouble(0, 1), 0.00001);
|
||||
assertEquals(-2.1, input.getDouble(1, 0), 0.00001);
|
||||
assertEquals(-2.2, input.getDouble(1, 1), 0.00001);
|
||||
INDArray featuresValues = features.getValue();
|
||||
assertEquals(-1.1, featuresValues.getDouble(0, 0), 0.00001);
|
||||
assertEquals(-1.2, featuresValues.getDouble(0, 1), 0.00001);
|
||||
assertEquals(-2.1, featuresValues.getDouble(1, 0), 0.00001);
|
||||
assertEquals(-2.2, featuresValues.getDouble(1, 1), 0.00001);
|
||||
|
||||
// target side
|
||||
INDArray target = dqnMock.gradientParams.get(0).getRight();
|
||||
assertEquals(1.0 + gamma * 2.0, target.getDouble(0, 0), 0.00001);
|
||||
assertEquals(1.2, target.getDouble(0, 1), 0.00001);
|
||||
assertEquals(2.1, target.getDouble(1, 0), 0.00001);
|
||||
assertEquals(2.0, target.getDouble(1, 1), 0.00001);
|
||||
INDArray targetsValues = targets.getValue();
|
||||
assertEquals(1.0 + gamma * 2.0, targetsValues.getDouble(0, 0), 0.00001);
|
||||
assertEquals(1.2, targetsValues.getDouble(0, 1), 0.00001);
|
||||
assertEquals(2.1, targetsValues.getDouble(1, 0), 0.00001);
|
||||
assertEquals(2.0, targetsValues.getDouble(1, 1), 0.00001);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -18,6 +18,7 @@
|
|||
package org.deeplearning4j.rl4j.learning.sync.qlearning.discrete;
|
||||
|
||||
import org.deeplearning4j.gym.StepReply;
|
||||
import org.deeplearning4j.rl4j.agent.learning.ILearningBehavior;
|
||||
import org.deeplearning4j.rl4j.learning.IHistoryProcessor;
|
||||
import org.deeplearning4j.rl4j.learning.configuration.QLearningConfiguration;
|
||||
import org.deeplearning4j.rl4j.learning.sync.qlearning.QLearning;
|
||||
|
@ -74,6 +75,9 @@ public class QLearningDiscreteTest {
|
|||
@Mock
|
||||
QLearningConfiguration mockQlearningConfiguration;
|
||||
|
||||
@Mock
|
||||
ILearningBehavior<Integer> learningBehavior;
|
||||
|
||||
// HWC
|
||||
int[] observationShape = new int[]{3, 10, 10};
|
||||
int totalObservationSize = 1;
|
||||
|
@ -92,18 +96,28 @@ public class QLearningDiscreteTest {
|
|||
}
|
||||
|
||||
|
||||
private void mockTestContext(int maxSteps, int updateStart, int batchSize, double rewardFactor, int maxExperienceReplay) {
|
||||
private void mockTestContext(int maxSteps, int updateStart, int batchSize, double rewardFactor, int maxExperienceReplay, ILearningBehavior<Integer> learningBehavior) {
|
||||
when(mockQlearningConfiguration.getBatchSize()).thenReturn(batchSize);
|
||||
when(mockQlearningConfiguration.getRewardFactor()).thenReturn(rewardFactor);
|
||||
when(mockQlearningConfiguration.getExpRepMaxSize()).thenReturn(maxExperienceReplay);
|
||||
when(mockQlearningConfiguration.getSeed()).thenReturn(123L);
|
||||
|
||||
qLearningDiscrete = mock(
|
||||
QLearningDiscrete.class,
|
||||
Mockito.withSettings()
|
||||
.useConstructor(mockMDP, mockDQN, mockQlearningConfiguration, 0)
|
||||
.defaultAnswer(Mockito.CALLS_REAL_METHODS)
|
||||
);
|
||||
if(learningBehavior != null) {
|
||||
qLearningDiscrete = mock(
|
||||
QLearningDiscrete.class,
|
||||
Mockito.withSettings()
|
||||
.useConstructor(mockMDP, mockDQN, mockQlearningConfiguration, 0, learningBehavior, Nd4j.getRandom())
|
||||
.defaultAnswer(Mockito.CALLS_REAL_METHODS)
|
||||
);
|
||||
}
|
||||
else {
|
||||
qLearningDiscrete = mock(
|
||||
QLearningDiscrete.class,
|
||||
Mockito.withSettings()
|
||||
.useConstructor(mockMDP, mockDQN, mockQlearningConfiguration, 0)
|
||||
.defaultAnswer(Mockito.CALLS_REAL_METHODS)
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
private void mockHistoryProcessor(int skipFrames) {
|
||||
|
@ -136,7 +150,7 @@ public class QLearningDiscreteTest {
|
|||
public void when_singleTrainStep_expect_correctValues() {
|
||||
|
||||
// Arrange
|
||||
mockTestContext(100,0,2,1.0, 10);
|
||||
mockTestContext(100,0,2,1.0, 10, null);
|
||||
|
||||
// An example observation and 2 Q values output (2 actions)
|
||||
Observation observation = new Observation(Nd4j.zeros(observationShape));
|
||||
|
@ -162,7 +176,7 @@ public class QLearningDiscreteTest {
|
|||
@Test
|
||||
public void when_singleTrainStepSkippedFrames_expect_correctValues() {
|
||||
// Arrange
|
||||
mockTestContext(100,0,2,1.0, 10);
|
||||
mockTestContext(100,0,2,1.0, 10, learningBehavior);
|
||||
|
||||
Observation skippedObservation = Observation.SkippedObservation;
|
||||
Observation nextObservation = new Observation(Nd4j.zeros(observationShape));
|
||||
|
@ -180,8 +194,8 @@ public class QLearningDiscreteTest {
|
|||
assertEquals(0, stepReply.getReward(), 1e-5);
|
||||
assertFalse(stepReply.isDone());
|
||||
assertFalse(stepReply.getObservation().isSkipped());
|
||||
assertEquals(0, qLearningDiscrete.getExperienceHandler().getTrainingBatchSize());
|
||||
|
||||
verify(learningBehavior, never()).handleNewExperience(any(Observation.class), any(Integer.class), any(Double.class), any(Boolean.class));
|
||||
verify(mockDQN, never()).output(any(INDArray.class));
|
||||
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue