parent
a18417193d
commit
5568b9d72f
|
@ -1,3 +1,18 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2020 Konduit K.K.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
package org.deeplearning4j.rl4j.agent;
|
package org.deeplearning4j.rl4j.agent;
|
||||||
|
|
||||||
import lombok.AccessLevel;
|
import lombok.AccessLevel;
|
||||||
|
@ -14,7 +29,13 @@ import org.nd4j.common.base.Preconditions;
|
||||||
|
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
|
|
||||||
public class Agent<ACTION> {
|
/**
|
||||||
|
* An agent implementation. The Agent will use a {@link IPolicy} to interact with an {@link Environment} and receive
|
||||||
|
* a reward.
|
||||||
|
*
|
||||||
|
* @param <ACTION> The type of action
|
||||||
|
*/
|
||||||
|
public class Agent<ACTION> implements IAgent<ACTION> {
|
||||||
@Getter
|
@Getter
|
||||||
private final String id;
|
private final String id;
|
||||||
|
|
||||||
|
@ -37,19 +58,28 @@ public class Agent<ACTION> {
|
||||||
private ACTION lastAction;
|
private ACTION lastAction;
|
||||||
|
|
||||||
@Getter
|
@Getter
|
||||||
private int episodeStepNumber;
|
private int episodeStepCount;
|
||||||
|
|
||||||
@Getter
|
@Getter
|
||||||
private double reward;
|
private double reward;
|
||||||
|
|
||||||
protected boolean canContinue;
|
protected boolean canContinue;
|
||||||
|
|
||||||
private Agent(Builder<ACTION> builder) {
|
/**
|
||||||
this.environment = builder.environment;
|
* @param environment The {@link Environment} to be used
|
||||||
this.transformProcess = builder.transformProcess;
|
* @param transformProcess The {@link TransformProcess} to be used to transform the raw observations into usable ones.
|
||||||
this.policy = builder.policy;
|
* @param policy The {@link IPolicy} to be used
|
||||||
this.maxEpisodeSteps = builder.maxEpisodeSteps;
|
* @param maxEpisodeSteps The maximum number of steps an episode can have before being interrupted. Use null to have no max.
|
||||||
this.id = builder.id;
|
* @param id A user-supplied id to identify the instance.
|
||||||
|
*/
|
||||||
|
public Agent(@NonNull Environment<ACTION> environment, @NonNull TransformProcess transformProcess, @NonNull IPolicy<ACTION> policy, Integer maxEpisodeSteps, String id) {
|
||||||
|
Preconditions.checkArgument(maxEpisodeSteps == null || maxEpisodeSteps > 0, "maxEpisodeSteps must be null (no maximum) or greater than 0, got", maxEpisodeSteps);
|
||||||
|
|
||||||
|
this.environment = environment;
|
||||||
|
this.transformProcess = transformProcess;
|
||||||
|
this.policy = policy;
|
||||||
|
this.maxEpisodeSteps = maxEpisodeSteps;
|
||||||
|
this.id = id;
|
||||||
|
|
||||||
listeners = buildListenerList();
|
listeners = buildListenerList();
|
||||||
}
|
}
|
||||||
|
@ -58,10 +88,17 @@ public class Agent<ACTION> {
|
||||||
return new AgentListenerList<ACTION>();
|
return new AgentListenerList<ACTION>();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Add a {@link AgentListener} that will be notified when agent events happens
|
||||||
|
* @param listener
|
||||||
|
*/
|
||||||
public void addListener(AgentListener listener) {
|
public void addListener(AgentListener listener) {
|
||||||
listeners.add(listener);
|
listeners.add(listener);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* This will run a single episode
|
||||||
|
*/
|
||||||
public void run() {
|
public void run() {
|
||||||
runEpisode();
|
runEpisode();
|
||||||
}
|
}
|
||||||
|
@ -80,7 +117,7 @@ public class Agent<ACTION> {
|
||||||
|
|
||||||
canContinue = listeners.notifyBeforeEpisode(this);
|
canContinue = listeners.notifyBeforeEpisode(this);
|
||||||
|
|
||||||
while (canContinue && !environment.isEpisodeFinished() && (maxEpisodeSteps == null || episodeStepNumber < maxEpisodeSteps)) {
|
while (canContinue && !environment.isEpisodeFinished() && (maxEpisodeSteps == null || episodeStepCount < maxEpisodeSteps)) {
|
||||||
performStep();
|
performStep();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -100,9 +137,9 @@ public class Agent<ACTION> {
|
||||||
}
|
}
|
||||||
|
|
||||||
protected void resetEnvironment() {
|
protected void resetEnvironment() {
|
||||||
episodeStepNumber = 0;
|
episodeStepCount = 0;
|
||||||
Map<String, Object> channelsData = environment.reset();
|
Map<String, Object> channelsData = environment.reset();
|
||||||
this.observation = transformProcess.transform(channelsData, episodeStepNumber, false);
|
this.observation = transformProcess.transform(channelsData, episodeStepCount, false);
|
||||||
}
|
}
|
||||||
|
|
||||||
protected void resetPolicy() {
|
protected void resetPolicy() {
|
||||||
|
@ -125,7 +162,6 @@ public class Agent<ACTION> {
|
||||||
}
|
}
|
||||||
|
|
||||||
StepResult stepResult = act(action);
|
StepResult stepResult = act(action);
|
||||||
handleStepResult(stepResult);
|
|
||||||
|
|
||||||
onAfterStep(stepResult);
|
onAfterStep(stepResult);
|
||||||
|
|
||||||
|
@ -134,11 +170,11 @@ public class Agent<ACTION> {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
incrementEpisodeStepNumber();
|
incrementEpisodeStepCount();
|
||||||
}
|
}
|
||||||
|
|
||||||
protected void incrementEpisodeStepNumber() {
|
protected void incrementEpisodeStepCount() {
|
||||||
++episodeStepNumber;
|
++episodeStepCount;
|
||||||
}
|
}
|
||||||
|
|
||||||
protected ACTION decideAction(Observation observation) {
|
protected ACTION decideAction(Observation observation) {
|
||||||
|
@ -150,12 +186,15 @@ public class Agent<ACTION> {
|
||||||
}
|
}
|
||||||
|
|
||||||
protected StepResult act(ACTION action) {
|
protected StepResult act(ACTION action) {
|
||||||
return environment.step(action);
|
Observation observationBeforeAction = observation;
|
||||||
}
|
|
||||||
|
|
||||||
protected void handleStepResult(StepResult stepResult) {
|
StepResult stepResult = environment.step(action);
|
||||||
observation = convertChannelDataToObservation(stepResult, episodeStepNumber + 1);
|
observation = convertChannelDataToObservation(stepResult, episodeStepCount + 1);
|
||||||
reward +=computeReward(stepResult);
|
reward += computeReward(stepResult);
|
||||||
|
|
||||||
|
onAfterAction(observationBeforeAction, action, stepResult);
|
||||||
|
|
||||||
|
return stepResult;
|
||||||
}
|
}
|
||||||
|
|
||||||
protected Observation convertChannelDataToObservation(StepResult stepResult, int episodeStepNumberOfObs) {
|
protected Observation convertChannelDataToObservation(StepResult stepResult, int episodeStepNumberOfObs) {
|
||||||
|
@ -166,6 +205,10 @@ public class Agent<ACTION> {
|
||||||
return stepResult.getReward();
|
return stepResult.getReward();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
protected void onAfterAction(Observation observationBeforeAction, ACTION action, StepResult stepResult) {
|
||||||
|
// Do Nothing
|
||||||
|
}
|
||||||
|
|
||||||
protected void onAfterStep(StepResult stepResult) {
|
protected void onAfterStep(StepResult stepResult) {
|
||||||
// Do Nothing
|
// Do Nothing
|
||||||
}
|
}
|
||||||
|
@ -174,16 +217,24 @@ public class Agent<ACTION> {
|
||||||
// Do Nothing
|
// Do Nothing
|
||||||
}
|
}
|
||||||
|
|
||||||
public static <ACTION> Builder<ACTION> builder(@NonNull Environment<ACTION> environment, @NonNull TransformProcess transformProcess, @NonNull IPolicy<ACTION> policy) {
|
/**
|
||||||
|
*
|
||||||
|
* @param environment
|
||||||
|
* @param transformProcess
|
||||||
|
* @param policy
|
||||||
|
* @param <ACTION>
|
||||||
|
* @return
|
||||||
|
*/
|
||||||
|
public static <ACTION> Builder<ACTION, Agent> builder(@NonNull Environment<ACTION> environment, @NonNull TransformProcess transformProcess, @NonNull IPolicy<ACTION> policy) {
|
||||||
return new Builder<>(environment, transformProcess, policy);
|
return new Builder<>(environment, transformProcess, policy);
|
||||||
}
|
}
|
||||||
|
|
||||||
public static class Builder<ACTION> {
|
public static class Builder<ACTION, AGENT_TYPE extends Agent> {
|
||||||
private final Environment<ACTION> environment;
|
protected final Environment<ACTION> environment;
|
||||||
private final TransformProcess transformProcess;
|
protected final TransformProcess transformProcess;
|
||||||
private final IPolicy<ACTION> policy;
|
protected final IPolicy<ACTION> policy;
|
||||||
private Integer maxEpisodeSteps = null; // Default, no max
|
protected Integer maxEpisodeSteps = null; // Default, no max
|
||||||
private String id;
|
protected String id;
|
||||||
|
|
||||||
public Builder(@NonNull Environment<ACTION> environment, @NonNull TransformProcess transformProcess, @NonNull IPolicy<ACTION> policy) {
|
public Builder(@NonNull Environment<ACTION> environment, @NonNull TransformProcess transformProcess, @NonNull IPolicy<ACTION> policy) {
|
||||||
this.environment = environment;
|
this.environment = environment;
|
||||||
|
@ -191,20 +242,20 @@ public class Agent<ACTION> {
|
||||||
this.policy = policy;
|
this.policy = policy;
|
||||||
}
|
}
|
||||||
|
|
||||||
public Builder<ACTION> maxEpisodeSteps(int maxEpisodeSteps) {
|
public Builder<ACTION, AGENT_TYPE> maxEpisodeSteps(int maxEpisodeSteps) {
|
||||||
Preconditions.checkArgument(maxEpisodeSteps > 0, "maxEpisodeSteps must be greater than 0, got", maxEpisodeSteps);
|
Preconditions.checkArgument(maxEpisodeSteps > 0, "maxEpisodeSteps must be greater than 0, got", maxEpisodeSteps);
|
||||||
this.maxEpisodeSteps = maxEpisodeSteps;
|
this.maxEpisodeSteps = maxEpisodeSteps;
|
||||||
|
|
||||||
return this;
|
return this;
|
||||||
}
|
}
|
||||||
|
|
||||||
public Builder<ACTION> id(String id) {
|
public Builder<ACTION, AGENT_TYPE> id(String id) {
|
||||||
this.id = id;
|
this.id = id;
|
||||||
return this;
|
return this;
|
||||||
}
|
}
|
||||||
|
|
||||||
public Agent build() {
|
public AGENT_TYPE build() {
|
||||||
return new Agent(this);
|
return (AGENT_TYPE)new Agent<ACTION>(environment, transformProcess, policy, maxEpisodeSteps, id);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
|
@ -0,0 +1,115 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2020 Konduit K.K.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
package org.deeplearning4j.rl4j.agent;
|
||||||
|
|
||||||
|
import lombok.Getter;
|
||||||
|
import lombok.NonNull;
|
||||||
|
import org.deeplearning4j.rl4j.agent.learning.ILearningBehavior;
|
||||||
|
import org.deeplearning4j.rl4j.environment.Environment;
|
||||||
|
import org.deeplearning4j.rl4j.environment.StepResult;
|
||||||
|
import org.deeplearning4j.rl4j.observation.Observation;
|
||||||
|
import org.deeplearning4j.rl4j.observation.transform.TransformProcess;
|
||||||
|
import org.deeplearning4j.rl4j.policy.IPolicy;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* The ActionLearner is an {@link Agent} that delegate the learning to a {@link ILearningBehavior}.
|
||||||
|
* @param <ACTION> The type of the action
|
||||||
|
*/
|
||||||
|
public class AgentLearner<ACTION> extends Agent<ACTION> implements IAgentLearner<ACTION> {
|
||||||
|
|
||||||
|
@Getter
|
||||||
|
private int totalStepCount = 0;
|
||||||
|
|
||||||
|
private final ILearningBehavior<ACTION> learningBehavior;
|
||||||
|
private double rewardAtLastExperience;
|
||||||
|
|
||||||
|
/**
|
||||||
|
*
|
||||||
|
* @param environment The {@link Environment} to be used
|
||||||
|
* @param transformProcess The {@link TransformProcess} to be used to transform the raw observations into usable ones.
|
||||||
|
* @param policy The {@link IPolicy} to be used
|
||||||
|
* @param maxEpisodeSteps The maximum number of steps an episode can have before being interrupted. Use null to have no max.
|
||||||
|
* @param id A user-supplied id to identify the instance.
|
||||||
|
* @param learningBehavior The {@link ILearningBehavior} that will be used to supervise the learning.
|
||||||
|
*/
|
||||||
|
public AgentLearner(Environment<ACTION> environment, TransformProcess transformProcess, IPolicy<ACTION> policy, Integer maxEpisodeSteps, String id, @NonNull ILearningBehavior<ACTION> learningBehavior) {
|
||||||
|
super(environment, transformProcess, policy, maxEpisodeSteps, id);
|
||||||
|
|
||||||
|
this.learningBehavior = learningBehavior;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
protected void reset() {
|
||||||
|
super.reset();
|
||||||
|
|
||||||
|
rewardAtLastExperience = 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
protected void onBeforeEpisode() {
|
||||||
|
super.onBeforeEpisode();
|
||||||
|
|
||||||
|
learningBehavior.handleEpisodeStart();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
protected void onAfterAction(Observation observationBeforeAction, ACTION action, StepResult stepResult) {
|
||||||
|
if(!observationBeforeAction.isSkipped()) {
|
||||||
|
double rewardSinceLastExperience = getReward() - rewardAtLastExperience;
|
||||||
|
learningBehavior.handleNewExperience(observationBeforeAction, action, rewardSinceLastExperience, stepResult.isTerminal());
|
||||||
|
|
||||||
|
rewardAtLastExperience = getReward();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
protected void onAfterEpisode() {
|
||||||
|
learningBehavior.handleEpisodeEnd(getObservation());
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
protected void incrementEpisodeStepCount() {
|
||||||
|
super.incrementEpisodeStepCount();
|
||||||
|
++totalStepCount;
|
||||||
|
}
|
||||||
|
|
||||||
|
// FIXME: parent is still visible
|
||||||
|
public static <ACTION> AgentLearner.Builder<ACTION, AgentLearner<ACTION>> builder(Environment<ACTION> environment,
|
||||||
|
TransformProcess transformProcess,
|
||||||
|
IPolicy<ACTION> policy,
|
||||||
|
ILearningBehavior<ACTION> learningBehavior) {
|
||||||
|
return new AgentLearner.Builder<ACTION, AgentLearner<ACTION>>(environment, transformProcess, policy, learningBehavior);
|
||||||
|
}
|
||||||
|
|
||||||
|
public static class Builder<ACTION, AGENT_TYPE extends AgentLearner<ACTION>> extends Agent.Builder<ACTION, AGENT_TYPE> {
|
||||||
|
|
||||||
|
private final ILearningBehavior<ACTION> learningBehavior;
|
||||||
|
|
||||||
|
public Builder(@NonNull Environment<ACTION> environment,
|
||||||
|
@NonNull TransformProcess transformProcess,
|
||||||
|
@NonNull IPolicy<ACTION> policy,
|
||||||
|
@NonNull ILearningBehavior<ACTION> learningBehavior) {
|
||||||
|
super(environment, transformProcess, policy);
|
||||||
|
|
||||||
|
this.learningBehavior = learningBehavior;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public AGENT_TYPE build() {
|
||||||
|
return (AGENT_TYPE)new AgentLearner<ACTION>(environment, transformProcess, policy, maxEpisodeSteps, id, learningBehavior);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,55 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2020 Konduit K.K.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
package org.deeplearning4j.rl4j.agent;
|
||||||
|
|
||||||
|
import org.deeplearning4j.rl4j.environment.Environment;
|
||||||
|
import org.deeplearning4j.rl4j.policy.IPolicy;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* The interface of {@link Agent}
|
||||||
|
* @param <ACTION>
|
||||||
|
*/
|
||||||
|
public interface IAgent<ACTION> {
|
||||||
|
/**
|
||||||
|
* Will play a single episode
|
||||||
|
*/
|
||||||
|
void run();
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @return A user-supplied id to identify the IAgent instance.
|
||||||
|
*/
|
||||||
|
String getId();
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @return The {@link Environment} instance being used by the agent.
|
||||||
|
*/
|
||||||
|
Environment<ACTION> getEnvironment();
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @return The {@link IPolicy} instance being used by the agent.
|
||||||
|
*/
|
||||||
|
IPolicy<ACTION> getPolicy();
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @return The step count taken in the current episode.
|
||||||
|
*/
|
||||||
|
int getEpisodeStepCount();
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @return The cumulative reward received in the current episode.
|
||||||
|
*/
|
||||||
|
double getReward();
|
||||||
|
}
|
|
@ -0,0 +1,24 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2020 Konduit K.K.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
package org.deeplearning4j.rl4j.agent;
|
||||||
|
|
||||||
|
public interface IAgentLearner<ACTION> extends IAgent<ACTION> {
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @return The total count of steps taken by this AgentLearner, for all episodes.
|
||||||
|
*/
|
||||||
|
int getTotalStepCount();
|
||||||
|
}
|
|
@ -0,0 +1,49 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2020 Konduit K.K.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
package org.deeplearning4j.rl4j.agent.learning;
|
||||||
|
|
||||||
|
import org.deeplearning4j.rl4j.observation.Observation;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* The <code>ILearningBehavior</code> implementations are in charge of the training. Through this interface, they are
|
||||||
|
* notified as new experience is generated.
|
||||||
|
*
|
||||||
|
* @param <ACTION> The type of action
|
||||||
|
*/
|
||||||
|
public interface ILearningBehavior<ACTION> {
|
||||||
|
|
||||||
|
/**
|
||||||
|
* This method is called when a new episode has been started.
|
||||||
|
*/
|
||||||
|
void handleEpisodeStart();
|
||||||
|
|
||||||
|
/**
|
||||||
|
* This method is called when new experience is generated.
|
||||||
|
*
|
||||||
|
* @param observation The observation prior to taking the action
|
||||||
|
* @param action The action that has been taken
|
||||||
|
* @param reward The reward received by taking the action
|
||||||
|
* @param isTerminal True if the episode ended after taking the action
|
||||||
|
*/
|
||||||
|
void handleNewExperience(Observation observation, ACTION action, double reward, boolean isTerminal);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* This method is called when the episode ends or the maximum number of episode steps is reached.
|
||||||
|
*
|
||||||
|
* @param finalObservation The observation after the last action of the episode has been taken.
|
||||||
|
*/
|
||||||
|
void handleEpisodeEnd(Observation finalObservation);
|
||||||
|
}
|
|
@ -0,0 +1,59 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2020 Konduit K.K.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
package org.deeplearning4j.rl4j.agent.learning;
|
||||||
|
|
||||||
|
import lombok.Builder;
|
||||||
|
import org.deeplearning4j.rl4j.agent.update.IUpdateRule;
|
||||||
|
import org.deeplearning4j.rl4j.experience.ExperienceHandler;
|
||||||
|
import org.deeplearning4j.rl4j.observation.Observation;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* A generic {@link ILearningBehavior} that delegates the handling of experience to a {@link ExperienceHandler} and
|
||||||
|
* the update logic to a {@link IUpdateRule}
|
||||||
|
*
|
||||||
|
* @param <ACTION> The type of the action
|
||||||
|
* @param <EXPERIENCE_TYPE> The type of experience the ExperienceHandler needs
|
||||||
|
*/
|
||||||
|
@Builder
|
||||||
|
public class LearningBehavior<ACTION, EXPERIENCE_TYPE> implements ILearningBehavior<ACTION> {
|
||||||
|
|
||||||
|
@Builder.Default
|
||||||
|
private int experienceUpdateSize = 64;
|
||||||
|
|
||||||
|
private final ExperienceHandler<ACTION, EXPERIENCE_TYPE> experienceHandler;
|
||||||
|
private final IUpdateRule<EXPERIENCE_TYPE> updateRule;
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void handleEpisodeStart() {
|
||||||
|
experienceHandler.reset();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void handleNewExperience(Observation observation, ACTION action, double reward, boolean isTerminal) {
|
||||||
|
experienceHandler.addExperience(observation, action, reward, isTerminal);
|
||||||
|
if(experienceHandler.isTrainingBatchReady()) {
|
||||||
|
updateRule.update(experienceHandler.generateTrainingBatch());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void handleEpisodeEnd(Observation finalObservation) {
|
||||||
|
experienceHandler.setFinalObservation(finalObservation);
|
||||||
|
if(experienceHandler.isTrainingBatchReady()) {
|
||||||
|
updateRule.update(experienceHandler.generateTrainingBatch());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -1,23 +1,66 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2020 Konduit K.K.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
package org.deeplearning4j.rl4j.agent.listener;
|
package org.deeplearning4j.rl4j.agent.listener;
|
||||||
|
|
||||||
import org.deeplearning4j.rl4j.agent.Agent;
|
import org.deeplearning4j.rl4j.agent.Agent;
|
||||||
import org.deeplearning4j.rl4j.environment.StepResult;
|
import org.deeplearning4j.rl4j.environment.StepResult;
|
||||||
import org.deeplearning4j.rl4j.observation.Observation;
|
import org.deeplearning4j.rl4j.observation.Observation;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* The base definition of all {@link Agent} event listeners
|
||||||
|
*/
|
||||||
public interface AgentListener<ACTION> {
|
public interface AgentListener<ACTION> {
|
||||||
enum ListenerResponse {
|
enum ListenerResponse {
|
||||||
/**
|
/**
|
||||||
* Tell the learning process to continue calling the listeners and the training.
|
* Tell the {@link Agent} to continue calling the listeners and the processing.
|
||||||
*/
|
*/
|
||||||
CONTINUE,
|
CONTINUE,
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Tell the learning process to stop calling the listeners and terminate the training.
|
* Tell the {@link Agent} to interrupt calling the listeners and stop the processing.
|
||||||
*/
|
*/
|
||||||
STOP,
|
STOP,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Called when a new episode is about to start.
|
||||||
|
* @param agent The agent that generated the event
|
||||||
|
*
|
||||||
|
* @return A {@link ListenerResponse}.
|
||||||
|
*/
|
||||||
AgentListener.ListenerResponse onBeforeEpisode(Agent agent);
|
AgentListener.ListenerResponse onBeforeEpisode(Agent agent);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Called when a step is about to be taken.
|
||||||
|
*
|
||||||
|
* @param agent The agent that generated the event
|
||||||
|
* @param observation The observation before the action is taken
|
||||||
|
* @param action The action that will be performed
|
||||||
|
*
|
||||||
|
* @return A {@link ListenerResponse}.
|
||||||
|
*/
|
||||||
AgentListener.ListenerResponse onBeforeStep(Agent agent, Observation observation, ACTION action);
|
AgentListener.ListenerResponse onBeforeStep(Agent agent, Observation observation, ACTION action);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Called after a step has been taken.
|
||||||
|
*
|
||||||
|
* @param agent The agent that generated the event
|
||||||
|
* @param stepResult The {@link StepResult} result of the step.
|
||||||
|
*
|
||||||
|
* @return A {@link ListenerResponse}.
|
||||||
|
*/
|
||||||
AgentListener.ListenerResponse onAfterStep(Agent agent, StepResult stepResult);
|
AgentListener.ListenerResponse onAfterStep(Agent agent, StepResult stepResult);
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,3 +1,18 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2020 Konduit K.K.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
package org.deeplearning4j.rl4j.agent.listener;
|
package org.deeplearning4j.rl4j.agent.listener;
|
||||||
|
|
||||||
import org.deeplearning4j.rl4j.agent.Agent;
|
import org.deeplearning4j.rl4j.agent.Agent;
|
||||||
|
@ -7,6 +22,10 @@ import org.deeplearning4j.rl4j.observation.Observation;
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* A class that manages a list of {@link AgentListener AgentListeners} listening to an {@link Agent}.
|
||||||
|
* @param <ACTION>
|
||||||
|
*/
|
||||||
public class AgentListenerList<ACTION> {
|
public class AgentListenerList<ACTION> {
|
||||||
protected final List<AgentListener<ACTION>> listeners = new ArrayList<>();
|
protected final List<AgentListener<ACTION>> listeners = new ArrayList<>();
|
||||||
|
|
||||||
|
@ -18,6 +37,13 @@ public class AgentListenerList<ACTION> {
|
||||||
listeners.add(listener);
|
listeners.add(listener);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* This method will notify all listeners that an episode is about to start. If a listener returns
|
||||||
|
* {@link AgentListener.ListenerResponse STOP}, any following listener is skipped.
|
||||||
|
*
|
||||||
|
* @param agent The agent that generated the event.
|
||||||
|
* @return False if the processing should be stopped
|
||||||
|
*/
|
||||||
public boolean notifyBeforeEpisode(Agent<ACTION> agent) {
|
public boolean notifyBeforeEpisode(Agent<ACTION> agent) {
|
||||||
for (AgentListener<ACTION> listener : listeners) {
|
for (AgentListener<ACTION> listener : listeners) {
|
||||||
if (listener.onBeforeEpisode(agent) == AgentListener.ListenerResponse.STOP) {
|
if (listener.onBeforeEpisode(agent) == AgentListener.ListenerResponse.STOP) {
|
||||||
|
@ -28,6 +54,13 @@ public class AgentListenerList<ACTION> {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
*
|
||||||
|
* @param agent The agent that generated the event.
|
||||||
|
* @param observation The observation before the action is taken
|
||||||
|
* @param action The action that will be performed
|
||||||
|
* @return False if the processing should be stopped
|
||||||
|
*/
|
||||||
public boolean notifyBeforeStep(Agent<ACTION> agent, Observation observation, ACTION action) {
|
public boolean notifyBeforeStep(Agent<ACTION> agent, Observation observation, ACTION action) {
|
||||||
for (AgentListener<ACTION> listener : listeners) {
|
for (AgentListener<ACTION> listener : listeners) {
|
||||||
if (listener.onBeforeStep(agent, observation, action) == AgentListener.ListenerResponse.STOP) {
|
if (listener.onBeforeStep(agent, observation, action) == AgentListener.ListenerResponse.STOP) {
|
||||||
|
@ -38,6 +71,12 @@ public class AgentListenerList<ACTION> {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
*
|
||||||
|
* @param agent The agent that generated the event.
|
||||||
|
* @param stepResult The {@link StepResult} result of the step.
|
||||||
|
* @return False if the processing should be stopped
|
||||||
|
*/
|
||||||
public boolean notifyAfterStep(Agent<ACTION> agent, StepResult stepResult) {
|
public boolean notifyAfterStep(Agent<ACTION> agent, StepResult stepResult) {
|
||||||
for (AgentListener<ACTION> listener : listeners) {
|
for (AgentListener<ACTION> listener : listeners) {
|
||||||
if (listener.onAfterStep(agent, stepResult) == AgentListener.ListenerResponse.STOP) {
|
if (listener.onAfterStep(agent, stepResult) == AgentListener.ListenerResponse.STOP) {
|
||||||
|
|
|
@ -0,0 +1,62 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2020 Konduit K.K.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
package org.deeplearning4j.rl4j.agent.update;
|
||||||
|
|
||||||
|
import lombok.Getter;
|
||||||
|
import org.deeplearning4j.rl4j.learning.sync.Transition;
|
||||||
|
import org.deeplearning4j.rl4j.learning.sync.qlearning.TargetQNetworkSource;
|
||||||
|
import org.deeplearning4j.rl4j.learning.sync.qlearning.discrete.TDTargetAlgorithm.DoubleDQN;
|
||||||
|
import org.deeplearning4j.rl4j.learning.sync.qlearning.discrete.TDTargetAlgorithm.ITDTargetAlgorithm;
|
||||||
|
import org.deeplearning4j.rl4j.learning.sync.qlearning.discrete.TDTargetAlgorithm.StandardDQN;
|
||||||
|
import org.deeplearning4j.rl4j.network.dqn.IDQN;
|
||||||
|
import org.nd4j.linalg.dataset.api.DataSet;
|
||||||
|
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
|
// Temporary class that will be replaced with a more generic class that delegates gradient computation
|
||||||
|
// and network update to sub components.
|
||||||
|
public class DQNNeuralNetUpdateRule implements IUpdateRule<Transition<Integer>>, TargetQNetworkSource {
|
||||||
|
|
||||||
|
@Getter
|
||||||
|
private final IDQN qNetwork;
|
||||||
|
|
||||||
|
@Getter
|
||||||
|
private IDQN targetQNetwork;
|
||||||
|
private final int targetUpdateFrequency;
|
||||||
|
|
||||||
|
private final ITDTargetAlgorithm<Integer> tdTargetAlgorithm;
|
||||||
|
|
||||||
|
@Getter
|
||||||
|
private int updateCount = 0;
|
||||||
|
|
||||||
|
public DQNNeuralNetUpdateRule(IDQN qNetwork, int targetUpdateFrequency, boolean isDoubleDQN, double gamma, double errorClamp) {
|
||||||
|
this.qNetwork = qNetwork;
|
||||||
|
this.targetQNetwork = qNetwork.clone();
|
||||||
|
this.targetUpdateFrequency = targetUpdateFrequency;
|
||||||
|
tdTargetAlgorithm = isDoubleDQN
|
||||||
|
? new DoubleDQN(this, gamma, errorClamp)
|
||||||
|
: new StandardDQN(this, gamma, errorClamp);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void update(List<Transition<Integer>> trainingBatch) {
|
||||||
|
DataSet targets = tdTargetAlgorithm.computeTDTargets(trainingBatch);
|
||||||
|
qNetwork.fit(targets.getFeatures(), targets.getLabels());
|
||||||
|
if(++updateCount % targetUpdateFrequency == 0) {
|
||||||
|
targetQNetwork = qNetwork.clone();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,26 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2020 Konduit K.K.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
package org.deeplearning4j.rl4j.agent.update;
|
||||||
|
|
||||||
|
import lombok.Value;
|
||||||
|
import org.deeplearning4j.nn.gradient.Gradient;
|
||||||
|
|
||||||
|
// Work in progress
|
||||||
|
@Value
|
||||||
|
public class Gradients {
|
||||||
|
private Gradient[] gradients; // Temporary: we'll need something better than a Gradient[]
|
||||||
|
private int batchSize;
|
||||||
|
}
|
|
@ -0,0 +1,37 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2020 Konduit K.K.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
package org.deeplearning4j.rl4j.agent.update;
|
||||||
|
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* The role of IUpdateRule implementations is to use an experience batch to improve the accuracy of the policy.
|
||||||
|
* Used by {@link org.deeplearning4j.rl4j.agent.AgentLearner AgentLearner}
|
||||||
|
* @param <EXPERIENCE_TYPE> The type of the experience
|
||||||
|
*/
|
||||||
|
public interface IUpdateRule<EXPERIENCE_TYPE> {
|
||||||
|
/**
|
||||||
|
* Perform the update
|
||||||
|
* @param trainingBatch A batch of experience
|
||||||
|
*/
|
||||||
|
void update(List<EXPERIENCE_TYPE> trainingBatch);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @return The total number of times the policy has been updated. In a multi-agent learning context, this total is
|
||||||
|
* for all the agents.
|
||||||
|
*/
|
||||||
|
int getUpdateCount();
|
||||||
|
}
|
|
@ -1,9 +0,0 @@
|
||||||
package org.deeplearning4j.rl4j.environment;
|
|
||||||
|
|
||||||
import lombok.Value;
|
|
||||||
|
|
||||||
@Value
|
|
||||||
public class ActionSchema<ACTION> {
|
|
||||||
private ACTION noOp;
|
|
||||||
//FIXME ACTION randomAction();
|
|
||||||
}
|
|
|
@ -1,11 +1,54 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2020 Konduit K.K.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
package org.deeplearning4j.rl4j.environment;
|
package org.deeplearning4j.rl4j.environment;
|
||||||
|
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* An interface for environments used by the {@link org.deeplearning4j.rl4j.agent.Agent Agents}.
|
||||||
|
* @param <ACTION> The type of actions
|
||||||
|
*/
|
||||||
public interface Environment<ACTION> {
|
public interface Environment<ACTION> {
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @return The {@link Schema} of the environment
|
||||||
|
*/
|
||||||
Schema<ACTION> getSchema();
|
Schema<ACTION> getSchema();
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Reset the environment's state to start a new episode.
|
||||||
|
* @return
|
||||||
|
*/
|
||||||
Map<String, Object> reset();
|
Map<String, Object> reset();
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Perform a single step.
|
||||||
|
*
|
||||||
|
* @param action The action taken
|
||||||
|
* @return A {@link StepResult} describing the result of the step.
|
||||||
|
*/
|
||||||
StepResult step(ACTION action);
|
StepResult step(ACTION action);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @return True if the episode is finished
|
||||||
|
*/
|
||||||
boolean isEpisodeFinished();
|
boolean isEpisodeFinished();
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Called when the agent is finished using this environment instance.
|
||||||
|
*/
|
||||||
void close();
|
void close();
|
||||||
}
|
}
|
||||||
|
|
|
@ -0,0 +1,26 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2020 Konduit K.K.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
package org.deeplearning4j.rl4j.environment;
|
||||||
|
|
||||||
|
import lombok.Value;
|
||||||
|
|
||||||
|
// Work in progress
|
||||||
|
public interface IActionSchema<ACTION> {
|
||||||
|
ACTION getNoOp();
|
||||||
|
|
||||||
|
// Review: A schema should be data-only and not have behavior
|
||||||
|
ACTION getRandomAction();
|
||||||
|
}
|
|
@ -0,0 +1,47 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2020 Konduit K.K.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
package org.deeplearning4j.rl4j.environment;
|
||||||
|
|
||||||
|
import org.nd4j.linalg.api.rng.Random;
|
||||||
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
|
|
||||||
|
// Work in progress
|
||||||
|
public class IntegerActionSchema implements IActionSchema<Integer> {
|
||||||
|
|
||||||
|
private final int numActions;
|
||||||
|
private final int noOpAction;
|
||||||
|
private final Random rnd;
|
||||||
|
|
||||||
|
public IntegerActionSchema(int numActions, int noOpAction) {
|
||||||
|
this(numActions, noOpAction, Nd4j.getRandom());
|
||||||
|
}
|
||||||
|
|
||||||
|
public IntegerActionSchema(int numActions, int noOpAction, Random rnd) {
|
||||||
|
this.numActions = numActions;
|
||||||
|
this.noOpAction = noOpAction;
|
||||||
|
this.rnd = rnd;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public Integer getNoOp() {
|
||||||
|
return noOpAction;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public Integer getRandomAction() {
|
||||||
|
return rnd.nextInt(numActions);
|
||||||
|
}
|
||||||
|
}
|
|
@ -1,8 +1,24 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2020 Konduit K.K.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
package org.deeplearning4j.rl4j.environment;
|
package org.deeplearning4j.rl4j.environment;
|
||||||
|
|
||||||
import lombok.Value;
|
import lombok.Value;
|
||||||
|
|
||||||
|
// Work in progress
|
||||||
@Value
|
@Value
|
||||||
public class Schema<ACTION> {
|
public class Schema<ACTION> {
|
||||||
private ActionSchema<ACTION> actionSchema;
|
private IActionSchema<ACTION> actionSchema;
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,3 +1,18 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2020 Konduit K.K.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
package org.deeplearning4j.rl4j.environment;
|
package org.deeplearning4j.rl4j.environment;
|
||||||
|
|
||||||
import lombok.Value;
|
import lombok.Value;
|
||||||
|
|
|
@ -41,6 +41,11 @@ public interface ExperienceHandler<A, E> {
|
||||||
*/
|
*/
|
||||||
int getTrainingBatchSize();
|
int getTrainingBatchSize();
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @return True if a batch is ready for training.
|
||||||
|
*/
|
||||||
|
boolean isTrainingBatchReady();
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* The elements are returned in the historical order (i.e. in the order they happened)
|
* The elements are returned in the historical order (i.e. in the order they happened)
|
||||||
* @return The list of experience elements
|
* @return The list of experience elements
|
||||||
|
|
|
@ -36,6 +36,7 @@ import java.util.List;
|
||||||
public class ReplayMemoryExperienceHandler<A> implements ExperienceHandler<A, Transition<A>> {
|
public class ReplayMemoryExperienceHandler<A> implements ExperienceHandler<A, Transition<A>> {
|
||||||
private static final int DEFAULT_MAX_REPLAY_MEMORY_SIZE = 150000;
|
private static final int DEFAULT_MAX_REPLAY_MEMORY_SIZE = 150000;
|
||||||
private static final int DEFAULT_BATCH_SIZE = 32;
|
private static final int DEFAULT_BATCH_SIZE = 32;
|
||||||
|
private final int batchSize;
|
||||||
|
|
||||||
private IExpReplay<A> expReplay;
|
private IExpReplay<A> expReplay;
|
||||||
|
|
||||||
|
@ -43,6 +44,7 @@ public class ReplayMemoryExperienceHandler<A> implements ExperienceHandler<A, Tr
|
||||||
|
|
||||||
public ReplayMemoryExperienceHandler(IExpReplay<A> expReplay) {
|
public ReplayMemoryExperienceHandler(IExpReplay<A> expReplay) {
|
||||||
this.expReplay = expReplay;
|
this.expReplay = expReplay;
|
||||||
|
this.batchSize = expReplay.getDesignatedBatchSize();
|
||||||
}
|
}
|
||||||
|
|
||||||
public ReplayMemoryExperienceHandler(int maxReplayMemorySize, int batchSize, Random random) {
|
public ReplayMemoryExperienceHandler(int maxReplayMemorySize, int batchSize, Random random) {
|
||||||
|
@ -64,6 +66,11 @@ public class ReplayMemoryExperienceHandler<A> implements ExperienceHandler<A, Tr
|
||||||
return expReplay.getBatchSize();
|
return expReplay.getBatchSize();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public boolean isTrainingBatchReady() {
|
||||||
|
return expReplay.getBatchSize() >= batchSize;
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @return A batch of experience selected from the replay memory. The replay memory is unchanged after the call.
|
* @return A batch of experience selected from the replay memory. The replay memory is unchanged after the call.
|
||||||
*/
|
*/
|
||||||
|
|
|
@ -30,10 +30,18 @@ import java.util.List;
|
||||||
*/
|
*/
|
||||||
public class StateActionExperienceHandler<A> implements ExperienceHandler<A, StateActionPair<A>> {
|
public class StateActionExperienceHandler<A> implements ExperienceHandler<A, StateActionPair<A>> {
|
||||||
|
|
||||||
|
private final int batchSize;
|
||||||
|
|
||||||
|
private boolean isFinalObservationSet;
|
||||||
|
|
||||||
|
public StateActionExperienceHandler(int batchSize) {
|
||||||
|
this.batchSize = batchSize;
|
||||||
|
}
|
||||||
|
|
||||||
private List<StateActionPair<A>> stateActionPairs = new ArrayList<>();
|
private List<StateActionPair<A>> stateActionPairs = new ArrayList<>();
|
||||||
|
|
||||||
public void setFinalObservation(Observation observation) {
|
public void setFinalObservation(Observation observation) {
|
||||||
// Do nothing
|
isFinalObservationSet = true;
|
||||||
}
|
}
|
||||||
|
|
||||||
public void addExperience(Observation observation, A action, double reward, boolean isTerminal) {
|
public void addExperience(Observation observation, A action, double reward, boolean isTerminal) {
|
||||||
|
@ -45,6 +53,12 @@ public class StateActionExperienceHandler<A> implements ExperienceHandler<A, Sta
|
||||||
return stateActionPairs.size();
|
return stateActionPairs.size();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public boolean isTrainingBatchReady() {
|
||||||
|
return stateActionPairs.size() >= batchSize
|
||||||
|
|| (isFinalObservationSet && stateActionPairs.size() > 0);
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* The elements are returned in the historical order (i.e. in the order they happened)
|
* The elements are returned in the historical order (i.e. in the order they happened)
|
||||||
* Note: the experience store is cleared after calling this method.
|
* Note: the experience store is cleared after calling this method.
|
||||||
|
@ -62,6 +76,7 @@ public class StateActionExperienceHandler<A> implements ExperienceHandler<A, Sta
|
||||||
@Override
|
@Override
|
||||||
public void reset() {
|
public void reset() {
|
||||||
stateActionPairs = new ArrayList<>();
|
stateActionPairs = new ArrayList<>();
|
||||||
|
isFinalObservationSet = false;
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -24,17 +24,38 @@ import org.nd4j.linalg.factory.Nd4j;
|
||||||
* @author Alexandre Boulanger
|
* @author Alexandre Boulanger
|
||||||
*/
|
*/
|
||||||
public class INDArrayHelper {
|
public class INDArrayHelper {
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* MultiLayerNetwork and ComputationGraph expects input data to be in NCHW in the case of pixels and NS in case of other data types.
|
* Force the input source to have the correct shape:
|
||||||
*
|
* <p><ul>
|
||||||
* We must have either shape 2 (NK) or shape 4 (NCHW)
|
* <li>DL4J requires it to be at least 2D</li>
|
||||||
|
* <li>RL4J has a convention to have the batch size on dimension 0 to all INDArrays</li>
|
||||||
|
* </ul></p>
|
||||||
|
* @param source The {@link INDArray} to be corrected.
|
||||||
|
* @return The corrected INDArray
|
||||||
*/
|
*/
|
||||||
public static INDArray forceCorrectShape(INDArray source) {
|
public static INDArray forceCorrectShape(INDArray source) {
|
||||||
|
|
||||||
return source.shape()[0] == 1 && source.shape().length > 1
|
return source.shape()[0] == 1 && source.rank() > 1
|
||||||
? source
|
? source
|
||||||
: Nd4j.expandDims(source, 0);
|
: Nd4j.expandDims(source, 0);
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* This will create a INDArray with <i>batchSize</i> as dimension 0 and <i>shape</i> as other dimensions.
|
||||||
|
* For example, if <i>batchSize</i> is 10 and shape is { 1, 3, 4 }, the resulting INDArray shape will be { 10, 3, 4}
|
||||||
|
* @param batchSize The size of the batch to create
|
||||||
|
* @param shape The shape of individual elements.
|
||||||
|
* Note: all shapes in RL4J should have a batch size as dimension 0; in this case the batch size should be 1.
|
||||||
|
* @return A INDArray
|
||||||
|
*/
|
||||||
|
public static INDArray createBatchForShape(long batchSize, long... shape) {
|
||||||
|
long[] batchShape;
|
||||||
|
|
||||||
|
batchShape = new long[shape.length];
|
||||||
|
System.arraycopy(shape, 0, batchShape, 0, shape.length);
|
||||||
|
|
||||||
|
batchShape[0] = batchSize;
|
||||||
|
return Nd4j.create(batchShape);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -25,6 +25,7 @@ import org.deeplearning4j.gym.StepReply;
|
||||||
import org.deeplearning4j.rl4j.experience.ExperienceHandler;
|
import org.deeplearning4j.rl4j.experience.ExperienceHandler;
|
||||||
import org.deeplearning4j.rl4j.experience.StateActionExperienceHandler;
|
import org.deeplearning4j.rl4j.experience.StateActionExperienceHandler;
|
||||||
import org.deeplearning4j.rl4j.learning.IHistoryProcessor;
|
import org.deeplearning4j.rl4j.learning.IHistoryProcessor;
|
||||||
|
import org.deeplearning4j.rl4j.learning.configuration.IAsyncLearningConfiguration;
|
||||||
import org.deeplearning4j.rl4j.learning.listener.TrainingListenerList;
|
import org.deeplearning4j.rl4j.learning.listener.TrainingListenerList;
|
||||||
import org.deeplearning4j.rl4j.mdp.MDP;
|
import org.deeplearning4j.rl4j.mdp.MDP;
|
||||||
import org.deeplearning4j.rl4j.network.NeuralNet;
|
import org.deeplearning4j.rl4j.network.NeuralNet;
|
||||||
|
@ -49,7 +50,7 @@ public abstract class AsyncThreadDiscrete<OBSERVATION extends Encodable, NN exte
|
||||||
|
|
||||||
// TODO: Make it configurable with a builder
|
// TODO: Make it configurable with a builder
|
||||||
@Setter(AccessLevel.PROTECTED) @Getter
|
@Setter(AccessLevel.PROTECTED) @Getter
|
||||||
private ExperienceHandler experienceHandler = new StateActionExperienceHandler();
|
private ExperienceHandler experienceHandler;
|
||||||
|
|
||||||
public AsyncThreadDiscrete(IAsyncGlobal<NN> asyncGlobal,
|
public AsyncThreadDiscrete(IAsyncGlobal<NN> asyncGlobal,
|
||||||
MDP<OBSERVATION, Integer, DiscreteSpace> mdp,
|
MDP<OBSERVATION, Integer, DiscreteSpace> mdp,
|
||||||
|
@ -60,6 +61,17 @@ public abstract class AsyncThreadDiscrete<OBSERVATION extends Encodable, NN exte
|
||||||
synchronized (asyncGlobal) {
|
synchronized (asyncGlobal) {
|
||||||
current = (NN) asyncGlobal.getTarget().clone();
|
current = (NN) asyncGlobal.getTarget().clone();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
experienceHandler = new StateActionExperienceHandler(getNStep());
|
||||||
|
}
|
||||||
|
|
||||||
|
private int getNStep() {
|
||||||
|
IAsyncLearningConfiguration configuration = getConfiguration();
|
||||||
|
if(configuration == null) {
|
||||||
|
return Integer.MAX_VALUE;
|
||||||
|
}
|
||||||
|
|
||||||
|
return configuration.getNStep();
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO: Add an actor-learner class and be able to inject the update algorithm
|
// TODO: Add an actor-learner class and be able to inject the update algorithm
|
||||||
|
|
|
@ -71,7 +71,6 @@ public class AsyncNStepQLearningThreadDiscrete<OBSERVATION extends Encodable> ex
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
protected UpdateAlgorithm<IDQN> buildUpdateAlgorithm() {
|
protected UpdateAlgorithm<IDQN> buildUpdateAlgorithm() {
|
||||||
int[] shape = getHistoryProcessor() == null ? getMdp().getObservationSpace().getShape() : getHistoryProcessor().getConf().getShape();
|
return new QLearningUpdateAlgorithm(getMdp().getActionSpace().getSize(), configuration.getGamma());
|
||||||
return new QLearningUpdateAlgorithm(shape, getMdp().getActionSpace().getSize(), configuration.getGamma());
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -17,7 +17,7 @@ package org.deeplearning4j.rl4j.learning.async.nstep.discrete;
|
||||||
|
|
||||||
import org.deeplearning4j.nn.gradient.Gradient;
|
import org.deeplearning4j.nn.gradient.Gradient;
|
||||||
import org.deeplearning4j.rl4j.experience.StateActionPair;
|
import org.deeplearning4j.rl4j.experience.StateActionPair;
|
||||||
import org.deeplearning4j.rl4j.learning.Learning;
|
import org.deeplearning4j.rl4j.helper.INDArrayHelper;
|
||||||
import org.deeplearning4j.rl4j.learning.async.UpdateAlgorithm;
|
import org.deeplearning4j.rl4j.learning.async.UpdateAlgorithm;
|
||||||
import org.deeplearning4j.rl4j.network.dqn.IDQN;
|
import org.deeplearning4j.rl4j.network.dqn.IDQN;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
@ -27,15 +27,12 @@ import java.util.List;
|
||||||
|
|
||||||
public class QLearningUpdateAlgorithm implements UpdateAlgorithm<IDQN> {
|
public class QLearningUpdateAlgorithm implements UpdateAlgorithm<IDQN> {
|
||||||
|
|
||||||
private final int[] shape;
|
|
||||||
private final int actionSpaceSize;
|
private final int actionSpaceSize;
|
||||||
private final double gamma;
|
private final double gamma;
|
||||||
|
|
||||||
public QLearningUpdateAlgorithm(int[] shape,
|
public QLearningUpdateAlgorithm(int actionSpaceSize,
|
||||||
int actionSpaceSize,
|
|
||||||
double gamma) {
|
double gamma) {
|
||||||
|
|
||||||
this.shape = shape;
|
|
||||||
this.actionSpaceSize = actionSpaceSize;
|
this.actionSpaceSize = actionSpaceSize;
|
||||||
this.gamma = gamma;
|
this.gamma = gamma;
|
||||||
}
|
}
|
||||||
|
@ -44,33 +41,34 @@ public class QLearningUpdateAlgorithm implements UpdateAlgorithm<IDQN> {
|
||||||
public Gradient[] computeGradients(IDQN current, List<StateActionPair<Integer>> experience) {
|
public Gradient[] computeGradients(IDQN current, List<StateActionPair<Integer>> experience) {
|
||||||
int size = experience.size();
|
int size = experience.size();
|
||||||
|
|
||||||
int[] nshape = Learning.makeShape(size, shape);
|
|
||||||
INDArray input = Nd4j.create(nshape);
|
|
||||||
INDArray targets = Nd4j.create(size, actionSpaceSize);
|
|
||||||
|
|
||||||
StateActionPair<Integer> stateActionPair = experience.get(size - 1);
|
StateActionPair<Integer> stateActionPair = experience.get(size - 1);
|
||||||
|
|
||||||
|
INDArray data = stateActionPair.getObservation().getData();
|
||||||
|
INDArray features = INDArrayHelper.createBatchForShape(size, data.shape());
|
||||||
|
INDArray targets = Nd4j.create(size, actionSpaceSize);
|
||||||
|
|
||||||
double r;
|
double r;
|
||||||
if (stateActionPair.isTerminal()) {
|
if (stateActionPair.isTerminal()) {
|
||||||
r = 0;
|
r = 0;
|
||||||
} else {
|
} else {
|
||||||
INDArray[] output = null;
|
INDArray[] output = null;
|
||||||
output = current.outputAll(stateActionPair.getObservation().getData());
|
output = current.outputAll(data);
|
||||||
r = Nd4j.max(output[0]).getDouble(0);
|
r = Nd4j.max(output[0]).getDouble(0);
|
||||||
}
|
}
|
||||||
|
|
||||||
for (int i = size - 1; i >= 0; i--) {
|
for (int i = size - 1; i >= 0; i--) {
|
||||||
stateActionPair = experience.get(i);
|
stateActionPair = experience.get(i);
|
||||||
|
data = stateActionPair.getObservation().getData();
|
||||||
|
|
||||||
input.putRow(i, stateActionPair.getObservation().getData());
|
features.putRow(i, data);
|
||||||
|
|
||||||
r = stateActionPair.getReward() + gamma * r;
|
r = stateActionPair.getReward() + gamma * r;
|
||||||
INDArray[] output = current.outputAll(stateActionPair.getObservation().getData());
|
INDArray[] output = current.outputAll(data);
|
||||||
INDArray row = output[0];
|
INDArray row = output[0];
|
||||||
row = row.putScalar(stateActionPair.getAction(), r);
|
row = row.putScalar(stateActionPair.getAction(), r);
|
||||||
targets.putRow(i, row);
|
targets.putRow(i, row);
|
||||||
}
|
}
|
||||||
|
|
||||||
return current.gradient(input, targets);
|
return current.gradient(features, targets);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -80,6 +80,11 @@ public class ExpReplay<A> implements IExpReplay<A> {
|
||||||
//log.info("size: "+storage.size());
|
//log.info("size: "+storage.size());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public int getDesignatedBatchSize() {
|
||||||
|
return batchSize;
|
||||||
|
}
|
||||||
|
|
||||||
public int getBatchSize() {
|
public int getBatchSize() {
|
||||||
int storageSize = storage.size();
|
int storageSize = storage.size();
|
||||||
return Math.min(storageSize, batchSize);
|
return Math.min(storageSize, batchSize);
|
||||||
|
|
|
@ -47,4 +47,9 @@ public interface IExpReplay<A> {
|
||||||
* @param transition a new transition to store
|
* @param transition a new transition to store
|
||||||
*/
|
*/
|
||||||
void store(Transition<A> transition);
|
void store(Transition<A> transition);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @return The desired size of batches
|
||||||
|
*/
|
||||||
|
int getDesignatedBatchSize();
|
||||||
}
|
}
|
||||||
|
|
|
@ -51,25 +51,16 @@ import java.util.List;
|
||||||
@Slf4j
|
@Slf4j
|
||||||
public abstract class QLearning<O extends Encodable, A, AS extends ActionSpace<A>>
|
public abstract class QLearning<O extends Encodable, A, AS extends ActionSpace<A>>
|
||||||
extends SyncLearning<O, A, AS, IDQN>
|
extends SyncLearning<O, A, AS, IDQN>
|
||||||
implements TargetQNetworkSource, IEpochTrainer {
|
implements IEpochTrainer {
|
||||||
|
|
||||||
protected abstract LegacyMDPWrapper<O, A, AS> getLegacyMDPWrapper();
|
protected abstract LegacyMDPWrapper<O, A, AS> getLegacyMDPWrapper();
|
||||||
|
|
||||||
protected abstract EpsGreedy<O, A, AS> getEgPolicy();
|
protected abstract EpsGreedy<A> getEgPolicy();
|
||||||
|
|
||||||
public abstract MDP<O, A, AS> getMdp();
|
public abstract MDP<O, A, AS> getMdp();
|
||||||
|
|
||||||
public abstract IDQN getQNetwork();
|
public abstract IDQN getQNetwork();
|
||||||
|
|
||||||
public abstract IDQN getTargetQNetwork();
|
|
||||||
|
|
||||||
protected abstract void setTargetQNetwork(IDQN dqn);
|
|
||||||
|
|
||||||
protected void updateTargetNetwork() {
|
|
||||||
log.info("Update target network");
|
|
||||||
setTargetQNetwork(getQNetwork().clone());
|
|
||||||
}
|
|
||||||
|
|
||||||
public IDQN getNeuralNet() {
|
public IDQN getNeuralNet() {
|
||||||
return getQNetwork();
|
return getQNetwork();
|
||||||
}
|
}
|
||||||
|
@ -101,11 +92,6 @@ public abstract class QLearning<O extends Encodable, A, AS extends ActionSpace<A
|
||||||
int numQ = 0;
|
int numQ = 0;
|
||||||
List<Double> scores = new ArrayList<>();
|
List<Double> scores = new ArrayList<>();
|
||||||
while (currentEpisodeStepCount < getConfiguration().getMaxEpochStep() && !getMdp().isDone()) {
|
while (currentEpisodeStepCount < getConfiguration().getMaxEpochStep() && !getMdp().isDone()) {
|
||||||
|
|
||||||
if (this.getStepCount() % getConfiguration().getTargetDqnUpdateFreq() == 0) {
|
|
||||||
updateTargetNetwork();
|
|
||||||
}
|
|
||||||
|
|
||||||
QLStepReturn<Observation> stepR = trainStep(obs);
|
QLStepReturn<Observation> stepR = trainStep(obs);
|
||||||
|
|
||||||
if (!stepR.getMaxQ().isNaN()) {
|
if (!stepR.getMaxQ().isNaN()) {
|
||||||
|
@ -146,7 +132,6 @@ public abstract class QLearning<O extends Encodable, A, AS extends ActionSpace<A
|
||||||
|
|
||||||
protected void resetNetworks() {
|
protected void resetNetworks() {
|
||||||
getQNetwork().reset();
|
getQNetwork().reset();
|
||||||
getTargetQNetwork().reset();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
private InitMdp<Observation> refacInitMdp() {
|
private InitMdp<Observation> refacInitMdp() {
|
||||||
|
|
|
@ -21,6 +21,10 @@ import lombok.AccessLevel;
|
||||||
import lombok.Getter;
|
import lombok.Getter;
|
||||||
import lombok.Setter;
|
import lombok.Setter;
|
||||||
import org.deeplearning4j.gym.StepReply;
|
import org.deeplearning4j.gym.StepReply;
|
||||||
|
import org.deeplearning4j.rl4j.agent.learning.ILearningBehavior;
|
||||||
|
import org.deeplearning4j.rl4j.agent.learning.LearningBehavior;
|
||||||
|
import org.deeplearning4j.rl4j.agent.update.DQNNeuralNetUpdateRule;
|
||||||
|
import org.deeplearning4j.rl4j.agent.update.IUpdateRule;
|
||||||
import org.deeplearning4j.rl4j.experience.ExperienceHandler;
|
import org.deeplearning4j.rl4j.experience.ExperienceHandler;
|
||||||
import org.deeplearning4j.rl4j.experience.ReplayMemoryExperienceHandler;
|
import org.deeplearning4j.rl4j.experience.ReplayMemoryExperienceHandler;
|
||||||
import org.deeplearning4j.rl4j.learning.IHistoryProcessor;
|
import org.deeplearning4j.rl4j.learning.IHistoryProcessor;
|
||||||
|
@ -28,9 +32,6 @@ import org.deeplearning4j.rl4j.learning.Learning;
|
||||||
import org.deeplearning4j.rl4j.learning.configuration.QLearningConfiguration;
|
import org.deeplearning4j.rl4j.learning.configuration.QLearningConfiguration;
|
||||||
import org.deeplearning4j.rl4j.learning.sync.Transition;
|
import org.deeplearning4j.rl4j.learning.sync.Transition;
|
||||||
import org.deeplearning4j.rl4j.learning.sync.qlearning.QLearning;
|
import org.deeplearning4j.rl4j.learning.sync.qlearning.QLearning;
|
||||||
import org.deeplearning4j.rl4j.learning.sync.qlearning.discrete.TDTargetAlgorithm.DoubleDQN;
|
|
||||||
import org.deeplearning4j.rl4j.learning.sync.qlearning.discrete.TDTargetAlgorithm.ITDTargetAlgorithm;
|
|
||||||
import org.deeplearning4j.rl4j.learning.sync.qlearning.discrete.TDTargetAlgorithm.StandardDQN;
|
|
||||||
import org.deeplearning4j.rl4j.mdp.MDP;
|
import org.deeplearning4j.rl4j.mdp.MDP;
|
||||||
import org.deeplearning4j.rl4j.network.dqn.IDQN;
|
import org.deeplearning4j.rl4j.network.dqn.IDQN;
|
||||||
import org.deeplearning4j.rl4j.space.Encodable;
|
import org.deeplearning4j.rl4j.space.Encodable;
|
||||||
|
@ -41,12 +42,8 @@ import org.deeplearning4j.rl4j.space.DiscreteSpace;
|
||||||
import org.deeplearning4j.rl4j.util.LegacyMDPWrapper;
|
import org.deeplearning4j.rl4j.util.LegacyMDPWrapper;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.api.rng.Random;
|
import org.nd4j.linalg.api.rng.Random;
|
||||||
import org.nd4j.linalg.dataset.api.DataSet;
|
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
|
|
||||||
import java.util.List;
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @author rubenfiszel (ruben.fiszel@epfl.ch) 7/18/16.
|
* @author rubenfiszel (ruben.fiszel@epfl.ch) 7/18/16.
|
||||||
|
@ -63,22 +60,15 @@ public abstract class QLearningDiscrete<O extends Encodable> extends QLearning<O
|
||||||
@Getter
|
@Getter
|
||||||
private DQNPolicy<O> policy;
|
private DQNPolicy<O> policy;
|
||||||
@Getter
|
@Getter
|
||||||
private EpsGreedy<O, Integer, DiscreteSpace> egPolicy;
|
private EpsGreedy<Integer> egPolicy;
|
||||||
|
|
||||||
@Getter
|
@Getter
|
||||||
final private IDQN qNetwork;
|
final private IDQN qNetwork;
|
||||||
@Getter
|
|
||||||
@Setter(AccessLevel.PROTECTED)
|
|
||||||
private IDQN targetQNetwork;
|
|
||||||
|
|
||||||
private int lastAction;
|
private int lastAction;
|
||||||
private double accuReward = 0;
|
private double accuReward = 0;
|
||||||
|
|
||||||
ITDTargetAlgorithm tdTargetAlgorithm;
|
private final ILearningBehavior<Integer> learningBehavior;
|
||||||
|
|
||||||
// TODO: User a builder and remove the setter
|
|
||||||
@Getter(AccessLevel.PROTECTED) @Setter
|
|
||||||
private ExperienceHandler<Integer, Transition<Integer>> experienceHandler;
|
|
||||||
|
|
||||||
protected LegacyMDPWrapper<O, Integer, DiscreteSpace> getLegacyMDPWrapper() {
|
protected LegacyMDPWrapper<O, Integer, DiscreteSpace> getLegacyMDPWrapper() {
|
||||||
return mdp;
|
return mdp;
|
||||||
|
@ -88,21 +78,31 @@ public abstract class QLearningDiscrete<O extends Encodable> extends QLearning<O
|
||||||
this(mdp, dqn, conf, epsilonNbStep, Nd4j.getRandomFactory().getNewRandomInstance(conf.getSeed()));
|
this(mdp, dqn, conf, epsilonNbStep, Nd4j.getRandomFactory().getNewRandomInstance(conf.getSeed()));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public QLearningDiscrete(MDP<O, Integer, DiscreteSpace> mdp, IDQN dqn, QLearningConfiguration conf, int epsilonNbStep, Random random) {
|
||||||
|
this(mdp, dqn, conf, epsilonNbStep, buildLearningBehavior(dqn, conf, random), random);
|
||||||
|
}
|
||||||
|
|
||||||
public QLearningDiscrete(MDP<O, Integer, DiscreteSpace> mdp, IDQN dqn, QLearningConfiguration conf,
|
public QLearningDiscrete(MDP<O, Integer, DiscreteSpace> mdp, IDQN dqn, QLearningConfiguration conf,
|
||||||
int epsilonNbStep, Random random) {
|
int epsilonNbStep, ILearningBehavior<Integer> learningBehavior, Random random) {
|
||||||
this.configuration = conf;
|
this.configuration = conf;
|
||||||
this.mdp = new LegacyMDPWrapper<>(mdp, null);
|
this.mdp = new LegacyMDPWrapper<>(mdp, null);
|
||||||
qNetwork = dqn;
|
qNetwork = dqn;
|
||||||
targetQNetwork = dqn.clone();
|
|
||||||
policy = new DQNPolicy(getQNetwork());
|
policy = new DQNPolicy(getQNetwork());
|
||||||
egPolicy = new EpsGreedy(policy, mdp, conf.getUpdateStart(), epsilonNbStep, random, conf.getMinEpsilon(),
|
egPolicy = new EpsGreedy(policy, mdp, conf.getUpdateStart(), epsilonNbStep, random, conf.getMinEpsilon(),
|
||||||
this);
|
this);
|
||||||
|
|
||||||
tdTargetAlgorithm = conf.isDoubleDQN()
|
this.learningBehavior = learningBehavior;
|
||||||
? new DoubleDQN(this, conf.getGamma(), conf.getErrorClamp())
|
}
|
||||||
: new StandardDQN(this, conf.getGamma(), conf.getErrorClamp());
|
|
||||||
|
private static ILearningBehavior<Integer> buildLearningBehavior(IDQN qNetwork, QLearningConfiguration conf, Random random) {
|
||||||
|
IUpdateRule<Transition<Integer>> updateRule = new DQNNeuralNetUpdateRule(qNetwork, conf.getTargetDqnUpdateFreq(), conf.isDoubleDQN(), conf.getGamma(), conf.getErrorClamp());
|
||||||
|
ExperienceHandler<Integer, Transition<Integer>> experienceHandler = new ReplayMemoryExperienceHandler(conf.getExpRepMaxSize(), conf.getBatchSize(), random);
|
||||||
|
return LearningBehavior.<Integer, Transition<Integer>>builder()
|
||||||
|
.experienceHandler(experienceHandler)
|
||||||
|
.updateRule(updateRule)
|
||||||
|
.experienceUpdateSize(conf.getBatchSize())
|
||||||
|
.build();
|
||||||
|
|
||||||
experienceHandler = new ReplayMemoryExperienceHandler(conf.getExpRepMaxSize(), conf.getBatchSize(), random);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
public MDP<O, Integer, DiscreteSpace> getMdp() {
|
public MDP<O, Integer, DiscreteSpace> getMdp() {
|
||||||
|
@ -119,7 +119,7 @@ public abstract class QLearningDiscrete<O extends Encodable> extends QLearning<O
|
||||||
public void preEpoch() {
|
public void preEpoch() {
|
||||||
lastAction = mdp.getActionSpace().noOp();
|
lastAction = mdp.getActionSpace().noOp();
|
||||||
accuReward = 0;
|
accuReward = 0;
|
||||||
experienceHandler.reset();
|
learningBehavior.handleEpisodeStart();
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
@ -136,12 +136,6 @@ public abstract class QLearningDiscrete<O extends Encodable> extends QLearning<O
|
||||||
*/
|
*/
|
||||||
protected QLStepReturn<Observation> trainStep(Observation obs) {
|
protected QLStepReturn<Observation> trainStep(Observation obs) {
|
||||||
|
|
||||||
boolean isHistoryProcessor = getHistoryProcessor() != null;
|
|
||||||
int skipFrame = isHistoryProcessor ? getHistoryProcessor().getConf().getSkipFrame() : 1;
|
|
||||||
int historyLength = isHistoryProcessor ? getHistoryProcessor().getConf().getHistoryLength() : 1;
|
|
||||||
int updateStart = this.getConfiguration().getUpdateStart()
|
|
||||||
+ ((this.getConfiguration().getBatchSize() + historyLength) * skipFrame);
|
|
||||||
|
|
||||||
Double maxQ = Double.NaN; //ignore if Nan for stats
|
Double maxQ = Double.NaN; //ignore if Nan for stats
|
||||||
|
|
||||||
//if step of training, just repeat lastAction
|
//if step of training, just repeat lastAction
|
||||||
|
@ -160,29 +154,15 @@ public abstract class QLearningDiscrete<O extends Encodable> extends QLearning<O
|
||||||
if (!obs.isSkipped()) {
|
if (!obs.isSkipped()) {
|
||||||
|
|
||||||
// Add experience
|
// Add experience
|
||||||
experienceHandler.addExperience(obs, lastAction, accuReward, stepReply.isDone());
|
learningBehavior.handleNewExperience(obs, lastAction, accuReward, stepReply.isDone());
|
||||||
accuReward = 0;
|
accuReward = 0;
|
||||||
|
|
||||||
// Update NN
|
|
||||||
// FIXME: maybe start updating when experience replay has reached a certain size instead of using "updateStart"?
|
|
||||||
if (this.getStepCount() > updateStart) {
|
|
||||||
DataSet targets = setTarget(experienceHandler.generateTrainingBatch());
|
|
||||||
getQNetwork().fit(targets.getFeatures(), targets.getLabels());
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return new QLStepReturn<>(maxQ, getQNetwork().getLatestScore(), stepReply);
|
return new QLStepReturn<>(maxQ, getQNetwork().getLatestScore(), stepReply);
|
||||||
}
|
}
|
||||||
|
|
||||||
protected DataSet setTarget(List<Transition<Integer>> transitions) {
|
|
||||||
if (transitions.size() == 0)
|
|
||||||
throw new IllegalArgumentException("too few transitions");
|
|
||||||
|
|
||||||
return tdTargetAlgorithm.computeTDTargets(transitions);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
protected void finishEpoch(Observation observation) {
|
protected void finishEpoch(Observation observation) {
|
||||||
experienceHandler.setFinalObservation(observation);
|
learningBehavior.handleEpisodeEnd(observation);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -2,21 +2,19 @@ package org.deeplearning4j.rl4j.mdp;
|
||||||
|
|
||||||
import lombok.Getter;
|
import lombok.Getter;
|
||||||
import lombok.Setter;
|
import lombok.Setter;
|
||||||
import org.deeplearning4j.rl4j.environment.ActionSchema;
|
import org.deeplearning4j.rl4j.environment.*;
|
||||||
import org.deeplearning4j.rl4j.environment.Environment;
|
import org.nd4j.linalg.api.rng.Random;
|
||||||
import org.deeplearning4j.rl4j.environment.Schema;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
import org.deeplearning4j.rl4j.environment.StepResult;
|
|
||||||
|
|
||||||
import java.util.HashMap;
|
import java.util.HashMap;
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
import java.util.Random;
|
|
||||||
|
|
||||||
public class CartpoleEnvironment implements Environment<Integer> {
|
public class CartpoleEnvironment implements Environment<Integer> {
|
||||||
private static final int NUM_ACTIONS = 2;
|
private static final int NUM_ACTIONS = 2;
|
||||||
private static final int ACTION_LEFT = 0;
|
private static final int ACTION_LEFT = 0;
|
||||||
private static final int ACTION_RIGHT = 1;
|
private static final int ACTION_RIGHT = 1;
|
||||||
|
|
||||||
private static final Schema<Integer> schema = new Schema<>(new ActionSchema<>(ACTION_LEFT));
|
private final Schema<Integer> schema;
|
||||||
|
|
||||||
public enum KinematicsIntegrators { Euler, SemiImplicitEuler };
|
public enum KinematicsIntegrators { Euler, SemiImplicitEuler };
|
||||||
|
|
||||||
|
@ -48,11 +46,12 @@ public class CartpoleEnvironment implements Environment<Integer> {
|
||||||
private Integer stepsBeyondDone;
|
private Integer stepsBeyondDone;
|
||||||
|
|
||||||
public CartpoleEnvironment() {
|
public CartpoleEnvironment() {
|
||||||
rnd = new Random();
|
this(Nd4j.getRandom());
|
||||||
}
|
}
|
||||||
|
|
||||||
public CartpoleEnvironment(int seed) {
|
public CartpoleEnvironment(Random rnd) {
|
||||||
rnd = new Random(seed);
|
this.rnd = rnd;
|
||||||
|
this.schema = new Schema<Integer>(new IntegerActionSchema(NUM_ACTIONS, ACTION_LEFT, rnd));
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
|
|
@ -17,16 +17,19 @@
|
||||||
|
|
||||||
package org.deeplearning4j.rl4j.policy;
|
package org.deeplearning4j.rl4j.policy;
|
||||||
|
|
||||||
import lombok.AllArgsConstructor;
|
import lombok.Builder;
|
||||||
|
import lombok.NonNull;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
|
import org.deeplearning4j.rl4j.environment.IActionSchema;
|
||||||
import org.deeplearning4j.rl4j.learning.IEpochTrainer;
|
import org.deeplearning4j.rl4j.learning.IEpochTrainer;
|
||||||
import org.deeplearning4j.rl4j.mdp.MDP;
|
import org.deeplearning4j.rl4j.mdp.MDP;
|
||||||
import org.deeplearning4j.rl4j.network.NeuralNet;
|
import org.deeplearning4j.rl4j.network.NeuralNet;
|
||||||
import org.deeplearning4j.rl4j.space.Encodable;
|
|
||||||
import org.deeplearning4j.rl4j.observation.Observation;
|
import org.deeplearning4j.rl4j.observation.Observation;
|
||||||
import org.deeplearning4j.rl4j.space.ActionSpace;
|
import org.deeplearning4j.rl4j.space.ActionSpace;
|
||||||
|
import org.deeplearning4j.rl4j.space.Encodable;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.api.rng.Random;
|
import org.nd4j.linalg.api.rng.Random;
|
||||||
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @author rubenfiszel (ruben.fiszel@epfl.ch) 7/24/16.
|
* @author rubenfiszel (ruben.fiszel@epfl.ch) 7/24/16.
|
||||||
|
@ -38,18 +41,60 @@ import org.nd4j.linalg.api.rng.Random;
|
||||||
* epislon is annealed to minEpsilon over epsilonNbStep steps
|
* epislon is annealed to minEpsilon over epsilonNbStep steps
|
||||||
*
|
*
|
||||||
*/
|
*/
|
||||||
@AllArgsConstructor
|
|
||||||
@Slf4j
|
@Slf4j
|
||||||
public class EpsGreedy<OBSERVATION extends Encodable, A, AS extends ActionSpace<A>> extends Policy<A> {
|
public class EpsGreedy<A> extends Policy<A> {
|
||||||
|
|
||||||
final private Policy<A> policy;
|
final private INeuralNetPolicy<A> policy;
|
||||||
final private MDP<OBSERVATION, A, AS> mdp;
|
|
||||||
final private int updateStart;
|
final private int updateStart;
|
||||||
final private int epsilonNbStep;
|
final private int epsilonNbStep;
|
||||||
final private Random rnd;
|
final private Random rnd;
|
||||||
final private double minEpsilon;
|
final private double minEpsilon;
|
||||||
|
|
||||||
|
private final IActionSchema<A> actionSchema;
|
||||||
|
|
||||||
|
final private MDP<Encodable, A, ActionSpace<A>> mdp;
|
||||||
final private IEpochTrainer learning;
|
final private IEpochTrainer learning;
|
||||||
|
|
||||||
|
// Using agent's (learning's) step count is incorrect; frame skipping makes epsilon's value decrease too quickly
|
||||||
|
private int annealingStep = 0;
|
||||||
|
|
||||||
|
@Deprecated
|
||||||
|
public <OBSERVATION extends Encodable, AS extends ActionSpace<A>> EpsGreedy(Policy<A> policy,
|
||||||
|
MDP<Encodable, A, ActionSpace<A>> mdp,
|
||||||
|
int updateStart,
|
||||||
|
int epsilonNbStep,
|
||||||
|
Random rnd,
|
||||||
|
double minEpsilon,
|
||||||
|
IEpochTrainer learning) {
|
||||||
|
this.policy = policy;
|
||||||
|
this.mdp = mdp;
|
||||||
|
this.updateStart = updateStart;
|
||||||
|
this.epsilonNbStep = epsilonNbStep;
|
||||||
|
this.rnd = rnd;
|
||||||
|
this.minEpsilon = minEpsilon;
|
||||||
|
this.learning = learning;
|
||||||
|
|
||||||
|
this.actionSchema = null;
|
||||||
|
}
|
||||||
|
|
||||||
|
public EpsGreedy(@NonNull Policy<A> policy, @NonNull IActionSchema<A> actionSchema, double minEpsilon, int updateStart, int epsilonNbStep) {
|
||||||
|
this(policy, actionSchema, minEpsilon, updateStart, epsilonNbStep, null);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Builder
|
||||||
|
public EpsGreedy(@NonNull INeuralNetPolicy<A> policy, @NonNull IActionSchema<A> actionSchema, double minEpsilon, int updateStart, int epsilonNbStep, Random rnd) {
|
||||||
|
this.policy = policy;
|
||||||
|
|
||||||
|
this.rnd = rnd == null ? Nd4j.getRandom() : rnd;
|
||||||
|
this.minEpsilon = minEpsilon;
|
||||||
|
this.updateStart = updateStart;
|
||||||
|
this.epsilonNbStep = epsilonNbStep;
|
||||||
|
this.actionSchema = actionSchema;
|
||||||
|
|
||||||
|
this.mdp = null;
|
||||||
|
this.learning = null;
|
||||||
|
}
|
||||||
|
|
||||||
public NeuralNet getNeuralNet() {
|
public NeuralNet getNeuralNet() {
|
||||||
return policy.getNeuralNet();
|
return policy.getNeuralNet();
|
||||||
}
|
}
|
||||||
|
@ -57,6 +102,11 @@ public class EpsGreedy<OBSERVATION extends Encodable, A, AS extends ActionSpace<
|
||||||
public A nextAction(INDArray input) {
|
public A nextAction(INDArray input) {
|
||||||
|
|
||||||
double ep = getEpsilon();
|
double ep = getEpsilon();
|
||||||
|
if(actionSchema != null) {
|
||||||
|
// Only legacy classes should pass here.
|
||||||
|
throw new RuntimeException("nextAction(Observation observation) should be called when using a AgentLearner");
|
||||||
|
}
|
||||||
|
|
||||||
if (learning.getStepCount() % 500 == 1)
|
if (learning.getStepCount() % 500 == 1)
|
||||||
log.info("EP: " + ep + " " + learning.getStepCount());
|
log.info("EP: " + ep + " " + learning.getStepCount());
|
||||||
if (rnd.nextDouble() > ep)
|
if (rnd.nextDouble() > ep)
|
||||||
|
@ -66,10 +116,31 @@ public class EpsGreedy<OBSERVATION extends Encodable, A, AS extends ActionSpace<
|
||||||
}
|
}
|
||||||
|
|
||||||
public A nextAction(Observation observation) {
|
public A nextAction(Observation observation) {
|
||||||
|
if(actionSchema == null) {
|
||||||
return this.nextAction(observation.getData());
|
return this.nextAction(observation.getData());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
A result;
|
||||||
|
|
||||||
|
double ep = getEpsilon();
|
||||||
|
if (annealingStep % 500 == 1) {
|
||||||
|
log.info("EP: " + ep + " " + annealingStep);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (rnd.nextDouble() > ep) {
|
||||||
|
result = policy.nextAction(observation);
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
result = actionSchema.getRandomAction();
|
||||||
|
}
|
||||||
|
|
||||||
|
++annealingStep;
|
||||||
|
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
public double getEpsilon() {
|
public double getEpsilon() {
|
||||||
return Math.min(1.0, Math.max(minEpsilon, 1.0 - (learning.getStepCount() - updateStart) * 1.0 / epsilonNbStep));
|
int step = actionSchema != null ? annealingStep : learning.getStepCount();
|
||||||
|
return Math.min(1.0, Math.max(minEpsilon, 1.0 - (step - updateStart) * 1.0 / epsilonNbStep));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -0,0 +1,7 @@
|
||||||
|
package org.deeplearning4j.rl4j.policy;
|
||||||
|
|
||||||
|
import org.deeplearning4j.rl4j.network.NeuralNet;
|
||||||
|
|
||||||
|
public interface INeuralNetPolicy<ACTION> extends IPolicy<ACTION> {
|
||||||
|
NeuralNet getNeuralNet();
|
||||||
|
}
|
|
@ -34,7 +34,7 @@ import org.deeplearning4j.rl4j.util.LegacyMDPWrapper;
|
||||||
*
|
*
|
||||||
* A Policy responsability is to choose the next action given a state
|
* A Policy responsability is to choose the next action given a state
|
||||||
*/
|
*/
|
||||||
public abstract class Policy<A> implements IPolicy<A> {
|
public abstract class Policy<A> implements INeuralNetPolicy<A> {
|
||||||
|
|
||||||
public abstract NeuralNet getNeuralNet();
|
public abstract NeuralNet getNeuralNet();
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,211 @@
|
||||||
|
package org.deeplearning4j.rl4j.agent;
|
||||||
|
|
||||||
|
import org.deeplearning4j.rl4j.agent.learning.LearningBehavior;
|
||||||
|
import org.deeplearning4j.rl4j.environment.Environment;
|
||||||
|
import org.deeplearning4j.rl4j.environment.IntegerActionSchema;
|
||||||
|
import org.deeplearning4j.rl4j.environment.Schema;
|
||||||
|
import org.deeplearning4j.rl4j.environment.StepResult;
|
||||||
|
import org.deeplearning4j.rl4j.observation.Observation;
|
||||||
|
import org.deeplearning4j.rl4j.observation.transform.TransformProcess;
|
||||||
|
import org.deeplearning4j.rl4j.policy.IPolicy;
|
||||||
|
import org.junit.Test;
|
||||||
|
import org.junit.runner.RunWith;
|
||||||
|
import org.mockito.ArgumentCaptor;
|
||||||
|
import org.mockito.Mock;
|
||||||
|
import org.mockito.invocation.InvocationOnMock;
|
||||||
|
import org.mockito.junit.MockitoJUnitRunner;
|
||||||
|
import org.mockito.stubbing.Answer;
|
||||||
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
|
|
||||||
|
import java.util.HashMap;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.Map;
|
||||||
|
|
||||||
|
import static org.mockito.ArgumentMatchers.*;
|
||||||
|
import static org.mockito.Mockito.*;
|
||||||
|
import static org.junit.Assert.*;
|
||||||
|
|
||||||
|
@RunWith(MockitoJUnitRunner.class)
|
||||||
|
public class AgentLearnerTest {
|
||||||
|
|
||||||
|
@Mock
|
||||||
|
Environment<Integer> environmentMock;
|
||||||
|
|
||||||
|
@Mock
|
||||||
|
TransformProcess transformProcessMock;
|
||||||
|
|
||||||
|
@Mock
|
||||||
|
IPolicy<Integer> policyMock;
|
||||||
|
|
||||||
|
@Mock
|
||||||
|
LearningBehavior<Integer, Object> learningBehaviorMock;
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void when_episodeIsStarted_expect_learningBehaviorHandleEpisodeStartCalled() {
|
||||||
|
// Arrange
|
||||||
|
AgentLearner<Integer> sut = AgentLearner.builder(environmentMock, transformProcessMock, policyMock, learningBehaviorMock)
|
||||||
|
.maxEpisodeSteps(3)
|
||||||
|
.build();
|
||||||
|
|
||||||
|
Schema schema = new Schema(new IntegerActionSchema(0, -1));
|
||||||
|
when(environmentMock.reset()).thenReturn(new HashMap<>());
|
||||||
|
when(environmentMock.getSchema()).thenReturn(schema);
|
||||||
|
StepResult stepResult = new StepResult(new HashMap<>(), 234.0, false);
|
||||||
|
when(environmentMock.step(any(Integer.class))).thenReturn(stepResult);
|
||||||
|
|
||||||
|
when(transformProcessMock.transform(any(Map.class), anyInt(), anyBoolean())).thenReturn(new Observation(Nd4j.create(new double[] { 123.0 })));
|
||||||
|
|
||||||
|
when(policyMock.nextAction(any(Observation.class))).thenReturn(123);
|
||||||
|
|
||||||
|
// Act
|
||||||
|
sut.run();
|
||||||
|
|
||||||
|
// Assert
|
||||||
|
verify(learningBehaviorMock, times(1)).handleEpisodeStart();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void when_runIsCalled_expect_experienceHandledWithLearningBehavior() {
|
||||||
|
// Arrange
|
||||||
|
AgentLearner<Integer> sut = AgentLearner.builder(environmentMock, transformProcessMock, policyMock, learningBehaviorMock)
|
||||||
|
.maxEpisodeSteps(4)
|
||||||
|
.build();
|
||||||
|
|
||||||
|
Schema schema = new Schema(new IntegerActionSchema(0, -1));
|
||||||
|
when(environmentMock.getSchema()).thenReturn(schema);
|
||||||
|
when(environmentMock.reset()).thenReturn(new HashMap<>());
|
||||||
|
|
||||||
|
double[] reward = new double[] { 0.0 };
|
||||||
|
when(environmentMock.step(any(Integer.class)))
|
||||||
|
.thenAnswer(a -> new StepResult(new HashMap<>(), ++reward[0], reward[0] == 4.0));
|
||||||
|
|
||||||
|
when(environmentMock.isEpisodeFinished()).thenAnswer(x -> reward[0] == 4.0);
|
||||||
|
|
||||||
|
when(transformProcessMock.transform(any(Map.class), anyInt(), anyBoolean()))
|
||||||
|
.thenAnswer(new Answer<Observation>() {
|
||||||
|
public Observation answer(InvocationOnMock invocation) throws Throwable {
|
||||||
|
int step = (int)invocation.getArgument(1);
|
||||||
|
boolean isTerminal = (boolean)invocation.getArgument(2);
|
||||||
|
return (step % 2 == 0 || isTerminal)
|
||||||
|
? new Observation(Nd4j.create(new double[] { step * 1.1 }))
|
||||||
|
: Observation.SkippedObservation;
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
when(policyMock.nextAction(any(Observation.class))).thenAnswer(x -> (int)reward[0]);
|
||||||
|
|
||||||
|
// Act
|
||||||
|
sut.run();
|
||||||
|
|
||||||
|
// Assert
|
||||||
|
ArgumentCaptor<Observation> observationCaptor = ArgumentCaptor.forClass(Observation.class);
|
||||||
|
ArgumentCaptor<Integer> actionCaptor = ArgumentCaptor.forClass(Integer.class);
|
||||||
|
ArgumentCaptor<Double> rewardCaptor = ArgumentCaptor.forClass(Double.class);
|
||||||
|
ArgumentCaptor<Boolean> isTerminalCaptor = ArgumentCaptor.forClass(Boolean.class);
|
||||||
|
|
||||||
|
verify(learningBehaviorMock, times(2)).handleNewExperience(observationCaptor.capture(), actionCaptor.capture(), rewardCaptor.capture(), isTerminalCaptor.capture());
|
||||||
|
List<Observation> observations = observationCaptor.getAllValues();
|
||||||
|
List<Integer> actions = actionCaptor.getAllValues();
|
||||||
|
List<Double> rewards = rewardCaptor.getAllValues();
|
||||||
|
List<Boolean> isTerminalList = isTerminalCaptor.getAllValues();
|
||||||
|
|
||||||
|
assertEquals(0.0, observations.get(0).getData().getDouble(0), 0.00001);
|
||||||
|
assertEquals(0, (int)actions.get(0));
|
||||||
|
assertEquals(0.0 + 1.0, rewards.get(0), 0.00001);
|
||||||
|
assertFalse(isTerminalList.get(0));
|
||||||
|
|
||||||
|
assertEquals(2.2, observations.get(1).getData().getDouble(0), 0.00001);
|
||||||
|
assertEquals(2, (int)actions.get(1));
|
||||||
|
assertEquals(2.0 + 3.0, rewards.get(1), 0.00001);
|
||||||
|
assertFalse(isTerminalList.get(1));
|
||||||
|
|
||||||
|
ArgumentCaptor<Observation> finalObservationCaptor = ArgumentCaptor.forClass(Observation.class);
|
||||||
|
verify(learningBehaviorMock, times(1)).handleEpisodeEnd(finalObservationCaptor.capture());
|
||||||
|
assertEquals(4.4, finalObservationCaptor.getValue().getData().getDouble(0), 0.00001);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void when_runIsCalledMultipleTimes_expect_totalStepCountCorrect() {
|
||||||
|
// Arrange
|
||||||
|
AgentLearner<Integer> sut = AgentLearner.builder(environmentMock, transformProcessMock, policyMock, learningBehaviorMock)
|
||||||
|
.maxEpisodeSteps(4)
|
||||||
|
.build();
|
||||||
|
|
||||||
|
Schema schema = new Schema(new IntegerActionSchema(0, -1));
|
||||||
|
when(environmentMock.getSchema()).thenReturn(schema);
|
||||||
|
when(environmentMock.reset()).thenReturn(new HashMap<>());
|
||||||
|
|
||||||
|
double[] reward = new double[] { 0.0 };
|
||||||
|
when(environmentMock.step(any(Integer.class)))
|
||||||
|
.thenAnswer(a -> new StepResult(new HashMap<>(), ++reward[0], reward[0] == 4.0));
|
||||||
|
|
||||||
|
when(environmentMock.isEpisodeFinished()).thenAnswer(x -> reward[0] == 4.0);
|
||||||
|
|
||||||
|
when(transformProcessMock.transform(any(Map.class), anyInt(), anyBoolean()))
|
||||||
|
.thenAnswer(new Answer<Observation>() {
|
||||||
|
public Observation answer(InvocationOnMock invocation) throws Throwable {
|
||||||
|
int step = (int)invocation.getArgument(1);
|
||||||
|
boolean isTerminal = (boolean)invocation.getArgument(2);
|
||||||
|
return (step % 2 == 0 || isTerminal)
|
||||||
|
? new Observation(Nd4j.create(new double[] { step * 1.1 }))
|
||||||
|
: Observation.SkippedObservation;
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
when(policyMock.nextAction(any(Observation.class))).thenAnswer(x -> (int)reward[0]);
|
||||||
|
|
||||||
|
// Act
|
||||||
|
sut.run();
|
||||||
|
reward[0] = 0.0;
|
||||||
|
sut.run();
|
||||||
|
|
||||||
|
// Assert
|
||||||
|
assertEquals(8, sut.getTotalStepCount());
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void when_runIsCalledMultipleTimes_expect_rewardSentToLearningBehaviorToBeCorrect() {
|
||||||
|
// Arrange
|
||||||
|
AgentLearner<Integer> sut = AgentLearner.builder(environmentMock, transformProcessMock, policyMock, learningBehaviorMock)
|
||||||
|
.maxEpisodeSteps(4)
|
||||||
|
.build();
|
||||||
|
|
||||||
|
Schema schema = new Schema(new IntegerActionSchema(0, -1));
|
||||||
|
when(environmentMock.getSchema()).thenReturn(schema);
|
||||||
|
when(environmentMock.reset()).thenReturn(new HashMap<>());
|
||||||
|
|
||||||
|
double[] reward = new double[] { 0.0 };
|
||||||
|
when(environmentMock.step(any(Integer.class)))
|
||||||
|
.thenAnswer(a -> new StepResult(new HashMap<>(), ++reward[0], reward[0] == 4.0));
|
||||||
|
|
||||||
|
when(environmentMock.isEpisodeFinished()).thenAnswer(x -> reward[0] == 4.0);
|
||||||
|
|
||||||
|
when(transformProcessMock.transform(any(Map.class), anyInt(), anyBoolean()))
|
||||||
|
.thenAnswer(new Answer<Observation>() {
|
||||||
|
public Observation answer(InvocationOnMock invocation) throws Throwable {
|
||||||
|
int step = (int)invocation.getArgument(1);
|
||||||
|
boolean isTerminal = (boolean)invocation.getArgument(2);
|
||||||
|
return (step % 2 == 0 || isTerminal)
|
||||||
|
? new Observation(Nd4j.create(new double[] { step * 1.1 }))
|
||||||
|
: Observation.SkippedObservation;
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
when(policyMock.nextAction(any(Observation.class))).thenAnswer(x -> (int)reward[0]);
|
||||||
|
|
||||||
|
// Act
|
||||||
|
sut.run();
|
||||||
|
reward[0] = 0.0;
|
||||||
|
sut.run();
|
||||||
|
|
||||||
|
// Assert
|
||||||
|
ArgumentCaptor<Double> rewardCaptor = ArgumentCaptor.forClass(Double.class);
|
||||||
|
|
||||||
|
verify(learningBehaviorMock, times(4)).handleNewExperience(any(Observation.class), any(Integer.class), rewardCaptor.capture(), any(Boolean.class));
|
||||||
|
List<Double> rewards = rewardCaptor.getAllValues();
|
||||||
|
|
||||||
|
// rewardAtLastExperience at the end of 1st call to .run() should not leak into 2nd call.
|
||||||
|
assertEquals(0.0 + 1.0, rewards.get(2), 0.00001);
|
||||||
|
assertEquals(2.0 + 3.0, rewards.get(3), 0.00001);
|
||||||
|
}
|
||||||
|
}
|
|
@ -1,10 +1,7 @@
|
||||||
package org.deeplearning4j.rl4j.agent;
|
package org.deeplearning4j.rl4j.agent;
|
||||||
|
|
||||||
import org.deeplearning4j.rl4j.agent.listener.AgentListener;
|
import org.deeplearning4j.rl4j.agent.listener.AgentListener;
|
||||||
import org.deeplearning4j.rl4j.environment.ActionSchema;
|
import org.deeplearning4j.rl4j.environment.*;
|
||||||
import org.deeplearning4j.rl4j.environment.Environment;
|
|
||||||
import org.deeplearning4j.rl4j.environment.Schema;
|
|
||||||
import org.deeplearning4j.rl4j.environment.StepResult;
|
|
||||||
import org.deeplearning4j.rl4j.observation.Observation;
|
import org.deeplearning4j.rl4j.observation.Observation;
|
||||||
import org.deeplearning4j.rl4j.observation.transform.TransformProcess;
|
import org.deeplearning4j.rl4j.observation.transform.TransformProcess;
|
||||||
import org.deeplearning4j.rl4j.policy.IPolicy;
|
import org.deeplearning4j.rl4j.policy.IPolicy;
|
||||||
|
@ -12,6 +9,7 @@ import org.junit.Rule;
|
||||||
import org.junit.Test;
|
import org.junit.Test;
|
||||||
import static org.junit.Assert.*;
|
import static org.junit.Assert.*;
|
||||||
|
|
||||||
|
import org.junit.runner.RunWith;
|
||||||
import org.mockito.*;
|
import org.mockito.*;
|
||||||
import org.mockito.junit.*;
|
import org.mockito.junit.*;
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
|
@ -23,8 +21,8 @@ import java.util.Map;
|
||||||
import static org.mockito.ArgumentMatchers.*;
|
import static org.mockito.ArgumentMatchers.*;
|
||||||
import static org.mockito.Mockito.*;
|
import static org.mockito.Mockito.*;
|
||||||
|
|
||||||
|
@RunWith(MockitoJUnitRunner.class)
|
||||||
public class AgentTest {
|
public class AgentTest {
|
||||||
|
|
||||||
@Mock Environment environmentMock;
|
@Mock Environment environmentMock;
|
||||||
@Mock TransformProcess transformProcessMock;
|
@Mock TransformProcess transformProcessMock;
|
||||||
@Mock IPolicy policyMock;
|
@Mock IPolicy policyMock;
|
||||||
|
@ -102,7 +100,7 @@ public class AgentTest {
|
||||||
public void when_runIsCalled_expect_agentIsReset() {
|
public void when_runIsCalled_expect_agentIsReset() {
|
||||||
// Arrange
|
// Arrange
|
||||||
Map<String, Object> envResetResult = new HashMap<>();
|
Map<String, Object> envResetResult = new HashMap<>();
|
||||||
Schema schema = new Schema(new ActionSchema<>(-1));
|
Schema schema = new Schema(new IntegerActionSchema(0, -1));
|
||||||
when(environmentMock.reset()).thenReturn(envResetResult);
|
when(environmentMock.reset()).thenReturn(envResetResult);
|
||||||
when(environmentMock.getSchema()).thenReturn(schema);
|
when(environmentMock.getSchema()).thenReturn(schema);
|
||||||
|
|
||||||
|
@ -119,7 +117,7 @@ public class AgentTest {
|
||||||
sut.run();
|
sut.run();
|
||||||
|
|
||||||
// Assert
|
// Assert
|
||||||
assertEquals(0, sut.getEpisodeStepNumber());
|
assertEquals(0, sut.getEpisodeStepCount());
|
||||||
verify(transformProcessMock).transform(envResetResult, 0, false);
|
verify(transformProcessMock).transform(envResetResult, 0, false);
|
||||||
verify(policyMock, times(1)).reset();
|
verify(policyMock, times(1)).reset();
|
||||||
assertEquals(0.0, sut.getReward(), 0.00001);
|
assertEquals(0.0, sut.getReward(), 0.00001);
|
||||||
|
@ -130,7 +128,7 @@ public class AgentTest {
|
||||||
public void when_runIsCalled_expect_onBeforeAndAfterEpisodeCalled() {
|
public void when_runIsCalled_expect_onBeforeAndAfterEpisodeCalled() {
|
||||||
// Arrange
|
// Arrange
|
||||||
Map<String, Object> envResetResult = new HashMap<>();
|
Map<String, Object> envResetResult = new HashMap<>();
|
||||||
Schema schema = new Schema(new ActionSchema<>(-1));
|
Schema schema = new Schema(new IntegerActionSchema(0, -1));
|
||||||
when(environmentMock.reset()).thenReturn(envResetResult);
|
when(environmentMock.reset()).thenReturn(envResetResult);
|
||||||
when(environmentMock.getSchema()).thenReturn(schema);
|
when(environmentMock.getSchema()).thenReturn(schema);
|
||||||
|
|
||||||
|
@ -152,7 +150,7 @@ public class AgentTest {
|
||||||
public void when_onBeforeEpisodeReturnsStop_expect_performStepAndOnAfterEpisodeNotCalled() {
|
public void when_onBeforeEpisodeReturnsStop_expect_performStepAndOnAfterEpisodeNotCalled() {
|
||||||
// Arrange
|
// Arrange
|
||||||
Map<String, Object> envResetResult = new HashMap<>();
|
Map<String, Object> envResetResult = new HashMap<>();
|
||||||
Schema schema = new Schema(new ActionSchema<>(-1));
|
Schema schema = new Schema(new IntegerActionSchema(0, -1));
|
||||||
when(environmentMock.reset()).thenReturn(envResetResult);
|
when(environmentMock.reset()).thenReturn(envResetResult);
|
||||||
when(environmentMock.getSchema()).thenReturn(schema);
|
when(environmentMock.getSchema()).thenReturn(schema);
|
||||||
|
|
||||||
|
@ -179,7 +177,7 @@ public class AgentTest {
|
||||||
public void when_runIsCalledWithoutMaxStep_expect_agentRunUntilEpisodeIsFinished() {
|
public void when_runIsCalledWithoutMaxStep_expect_agentRunUntilEpisodeIsFinished() {
|
||||||
// Arrange
|
// Arrange
|
||||||
Map<String, Object> envResetResult = new HashMap<>();
|
Map<String, Object> envResetResult = new HashMap<>();
|
||||||
Schema schema = new Schema(new ActionSchema<>(-1));
|
Schema schema = new Schema(new IntegerActionSchema(0, -1));
|
||||||
when(environmentMock.reset()).thenReturn(envResetResult);
|
when(environmentMock.reset()).thenReturn(envResetResult);
|
||||||
when(environmentMock.getSchema()).thenReturn(schema);
|
when(environmentMock.getSchema()).thenReturn(schema);
|
||||||
|
|
||||||
|
@ -191,10 +189,10 @@ public class AgentTest {
|
||||||
final Agent spy = Mockito.spy(sut);
|
final Agent spy = Mockito.spy(sut);
|
||||||
|
|
||||||
doAnswer(invocation -> {
|
doAnswer(invocation -> {
|
||||||
((Agent)invocation.getMock()).incrementEpisodeStepNumber();
|
((Agent)invocation.getMock()).incrementEpisodeStepCount();
|
||||||
return null;
|
return null;
|
||||||
}).when(spy).performStep();
|
}).when(spy).performStep();
|
||||||
when(environmentMock.isEpisodeFinished()).thenAnswer(invocation -> spy.getEpisodeStepNumber() >= 5 );
|
when(environmentMock.isEpisodeFinished()).thenAnswer(invocation -> spy.getEpisodeStepCount() >= 5 );
|
||||||
|
|
||||||
// Act
|
// Act
|
||||||
spy.run();
|
spy.run();
|
||||||
|
@ -209,7 +207,7 @@ public class AgentTest {
|
||||||
public void when_maxStepsIsReachedBeforeEposideEnds_expect_runTerminated() {
|
public void when_maxStepsIsReachedBeforeEposideEnds_expect_runTerminated() {
|
||||||
// Arrange
|
// Arrange
|
||||||
Map<String, Object> envResetResult = new HashMap<>();
|
Map<String, Object> envResetResult = new HashMap<>();
|
||||||
Schema schema = new Schema(new ActionSchema<>(-1));
|
Schema schema = new Schema(new IntegerActionSchema(0, -1));
|
||||||
when(environmentMock.reset()).thenReturn(envResetResult);
|
when(environmentMock.reset()).thenReturn(envResetResult);
|
||||||
when(environmentMock.getSchema()).thenReturn(schema);
|
when(environmentMock.getSchema()).thenReturn(schema);
|
||||||
|
|
||||||
|
@ -222,7 +220,7 @@ public class AgentTest {
|
||||||
final Agent spy = Mockito.spy(sut);
|
final Agent spy = Mockito.spy(sut);
|
||||||
|
|
||||||
doAnswer(invocation -> {
|
doAnswer(invocation -> {
|
||||||
((Agent)invocation.getMock()).incrementEpisodeStepNumber();
|
((Agent)invocation.getMock()).incrementEpisodeStepCount();
|
||||||
return null;
|
return null;
|
||||||
}).when(spy).performStep();
|
}).when(spy).performStep();
|
||||||
|
|
||||||
|
@ -239,7 +237,7 @@ public class AgentTest {
|
||||||
public void when_initialObservationsAreSkipped_expect_performNoOpAction() {
|
public void when_initialObservationsAreSkipped_expect_performNoOpAction() {
|
||||||
// Arrange
|
// Arrange
|
||||||
Map<String, Object> envResetResult = new HashMap<>();
|
Map<String, Object> envResetResult = new HashMap<>();
|
||||||
Schema schema = new Schema(new ActionSchema<>(-1));
|
Schema schema = new Schema(new IntegerActionSchema(0, -1));
|
||||||
when(environmentMock.reset()).thenReturn(envResetResult);
|
when(environmentMock.reset()).thenReturn(envResetResult);
|
||||||
when(environmentMock.getSchema()).thenReturn(schema);
|
when(environmentMock.getSchema()).thenReturn(schema);
|
||||||
|
|
||||||
|
@ -264,7 +262,7 @@ public class AgentTest {
|
||||||
public void when_initialObservationsAreSkipped_expect_performNoOpActionAnd() {
|
public void when_initialObservationsAreSkipped_expect_performNoOpActionAnd() {
|
||||||
// Arrange
|
// Arrange
|
||||||
Map<String, Object> envResetResult = new HashMap<>();
|
Map<String, Object> envResetResult = new HashMap<>();
|
||||||
Schema schema = new Schema(new ActionSchema<>(-1));
|
Schema schema = new Schema(new IntegerActionSchema(0, -1));
|
||||||
when(environmentMock.reset()).thenReturn(envResetResult);
|
when(environmentMock.reset()).thenReturn(envResetResult);
|
||||||
when(environmentMock.getSchema()).thenReturn(schema);
|
when(environmentMock.getSchema()).thenReturn(schema);
|
||||||
|
|
||||||
|
@ -289,7 +287,7 @@ public class AgentTest {
|
||||||
public void when_observationsIsSkipped_expect_performLastAction() {
|
public void when_observationsIsSkipped_expect_performLastAction() {
|
||||||
// Arrange
|
// Arrange
|
||||||
Map<String, Object> envResetResult = new HashMap<>();
|
Map<String, Object> envResetResult = new HashMap<>();
|
||||||
Schema schema = new Schema(new ActionSchema<>(-1));
|
Schema schema = new Schema(new IntegerActionSchema(0, -1));
|
||||||
when(environmentMock.reset()).thenReturn(envResetResult);
|
when(environmentMock.reset()).thenReturn(envResetResult);
|
||||||
when(environmentMock.step(any(Integer.class))).thenReturn(new StepResult(envResetResult, 0.0, false));
|
when(environmentMock.step(any(Integer.class))).thenReturn(new StepResult(envResetResult, 0.0, false));
|
||||||
when(environmentMock.getSchema()).thenReturn(schema);
|
when(environmentMock.getSchema()).thenReturn(schema);
|
||||||
|
@ -331,7 +329,7 @@ public class AgentTest {
|
||||||
@Test
|
@Test
|
||||||
public void when_onBeforeStepReturnsStop_expect_performStepAndOnAfterEpisodeNotCalled() {
|
public void when_onBeforeStepReturnsStop_expect_performStepAndOnAfterEpisodeNotCalled() {
|
||||||
// Arrange
|
// Arrange
|
||||||
Schema schema = new Schema(new ActionSchema<>(-1));
|
Schema schema = new Schema(new IntegerActionSchema(0, -1));
|
||||||
when(environmentMock.reset()).thenReturn(new HashMap<>());
|
when(environmentMock.reset()).thenReturn(new HashMap<>());
|
||||||
when(environmentMock.getSchema()).thenReturn(schema);
|
when(environmentMock.getSchema()).thenReturn(schema);
|
||||||
|
|
||||||
|
@ -358,7 +356,7 @@ public class AgentTest {
|
||||||
@Test
|
@Test
|
||||||
public void when_observationIsNotSkipped_expect_policyActionIsSentToEnvironment() {
|
public void when_observationIsNotSkipped_expect_policyActionIsSentToEnvironment() {
|
||||||
// Arrange
|
// Arrange
|
||||||
Schema schema = new Schema(new ActionSchema<>(-1));
|
Schema schema = new Schema(new IntegerActionSchema(0, -1));
|
||||||
when(environmentMock.reset()).thenReturn(new HashMap<>());
|
when(environmentMock.reset()).thenReturn(new HashMap<>());
|
||||||
when(environmentMock.getSchema()).thenReturn(schema);
|
when(environmentMock.getSchema()).thenReturn(schema);
|
||||||
when(environmentMock.step(any(Integer.class))).thenReturn(new StepResult(new HashMap<>(), 0.0, false));
|
when(environmentMock.step(any(Integer.class))).thenReturn(new StepResult(new HashMap<>(), 0.0, false));
|
||||||
|
@ -381,7 +379,7 @@ public class AgentTest {
|
||||||
@Test
|
@Test
|
||||||
public void when_stepResultIsReceived_expect_observationAndRewardUpdated() {
|
public void when_stepResultIsReceived_expect_observationAndRewardUpdated() {
|
||||||
// Arrange
|
// Arrange
|
||||||
Schema schema = new Schema(new ActionSchema<>(-1));
|
Schema schema = new Schema(new IntegerActionSchema(0, -1));
|
||||||
when(environmentMock.reset()).thenReturn(new HashMap<>());
|
when(environmentMock.reset()).thenReturn(new HashMap<>());
|
||||||
when(environmentMock.getSchema()).thenReturn(schema);
|
when(environmentMock.getSchema()).thenReturn(schema);
|
||||||
when(environmentMock.step(any(Integer.class))).thenReturn(new StepResult(new HashMap<>(), 234.0, false));
|
when(environmentMock.step(any(Integer.class))).thenReturn(new StepResult(new HashMap<>(), 234.0, false));
|
||||||
|
@ -405,7 +403,7 @@ public class AgentTest {
|
||||||
@Test
|
@Test
|
||||||
public void when_stepIsDone_expect_onAfterStepAndWithStepResult() {
|
public void when_stepIsDone_expect_onAfterStepAndWithStepResult() {
|
||||||
// Arrange
|
// Arrange
|
||||||
Schema schema = new Schema(new ActionSchema<>(-1));
|
Schema schema = new Schema(new IntegerActionSchema(0, -1));
|
||||||
when(environmentMock.reset()).thenReturn(new HashMap<>());
|
when(environmentMock.reset()).thenReturn(new HashMap<>());
|
||||||
when(environmentMock.getSchema()).thenReturn(schema);
|
when(environmentMock.getSchema()).thenReturn(schema);
|
||||||
StepResult stepResult = new StepResult(new HashMap<>(), 234.0, false);
|
StepResult stepResult = new StepResult(new HashMap<>(), 234.0, false);
|
||||||
|
@ -430,7 +428,7 @@ public class AgentTest {
|
||||||
@Test
|
@Test
|
||||||
public void when_onAfterStepReturnsStop_expect_onAfterEpisodeNotCalled() {
|
public void when_onAfterStepReturnsStop_expect_onAfterEpisodeNotCalled() {
|
||||||
// Arrange
|
// Arrange
|
||||||
Schema schema = new Schema(new ActionSchema<>(-1));
|
Schema schema = new Schema(new IntegerActionSchema(0, -1));
|
||||||
when(environmentMock.reset()).thenReturn(new HashMap<>());
|
when(environmentMock.reset()).thenReturn(new HashMap<>());
|
||||||
when(environmentMock.getSchema()).thenReturn(schema);
|
when(environmentMock.getSchema()).thenReturn(schema);
|
||||||
StepResult stepResult = new StepResult(new HashMap<>(), 234.0, false);
|
StepResult stepResult = new StepResult(new HashMap<>(), 234.0, false);
|
||||||
|
@ -458,7 +456,7 @@ public class AgentTest {
|
||||||
@Test
|
@Test
|
||||||
public void when_runIsCalled_expect_onAfterEpisodeIsCalled() {
|
public void when_runIsCalled_expect_onAfterEpisodeIsCalled() {
|
||||||
// Arrange
|
// Arrange
|
||||||
Schema schema = new Schema(new ActionSchema<>(-1));
|
Schema schema = new Schema(new IntegerActionSchema(0, -1));
|
||||||
when(environmentMock.reset()).thenReturn(new HashMap<>());
|
when(environmentMock.reset()).thenReturn(new HashMap<>());
|
||||||
when(environmentMock.getSchema()).thenReturn(schema);
|
when(environmentMock.getSchema()).thenReturn(schema);
|
||||||
StepResult stepResult = new StepResult(new HashMap<>(), 234.0, false);
|
StepResult stepResult = new StepResult(new HashMap<>(), 234.0, false);
|
||||||
|
|
|
@ -0,0 +1,133 @@
|
||||||
|
package org.deeplearning4j.rl4j.agent.learning;
|
||||||
|
|
||||||
|
import org.deeplearning4j.rl4j.agent.update.IUpdateRule;
|
||||||
|
import org.deeplearning4j.rl4j.experience.ExperienceHandler;
|
||||||
|
import org.deeplearning4j.rl4j.observation.Observation;
|
||||||
|
import org.junit.Before;
|
||||||
|
import org.junit.Test;
|
||||||
|
import org.junit.runner.RunWith;
|
||||||
|
import org.mockito.ArgumentCaptor;
|
||||||
|
import org.mockito.Mock;
|
||||||
|
import org.mockito.junit.MockitoJUnitRunner;
|
||||||
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
|
|
||||||
|
import java.util.ArrayList;
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
|
import static org.junit.Assert.assertEquals;
|
||||||
|
import static org.junit.Assert.assertFalse;
|
||||||
|
import static org.mockito.ArgumentMatchers.any;
|
||||||
|
import static org.mockito.Mockito.*;
|
||||||
|
|
||||||
|
@RunWith(MockitoJUnitRunner.class)
|
||||||
|
public class LearningBehaviorTest {
|
||||||
|
|
||||||
|
@Mock
|
||||||
|
ExperienceHandler<Integer, Object> experienceHandlerMock;
|
||||||
|
|
||||||
|
@Mock
|
||||||
|
IUpdateRule<Object> updateRuleMock;
|
||||||
|
|
||||||
|
LearningBehavior<Integer, Object> sut;
|
||||||
|
|
||||||
|
@Before
|
||||||
|
public void setup() {
|
||||||
|
sut = LearningBehavior.<Integer, Object>builder()
|
||||||
|
.experienceHandler(experienceHandlerMock)
|
||||||
|
.updateRule(updateRuleMock)
|
||||||
|
.build();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void when_callingHandleEpisodeStart_expect_experienceHandlerResetCalled() {
|
||||||
|
// Arrange
|
||||||
|
LearningBehavior<Integer, Object> sut = LearningBehavior.<Integer, Object>builder()
|
||||||
|
.experienceHandler(experienceHandlerMock)
|
||||||
|
.updateRule(updateRuleMock)
|
||||||
|
.build();
|
||||||
|
|
||||||
|
// Act
|
||||||
|
sut.handleEpisodeStart();
|
||||||
|
|
||||||
|
// Assert
|
||||||
|
verify(experienceHandlerMock, times(1)).reset();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void when_callingHandleNewExperience_expect_experienceHandlerAddExperienceCalled() {
|
||||||
|
// Arrange
|
||||||
|
INDArray observationData = Nd4j.rand(1, 1);
|
||||||
|
when(experienceHandlerMock.isTrainingBatchReady()).thenReturn(false);
|
||||||
|
|
||||||
|
// Act
|
||||||
|
sut.handleNewExperience(new Observation(observationData), 1, 2.0, false);
|
||||||
|
|
||||||
|
// Assert
|
||||||
|
ArgumentCaptor<Observation> observationCaptor = ArgumentCaptor.forClass(Observation.class);
|
||||||
|
ArgumentCaptor<Integer> actionCaptor = ArgumentCaptor.forClass(Integer.class);
|
||||||
|
ArgumentCaptor<Double> rewardCaptor = ArgumentCaptor.forClass(Double.class);
|
||||||
|
ArgumentCaptor<Boolean> isTerminatedCaptor = ArgumentCaptor.forClass(Boolean.class);
|
||||||
|
verify(experienceHandlerMock, times(1)).addExperience(observationCaptor.capture(), actionCaptor.capture(), rewardCaptor.capture(), isTerminatedCaptor.capture());
|
||||||
|
|
||||||
|
assertEquals(observationData.getDouble(0, 0), observationCaptor.getValue().getData().getDouble(0, 0), 0.00001);
|
||||||
|
assertEquals(1, (int)actionCaptor.getValue());
|
||||||
|
assertEquals(2.0, (double)rewardCaptor.getValue(), 0.00001);
|
||||||
|
assertFalse(isTerminatedCaptor.getValue());
|
||||||
|
|
||||||
|
verify(updateRuleMock, never()).update(any(List.class));
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void when_callingHandleNewExperienceAndTrainingBatchIsReady_expect_updateRuleUpdateWithTrainingBatch() {
|
||||||
|
// Arrange
|
||||||
|
INDArray observationData = Nd4j.rand(1, 1);
|
||||||
|
when(experienceHandlerMock.isTrainingBatchReady()).thenReturn(true);
|
||||||
|
List<Object> trainingBatch = new ArrayList<Object>();
|
||||||
|
when(experienceHandlerMock.generateTrainingBatch()).thenReturn(trainingBatch);
|
||||||
|
|
||||||
|
// Act
|
||||||
|
sut.handleNewExperience(new Observation(observationData), 1, 2.0, false);
|
||||||
|
|
||||||
|
// Assert
|
||||||
|
verify(updateRuleMock, times(1)).update(trainingBatch);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void when_callingHandleEpisodeEnd_expect_experienceHandlerSetFinalObservationCalled() {
|
||||||
|
// Arrange
|
||||||
|
INDArray observationData = Nd4j.rand(1, 1);
|
||||||
|
when(experienceHandlerMock.isTrainingBatchReady()).thenReturn(false);
|
||||||
|
|
||||||
|
// Act
|
||||||
|
sut.handleEpisodeEnd(new Observation(observationData));
|
||||||
|
|
||||||
|
// Assert
|
||||||
|
ArgumentCaptor<Observation> observationCaptor = ArgumentCaptor.forClass(Observation.class);
|
||||||
|
verify(experienceHandlerMock, times(1)).setFinalObservation(observationCaptor.capture());
|
||||||
|
|
||||||
|
assertEquals(observationData.getDouble(0, 0), observationCaptor.getValue().getData().getDouble(0, 0), 0.00001);
|
||||||
|
|
||||||
|
verify(updateRuleMock, never()).update(any(List.class));
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void when_callingHandleEpisodeEndAndTrainingBatchIsNotEmpty_expect_updateRuleUpdateWithTrainingBatch() {
|
||||||
|
// Arrange
|
||||||
|
INDArray observationData = Nd4j.rand(1, 1);
|
||||||
|
when(experienceHandlerMock.isTrainingBatchReady()).thenReturn(true);
|
||||||
|
List<Object> trainingBatch = new ArrayList<Object>();
|
||||||
|
when(experienceHandlerMock.generateTrainingBatch()).thenReturn(trainingBatch);
|
||||||
|
|
||||||
|
// Act
|
||||||
|
sut.handleEpisodeEnd(new Observation(observationData));
|
||||||
|
|
||||||
|
// Assert
|
||||||
|
ArgumentCaptor<Observation> observationCaptor = ArgumentCaptor.forClass(Observation.class);
|
||||||
|
verify(experienceHandlerMock, times(1)).setFinalObservation(observationCaptor.capture());
|
||||||
|
|
||||||
|
assertEquals(observationData.getDouble(0, 0), observationCaptor.getValue().getData().getDouble(0, 0), 0.00001);
|
||||||
|
|
||||||
|
verify(updateRuleMock, times(1)).update(trainingBatch);
|
||||||
|
}
|
||||||
|
}
|
|
@ -4,34 +4,44 @@ import org.deeplearning4j.rl4j.learning.sync.IExpReplay;
|
||||||
import org.deeplearning4j.rl4j.learning.sync.Transition;
|
import org.deeplearning4j.rl4j.learning.sync.Transition;
|
||||||
import org.deeplearning4j.rl4j.observation.Observation;
|
import org.deeplearning4j.rl4j.observation.Observation;
|
||||||
import org.junit.Test;
|
import org.junit.Test;
|
||||||
|
import org.junit.runner.RunWith;
|
||||||
|
import org.mockito.ArgumentCaptor;
|
||||||
|
import org.mockito.Mock;
|
||||||
|
import org.mockito.junit.MockitoJUnitRunner;
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
|
|
||||||
import java.util.ArrayList;
|
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
|
||||||
import static org.junit.Assert.assertEquals;
|
import static org.junit.Assert.*;
|
||||||
|
import static org.mockito.Mockito.*;
|
||||||
|
|
||||||
|
@RunWith(MockitoJUnitRunner.class)
|
||||||
public class ReplayMemoryExperienceHandlerTest {
|
public class ReplayMemoryExperienceHandlerTest {
|
||||||
|
|
||||||
|
@Mock
|
||||||
|
IExpReplay<Integer> expReplayMock;
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void when_addingFirstExperience_expect_notAddedToStoreBeforeNextObservationIsAdded() {
|
public void when_addingFirstExperience_expect_notAddedToStoreBeforeNextObservationIsAdded() {
|
||||||
// Arrange
|
// Arrange
|
||||||
TestExpReplay expReplayMock = new TestExpReplay();
|
when(expReplayMock.getDesignatedBatchSize()).thenReturn(10);
|
||||||
|
|
||||||
ReplayMemoryExperienceHandler sut = new ReplayMemoryExperienceHandler(expReplayMock);
|
ReplayMemoryExperienceHandler sut = new ReplayMemoryExperienceHandler(expReplayMock);
|
||||||
|
|
||||||
// Act
|
// Act
|
||||||
sut.addExperience(new Observation(Nd4j.create(new double[] { 1.0 })), 1, 1.0, false);
|
sut.addExperience(new Observation(Nd4j.create(new double[] { 1.0 })), 1, 1.0, false);
|
||||||
int numStoredTransitions = expReplayMock.addedTransitions.size();
|
boolean isStoreCalledAfterFirstAdd = mockingDetails(expReplayMock).getInvocations().stream().anyMatch(x -> x.getMethod().getName() == "store");
|
||||||
sut.addExperience(new Observation(Nd4j.create(new double[] { 2.0 })), 2, 2.0, false);
|
sut.addExperience(new Observation(Nd4j.create(new double[] { 2.0 })), 2, 2.0, false);
|
||||||
|
boolean isStoreCalledAfterSecondAdd = mockingDetails(expReplayMock).getInvocations().stream().anyMatch(x -> x.getMethod().getName() == "store");
|
||||||
|
|
||||||
// Assert
|
// Assert
|
||||||
assertEquals(0, numStoredTransitions);
|
assertFalse(isStoreCalledAfterFirstAdd);
|
||||||
assertEquals(1, expReplayMock.addedTransitions.size());
|
assertTrue(isStoreCalledAfterSecondAdd);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void when_addingExperience_expect_transitionsAreCorrect() {
|
public void when_addingExperience_expect_transitionsAreCorrect() {
|
||||||
// Arrange
|
// Arrange
|
||||||
TestExpReplay expReplayMock = new TestExpReplay();
|
|
||||||
ReplayMemoryExperienceHandler sut = new ReplayMemoryExperienceHandler(expReplayMock);
|
ReplayMemoryExperienceHandler sut = new ReplayMemoryExperienceHandler(expReplayMock);
|
||||||
|
|
||||||
// Act
|
// Act
|
||||||
|
@ -40,24 +50,25 @@ public class ReplayMemoryExperienceHandlerTest {
|
||||||
sut.setFinalObservation(new Observation(Nd4j.create(new double[] { 3.0 })));
|
sut.setFinalObservation(new Observation(Nd4j.create(new double[] { 3.0 })));
|
||||||
|
|
||||||
// Assert
|
// Assert
|
||||||
assertEquals(2, expReplayMock.addedTransitions.size());
|
ArgumentCaptor<Transition<Integer>> argument = ArgumentCaptor.forClass(Transition.class);
|
||||||
|
verify(expReplayMock, times(2)).store(argument.capture());
|
||||||
|
List<Transition<Integer>> transitions = argument.getAllValues();
|
||||||
|
|
||||||
assertEquals(1.0, expReplayMock.addedTransitions.get(0).getObservation().getData().getDouble(0), 0.00001);
|
assertEquals(1.0, transitions.get(0).getObservation().getData().getDouble(0), 0.00001);
|
||||||
assertEquals(1, (int)expReplayMock.addedTransitions.get(0).getAction());
|
assertEquals(1, (int)transitions.get(0).getAction());
|
||||||
assertEquals(1.0, expReplayMock.addedTransitions.get(0).getReward(), 0.00001);
|
assertEquals(1.0, transitions.get(0).getReward(), 0.00001);
|
||||||
assertEquals(2.0, expReplayMock.addedTransitions.get(0).getNextObservation().getDouble(0), 0.00001);
|
assertEquals(2.0, transitions.get(0).getNextObservation().getDouble(0), 0.00001);
|
||||||
|
|
||||||
assertEquals(2.0, expReplayMock.addedTransitions.get(1).getObservation().getData().getDouble(0), 0.00001);
|
assertEquals(2.0, transitions.get(1).getObservation().getData().getDouble(0), 0.00001);
|
||||||
assertEquals(2, (int)expReplayMock.addedTransitions.get(1).getAction());
|
assertEquals(2, (int)transitions.get(1).getAction());
|
||||||
assertEquals(2.0, expReplayMock.addedTransitions.get(1).getReward(), 0.00001);
|
assertEquals(2.0, transitions.get(1).getReward(), 0.00001);
|
||||||
assertEquals(3.0, expReplayMock.addedTransitions.get(1).getNextObservation().getDouble(0), 0.00001);
|
assertEquals(3.0, transitions.get(1).getNextObservation().getDouble(0), 0.00001);
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void when_settingFinalObservation_expect_nextAddedExperienceDoNotUsePreviousObservation() {
|
public void when_settingFinalObservation_expect_nextAddedExperienceDoNotUsePreviousObservation() {
|
||||||
// Arrange
|
// Arrange
|
||||||
TestExpReplay expReplayMock = new TestExpReplay();
|
|
||||||
ReplayMemoryExperienceHandler sut = new ReplayMemoryExperienceHandler(expReplayMock);
|
ReplayMemoryExperienceHandler sut = new ReplayMemoryExperienceHandler(expReplayMock);
|
||||||
|
|
||||||
// Act
|
// Act
|
||||||
|
@ -66,42 +77,57 @@ public class ReplayMemoryExperienceHandlerTest {
|
||||||
sut.addExperience(new Observation(Nd4j.create(new double[] { 3.0 })), 3, 3.0, false);
|
sut.addExperience(new Observation(Nd4j.create(new double[] { 3.0 })), 3, 3.0, false);
|
||||||
|
|
||||||
// Assert
|
// Assert
|
||||||
assertEquals(1, expReplayMock.addedTransitions.size());
|
ArgumentCaptor<Transition<Integer>> argument = ArgumentCaptor.forClass(Transition.class);
|
||||||
assertEquals(1, (int)expReplayMock.addedTransitions.get(0).getAction());
|
verify(expReplayMock, times(1)).store(argument.capture());
|
||||||
|
Transition<Integer> transition = argument.getValue();
|
||||||
|
|
||||||
|
assertEquals(1, (int)transition.getAction());
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void when_addingExperience_expect_getTrainingBatchSizeReturnSize() {
|
public void when_addingExperience_expect_getTrainingBatchSizeReturnSize() {
|
||||||
// Arrange
|
// Arrange
|
||||||
TestExpReplay expReplayMock = new TestExpReplay();
|
ReplayMemoryExperienceHandler sut = new ReplayMemoryExperienceHandler(10, 5, Nd4j.getRandom());
|
||||||
ReplayMemoryExperienceHandler sut = new ReplayMemoryExperienceHandler(expReplayMock);
|
|
||||||
sut.addExperience(new Observation(Nd4j.create(new double[] { 1.0 })), 1, 1.0, false);
|
sut.addExperience(new Observation(Nd4j.create(new double[] { 1.0 })), 1, 1.0, false);
|
||||||
sut.addExperience(new Observation(Nd4j.create(new double[] { 2.0 })), 2, 2.0, false);
|
sut.addExperience(new Observation(Nd4j.create(new double[] { 2.0 })), 2, 2.0, false);
|
||||||
sut.setFinalObservation(new Observation(Nd4j.create(new double[] { 3.0 })));
|
sut.setFinalObservation(new Observation(Nd4j.create(new double[] { 3.0 })));
|
||||||
|
|
||||||
// Act
|
// Act
|
||||||
int size = sut.getTrainingBatchSize();
|
int size = sut.getTrainingBatchSize();
|
||||||
|
|
||||||
// Assert
|
// Assert
|
||||||
assertEquals(2, size);
|
assertEquals(2, size);
|
||||||
}
|
}
|
||||||
|
|
||||||
private static class TestExpReplay implements IExpReplay<Integer> {
|
@Test
|
||||||
|
public void when_experienceSizeIsSmallerThanBatchSize_expect_TrainingBatchIsNotReady() {
|
||||||
|
// Arrange
|
||||||
|
ReplayMemoryExperienceHandler sut = new ReplayMemoryExperienceHandler(10, 5, Nd4j.getRandom());
|
||||||
|
sut.addExperience(new Observation(Nd4j.create(new double[] { 1.0 })), 1, 1.0, false);
|
||||||
|
sut.addExperience(new Observation(Nd4j.create(new double[] { 2.0 })), 2, 2.0, false);
|
||||||
|
sut.setFinalObservation(new Observation(Nd4j.create(new double[] { 3.0 })));
|
||||||
|
|
||||||
public final List<Transition<Integer>> addedTransitions = new ArrayList<>();
|
// Act
|
||||||
|
|
||||||
@Override
|
// Assert
|
||||||
public ArrayList<Transition<Integer>> getBatch() {
|
assertFalse(sut.isTrainingBatchReady());
|
||||||
return null;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Test
|
||||||
public void store(Transition<Integer> transition) {
|
public void when_experienceSizeIsGreaterOrEqualToBatchSize_expect_TrainingBatchIsReady() {
|
||||||
addedTransitions.add(transition);
|
// Arrange
|
||||||
|
ReplayMemoryExperienceHandler sut = new ReplayMemoryExperienceHandler(10, 5, Nd4j.getRandom());
|
||||||
|
sut.addExperience(new Observation(Nd4j.create(new double[] { 1.0 })), 1, 1.0, false);
|
||||||
|
sut.addExperience(new Observation(Nd4j.create(new double[] { 2.0 })), 2, 2.0, false);
|
||||||
|
sut.addExperience(new Observation(Nd4j.create(new double[] { 3.0 })), 3, 3.0, false);
|
||||||
|
sut.addExperience(new Observation(Nd4j.create(new double[] { 4.0 })), 4, 4.0, false);
|
||||||
|
sut.addExperience(new Observation(Nd4j.create(new double[] { 5.0 })), 5, 5.0, false);
|
||||||
|
sut.setFinalObservation(new Observation(Nd4j.create(new double[] { 6.0 })));
|
||||||
|
|
||||||
|
// Act
|
||||||
|
|
||||||
|
// Assert
|
||||||
|
assertTrue(sut.isTrainingBatchReady());
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
|
||||||
public int getBatchSize() {
|
|
||||||
return addedTransitions.size();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -13,7 +13,7 @@ public class StateActionExperienceHandlerTest {
|
||||||
@Test
|
@Test
|
||||||
public void when_addingExperience_expect_generateTrainingBatchReturnsIt() {
|
public void when_addingExperience_expect_generateTrainingBatchReturnsIt() {
|
||||||
// Arrange
|
// Arrange
|
||||||
StateActionExperienceHandler sut = new StateActionExperienceHandler();
|
StateActionExperienceHandler sut = new StateActionExperienceHandler(Integer.MAX_VALUE);
|
||||||
sut.reset();
|
sut.reset();
|
||||||
Observation observation = new Observation(Nd4j.zeros(1));
|
Observation observation = new Observation(Nd4j.zeros(1));
|
||||||
sut.addExperience(observation, 123, 234.0, true);
|
sut.addExperience(observation, 123, 234.0, true);
|
||||||
|
@ -32,7 +32,7 @@ public class StateActionExperienceHandlerTest {
|
||||||
@Test
|
@Test
|
||||||
public void when_addingMultipleExperiences_expect_generateTrainingBatchReturnsItInSameOrder() {
|
public void when_addingMultipleExperiences_expect_generateTrainingBatchReturnsItInSameOrder() {
|
||||||
// Arrange
|
// Arrange
|
||||||
StateActionExperienceHandler sut = new StateActionExperienceHandler();
|
StateActionExperienceHandler sut = new StateActionExperienceHandler(Integer.MAX_VALUE);
|
||||||
sut.reset();
|
sut.reset();
|
||||||
sut.addExperience(null, 1, 1.0, false);
|
sut.addExperience(null, 1, 1.0, false);
|
||||||
sut.addExperience(null, 2, 2.0, false);
|
sut.addExperience(null, 2, 2.0, false);
|
||||||
|
@ -51,7 +51,7 @@ public class StateActionExperienceHandlerTest {
|
||||||
@Test
|
@Test
|
||||||
public void when_gettingExperience_expect_experienceStoreIsCleared() {
|
public void when_gettingExperience_expect_experienceStoreIsCleared() {
|
||||||
// Arrange
|
// Arrange
|
||||||
StateActionExperienceHandler sut = new StateActionExperienceHandler();
|
StateActionExperienceHandler sut = new StateActionExperienceHandler(Integer.MAX_VALUE);
|
||||||
sut.reset();
|
sut.reset();
|
||||||
sut.addExperience(null, 1, 1.0, false);
|
sut.addExperience(null, 1, 1.0, false);
|
||||||
|
|
||||||
|
@ -67,7 +67,7 @@ public class StateActionExperienceHandlerTest {
|
||||||
@Test
|
@Test
|
||||||
public void when_addingExperience_expect_getTrainingBatchSizeReturnSize() {
|
public void when_addingExperience_expect_getTrainingBatchSizeReturnSize() {
|
||||||
// Arrange
|
// Arrange
|
||||||
StateActionExperienceHandler sut = new StateActionExperienceHandler();
|
StateActionExperienceHandler sut = new StateActionExperienceHandler(Integer.MAX_VALUE);
|
||||||
sut.reset();
|
sut.reset();
|
||||||
sut.addExperience(null, 1, 1.0, false);
|
sut.addExperience(null, 1, 1.0, false);
|
||||||
sut.addExperience(null, 2, 2.0, false);
|
sut.addExperience(null, 2, 2.0, false);
|
||||||
|
@ -79,4 +79,66 @@ public class StateActionExperienceHandlerTest {
|
||||||
// Assert
|
// Assert
|
||||||
assertEquals(3, size);
|
assertEquals(3, size);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void when_experienceIsEmpty_expect_TrainingBatchNotReady() {
|
||||||
|
// Arrange
|
||||||
|
StateActionExperienceHandler sut = new StateActionExperienceHandler(5);
|
||||||
|
sut.reset();
|
||||||
|
|
||||||
|
// Act
|
||||||
|
boolean isTrainingBatchReady = sut.isTrainingBatchReady();
|
||||||
|
|
||||||
|
// Assert
|
||||||
|
assertFalse(isTrainingBatchReady);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void when_experienceSizeIsGreaterOrEqualToThanBatchSize_expect_TrainingBatchIsReady() {
|
||||||
|
// Arrange
|
||||||
|
StateActionExperienceHandler sut = new StateActionExperienceHandler(5);
|
||||||
|
sut.reset();
|
||||||
|
sut.addExperience(null, 1, 1.0, false);
|
||||||
|
sut.addExperience(null, 2, 2.0, false);
|
||||||
|
sut.addExperience(null, 3, 3.0, false);
|
||||||
|
sut.addExperience(null, 4, 4.0, false);
|
||||||
|
sut.addExperience(null, 5, 5.0, false);
|
||||||
|
|
||||||
|
// Act
|
||||||
|
boolean isTrainingBatchReady = sut.isTrainingBatchReady();
|
||||||
|
|
||||||
|
// Assert
|
||||||
|
assertTrue(isTrainingBatchReady);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void when_experienceSizeIsSmallerThanBatchSizeButFinalObservationIsSet_expect_TrainingBatchIsReady() {
|
||||||
|
// Arrange
|
||||||
|
StateActionExperienceHandler sut = new StateActionExperienceHandler(5);
|
||||||
|
sut.reset();
|
||||||
|
sut.addExperience(null, 1, 1.0, false);
|
||||||
|
sut.addExperience(null, 2, 2.0, false);
|
||||||
|
sut.setFinalObservation(null);
|
||||||
|
|
||||||
|
// Act
|
||||||
|
boolean isTrainingBatchReady = sut.isTrainingBatchReady();
|
||||||
|
|
||||||
|
// Assert
|
||||||
|
assertTrue(isTrainingBatchReady);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void when_experienceSizeIsZeroAndFinalObservationIsSet_expect_TrainingBatchIsNotReady() {
|
||||||
|
// Arrange
|
||||||
|
StateActionExperienceHandler sut = new StateActionExperienceHandler(5);
|
||||||
|
sut.reset();
|
||||||
|
sut.setFinalObservation(null);
|
||||||
|
|
||||||
|
// Act
|
||||||
|
boolean isTrainingBatchReady = sut.isTrainingBatchReady();
|
||||||
|
|
||||||
|
// Assert
|
||||||
|
assertFalse(isTrainingBatchReady);
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -49,4 +49,25 @@ public class INDArrayHelperTest {
|
||||||
assertEquals(1, output.shape()[1]);
|
assertEquals(1, output.shape()[1]);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void when_callingCreateBatchForShape_expect_INDArrayWithCorrectShapeAndOriginalShapeUnchanged() {
|
||||||
|
// Arrange
|
||||||
|
long[] shape = new long[] { 1, 3, 4};
|
||||||
|
|
||||||
|
// Act
|
||||||
|
INDArray output = INDArrayHelper.createBatchForShape(2, shape);
|
||||||
|
|
||||||
|
// Assert
|
||||||
|
// Output shape
|
||||||
|
assertEquals(3, output.shape().length);
|
||||||
|
assertEquals(2, output.shape()[0]);
|
||||||
|
assertEquals(3, output.shape()[1]);
|
||||||
|
assertEquals(4, output.shape()[2]);
|
||||||
|
|
||||||
|
// Input should remain unchanged
|
||||||
|
assertEquals(1, shape[0]);
|
||||||
|
assertEquals(3, shape[1]);
|
||||||
|
assertEquals(4, shape[2]);
|
||||||
|
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -19,10 +19,11 @@ package org.deeplearning4j.rl4j.learning.async.nstep.discrete;
|
||||||
import org.deeplearning4j.rl4j.experience.StateActionPair;
|
import org.deeplearning4j.rl4j.experience.StateActionPair;
|
||||||
import org.deeplearning4j.rl4j.learning.async.AsyncGlobal;
|
import org.deeplearning4j.rl4j.learning.async.AsyncGlobal;
|
||||||
import org.deeplearning4j.rl4j.learning.async.UpdateAlgorithm;
|
import org.deeplearning4j.rl4j.learning.async.UpdateAlgorithm;
|
||||||
|
import org.deeplearning4j.rl4j.network.dqn.IDQN;
|
||||||
import org.deeplearning4j.rl4j.observation.Observation;
|
import org.deeplearning4j.rl4j.observation.Observation;
|
||||||
import org.deeplearning4j.rl4j.support.MockDQN;
|
|
||||||
import org.junit.Test;
|
import org.junit.Test;
|
||||||
import org.junit.runner.RunWith;
|
import org.junit.runner.RunWith;
|
||||||
|
import org.mockito.ArgumentCaptor;
|
||||||
import org.mockito.Mock;
|
import org.mockito.Mock;
|
||||||
import org.mockito.junit.MockitoJUnitRunner;
|
import org.mockito.junit.MockitoJUnitRunner;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
@ -32,6 +33,9 @@ import java.util.ArrayList;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
|
||||||
import static org.junit.Assert.assertEquals;
|
import static org.junit.Assert.assertEquals;
|
||||||
|
import static org.mockito.ArgumentMatchers.any;
|
||||||
|
import static org.mockito.ArgumentMatchers.argThat;
|
||||||
|
import static org.mockito.Mockito.*;
|
||||||
|
|
||||||
@RunWith(MockitoJUnitRunner.class)
|
@RunWith(MockitoJUnitRunner.class)
|
||||||
public class QLearningUpdateAlgorithmTest {
|
public class QLearningUpdateAlgorithmTest {
|
||||||
|
@ -39,12 +43,24 @@ public class QLearningUpdateAlgorithmTest {
|
||||||
@Mock
|
@Mock
|
||||||
AsyncGlobal mockAsyncGlobal;
|
AsyncGlobal mockAsyncGlobal;
|
||||||
|
|
||||||
|
@Mock
|
||||||
|
IDQN dqnMock;
|
||||||
|
|
||||||
|
private UpdateAlgorithm sut;
|
||||||
|
|
||||||
|
private void setup(double gamma) {
|
||||||
|
// mock a neural net output -- just invert the sign of the input
|
||||||
|
when(dqnMock.outputAll(any(INDArray.class))).thenAnswer(invocation -> new INDArray[] { invocation.getArgument(0, INDArray.class).mul(-1.0) });
|
||||||
|
|
||||||
|
sut = new QLearningUpdateAlgorithm(2, gamma);
|
||||||
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void when_isTerminal_expect_initRewardIs0() {
|
public void when_isTerminal_expect_initRewardIs0() {
|
||||||
// Arrange
|
// Arrange
|
||||||
MockDQN dqnMock = new MockDQN();
|
setup(1.0);
|
||||||
UpdateAlgorithm sut = new QLearningUpdateAlgorithm(new int[] { 1 }, 1, 1.0);
|
|
||||||
final Observation observation = new Observation(Nd4j.zeros(1));
|
final Observation observation = new Observation(Nd4j.zeros(1, 2));
|
||||||
List<StateActionPair<Integer>> experience = new ArrayList<StateActionPair<Integer>>() {
|
List<StateActionPair<Integer>> experience = new ArrayList<StateActionPair<Integer>>() {
|
||||||
{
|
{
|
||||||
add(new StateActionPair<Integer>(observation, 0, 0.0, true));
|
add(new StateActionPair<Integer>(observation, 0, 0.0, true));
|
||||||
|
@ -55,59 +71,68 @@ public class QLearningUpdateAlgorithmTest {
|
||||||
sut.computeGradients(dqnMock, experience);
|
sut.computeGradients(dqnMock, experience);
|
||||||
|
|
||||||
// Assert
|
// Assert
|
||||||
assertEquals(0.0, dqnMock.gradientParams.get(0).getRight().getDouble(0), 0.00001);
|
verify(dqnMock, times(1)).gradient(any(INDArray.class), argThat((INDArray x) -> x.getDouble(0) == 0.0));
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void when_terminalAndNoTargetUpdate_expect_initRewardWithMaxQFromCurrent() {
|
public void when_terminalAndNoTargetUpdate_expect_initRewardWithMaxQFromCurrent() {
|
||||||
// Arrange
|
// Arrange
|
||||||
UpdateAlgorithm sut = new QLearningUpdateAlgorithm(new int[] { 2 }, 2, 1.0);
|
setup(1.0);
|
||||||
final Observation observation = new Observation(Nd4j.create(new double[] { -123.0, -234.0 }));
|
|
||||||
|
final Observation observation = new Observation(Nd4j.create(new double[] { -123.0, -234.0 }).reshape(1, 2));
|
||||||
List<StateActionPair<Integer>> experience = new ArrayList<StateActionPair<Integer>>() {
|
List<StateActionPair<Integer>> experience = new ArrayList<StateActionPair<Integer>>() {
|
||||||
{
|
{
|
||||||
add(new StateActionPair<Integer>(observation, 0, 0.0, false));
|
add(new StateActionPair<Integer>(observation, 0, 0.0, false));
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
MockDQN dqnMock = new MockDQN();
|
|
||||||
|
|
||||||
// Act
|
// Act
|
||||||
sut.computeGradients(dqnMock, experience);
|
sut.computeGradients(dqnMock, experience);
|
||||||
|
|
||||||
// Assert
|
// Assert
|
||||||
assertEquals(2, dqnMock.outputAllParams.size());
|
ArgumentCaptor<INDArray> argument = ArgumentCaptor.forClass(INDArray.class);
|
||||||
assertEquals(-123.0, dqnMock.outputAllParams.get(0).getDouble(0, 0), 0.00001);
|
|
||||||
assertEquals(234.0, dqnMock.gradientParams.get(0).getRight().getDouble(0), 0.00001);
|
verify(dqnMock, times(2)).outputAll(argument.capture());
|
||||||
|
List<INDArray> values = argument.getAllValues();
|
||||||
|
assertEquals(-123.0, values.get(0).getDouble(0, 0), 0.00001);
|
||||||
|
assertEquals(-123.0, values.get(1).getDouble(0, 0), 0.00001);
|
||||||
|
|
||||||
|
verify(dqnMock, times(1)).gradient(any(INDArray.class), argThat((INDArray x) -> x.getDouble(0) == 234.0));
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void when_callingWithMultipleExperiences_expect_gradientsAreValid() {
|
public void when_callingWithMultipleExperiences_expect_gradientsAreValid() {
|
||||||
// Arrange
|
// Arrange
|
||||||
double gamma = 0.9;
|
double gamma = 0.9;
|
||||||
UpdateAlgorithm sut = new QLearningUpdateAlgorithm(new int[] { 2 }, 2, gamma);
|
setup(gamma);
|
||||||
|
|
||||||
List<StateActionPair<Integer>> experience = new ArrayList<StateActionPair<Integer>>() {
|
List<StateActionPair<Integer>> experience = new ArrayList<StateActionPair<Integer>>() {
|
||||||
{
|
{
|
||||||
add(new StateActionPair<Integer>(new Observation(Nd4j.create(new double[] { -1.1, -1.2 })), 0, 1.0, false));
|
add(new StateActionPair<Integer>(new Observation(Nd4j.create(new double[] { -1.1, -1.2 }).reshape(1, 2)), 0, 1.0, false));
|
||||||
add(new StateActionPair<Integer>(new Observation(Nd4j.create(new double[] { -2.1, -2.2 })), 1, 2.0, true));
|
add(new StateActionPair<Integer>(new Observation(Nd4j.create(new double[] { -2.1, -2.2 }).reshape(1, 2)), 1, 2.0, true));
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
MockDQN dqnMock = new MockDQN();
|
|
||||||
|
|
||||||
// Act
|
// Act
|
||||||
sut.computeGradients(dqnMock, experience);
|
sut.computeGradients(dqnMock, experience);
|
||||||
|
|
||||||
// Assert
|
// Assert
|
||||||
|
ArgumentCaptor<INDArray> features = ArgumentCaptor.forClass(INDArray.class);
|
||||||
|
ArgumentCaptor<INDArray> targets = ArgumentCaptor.forClass(INDArray.class);
|
||||||
|
verify(dqnMock, times(1)).gradient(features.capture(), targets.capture());
|
||||||
|
|
||||||
// input side -- should be a stack of observations
|
// input side -- should be a stack of observations
|
||||||
INDArray input = dqnMock.gradientParams.get(0).getLeft();
|
INDArray featuresValues = features.getValue();
|
||||||
assertEquals(-1.1, input.getDouble(0, 0), 0.00001);
|
assertEquals(-1.1, featuresValues.getDouble(0, 0), 0.00001);
|
||||||
assertEquals(-1.2, input.getDouble(0, 1), 0.00001);
|
assertEquals(-1.2, featuresValues.getDouble(0, 1), 0.00001);
|
||||||
assertEquals(-2.1, input.getDouble(1, 0), 0.00001);
|
assertEquals(-2.1, featuresValues.getDouble(1, 0), 0.00001);
|
||||||
assertEquals(-2.2, input.getDouble(1, 1), 0.00001);
|
assertEquals(-2.2, featuresValues.getDouble(1, 1), 0.00001);
|
||||||
|
|
||||||
// target side
|
// target side
|
||||||
INDArray target = dqnMock.gradientParams.get(0).getRight();
|
INDArray targetsValues = targets.getValue();
|
||||||
assertEquals(1.0 + gamma * 2.0, target.getDouble(0, 0), 0.00001);
|
assertEquals(1.0 + gamma * 2.0, targetsValues.getDouble(0, 0), 0.00001);
|
||||||
assertEquals(1.2, target.getDouble(0, 1), 0.00001);
|
assertEquals(1.2, targetsValues.getDouble(0, 1), 0.00001);
|
||||||
assertEquals(2.1, target.getDouble(1, 0), 0.00001);
|
assertEquals(2.1, targetsValues.getDouble(1, 0), 0.00001);
|
||||||
assertEquals(2.0, target.getDouble(1, 1), 0.00001);
|
assertEquals(2.0, targetsValues.getDouble(1, 1), 0.00001);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -18,6 +18,7 @@
|
||||||
package org.deeplearning4j.rl4j.learning.sync.qlearning.discrete;
|
package org.deeplearning4j.rl4j.learning.sync.qlearning.discrete;
|
||||||
|
|
||||||
import org.deeplearning4j.gym.StepReply;
|
import org.deeplearning4j.gym.StepReply;
|
||||||
|
import org.deeplearning4j.rl4j.agent.learning.ILearningBehavior;
|
||||||
import org.deeplearning4j.rl4j.learning.IHistoryProcessor;
|
import org.deeplearning4j.rl4j.learning.IHistoryProcessor;
|
||||||
import org.deeplearning4j.rl4j.learning.configuration.QLearningConfiguration;
|
import org.deeplearning4j.rl4j.learning.configuration.QLearningConfiguration;
|
||||||
import org.deeplearning4j.rl4j.learning.sync.qlearning.QLearning;
|
import org.deeplearning4j.rl4j.learning.sync.qlearning.QLearning;
|
||||||
|
@ -74,6 +75,9 @@ public class QLearningDiscreteTest {
|
||||||
@Mock
|
@Mock
|
||||||
QLearningConfiguration mockQlearningConfiguration;
|
QLearningConfiguration mockQlearningConfiguration;
|
||||||
|
|
||||||
|
@Mock
|
||||||
|
ILearningBehavior<Integer> learningBehavior;
|
||||||
|
|
||||||
// HWC
|
// HWC
|
||||||
int[] observationShape = new int[]{3, 10, 10};
|
int[] observationShape = new int[]{3, 10, 10};
|
||||||
int totalObservationSize = 1;
|
int totalObservationSize = 1;
|
||||||
|
@ -92,12 +96,21 @@ public class QLearningDiscreteTest {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
private void mockTestContext(int maxSteps, int updateStart, int batchSize, double rewardFactor, int maxExperienceReplay) {
|
private void mockTestContext(int maxSteps, int updateStart, int batchSize, double rewardFactor, int maxExperienceReplay, ILearningBehavior<Integer> learningBehavior) {
|
||||||
when(mockQlearningConfiguration.getBatchSize()).thenReturn(batchSize);
|
when(mockQlearningConfiguration.getBatchSize()).thenReturn(batchSize);
|
||||||
when(mockQlearningConfiguration.getRewardFactor()).thenReturn(rewardFactor);
|
when(mockQlearningConfiguration.getRewardFactor()).thenReturn(rewardFactor);
|
||||||
when(mockQlearningConfiguration.getExpRepMaxSize()).thenReturn(maxExperienceReplay);
|
when(mockQlearningConfiguration.getExpRepMaxSize()).thenReturn(maxExperienceReplay);
|
||||||
when(mockQlearningConfiguration.getSeed()).thenReturn(123L);
|
when(mockQlearningConfiguration.getSeed()).thenReturn(123L);
|
||||||
|
|
||||||
|
if(learningBehavior != null) {
|
||||||
|
qLearningDiscrete = mock(
|
||||||
|
QLearningDiscrete.class,
|
||||||
|
Mockito.withSettings()
|
||||||
|
.useConstructor(mockMDP, mockDQN, mockQlearningConfiguration, 0, learningBehavior, Nd4j.getRandom())
|
||||||
|
.defaultAnswer(Mockito.CALLS_REAL_METHODS)
|
||||||
|
);
|
||||||
|
}
|
||||||
|
else {
|
||||||
qLearningDiscrete = mock(
|
qLearningDiscrete = mock(
|
||||||
QLearningDiscrete.class,
|
QLearningDiscrete.class,
|
||||||
Mockito.withSettings()
|
Mockito.withSettings()
|
||||||
|
@ -105,6 +118,7 @@ public class QLearningDiscreteTest {
|
||||||
.defaultAnswer(Mockito.CALLS_REAL_METHODS)
|
.defaultAnswer(Mockito.CALLS_REAL_METHODS)
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
private void mockHistoryProcessor(int skipFrames) {
|
private void mockHistoryProcessor(int skipFrames) {
|
||||||
when(mockHistoryConfiguration.getRescaledHeight()).thenReturn(observationShape[1]);
|
when(mockHistoryConfiguration.getRescaledHeight()).thenReturn(observationShape[1]);
|
||||||
|
@ -136,7 +150,7 @@ public class QLearningDiscreteTest {
|
||||||
public void when_singleTrainStep_expect_correctValues() {
|
public void when_singleTrainStep_expect_correctValues() {
|
||||||
|
|
||||||
// Arrange
|
// Arrange
|
||||||
mockTestContext(100,0,2,1.0, 10);
|
mockTestContext(100,0,2,1.0, 10, null);
|
||||||
|
|
||||||
// An example observation and 2 Q values output (2 actions)
|
// An example observation and 2 Q values output (2 actions)
|
||||||
Observation observation = new Observation(Nd4j.zeros(observationShape));
|
Observation observation = new Observation(Nd4j.zeros(observationShape));
|
||||||
|
@ -162,7 +176,7 @@ public class QLearningDiscreteTest {
|
||||||
@Test
|
@Test
|
||||||
public void when_singleTrainStepSkippedFrames_expect_correctValues() {
|
public void when_singleTrainStepSkippedFrames_expect_correctValues() {
|
||||||
// Arrange
|
// Arrange
|
||||||
mockTestContext(100,0,2,1.0, 10);
|
mockTestContext(100,0,2,1.0, 10, learningBehavior);
|
||||||
|
|
||||||
Observation skippedObservation = Observation.SkippedObservation;
|
Observation skippedObservation = Observation.SkippedObservation;
|
||||||
Observation nextObservation = new Observation(Nd4j.zeros(observationShape));
|
Observation nextObservation = new Observation(Nd4j.zeros(observationShape));
|
||||||
|
@ -180,8 +194,8 @@ public class QLearningDiscreteTest {
|
||||||
assertEquals(0, stepReply.getReward(), 1e-5);
|
assertEquals(0, stepReply.getReward(), 1e-5);
|
||||||
assertFalse(stepReply.isDone());
|
assertFalse(stepReply.isDone());
|
||||||
assertFalse(stepReply.getObservation().isSkipped());
|
assertFalse(stepReply.getObservation().isSkipped());
|
||||||
assertEquals(0, qLearningDiscrete.getExperienceHandler().getTrainingBatchSize());
|
|
||||||
|
|
||||||
|
verify(learningBehavior, never()).handleNewExperience(any(Observation.class), any(Integer.class), any(Double.class), any(Boolean.class));
|
||||||
verify(mockDQN, never()).output(any(INDArray.class));
|
verify(mockDQN, never()).output(any(INDArray.class));
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue