RL4J: Add ExperienceHandler (#369)

* Added ExperienceHandler

Signed-off-by: Alexandre Boulanger <aboulang2002@yahoo.com>

* Added getTrainingBatchSize()

Signed-off-by: Alexandre Boulanger <aboulang2002@yahoo.com>
master
Alexandre Boulanger 2020-04-10 19:50:40 -04:00 committed by GitHub
parent 3e2dbc65dd
commit f1debe8c07
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
27 changed files with 1183 additions and 625 deletions

View File

@ -0,0 +1,54 @@
/*******************************************************************************
* Copyright (c) 2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.rl4j.experience;
import org.deeplearning4j.rl4j.observation.Observation;
import java.util.List;
/**
* A common interface to all classes capable of handling experience generated by the agents in a learning context.
*
* @param <A> Action type
* @param <E> Experience type
*
* @author Alexandre Boulanger
*/
public interface ExperienceHandler<A, E> {
void addExperience(Observation observation, A action, double reward, boolean isTerminal);
/**
* Called when the episode is done with the last observation
* @param observation
*/
void setFinalObservation(Observation observation);
/**
* @return The size of the list that will be returned by generateTrainingBatch().
*/
int getTrainingBatchSize();
/**
* The elements are returned in the historical order (i.e. in the order they happened)
* @return The list of experience elements
*/
List<E> generateTrainingBatch();
/**
* Signal the experience handler that a new episode is starting
*/
void reset();
}

View File

@ -0,0 +1,111 @@
/*******************************************************************************
* Copyright (c) 2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.rl4j.experience;
import lombok.EqualsAndHashCode;
import org.deeplearning4j.rl4j.learning.sync.ExpReplay;
import org.deeplearning4j.rl4j.learning.sync.IExpReplay;
import org.deeplearning4j.rl4j.learning.sync.Transition;
import org.deeplearning4j.rl4j.observation.Observation;
import org.nd4j.linalg.api.rng.Random;
import org.nd4j.linalg.factory.Nd4j;
import java.util.List;
/**
* A experience handler that stores the experience in a replay memory. See https://arxiv.org/abs/1312.5602
* The experience container is a {@link Transition Transition} that stores the tuple observation-action-reward-nextObservation,
* as well as whether or the not the episode ended after the Transition
*
* @param <A> Action type
*/
@EqualsAndHashCode
public class ReplayMemoryExperienceHandler<A> implements ExperienceHandler<A, Transition<A>> {
private static final int DEFAULT_MAX_REPLAY_MEMORY_SIZE = 150000;
private static final int DEFAULT_BATCH_SIZE = 32;
private IExpReplay<A> expReplay;
private Transition<A> pendingTransition;
public ReplayMemoryExperienceHandler(IExpReplay<A> expReplay) {
this.expReplay = expReplay;
}
public ReplayMemoryExperienceHandler(int maxReplayMemorySize, int batchSize, Random random) {
this(new ExpReplay<A>(maxReplayMemorySize, batchSize, random));
}
public void addExperience(Observation observation, A action, double reward, boolean isTerminal) {
setNextObservationOnPending(observation);
pendingTransition = new Transition<>(observation, action, reward, isTerminal);
}
public void setFinalObservation(Observation observation) {
setNextObservationOnPending(observation);
pendingTransition = null;
}
@Override
public int getTrainingBatchSize() {
return expReplay.getBatchSize();
}
/**
* @return A batch of experience selected from the replay memory. The replay memory is unchanged after the call.
*/
@Override
public List<Transition<A>> generateTrainingBatch() {
return expReplay.getBatch();
}
@Override
public void reset() {
pendingTransition = null;
}
private void setNextObservationOnPending(Observation observation) {
if(pendingTransition != null) {
pendingTransition.setNextObservation(observation);
expReplay.store(pendingTransition);
}
}
public class Builder {
private int maxReplayMemorySize = DEFAULT_MAX_REPLAY_MEMORY_SIZE;
private int batchSize = DEFAULT_BATCH_SIZE;
private Random random = Nd4j.getRandom();
public Builder maxReplayMemorySize(int value) {
maxReplayMemorySize = value;
return this;
}
public Builder batchSize(int value) {
batchSize = value;
return this;
}
public Builder random(Random value) {
random = value;
return this;
}
public ReplayMemoryExperienceHandler<A> build() {
return new ReplayMemoryExperienceHandler<A>(maxReplayMemorySize, batchSize, random);
}
}
}

View File

@ -0,0 +1,67 @@
/*******************************************************************************
* Copyright (c) 2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.rl4j.experience;
import org.deeplearning4j.rl4j.observation.Observation;
import java.util.ArrayList;
import java.util.List;
/**
* A simple {@link ExperienceHandler experience handler} that stores the experiences.
* Note: Calling {@link StateActionExperienceHandler#generateTrainingBatch() generateTrainingBatch()} will clear the stored experiences
*
* @param <A> Action type
*
* @author Alexandre Boulanger
*/
public class StateActionExperienceHandler<A> implements ExperienceHandler<A, StateActionPair<A>> {
private List<StateActionPair<A>> stateActionPairs;
public void setFinalObservation(Observation observation) {
// Do nothing
}
public void addExperience(Observation observation, A action, double reward, boolean isTerminal) {
stateActionPairs.add(new StateActionPair<A>(observation, action, reward, isTerminal));
}
@Override
public int getTrainingBatchSize() {
return stateActionPairs.size();
}
/**
* The elements are returned in the historical order (i.e. in the order they happened)
* Note: the experience store is cleared after calling this method.
*
* @return The list of experience elements
*/
@Override
public List<StateActionPair<A>> generateTrainingBatch() {
List<StateActionPair<A>> result = stateActionPairs;
stateActionPairs = new ArrayList<>();
return result;
}
@Override
public void reset() {
stateActionPairs = new ArrayList<>();
}
}

View File

@ -0,0 +1,49 @@
/*******************************************************************************
* Copyright (c) 2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.rl4j.experience;
import lombok.AllArgsConstructor;
import lombok.Getter;
import org.deeplearning4j.rl4j.observation.Observation;
/**
* A simple experience container. Used by {@link StateActionExperienceHandler StateActionExperienceHandler}.
*
* @param <A> Action type
*
* @author Alexandre Boulanger
*/
@AllArgsConstructor
public class StateActionPair<A> {
/**
* The observation before the action is taken
*/
@Getter
private final Observation observation;
@Getter
private final A action;
@Getter
private final double reward;
/**
* True if the episode ended after the action has been taken.
*/
@Getter
private final boolean terminal;
}

View File

@ -18,9 +18,12 @@
package org.deeplearning4j.rl4j.learning.async;
import lombok.AccessLevel;
import lombok.Getter;
import lombok.Setter;
import org.deeplearning4j.gym.StepReply;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.rl4j.experience.ExperienceHandler;
import org.deeplearning4j.rl4j.experience.StateActionExperienceHandler;
import org.deeplearning4j.rl4j.learning.IHistoryProcessor;
import org.deeplearning4j.rl4j.learning.listener.TrainingListenerList;
import org.deeplearning4j.rl4j.mdp.MDP;
@ -28,10 +31,6 @@ import org.deeplearning4j.rl4j.network.NeuralNet;
import org.deeplearning4j.rl4j.observation.Observation;
import org.deeplearning4j.rl4j.policy.IPolicy;
import org.deeplearning4j.rl4j.space.DiscreteSpace;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import java.util.Stack;
/**
* @author rubenfiszel (ruben.fiszel@epfl.ch) on 8/5/16.
@ -45,13 +44,39 @@ public abstract class AsyncThreadDiscrete<O, NN extends NeuralNet>
@Getter
private NN current;
public AsyncThreadDiscrete(IAsyncGlobal<NN> asyncGlobal, MDP<O, Integer, DiscreteSpace> mdp, TrainingListenerList listeners, int threadNumber, int deviceNum) {
@Setter(AccessLevel.PROTECTED)
private UpdateAlgorithm<NN> updateAlgorithm;
// TODO: Make it configurable with a builder
@Setter(AccessLevel.PROTECTED)
private ExperienceHandler experienceHandler = new StateActionExperienceHandler();
public AsyncThreadDiscrete(IAsyncGlobal<NN> asyncGlobal,
MDP<O, Integer, DiscreteSpace> mdp,
TrainingListenerList listeners,
int threadNumber,
int deviceNum) {
super(asyncGlobal, mdp, listeners, threadNumber, deviceNum);
synchronized (asyncGlobal) {
current = (NN)asyncGlobal.getCurrent().clone();
}
}
// TODO: Add an actor-learner class and be able to inject the update algorithm
protected abstract UpdateAlgorithm<NN> buildUpdateAlgorithm();
@Override
public void setHistoryProcessor(IHistoryProcessor historyProcessor) {
super.setHistoryProcessor(historyProcessor);
updateAlgorithm = buildUpdateAlgorithm();
}
@Override
protected void preEpoch() {
experienceHandler.reset();
}
/**
* "Subepoch" correspond to the t_max-step iterations
* that stack rewards with t_max MiniTrans
@ -65,13 +90,11 @@ public abstract class AsyncThreadDiscrete<O, NN extends NeuralNet>
synchronized (getAsyncGlobal()) {
current.copy(getAsyncGlobal().getCurrent());
}
Stack<MiniTrans<Integer>> rewards = new Stack<>();
Observation obs = sObs;
IPolicy<O, Integer> policy = getPolicy(current);
Integer action;
Integer lastAction = getMdp().getActionSpace().noOp();
Integer action = getMdp().getActionSpace().noOp();
IHistoryProcessor hp = getHistoryProcessor();
int skipFrame = hp != null ? hp.getConf().getSkipFrame() : 1;
@ -82,21 +105,15 @@ public abstract class AsyncThreadDiscrete<O, NN extends NeuralNet>
while (!getMdp().isDone() && getCurrentEpochStep() < lastStep) {
//if step of training, just repeat lastAction
if (obs.isSkipped()) {
action = lastAction;
} else {
if (!obs.isSkipped()) {
action = policy.nextAction(obs);
}
StepReply<Observation> stepReply = getLegacyMDPWrapper().step(action);
accuReward += stepReply.getReward() * getConf().getRewardFactor();
//if it's not a skipped frame, you can do a step of training
if (!obs.isSkipped()) {
INDArray[] output = current.outputAll(obs.getData());
rewards.add(new MiniTrans(obs.getData(), action, output, accuReward));
experienceHandler.addExperience(obs, action, accuReward, stepReply.isDone());
accuReward = 0;
}
@ -104,29 +121,14 @@ public abstract class AsyncThreadDiscrete<O, NN extends NeuralNet>
reward += stepReply.getReward();
incrementStep();
lastAction = action;
}
//a bit of a trick usable because of how the stack is treated to init R
// FIXME: The last element of minitrans is only used to seed the reward in calcGradient; observation, action and output are ignored.
if (getMdp().isDone() && getCurrentEpochStep() < lastStep)
rewards.add(new MiniTrans(obs.getData(), null, null, 0));
else {
INDArray[] output = null;
if (getConf().getLearnerUpdateFrequency() == -1)
output = current.outputAll(obs.getData());
else synchronized (getAsyncGlobal()) {
output = getAsyncGlobal().getTarget().outputAll(obs.getData());
}
double maxQ = Nd4j.max(output[0]).getDouble(0);
rewards.add(new MiniTrans(obs.getData(), null, output, maxQ));
if (getMdp().isDone() && getCurrentEpochStep() < lastStep) {
experienceHandler.setFinalObservation(obs);
}
getAsyncGlobal().enqueue(calcGradient(current, rewards), getCurrentEpochStep());
getAsyncGlobal().enqueue(updateAlgorithm.computeGradients(current, experienceHandler.generateTrainingBatch()), getCurrentEpochStep());
return new SubEpochReturn(getCurrentEpochStep() - stepAtStart, obs, reward, current.getLatestScore());
}
public abstract Gradient[] calcGradient(NN nn, Stack<MiniTrans<Integer>> rewards);
}

View File

@ -1,5 +1,5 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
* Copyright (c) 2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
@ -13,28 +13,14 @@
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.rl4j.learning.async;
import lombok.AllArgsConstructor;
import lombok.Value;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.rl4j.experience.StateActionPair;
import org.deeplearning4j.rl4j.network.NeuralNet;
/**
* @author rubenfiszel (ruben.fiszel@epfl.ch) on 8/5/16.
*
* Its called a MiniTrans because it is similar to a Transition
* but without a next observation
*
* It is stacked and then processed by AsyncNStepQL or A3C
* following the paper implementation https://arxiv.org/abs/1602.01783 paper.
*
*/
@AllArgsConstructor
@Value
public class MiniTrans<A> {
INDArray obs;
A action;
INDArray[] output;
double reward;
import java.util.List;
public interface UpdateAlgorithm<NN extends NeuralNet> {
Gradient[] computeGradients(NN current, List<StateActionPair<Integer>> experience);
}

View File

@ -18,11 +18,7 @@
package org.deeplearning4j.rl4j.learning.async.a3c.discrete;
import lombok.Getter;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.rl4j.learning.Learning;
import org.deeplearning4j.rl4j.learning.async.AsyncThreadDiscrete;
import org.deeplearning4j.rl4j.learning.async.IAsyncGlobal;
import org.deeplearning4j.rl4j.learning.async.MiniTrans;
import org.deeplearning4j.rl4j.learning.async.*;
import org.deeplearning4j.rl4j.learning.configuration.A3CLearningConfiguration;
import org.deeplearning4j.rl4j.learning.listener.TrainingListenerList;
import org.deeplearning4j.rl4j.mdp.MDP;
@ -34,9 +30,7 @@ import org.deeplearning4j.rl4j.space.Encodable;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.rng.Random;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.NDArrayIndex;
import java.util.Stack;
import org.nd4j.linalg.api.rng.Random;
/**
* @author rubenfiszel (ruben.fiszel@epfl.ch) 7/23/16.
@ -67,6 +61,8 @@ public class A3CThreadDiscrete<O extends Encodable> extends AsyncThreadDiscrete<
if(seed != null) {
rnd.setSeed(seed + threadNumber);
}
setUpdateAlgorithm(buildUpdateAlgorithm());
}
@Override
@ -74,52 +70,9 @@ public class A3CThreadDiscrete<O extends Encodable> extends AsyncThreadDiscrete<
return new ACPolicy(net, rnd);
}
/**
* calc the gradients based on the n-step rewards
*/
@Override
public Gradient[] calcGradient(IActorCritic iac, Stack<MiniTrans<Integer>> rewards) {
MiniTrans<Integer> minTrans = rewards.pop();
int size = rewards.size();
//if recurrent then train as a time serie with a batch size of 1
boolean recurrent = getAsyncGlobal().getCurrent().isRecurrent();
int[] shape = getHistoryProcessor() == null ? getMdp().getObservationSpace().getShape()
: getHistoryProcessor().getConf().getShape();
int[] nshape = recurrent ? Learning.makeShape(1, shape, size)
: Learning.makeShape(size, shape);
INDArray input = Nd4j.create(nshape);
INDArray targets = recurrent ? Nd4j.create(1, 1, size) : Nd4j.create(size, 1);
INDArray logSoftmax = recurrent ? Nd4j.zeros(1, getMdp().getActionSpace().getSize(), size)
: Nd4j.zeros(size, getMdp().getActionSpace().getSize());
double r = minTrans.getReward();
for (int i = size - 1; i >= 0; i--) {
minTrans = rewards.pop();
r = minTrans.getReward() + conf.getGamma() * r;
if (recurrent) {
input.get(NDArrayIndex.point(0), NDArrayIndex.all(), NDArrayIndex.point(i)).assign(minTrans.getObs());
} else {
input.putRow(i, minTrans.getObs());
}
//the critic
targets.putScalar(i, r);
//the actor
double expectedV = minTrans.getOutput()[0].getDouble(0);
double advantage = r - expectedV;
if (recurrent) {
logSoftmax.putScalar(0, minTrans.getAction(), i, advantage);
} else {
logSoftmax.putScalar(i, minTrans.getAction(), advantage);
}
}
return iac.gradient(input, new INDArray[] {targets, logSoftmax});
protected UpdateAlgorithm<IActorCritic> buildUpdateAlgorithm() {
int[] shape = getHistoryProcessor() == null ? getMdp().getObservationSpace().getShape() : getHistoryProcessor().getConf().getShape();
return new A3CUpdateAlgorithm(asyncGlobal, shape, getMdp().getActionSpace().getSize(), conf.getLearnerUpdateFrequency(), conf.getGamma());
}
}

View File

@ -0,0 +1,113 @@
/*******************************************************************************
* Copyright (c) 2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.rl4j.learning.async.a3c.discrete;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.rl4j.experience.StateActionPair;
import org.deeplearning4j.rl4j.learning.Learning;
import org.deeplearning4j.rl4j.learning.async.IAsyncGlobal;
import org.deeplearning4j.rl4j.learning.async.UpdateAlgorithm;
import org.deeplearning4j.rl4j.network.ac.IActorCritic;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.NDArrayIndex;
import java.util.List;
public class A3CUpdateAlgorithm implements UpdateAlgorithm<IActorCritic> {
private final IAsyncGlobal asyncGlobal;
private final int[] shape;
private final int actionSpaceSize;
private final int targetDqnUpdateFreq;
private final double gamma;
private final boolean recurrent;
public A3CUpdateAlgorithm(IAsyncGlobal asyncGlobal,
int[] shape,
int actionSpaceSize,
int targetDqnUpdateFreq,
double gamma) {
this.asyncGlobal = asyncGlobal;
//if recurrent then train as a time serie with a batch size of 1
recurrent = asyncGlobal.getCurrent().isRecurrent();
this.shape = shape;
this.actionSpaceSize = actionSpaceSize;
this.targetDqnUpdateFreq = targetDqnUpdateFreq;
this.gamma = gamma;
}
@Override
public Gradient[] computeGradients(IActorCritic current, List<StateActionPair<Integer>> experience) {
int size = experience.size();
int[] nshape = recurrent ? Learning.makeShape(1, shape, size)
: Learning.makeShape(size, shape);
INDArray input = Nd4j.create(nshape);
INDArray targets = recurrent ? Nd4j.create(1, 1, size) : Nd4j.create(size, 1);
INDArray logSoftmax = recurrent ? Nd4j.zeros(1, actionSpaceSize, size)
: Nd4j.zeros(size, actionSpaceSize);
StateActionPair<Integer> stateActionPair = experience.get(size - 1);
double r;
if(stateActionPair.isTerminal()) {
r = 0;
}
else {
INDArray[] output = null;
if (targetDqnUpdateFreq == -1)
output = current.outputAll(stateActionPair.getObservation().getData());
else synchronized (asyncGlobal) {
output = asyncGlobal.getTarget().outputAll(stateActionPair.getObservation().getData());
}
r = output[0].getDouble(0);
}
for (int i = size - 1; i >= 0; --i) {
stateActionPair = experience.get(i);
INDArray observationData = stateActionPair.getObservation().getData();
INDArray[] output = current.outputAll(observationData);
r = stateActionPair.getReward() + gamma * r;
if (recurrent) {
input.get(NDArrayIndex.point(0), NDArrayIndex.all(), NDArrayIndex.point(i)).assign(observationData);
} else {
input.putRow(i, observationData);
}
//the critic
targets.putScalar(i, r);
//the actor
double expectedV = output[0].getDouble(0);
double advantage = r - expectedV;
if (recurrent) {
logSoftmax.putScalar(0, stateActionPair.getAction(), i, advantage);
} else {
logSoftmax.putScalar(i, stateActionPair.getAction(), advantage);
}
}
// targets -> value, critic
// logSoftmax -> policy, actor
return current.gradient(input, new INDArray[] {targets, logSoftmax});
}
}

View File

@ -18,11 +18,9 @@
package org.deeplearning4j.rl4j.learning.async.nstep.discrete;
import lombok.Getter;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.rl4j.learning.Learning;
import org.deeplearning4j.rl4j.learning.async.AsyncThreadDiscrete;
import org.deeplearning4j.rl4j.learning.async.IAsyncGlobal;
import org.deeplearning4j.rl4j.learning.async.MiniTrans;
import org.deeplearning4j.rl4j.learning.async.UpdateAlgorithm;
import org.deeplearning4j.rl4j.learning.configuration.AsyncQLearningConfiguration;
import org.deeplearning4j.rl4j.learning.listener.TrainingListenerList;
import org.deeplearning4j.rl4j.mdp.MDP;
@ -32,12 +30,9 @@ import org.deeplearning4j.rl4j.policy.EpsGreedy;
import org.deeplearning4j.rl4j.policy.Policy;
import org.deeplearning4j.rl4j.space.DiscreteSpace;
import org.deeplearning4j.rl4j.space.Encodable;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.api.rng.Random;
import java.util.Stack;
/**
* @author rubenfiszel (ruben.fiszel@epfl.ch) on 8/5/16.
*/
@ -65,6 +60,8 @@ public class AsyncNStepQLearningThreadDiscrete<O extends Encodable> extends Asyn
if(seed != null) {
rnd.setSeed(seed + threadNumber);
}
setUpdateAlgorithm(buildUpdateAlgorithm());
}
public Policy<O, Integer> getPolicy(IDQN nn) {
@ -72,32 +69,9 @@ public class AsyncNStepQLearningThreadDiscrete<O extends Encodable> extends Asyn
rnd, conf.getMinEpsilon(), this);
}
//calc the gradient based on the n-step rewards
public Gradient[] calcGradient(IDQN current, Stack<MiniTrans<Integer>> rewards) {
MiniTrans<Integer> minTrans = rewards.pop();
int size = rewards.size();
int[] shape = getHistoryProcessor() == null ? getMdp().getObservationSpace().getShape()
: getHistoryProcessor().getConf().getShape();
int[] nshape = Learning.makeShape(size, shape);
INDArray input = Nd4j.create(nshape);
INDArray targets = Nd4j.create(size, getMdp().getActionSpace().getSize());
double r = minTrans.getReward();
for (int i = size - 1; i >= 0; i--) {
minTrans = rewards.pop();
r = minTrans.getReward() + conf.getGamma() * r;
input.putRow(i, minTrans.getObs());
INDArray row = minTrans.getOutput()[0];
row = row.putScalar(minTrans.getAction(), r);
targets.putRow(i, row);
}
return current.gradient(input, targets);
@Override
protected UpdateAlgorithm<IDQN> buildUpdateAlgorithm() {
int[] shape = getHistoryProcessor() == null ? getMdp().getObservationSpace().getShape() : getHistoryProcessor().getConf().getShape();
return new QLearningUpdateAlgorithm(asyncGlobal, shape, getMdp().getActionSpace().getSize(), conf.getTargetDqnUpdateFreq(), conf.getGamma());
}
}

View File

@ -0,0 +1,88 @@
/*******************************************************************************
* Copyright (c) 2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.rl4j.learning.async.nstep.discrete;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.rl4j.experience.StateActionPair;
import org.deeplearning4j.rl4j.learning.Learning;
import org.deeplearning4j.rl4j.learning.async.IAsyncGlobal;
import org.deeplearning4j.rl4j.learning.async.UpdateAlgorithm;
import org.deeplearning4j.rl4j.network.dqn.IDQN;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import java.util.List;
public class QLearningUpdateAlgorithm implements UpdateAlgorithm<IDQN> {
private final IAsyncGlobal asyncGlobal;
private final int[] shape;
private final int actionSpaceSize;
private final int targetDqnUpdateFreq;
private final double gamma;
public QLearningUpdateAlgorithm(IAsyncGlobal asyncGlobal,
int[] shape,
int actionSpaceSize,
int targetDqnUpdateFreq,
double gamma) {
this.asyncGlobal = asyncGlobal;
this.shape = shape;
this.actionSpaceSize = actionSpaceSize;
this.targetDqnUpdateFreq = targetDqnUpdateFreq;
this.gamma = gamma;
}
@Override
public Gradient[] computeGradients(IDQN current, List<StateActionPair<Integer>> experience) {
int size = experience.size();
int[] nshape = Learning.makeShape(size, shape);
INDArray input = Nd4j.create(nshape);
INDArray targets = Nd4j.create(size, actionSpaceSize);
StateActionPair<Integer> stateActionPair = experience.get(size - 1);
double r;
if(stateActionPair.isTerminal()) {
r = 0;
}
else {
INDArray[] output = null;
if (targetDqnUpdateFreq == -1)
output = current.outputAll(stateActionPair.getObservation().getData());
else synchronized (asyncGlobal) {
output = asyncGlobal.getTarget().outputAll(stateActionPair.getObservation().getData());
}
r = Nd4j.max(output[0]).getDouble(0);
}
for (int i = size - 1; i >= 0; i--) {
stateActionPair = experience.get(i);
input.putRow(i, stateActionPair.getObservation().getData());
r = stateActionPair.getReward() + gamma * r;
INDArray[] output = current.outputAll(stateActionPair.getObservation().getData());
INDArray row = output[0];
row = row.putScalar(stateActionPair.getAction(), r);
targets.putRow(i, row);
}
return current.gradient(input, targets);
}
}

View File

@ -80,6 +80,9 @@ public class ExpReplay<A> implements IExpReplay<A> {
//log.info("size: "+storage.size());
}
public int getBatchSize() {
int storageSize = storage.size();
return Math.min(storageSize, batchSize);
}
}

View File

@ -32,6 +32,11 @@ import java.util.ArrayList;
*/
public interface IExpReplay<A> {
/**
* @return The size of the batch that will be returned by getBatch()
*/
int getBatchSize();
/**
* @return a batch of uniformly sampled transitions
*/
@ -42,5 +47,4 @@ public interface IExpReplay<A> {
* @param transition a new transition to store
*/
void store(Transition<A> transition);
}

View File

@ -60,32 +60,8 @@ public abstract class QLearning<O extends Encodable, A, AS extends ActionSpace<A
extends SyncLearning<O, A, AS, IDQN>
implements TargetQNetworkSource, EpochStepCounter {
// FIXME Changed for refac
// @Getter
// final private IExpReplay<A> expReplay;
@Getter
@Setter(AccessLevel.PROTECTED)
protected IExpReplay<A> expReplay;
protected abstract LegacyMDPWrapper<O, A, AS> getLegacyMDPWrapper();
public QLearning(QLearningConfiguration conf) {
this(conf, getSeededRandom(conf.getSeed()));
}
public QLearning(QLearningConfiguration conf, Random random) {
expReplay = new ExpReplay<>(conf.getExpRepMaxSize(), conf.getBatchSize(), random);
}
private static Random getSeededRandom(Long seed) {
Random rnd = Nd4j.getRandom();
if(seed != null) {
rnd.setSeed(seed);
}
return rnd;
}
protected abstract EpsGreedy<O, A, AS> getEgPolicy();
public abstract MDP<O, A, AS> getMdp();

View File

@ -21,6 +21,8 @@ import lombok.AccessLevel;
import lombok.Getter;
import lombok.Setter;
import org.deeplearning4j.gym.StepReply;
import org.deeplearning4j.rl4j.experience.ExperienceHandler;
import org.deeplearning4j.rl4j.experience.ReplayMemoryExperienceHandler;
import org.deeplearning4j.rl4j.learning.IHistoryProcessor;
import org.deeplearning4j.rl4j.learning.Learning;
import org.deeplearning4j.rl4j.learning.configuration.QLearningConfiguration;
@ -42,7 +44,7 @@ import org.nd4j.linalg.api.rng.Random;
import org.nd4j.linalg.dataset.api.DataSet;
import org.nd4j.linalg.factory.Nd4j;
import java.util.ArrayList;
import java.util.List;
/**
@ -71,10 +73,12 @@ public abstract class QLearningDiscrete<O extends Encodable> extends QLearning<O
private int lastAction;
private double accuReward = 0;
private Transition pendingTransition;
ITDTargetAlgorithm tdTargetAlgorithm;
// TODO: User a builder and remove the setter
@Getter(AccessLevel.PROTECTED) @Setter
private ExperienceHandler<Integer, Transition<Integer>> experienceHandler;
protected LegacyMDPWrapper<O, Integer, DiscreteSpace> getLegacyMDPWrapper() {
return mdp;
}
@ -85,7 +89,6 @@ public abstract class QLearningDiscrete<O extends Encodable> extends QLearning<O
public QLearningDiscrete(MDP<O, Integer, DiscreteSpace> mdp, IDQN dqn, QLearningConfiguration conf,
int epsilonNbStep, Random random) {
super(conf);
this.configuration = conf;
this.mdp = new LegacyMDPWrapper<>(mdp, null, this);
qNetwork = dqn;
@ -98,6 +101,7 @@ public abstract class QLearningDiscrete<O extends Encodable> extends QLearning<O
? new DoubleDQN(this, conf.getGamma(), conf.getErrorClamp())
: new StandardDQN(this, conf.getGamma(), conf.getErrorClamp());
experienceHandler = new ReplayMemoryExperienceHandler(conf.getExpRepMaxSize(), conf.getBatchSize(), random);
}
public MDP<O, Integer, DiscreteSpace> getMdp() {
@ -114,7 +118,7 @@ public abstract class QLearningDiscrete<O extends Encodable> extends QLearning<O
public void preEpoch() {
lastAction = mdp.getActionSpace().noOp();
accuReward = 0;
pendingTransition = null;
experienceHandler.reset();
}
@Override
@ -131,8 +135,6 @@ public abstract class QLearningDiscrete<O extends Encodable> extends QLearning<O
*/
protected QLStepReturn<Observation> trainStep(Observation obs) {
Integer action;
boolean isHistoryProcessor = getHistoryProcessor() != null;
int skipFrame = isHistoryProcessor ? getHistoryProcessor().getConf().getSkipFrame() : 1;
int historyLength = isHistoryProcessor ? getHistoryProcessor().getConf().getHistoryLength() : 1;
@ -142,37 +144,28 @@ public abstract class QLearningDiscrete<O extends Encodable> extends QLearning<O
Double maxQ = Double.NaN; //ignore if Nan for stats
//if step of training, just repeat lastAction
if (obs.isSkipped()) {
action = lastAction;
} else {
if (!obs.isSkipped()) {
INDArray qs = getQNetwork().output(obs);
int maxAction = Learning.getMaxAction(qs);
maxQ = qs.getDouble(maxAction);
action = getEgPolicy().nextAction(obs);
lastAction = getEgPolicy().nextAction(obs);
}
lastAction = action;
StepReply<Observation> stepReply = mdp.step(action);
StepReply<Observation> stepReply = mdp.step(lastAction);
accuReward += stepReply.getReward() * configuration.getRewardFactor();
//if it's not a skipped frame, you can do a step of training
if (!obs.isSkipped()) {
// Add experience
if (pendingTransition != null) {
pendingTransition.setNextObservation(obs);
getExpReplay().store(pendingTransition);
}
pendingTransition = new Transition(obs, action, accuReward, stepReply.isDone());
experienceHandler.addExperience(obs, lastAction, accuReward, stepReply.isDone());
accuReward = 0;
// Update NN
// FIXME: maybe start updating when experience replay has reached a certain size instead of using "updateStart"?
if (getStepCounter() > updateStart) {
DataSet targets = setTarget(getExpReplay().getBatch());
DataSet targets = setTarget(experienceHandler.generateTrainingBatch());
getQNetwork().fit(targets.getFeatures(), targets.getLabels());
}
}
@ -180,7 +173,7 @@ public abstract class QLearningDiscrete<O extends Encodable> extends QLearning<O
return new QLStepReturn<Observation>(maxQ, getQNetwork().getLatestScore(), stepReply);
}
protected DataSet setTarget(ArrayList<Transition<Integer>> transitions) {
protected DataSet setTarget(List<Transition<Integer>> transitions) {
if (transitions.size() == 0)
throw new IllegalArgumentException("too few transitions");
@ -189,9 +182,6 @@ public abstract class QLearningDiscrete<O extends Encodable> extends QLearning<O
@Override
protected void finishEpoch(Observation observation) {
if (pendingTransition != null) {
pendingTransition.setNextObservation(observation);
getExpReplay().store(pendingTransition);
}
experienceHandler.setFinalObservation(observation);
}
}

View File

@ -0,0 +1,107 @@
package org.deeplearning4j.rl4j.experience;
import org.deeplearning4j.rl4j.learning.sync.IExpReplay;
import org.deeplearning4j.rl4j.learning.sync.Transition;
import org.deeplearning4j.rl4j.observation.Observation;
import org.junit.Test;
import org.nd4j.linalg.factory.Nd4j;
import java.util.ArrayList;
import java.util.List;
import static org.junit.Assert.assertEquals;
public class ReplayMemoryExperienceHandlerTest {
@Test
public void when_addingFirstExperience_expect_notAddedToStoreBeforeNextObservationIsAdded() {
// Arrange
TestExpReplay expReplayMock = new TestExpReplay();
ReplayMemoryExperienceHandler sut = new ReplayMemoryExperienceHandler(expReplayMock);
// Act
sut.addExperience(new Observation(Nd4j.create(new double[] { 1.0 })), 1, 1.0, false);
int numStoredTransitions = expReplayMock.addedTransitions.size();
sut.addExperience(new Observation(Nd4j.create(new double[] { 2.0 })), 2, 2.0, false);
// Assert
assertEquals(0, numStoredTransitions);
assertEquals(1, expReplayMock.addedTransitions.size());
}
@Test
public void when_addingExperience_expect_transitionsAreCorrect() {
// Arrange
TestExpReplay expReplayMock = new TestExpReplay();
ReplayMemoryExperienceHandler sut = new ReplayMemoryExperienceHandler(expReplayMock);
// Act
sut.addExperience(new Observation(Nd4j.create(new double[] { 1.0 })), 1, 1.0, false);
sut.addExperience(new Observation(Nd4j.create(new double[] { 2.0 })), 2, 2.0, false);
sut.setFinalObservation(new Observation(Nd4j.create(new double[] { 3.0 })));
// Assert
assertEquals(2, expReplayMock.addedTransitions.size());
assertEquals(1.0, expReplayMock.addedTransitions.get(0).getObservation().getData().getDouble(0), 0.00001);
assertEquals(1, (int)expReplayMock.addedTransitions.get(0).getAction());
assertEquals(1.0, expReplayMock.addedTransitions.get(0).getReward(), 0.00001);
assertEquals(2.0, expReplayMock.addedTransitions.get(0).getNextObservation().getDouble(0), 0.00001);
assertEquals(2.0, expReplayMock.addedTransitions.get(1).getObservation().getData().getDouble(0), 0.00001);
assertEquals(2, (int)expReplayMock.addedTransitions.get(1).getAction());
assertEquals(2.0, expReplayMock.addedTransitions.get(1).getReward(), 0.00001);
assertEquals(3.0, expReplayMock.addedTransitions.get(1).getNextObservation().getDouble(0), 0.00001);
}
@Test
public void when_settingFinalObservation_expect_nextAddedExperienceDoNotUsePreviousObservation() {
// Arrange
TestExpReplay expReplayMock = new TestExpReplay();
ReplayMemoryExperienceHandler sut = new ReplayMemoryExperienceHandler(expReplayMock);
// Act
sut.addExperience(new Observation(Nd4j.create(new double[] { 1.0 })), 1, 1.0, false);
sut.setFinalObservation(new Observation(Nd4j.create(new double[] { 2.0 })));
sut.addExperience(new Observation(Nd4j.create(new double[] { 3.0 })), 3, 3.0, false);
// Assert
assertEquals(1, expReplayMock.addedTransitions.size());
assertEquals(1, (int)expReplayMock.addedTransitions.get(0).getAction());
}
@Test
public void when_addingExperience_expect_getTrainingBatchSizeReturnSize() {
// Arrange
TestExpReplay expReplayMock = new TestExpReplay();
ReplayMemoryExperienceHandler sut = new ReplayMemoryExperienceHandler(expReplayMock);
sut.addExperience(new Observation(Nd4j.create(new double[] { 1.0 })), 1, 1.0, false);
sut.addExperience(new Observation(Nd4j.create(new double[] { 2.0 })), 2, 2.0, false);
sut.setFinalObservation(new Observation(Nd4j.create(new double[] { 3.0 })));
// Act
int size = sut.getTrainingBatchSize();
// Assert
assertEquals(2, size);
}
private static class TestExpReplay implements IExpReplay<Integer> {
public final List<Transition<Integer>> addedTransitions = new ArrayList<>();
@Override
public ArrayList<Transition<Integer>> getBatch() {
return null;
}
@Override
public void store(Transition<Integer> transition) {
addedTransitions.add(transition);
}
@Override
public int getBatchSize() {
return addedTransitions.size();
}
}
}

View File

@ -0,0 +1,82 @@
package org.deeplearning4j.rl4j.experience;
import org.deeplearning4j.rl4j.observation.Observation;
import org.junit.Test;
import org.nd4j.linalg.factory.Nd4j;
import java.util.List;
import static org.junit.Assert.*;
public class StateActionExperienceHandlerTest {
@Test
public void when_addingExperience_expect_generateTrainingBatchReturnsIt() {
// Arrange
StateActionExperienceHandler sut = new StateActionExperienceHandler();
sut.reset();
Observation observation = new Observation(Nd4j.zeros(1));
sut.addExperience(observation, 123, 234.0, true);
// Act
List<StateActionPair<Integer>> result = sut.generateTrainingBatch();
// Assert
assertEquals(1, result.size());
assertSame(observation, result.get(0).getObservation());
assertEquals(123, (int)result.get(0).getAction());
assertEquals(234.0, result.get(0).getReward(), 0.00001);
assertTrue(result.get(0).isTerminal());
}
@Test
public void when_addingMultipleExperiences_expect_generateTrainingBatchReturnsItInSameOrder() {
// Arrange
StateActionExperienceHandler sut = new StateActionExperienceHandler();
sut.reset();
sut.addExperience(null, 1, 1.0, false);
sut.addExperience(null, 2, 2.0, false);
sut.addExperience(null, 3, 3.0, false);
// Act
List<StateActionPair<Integer>> result = sut.generateTrainingBatch();
// Assert
assertEquals(3, result.size());
assertEquals(1, (int)result.get(0).getAction());
assertEquals(2, (int)result.get(1).getAction());
assertEquals(3, (int)result.get(2).getAction());
}
@Test
public void when_gettingExperience_expect_experienceStoreIsCleared() {
// Arrange
StateActionExperienceHandler sut = new StateActionExperienceHandler();
sut.reset();
sut.addExperience(null, 1, 1.0, false);
// Act
List<StateActionPair<Integer>> firstResult = sut.generateTrainingBatch();
List<StateActionPair<Integer>> secondResult = sut.generateTrainingBatch();
// Assert
assertEquals(1, firstResult.size());
assertEquals(0, secondResult.size());
}
@Test
public void when_addingExperience_expect_getTrainingBatchSizeReturnSize() {
// Arrange
StateActionExperienceHandler sut = new StateActionExperienceHandler();
sut.reset();
sut.addExperience(null, 1, 1.0, false);
sut.addExperience(null, 2, 2.0, false);
sut.addExperience(null, 3, 3.0, false);
// Act
int size = sut.getTrainingBatchSize();
// Assert
assertEquals(3, size);
}
}

View File

@ -18,9 +18,12 @@
package org.deeplearning4j.rl4j.learning.async;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.rl4j.experience.ExperienceHandler;
import org.deeplearning4j.rl4j.experience.StateActionPair;
import org.deeplearning4j.rl4j.learning.IHistoryProcessor;
import org.deeplearning4j.rl4j.learning.configuration.IAsyncLearningConfiguration;
import org.deeplearning4j.rl4j.learning.listener.TrainingListenerList;
import org.deeplearning4j.rl4j.learning.sync.Transition;
import org.deeplearning4j.rl4j.mdp.MDP;
import org.deeplearning4j.rl4j.observation.Observation;
import org.deeplearning4j.rl4j.policy.IPolicy;
@ -31,7 +34,6 @@ import org.nd4j.linalg.api.ndarray.INDArray;
import java.util.ArrayList;
import java.util.List;
import java.util.Stack;
import static org.junit.Assert.assertEquals;
@ -51,7 +53,9 @@ public class AsyncThreadDiscreteTest {
TrainingListenerList listeners = new TrainingListenerList();
MockPolicy policyMock = new MockPolicy();
MockAsyncConfiguration config = new MockAsyncConfiguration(5L, 100, 0,0, 0, 0, 0, 0, 2, 5);
TestAsyncThreadDiscrete sut = new TestAsyncThreadDiscrete(asyncGlobalMock, mdpMock, listeners, 0, 0, policyMock, config, hpMock);
MockExperienceHandler experienceHandlerMock = new MockExperienceHandler();
MockUpdateAlgorithm updateAlgorithmMock = new MockUpdateAlgorithm();
TestAsyncThreadDiscrete sut = new TestAsyncThreadDiscrete(asyncGlobalMock, mdpMock, listeners, 0, 0, policyMock, config, hpMock, experienceHandlerMock, updateAlgorithmMock);
sut.getLegacyMDPWrapper().setTransformProcess(MockMDP.buildTransformProcess(observationSpace.getShape(), hpConf.getSkipFrame(), hpConf.getHistoryLength()));
// Act
@ -60,8 +64,8 @@ public class AsyncThreadDiscreteTest {
// Assert
assertEquals(2, sut.trainSubEpochResults.size());
double[][] expectedLastObservations = new double[][] {
new double[] { 4.0, 6.0, 8.0, 10.0, 12.0 },
new double[] { 8.0, 10.0, 12.0, 14.0, 16.0 },
new double[] { 4.0, 6.0, 8.0, 10.0, 12.0 },
new double[] { 8.0, 10.0, 12.0, 14.0, 16.0 },
};
double[] expectedSubEpochReturnRewards = new double[] { 42.0, 58.0 };
for(int i = 0; i < 2; ++i) {
@ -102,62 +106,22 @@ public class AsyncThreadDiscreteTest {
}
}
// NeuralNetwork
assertEquals(2, nnMock.copyCallCount);
double[][] expectedNNInputs = new double[][] {
// ExperienceHandler
double[][] expectedExperienceHandlerInputs = new double[][] {
new double[] { 0.0, 2.0, 4.0, 6.0, 8.0 },
new double[] { 2.0, 4.0, 6.0, 8.0, 10.0 },
new double[] { 4.0, 6.0, 8.0, 10.0, 12.0 }, // FIXME: This one comes from the computation of output of the last minitrans
new double[] { 4.0, 6.0, 8.0, 10.0, 12.0 },
new double[] { 6.0, 8.0, 10.0, 12.0, 14.0 },
new double[] { 8.0, 10.0, 12.0, 14.0, 16.0 }, // FIXME: This one comes from the computation of output of the last minitrans
};
assertEquals(expectedNNInputs.length, nnMock.outputAllInputs.size());
for(int i = 0; i < expectedNNInputs.length; ++i) {
double[] expectedRow = expectedNNInputs[i];
INDArray input = nnMock.outputAllInputs.get(i);
assertEquals(expectedExperienceHandlerInputs.length, experienceHandlerMock.addExperienceArgs.size());
for(int i = 0; i < expectedExperienceHandlerInputs.length; ++i) {
double[] expectedRow = expectedExperienceHandlerInputs[i];
INDArray input = experienceHandlerMock.addExperienceArgs.get(i).getObservation().getData();
assertEquals(expectedRow.length, input.shape()[1]);
for(int j = 0; j < expectedRow.length; ++j) {
assertEquals(expectedRow[j], 255.0 * input.getDouble(j), 0.00001);
}
}
int arrayIdx = 0;
double[][][] expectedMinitransObs = new double[][][] {
new double[][] {
new double[] { 0.0, 2.0, 4.0, 6.0, 8.0 },
new double[] { 2.0, 4.0, 6.0, 8.0, 10.0 },
new double[] { 4.0, 6.0, 8.0, 10.0, 12.0 }, // FIXME: The last minitrans contains the next observation
},
new double[][] {
new double[] { 4.0, 6.0, 8.0, 10.0, 12.0 },
new double[] { 6.0, 8.0, 10.0, 12.0, 14.0 },
new double[] { 8.0, 10.0, 12.0, 14.0, 16.0 }, // FIXME: The last minitrans contains the next observation
}
};
double[] expectedOutputs = new double[] { 1.0, 2.0, 3.0, 4.0, 5.0, 6.0 };
double[] expectedRewards = new double[] { 0.0, 0.0, 3.0, 0.0, 0.0, 6.0 };
assertEquals(2, sut.rewards.size());
for(int rewardIdx = 0; rewardIdx < 2; ++rewardIdx) {
Stack<MiniTrans<Integer>> miniTransStack = sut.rewards.get(rewardIdx);
for (int i = 0; i < expectedMinitransObs[rewardIdx].length; ++i) {
MiniTrans minitrans = miniTransStack.get(i);
// Observation
double[] expectedRow = expectedMinitransObs[rewardIdx][i];
INDArray realRewards = minitrans.getObs();
assertEquals(expectedRow.length, realRewards.shape()[1]);
for (int j = 0; j < expectedRow.length; ++j) {
assertEquals("row: "+ i + " col: " + j, expectedRow[j], 255.0 * realRewards.getDouble(j), 0.00001);
}
assertEquals(expectedOutputs[arrayIdx], minitrans.getOutput()[0].getDouble(0), 0.00001);
assertEquals(expectedRewards[arrayIdx], minitrans.getReward(), 0.00001);
++arrayIdx;
}
}
}
public static class TestAsyncThreadDiscrete extends AsyncThreadDiscrete<MockEncodable, MockNeuralNet> {
@ -167,22 +131,19 @@ public class AsyncThreadDiscreteTest {
private final MockAsyncConfiguration config;
public final List<SubEpochReturn> trainSubEpochResults = new ArrayList<SubEpochReturn>();
public final List<Stack<MiniTrans<Integer>>> rewards = new ArrayList<Stack<MiniTrans<Integer>>>();
public TestAsyncThreadDiscrete(MockAsyncGlobal asyncGlobal, MDP<MockEncodable, Integer, DiscreteSpace> mdp,
TrainingListenerList listeners, int threadNumber, int deviceNum, MockPolicy policy,
MockAsyncConfiguration config, IHistoryProcessor hp) {
MockAsyncConfiguration config, IHistoryProcessor hp,
ExperienceHandler<Integer, Transition<Integer>> experienceHandler,
UpdateAlgorithm<MockNeuralNet> updateAlgorithm) {
super(asyncGlobal, mdp, listeners, threadNumber, deviceNum);
this.asyncGlobal = asyncGlobal;
this.policy = policy;
this.config = config;
setHistoryProcessor(hp);
}
@Override
public Gradient[] calcGradient(MockNeuralNet mockNeuralNet, Stack<MiniTrans<Integer>> rewards) {
this.rewards.add(rewards);
return new Gradient[0];
setExperienceHandler(experienceHandler);
setUpdateAlgorithm(updateAlgorithm);
}
@Override
@ -200,6 +161,11 @@ public class AsyncThreadDiscreteTest {
return policy;
}
@Override
protected UpdateAlgorithm<MockNeuralNet> buildUpdateAlgorithm() {
return null;
}
@Override
public SubEpochReturn trainSubEpoch(Observation sObs, int nstep) {
asyncGlobal.increaseCurrentLoop();

View File

@ -1,197 +0,0 @@
/*******************************************************************************
* Copyright (c) 2015-2019 Skymind, Inc.
* Copyright (c) 2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.rl4j.learning.async.a3c.discrete;
import org.deeplearning4j.nn.api.NeuralNetwork;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.rl4j.learning.IHistoryProcessor;
import org.deeplearning4j.rl4j.learning.async.MiniTrans;
import org.deeplearning4j.rl4j.learning.configuration.A3CLearningConfiguration;
import org.deeplearning4j.rl4j.network.NeuralNet;
import org.deeplearning4j.rl4j.network.ac.IActorCritic;
import org.deeplearning4j.rl4j.support.*;
import org.junit.Test;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.primitives.Pair;
import java.io.IOException;
import java.io.OutputStream;
import java.util.ArrayList;
import java.util.List;
import java.util.Stack;
import static org.junit.Assert.assertArrayEquals;
import static org.junit.Assert.assertEquals;
public class A3CThreadDiscreteTest {
@Test
public void refac_calcGradient() {
// Arrange
double gamma = 0.9;
MockObservationSpace observationSpace = new MockObservationSpace();
MockMDP mdpMock = new MockMDP(observationSpace);
A3CLearningConfiguration config = A3CLearningConfiguration.builder().gamma(0.9).build();
MockActorCritic actorCriticMock = new MockActorCritic();
IHistoryProcessor.Configuration hpConf = new IHistoryProcessor.Configuration(5, 1, 1, 1, 1, 0, 0, 2);
MockAsyncGlobal<IActorCritic> asyncGlobalMock = new MockAsyncGlobal<IActorCritic>(actorCriticMock);
A3CThreadDiscrete sut = new A3CThreadDiscrete<MockEncodable>(mdpMock, asyncGlobalMock, config, 0, null, 0);
MockHistoryProcessor hpMock = new MockHistoryProcessor(hpConf);
sut.setHistoryProcessor(hpMock);
double[][] minitransObs = new double[][] {
new double[] { 0.0, 2.0, 4.0, 6.0, 8.0 },
new double[] { 2.0, 4.0, 6.0, 8.0, 10.0 },
new double[] { 4.0, 6.0, 8.0, 10.0, 12.0 },
};
double[] outputs = new double[] { 1.0, 2.0, 3.0 };
double[] rewards = new double[] { 0.0, 0.0, 3.0 };
Stack<MiniTrans<Integer>> minitransList = new Stack<MiniTrans<Integer>>();
for(int i = 0; i < 3; ++i) {
INDArray obs = Nd4j.create(minitransObs[i]).reshape(5, 1, 1);
INDArray[] output = new INDArray[] {
Nd4j.zeros(5)
};
output[0].putScalar(i, outputs[i]);
minitransList.push(new MiniTrans<>(obs, i, output, rewards[i]));
}
minitransList.push(new MiniTrans<>(null, 0, null, 4.0)); // The special batch-ending MiniTrans
// Act
sut.calcGradient(actorCriticMock, minitransList);
// Assert
assertEquals(1, actorCriticMock.gradientParams.size());
INDArray input = actorCriticMock.gradientParams.get(0).getFirst();
INDArray[] labels = actorCriticMock.gradientParams.get(0).getSecond();
assertEquals(minitransObs.length, input.shape()[0]);
for(int i = 0; i < minitransObs.length; ++i) {
double[] expectedRow = minitransObs[i];
assertEquals(expectedRow.length, input.shape()[1]);
for(int j = 0; j < expectedRow.length; ++j) {
assertEquals(expectedRow[j], input.getDouble(i, j, 1, 1), 0.00001);
}
}
double latestReward = (gamma * 4.0) + 3.0;
double[] expectedLabels0 = new double[] { gamma * gamma * latestReward, gamma * latestReward, latestReward };
for(int i = 0; i < expectedLabels0.length; ++i) {
assertEquals(expectedLabels0[i], labels[0].getDouble(i), 0.00001);
}
double[][] expectedLabels1 = new double[][] {
new double[] { 4.346, 0.0, 0.0, 0.0, 0.0 },
new double[] { 0.0, gamma * latestReward, 0.0, 0.0, 0.0 },
new double[] { 0.0, 0.0, latestReward, 0.0, 0.0 },
};
assertArrayEquals(new long[] { expectedLabels0.length, 1 }, labels[0].shape());
for(int i = 0; i < expectedLabels1.length; ++i) {
double[] expectedRow = expectedLabels1[i];
assertEquals(expectedRow.length, labels[1].shape()[1]);
for(int j = 0; j < expectedRow.length; ++j) {
assertEquals(expectedRow[j], labels[1].getDouble(i, j), 0.00001);
}
}
}
public class MockActorCritic implements IActorCritic {
public final List<Pair<INDArray, INDArray[]>> gradientParams = new ArrayList<>();
@Override
public NeuralNetwork[] getNeuralNetworks() {
return new NeuralNetwork[0];
}
@Override
public boolean isRecurrent() {
return false;
}
@Override
public void reset() {
}
@Override
public void fit(INDArray input, INDArray[] labels) {
}
@Override
public INDArray[] outputAll(INDArray batch) {
return new INDArray[0];
}
@Override
public IActorCritic clone() {
return this;
}
@Override
public void copy(NeuralNet from) {
}
@Override
public void copy(IActorCritic from) {
}
@Override
public Gradient[] gradient(INDArray input, INDArray[] labels) {
gradientParams.add(new Pair<INDArray, INDArray[]>(input, labels));
return new Gradient[0];
}
@Override
public void applyGradient(Gradient[] gradient, int batchSize) {
}
@Override
public void save(OutputStream streamValue, OutputStream streamPolicy) throws IOException {
}
@Override
public void save(String pathValue, String pathPolicy) throws IOException {
}
@Override
public double getLatestScore() {
return 0;
}
@Override
public void save(OutputStream os) throws IOException {
}
@Override
public void save(String filename) throws IOException {
}
}
}

View File

@ -0,0 +1,160 @@
package org.deeplearning4j.rl4j.learning.async.a3c.discrete;
import org.deeplearning4j.nn.api.NeuralNetwork;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.rl4j.experience.StateActionPair;
import org.deeplearning4j.rl4j.network.NeuralNet;
import org.deeplearning4j.rl4j.network.ac.IActorCritic;
import org.deeplearning4j.rl4j.observation.Observation;
import org.deeplearning4j.rl4j.support.MockAsyncGlobal;
import org.deeplearning4j.rl4j.support.MockMDP;
import org.deeplearning4j.rl4j.support.MockObservationSpace;
import org.junit.Test;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.primitives.Pair;
import java.io.IOException;
import java.io.OutputStream;
import java.util.ArrayList;
import java.util.List;
import static org.junit.Assert.assertEquals;
public class A3CUpdateAlgorithmTest {
@Test
public void refac_calcGradient_non_terminal() {
// Arrange
double gamma = 0.9;
MockObservationSpace observationSpace = new MockObservationSpace(new int[] { 5 });
MockMDP mdpMock = new MockMDP(observationSpace);
MockActorCritic actorCriticMock = new MockActorCritic();
MockAsyncGlobal<IActorCritic> asyncGlobalMock = new MockAsyncGlobal<IActorCritic>(actorCriticMock);
A3CUpdateAlgorithm sut = new A3CUpdateAlgorithm(asyncGlobalMock, observationSpace.getShape(), mdpMock.getActionSpace().getSize(), -1, gamma);
INDArray[] originalObservations = new INDArray[] {
Nd4j.create(new double[] { 0.0, 0.1, 0.2, 0.3, 0.4 }),
Nd4j.create(new double[] { 1.0, 1.1, 1.2, 1.3, 1.4 }),
Nd4j.create(new double[] { 2.0, 2.1, 2.2, 2.3, 2.4 }),
Nd4j.create(new double[] { 3.0, 3.1, 3.2, 3.3, 3.4 }),
};
int[] actions = new int[] { 0, 1, 2, 1 };
double[] rewards = new double[] { 0.1, 1.0, 10.0, 100.0 };
List<StateActionPair<Integer>> experience = new ArrayList<StateActionPair<Integer>>();
for(int i = 0; i < originalObservations.length; ++i) {
experience.add(new StateActionPair<>(new Observation(originalObservations[i]), actions[i], rewards[i], false));
}
// Act
sut.computeGradients(actorCriticMock, experience);
// Assert
assertEquals(1, actorCriticMock.gradientParams.size());
// Inputs
INDArray input = actorCriticMock.gradientParams.get(0).getLeft();
for(int i = 0; i < 4; ++i) {
for(int j = 0; j < 5; ++j) {
assertEquals(i + j / 10.0, input.getDouble(i, j), 0.00001);
}
}
INDArray targets = actorCriticMock.gradientParams.get(0).getRight()[0];
INDArray logSoftmax = actorCriticMock.gradientParams.get(0).getRight()[1];
assertEquals(4, targets.shape()[0]);
assertEquals(1, targets.shape()[1]);
// FIXME: check targets values once fixed
assertEquals(4, logSoftmax.shape()[0]);
assertEquals(5, logSoftmax.shape()[1]);
// FIXME: check logSoftmax values once fixed
}
public class MockActorCritic implements IActorCritic {
public final List<Pair<INDArray, INDArray[]>> gradientParams = new ArrayList<>();
@Override
public NeuralNetwork[] getNeuralNetworks() {
return new NeuralNetwork[0];
}
@Override
public boolean isRecurrent() {
return false;
}
@Override
public void reset() {
}
@Override
public void fit(INDArray input, INDArray[] labels) {
}
@Override
public INDArray[] outputAll(INDArray batch) {
return new INDArray[] { batch.mul(-1.0) };
}
@Override
public IActorCritic clone() {
return this;
}
@Override
public void copy(NeuralNet from) {
}
@Override
public void copy(IActorCritic from) {
}
@Override
public Gradient[] gradient(INDArray input, INDArray[] labels) {
gradientParams.add(new Pair<INDArray, INDArray[]>(input, labels));
return new Gradient[0];
}
@Override
public void applyGradient(Gradient[] gradient, int batchSize) {
}
@Override
public void save(OutputStream streamValue, OutputStream streamPolicy) throws IOException {
}
@Override
public void save(String pathValue, String pathPolicy) throws IOException {
}
@Override
public double getLatestScore() {
return 0;
}
@Override
public void save(OutputStream os) throws IOException {
}
@Override
public void save(String filename) throws IOException {
}
}
}

View File

@ -1,98 +0,0 @@
/*******************************************************************************
* Copyright (c) 2015-2020 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.rl4j.learning.async.nstep.discrete;
import org.deeplearning4j.rl4j.learning.IHistoryProcessor;
import org.deeplearning4j.rl4j.learning.async.MiniTrans;
import org.deeplearning4j.rl4j.learning.configuration.AsyncQLearningConfiguration;
import org.deeplearning4j.rl4j.support.*;
import org.junit.Test;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import java.util.Stack;
import static org.junit.Assert.assertEquals;
public class AsyncNStepQLearningThreadDiscreteTest {
@Test
public void refac_calcGradient() {
// Arrange
double gamma = 0.9;
MockObservationSpace observationSpace = new MockObservationSpace();
MockMDP mdpMock = new MockMDP(observationSpace);
AsyncQLearningConfiguration config = AsyncQLearningConfiguration.builder().gamma(gamma).build();
MockDQN dqnMock = new MockDQN();
IHistoryProcessor.Configuration hpConf = new IHistoryProcessor.Configuration(5, 1, 1, 1, 1, 0, 0, 2);
MockAsyncGlobal asyncGlobalMock = new MockAsyncGlobal(dqnMock);
AsyncNStepQLearningThreadDiscrete sut = new AsyncNStepQLearningThreadDiscrete<MockEncodable>(mdpMock, asyncGlobalMock, config, null, 0, 0);
MockHistoryProcessor hpMock = new MockHistoryProcessor(hpConf);
sut.setHistoryProcessor(hpMock);
double[][] minitransObs = new double[][] {
new double[] { 0.0, 2.0, 4.0, 6.0, 8.0 },
new double[] { 2.0, 4.0, 6.0, 8.0, 10.0 },
new double[] { 4.0, 6.0, 8.0, 10.0, 12.0 },
};
double[] outputs = new double[] { 1.0, 2.0, 3.0 };
double[] rewards = new double[] { 0.0, 0.0, 3.0 };
Stack<MiniTrans<Integer>> minitransList = new Stack<MiniTrans<Integer>>();
for(int i = 0; i < 3; ++i) {
INDArray obs = Nd4j.create(minitransObs[i]).reshape(5, 1, 1);
INDArray[] output = new INDArray[] {
Nd4j.zeros(5)
};
output[0].putScalar(i, outputs[i]);
minitransList.push(new MiniTrans<>(obs, i, output, rewards[i]));
}
minitransList.push(new MiniTrans<>(null, 0, null, 4.0)); // The special batch-ending MiniTrans
// Act
sut.calcGradient(dqnMock, minitransList);
// Assert
assertEquals(1, dqnMock.gradientParams.size());
INDArray input = dqnMock.gradientParams.get(0).getFirst();
INDArray labels = dqnMock.gradientParams.get(0).getSecond();
assertEquals(minitransObs.length, input.shape()[0]);
for(int i = 0; i < minitransObs.length; ++i) {
double[] expectedRow = minitransObs[i];
assertEquals(expectedRow.length, input.shape()[1]);
for(int j = 0; j < expectedRow.length; ++j) {
assertEquals(expectedRow[j], input.getDouble(i, j, 1, 1), 0.00001);
}
}
double latestReward = (gamma * 4.0) + 3.0;
double[][] expectedLabels = new double[][] {
new double[] { gamma * gamma * latestReward, 0.0, 0.0, 0.0, 0.0 },
new double[] { 0.0, gamma * latestReward, 0.0, 0.0, 0.0 },
new double[] { 0.0, 0.0, latestReward, 0.0, 0.0 },
};
assertEquals(minitransObs.length, labels.shape()[0]);
for(int i = 0; i < minitransObs.length; ++i) {
double[] expectedRow = expectedLabels[i];
assertEquals(expectedRow.length, labels.shape()[1]);
for(int j = 0; j < expectedRow.length; ++j) {
assertEquals(expectedRow[j], labels.getDouble(i, j), 0.00001);
}
}
}
}

View File

@ -0,0 +1,115 @@
package org.deeplearning4j.rl4j.learning.async.nstep.discrete;
import org.deeplearning4j.rl4j.experience.StateActionPair;
import org.deeplearning4j.rl4j.learning.async.UpdateAlgorithm;
import org.deeplearning4j.rl4j.observation.Observation;
import org.deeplearning4j.rl4j.support.MockAsyncGlobal;
import org.deeplearning4j.rl4j.support.MockDQN;
import org.junit.Test;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import java.util.ArrayList;
import java.util.List;
import static org.junit.Assert.assertEquals;
public class QLearningUpdateAlgorithmTest {
@Test
public void when_isTerminal_expect_initRewardIs0() {
// Arrange
MockDQN dqnMock = new MockDQN();
MockAsyncGlobal asyncGlobalMock = new MockAsyncGlobal(dqnMock);
UpdateAlgorithm sut = new QLearningUpdateAlgorithm(asyncGlobalMock, new int[] { 1 }, 1, -1, 1.0);
List<StateActionPair<Integer>> experience = new ArrayList<StateActionPair<Integer>>() {
{
add(new StateActionPair<Integer>(new Observation(Nd4j.zeros(1)), 0, 0.0, true));
}
};
// Act
sut.computeGradients(dqnMock, experience);
// Assert
assertEquals(0.0, dqnMock.gradientParams.get(0).getRight().getDouble(0), 0.00001);
}
@Test
public void when_terminalAndNoTargetUpdate_expect_initRewardWithMaxQFromCurrent() {
// Arrange
MockDQN globalDQNMock = new MockDQN();
MockAsyncGlobal asyncGlobalMock = new MockAsyncGlobal(globalDQNMock);
UpdateAlgorithm sut = new QLearningUpdateAlgorithm(asyncGlobalMock, new int[] { 2 }, 2, -1, 1.0);
List<StateActionPair<Integer>> experience = new ArrayList<StateActionPair<Integer>>() {
{
add(new StateActionPair<Integer>(new Observation(Nd4j.create(new double[] { -123.0, -234.0 })), 0, 0.0, false));
}
};
MockDQN dqnMock = new MockDQN();
// Act
sut.computeGradients(dqnMock, experience);
// Assert
assertEquals(2, dqnMock.outputAllParams.size());
assertEquals(-123.0, dqnMock.outputAllParams.get(0).getDouble(0, 0), 0.00001);
assertEquals(234.0, dqnMock.gradientParams.get(0).getRight().getDouble(0), 0.00001);
}
@Test
public void when_terminalWithTargetUpdate_expect_initRewardWithMaxQFromGlobal() {
// Arrange
MockDQN globalDQNMock = new MockDQN();
MockAsyncGlobal asyncGlobalMock = new MockAsyncGlobal(globalDQNMock);
UpdateAlgorithm sut = new QLearningUpdateAlgorithm(asyncGlobalMock, new int[] { 2 }, 2, 1, 1.0);
List<StateActionPair<Integer>> experience = new ArrayList<StateActionPair<Integer>>() {
{
add(new StateActionPair<Integer>(new Observation(Nd4j.create(new double[] { -123.0, -234.0 })), 0, 0.0, false));
}
};
MockDQN dqnMock = new MockDQN();
// Act
sut.computeGradients(dqnMock, experience);
// Assert
assertEquals(1, globalDQNMock.outputAllParams.size());
assertEquals(-123.0, globalDQNMock.outputAllParams.get(0).getDouble(0, 0), 0.00001);
assertEquals(234.0, dqnMock.gradientParams.get(0).getRight().getDouble(0), 0.00001);
}
@Test
public void when_callingWithMultipleExperiences_expect_gradientsAreValid() {
// Arrange
double gamma = 0.9;
MockDQN globalDQNMock = new MockDQN();
MockAsyncGlobal asyncGlobalMock = new MockAsyncGlobal(globalDQNMock);
UpdateAlgorithm sut = new QLearningUpdateAlgorithm(asyncGlobalMock, new int[] { 2 }, 2, 1, gamma);
List<StateActionPair<Integer>> experience = new ArrayList<StateActionPair<Integer>>() {
{
add(new StateActionPair<Integer>(new Observation(Nd4j.create(new double[] { -1.1, -1.2 })), 0, 1.0, false));
add(new StateActionPair<Integer>(new Observation(Nd4j.create(new double[] { -2.1, -2.2 })), 1, 2.0, true));
}
};
MockDQN dqnMock = new MockDQN();
// Act
sut.computeGradients(dqnMock, experience);
// Assert
// input side -- should be a stack of observations
INDArray input = dqnMock.gradientParams.get(0).getLeft();
assertEquals(-1.1, input.getDouble(0, 0), 0.00001);
assertEquals(-1.2, input.getDouble(0, 1), 0.00001);
assertEquals(-2.1, input.getDouble(1, 0), 0.00001);
assertEquals(-2.2, input.getDouble(1, 1), 0.00001);
// target side
INDArray target = dqnMock.gradientParams.get(0).getRight();
assertEquals(1.0 + gamma * 2.0, target.getDouble(0, 0), 0.00001);
assertEquals(1.2, target.getDouble(0, 1), 0.00001);
assertEquals(2.1, target.getDouble(1, 0), 0.00001);
assertEquals(2.0, target.getDouble(1, 1), 0.00001);
}
}

View File

@ -17,6 +17,8 @@
package org.deeplearning4j.rl4j.learning.sync.qlearning.discrete;
import org.deeplearning4j.rl4j.experience.ExperienceHandler;
import org.deeplearning4j.rl4j.experience.StateActionPair;
import org.deeplearning4j.rl4j.learning.IHistoryProcessor;
import org.deeplearning4j.rl4j.learning.configuration.QLearningConfiguration;
import org.deeplearning4j.rl4j.learning.sync.IExpReplay;
@ -75,8 +77,8 @@ public class QLearningDiscreteTest {
.build();
MockDataManager dataManager = new MockDataManager(false);
MockExpReplay expReplay = new MockExpReplay();
TestQLearningDiscrete sut = new TestQLearningDiscrete(mdp, dqn, conf, dataManager, expReplay, 10, random);
MockExperienceHandler experienceHandler = new MockExperienceHandler();
TestQLearningDiscrete sut = new TestQLearningDiscrete(mdp, dqn, conf, dataManager, experienceHandler, 10, random);
IHistoryProcessor.Configuration hpConf = new IHistoryProcessor.Configuration(5, 4, 4, 4, 4, 0, 0, 2);
MockHistoryProcessor hp = new MockHistoryProcessor(hpConf);
sut.setHistoryProcessor(hp);
@ -93,7 +95,6 @@ public class QLearningDiscreteTest {
for (int i = 0; i < expectedRecords.length; ++i) {
assertEquals(expectedRecords[i], hp.recordCalls.get(i).getDouble(0), 0.0001);
}
assertEquals(0, hp.startMonitorCallCount);
assertEquals(0, hp.stopMonitorCallCount);
@ -133,30 +134,31 @@ public class QLearningDiscreteTest {
// MDP calls
assertArrayEquals(new Integer[]{0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 4, 4, 2, 2, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4}, mdp.actions.toArray());
// ExpReplay calls
double[] expectedTrRewards = new double[]{9.0, 21.0, 25.0, 29.0, 33.0, 37.0, 41.0};
int[] expectedTrActions = new int[]{1, 4, 2, 4, 4, 4, 4, 4};
double[] expectedTrNextObservation = new double[]{2.0, 4.0, 6.0, 8.0, 10.0, 12.0, 14.0, 16.0};
double[][] expectedTrObservations = new double[][]{
new double[]{0.0, 2.0, 4.0, 6.0, 8.0},
new double[]{2.0, 4.0, 6.0, 8.0, 10.0},
new double[]{4.0, 6.0, 8.0, 10.0, 12.0},
new double[]{6.0, 8.0, 10.0, 12.0, 14.0},
new double[]{8.0, 10.0, 12.0, 14.0, 16.0},
new double[]{10.0, 12.0, 14.0, 16.0, 18.0},
new double[]{12.0, 14.0, 16.0, 18.0, 20.0},
new double[]{14.0, 16.0, 18.0, 20.0, 22.0},
// ExperienceHandler calls
double[] expectedTrRewards = new double[] { 9.0, 21.0, 25.0, 29.0, 33.0, 37.0, 41.0 };
int[] expectedTrActions = new int[] { 1, 4, 2, 4, 4, 4, 4, 4 };
double[] expectedTrNextObservation = new double[] { 2.0, 4.0, 6.0, 8.0, 10.0, 12.0, 14.0, 16.0 };
double[][] expectedTrObservations = new double[][] {
new double[] { 0.0, 2.0, 4.0, 6.0, 8.0 },
new double[] { 2.0, 4.0, 6.0, 8.0, 10.0 },
new double[] { 4.0, 6.0, 8.0, 10.0, 12.0 },
new double[] { 6.0, 8.0, 10.0, 12.0, 14.0 },
new double[] { 8.0, 10.0, 12.0, 14.0, 16.0 },
new double[] { 10.0, 12.0, 14.0, 16.0, 18.0 },
new double[] { 12.0, 14.0, 16.0, 18.0, 20.0 },
new double[] { 14.0, 16.0, 18.0, 20.0, 22.0 },
};
assertEquals(expectedTrObservations.length, expReplay.transitions.size());
for (int i = 0; i < expectedTrRewards.length; ++i) {
Transition tr = expReplay.transitions.get(i);
assertEquals(expectedTrRewards[i], tr.getReward(), 0.0001);
assertEquals(expectedTrActions[i], tr.getAction());
assertEquals(expectedTrNextObservation[i], 255.0 * tr.getNextObservation().getDouble(0), 0.0001);
for (int j = 0; j < expectedTrObservations[i].length; ++j) {
assertEquals("row: " + i + " col: " + j, expectedTrObservations[i][j], 255.0 * tr.getObservation().getData().getDouble(0, j, 0), 0.0001);
assertEquals(expectedTrObservations.length, experienceHandler.addExperienceArgs.size());
for(int i = 0; i < expectedTrRewards.length; ++i) {
StateActionPair<Integer> stateActionPair = experienceHandler.addExperienceArgs.get(i);
assertEquals(expectedTrRewards[i], stateActionPair.getReward(), 0.0001);
assertEquals((int)expectedTrActions[i], (int)stateActionPair.getAction());
for(int j = 0; j < expectedTrObservations[i].length; ++j) {
assertEquals("row: "+ i + " col: " + j, expectedTrObservations[i][j], 255.0 * stateActionPair.getObservation().getData().getDouble(0, j, 0), 0.0001);
}
}
assertEquals(expectedTrNextObservation[expectedTrNextObservation.length - 1], 255.0 * experienceHandler.finalObservation.getData().getDouble(0), 0.0001);
// trainEpoch result
assertEquals(initStepCount + 16, result.getStepCounter());
@ -167,20 +169,16 @@ public class QLearningDiscreteTest {
public static class TestQLearningDiscrete extends QLearningDiscrete<MockEncodable> {
public TestQLearningDiscrete(MDP<MockEncodable, Integer, DiscreteSpace> mdp, IDQN dqn,
QLearningConfiguration conf, IDataManager dataManager, MockExpReplay expReplay,
QLearningConfiguration conf, IDataManager dataManager, ExperienceHandler<Integer, Transition<Integer>> experienceHandler,
int epsilonNbStep, Random rnd) {
super(mdp, dqn, conf, epsilonNbStep, rnd);
addListener(new DataManagerTrainingListener(dataManager));
setExpReplay(expReplay);
setExperienceHandler(experienceHandler);
}
@Override
protected DataSet setTarget(ArrayList<Transition<Integer>> transitions) {
return new org.nd4j.linalg.dataset.DataSet(Nd4j.create(new double[]{123.0}), Nd4j.create(new double[]{234.0}));
}
public void setExpReplay(IExpReplay<Integer> exp) {
this.expReplay = exp;
protected DataSet setTarget(List<Transition<Integer>> transitions) {
return new org.nd4j.linalg.dataset.DataSet(Nd4j.create(new double[] { 123.0 }), Nd4j.create(new double[] { 234.0 }));
}
@Override

View File

@ -19,6 +19,7 @@ public class MockDQN implements IDQN {
public final List<INDArray> outputParams = new ArrayList<>();
public final List<Pair<INDArray, INDArray>> fitParams = new ArrayList<>();
public final List<Pair<INDArray, INDArray>> gradientParams = new ArrayList<>();
public final List<INDArray> outputAllParams = new ArrayList<>();
@Override
public NeuralNetwork[] getNeuralNetworks() {
@ -58,7 +59,8 @@ public class MockDQN implements IDQN {
@Override
public INDArray[] outputAll(INDArray batch) {
return new INDArray[0];
outputAllParams.add(batch);
return new INDArray[] { batch.mul(-1.0) };
}
@Override

View File

@ -1,22 +0,0 @@
package org.deeplearning4j.rl4j.support;
import org.deeplearning4j.rl4j.learning.sync.IExpReplay;
import org.deeplearning4j.rl4j.learning.sync.Transition;
import java.util.ArrayList;
import java.util.List;
public class MockExpReplay implements IExpReplay<Integer> {
public List<Transition<Integer>> transitions = new ArrayList<>();
@Override
public ArrayList<Transition<Integer>> getBatch() {
return null;
}
@Override
public void store(Transition<Integer> transition) {
transitions.add(transition);
}
}

View File

@ -0,0 +1,46 @@
package org.deeplearning4j.rl4j.support;
import org.deeplearning4j.rl4j.experience.ExperienceHandler;
import org.deeplearning4j.rl4j.experience.StateActionPair;
import org.deeplearning4j.rl4j.learning.sync.Transition;
import org.deeplearning4j.rl4j.observation.Observation;
import java.util.ArrayList;
import java.util.List;
public class MockExperienceHandler implements ExperienceHandler<Integer, Transition<Integer>> {
public List<StateActionPair<Integer>> addExperienceArgs = new ArrayList<StateActionPair<Integer>>();
public Observation finalObservation;
public boolean isGenerateTrainingBatchCalled;
public boolean isResetCalled;
@Override
public void addExperience(Observation observation, Integer action, double reward, boolean isTerminal) {
addExperienceArgs.add(new StateActionPair<>(observation, action, reward, isTerminal));
}
@Override
public void setFinalObservation(Observation observation) {
finalObservation = observation;
}
@Override
public List<Transition<Integer>> generateTrainingBatch() {
isGenerateTrainingBatchCalled = true;
return new ArrayList<Transition<Integer>>() {
{
add(new Transition<Integer>(null, 0, 0.0, false));
}
};
}
@Override
public void reset() {
isResetCalled = true;
}
@Override
public int getTrainingBatchSize() {
return 1;
}
}

View File

@ -5,6 +5,16 @@ import org.nd4j.linalg.api.ndarray.INDArray;
public class MockObservationSpace implements ObservationSpace {
private final int[] shape;
public MockObservationSpace() {
this(new int[] { 1 });
}
public MockObservationSpace(int[] shape) {
this.shape = shape;
}
@Override
public String getName() {
return null;
@ -12,7 +22,7 @@ public class MockObservationSpace implements ObservationSpace {
@Override
public int[] getShape() {
return new int[] { 1 };
return shape;
}
@Override

View File

@ -0,0 +1,19 @@
package org.deeplearning4j.rl4j.support;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.rl4j.experience.StateActionPair;
import org.deeplearning4j.rl4j.learning.async.UpdateAlgorithm;
import java.util.ArrayList;
import java.util.List;
public class MockUpdateAlgorithm implements UpdateAlgorithm<MockNeuralNet> {
public final List<List<StateActionPair<Integer>>> experiences = new ArrayList<List<StateActionPair<Integer>>>();
@Override
public Gradient[] computeGradients(MockNeuralNet current, List<StateActionPair<Integer>> experience) {
experiences.add(experience);
return new Gradient[0];
}
}