RL4J: Add ExperienceHandler (#369)
* Added ExperienceHandler Signed-off-by: Alexandre Boulanger <aboulang2002@yahoo.com> * Added getTrainingBatchSize() Signed-off-by: Alexandre Boulanger <aboulang2002@yahoo.com>master
parent
3e2dbc65dd
commit
f1debe8c07
|
@ -0,0 +1,54 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2020 Konduit K.K.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
package org.deeplearning4j.rl4j.experience;
|
||||
|
||||
import org.deeplearning4j.rl4j.observation.Observation;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
/**
|
||||
* A common interface to all classes capable of handling experience generated by the agents in a learning context.
|
||||
*
|
||||
* @param <A> Action type
|
||||
* @param <E> Experience type
|
||||
*
|
||||
* @author Alexandre Boulanger
|
||||
*/
|
||||
public interface ExperienceHandler<A, E> {
|
||||
void addExperience(Observation observation, A action, double reward, boolean isTerminal);
|
||||
|
||||
/**
|
||||
* Called when the episode is done with the last observation
|
||||
* @param observation
|
||||
*/
|
||||
void setFinalObservation(Observation observation);
|
||||
|
||||
/**
|
||||
* @return The size of the list that will be returned by generateTrainingBatch().
|
||||
*/
|
||||
int getTrainingBatchSize();
|
||||
|
||||
/**
|
||||
* The elements are returned in the historical order (i.e. in the order they happened)
|
||||
* @return The list of experience elements
|
||||
*/
|
||||
List<E> generateTrainingBatch();
|
||||
|
||||
/**
|
||||
* Signal the experience handler that a new episode is starting
|
||||
*/
|
||||
void reset();
|
||||
}
|
|
@ -0,0 +1,111 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2020 Konduit K.K.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
package org.deeplearning4j.rl4j.experience;
|
||||
|
||||
import lombok.EqualsAndHashCode;
|
||||
import org.deeplearning4j.rl4j.learning.sync.ExpReplay;
|
||||
import org.deeplearning4j.rl4j.learning.sync.IExpReplay;
|
||||
import org.deeplearning4j.rl4j.learning.sync.Transition;
|
||||
import org.deeplearning4j.rl4j.observation.Observation;
|
||||
import org.nd4j.linalg.api.rng.Random;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
/**
|
||||
* A experience handler that stores the experience in a replay memory. See https://arxiv.org/abs/1312.5602
|
||||
* The experience container is a {@link Transition Transition} that stores the tuple observation-action-reward-nextObservation,
|
||||
* as well as whether or the not the episode ended after the Transition
|
||||
*
|
||||
* @param <A> Action type
|
||||
*/
|
||||
@EqualsAndHashCode
|
||||
public class ReplayMemoryExperienceHandler<A> implements ExperienceHandler<A, Transition<A>> {
|
||||
private static final int DEFAULT_MAX_REPLAY_MEMORY_SIZE = 150000;
|
||||
private static final int DEFAULT_BATCH_SIZE = 32;
|
||||
|
||||
private IExpReplay<A> expReplay;
|
||||
|
||||
private Transition<A> pendingTransition;
|
||||
|
||||
public ReplayMemoryExperienceHandler(IExpReplay<A> expReplay) {
|
||||
this.expReplay = expReplay;
|
||||
}
|
||||
|
||||
public ReplayMemoryExperienceHandler(int maxReplayMemorySize, int batchSize, Random random) {
|
||||
this(new ExpReplay<A>(maxReplayMemorySize, batchSize, random));
|
||||
}
|
||||
|
||||
public void addExperience(Observation observation, A action, double reward, boolean isTerminal) {
|
||||
setNextObservationOnPending(observation);
|
||||
pendingTransition = new Transition<>(observation, action, reward, isTerminal);
|
||||
}
|
||||
|
||||
public void setFinalObservation(Observation observation) {
|
||||
setNextObservationOnPending(observation);
|
||||
pendingTransition = null;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int getTrainingBatchSize() {
|
||||
return expReplay.getBatchSize();
|
||||
}
|
||||
|
||||
/**
|
||||
* @return A batch of experience selected from the replay memory. The replay memory is unchanged after the call.
|
||||
*/
|
||||
@Override
|
||||
public List<Transition<A>> generateTrainingBatch() {
|
||||
return expReplay.getBatch();
|
||||
}
|
||||
|
||||
@Override
|
||||
public void reset() {
|
||||
pendingTransition = null;
|
||||
}
|
||||
|
||||
private void setNextObservationOnPending(Observation observation) {
|
||||
if(pendingTransition != null) {
|
||||
pendingTransition.setNextObservation(observation);
|
||||
expReplay.store(pendingTransition);
|
||||
}
|
||||
}
|
||||
|
||||
public class Builder {
|
||||
private int maxReplayMemorySize = DEFAULT_MAX_REPLAY_MEMORY_SIZE;
|
||||
private int batchSize = DEFAULT_BATCH_SIZE;
|
||||
private Random random = Nd4j.getRandom();
|
||||
|
||||
public Builder maxReplayMemorySize(int value) {
|
||||
maxReplayMemorySize = value;
|
||||
return this;
|
||||
}
|
||||
|
||||
public Builder batchSize(int value) {
|
||||
batchSize = value;
|
||||
return this;
|
||||
}
|
||||
|
||||
public Builder random(Random value) {
|
||||
random = value;
|
||||
return this;
|
||||
}
|
||||
|
||||
public ReplayMemoryExperienceHandler<A> build() {
|
||||
return new ReplayMemoryExperienceHandler<A>(maxReplayMemorySize, batchSize, random);
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,67 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2020 Konduit K.K.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
package org.deeplearning4j.rl4j.experience;
|
||||
|
||||
import org.deeplearning4j.rl4j.observation.Observation;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
|
||||
/**
|
||||
* A simple {@link ExperienceHandler experience handler} that stores the experiences.
|
||||
* Note: Calling {@link StateActionExperienceHandler#generateTrainingBatch() generateTrainingBatch()} will clear the stored experiences
|
||||
*
|
||||
* @param <A> Action type
|
||||
*
|
||||
* @author Alexandre Boulanger
|
||||
*/
|
||||
public class StateActionExperienceHandler<A> implements ExperienceHandler<A, StateActionPair<A>> {
|
||||
|
||||
private List<StateActionPair<A>> stateActionPairs;
|
||||
|
||||
public void setFinalObservation(Observation observation) {
|
||||
// Do nothing
|
||||
}
|
||||
|
||||
public void addExperience(Observation observation, A action, double reward, boolean isTerminal) {
|
||||
stateActionPairs.add(new StateActionPair<A>(observation, action, reward, isTerminal));
|
||||
}
|
||||
|
||||
@Override
|
||||
public int getTrainingBatchSize() {
|
||||
return stateActionPairs.size();
|
||||
}
|
||||
|
||||
/**
|
||||
* The elements are returned in the historical order (i.e. in the order they happened)
|
||||
* Note: the experience store is cleared after calling this method.
|
||||
*
|
||||
* @return The list of experience elements
|
||||
*/
|
||||
@Override
|
||||
public List<StateActionPair<A>> generateTrainingBatch() {
|
||||
List<StateActionPair<A>> result = stateActionPairs;
|
||||
stateActionPairs = new ArrayList<>();
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void reset() {
|
||||
stateActionPairs = new ArrayList<>();
|
||||
}
|
||||
|
||||
}
|
|
@ -0,0 +1,49 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2020 Konduit K.K.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
package org.deeplearning4j.rl4j.experience;
|
||||
|
||||
import lombok.AllArgsConstructor;
|
||||
import lombok.Getter;
|
||||
import org.deeplearning4j.rl4j.observation.Observation;
|
||||
|
||||
/**
|
||||
* A simple experience container. Used by {@link StateActionExperienceHandler StateActionExperienceHandler}.
|
||||
*
|
||||
* @param <A> Action type
|
||||
*
|
||||
* @author Alexandre Boulanger
|
||||
*/
|
||||
@AllArgsConstructor
|
||||
public class StateActionPair<A> {
|
||||
|
||||
/**
|
||||
* The observation before the action is taken
|
||||
*/
|
||||
@Getter
|
||||
private final Observation observation;
|
||||
|
||||
@Getter
|
||||
private final A action;
|
||||
|
||||
@Getter
|
||||
private final double reward;
|
||||
|
||||
/**
|
||||
* True if the episode ended after the action has been taken.
|
||||
*/
|
||||
@Getter
|
||||
private final boolean terminal;
|
||||
}
|
|
@ -18,9 +18,12 @@
|
|||
|
||||
package org.deeplearning4j.rl4j.learning.async;
|
||||
|
||||
import lombok.AccessLevel;
|
||||
import lombok.Getter;
|
||||
import lombok.Setter;
|
||||
import org.deeplearning4j.gym.StepReply;
|
||||
import org.deeplearning4j.nn.gradient.Gradient;
|
||||
import org.deeplearning4j.rl4j.experience.ExperienceHandler;
|
||||
import org.deeplearning4j.rl4j.experience.StateActionExperienceHandler;
|
||||
import org.deeplearning4j.rl4j.learning.IHistoryProcessor;
|
||||
import org.deeplearning4j.rl4j.learning.listener.TrainingListenerList;
|
||||
import org.deeplearning4j.rl4j.mdp.MDP;
|
||||
|
@ -28,10 +31,6 @@ import org.deeplearning4j.rl4j.network.NeuralNet;
|
|||
import org.deeplearning4j.rl4j.observation.Observation;
|
||||
import org.deeplearning4j.rl4j.policy.IPolicy;
|
||||
import org.deeplearning4j.rl4j.space.DiscreteSpace;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
|
||||
import java.util.Stack;
|
||||
|
||||
/**
|
||||
* @author rubenfiszel (ruben.fiszel@epfl.ch) on 8/5/16.
|
||||
|
@ -45,13 +44,39 @@ public abstract class AsyncThreadDiscrete<O, NN extends NeuralNet>
|
|||
@Getter
|
||||
private NN current;
|
||||
|
||||
public AsyncThreadDiscrete(IAsyncGlobal<NN> asyncGlobal, MDP<O, Integer, DiscreteSpace> mdp, TrainingListenerList listeners, int threadNumber, int deviceNum) {
|
||||
@Setter(AccessLevel.PROTECTED)
|
||||
private UpdateAlgorithm<NN> updateAlgorithm;
|
||||
|
||||
// TODO: Make it configurable with a builder
|
||||
@Setter(AccessLevel.PROTECTED)
|
||||
private ExperienceHandler experienceHandler = new StateActionExperienceHandler();
|
||||
|
||||
public AsyncThreadDiscrete(IAsyncGlobal<NN> asyncGlobal,
|
||||
MDP<O, Integer, DiscreteSpace> mdp,
|
||||
TrainingListenerList listeners,
|
||||
int threadNumber,
|
||||
int deviceNum) {
|
||||
super(asyncGlobal, mdp, listeners, threadNumber, deviceNum);
|
||||
synchronized (asyncGlobal) {
|
||||
current = (NN)asyncGlobal.getCurrent().clone();
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: Add an actor-learner class and be able to inject the update algorithm
|
||||
protected abstract UpdateAlgorithm<NN> buildUpdateAlgorithm();
|
||||
|
||||
@Override
|
||||
public void setHistoryProcessor(IHistoryProcessor historyProcessor) {
|
||||
super.setHistoryProcessor(historyProcessor);
|
||||
updateAlgorithm = buildUpdateAlgorithm();
|
||||
}
|
||||
|
||||
@Override
|
||||
protected void preEpoch() {
|
||||
experienceHandler.reset();
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* "Subepoch" correspond to the t_max-step iterations
|
||||
* that stack rewards with t_max MiniTrans
|
||||
|
@ -65,13 +90,11 @@ public abstract class AsyncThreadDiscrete<O, NN extends NeuralNet>
|
|||
synchronized (getAsyncGlobal()) {
|
||||
current.copy(getAsyncGlobal().getCurrent());
|
||||
}
|
||||
Stack<MiniTrans<Integer>> rewards = new Stack<>();
|
||||
|
||||
Observation obs = sObs;
|
||||
IPolicy<O, Integer> policy = getPolicy(current);
|
||||
|
||||
Integer action;
|
||||
Integer lastAction = getMdp().getActionSpace().noOp();
|
||||
Integer action = getMdp().getActionSpace().noOp();
|
||||
IHistoryProcessor hp = getHistoryProcessor();
|
||||
int skipFrame = hp != null ? hp.getConf().getSkipFrame() : 1;
|
||||
|
||||
|
@ -82,21 +105,15 @@ public abstract class AsyncThreadDiscrete<O, NN extends NeuralNet>
|
|||
while (!getMdp().isDone() && getCurrentEpochStep() < lastStep) {
|
||||
|
||||
//if step of training, just repeat lastAction
|
||||
if (obs.isSkipped()) {
|
||||
action = lastAction;
|
||||
} else {
|
||||
if (!obs.isSkipped()) {
|
||||
action = policy.nextAction(obs);
|
||||
}
|
||||
|
||||
StepReply<Observation> stepReply = getLegacyMDPWrapper().step(action);
|
||||
accuReward += stepReply.getReward() * getConf().getRewardFactor();
|
||||
|
||||
//if it's not a skipped frame, you can do a step of training
|
||||
if (!obs.isSkipped()) {
|
||||
|
||||
INDArray[] output = current.outputAll(obs.getData());
|
||||
rewards.add(new MiniTrans(obs.getData(), action, output, accuReward));
|
||||
|
||||
experienceHandler.addExperience(obs, action, accuReward, stepReply.isDone());
|
||||
accuReward = 0;
|
||||
}
|
||||
|
||||
|
@ -104,29 +121,14 @@ public abstract class AsyncThreadDiscrete<O, NN extends NeuralNet>
|
|||
reward += stepReply.getReward();
|
||||
|
||||
incrementStep();
|
||||
lastAction = action;
|
||||
}
|
||||
|
||||
//a bit of a trick usable because of how the stack is treated to init R
|
||||
// FIXME: The last element of minitrans is only used to seed the reward in calcGradient; observation, action and output are ignored.
|
||||
|
||||
if (getMdp().isDone() && getCurrentEpochStep() < lastStep)
|
||||
rewards.add(new MiniTrans(obs.getData(), null, null, 0));
|
||||
else {
|
||||
INDArray[] output = null;
|
||||
if (getConf().getLearnerUpdateFrequency() == -1)
|
||||
output = current.outputAll(obs.getData());
|
||||
else synchronized (getAsyncGlobal()) {
|
||||
output = getAsyncGlobal().getTarget().outputAll(obs.getData());
|
||||
}
|
||||
double maxQ = Nd4j.max(output[0]).getDouble(0);
|
||||
rewards.add(new MiniTrans(obs.getData(), null, output, maxQ));
|
||||
if (getMdp().isDone() && getCurrentEpochStep() < lastStep) {
|
||||
experienceHandler.setFinalObservation(obs);
|
||||
}
|
||||
|
||||
getAsyncGlobal().enqueue(calcGradient(current, rewards), getCurrentEpochStep());
|
||||
getAsyncGlobal().enqueue(updateAlgorithm.computeGradients(current, experienceHandler.generateTrainingBatch()), getCurrentEpochStep());
|
||||
|
||||
return new SubEpochReturn(getCurrentEpochStep() - stepAtStart, obs, reward, current.getLatestScore());
|
||||
}
|
||||
|
||||
public abstract Gradient[] calcGradient(NN nn, Stack<MiniTrans<Integer>> rewards);
|
||||
}
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||
* Copyright (c) 2020 Konduit K.K.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
|
@ -13,28 +13,14 @@
|
|||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
package org.deeplearning4j.rl4j.learning.async;
|
||||
|
||||
import lombok.AllArgsConstructor;
|
||||
import lombok.Value;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.deeplearning4j.nn.gradient.Gradient;
|
||||
import org.deeplearning4j.rl4j.experience.StateActionPair;
|
||||
import org.deeplearning4j.rl4j.network.NeuralNet;
|
||||
|
||||
/**
|
||||
* @author rubenfiszel (ruben.fiszel@epfl.ch) on 8/5/16.
|
||||
*
|
||||
* Its called a MiniTrans because it is similar to a Transition
|
||||
* but without a next observation
|
||||
*
|
||||
* It is stacked and then processed by AsyncNStepQL or A3C
|
||||
* following the paper implementation https://arxiv.org/abs/1602.01783 paper.
|
||||
*
|
||||
*/
|
||||
@AllArgsConstructor
|
||||
@Value
|
||||
public class MiniTrans<A> {
|
||||
INDArray obs;
|
||||
A action;
|
||||
INDArray[] output;
|
||||
double reward;
|
||||
import java.util.List;
|
||||
|
||||
public interface UpdateAlgorithm<NN extends NeuralNet> {
|
||||
Gradient[] computeGradients(NN current, List<StateActionPair<Integer>> experience);
|
||||
}
|
|
@ -18,11 +18,7 @@
|
|||
package org.deeplearning4j.rl4j.learning.async.a3c.discrete;
|
||||
|
||||
import lombok.Getter;
|
||||
import org.deeplearning4j.nn.gradient.Gradient;
|
||||
import org.deeplearning4j.rl4j.learning.Learning;
|
||||
import org.deeplearning4j.rl4j.learning.async.AsyncThreadDiscrete;
|
||||
import org.deeplearning4j.rl4j.learning.async.IAsyncGlobal;
|
||||
import org.deeplearning4j.rl4j.learning.async.MiniTrans;
|
||||
import org.deeplearning4j.rl4j.learning.async.*;
|
||||
import org.deeplearning4j.rl4j.learning.configuration.A3CLearningConfiguration;
|
||||
import org.deeplearning4j.rl4j.learning.listener.TrainingListenerList;
|
||||
import org.deeplearning4j.rl4j.mdp.MDP;
|
||||
|
@ -34,9 +30,7 @@ import org.deeplearning4j.rl4j.space.Encodable;
|
|||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.api.rng.Random;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
import org.nd4j.linalg.indexing.NDArrayIndex;
|
||||
|
||||
import java.util.Stack;
|
||||
import org.nd4j.linalg.api.rng.Random;
|
||||
|
||||
/**
|
||||
* @author rubenfiszel (ruben.fiszel@epfl.ch) 7/23/16.
|
||||
|
@ -67,6 +61,8 @@ public class A3CThreadDiscrete<O extends Encodable> extends AsyncThreadDiscrete<
|
|||
if(seed != null) {
|
||||
rnd.setSeed(seed + threadNumber);
|
||||
}
|
||||
|
||||
setUpdateAlgorithm(buildUpdateAlgorithm());
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -74,52 +70,9 @@ public class A3CThreadDiscrete<O extends Encodable> extends AsyncThreadDiscrete<
|
|||
return new ACPolicy(net, rnd);
|
||||
}
|
||||
|
||||
/**
|
||||
* calc the gradients based on the n-step rewards
|
||||
*/
|
||||
@Override
|
||||
public Gradient[] calcGradient(IActorCritic iac, Stack<MiniTrans<Integer>> rewards) {
|
||||
MiniTrans<Integer> minTrans = rewards.pop();
|
||||
|
||||
int size = rewards.size();
|
||||
|
||||
//if recurrent then train as a time serie with a batch size of 1
|
||||
boolean recurrent = getAsyncGlobal().getCurrent().isRecurrent();
|
||||
|
||||
int[] shape = getHistoryProcessor() == null ? getMdp().getObservationSpace().getShape()
|
||||
: getHistoryProcessor().getConf().getShape();
|
||||
int[] nshape = recurrent ? Learning.makeShape(1, shape, size)
|
||||
: Learning.makeShape(size, shape);
|
||||
|
||||
INDArray input = Nd4j.create(nshape);
|
||||
INDArray targets = recurrent ? Nd4j.create(1, 1, size) : Nd4j.create(size, 1);
|
||||
INDArray logSoftmax = recurrent ? Nd4j.zeros(1, getMdp().getActionSpace().getSize(), size)
|
||||
: Nd4j.zeros(size, getMdp().getActionSpace().getSize());
|
||||
|
||||
double r = minTrans.getReward();
|
||||
for (int i = size - 1; i >= 0; i--) {
|
||||
minTrans = rewards.pop();
|
||||
|
||||
r = minTrans.getReward() + conf.getGamma() * r;
|
||||
if (recurrent) {
|
||||
input.get(NDArrayIndex.point(0), NDArrayIndex.all(), NDArrayIndex.point(i)).assign(minTrans.getObs());
|
||||
} else {
|
||||
input.putRow(i, minTrans.getObs());
|
||||
}
|
||||
|
||||
//the critic
|
||||
targets.putScalar(i, r);
|
||||
|
||||
//the actor
|
||||
double expectedV = minTrans.getOutput()[0].getDouble(0);
|
||||
double advantage = r - expectedV;
|
||||
if (recurrent) {
|
||||
logSoftmax.putScalar(0, minTrans.getAction(), i, advantage);
|
||||
} else {
|
||||
logSoftmax.putScalar(i, minTrans.getAction(), advantage);
|
||||
}
|
||||
}
|
||||
|
||||
return iac.gradient(input, new INDArray[] {targets, logSoftmax});
|
||||
protected UpdateAlgorithm<IActorCritic> buildUpdateAlgorithm() {
|
||||
int[] shape = getHistoryProcessor() == null ? getMdp().getObservationSpace().getShape() : getHistoryProcessor().getConf().getShape();
|
||||
return new A3CUpdateAlgorithm(asyncGlobal, shape, getMdp().getActionSpace().getSize(), conf.getLearnerUpdateFrequency(), conf.getGamma());
|
||||
}
|
||||
}
|
||||
|
|
|
@ -0,0 +1,113 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2020 Konduit K.K.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
package org.deeplearning4j.rl4j.learning.async.a3c.discrete;
|
||||
|
||||
import org.deeplearning4j.nn.gradient.Gradient;
|
||||
import org.deeplearning4j.rl4j.experience.StateActionPair;
|
||||
import org.deeplearning4j.rl4j.learning.Learning;
|
||||
import org.deeplearning4j.rl4j.learning.async.IAsyncGlobal;
|
||||
import org.deeplearning4j.rl4j.learning.async.UpdateAlgorithm;
|
||||
import org.deeplearning4j.rl4j.network.ac.IActorCritic;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
import org.nd4j.linalg.indexing.NDArrayIndex;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
public class A3CUpdateAlgorithm implements UpdateAlgorithm<IActorCritic> {
|
||||
|
||||
private final IAsyncGlobal asyncGlobal;
|
||||
private final int[] shape;
|
||||
private final int actionSpaceSize;
|
||||
private final int targetDqnUpdateFreq;
|
||||
private final double gamma;
|
||||
private final boolean recurrent;
|
||||
|
||||
public A3CUpdateAlgorithm(IAsyncGlobal asyncGlobal,
|
||||
int[] shape,
|
||||
int actionSpaceSize,
|
||||
int targetDqnUpdateFreq,
|
||||
double gamma) {
|
||||
|
||||
this.asyncGlobal = asyncGlobal;
|
||||
|
||||
//if recurrent then train as a time serie with a batch size of 1
|
||||
recurrent = asyncGlobal.getCurrent().isRecurrent();
|
||||
this.shape = shape;
|
||||
this.actionSpaceSize = actionSpaceSize;
|
||||
this.targetDqnUpdateFreq = targetDqnUpdateFreq;
|
||||
this.gamma = gamma;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Gradient[] computeGradients(IActorCritic current, List<StateActionPair<Integer>> experience) {
|
||||
int size = experience.size();
|
||||
|
||||
int[] nshape = recurrent ? Learning.makeShape(1, shape, size)
|
||||
: Learning.makeShape(size, shape);
|
||||
|
||||
INDArray input = Nd4j.create(nshape);
|
||||
INDArray targets = recurrent ? Nd4j.create(1, 1, size) : Nd4j.create(size, 1);
|
||||
INDArray logSoftmax = recurrent ? Nd4j.zeros(1, actionSpaceSize, size)
|
||||
: Nd4j.zeros(size, actionSpaceSize);
|
||||
|
||||
StateActionPair<Integer> stateActionPair = experience.get(size - 1);
|
||||
double r;
|
||||
if(stateActionPair.isTerminal()) {
|
||||
r = 0;
|
||||
}
|
||||
else {
|
||||
INDArray[] output = null;
|
||||
if (targetDqnUpdateFreq == -1)
|
||||
output = current.outputAll(stateActionPair.getObservation().getData());
|
||||
else synchronized (asyncGlobal) {
|
||||
output = asyncGlobal.getTarget().outputAll(stateActionPair.getObservation().getData());
|
||||
}
|
||||
r = output[0].getDouble(0);
|
||||
}
|
||||
|
||||
for (int i = size - 1; i >= 0; --i) {
|
||||
stateActionPair = experience.get(i);
|
||||
|
||||
INDArray observationData = stateActionPair.getObservation().getData();
|
||||
|
||||
INDArray[] output = current.outputAll(observationData);
|
||||
|
||||
r = stateActionPair.getReward() + gamma * r;
|
||||
if (recurrent) {
|
||||
input.get(NDArrayIndex.point(0), NDArrayIndex.all(), NDArrayIndex.point(i)).assign(observationData);
|
||||
} else {
|
||||
input.putRow(i, observationData);
|
||||
}
|
||||
|
||||
//the critic
|
||||
targets.putScalar(i, r);
|
||||
|
||||
//the actor
|
||||
double expectedV = output[0].getDouble(0);
|
||||
double advantage = r - expectedV;
|
||||
if (recurrent) {
|
||||
logSoftmax.putScalar(0, stateActionPair.getAction(), i, advantage);
|
||||
} else {
|
||||
logSoftmax.putScalar(i, stateActionPair.getAction(), advantage);
|
||||
}
|
||||
}
|
||||
|
||||
// targets -> value, critic
|
||||
// logSoftmax -> policy, actor
|
||||
return current.gradient(input, new INDArray[] {targets, logSoftmax});
|
||||
}
|
||||
}
|
|
@ -18,11 +18,9 @@
|
|||
package org.deeplearning4j.rl4j.learning.async.nstep.discrete;
|
||||
|
||||
import lombok.Getter;
|
||||
import org.deeplearning4j.nn.gradient.Gradient;
|
||||
import org.deeplearning4j.rl4j.learning.Learning;
|
||||
import org.deeplearning4j.rl4j.learning.async.AsyncThreadDiscrete;
|
||||
import org.deeplearning4j.rl4j.learning.async.IAsyncGlobal;
|
||||
import org.deeplearning4j.rl4j.learning.async.MiniTrans;
|
||||
import org.deeplearning4j.rl4j.learning.async.UpdateAlgorithm;
|
||||
import org.deeplearning4j.rl4j.learning.configuration.AsyncQLearningConfiguration;
|
||||
import org.deeplearning4j.rl4j.learning.listener.TrainingListenerList;
|
||||
import org.deeplearning4j.rl4j.mdp.MDP;
|
||||
|
@ -32,12 +30,9 @@ import org.deeplearning4j.rl4j.policy.EpsGreedy;
|
|||
import org.deeplearning4j.rl4j.policy.Policy;
|
||||
import org.deeplearning4j.rl4j.space.DiscreteSpace;
|
||||
import org.deeplearning4j.rl4j.space.Encodable;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
import org.nd4j.linalg.api.rng.Random;
|
||||
|
||||
import java.util.Stack;
|
||||
|
||||
/**
|
||||
* @author rubenfiszel (ruben.fiszel@epfl.ch) on 8/5/16.
|
||||
*/
|
||||
|
@ -65,6 +60,8 @@ public class AsyncNStepQLearningThreadDiscrete<O extends Encodable> extends Asyn
|
|||
if(seed != null) {
|
||||
rnd.setSeed(seed + threadNumber);
|
||||
}
|
||||
|
||||
setUpdateAlgorithm(buildUpdateAlgorithm());
|
||||
}
|
||||
|
||||
public Policy<O, Integer> getPolicy(IDQN nn) {
|
||||
|
@ -72,32 +69,9 @@ public class AsyncNStepQLearningThreadDiscrete<O extends Encodable> extends Asyn
|
|||
rnd, conf.getMinEpsilon(), this);
|
||||
}
|
||||
|
||||
|
||||
|
||||
//calc the gradient based on the n-step rewards
|
||||
public Gradient[] calcGradient(IDQN current, Stack<MiniTrans<Integer>> rewards) {
|
||||
|
||||
MiniTrans<Integer> minTrans = rewards.pop();
|
||||
|
||||
int size = rewards.size();
|
||||
|
||||
int[] shape = getHistoryProcessor() == null ? getMdp().getObservationSpace().getShape()
|
||||
: getHistoryProcessor().getConf().getShape();
|
||||
int[] nshape = Learning.makeShape(size, shape);
|
||||
INDArray input = Nd4j.create(nshape);
|
||||
INDArray targets = Nd4j.create(size, getMdp().getActionSpace().getSize());
|
||||
|
||||
double r = minTrans.getReward();
|
||||
for (int i = size - 1; i >= 0; i--) {
|
||||
minTrans = rewards.pop();
|
||||
|
||||
r = minTrans.getReward() + conf.getGamma() * r;
|
||||
input.putRow(i, minTrans.getObs());
|
||||
INDArray row = minTrans.getOutput()[0];
|
||||
row = row.putScalar(minTrans.getAction(), r);
|
||||
targets.putRow(i, row);
|
||||
}
|
||||
|
||||
return current.gradient(input, targets);
|
||||
@Override
|
||||
protected UpdateAlgorithm<IDQN> buildUpdateAlgorithm() {
|
||||
int[] shape = getHistoryProcessor() == null ? getMdp().getObservationSpace().getShape() : getHistoryProcessor().getConf().getShape();
|
||||
return new QLearningUpdateAlgorithm(asyncGlobal, shape, getMdp().getActionSpace().getSize(), conf.getTargetDqnUpdateFreq(), conf.getGamma());
|
||||
}
|
||||
}
|
||||
|
|
|
@ -0,0 +1,88 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2020 Konduit K.K.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
package org.deeplearning4j.rl4j.learning.async.nstep.discrete;
|
||||
|
||||
import org.deeplearning4j.nn.gradient.Gradient;
|
||||
import org.deeplearning4j.rl4j.experience.StateActionPair;
|
||||
import org.deeplearning4j.rl4j.learning.Learning;
|
||||
import org.deeplearning4j.rl4j.learning.async.IAsyncGlobal;
|
||||
import org.deeplearning4j.rl4j.learning.async.UpdateAlgorithm;
|
||||
import org.deeplearning4j.rl4j.network.dqn.IDQN;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
public class QLearningUpdateAlgorithm implements UpdateAlgorithm<IDQN> {
|
||||
|
||||
private final IAsyncGlobal asyncGlobal;
|
||||
private final int[] shape;
|
||||
private final int actionSpaceSize;
|
||||
private final int targetDqnUpdateFreq;
|
||||
private final double gamma;
|
||||
|
||||
public QLearningUpdateAlgorithm(IAsyncGlobal asyncGlobal,
|
||||
int[] shape,
|
||||
int actionSpaceSize,
|
||||
int targetDqnUpdateFreq,
|
||||
double gamma) {
|
||||
|
||||
this.asyncGlobal = asyncGlobal;
|
||||
this.shape = shape;
|
||||
this.actionSpaceSize = actionSpaceSize;
|
||||
this.targetDqnUpdateFreq = targetDqnUpdateFreq;
|
||||
this.gamma = gamma;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Gradient[] computeGradients(IDQN current, List<StateActionPair<Integer>> experience) {
|
||||
int size = experience.size();
|
||||
|
||||
int[] nshape = Learning.makeShape(size, shape);
|
||||
INDArray input = Nd4j.create(nshape);
|
||||
INDArray targets = Nd4j.create(size, actionSpaceSize);
|
||||
|
||||
StateActionPair<Integer> stateActionPair = experience.get(size - 1);
|
||||
|
||||
double r;
|
||||
if(stateActionPair.isTerminal()) {
|
||||
r = 0;
|
||||
}
|
||||
else {
|
||||
INDArray[] output = null;
|
||||
if (targetDqnUpdateFreq == -1)
|
||||
output = current.outputAll(stateActionPair.getObservation().getData());
|
||||
else synchronized (asyncGlobal) {
|
||||
output = asyncGlobal.getTarget().outputAll(stateActionPair.getObservation().getData());
|
||||
}
|
||||
r = Nd4j.max(output[0]).getDouble(0);
|
||||
}
|
||||
|
||||
for (int i = size - 1; i >= 0; i--) {
|
||||
stateActionPair = experience.get(i);
|
||||
|
||||
input.putRow(i, stateActionPair.getObservation().getData());
|
||||
|
||||
r = stateActionPair.getReward() + gamma * r;
|
||||
INDArray[] output = current.outputAll(stateActionPair.getObservation().getData());
|
||||
INDArray row = output[0];
|
||||
row = row.putScalar(stateActionPair.getAction(), r);
|
||||
targets.putRow(i, row);
|
||||
}
|
||||
|
||||
return current.gradient(input, targets);
|
||||
}
|
||||
}
|
|
@ -80,6 +80,9 @@ public class ExpReplay<A> implements IExpReplay<A> {
|
|||
//log.info("size: "+storage.size());
|
||||
}
|
||||
|
||||
|
||||
public int getBatchSize() {
|
||||
int storageSize = storage.size();
|
||||
return Math.min(storageSize, batchSize);
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
@ -32,6 +32,11 @@ import java.util.ArrayList;
|
|||
*/
|
||||
public interface IExpReplay<A> {
|
||||
|
||||
/**
|
||||
* @return The size of the batch that will be returned by getBatch()
|
||||
*/
|
||||
int getBatchSize();
|
||||
|
||||
/**
|
||||
* @return a batch of uniformly sampled transitions
|
||||
*/
|
||||
|
@ -42,5 +47,4 @@ public interface IExpReplay<A> {
|
|||
* @param transition a new transition to store
|
||||
*/
|
||||
void store(Transition<A> transition);
|
||||
|
||||
}
|
||||
|
|
|
@ -60,32 +60,8 @@ public abstract class QLearning<O extends Encodable, A, AS extends ActionSpace<A
|
|||
extends SyncLearning<O, A, AS, IDQN>
|
||||
implements TargetQNetworkSource, EpochStepCounter {
|
||||
|
||||
// FIXME Changed for refac
|
||||
// @Getter
|
||||
// final private IExpReplay<A> expReplay;
|
||||
@Getter
|
||||
@Setter(AccessLevel.PROTECTED)
|
||||
protected IExpReplay<A> expReplay;
|
||||
|
||||
protected abstract LegacyMDPWrapper<O, A, AS> getLegacyMDPWrapper();
|
||||
|
||||
public QLearning(QLearningConfiguration conf) {
|
||||
this(conf, getSeededRandom(conf.getSeed()));
|
||||
}
|
||||
|
||||
public QLearning(QLearningConfiguration conf, Random random) {
|
||||
expReplay = new ExpReplay<>(conf.getExpRepMaxSize(), conf.getBatchSize(), random);
|
||||
}
|
||||
|
||||
private static Random getSeededRandom(Long seed) {
|
||||
Random rnd = Nd4j.getRandom();
|
||||
if(seed != null) {
|
||||
rnd.setSeed(seed);
|
||||
}
|
||||
|
||||
return rnd;
|
||||
}
|
||||
|
||||
protected abstract EpsGreedy<O, A, AS> getEgPolicy();
|
||||
|
||||
public abstract MDP<O, A, AS> getMdp();
|
||||
|
|
|
@ -21,6 +21,8 @@ import lombok.AccessLevel;
|
|||
import lombok.Getter;
|
||||
import lombok.Setter;
|
||||
import org.deeplearning4j.gym.StepReply;
|
||||
import org.deeplearning4j.rl4j.experience.ExperienceHandler;
|
||||
import org.deeplearning4j.rl4j.experience.ReplayMemoryExperienceHandler;
|
||||
import org.deeplearning4j.rl4j.learning.IHistoryProcessor;
|
||||
import org.deeplearning4j.rl4j.learning.Learning;
|
||||
import org.deeplearning4j.rl4j.learning.configuration.QLearningConfiguration;
|
||||
|
@ -42,7 +44,7 @@ import org.nd4j.linalg.api.rng.Random;
|
|||
import org.nd4j.linalg.dataset.api.DataSet;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
|
||||
|
||||
/**
|
||||
|
@ -71,10 +73,12 @@ public abstract class QLearningDiscrete<O extends Encodable> extends QLearning<O
|
|||
private int lastAction;
|
||||
private double accuReward = 0;
|
||||
|
||||
private Transition pendingTransition;
|
||||
|
||||
ITDTargetAlgorithm tdTargetAlgorithm;
|
||||
|
||||
// TODO: User a builder and remove the setter
|
||||
@Getter(AccessLevel.PROTECTED) @Setter
|
||||
private ExperienceHandler<Integer, Transition<Integer>> experienceHandler;
|
||||
|
||||
protected LegacyMDPWrapper<O, Integer, DiscreteSpace> getLegacyMDPWrapper() {
|
||||
return mdp;
|
||||
}
|
||||
|
@ -85,7 +89,6 @@ public abstract class QLearningDiscrete<O extends Encodable> extends QLearning<O
|
|||
|
||||
public QLearningDiscrete(MDP<O, Integer, DiscreteSpace> mdp, IDQN dqn, QLearningConfiguration conf,
|
||||
int epsilonNbStep, Random random) {
|
||||
super(conf);
|
||||
this.configuration = conf;
|
||||
this.mdp = new LegacyMDPWrapper<>(mdp, null, this);
|
||||
qNetwork = dqn;
|
||||
|
@ -98,6 +101,7 @@ public abstract class QLearningDiscrete<O extends Encodable> extends QLearning<O
|
|||
? new DoubleDQN(this, conf.getGamma(), conf.getErrorClamp())
|
||||
: new StandardDQN(this, conf.getGamma(), conf.getErrorClamp());
|
||||
|
||||
experienceHandler = new ReplayMemoryExperienceHandler(conf.getExpRepMaxSize(), conf.getBatchSize(), random);
|
||||
}
|
||||
|
||||
public MDP<O, Integer, DiscreteSpace> getMdp() {
|
||||
|
@ -114,7 +118,7 @@ public abstract class QLearningDiscrete<O extends Encodable> extends QLearning<O
|
|||
public void preEpoch() {
|
||||
lastAction = mdp.getActionSpace().noOp();
|
||||
accuReward = 0;
|
||||
pendingTransition = null;
|
||||
experienceHandler.reset();
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -131,8 +135,6 @@ public abstract class QLearningDiscrete<O extends Encodable> extends QLearning<O
|
|||
*/
|
||||
protected QLStepReturn<Observation> trainStep(Observation obs) {
|
||||
|
||||
Integer action;
|
||||
|
||||
boolean isHistoryProcessor = getHistoryProcessor() != null;
|
||||
int skipFrame = isHistoryProcessor ? getHistoryProcessor().getConf().getSkipFrame() : 1;
|
||||
int historyLength = isHistoryProcessor ? getHistoryProcessor().getConf().getHistoryLength() : 1;
|
||||
|
@ -142,37 +144,28 @@ public abstract class QLearningDiscrete<O extends Encodable> extends QLearning<O
|
|||
Double maxQ = Double.NaN; //ignore if Nan for stats
|
||||
|
||||
//if step of training, just repeat lastAction
|
||||
if (obs.isSkipped()) {
|
||||
action = lastAction;
|
||||
} else {
|
||||
if (!obs.isSkipped()) {
|
||||
INDArray qs = getQNetwork().output(obs);
|
||||
int maxAction = Learning.getMaxAction(qs);
|
||||
maxQ = qs.getDouble(maxAction);
|
||||
|
||||
action = getEgPolicy().nextAction(obs);
|
||||
lastAction = getEgPolicy().nextAction(obs);
|
||||
}
|
||||
|
||||
lastAction = action;
|
||||
|
||||
StepReply<Observation> stepReply = mdp.step(action);
|
||||
|
||||
StepReply<Observation> stepReply = mdp.step(lastAction);
|
||||
accuReward += stepReply.getReward() * configuration.getRewardFactor();
|
||||
|
||||
//if it's not a skipped frame, you can do a step of training
|
||||
if (!obs.isSkipped()) {
|
||||
|
||||
// Add experience
|
||||
if (pendingTransition != null) {
|
||||
pendingTransition.setNextObservation(obs);
|
||||
getExpReplay().store(pendingTransition);
|
||||
}
|
||||
pendingTransition = new Transition(obs, action, accuReward, stepReply.isDone());
|
||||
experienceHandler.addExperience(obs, lastAction, accuReward, stepReply.isDone());
|
||||
accuReward = 0;
|
||||
|
||||
// Update NN
|
||||
// FIXME: maybe start updating when experience replay has reached a certain size instead of using "updateStart"?
|
||||
if (getStepCounter() > updateStart) {
|
||||
DataSet targets = setTarget(getExpReplay().getBatch());
|
||||
DataSet targets = setTarget(experienceHandler.generateTrainingBatch());
|
||||
getQNetwork().fit(targets.getFeatures(), targets.getLabels());
|
||||
}
|
||||
}
|
||||
|
@ -180,7 +173,7 @@ public abstract class QLearningDiscrete<O extends Encodable> extends QLearning<O
|
|||
return new QLStepReturn<Observation>(maxQ, getQNetwork().getLatestScore(), stepReply);
|
||||
}
|
||||
|
||||
protected DataSet setTarget(ArrayList<Transition<Integer>> transitions) {
|
||||
protected DataSet setTarget(List<Transition<Integer>> transitions) {
|
||||
if (transitions.size() == 0)
|
||||
throw new IllegalArgumentException("too few transitions");
|
||||
|
||||
|
@ -189,9 +182,6 @@ public abstract class QLearningDiscrete<O extends Encodable> extends QLearning<O
|
|||
|
||||
@Override
|
||||
protected void finishEpoch(Observation observation) {
|
||||
if (pendingTransition != null) {
|
||||
pendingTransition.setNextObservation(observation);
|
||||
getExpReplay().store(pendingTransition);
|
||||
}
|
||||
experienceHandler.setFinalObservation(observation);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -0,0 +1,107 @@
|
|||
package org.deeplearning4j.rl4j.experience;
|
||||
|
||||
import org.deeplearning4j.rl4j.learning.sync.IExpReplay;
|
||||
import org.deeplearning4j.rl4j.learning.sync.Transition;
|
||||
import org.deeplearning4j.rl4j.observation.Observation;
|
||||
import org.junit.Test;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
|
||||
import static org.junit.Assert.assertEquals;
|
||||
|
||||
public class ReplayMemoryExperienceHandlerTest {
|
||||
@Test
|
||||
public void when_addingFirstExperience_expect_notAddedToStoreBeforeNextObservationIsAdded() {
|
||||
// Arrange
|
||||
TestExpReplay expReplayMock = new TestExpReplay();
|
||||
ReplayMemoryExperienceHandler sut = new ReplayMemoryExperienceHandler(expReplayMock);
|
||||
|
||||
// Act
|
||||
sut.addExperience(new Observation(Nd4j.create(new double[] { 1.0 })), 1, 1.0, false);
|
||||
int numStoredTransitions = expReplayMock.addedTransitions.size();
|
||||
sut.addExperience(new Observation(Nd4j.create(new double[] { 2.0 })), 2, 2.0, false);
|
||||
|
||||
// Assert
|
||||
assertEquals(0, numStoredTransitions);
|
||||
assertEquals(1, expReplayMock.addedTransitions.size());
|
||||
}
|
||||
|
||||
@Test
|
||||
public void when_addingExperience_expect_transitionsAreCorrect() {
|
||||
// Arrange
|
||||
TestExpReplay expReplayMock = new TestExpReplay();
|
||||
ReplayMemoryExperienceHandler sut = new ReplayMemoryExperienceHandler(expReplayMock);
|
||||
|
||||
// Act
|
||||
sut.addExperience(new Observation(Nd4j.create(new double[] { 1.0 })), 1, 1.0, false);
|
||||
sut.addExperience(new Observation(Nd4j.create(new double[] { 2.0 })), 2, 2.0, false);
|
||||
sut.setFinalObservation(new Observation(Nd4j.create(new double[] { 3.0 })));
|
||||
|
||||
// Assert
|
||||
assertEquals(2, expReplayMock.addedTransitions.size());
|
||||
|
||||
assertEquals(1.0, expReplayMock.addedTransitions.get(0).getObservation().getData().getDouble(0), 0.00001);
|
||||
assertEquals(1, (int)expReplayMock.addedTransitions.get(0).getAction());
|
||||
assertEquals(1.0, expReplayMock.addedTransitions.get(0).getReward(), 0.00001);
|
||||
assertEquals(2.0, expReplayMock.addedTransitions.get(0).getNextObservation().getDouble(0), 0.00001);
|
||||
|
||||
assertEquals(2.0, expReplayMock.addedTransitions.get(1).getObservation().getData().getDouble(0), 0.00001);
|
||||
assertEquals(2, (int)expReplayMock.addedTransitions.get(1).getAction());
|
||||
assertEquals(2.0, expReplayMock.addedTransitions.get(1).getReward(), 0.00001);
|
||||
assertEquals(3.0, expReplayMock.addedTransitions.get(1).getNextObservation().getDouble(0), 0.00001);
|
||||
|
||||
}
|
||||
|
||||
@Test
|
||||
public void when_settingFinalObservation_expect_nextAddedExperienceDoNotUsePreviousObservation() {
|
||||
// Arrange
|
||||
TestExpReplay expReplayMock = new TestExpReplay();
|
||||
ReplayMemoryExperienceHandler sut = new ReplayMemoryExperienceHandler(expReplayMock);
|
||||
|
||||
// Act
|
||||
sut.addExperience(new Observation(Nd4j.create(new double[] { 1.0 })), 1, 1.0, false);
|
||||
sut.setFinalObservation(new Observation(Nd4j.create(new double[] { 2.0 })));
|
||||
sut.addExperience(new Observation(Nd4j.create(new double[] { 3.0 })), 3, 3.0, false);
|
||||
|
||||
// Assert
|
||||
assertEquals(1, expReplayMock.addedTransitions.size());
|
||||
assertEquals(1, (int)expReplayMock.addedTransitions.get(0).getAction());
|
||||
}
|
||||
|
||||
@Test
|
||||
public void when_addingExperience_expect_getTrainingBatchSizeReturnSize() {
|
||||
// Arrange
|
||||
TestExpReplay expReplayMock = new TestExpReplay();
|
||||
ReplayMemoryExperienceHandler sut = new ReplayMemoryExperienceHandler(expReplayMock);
|
||||
sut.addExperience(new Observation(Nd4j.create(new double[] { 1.0 })), 1, 1.0, false);
|
||||
sut.addExperience(new Observation(Nd4j.create(new double[] { 2.0 })), 2, 2.0, false);
|
||||
sut.setFinalObservation(new Observation(Nd4j.create(new double[] { 3.0 })));
|
||||
|
||||
// Act
|
||||
int size = sut.getTrainingBatchSize();
|
||||
// Assert
|
||||
assertEquals(2, size);
|
||||
}
|
||||
|
||||
private static class TestExpReplay implements IExpReplay<Integer> {
|
||||
|
||||
public final List<Transition<Integer>> addedTransitions = new ArrayList<>();
|
||||
|
||||
@Override
|
||||
public ArrayList<Transition<Integer>> getBatch() {
|
||||
return null;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void store(Transition<Integer> transition) {
|
||||
addedTransitions.add(transition);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int getBatchSize() {
|
||||
return addedTransitions.size();
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,82 @@
|
|||
package org.deeplearning4j.rl4j.experience;
|
||||
|
||||
import org.deeplearning4j.rl4j.observation.Observation;
|
||||
import org.junit.Test;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
import static org.junit.Assert.*;
|
||||
|
||||
public class StateActionExperienceHandlerTest {
|
||||
|
||||
@Test
|
||||
public void when_addingExperience_expect_generateTrainingBatchReturnsIt() {
|
||||
// Arrange
|
||||
StateActionExperienceHandler sut = new StateActionExperienceHandler();
|
||||
sut.reset();
|
||||
Observation observation = new Observation(Nd4j.zeros(1));
|
||||
sut.addExperience(observation, 123, 234.0, true);
|
||||
|
||||
// Act
|
||||
List<StateActionPair<Integer>> result = sut.generateTrainingBatch();
|
||||
|
||||
// Assert
|
||||
assertEquals(1, result.size());
|
||||
assertSame(observation, result.get(0).getObservation());
|
||||
assertEquals(123, (int)result.get(0).getAction());
|
||||
assertEquals(234.0, result.get(0).getReward(), 0.00001);
|
||||
assertTrue(result.get(0).isTerminal());
|
||||
}
|
||||
|
||||
@Test
|
||||
public void when_addingMultipleExperiences_expect_generateTrainingBatchReturnsItInSameOrder() {
|
||||
// Arrange
|
||||
StateActionExperienceHandler sut = new StateActionExperienceHandler();
|
||||
sut.reset();
|
||||
sut.addExperience(null, 1, 1.0, false);
|
||||
sut.addExperience(null, 2, 2.0, false);
|
||||
sut.addExperience(null, 3, 3.0, false);
|
||||
|
||||
// Act
|
||||
List<StateActionPair<Integer>> result = sut.generateTrainingBatch();
|
||||
|
||||
// Assert
|
||||
assertEquals(3, result.size());
|
||||
assertEquals(1, (int)result.get(0).getAction());
|
||||
assertEquals(2, (int)result.get(1).getAction());
|
||||
assertEquals(3, (int)result.get(2).getAction());
|
||||
}
|
||||
|
||||
@Test
|
||||
public void when_gettingExperience_expect_experienceStoreIsCleared() {
|
||||
// Arrange
|
||||
StateActionExperienceHandler sut = new StateActionExperienceHandler();
|
||||
sut.reset();
|
||||
sut.addExperience(null, 1, 1.0, false);
|
||||
|
||||
// Act
|
||||
List<StateActionPair<Integer>> firstResult = sut.generateTrainingBatch();
|
||||
List<StateActionPair<Integer>> secondResult = sut.generateTrainingBatch();
|
||||
|
||||
// Assert
|
||||
assertEquals(1, firstResult.size());
|
||||
assertEquals(0, secondResult.size());
|
||||
}
|
||||
|
||||
@Test
|
||||
public void when_addingExperience_expect_getTrainingBatchSizeReturnSize() {
|
||||
// Arrange
|
||||
StateActionExperienceHandler sut = new StateActionExperienceHandler();
|
||||
sut.reset();
|
||||
sut.addExperience(null, 1, 1.0, false);
|
||||
sut.addExperience(null, 2, 2.0, false);
|
||||
sut.addExperience(null, 3, 3.0, false);
|
||||
|
||||
// Act
|
||||
int size = sut.getTrainingBatchSize();
|
||||
|
||||
// Assert
|
||||
assertEquals(3, size);
|
||||
}
|
||||
}
|
|
@ -18,9 +18,12 @@
|
|||
package org.deeplearning4j.rl4j.learning.async;
|
||||
|
||||
import org.deeplearning4j.nn.gradient.Gradient;
|
||||
import org.deeplearning4j.rl4j.experience.ExperienceHandler;
|
||||
import org.deeplearning4j.rl4j.experience.StateActionPair;
|
||||
import org.deeplearning4j.rl4j.learning.IHistoryProcessor;
|
||||
import org.deeplearning4j.rl4j.learning.configuration.IAsyncLearningConfiguration;
|
||||
import org.deeplearning4j.rl4j.learning.listener.TrainingListenerList;
|
||||
import org.deeplearning4j.rl4j.learning.sync.Transition;
|
||||
import org.deeplearning4j.rl4j.mdp.MDP;
|
||||
import org.deeplearning4j.rl4j.observation.Observation;
|
||||
import org.deeplearning4j.rl4j.policy.IPolicy;
|
||||
|
@ -31,7 +34,6 @@ import org.nd4j.linalg.api.ndarray.INDArray;
|
|||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.Stack;
|
||||
|
||||
import static org.junit.Assert.assertEquals;
|
||||
|
||||
|
@ -51,7 +53,9 @@ public class AsyncThreadDiscreteTest {
|
|||
TrainingListenerList listeners = new TrainingListenerList();
|
||||
MockPolicy policyMock = new MockPolicy();
|
||||
MockAsyncConfiguration config = new MockAsyncConfiguration(5L, 100, 0,0, 0, 0, 0, 0, 2, 5);
|
||||
TestAsyncThreadDiscrete sut = new TestAsyncThreadDiscrete(asyncGlobalMock, mdpMock, listeners, 0, 0, policyMock, config, hpMock);
|
||||
MockExperienceHandler experienceHandlerMock = new MockExperienceHandler();
|
||||
MockUpdateAlgorithm updateAlgorithmMock = new MockUpdateAlgorithm();
|
||||
TestAsyncThreadDiscrete sut = new TestAsyncThreadDiscrete(asyncGlobalMock, mdpMock, listeners, 0, 0, policyMock, config, hpMock, experienceHandlerMock, updateAlgorithmMock);
|
||||
sut.getLegacyMDPWrapper().setTransformProcess(MockMDP.buildTransformProcess(observationSpace.getShape(), hpConf.getSkipFrame(), hpConf.getHistoryLength()));
|
||||
|
||||
// Act
|
||||
|
@ -102,62 +106,22 @@ public class AsyncThreadDiscreteTest {
|
|||
}
|
||||
}
|
||||
|
||||
// NeuralNetwork
|
||||
assertEquals(2, nnMock.copyCallCount);
|
||||
double[][] expectedNNInputs = new double[][] {
|
||||
// ExperienceHandler
|
||||
double[][] expectedExperienceHandlerInputs = new double[][] {
|
||||
new double[] { 0.0, 2.0, 4.0, 6.0, 8.0 },
|
||||
new double[] { 2.0, 4.0, 6.0, 8.0, 10.0 },
|
||||
new double[] { 4.0, 6.0, 8.0, 10.0, 12.0 }, // FIXME: This one comes from the computation of output of the last minitrans
|
||||
new double[] { 4.0, 6.0, 8.0, 10.0, 12.0 },
|
||||
new double[] { 6.0, 8.0, 10.0, 12.0, 14.0 },
|
||||
new double[] { 8.0, 10.0, 12.0, 14.0, 16.0 }, // FIXME: This one comes from the computation of output of the last minitrans
|
||||
};
|
||||
assertEquals(expectedNNInputs.length, nnMock.outputAllInputs.size());
|
||||
for(int i = 0; i < expectedNNInputs.length; ++i) {
|
||||
double[] expectedRow = expectedNNInputs[i];
|
||||
INDArray input = nnMock.outputAllInputs.get(i);
|
||||
assertEquals(expectedExperienceHandlerInputs.length, experienceHandlerMock.addExperienceArgs.size());
|
||||
for(int i = 0; i < expectedExperienceHandlerInputs.length; ++i) {
|
||||
double[] expectedRow = expectedExperienceHandlerInputs[i];
|
||||
INDArray input = experienceHandlerMock.addExperienceArgs.get(i).getObservation().getData();
|
||||
assertEquals(expectedRow.length, input.shape()[1]);
|
||||
for(int j = 0; j < expectedRow.length; ++j) {
|
||||
assertEquals(expectedRow[j], 255.0 * input.getDouble(j), 0.00001);
|
||||
}
|
||||
}
|
||||
|
||||
int arrayIdx = 0;
|
||||
double[][][] expectedMinitransObs = new double[][][] {
|
||||
new double[][] {
|
||||
new double[] { 0.0, 2.0, 4.0, 6.0, 8.0 },
|
||||
new double[] { 2.0, 4.0, 6.0, 8.0, 10.0 },
|
||||
new double[] { 4.0, 6.0, 8.0, 10.0, 12.0 }, // FIXME: The last minitrans contains the next observation
|
||||
},
|
||||
new double[][] {
|
||||
new double[] { 4.0, 6.0, 8.0, 10.0, 12.0 },
|
||||
new double[] { 6.0, 8.0, 10.0, 12.0, 14.0 },
|
||||
new double[] { 8.0, 10.0, 12.0, 14.0, 16.0 }, // FIXME: The last minitrans contains the next observation
|
||||
}
|
||||
};
|
||||
double[] expectedOutputs = new double[] { 1.0, 2.0, 3.0, 4.0, 5.0, 6.0 };
|
||||
double[] expectedRewards = new double[] { 0.0, 0.0, 3.0, 0.0, 0.0, 6.0 };
|
||||
|
||||
assertEquals(2, sut.rewards.size());
|
||||
for(int rewardIdx = 0; rewardIdx < 2; ++rewardIdx) {
|
||||
Stack<MiniTrans<Integer>> miniTransStack = sut.rewards.get(rewardIdx);
|
||||
|
||||
for (int i = 0; i < expectedMinitransObs[rewardIdx].length; ++i) {
|
||||
MiniTrans minitrans = miniTransStack.get(i);
|
||||
|
||||
// Observation
|
||||
double[] expectedRow = expectedMinitransObs[rewardIdx][i];
|
||||
INDArray realRewards = minitrans.getObs();
|
||||
assertEquals(expectedRow.length, realRewards.shape()[1]);
|
||||
for (int j = 0; j < expectedRow.length; ++j) {
|
||||
assertEquals("row: "+ i + " col: " + j, expectedRow[j], 255.0 * realRewards.getDouble(j), 0.00001);
|
||||
}
|
||||
|
||||
assertEquals(expectedOutputs[arrayIdx], minitrans.getOutput()[0].getDouble(0), 0.00001);
|
||||
assertEquals(expectedRewards[arrayIdx], minitrans.getReward(), 0.00001);
|
||||
++arrayIdx;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
public static class TestAsyncThreadDiscrete extends AsyncThreadDiscrete<MockEncodable, MockNeuralNet> {
|
||||
|
@ -167,22 +131,19 @@ public class AsyncThreadDiscreteTest {
|
|||
private final MockAsyncConfiguration config;
|
||||
|
||||
public final List<SubEpochReturn> trainSubEpochResults = new ArrayList<SubEpochReturn>();
|
||||
public final List<Stack<MiniTrans<Integer>>> rewards = new ArrayList<Stack<MiniTrans<Integer>>>();
|
||||
|
||||
public TestAsyncThreadDiscrete(MockAsyncGlobal asyncGlobal, MDP<MockEncodable, Integer, DiscreteSpace> mdp,
|
||||
TrainingListenerList listeners, int threadNumber, int deviceNum, MockPolicy policy,
|
||||
MockAsyncConfiguration config, IHistoryProcessor hp) {
|
||||
MockAsyncConfiguration config, IHistoryProcessor hp,
|
||||
ExperienceHandler<Integer, Transition<Integer>> experienceHandler,
|
||||
UpdateAlgorithm<MockNeuralNet> updateAlgorithm) {
|
||||
super(asyncGlobal, mdp, listeners, threadNumber, deviceNum);
|
||||
this.asyncGlobal = asyncGlobal;
|
||||
this.policy = policy;
|
||||
this.config = config;
|
||||
setHistoryProcessor(hp);
|
||||
}
|
||||
|
||||
@Override
|
||||
public Gradient[] calcGradient(MockNeuralNet mockNeuralNet, Stack<MiniTrans<Integer>> rewards) {
|
||||
this.rewards.add(rewards);
|
||||
return new Gradient[0];
|
||||
setExperienceHandler(experienceHandler);
|
||||
setUpdateAlgorithm(updateAlgorithm);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -200,6 +161,11 @@ public class AsyncThreadDiscreteTest {
|
|||
return policy;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected UpdateAlgorithm<MockNeuralNet> buildUpdateAlgorithm() {
|
||||
return null;
|
||||
}
|
||||
|
||||
@Override
|
||||
public SubEpochReturn trainSubEpoch(Observation sObs, int nstep) {
|
||||
asyncGlobal.increaseCurrentLoop();
|
||||
|
|
|
@ -1,197 +0,0 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2019 Skymind, Inc.
|
||||
* Copyright (c) 2020 Konduit K.K.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
package org.deeplearning4j.rl4j.learning.async.a3c.discrete;
|
||||
|
||||
import org.deeplearning4j.nn.api.NeuralNetwork;
|
||||
import org.deeplearning4j.nn.gradient.Gradient;
|
||||
import org.deeplearning4j.rl4j.learning.IHistoryProcessor;
|
||||
import org.deeplearning4j.rl4j.learning.async.MiniTrans;
|
||||
import org.deeplearning4j.rl4j.learning.configuration.A3CLearningConfiguration;
|
||||
import org.deeplearning4j.rl4j.network.NeuralNet;
|
||||
import org.deeplearning4j.rl4j.network.ac.IActorCritic;
|
||||
import org.deeplearning4j.rl4j.support.*;
|
||||
import org.junit.Test;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
import org.nd4j.linalg.primitives.Pair;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.io.OutputStream;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.Stack;
|
||||
|
||||
import static org.junit.Assert.assertArrayEquals;
|
||||
import static org.junit.Assert.assertEquals;
|
||||
|
||||
public class A3CThreadDiscreteTest {
|
||||
|
||||
@Test
|
||||
public void refac_calcGradient() {
|
||||
// Arrange
|
||||
double gamma = 0.9;
|
||||
MockObservationSpace observationSpace = new MockObservationSpace();
|
||||
MockMDP mdpMock = new MockMDP(observationSpace);
|
||||
A3CLearningConfiguration config = A3CLearningConfiguration.builder().gamma(0.9).build();
|
||||
MockActorCritic actorCriticMock = new MockActorCritic();
|
||||
IHistoryProcessor.Configuration hpConf = new IHistoryProcessor.Configuration(5, 1, 1, 1, 1, 0, 0, 2);
|
||||
MockAsyncGlobal<IActorCritic> asyncGlobalMock = new MockAsyncGlobal<IActorCritic>(actorCriticMock);
|
||||
A3CThreadDiscrete sut = new A3CThreadDiscrete<MockEncodable>(mdpMock, asyncGlobalMock, config, 0, null, 0);
|
||||
MockHistoryProcessor hpMock = new MockHistoryProcessor(hpConf);
|
||||
sut.setHistoryProcessor(hpMock);
|
||||
|
||||
double[][] minitransObs = new double[][] {
|
||||
new double[] { 0.0, 2.0, 4.0, 6.0, 8.0 },
|
||||
new double[] { 2.0, 4.0, 6.0, 8.0, 10.0 },
|
||||
new double[] { 4.0, 6.0, 8.0, 10.0, 12.0 },
|
||||
};
|
||||
double[] outputs = new double[] { 1.0, 2.0, 3.0 };
|
||||
double[] rewards = new double[] { 0.0, 0.0, 3.0 };
|
||||
|
||||
Stack<MiniTrans<Integer>> minitransList = new Stack<MiniTrans<Integer>>();
|
||||
for(int i = 0; i < 3; ++i) {
|
||||
INDArray obs = Nd4j.create(minitransObs[i]).reshape(5, 1, 1);
|
||||
INDArray[] output = new INDArray[] {
|
||||
Nd4j.zeros(5)
|
||||
};
|
||||
output[0].putScalar(i, outputs[i]);
|
||||
minitransList.push(new MiniTrans<>(obs, i, output, rewards[i]));
|
||||
}
|
||||
minitransList.push(new MiniTrans<>(null, 0, null, 4.0)); // The special batch-ending MiniTrans
|
||||
|
||||
// Act
|
||||
sut.calcGradient(actorCriticMock, minitransList);
|
||||
|
||||
// Assert
|
||||
assertEquals(1, actorCriticMock.gradientParams.size());
|
||||
INDArray input = actorCriticMock.gradientParams.get(0).getFirst();
|
||||
INDArray[] labels = actorCriticMock.gradientParams.get(0).getSecond();
|
||||
|
||||
assertEquals(minitransObs.length, input.shape()[0]);
|
||||
for(int i = 0; i < minitransObs.length; ++i) {
|
||||
double[] expectedRow = minitransObs[i];
|
||||
assertEquals(expectedRow.length, input.shape()[1]);
|
||||
for(int j = 0; j < expectedRow.length; ++j) {
|
||||
assertEquals(expectedRow[j], input.getDouble(i, j, 1, 1), 0.00001);
|
||||
}
|
||||
}
|
||||
|
||||
double latestReward = (gamma * 4.0) + 3.0;
|
||||
double[] expectedLabels0 = new double[] { gamma * gamma * latestReward, gamma * latestReward, latestReward };
|
||||
for(int i = 0; i < expectedLabels0.length; ++i) {
|
||||
assertEquals(expectedLabels0[i], labels[0].getDouble(i), 0.00001);
|
||||
}
|
||||
double[][] expectedLabels1 = new double[][] {
|
||||
new double[] { 4.346, 0.0, 0.0, 0.0, 0.0 },
|
||||
new double[] { 0.0, gamma * latestReward, 0.0, 0.0, 0.0 },
|
||||
new double[] { 0.0, 0.0, latestReward, 0.0, 0.0 },
|
||||
};
|
||||
|
||||
assertArrayEquals(new long[] { expectedLabels0.length, 1 }, labels[0].shape());
|
||||
|
||||
for(int i = 0; i < expectedLabels1.length; ++i) {
|
||||
double[] expectedRow = expectedLabels1[i];
|
||||
assertEquals(expectedRow.length, labels[1].shape()[1]);
|
||||
for(int j = 0; j < expectedRow.length; ++j) {
|
||||
assertEquals(expectedRow[j], labels[1].getDouble(i, j), 0.00001);
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
public class MockActorCritic implements IActorCritic {
|
||||
|
||||
public final List<Pair<INDArray, INDArray[]>> gradientParams = new ArrayList<>();
|
||||
|
||||
@Override
|
||||
public NeuralNetwork[] getNeuralNetworks() {
|
||||
return new NeuralNetwork[0];
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean isRecurrent() {
|
||||
return false;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void reset() {
|
||||
|
||||
}
|
||||
|
||||
@Override
|
||||
public void fit(INDArray input, INDArray[] labels) {
|
||||
|
||||
}
|
||||
|
||||
@Override
|
||||
public INDArray[] outputAll(INDArray batch) {
|
||||
return new INDArray[0];
|
||||
}
|
||||
|
||||
@Override
|
||||
public IActorCritic clone() {
|
||||
return this;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void copy(NeuralNet from) {
|
||||
|
||||
}
|
||||
|
||||
@Override
|
||||
public void copy(IActorCritic from) {
|
||||
|
||||
}
|
||||
|
||||
@Override
|
||||
public Gradient[] gradient(INDArray input, INDArray[] labels) {
|
||||
gradientParams.add(new Pair<INDArray, INDArray[]>(input, labels));
|
||||
return new Gradient[0];
|
||||
}
|
||||
|
||||
@Override
|
||||
public void applyGradient(Gradient[] gradient, int batchSize) {
|
||||
|
||||
}
|
||||
|
||||
@Override
|
||||
public void save(OutputStream streamValue, OutputStream streamPolicy) throws IOException {
|
||||
|
||||
}
|
||||
|
||||
@Override
|
||||
public void save(String pathValue, String pathPolicy) throws IOException {
|
||||
|
||||
}
|
||||
|
||||
@Override
|
||||
public double getLatestScore() {
|
||||
return 0;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void save(OutputStream os) throws IOException {
|
||||
|
||||
}
|
||||
|
||||
@Override
|
||||
public void save(String filename) throws IOException {
|
||||
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,160 @@
|
|||
package org.deeplearning4j.rl4j.learning.async.a3c.discrete;
|
||||
|
||||
import org.deeplearning4j.nn.api.NeuralNetwork;
|
||||
import org.deeplearning4j.nn.gradient.Gradient;
|
||||
import org.deeplearning4j.rl4j.experience.StateActionPair;
|
||||
import org.deeplearning4j.rl4j.network.NeuralNet;
|
||||
import org.deeplearning4j.rl4j.network.ac.IActorCritic;
|
||||
import org.deeplearning4j.rl4j.observation.Observation;
|
||||
import org.deeplearning4j.rl4j.support.MockAsyncGlobal;
|
||||
import org.deeplearning4j.rl4j.support.MockMDP;
|
||||
import org.deeplearning4j.rl4j.support.MockObservationSpace;
|
||||
import org.junit.Test;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
import org.nd4j.linalg.primitives.Pair;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.io.OutputStream;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
|
||||
import static org.junit.Assert.assertEquals;
|
||||
|
||||
public class A3CUpdateAlgorithmTest {
|
||||
|
||||
@Test
|
||||
public void refac_calcGradient_non_terminal() {
|
||||
// Arrange
|
||||
double gamma = 0.9;
|
||||
MockObservationSpace observationSpace = new MockObservationSpace(new int[] { 5 });
|
||||
MockMDP mdpMock = new MockMDP(observationSpace);
|
||||
MockActorCritic actorCriticMock = new MockActorCritic();
|
||||
MockAsyncGlobal<IActorCritic> asyncGlobalMock = new MockAsyncGlobal<IActorCritic>(actorCriticMock);
|
||||
A3CUpdateAlgorithm sut = new A3CUpdateAlgorithm(asyncGlobalMock, observationSpace.getShape(), mdpMock.getActionSpace().getSize(), -1, gamma);
|
||||
|
||||
|
||||
INDArray[] originalObservations = new INDArray[] {
|
||||
Nd4j.create(new double[] { 0.0, 0.1, 0.2, 0.3, 0.4 }),
|
||||
Nd4j.create(new double[] { 1.0, 1.1, 1.2, 1.3, 1.4 }),
|
||||
Nd4j.create(new double[] { 2.0, 2.1, 2.2, 2.3, 2.4 }),
|
||||
Nd4j.create(new double[] { 3.0, 3.1, 3.2, 3.3, 3.4 }),
|
||||
};
|
||||
int[] actions = new int[] { 0, 1, 2, 1 };
|
||||
double[] rewards = new double[] { 0.1, 1.0, 10.0, 100.0 };
|
||||
|
||||
List<StateActionPair<Integer>> experience = new ArrayList<StateActionPair<Integer>>();
|
||||
for(int i = 0; i < originalObservations.length; ++i) {
|
||||
experience.add(new StateActionPair<>(new Observation(originalObservations[i]), actions[i], rewards[i], false));
|
||||
}
|
||||
|
||||
// Act
|
||||
sut.computeGradients(actorCriticMock, experience);
|
||||
|
||||
// Assert
|
||||
assertEquals(1, actorCriticMock.gradientParams.size());
|
||||
|
||||
// Inputs
|
||||
INDArray input = actorCriticMock.gradientParams.get(0).getLeft();
|
||||
for(int i = 0; i < 4; ++i) {
|
||||
for(int j = 0; j < 5; ++j) {
|
||||
assertEquals(i + j / 10.0, input.getDouble(i, j), 0.00001);
|
||||
}
|
||||
}
|
||||
|
||||
INDArray targets = actorCriticMock.gradientParams.get(0).getRight()[0];
|
||||
INDArray logSoftmax = actorCriticMock.gradientParams.get(0).getRight()[1];
|
||||
|
||||
assertEquals(4, targets.shape()[0]);
|
||||
assertEquals(1, targets.shape()[1]);
|
||||
|
||||
// FIXME: check targets values once fixed
|
||||
|
||||
assertEquals(4, logSoftmax.shape()[0]);
|
||||
assertEquals(5, logSoftmax.shape()[1]);
|
||||
|
||||
// FIXME: check logSoftmax values once fixed
|
||||
|
||||
}
|
||||
|
||||
public class MockActorCritic implements IActorCritic {
|
||||
|
||||
public final List<Pair<INDArray, INDArray[]>> gradientParams = new ArrayList<>();
|
||||
|
||||
@Override
|
||||
public NeuralNetwork[] getNeuralNetworks() {
|
||||
return new NeuralNetwork[0];
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean isRecurrent() {
|
||||
return false;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void reset() {
|
||||
|
||||
}
|
||||
|
||||
@Override
|
||||
public void fit(INDArray input, INDArray[] labels) {
|
||||
|
||||
}
|
||||
|
||||
@Override
|
||||
public INDArray[] outputAll(INDArray batch) {
|
||||
return new INDArray[] { batch.mul(-1.0) };
|
||||
}
|
||||
|
||||
@Override
|
||||
public IActorCritic clone() {
|
||||
return this;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void copy(NeuralNet from) {
|
||||
|
||||
}
|
||||
|
||||
@Override
|
||||
public void copy(IActorCritic from) {
|
||||
|
||||
}
|
||||
|
||||
@Override
|
||||
public Gradient[] gradient(INDArray input, INDArray[] labels) {
|
||||
gradientParams.add(new Pair<INDArray, INDArray[]>(input, labels));
|
||||
return new Gradient[0];
|
||||
}
|
||||
|
||||
@Override
|
||||
public void applyGradient(Gradient[] gradient, int batchSize) {
|
||||
|
||||
}
|
||||
|
||||
@Override
|
||||
public void save(OutputStream streamValue, OutputStream streamPolicy) throws IOException {
|
||||
|
||||
}
|
||||
|
||||
@Override
|
||||
public void save(String pathValue, String pathPolicy) throws IOException {
|
||||
|
||||
}
|
||||
|
||||
@Override
|
||||
public double getLatestScore() {
|
||||
return 0;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void save(OutputStream os) throws IOException {
|
||||
|
||||
}
|
||||
|
||||
@Override
|
||||
public void save(String filename) throws IOException {
|
||||
|
||||
}
|
||||
}
|
||||
}
|
|
@ -1,98 +0,0 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2020 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
package org.deeplearning4j.rl4j.learning.async.nstep.discrete;
|
||||
|
||||
import org.deeplearning4j.rl4j.learning.IHistoryProcessor;
|
||||
import org.deeplearning4j.rl4j.learning.async.MiniTrans;
|
||||
import org.deeplearning4j.rl4j.learning.configuration.AsyncQLearningConfiguration;
|
||||
import org.deeplearning4j.rl4j.support.*;
|
||||
import org.junit.Test;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
|
||||
import java.util.Stack;
|
||||
|
||||
import static org.junit.Assert.assertEquals;
|
||||
|
||||
public class AsyncNStepQLearningThreadDiscreteTest {
|
||||
|
||||
@Test
|
||||
public void refac_calcGradient() {
|
||||
// Arrange
|
||||
double gamma = 0.9;
|
||||
MockObservationSpace observationSpace = new MockObservationSpace();
|
||||
MockMDP mdpMock = new MockMDP(observationSpace);
|
||||
AsyncQLearningConfiguration config = AsyncQLearningConfiguration.builder().gamma(gamma).build();
|
||||
MockDQN dqnMock = new MockDQN();
|
||||
IHistoryProcessor.Configuration hpConf = new IHistoryProcessor.Configuration(5, 1, 1, 1, 1, 0, 0, 2);
|
||||
MockAsyncGlobal asyncGlobalMock = new MockAsyncGlobal(dqnMock);
|
||||
AsyncNStepQLearningThreadDiscrete sut = new AsyncNStepQLearningThreadDiscrete<MockEncodable>(mdpMock, asyncGlobalMock, config, null, 0, 0);
|
||||
MockHistoryProcessor hpMock = new MockHistoryProcessor(hpConf);
|
||||
sut.setHistoryProcessor(hpMock);
|
||||
|
||||
double[][] minitransObs = new double[][] {
|
||||
new double[] { 0.0, 2.0, 4.0, 6.0, 8.0 },
|
||||
new double[] { 2.0, 4.0, 6.0, 8.0, 10.0 },
|
||||
new double[] { 4.0, 6.0, 8.0, 10.0, 12.0 },
|
||||
};
|
||||
double[] outputs = new double[] { 1.0, 2.0, 3.0 };
|
||||
double[] rewards = new double[] { 0.0, 0.0, 3.0 };
|
||||
|
||||
Stack<MiniTrans<Integer>> minitransList = new Stack<MiniTrans<Integer>>();
|
||||
for(int i = 0; i < 3; ++i) {
|
||||
INDArray obs = Nd4j.create(minitransObs[i]).reshape(5, 1, 1);
|
||||
INDArray[] output = new INDArray[] {
|
||||
Nd4j.zeros(5)
|
||||
};
|
||||
output[0].putScalar(i, outputs[i]);
|
||||
minitransList.push(new MiniTrans<>(obs, i, output, rewards[i]));
|
||||
}
|
||||
minitransList.push(new MiniTrans<>(null, 0, null, 4.0)); // The special batch-ending MiniTrans
|
||||
|
||||
// Act
|
||||
sut.calcGradient(dqnMock, minitransList);
|
||||
|
||||
// Assert
|
||||
assertEquals(1, dqnMock.gradientParams.size());
|
||||
INDArray input = dqnMock.gradientParams.get(0).getFirst();
|
||||
INDArray labels = dqnMock.gradientParams.get(0).getSecond();
|
||||
|
||||
assertEquals(minitransObs.length, input.shape()[0]);
|
||||
for(int i = 0; i < minitransObs.length; ++i) {
|
||||
double[] expectedRow = minitransObs[i];
|
||||
assertEquals(expectedRow.length, input.shape()[1]);
|
||||
for(int j = 0; j < expectedRow.length; ++j) {
|
||||
assertEquals(expectedRow[j], input.getDouble(i, j, 1, 1), 0.00001);
|
||||
}
|
||||
}
|
||||
|
||||
double latestReward = (gamma * 4.0) + 3.0;
|
||||
double[][] expectedLabels = new double[][] {
|
||||
new double[] { gamma * gamma * latestReward, 0.0, 0.0, 0.0, 0.0 },
|
||||
new double[] { 0.0, gamma * latestReward, 0.0, 0.0, 0.0 },
|
||||
new double[] { 0.0, 0.0, latestReward, 0.0, 0.0 },
|
||||
};
|
||||
assertEquals(minitransObs.length, labels.shape()[0]);
|
||||
for(int i = 0; i < minitransObs.length; ++i) {
|
||||
double[] expectedRow = expectedLabels[i];
|
||||
assertEquals(expectedRow.length, labels.shape()[1]);
|
||||
for(int j = 0; j < expectedRow.length; ++j) {
|
||||
assertEquals(expectedRow[j], labels.getDouble(i, j), 0.00001);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,115 @@
|
|||
package org.deeplearning4j.rl4j.learning.async.nstep.discrete;
|
||||
|
||||
import org.deeplearning4j.rl4j.experience.StateActionPair;
|
||||
import org.deeplearning4j.rl4j.learning.async.UpdateAlgorithm;
|
||||
import org.deeplearning4j.rl4j.observation.Observation;
|
||||
import org.deeplearning4j.rl4j.support.MockAsyncGlobal;
|
||||
import org.deeplearning4j.rl4j.support.MockDQN;
|
||||
import org.junit.Test;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
|
||||
import static org.junit.Assert.assertEquals;
|
||||
|
||||
public class QLearningUpdateAlgorithmTest {
|
||||
|
||||
@Test
|
||||
public void when_isTerminal_expect_initRewardIs0() {
|
||||
// Arrange
|
||||
MockDQN dqnMock = new MockDQN();
|
||||
MockAsyncGlobal asyncGlobalMock = new MockAsyncGlobal(dqnMock);
|
||||
UpdateAlgorithm sut = new QLearningUpdateAlgorithm(asyncGlobalMock, new int[] { 1 }, 1, -1, 1.0);
|
||||
List<StateActionPair<Integer>> experience = new ArrayList<StateActionPair<Integer>>() {
|
||||
{
|
||||
add(new StateActionPair<Integer>(new Observation(Nd4j.zeros(1)), 0, 0.0, true));
|
||||
}
|
||||
};
|
||||
|
||||
// Act
|
||||
sut.computeGradients(dqnMock, experience);
|
||||
|
||||
// Assert
|
||||
assertEquals(0.0, dqnMock.gradientParams.get(0).getRight().getDouble(0), 0.00001);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void when_terminalAndNoTargetUpdate_expect_initRewardWithMaxQFromCurrent() {
|
||||
// Arrange
|
||||
MockDQN globalDQNMock = new MockDQN();
|
||||
MockAsyncGlobal asyncGlobalMock = new MockAsyncGlobal(globalDQNMock);
|
||||
UpdateAlgorithm sut = new QLearningUpdateAlgorithm(asyncGlobalMock, new int[] { 2 }, 2, -1, 1.0);
|
||||
List<StateActionPair<Integer>> experience = new ArrayList<StateActionPair<Integer>>() {
|
||||
{
|
||||
add(new StateActionPair<Integer>(new Observation(Nd4j.create(new double[] { -123.0, -234.0 })), 0, 0.0, false));
|
||||
}
|
||||
};
|
||||
MockDQN dqnMock = new MockDQN();
|
||||
|
||||
// Act
|
||||
sut.computeGradients(dqnMock, experience);
|
||||
|
||||
// Assert
|
||||
assertEquals(2, dqnMock.outputAllParams.size());
|
||||
assertEquals(-123.0, dqnMock.outputAllParams.get(0).getDouble(0, 0), 0.00001);
|
||||
assertEquals(234.0, dqnMock.gradientParams.get(0).getRight().getDouble(0), 0.00001);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void when_terminalWithTargetUpdate_expect_initRewardWithMaxQFromGlobal() {
|
||||
// Arrange
|
||||
MockDQN globalDQNMock = new MockDQN();
|
||||
MockAsyncGlobal asyncGlobalMock = new MockAsyncGlobal(globalDQNMock);
|
||||
UpdateAlgorithm sut = new QLearningUpdateAlgorithm(asyncGlobalMock, new int[] { 2 }, 2, 1, 1.0);
|
||||
List<StateActionPair<Integer>> experience = new ArrayList<StateActionPair<Integer>>() {
|
||||
{
|
||||
add(new StateActionPair<Integer>(new Observation(Nd4j.create(new double[] { -123.0, -234.0 })), 0, 0.0, false));
|
||||
}
|
||||
};
|
||||
MockDQN dqnMock = new MockDQN();
|
||||
|
||||
// Act
|
||||
sut.computeGradients(dqnMock, experience);
|
||||
|
||||
// Assert
|
||||
assertEquals(1, globalDQNMock.outputAllParams.size());
|
||||
assertEquals(-123.0, globalDQNMock.outputAllParams.get(0).getDouble(0, 0), 0.00001);
|
||||
assertEquals(234.0, dqnMock.gradientParams.get(0).getRight().getDouble(0), 0.00001);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void when_callingWithMultipleExperiences_expect_gradientsAreValid() {
|
||||
// Arrange
|
||||
double gamma = 0.9;
|
||||
MockDQN globalDQNMock = new MockDQN();
|
||||
MockAsyncGlobal asyncGlobalMock = new MockAsyncGlobal(globalDQNMock);
|
||||
UpdateAlgorithm sut = new QLearningUpdateAlgorithm(asyncGlobalMock, new int[] { 2 }, 2, 1, gamma);
|
||||
List<StateActionPair<Integer>> experience = new ArrayList<StateActionPair<Integer>>() {
|
||||
{
|
||||
add(new StateActionPair<Integer>(new Observation(Nd4j.create(new double[] { -1.1, -1.2 })), 0, 1.0, false));
|
||||
add(new StateActionPair<Integer>(new Observation(Nd4j.create(new double[] { -2.1, -2.2 })), 1, 2.0, true));
|
||||
}
|
||||
};
|
||||
MockDQN dqnMock = new MockDQN();
|
||||
|
||||
// Act
|
||||
sut.computeGradients(dqnMock, experience);
|
||||
|
||||
// Assert
|
||||
// input side -- should be a stack of observations
|
||||
INDArray input = dqnMock.gradientParams.get(0).getLeft();
|
||||
assertEquals(-1.1, input.getDouble(0, 0), 0.00001);
|
||||
assertEquals(-1.2, input.getDouble(0, 1), 0.00001);
|
||||
assertEquals(-2.1, input.getDouble(1, 0), 0.00001);
|
||||
assertEquals(-2.2, input.getDouble(1, 1), 0.00001);
|
||||
|
||||
// target side
|
||||
INDArray target = dqnMock.gradientParams.get(0).getRight();
|
||||
assertEquals(1.0 + gamma * 2.0, target.getDouble(0, 0), 0.00001);
|
||||
assertEquals(1.2, target.getDouble(0, 1), 0.00001);
|
||||
assertEquals(2.1, target.getDouble(1, 0), 0.00001);
|
||||
assertEquals(2.0, target.getDouble(1, 1), 0.00001);
|
||||
}
|
||||
}
|
|
@ -17,6 +17,8 @@
|
|||
|
||||
package org.deeplearning4j.rl4j.learning.sync.qlearning.discrete;
|
||||
|
||||
import org.deeplearning4j.rl4j.experience.ExperienceHandler;
|
||||
import org.deeplearning4j.rl4j.experience.StateActionPair;
|
||||
import org.deeplearning4j.rl4j.learning.IHistoryProcessor;
|
||||
import org.deeplearning4j.rl4j.learning.configuration.QLearningConfiguration;
|
||||
import org.deeplearning4j.rl4j.learning.sync.IExpReplay;
|
||||
|
@ -75,8 +77,8 @@ public class QLearningDiscreteTest {
|
|||
.build();
|
||||
|
||||
MockDataManager dataManager = new MockDataManager(false);
|
||||
MockExpReplay expReplay = new MockExpReplay();
|
||||
TestQLearningDiscrete sut = new TestQLearningDiscrete(mdp, dqn, conf, dataManager, expReplay, 10, random);
|
||||
MockExperienceHandler experienceHandler = new MockExperienceHandler();
|
||||
TestQLearningDiscrete sut = new TestQLearningDiscrete(mdp, dqn, conf, dataManager, experienceHandler, 10, random);
|
||||
IHistoryProcessor.Configuration hpConf = new IHistoryProcessor.Configuration(5, 4, 4, 4, 4, 0, 0, 2);
|
||||
MockHistoryProcessor hp = new MockHistoryProcessor(hpConf);
|
||||
sut.setHistoryProcessor(hp);
|
||||
|
@ -93,7 +95,6 @@ public class QLearningDiscreteTest {
|
|||
for (int i = 0; i < expectedRecords.length; ++i) {
|
||||
assertEquals(expectedRecords[i], hp.recordCalls.get(i).getDouble(0), 0.0001);
|
||||
}
|
||||
|
||||
assertEquals(0, hp.startMonitorCallCount);
|
||||
assertEquals(0, hp.stopMonitorCallCount);
|
||||
|
||||
|
@ -133,7 +134,7 @@ public class QLearningDiscreteTest {
|
|||
// MDP calls
|
||||
assertArrayEquals(new Integer[]{0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 4, 4, 2, 2, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4}, mdp.actions.toArray());
|
||||
|
||||
// ExpReplay calls
|
||||
// ExperienceHandler calls
|
||||
double[] expectedTrRewards = new double[] { 9.0, 21.0, 25.0, 29.0, 33.0, 37.0, 41.0 };
|
||||
int[] expectedTrActions = new int[] { 1, 4, 2, 4, 4, 4, 4, 4 };
|
||||
double[] expectedTrNextObservation = new double[] { 2.0, 4.0, 6.0, 8.0, 10.0, 12.0, 14.0, 16.0 };
|
||||
|
@ -147,16 +148,17 @@ public class QLearningDiscreteTest {
|
|||
new double[] { 12.0, 14.0, 16.0, 18.0, 20.0 },
|
||||
new double[] { 14.0, 16.0, 18.0, 20.0, 22.0 },
|
||||
};
|
||||
assertEquals(expectedTrObservations.length, expReplay.transitions.size());
|
||||
|
||||
assertEquals(expectedTrObservations.length, experienceHandler.addExperienceArgs.size());
|
||||
for(int i = 0; i < expectedTrRewards.length; ++i) {
|
||||
Transition tr = expReplay.transitions.get(i);
|
||||
assertEquals(expectedTrRewards[i], tr.getReward(), 0.0001);
|
||||
assertEquals(expectedTrActions[i], tr.getAction());
|
||||
assertEquals(expectedTrNextObservation[i], 255.0 * tr.getNextObservation().getDouble(0), 0.0001);
|
||||
StateActionPair<Integer> stateActionPair = experienceHandler.addExperienceArgs.get(i);
|
||||
assertEquals(expectedTrRewards[i], stateActionPair.getReward(), 0.0001);
|
||||
assertEquals((int)expectedTrActions[i], (int)stateActionPair.getAction());
|
||||
for(int j = 0; j < expectedTrObservations[i].length; ++j) {
|
||||
assertEquals("row: " + i + " col: " + j, expectedTrObservations[i][j], 255.0 * tr.getObservation().getData().getDouble(0, j, 0), 0.0001);
|
||||
assertEquals("row: "+ i + " col: " + j, expectedTrObservations[i][j], 255.0 * stateActionPair.getObservation().getData().getDouble(0, j, 0), 0.0001);
|
||||
}
|
||||
}
|
||||
assertEquals(expectedTrNextObservation[expectedTrNextObservation.length - 1], 255.0 * experienceHandler.finalObservation.getData().getDouble(0), 0.0001);
|
||||
|
||||
// trainEpoch result
|
||||
assertEquals(initStepCount + 16, result.getStepCounter());
|
||||
|
@ -167,22 +169,18 @@ public class QLearningDiscreteTest {
|
|||
|
||||
public static class TestQLearningDiscrete extends QLearningDiscrete<MockEncodable> {
|
||||
public TestQLearningDiscrete(MDP<MockEncodable, Integer, DiscreteSpace> mdp, IDQN dqn,
|
||||
QLearningConfiguration conf, IDataManager dataManager, MockExpReplay expReplay,
|
||||
QLearningConfiguration conf, IDataManager dataManager, ExperienceHandler<Integer, Transition<Integer>> experienceHandler,
|
||||
int epsilonNbStep, Random rnd) {
|
||||
super(mdp, dqn, conf, epsilonNbStep, rnd);
|
||||
addListener(new DataManagerTrainingListener(dataManager));
|
||||
setExpReplay(expReplay);
|
||||
setExperienceHandler(experienceHandler);
|
||||
}
|
||||
|
||||
@Override
|
||||
protected DataSet setTarget(ArrayList<Transition<Integer>> transitions) {
|
||||
protected DataSet setTarget(List<Transition<Integer>> transitions) {
|
||||
return new org.nd4j.linalg.dataset.DataSet(Nd4j.create(new double[] { 123.0 }), Nd4j.create(new double[] { 234.0 }));
|
||||
}
|
||||
|
||||
public void setExpReplay(IExpReplay<Integer> exp) {
|
||||
this.expReplay = exp;
|
||||
}
|
||||
|
||||
@Override
|
||||
public IDataManager.StatEntry trainEpoch() {
|
||||
return super.trainEpoch();
|
||||
|
|
|
@ -19,6 +19,7 @@ public class MockDQN implements IDQN {
|
|||
public final List<INDArray> outputParams = new ArrayList<>();
|
||||
public final List<Pair<INDArray, INDArray>> fitParams = new ArrayList<>();
|
||||
public final List<Pair<INDArray, INDArray>> gradientParams = new ArrayList<>();
|
||||
public final List<INDArray> outputAllParams = new ArrayList<>();
|
||||
|
||||
@Override
|
||||
public NeuralNetwork[] getNeuralNetworks() {
|
||||
|
@ -58,7 +59,8 @@ public class MockDQN implements IDQN {
|
|||
|
||||
@Override
|
||||
public INDArray[] outputAll(INDArray batch) {
|
||||
return new INDArray[0];
|
||||
outputAllParams.add(batch);
|
||||
return new INDArray[] { batch.mul(-1.0) };
|
||||
}
|
||||
|
||||
@Override
|
||||
|
|
|
@ -1,22 +0,0 @@
|
|||
package org.deeplearning4j.rl4j.support;
|
||||
|
||||
import org.deeplearning4j.rl4j.learning.sync.IExpReplay;
|
||||
import org.deeplearning4j.rl4j.learning.sync.Transition;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
|
||||
public class MockExpReplay implements IExpReplay<Integer> {
|
||||
|
||||
public List<Transition<Integer>> transitions = new ArrayList<>();
|
||||
|
||||
@Override
|
||||
public ArrayList<Transition<Integer>> getBatch() {
|
||||
return null;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void store(Transition<Integer> transition) {
|
||||
transitions.add(transition);
|
||||
}
|
||||
}
|
|
@ -0,0 +1,46 @@
|
|||
package org.deeplearning4j.rl4j.support;
|
||||
|
||||
import org.deeplearning4j.rl4j.experience.ExperienceHandler;
|
||||
import org.deeplearning4j.rl4j.experience.StateActionPair;
|
||||
import org.deeplearning4j.rl4j.learning.sync.Transition;
|
||||
import org.deeplearning4j.rl4j.observation.Observation;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
|
||||
public class MockExperienceHandler implements ExperienceHandler<Integer, Transition<Integer>> {
|
||||
public List<StateActionPair<Integer>> addExperienceArgs = new ArrayList<StateActionPair<Integer>>();
|
||||
public Observation finalObservation;
|
||||
public boolean isGenerateTrainingBatchCalled;
|
||||
public boolean isResetCalled;
|
||||
|
||||
@Override
|
||||
public void addExperience(Observation observation, Integer action, double reward, boolean isTerminal) {
|
||||
addExperienceArgs.add(new StateActionPair<>(observation, action, reward, isTerminal));
|
||||
}
|
||||
|
||||
@Override
|
||||
public void setFinalObservation(Observation observation) {
|
||||
finalObservation = observation;
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<Transition<Integer>> generateTrainingBatch() {
|
||||
isGenerateTrainingBatchCalled = true;
|
||||
return new ArrayList<Transition<Integer>>() {
|
||||
{
|
||||
add(new Transition<Integer>(null, 0, 0.0, false));
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
@Override
|
||||
public void reset() {
|
||||
isResetCalled = true;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int getTrainingBatchSize() {
|
||||
return 1;
|
||||
}
|
||||
}
|
|
@ -5,6 +5,16 @@ import org.nd4j.linalg.api.ndarray.INDArray;
|
|||
|
||||
public class MockObservationSpace implements ObservationSpace {
|
||||
|
||||
private final int[] shape;
|
||||
|
||||
public MockObservationSpace() {
|
||||
this(new int[] { 1 });
|
||||
}
|
||||
|
||||
public MockObservationSpace(int[] shape) {
|
||||
this.shape = shape;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getName() {
|
||||
return null;
|
||||
|
@ -12,7 +22,7 @@ public class MockObservationSpace implements ObservationSpace {
|
|||
|
||||
@Override
|
||||
public int[] getShape() {
|
||||
return new int[] { 1 };
|
||||
return shape;
|
||||
}
|
||||
|
||||
@Override
|
||||
|
|
|
@ -0,0 +1,19 @@
|
|||
package org.deeplearning4j.rl4j.support;
|
||||
|
||||
import org.deeplearning4j.nn.gradient.Gradient;
|
||||
import org.deeplearning4j.rl4j.experience.StateActionPair;
|
||||
import org.deeplearning4j.rl4j.learning.async.UpdateAlgorithm;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
|
||||
public class MockUpdateAlgorithm implements UpdateAlgorithm<MockNeuralNet> {
|
||||
|
||||
public final List<List<StateActionPair<Integer>>> experiences = new ArrayList<List<StateActionPair<Integer>>>();
|
||||
|
||||
@Override
|
||||
public Gradient[] computeGradients(MockNeuralNet current, List<StateActionPair<Integer>> experience) {
|
||||
experiences.add(experience);
|
||||
return new Gradient[0];
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue