RL4J: Add ExperienceHandler (#369)
* Added ExperienceHandler Signed-off-by: Alexandre Boulanger <aboulang2002@yahoo.com> * Added getTrainingBatchSize() Signed-off-by: Alexandre Boulanger <aboulang2002@yahoo.com>master
parent
3e2dbc65dd
commit
f1debe8c07
|
@ -0,0 +1,54 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2020 Konduit K.K.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
package org.deeplearning4j.rl4j.experience;
|
||||||
|
|
||||||
|
import org.deeplearning4j.rl4j.observation.Observation;
|
||||||
|
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* A common interface to all classes capable of handling experience generated by the agents in a learning context.
|
||||||
|
*
|
||||||
|
* @param <A> Action type
|
||||||
|
* @param <E> Experience type
|
||||||
|
*
|
||||||
|
* @author Alexandre Boulanger
|
||||||
|
*/
|
||||||
|
public interface ExperienceHandler<A, E> {
|
||||||
|
void addExperience(Observation observation, A action, double reward, boolean isTerminal);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Called when the episode is done with the last observation
|
||||||
|
* @param observation
|
||||||
|
*/
|
||||||
|
void setFinalObservation(Observation observation);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @return The size of the list that will be returned by generateTrainingBatch().
|
||||||
|
*/
|
||||||
|
int getTrainingBatchSize();
|
||||||
|
|
||||||
|
/**
|
||||||
|
* The elements are returned in the historical order (i.e. in the order they happened)
|
||||||
|
* @return The list of experience elements
|
||||||
|
*/
|
||||||
|
List<E> generateTrainingBatch();
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Signal the experience handler that a new episode is starting
|
||||||
|
*/
|
||||||
|
void reset();
|
||||||
|
}
|
|
@ -0,0 +1,111 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2020 Konduit K.K.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
package org.deeplearning4j.rl4j.experience;
|
||||||
|
|
||||||
|
import lombok.EqualsAndHashCode;
|
||||||
|
import org.deeplearning4j.rl4j.learning.sync.ExpReplay;
|
||||||
|
import org.deeplearning4j.rl4j.learning.sync.IExpReplay;
|
||||||
|
import org.deeplearning4j.rl4j.learning.sync.Transition;
|
||||||
|
import org.deeplearning4j.rl4j.observation.Observation;
|
||||||
|
import org.nd4j.linalg.api.rng.Random;
|
||||||
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
|
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* A experience handler that stores the experience in a replay memory. See https://arxiv.org/abs/1312.5602
|
||||||
|
* The experience container is a {@link Transition Transition} that stores the tuple observation-action-reward-nextObservation,
|
||||||
|
* as well as whether or the not the episode ended after the Transition
|
||||||
|
*
|
||||||
|
* @param <A> Action type
|
||||||
|
*/
|
||||||
|
@EqualsAndHashCode
|
||||||
|
public class ReplayMemoryExperienceHandler<A> implements ExperienceHandler<A, Transition<A>> {
|
||||||
|
private static final int DEFAULT_MAX_REPLAY_MEMORY_SIZE = 150000;
|
||||||
|
private static final int DEFAULT_BATCH_SIZE = 32;
|
||||||
|
|
||||||
|
private IExpReplay<A> expReplay;
|
||||||
|
|
||||||
|
private Transition<A> pendingTransition;
|
||||||
|
|
||||||
|
public ReplayMemoryExperienceHandler(IExpReplay<A> expReplay) {
|
||||||
|
this.expReplay = expReplay;
|
||||||
|
}
|
||||||
|
|
||||||
|
public ReplayMemoryExperienceHandler(int maxReplayMemorySize, int batchSize, Random random) {
|
||||||
|
this(new ExpReplay<A>(maxReplayMemorySize, batchSize, random));
|
||||||
|
}
|
||||||
|
|
||||||
|
public void addExperience(Observation observation, A action, double reward, boolean isTerminal) {
|
||||||
|
setNextObservationOnPending(observation);
|
||||||
|
pendingTransition = new Transition<>(observation, action, reward, isTerminal);
|
||||||
|
}
|
||||||
|
|
||||||
|
public void setFinalObservation(Observation observation) {
|
||||||
|
setNextObservationOnPending(observation);
|
||||||
|
pendingTransition = null;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public int getTrainingBatchSize() {
|
||||||
|
return expReplay.getBatchSize();
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @return A batch of experience selected from the replay memory. The replay memory is unchanged after the call.
|
||||||
|
*/
|
||||||
|
@Override
|
||||||
|
public List<Transition<A>> generateTrainingBatch() {
|
||||||
|
return expReplay.getBatch();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void reset() {
|
||||||
|
pendingTransition = null;
|
||||||
|
}
|
||||||
|
|
||||||
|
private void setNextObservationOnPending(Observation observation) {
|
||||||
|
if(pendingTransition != null) {
|
||||||
|
pendingTransition.setNextObservation(observation);
|
||||||
|
expReplay.store(pendingTransition);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public class Builder {
|
||||||
|
private int maxReplayMemorySize = DEFAULT_MAX_REPLAY_MEMORY_SIZE;
|
||||||
|
private int batchSize = DEFAULT_BATCH_SIZE;
|
||||||
|
private Random random = Nd4j.getRandom();
|
||||||
|
|
||||||
|
public Builder maxReplayMemorySize(int value) {
|
||||||
|
maxReplayMemorySize = value;
|
||||||
|
return this;
|
||||||
|
}
|
||||||
|
|
||||||
|
public Builder batchSize(int value) {
|
||||||
|
batchSize = value;
|
||||||
|
return this;
|
||||||
|
}
|
||||||
|
|
||||||
|
public Builder random(Random value) {
|
||||||
|
random = value;
|
||||||
|
return this;
|
||||||
|
}
|
||||||
|
|
||||||
|
public ReplayMemoryExperienceHandler<A> build() {
|
||||||
|
return new ReplayMemoryExperienceHandler<A>(maxReplayMemorySize, batchSize, random);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,67 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2020 Konduit K.K.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
package org.deeplearning4j.rl4j.experience;
|
||||||
|
|
||||||
|
import org.deeplearning4j.rl4j.observation.Observation;
|
||||||
|
|
||||||
|
import java.util.ArrayList;
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* A simple {@link ExperienceHandler experience handler} that stores the experiences.
|
||||||
|
* Note: Calling {@link StateActionExperienceHandler#generateTrainingBatch() generateTrainingBatch()} will clear the stored experiences
|
||||||
|
*
|
||||||
|
* @param <A> Action type
|
||||||
|
*
|
||||||
|
* @author Alexandre Boulanger
|
||||||
|
*/
|
||||||
|
public class StateActionExperienceHandler<A> implements ExperienceHandler<A, StateActionPair<A>> {
|
||||||
|
|
||||||
|
private List<StateActionPair<A>> stateActionPairs;
|
||||||
|
|
||||||
|
public void setFinalObservation(Observation observation) {
|
||||||
|
// Do nothing
|
||||||
|
}
|
||||||
|
|
||||||
|
public void addExperience(Observation observation, A action, double reward, boolean isTerminal) {
|
||||||
|
stateActionPairs.add(new StateActionPair<A>(observation, action, reward, isTerminal));
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public int getTrainingBatchSize() {
|
||||||
|
return stateActionPairs.size();
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* The elements are returned in the historical order (i.e. in the order they happened)
|
||||||
|
* Note: the experience store is cleared after calling this method.
|
||||||
|
*
|
||||||
|
* @return The list of experience elements
|
||||||
|
*/
|
||||||
|
@Override
|
||||||
|
public List<StateActionPair<A>> generateTrainingBatch() {
|
||||||
|
List<StateActionPair<A>> result = stateActionPairs;
|
||||||
|
stateActionPairs = new ArrayList<>();
|
||||||
|
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void reset() {
|
||||||
|
stateActionPairs = new ArrayList<>();
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
|
@ -0,0 +1,49 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2020 Konduit K.K.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
package org.deeplearning4j.rl4j.experience;
|
||||||
|
|
||||||
|
import lombok.AllArgsConstructor;
|
||||||
|
import lombok.Getter;
|
||||||
|
import org.deeplearning4j.rl4j.observation.Observation;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* A simple experience container. Used by {@link StateActionExperienceHandler StateActionExperienceHandler}.
|
||||||
|
*
|
||||||
|
* @param <A> Action type
|
||||||
|
*
|
||||||
|
* @author Alexandre Boulanger
|
||||||
|
*/
|
||||||
|
@AllArgsConstructor
|
||||||
|
public class StateActionPair<A> {
|
||||||
|
|
||||||
|
/**
|
||||||
|
* The observation before the action is taken
|
||||||
|
*/
|
||||||
|
@Getter
|
||||||
|
private final Observation observation;
|
||||||
|
|
||||||
|
@Getter
|
||||||
|
private final A action;
|
||||||
|
|
||||||
|
@Getter
|
||||||
|
private final double reward;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* True if the episode ended after the action has been taken.
|
||||||
|
*/
|
||||||
|
@Getter
|
||||||
|
private final boolean terminal;
|
||||||
|
}
|
|
@ -18,9 +18,12 @@
|
||||||
|
|
||||||
package org.deeplearning4j.rl4j.learning.async;
|
package org.deeplearning4j.rl4j.learning.async;
|
||||||
|
|
||||||
|
import lombok.AccessLevel;
|
||||||
import lombok.Getter;
|
import lombok.Getter;
|
||||||
|
import lombok.Setter;
|
||||||
import org.deeplearning4j.gym.StepReply;
|
import org.deeplearning4j.gym.StepReply;
|
||||||
import org.deeplearning4j.nn.gradient.Gradient;
|
import org.deeplearning4j.rl4j.experience.ExperienceHandler;
|
||||||
|
import org.deeplearning4j.rl4j.experience.StateActionExperienceHandler;
|
||||||
import org.deeplearning4j.rl4j.learning.IHistoryProcessor;
|
import org.deeplearning4j.rl4j.learning.IHistoryProcessor;
|
||||||
import org.deeplearning4j.rl4j.learning.listener.TrainingListenerList;
|
import org.deeplearning4j.rl4j.learning.listener.TrainingListenerList;
|
||||||
import org.deeplearning4j.rl4j.mdp.MDP;
|
import org.deeplearning4j.rl4j.mdp.MDP;
|
||||||
|
@ -28,10 +31,6 @@ import org.deeplearning4j.rl4j.network.NeuralNet;
|
||||||
import org.deeplearning4j.rl4j.observation.Observation;
|
import org.deeplearning4j.rl4j.observation.Observation;
|
||||||
import org.deeplearning4j.rl4j.policy.IPolicy;
|
import org.deeplearning4j.rl4j.policy.IPolicy;
|
||||||
import org.deeplearning4j.rl4j.space.DiscreteSpace;
|
import org.deeplearning4j.rl4j.space.DiscreteSpace;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
|
||||||
|
|
||||||
import java.util.Stack;
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @author rubenfiszel (ruben.fiszel@epfl.ch) on 8/5/16.
|
* @author rubenfiszel (ruben.fiszel@epfl.ch) on 8/5/16.
|
||||||
|
@ -45,13 +44,39 @@ public abstract class AsyncThreadDiscrete<O, NN extends NeuralNet>
|
||||||
@Getter
|
@Getter
|
||||||
private NN current;
|
private NN current;
|
||||||
|
|
||||||
public AsyncThreadDiscrete(IAsyncGlobal<NN> asyncGlobal, MDP<O, Integer, DiscreteSpace> mdp, TrainingListenerList listeners, int threadNumber, int deviceNum) {
|
@Setter(AccessLevel.PROTECTED)
|
||||||
|
private UpdateAlgorithm<NN> updateAlgorithm;
|
||||||
|
|
||||||
|
// TODO: Make it configurable with a builder
|
||||||
|
@Setter(AccessLevel.PROTECTED)
|
||||||
|
private ExperienceHandler experienceHandler = new StateActionExperienceHandler();
|
||||||
|
|
||||||
|
public AsyncThreadDiscrete(IAsyncGlobal<NN> asyncGlobal,
|
||||||
|
MDP<O, Integer, DiscreteSpace> mdp,
|
||||||
|
TrainingListenerList listeners,
|
||||||
|
int threadNumber,
|
||||||
|
int deviceNum) {
|
||||||
super(asyncGlobal, mdp, listeners, threadNumber, deviceNum);
|
super(asyncGlobal, mdp, listeners, threadNumber, deviceNum);
|
||||||
synchronized (asyncGlobal) {
|
synchronized (asyncGlobal) {
|
||||||
current = (NN)asyncGlobal.getCurrent().clone();
|
current = (NN)asyncGlobal.getCurrent().clone();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TODO: Add an actor-learner class and be able to inject the update algorithm
|
||||||
|
protected abstract UpdateAlgorithm<NN> buildUpdateAlgorithm();
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void setHistoryProcessor(IHistoryProcessor historyProcessor) {
|
||||||
|
super.setHistoryProcessor(historyProcessor);
|
||||||
|
updateAlgorithm = buildUpdateAlgorithm();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
protected void preEpoch() {
|
||||||
|
experienceHandler.reset();
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* "Subepoch" correspond to the t_max-step iterations
|
* "Subepoch" correspond to the t_max-step iterations
|
||||||
* that stack rewards with t_max MiniTrans
|
* that stack rewards with t_max MiniTrans
|
||||||
|
@ -65,13 +90,11 @@ public abstract class AsyncThreadDiscrete<O, NN extends NeuralNet>
|
||||||
synchronized (getAsyncGlobal()) {
|
synchronized (getAsyncGlobal()) {
|
||||||
current.copy(getAsyncGlobal().getCurrent());
|
current.copy(getAsyncGlobal().getCurrent());
|
||||||
}
|
}
|
||||||
Stack<MiniTrans<Integer>> rewards = new Stack<>();
|
|
||||||
|
|
||||||
Observation obs = sObs;
|
Observation obs = sObs;
|
||||||
IPolicy<O, Integer> policy = getPolicy(current);
|
IPolicy<O, Integer> policy = getPolicy(current);
|
||||||
|
|
||||||
Integer action;
|
Integer action = getMdp().getActionSpace().noOp();
|
||||||
Integer lastAction = getMdp().getActionSpace().noOp();
|
|
||||||
IHistoryProcessor hp = getHistoryProcessor();
|
IHistoryProcessor hp = getHistoryProcessor();
|
||||||
int skipFrame = hp != null ? hp.getConf().getSkipFrame() : 1;
|
int skipFrame = hp != null ? hp.getConf().getSkipFrame() : 1;
|
||||||
|
|
||||||
|
@ -82,21 +105,15 @@ public abstract class AsyncThreadDiscrete<O, NN extends NeuralNet>
|
||||||
while (!getMdp().isDone() && getCurrentEpochStep() < lastStep) {
|
while (!getMdp().isDone() && getCurrentEpochStep() < lastStep) {
|
||||||
|
|
||||||
//if step of training, just repeat lastAction
|
//if step of training, just repeat lastAction
|
||||||
if (obs.isSkipped()) {
|
if (!obs.isSkipped()) {
|
||||||
action = lastAction;
|
|
||||||
} else {
|
|
||||||
action = policy.nextAction(obs);
|
action = policy.nextAction(obs);
|
||||||
}
|
}
|
||||||
|
|
||||||
StepReply<Observation> stepReply = getLegacyMDPWrapper().step(action);
|
StepReply<Observation> stepReply = getLegacyMDPWrapper().step(action);
|
||||||
accuReward += stepReply.getReward() * getConf().getRewardFactor();
|
accuReward += stepReply.getReward() * getConf().getRewardFactor();
|
||||||
|
|
||||||
//if it's not a skipped frame, you can do a step of training
|
|
||||||
if (!obs.isSkipped()) {
|
if (!obs.isSkipped()) {
|
||||||
|
experienceHandler.addExperience(obs, action, accuReward, stepReply.isDone());
|
||||||
INDArray[] output = current.outputAll(obs.getData());
|
|
||||||
rewards.add(new MiniTrans(obs.getData(), action, output, accuReward));
|
|
||||||
|
|
||||||
accuReward = 0;
|
accuReward = 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -104,29 +121,14 @@ public abstract class AsyncThreadDiscrete<O, NN extends NeuralNet>
|
||||||
reward += stepReply.getReward();
|
reward += stepReply.getReward();
|
||||||
|
|
||||||
incrementStep();
|
incrementStep();
|
||||||
lastAction = action;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
//a bit of a trick usable because of how the stack is treated to init R
|
if (getMdp().isDone() && getCurrentEpochStep() < lastStep) {
|
||||||
// FIXME: The last element of minitrans is only used to seed the reward in calcGradient; observation, action and output are ignored.
|
experienceHandler.setFinalObservation(obs);
|
||||||
|
|
||||||
if (getMdp().isDone() && getCurrentEpochStep() < lastStep)
|
|
||||||
rewards.add(new MiniTrans(obs.getData(), null, null, 0));
|
|
||||||
else {
|
|
||||||
INDArray[] output = null;
|
|
||||||
if (getConf().getLearnerUpdateFrequency() == -1)
|
|
||||||
output = current.outputAll(obs.getData());
|
|
||||||
else synchronized (getAsyncGlobal()) {
|
|
||||||
output = getAsyncGlobal().getTarget().outputAll(obs.getData());
|
|
||||||
}
|
|
||||||
double maxQ = Nd4j.max(output[0]).getDouble(0);
|
|
||||||
rewards.add(new MiniTrans(obs.getData(), null, output, maxQ));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
getAsyncGlobal().enqueue(calcGradient(current, rewards), getCurrentEpochStep());
|
getAsyncGlobal().enqueue(updateAlgorithm.computeGradients(current, experienceHandler.generateTrainingBatch()), getCurrentEpochStep());
|
||||||
|
|
||||||
return new SubEpochReturn(getCurrentEpochStep() - stepAtStart, obs, reward, current.getLatestScore());
|
return new SubEpochReturn(getCurrentEpochStep() - stepAtStart, obs, reward, current.getLatestScore());
|
||||||
}
|
}
|
||||||
|
|
||||||
public abstract Gradient[] calcGradient(NN nn, Stack<MiniTrans<Integer>> rewards);
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
/*******************************************************************************
|
/*******************************************************************************
|
||||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
* Copyright (c) 2020 Konduit K.K.
|
||||||
*
|
*
|
||||||
* This program and the accompanying materials are made available under the
|
* This program and the accompanying materials are made available under the
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
@ -13,28 +13,14 @@
|
||||||
*
|
*
|
||||||
* SPDX-License-Identifier: Apache-2.0
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
******************************************************************************/
|
******************************************************************************/
|
||||||
|
|
||||||
package org.deeplearning4j.rl4j.learning.async;
|
package org.deeplearning4j.rl4j.learning.async;
|
||||||
|
|
||||||
import lombok.AllArgsConstructor;
|
import org.deeplearning4j.nn.gradient.Gradient;
|
||||||
import lombok.Value;
|
import org.deeplearning4j.rl4j.experience.StateActionPair;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.deeplearning4j.rl4j.network.NeuralNet;
|
||||||
|
|
||||||
/**
|
import java.util.List;
|
||||||
* @author rubenfiszel (ruben.fiszel@epfl.ch) on 8/5/16.
|
|
||||||
*
|
public interface UpdateAlgorithm<NN extends NeuralNet> {
|
||||||
* Its called a MiniTrans because it is similar to a Transition
|
Gradient[] computeGradients(NN current, List<StateActionPair<Integer>> experience);
|
||||||
* but without a next observation
|
|
||||||
*
|
|
||||||
* It is stacked and then processed by AsyncNStepQL or A3C
|
|
||||||
* following the paper implementation https://arxiv.org/abs/1602.01783 paper.
|
|
||||||
*
|
|
||||||
*/
|
|
||||||
@AllArgsConstructor
|
|
||||||
@Value
|
|
||||||
public class MiniTrans<A> {
|
|
||||||
INDArray obs;
|
|
||||||
A action;
|
|
||||||
INDArray[] output;
|
|
||||||
double reward;
|
|
||||||
}
|
}
|
|
@ -18,11 +18,7 @@
|
||||||
package org.deeplearning4j.rl4j.learning.async.a3c.discrete;
|
package org.deeplearning4j.rl4j.learning.async.a3c.discrete;
|
||||||
|
|
||||||
import lombok.Getter;
|
import lombok.Getter;
|
||||||
import org.deeplearning4j.nn.gradient.Gradient;
|
import org.deeplearning4j.rl4j.learning.async.*;
|
||||||
import org.deeplearning4j.rl4j.learning.Learning;
|
|
||||||
import org.deeplearning4j.rl4j.learning.async.AsyncThreadDiscrete;
|
|
||||||
import org.deeplearning4j.rl4j.learning.async.IAsyncGlobal;
|
|
||||||
import org.deeplearning4j.rl4j.learning.async.MiniTrans;
|
|
||||||
import org.deeplearning4j.rl4j.learning.configuration.A3CLearningConfiguration;
|
import org.deeplearning4j.rl4j.learning.configuration.A3CLearningConfiguration;
|
||||||
import org.deeplearning4j.rl4j.learning.listener.TrainingListenerList;
|
import org.deeplearning4j.rl4j.learning.listener.TrainingListenerList;
|
||||||
import org.deeplearning4j.rl4j.mdp.MDP;
|
import org.deeplearning4j.rl4j.mdp.MDP;
|
||||||
|
@ -34,9 +30,7 @@ import org.deeplearning4j.rl4j.space.Encodable;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.api.rng.Random;
|
import org.nd4j.linalg.api.rng.Random;
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
import org.nd4j.linalg.indexing.NDArrayIndex;
|
import org.nd4j.linalg.api.rng.Random;
|
||||||
|
|
||||||
import java.util.Stack;
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @author rubenfiszel (ruben.fiszel@epfl.ch) 7/23/16.
|
* @author rubenfiszel (ruben.fiszel@epfl.ch) 7/23/16.
|
||||||
|
@ -67,6 +61,8 @@ public class A3CThreadDiscrete<O extends Encodable> extends AsyncThreadDiscrete<
|
||||||
if(seed != null) {
|
if(seed != null) {
|
||||||
rnd.setSeed(seed + threadNumber);
|
rnd.setSeed(seed + threadNumber);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
setUpdateAlgorithm(buildUpdateAlgorithm());
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
@ -74,52 +70,9 @@ public class A3CThreadDiscrete<O extends Encodable> extends AsyncThreadDiscrete<
|
||||||
return new ACPolicy(net, rnd);
|
return new ACPolicy(net, rnd);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
|
||||||
* calc the gradients based on the n-step rewards
|
|
||||||
*/
|
|
||||||
@Override
|
@Override
|
||||||
public Gradient[] calcGradient(IActorCritic iac, Stack<MiniTrans<Integer>> rewards) {
|
protected UpdateAlgorithm<IActorCritic> buildUpdateAlgorithm() {
|
||||||
MiniTrans<Integer> minTrans = rewards.pop();
|
int[] shape = getHistoryProcessor() == null ? getMdp().getObservationSpace().getShape() : getHistoryProcessor().getConf().getShape();
|
||||||
|
return new A3CUpdateAlgorithm(asyncGlobal, shape, getMdp().getActionSpace().getSize(), conf.getLearnerUpdateFrequency(), conf.getGamma());
|
||||||
int size = rewards.size();
|
|
||||||
|
|
||||||
//if recurrent then train as a time serie with a batch size of 1
|
|
||||||
boolean recurrent = getAsyncGlobal().getCurrent().isRecurrent();
|
|
||||||
|
|
||||||
int[] shape = getHistoryProcessor() == null ? getMdp().getObservationSpace().getShape()
|
|
||||||
: getHistoryProcessor().getConf().getShape();
|
|
||||||
int[] nshape = recurrent ? Learning.makeShape(1, shape, size)
|
|
||||||
: Learning.makeShape(size, shape);
|
|
||||||
|
|
||||||
INDArray input = Nd4j.create(nshape);
|
|
||||||
INDArray targets = recurrent ? Nd4j.create(1, 1, size) : Nd4j.create(size, 1);
|
|
||||||
INDArray logSoftmax = recurrent ? Nd4j.zeros(1, getMdp().getActionSpace().getSize(), size)
|
|
||||||
: Nd4j.zeros(size, getMdp().getActionSpace().getSize());
|
|
||||||
|
|
||||||
double r = minTrans.getReward();
|
|
||||||
for (int i = size - 1; i >= 0; i--) {
|
|
||||||
minTrans = rewards.pop();
|
|
||||||
|
|
||||||
r = minTrans.getReward() + conf.getGamma() * r;
|
|
||||||
if (recurrent) {
|
|
||||||
input.get(NDArrayIndex.point(0), NDArrayIndex.all(), NDArrayIndex.point(i)).assign(minTrans.getObs());
|
|
||||||
} else {
|
|
||||||
input.putRow(i, minTrans.getObs());
|
|
||||||
}
|
|
||||||
|
|
||||||
//the critic
|
|
||||||
targets.putScalar(i, r);
|
|
||||||
|
|
||||||
//the actor
|
|
||||||
double expectedV = minTrans.getOutput()[0].getDouble(0);
|
|
||||||
double advantage = r - expectedV;
|
|
||||||
if (recurrent) {
|
|
||||||
logSoftmax.putScalar(0, minTrans.getAction(), i, advantage);
|
|
||||||
} else {
|
|
||||||
logSoftmax.putScalar(i, minTrans.getAction(), advantage);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return iac.gradient(input, new INDArray[] {targets, logSoftmax});
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -0,0 +1,113 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2020 Konduit K.K.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
package org.deeplearning4j.rl4j.learning.async.a3c.discrete;
|
||||||
|
|
||||||
|
import org.deeplearning4j.nn.gradient.Gradient;
|
||||||
|
import org.deeplearning4j.rl4j.experience.StateActionPair;
|
||||||
|
import org.deeplearning4j.rl4j.learning.Learning;
|
||||||
|
import org.deeplearning4j.rl4j.learning.async.IAsyncGlobal;
|
||||||
|
import org.deeplearning4j.rl4j.learning.async.UpdateAlgorithm;
|
||||||
|
import org.deeplearning4j.rl4j.network.ac.IActorCritic;
|
||||||
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
|
import org.nd4j.linalg.indexing.NDArrayIndex;
|
||||||
|
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
|
public class A3CUpdateAlgorithm implements UpdateAlgorithm<IActorCritic> {
|
||||||
|
|
||||||
|
private final IAsyncGlobal asyncGlobal;
|
||||||
|
private final int[] shape;
|
||||||
|
private final int actionSpaceSize;
|
||||||
|
private final int targetDqnUpdateFreq;
|
||||||
|
private final double gamma;
|
||||||
|
private final boolean recurrent;
|
||||||
|
|
||||||
|
public A3CUpdateAlgorithm(IAsyncGlobal asyncGlobal,
|
||||||
|
int[] shape,
|
||||||
|
int actionSpaceSize,
|
||||||
|
int targetDqnUpdateFreq,
|
||||||
|
double gamma) {
|
||||||
|
|
||||||
|
this.asyncGlobal = asyncGlobal;
|
||||||
|
|
||||||
|
//if recurrent then train as a time serie with a batch size of 1
|
||||||
|
recurrent = asyncGlobal.getCurrent().isRecurrent();
|
||||||
|
this.shape = shape;
|
||||||
|
this.actionSpaceSize = actionSpaceSize;
|
||||||
|
this.targetDqnUpdateFreq = targetDqnUpdateFreq;
|
||||||
|
this.gamma = gamma;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public Gradient[] computeGradients(IActorCritic current, List<StateActionPair<Integer>> experience) {
|
||||||
|
int size = experience.size();
|
||||||
|
|
||||||
|
int[] nshape = recurrent ? Learning.makeShape(1, shape, size)
|
||||||
|
: Learning.makeShape(size, shape);
|
||||||
|
|
||||||
|
INDArray input = Nd4j.create(nshape);
|
||||||
|
INDArray targets = recurrent ? Nd4j.create(1, 1, size) : Nd4j.create(size, 1);
|
||||||
|
INDArray logSoftmax = recurrent ? Nd4j.zeros(1, actionSpaceSize, size)
|
||||||
|
: Nd4j.zeros(size, actionSpaceSize);
|
||||||
|
|
||||||
|
StateActionPair<Integer> stateActionPair = experience.get(size - 1);
|
||||||
|
double r;
|
||||||
|
if(stateActionPair.isTerminal()) {
|
||||||
|
r = 0;
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
INDArray[] output = null;
|
||||||
|
if (targetDqnUpdateFreq == -1)
|
||||||
|
output = current.outputAll(stateActionPair.getObservation().getData());
|
||||||
|
else synchronized (asyncGlobal) {
|
||||||
|
output = asyncGlobal.getTarget().outputAll(stateActionPair.getObservation().getData());
|
||||||
|
}
|
||||||
|
r = output[0].getDouble(0);
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int i = size - 1; i >= 0; --i) {
|
||||||
|
stateActionPair = experience.get(i);
|
||||||
|
|
||||||
|
INDArray observationData = stateActionPair.getObservation().getData();
|
||||||
|
|
||||||
|
INDArray[] output = current.outputAll(observationData);
|
||||||
|
|
||||||
|
r = stateActionPair.getReward() + gamma * r;
|
||||||
|
if (recurrent) {
|
||||||
|
input.get(NDArrayIndex.point(0), NDArrayIndex.all(), NDArrayIndex.point(i)).assign(observationData);
|
||||||
|
} else {
|
||||||
|
input.putRow(i, observationData);
|
||||||
|
}
|
||||||
|
|
||||||
|
//the critic
|
||||||
|
targets.putScalar(i, r);
|
||||||
|
|
||||||
|
//the actor
|
||||||
|
double expectedV = output[0].getDouble(0);
|
||||||
|
double advantage = r - expectedV;
|
||||||
|
if (recurrent) {
|
||||||
|
logSoftmax.putScalar(0, stateActionPair.getAction(), i, advantage);
|
||||||
|
} else {
|
||||||
|
logSoftmax.putScalar(i, stateActionPair.getAction(), advantage);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// targets -> value, critic
|
||||||
|
// logSoftmax -> policy, actor
|
||||||
|
return current.gradient(input, new INDArray[] {targets, logSoftmax});
|
||||||
|
}
|
||||||
|
}
|
|
@ -18,11 +18,9 @@
|
||||||
package org.deeplearning4j.rl4j.learning.async.nstep.discrete;
|
package org.deeplearning4j.rl4j.learning.async.nstep.discrete;
|
||||||
|
|
||||||
import lombok.Getter;
|
import lombok.Getter;
|
||||||
import org.deeplearning4j.nn.gradient.Gradient;
|
|
||||||
import org.deeplearning4j.rl4j.learning.Learning;
|
|
||||||
import org.deeplearning4j.rl4j.learning.async.AsyncThreadDiscrete;
|
import org.deeplearning4j.rl4j.learning.async.AsyncThreadDiscrete;
|
||||||
import org.deeplearning4j.rl4j.learning.async.IAsyncGlobal;
|
import org.deeplearning4j.rl4j.learning.async.IAsyncGlobal;
|
||||||
import org.deeplearning4j.rl4j.learning.async.MiniTrans;
|
import org.deeplearning4j.rl4j.learning.async.UpdateAlgorithm;
|
||||||
import org.deeplearning4j.rl4j.learning.configuration.AsyncQLearningConfiguration;
|
import org.deeplearning4j.rl4j.learning.configuration.AsyncQLearningConfiguration;
|
||||||
import org.deeplearning4j.rl4j.learning.listener.TrainingListenerList;
|
import org.deeplearning4j.rl4j.learning.listener.TrainingListenerList;
|
||||||
import org.deeplearning4j.rl4j.mdp.MDP;
|
import org.deeplearning4j.rl4j.mdp.MDP;
|
||||||
|
@ -32,12 +30,9 @@ import org.deeplearning4j.rl4j.policy.EpsGreedy;
|
||||||
import org.deeplearning4j.rl4j.policy.Policy;
|
import org.deeplearning4j.rl4j.policy.Policy;
|
||||||
import org.deeplearning4j.rl4j.space.DiscreteSpace;
|
import org.deeplearning4j.rl4j.space.DiscreteSpace;
|
||||||
import org.deeplearning4j.rl4j.space.Encodable;
|
import org.deeplearning4j.rl4j.space.Encodable;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
import org.nd4j.linalg.api.rng.Random;
|
import org.nd4j.linalg.api.rng.Random;
|
||||||
|
|
||||||
import java.util.Stack;
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @author rubenfiszel (ruben.fiszel@epfl.ch) on 8/5/16.
|
* @author rubenfiszel (ruben.fiszel@epfl.ch) on 8/5/16.
|
||||||
*/
|
*/
|
||||||
|
@ -65,6 +60,8 @@ public class AsyncNStepQLearningThreadDiscrete<O extends Encodable> extends Asyn
|
||||||
if(seed != null) {
|
if(seed != null) {
|
||||||
rnd.setSeed(seed + threadNumber);
|
rnd.setSeed(seed + threadNumber);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
setUpdateAlgorithm(buildUpdateAlgorithm());
|
||||||
}
|
}
|
||||||
|
|
||||||
public Policy<O, Integer> getPolicy(IDQN nn) {
|
public Policy<O, Integer> getPolicy(IDQN nn) {
|
||||||
|
@ -72,32 +69,9 @@ public class AsyncNStepQLearningThreadDiscrete<O extends Encodable> extends Asyn
|
||||||
rnd, conf.getMinEpsilon(), this);
|
rnd, conf.getMinEpsilon(), this);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
protected UpdateAlgorithm<IDQN> buildUpdateAlgorithm() {
|
||||||
//calc the gradient based on the n-step rewards
|
int[] shape = getHistoryProcessor() == null ? getMdp().getObservationSpace().getShape() : getHistoryProcessor().getConf().getShape();
|
||||||
public Gradient[] calcGradient(IDQN current, Stack<MiniTrans<Integer>> rewards) {
|
return new QLearningUpdateAlgorithm(asyncGlobal, shape, getMdp().getActionSpace().getSize(), conf.getTargetDqnUpdateFreq(), conf.getGamma());
|
||||||
|
|
||||||
MiniTrans<Integer> minTrans = rewards.pop();
|
|
||||||
|
|
||||||
int size = rewards.size();
|
|
||||||
|
|
||||||
int[] shape = getHistoryProcessor() == null ? getMdp().getObservationSpace().getShape()
|
|
||||||
: getHistoryProcessor().getConf().getShape();
|
|
||||||
int[] nshape = Learning.makeShape(size, shape);
|
|
||||||
INDArray input = Nd4j.create(nshape);
|
|
||||||
INDArray targets = Nd4j.create(size, getMdp().getActionSpace().getSize());
|
|
||||||
|
|
||||||
double r = minTrans.getReward();
|
|
||||||
for (int i = size - 1; i >= 0; i--) {
|
|
||||||
minTrans = rewards.pop();
|
|
||||||
|
|
||||||
r = minTrans.getReward() + conf.getGamma() * r;
|
|
||||||
input.putRow(i, minTrans.getObs());
|
|
||||||
INDArray row = minTrans.getOutput()[0];
|
|
||||||
row = row.putScalar(minTrans.getAction(), r);
|
|
||||||
targets.putRow(i, row);
|
|
||||||
}
|
|
||||||
|
|
||||||
return current.gradient(input, targets);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -0,0 +1,88 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2020 Konduit K.K.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
package org.deeplearning4j.rl4j.learning.async.nstep.discrete;
|
||||||
|
|
||||||
|
import org.deeplearning4j.nn.gradient.Gradient;
|
||||||
|
import org.deeplearning4j.rl4j.experience.StateActionPair;
|
||||||
|
import org.deeplearning4j.rl4j.learning.Learning;
|
||||||
|
import org.deeplearning4j.rl4j.learning.async.IAsyncGlobal;
|
||||||
|
import org.deeplearning4j.rl4j.learning.async.UpdateAlgorithm;
|
||||||
|
import org.deeplearning4j.rl4j.network.dqn.IDQN;
|
||||||
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
|
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
|
public class QLearningUpdateAlgorithm implements UpdateAlgorithm<IDQN> {
|
||||||
|
|
||||||
|
private final IAsyncGlobal asyncGlobal;
|
||||||
|
private final int[] shape;
|
||||||
|
private final int actionSpaceSize;
|
||||||
|
private final int targetDqnUpdateFreq;
|
||||||
|
private final double gamma;
|
||||||
|
|
||||||
|
public QLearningUpdateAlgorithm(IAsyncGlobal asyncGlobal,
|
||||||
|
int[] shape,
|
||||||
|
int actionSpaceSize,
|
||||||
|
int targetDqnUpdateFreq,
|
||||||
|
double gamma) {
|
||||||
|
|
||||||
|
this.asyncGlobal = asyncGlobal;
|
||||||
|
this.shape = shape;
|
||||||
|
this.actionSpaceSize = actionSpaceSize;
|
||||||
|
this.targetDqnUpdateFreq = targetDqnUpdateFreq;
|
||||||
|
this.gamma = gamma;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public Gradient[] computeGradients(IDQN current, List<StateActionPair<Integer>> experience) {
|
||||||
|
int size = experience.size();
|
||||||
|
|
||||||
|
int[] nshape = Learning.makeShape(size, shape);
|
||||||
|
INDArray input = Nd4j.create(nshape);
|
||||||
|
INDArray targets = Nd4j.create(size, actionSpaceSize);
|
||||||
|
|
||||||
|
StateActionPair<Integer> stateActionPair = experience.get(size - 1);
|
||||||
|
|
||||||
|
double r;
|
||||||
|
if(stateActionPair.isTerminal()) {
|
||||||
|
r = 0;
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
INDArray[] output = null;
|
||||||
|
if (targetDqnUpdateFreq == -1)
|
||||||
|
output = current.outputAll(stateActionPair.getObservation().getData());
|
||||||
|
else synchronized (asyncGlobal) {
|
||||||
|
output = asyncGlobal.getTarget().outputAll(stateActionPair.getObservation().getData());
|
||||||
|
}
|
||||||
|
r = Nd4j.max(output[0]).getDouble(0);
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int i = size - 1; i >= 0; i--) {
|
||||||
|
stateActionPair = experience.get(i);
|
||||||
|
|
||||||
|
input.putRow(i, stateActionPair.getObservation().getData());
|
||||||
|
|
||||||
|
r = stateActionPair.getReward() + gamma * r;
|
||||||
|
INDArray[] output = current.outputAll(stateActionPair.getObservation().getData());
|
||||||
|
INDArray row = output[0];
|
||||||
|
row = row.putScalar(stateActionPair.getAction(), r);
|
||||||
|
targets.putRow(i, row);
|
||||||
|
}
|
||||||
|
|
||||||
|
return current.gradient(input, targets);
|
||||||
|
}
|
||||||
|
}
|
|
@ -80,6 +80,9 @@ public class ExpReplay<A> implements IExpReplay<A> {
|
||||||
//log.info("size: "+storage.size());
|
//log.info("size: "+storage.size());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public int getBatchSize() {
|
||||||
|
int storageSize = storage.size();
|
||||||
|
return Math.min(storageSize, batchSize);
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -32,6 +32,11 @@ import java.util.ArrayList;
|
||||||
*/
|
*/
|
||||||
public interface IExpReplay<A> {
|
public interface IExpReplay<A> {
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @return The size of the batch that will be returned by getBatch()
|
||||||
|
*/
|
||||||
|
int getBatchSize();
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @return a batch of uniformly sampled transitions
|
* @return a batch of uniformly sampled transitions
|
||||||
*/
|
*/
|
||||||
|
@ -42,5 +47,4 @@ public interface IExpReplay<A> {
|
||||||
* @param transition a new transition to store
|
* @param transition a new transition to store
|
||||||
*/
|
*/
|
||||||
void store(Transition<A> transition);
|
void store(Transition<A> transition);
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -60,32 +60,8 @@ public abstract class QLearning<O extends Encodable, A, AS extends ActionSpace<A
|
||||||
extends SyncLearning<O, A, AS, IDQN>
|
extends SyncLearning<O, A, AS, IDQN>
|
||||||
implements TargetQNetworkSource, EpochStepCounter {
|
implements TargetQNetworkSource, EpochStepCounter {
|
||||||
|
|
||||||
// FIXME Changed for refac
|
|
||||||
// @Getter
|
|
||||||
// final private IExpReplay<A> expReplay;
|
|
||||||
@Getter
|
|
||||||
@Setter(AccessLevel.PROTECTED)
|
|
||||||
protected IExpReplay<A> expReplay;
|
|
||||||
|
|
||||||
protected abstract LegacyMDPWrapper<O, A, AS> getLegacyMDPWrapper();
|
protected abstract LegacyMDPWrapper<O, A, AS> getLegacyMDPWrapper();
|
||||||
|
|
||||||
public QLearning(QLearningConfiguration conf) {
|
|
||||||
this(conf, getSeededRandom(conf.getSeed()));
|
|
||||||
}
|
|
||||||
|
|
||||||
public QLearning(QLearningConfiguration conf, Random random) {
|
|
||||||
expReplay = new ExpReplay<>(conf.getExpRepMaxSize(), conf.getBatchSize(), random);
|
|
||||||
}
|
|
||||||
|
|
||||||
private static Random getSeededRandom(Long seed) {
|
|
||||||
Random rnd = Nd4j.getRandom();
|
|
||||||
if(seed != null) {
|
|
||||||
rnd.setSeed(seed);
|
|
||||||
}
|
|
||||||
|
|
||||||
return rnd;
|
|
||||||
}
|
|
||||||
|
|
||||||
protected abstract EpsGreedy<O, A, AS> getEgPolicy();
|
protected abstract EpsGreedy<O, A, AS> getEgPolicy();
|
||||||
|
|
||||||
public abstract MDP<O, A, AS> getMdp();
|
public abstract MDP<O, A, AS> getMdp();
|
||||||
|
|
|
@ -21,6 +21,8 @@ import lombok.AccessLevel;
|
||||||
import lombok.Getter;
|
import lombok.Getter;
|
||||||
import lombok.Setter;
|
import lombok.Setter;
|
||||||
import org.deeplearning4j.gym.StepReply;
|
import org.deeplearning4j.gym.StepReply;
|
||||||
|
import org.deeplearning4j.rl4j.experience.ExperienceHandler;
|
||||||
|
import org.deeplearning4j.rl4j.experience.ReplayMemoryExperienceHandler;
|
||||||
import org.deeplearning4j.rl4j.learning.IHistoryProcessor;
|
import org.deeplearning4j.rl4j.learning.IHistoryProcessor;
|
||||||
import org.deeplearning4j.rl4j.learning.Learning;
|
import org.deeplearning4j.rl4j.learning.Learning;
|
||||||
import org.deeplearning4j.rl4j.learning.configuration.QLearningConfiguration;
|
import org.deeplearning4j.rl4j.learning.configuration.QLearningConfiguration;
|
||||||
|
@ -42,7 +44,7 @@ import org.nd4j.linalg.api.rng.Random;
|
||||||
import org.nd4j.linalg.dataset.api.DataSet;
|
import org.nd4j.linalg.dataset.api.DataSet;
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
|
|
||||||
import java.util.ArrayList;
|
import java.util.List;
|
||||||
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -71,10 +73,12 @@ public abstract class QLearningDiscrete<O extends Encodable> extends QLearning<O
|
||||||
private int lastAction;
|
private int lastAction;
|
||||||
private double accuReward = 0;
|
private double accuReward = 0;
|
||||||
|
|
||||||
private Transition pendingTransition;
|
|
||||||
|
|
||||||
ITDTargetAlgorithm tdTargetAlgorithm;
|
ITDTargetAlgorithm tdTargetAlgorithm;
|
||||||
|
|
||||||
|
// TODO: User a builder and remove the setter
|
||||||
|
@Getter(AccessLevel.PROTECTED) @Setter
|
||||||
|
private ExperienceHandler<Integer, Transition<Integer>> experienceHandler;
|
||||||
|
|
||||||
protected LegacyMDPWrapper<O, Integer, DiscreteSpace> getLegacyMDPWrapper() {
|
protected LegacyMDPWrapper<O, Integer, DiscreteSpace> getLegacyMDPWrapper() {
|
||||||
return mdp;
|
return mdp;
|
||||||
}
|
}
|
||||||
|
@ -85,7 +89,6 @@ public abstract class QLearningDiscrete<O extends Encodable> extends QLearning<O
|
||||||
|
|
||||||
public QLearningDiscrete(MDP<O, Integer, DiscreteSpace> mdp, IDQN dqn, QLearningConfiguration conf,
|
public QLearningDiscrete(MDP<O, Integer, DiscreteSpace> mdp, IDQN dqn, QLearningConfiguration conf,
|
||||||
int epsilonNbStep, Random random) {
|
int epsilonNbStep, Random random) {
|
||||||
super(conf);
|
|
||||||
this.configuration = conf;
|
this.configuration = conf;
|
||||||
this.mdp = new LegacyMDPWrapper<>(mdp, null, this);
|
this.mdp = new LegacyMDPWrapper<>(mdp, null, this);
|
||||||
qNetwork = dqn;
|
qNetwork = dqn;
|
||||||
|
@ -98,6 +101,7 @@ public abstract class QLearningDiscrete<O extends Encodable> extends QLearning<O
|
||||||
? new DoubleDQN(this, conf.getGamma(), conf.getErrorClamp())
|
? new DoubleDQN(this, conf.getGamma(), conf.getErrorClamp())
|
||||||
: new StandardDQN(this, conf.getGamma(), conf.getErrorClamp());
|
: new StandardDQN(this, conf.getGamma(), conf.getErrorClamp());
|
||||||
|
|
||||||
|
experienceHandler = new ReplayMemoryExperienceHandler(conf.getExpRepMaxSize(), conf.getBatchSize(), random);
|
||||||
}
|
}
|
||||||
|
|
||||||
public MDP<O, Integer, DiscreteSpace> getMdp() {
|
public MDP<O, Integer, DiscreteSpace> getMdp() {
|
||||||
|
@ -114,7 +118,7 @@ public abstract class QLearningDiscrete<O extends Encodable> extends QLearning<O
|
||||||
public void preEpoch() {
|
public void preEpoch() {
|
||||||
lastAction = mdp.getActionSpace().noOp();
|
lastAction = mdp.getActionSpace().noOp();
|
||||||
accuReward = 0;
|
accuReward = 0;
|
||||||
pendingTransition = null;
|
experienceHandler.reset();
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
@ -131,8 +135,6 @@ public abstract class QLearningDiscrete<O extends Encodable> extends QLearning<O
|
||||||
*/
|
*/
|
||||||
protected QLStepReturn<Observation> trainStep(Observation obs) {
|
protected QLStepReturn<Observation> trainStep(Observation obs) {
|
||||||
|
|
||||||
Integer action;
|
|
||||||
|
|
||||||
boolean isHistoryProcessor = getHistoryProcessor() != null;
|
boolean isHistoryProcessor = getHistoryProcessor() != null;
|
||||||
int skipFrame = isHistoryProcessor ? getHistoryProcessor().getConf().getSkipFrame() : 1;
|
int skipFrame = isHistoryProcessor ? getHistoryProcessor().getConf().getSkipFrame() : 1;
|
||||||
int historyLength = isHistoryProcessor ? getHistoryProcessor().getConf().getHistoryLength() : 1;
|
int historyLength = isHistoryProcessor ? getHistoryProcessor().getConf().getHistoryLength() : 1;
|
||||||
|
@ -142,37 +144,28 @@ public abstract class QLearningDiscrete<O extends Encodable> extends QLearning<O
|
||||||
Double maxQ = Double.NaN; //ignore if Nan for stats
|
Double maxQ = Double.NaN; //ignore if Nan for stats
|
||||||
|
|
||||||
//if step of training, just repeat lastAction
|
//if step of training, just repeat lastAction
|
||||||
if (obs.isSkipped()) {
|
if (!obs.isSkipped()) {
|
||||||
action = lastAction;
|
|
||||||
} else {
|
|
||||||
INDArray qs = getQNetwork().output(obs);
|
INDArray qs = getQNetwork().output(obs);
|
||||||
int maxAction = Learning.getMaxAction(qs);
|
int maxAction = Learning.getMaxAction(qs);
|
||||||
maxQ = qs.getDouble(maxAction);
|
maxQ = qs.getDouble(maxAction);
|
||||||
|
|
||||||
action = getEgPolicy().nextAction(obs);
|
lastAction = getEgPolicy().nextAction(obs);
|
||||||
}
|
}
|
||||||
|
|
||||||
lastAction = action;
|
StepReply<Observation> stepReply = mdp.step(lastAction);
|
||||||
|
|
||||||
StepReply<Observation> stepReply = mdp.step(action);
|
|
||||||
|
|
||||||
accuReward += stepReply.getReward() * configuration.getRewardFactor();
|
accuReward += stepReply.getReward() * configuration.getRewardFactor();
|
||||||
|
|
||||||
//if it's not a skipped frame, you can do a step of training
|
//if it's not a skipped frame, you can do a step of training
|
||||||
if (!obs.isSkipped()) {
|
if (!obs.isSkipped()) {
|
||||||
|
|
||||||
// Add experience
|
// Add experience
|
||||||
if (pendingTransition != null) {
|
experienceHandler.addExperience(obs, lastAction, accuReward, stepReply.isDone());
|
||||||
pendingTransition.setNextObservation(obs);
|
|
||||||
getExpReplay().store(pendingTransition);
|
|
||||||
}
|
|
||||||
pendingTransition = new Transition(obs, action, accuReward, stepReply.isDone());
|
|
||||||
accuReward = 0;
|
accuReward = 0;
|
||||||
|
|
||||||
// Update NN
|
// Update NN
|
||||||
// FIXME: maybe start updating when experience replay has reached a certain size instead of using "updateStart"?
|
// FIXME: maybe start updating when experience replay has reached a certain size instead of using "updateStart"?
|
||||||
if (getStepCounter() > updateStart) {
|
if (getStepCounter() > updateStart) {
|
||||||
DataSet targets = setTarget(getExpReplay().getBatch());
|
DataSet targets = setTarget(experienceHandler.generateTrainingBatch());
|
||||||
getQNetwork().fit(targets.getFeatures(), targets.getLabels());
|
getQNetwork().fit(targets.getFeatures(), targets.getLabels());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -180,7 +173,7 @@ public abstract class QLearningDiscrete<O extends Encodable> extends QLearning<O
|
||||||
return new QLStepReturn<Observation>(maxQ, getQNetwork().getLatestScore(), stepReply);
|
return new QLStepReturn<Observation>(maxQ, getQNetwork().getLatestScore(), stepReply);
|
||||||
}
|
}
|
||||||
|
|
||||||
protected DataSet setTarget(ArrayList<Transition<Integer>> transitions) {
|
protected DataSet setTarget(List<Transition<Integer>> transitions) {
|
||||||
if (transitions.size() == 0)
|
if (transitions.size() == 0)
|
||||||
throw new IllegalArgumentException("too few transitions");
|
throw new IllegalArgumentException("too few transitions");
|
||||||
|
|
||||||
|
@ -189,9 +182,6 @@ public abstract class QLearningDiscrete<O extends Encodable> extends QLearning<O
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
protected void finishEpoch(Observation observation) {
|
protected void finishEpoch(Observation observation) {
|
||||||
if (pendingTransition != null) {
|
experienceHandler.setFinalObservation(observation);
|
||||||
pendingTransition.setNextObservation(observation);
|
|
||||||
getExpReplay().store(pendingTransition);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -0,0 +1,107 @@
|
||||||
|
package org.deeplearning4j.rl4j.experience;
|
||||||
|
|
||||||
|
import org.deeplearning4j.rl4j.learning.sync.IExpReplay;
|
||||||
|
import org.deeplearning4j.rl4j.learning.sync.Transition;
|
||||||
|
import org.deeplearning4j.rl4j.observation.Observation;
|
||||||
|
import org.junit.Test;
|
||||||
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
|
|
||||||
|
import java.util.ArrayList;
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
|
import static org.junit.Assert.assertEquals;
|
||||||
|
|
||||||
|
public class ReplayMemoryExperienceHandlerTest {
|
||||||
|
@Test
|
||||||
|
public void when_addingFirstExperience_expect_notAddedToStoreBeforeNextObservationIsAdded() {
|
||||||
|
// Arrange
|
||||||
|
TestExpReplay expReplayMock = new TestExpReplay();
|
||||||
|
ReplayMemoryExperienceHandler sut = new ReplayMemoryExperienceHandler(expReplayMock);
|
||||||
|
|
||||||
|
// Act
|
||||||
|
sut.addExperience(new Observation(Nd4j.create(new double[] { 1.0 })), 1, 1.0, false);
|
||||||
|
int numStoredTransitions = expReplayMock.addedTransitions.size();
|
||||||
|
sut.addExperience(new Observation(Nd4j.create(new double[] { 2.0 })), 2, 2.0, false);
|
||||||
|
|
||||||
|
// Assert
|
||||||
|
assertEquals(0, numStoredTransitions);
|
||||||
|
assertEquals(1, expReplayMock.addedTransitions.size());
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void when_addingExperience_expect_transitionsAreCorrect() {
|
||||||
|
// Arrange
|
||||||
|
TestExpReplay expReplayMock = new TestExpReplay();
|
||||||
|
ReplayMemoryExperienceHandler sut = new ReplayMemoryExperienceHandler(expReplayMock);
|
||||||
|
|
||||||
|
// Act
|
||||||
|
sut.addExperience(new Observation(Nd4j.create(new double[] { 1.0 })), 1, 1.0, false);
|
||||||
|
sut.addExperience(new Observation(Nd4j.create(new double[] { 2.0 })), 2, 2.0, false);
|
||||||
|
sut.setFinalObservation(new Observation(Nd4j.create(new double[] { 3.0 })));
|
||||||
|
|
||||||
|
// Assert
|
||||||
|
assertEquals(2, expReplayMock.addedTransitions.size());
|
||||||
|
|
||||||
|
assertEquals(1.0, expReplayMock.addedTransitions.get(0).getObservation().getData().getDouble(0), 0.00001);
|
||||||
|
assertEquals(1, (int)expReplayMock.addedTransitions.get(0).getAction());
|
||||||
|
assertEquals(1.0, expReplayMock.addedTransitions.get(0).getReward(), 0.00001);
|
||||||
|
assertEquals(2.0, expReplayMock.addedTransitions.get(0).getNextObservation().getDouble(0), 0.00001);
|
||||||
|
|
||||||
|
assertEquals(2.0, expReplayMock.addedTransitions.get(1).getObservation().getData().getDouble(0), 0.00001);
|
||||||
|
assertEquals(2, (int)expReplayMock.addedTransitions.get(1).getAction());
|
||||||
|
assertEquals(2.0, expReplayMock.addedTransitions.get(1).getReward(), 0.00001);
|
||||||
|
assertEquals(3.0, expReplayMock.addedTransitions.get(1).getNextObservation().getDouble(0), 0.00001);
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void when_settingFinalObservation_expect_nextAddedExperienceDoNotUsePreviousObservation() {
|
||||||
|
// Arrange
|
||||||
|
TestExpReplay expReplayMock = new TestExpReplay();
|
||||||
|
ReplayMemoryExperienceHandler sut = new ReplayMemoryExperienceHandler(expReplayMock);
|
||||||
|
|
||||||
|
// Act
|
||||||
|
sut.addExperience(new Observation(Nd4j.create(new double[] { 1.0 })), 1, 1.0, false);
|
||||||
|
sut.setFinalObservation(new Observation(Nd4j.create(new double[] { 2.0 })));
|
||||||
|
sut.addExperience(new Observation(Nd4j.create(new double[] { 3.0 })), 3, 3.0, false);
|
||||||
|
|
||||||
|
// Assert
|
||||||
|
assertEquals(1, expReplayMock.addedTransitions.size());
|
||||||
|
assertEquals(1, (int)expReplayMock.addedTransitions.get(0).getAction());
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void when_addingExperience_expect_getTrainingBatchSizeReturnSize() {
|
||||||
|
// Arrange
|
||||||
|
TestExpReplay expReplayMock = new TestExpReplay();
|
||||||
|
ReplayMemoryExperienceHandler sut = new ReplayMemoryExperienceHandler(expReplayMock);
|
||||||
|
sut.addExperience(new Observation(Nd4j.create(new double[] { 1.0 })), 1, 1.0, false);
|
||||||
|
sut.addExperience(new Observation(Nd4j.create(new double[] { 2.0 })), 2, 2.0, false);
|
||||||
|
sut.setFinalObservation(new Observation(Nd4j.create(new double[] { 3.0 })));
|
||||||
|
|
||||||
|
// Act
|
||||||
|
int size = sut.getTrainingBatchSize();
|
||||||
|
// Assert
|
||||||
|
assertEquals(2, size);
|
||||||
|
}
|
||||||
|
|
||||||
|
private static class TestExpReplay implements IExpReplay<Integer> {
|
||||||
|
|
||||||
|
public final List<Transition<Integer>> addedTransitions = new ArrayList<>();
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public ArrayList<Transition<Integer>> getBatch() {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void store(Transition<Integer> transition) {
|
||||||
|
addedTransitions.add(transition);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public int getBatchSize() {
|
||||||
|
return addedTransitions.size();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,82 @@
|
||||||
|
package org.deeplearning4j.rl4j.experience;
|
||||||
|
|
||||||
|
import org.deeplearning4j.rl4j.observation.Observation;
|
||||||
|
import org.junit.Test;
|
||||||
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
|
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
|
import static org.junit.Assert.*;
|
||||||
|
|
||||||
|
public class StateActionExperienceHandlerTest {
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void when_addingExperience_expect_generateTrainingBatchReturnsIt() {
|
||||||
|
// Arrange
|
||||||
|
StateActionExperienceHandler sut = new StateActionExperienceHandler();
|
||||||
|
sut.reset();
|
||||||
|
Observation observation = new Observation(Nd4j.zeros(1));
|
||||||
|
sut.addExperience(observation, 123, 234.0, true);
|
||||||
|
|
||||||
|
// Act
|
||||||
|
List<StateActionPair<Integer>> result = sut.generateTrainingBatch();
|
||||||
|
|
||||||
|
// Assert
|
||||||
|
assertEquals(1, result.size());
|
||||||
|
assertSame(observation, result.get(0).getObservation());
|
||||||
|
assertEquals(123, (int)result.get(0).getAction());
|
||||||
|
assertEquals(234.0, result.get(0).getReward(), 0.00001);
|
||||||
|
assertTrue(result.get(0).isTerminal());
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void when_addingMultipleExperiences_expect_generateTrainingBatchReturnsItInSameOrder() {
|
||||||
|
// Arrange
|
||||||
|
StateActionExperienceHandler sut = new StateActionExperienceHandler();
|
||||||
|
sut.reset();
|
||||||
|
sut.addExperience(null, 1, 1.0, false);
|
||||||
|
sut.addExperience(null, 2, 2.0, false);
|
||||||
|
sut.addExperience(null, 3, 3.0, false);
|
||||||
|
|
||||||
|
// Act
|
||||||
|
List<StateActionPair<Integer>> result = sut.generateTrainingBatch();
|
||||||
|
|
||||||
|
// Assert
|
||||||
|
assertEquals(3, result.size());
|
||||||
|
assertEquals(1, (int)result.get(0).getAction());
|
||||||
|
assertEquals(2, (int)result.get(1).getAction());
|
||||||
|
assertEquals(3, (int)result.get(2).getAction());
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void when_gettingExperience_expect_experienceStoreIsCleared() {
|
||||||
|
// Arrange
|
||||||
|
StateActionExperienceHandler sut = new StateActionExperienceHandler();
|
||||||
|
sut.reset();
|
||||||
|
sut.addExperience(null, 1, 1.0, false);
|
||||||
|
|
||||||
|
// Act
|
||||||
|
List<StateActionPair<Integer>> firstResult = sut.generateTrainingBatch();
|
||||||
|
List<StateActionPair<Integer>> secondResult = sut.generateTrainingBatch();
|
||||||
|
|
||||||
|
// Assert
|
||||||
|
assertEquals(1, firstResult.size());
|
||||||
|
assertEquals(0, secondResult.size());
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void when_addingExperience_expect_getTrainingBatchSizeReturnSize() {
|
||||||
|
// Arrange
|
||||||
|
StateActionExperienceHandler sut = new StateActionExperienceHandler();
|
||||||
|
sut.reset();
|
||||||
|
sut.addExperience(null, 1, 1.0, false);
|
||||||
|
sut.addExperience(null, 2, 2.0, false);
|
||||||
|
sut.addExperience(null, 3, 3.0, false);
|
||||||
|
|
||||||
|
// Act
|
||||||
|
int size = sut.getTrainingBatchSize();
|
||||||
|
|
||||||
|
// Assert
|
||||||
|
assertEquals(3, size);
|
||||||
|
}
|
||||||
|
}
|
|
@ -18,9 +18,12 @@
|
||||||
package org.deeplearning4j.rl4j.learning.async;
|
package org.deeplearning4j.rl4j.learning.async;
|
||||||
|
|
||||||
import org.deeplearning4j.nn.gradient.Gradient;
|
import org.deeplearning4j.nn.gradient.Gradient;
|
||||||
|
import org.deeplearning4j.rl4j.experience.ExperienceHandler;
|
||||||
|
import org.deeplearning4j.rl4j.experience.StateActionPair;
|
||||||
import org.deeplearning4j.rl4j.learning.IHistoryProcessor;
|
import org.deeplearning4j.rl4j.learning.IHistoryProcessor;
|
||||||
import org.deeplearning4j.rl4j.learning.configuration.IAsyncLearningConfiguration;
|
import org.deeplearning4j.rl4j.learning.configuration.IAsyncLearningConfiguration;
|
||||||
import org.deeplearning4j.rl4j.learning.listener.TrainingListenerList;
|
import org.deeplearning4j.rl4j.learning.listener.TrainingListenerList;
|
||||||
|
import org.deeplearning4j.rl4j.learning.sync.Transition;
|
||||||
import org.deeplearning4j.rl4j.mdp.MDP;
|
import org.deeplearning4j.rl4j.mdp.MDP;
|
||||||
import org.deeplearning4j.rl4j.observation.Observation;
|
import org.deeplearning4j.rl4j.observation.Observation;
|
||||||
import org.deeplearning4j.rl4j.policy.IPolicy;
|
import org.deeplearning4j.rl4j.policy.IPolicy;
|
||||||
|
@ -31,7 +34,6 @@ import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.Stack;
|
|
||||||
|
|
||||||
import static org.junit.Assert.assertEquals;
|
import static org.junit.Assert.assertEquals;
|
||||||
|
|
||||||
|
@ -51,7 +53,9 @@ public class AsyncThreadDiscreteTest {
|
||||||
TrainingListenerList listeners = new TrainingListenerList();
|
TrainingListenerList listeners = new TrainingListenerList();
|
||||||
MockPolicy policyMock = new MockPolicy();
|
MockPolicy policyMock = new MockPolicy();
|
||||||
MockAsyncConfiguration config = new MockAsyncConfiguration(5L, 100, 0,0, 0, 0, 0, 0, 2, 5);
|
MockAsyncConfiguration config = new MockAsyncConfiguration(5L, 100, 0,0, 0, 0, 0, 0, 2, 5);
|
||||||
TestAsyncThreadDiscrete sut = new TestAsyncThreadDiscrete(asyncGlobalMock, mdpMock, listeners, 0, 0, policyMock, config, hpMock);
|
MockExperienceHandler experienceHandlerMock = new MockExperienceHandler();
|
||||||
|
MockUpdateAlgorithm updateAlgorithmMock = new MockUpdateAlgorithm();
|
||||||
|
TestAsyncThreadDiscrete sut = new TestAsyncThreadDiscrete(asyncGlobalMock, mdpMock, listeners, 0, 0, policyMock, config, hpMock, experienceHandlerMock, updateAlgorithmMock);
|
||||||
sut.getLegacyMDPWrapper().setTransformProcess(MockMDP.buildTransformProcess(observationSpace.getShape(), hpConf.getSkipFrame(), hpConf.getHistoryLength()));
|
sut.getLegacyMDPWrapper().setTransformProcess(MockMDP.buildTransformProcess(observationSpace.getShape(), hpConf.getSkipFrame(), hpConf.getHistoryLength()));
|
||||||
|
|
||||||
// Act
|
// Act
|
||||||
|
@ -102,62 +106,22 @@ public class AsyncThreadDiscreteTest {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// NeuralNetwork
|
// ExperienceHandler
|
||||||
assertEquals(2, nnMock.copyCallCount);
|
double[][] expectedExperienceHandlerInputs = new double[][] {
|
||||||
double[][] expectedNNInputs = new double[][] {
|
|
||||||
new double[] { 0.0, 2.0, 4.0, 6.0, 8.0 },
|
new double[] { 0.0, 2.0, 4.0, 6.0, 8.0 },
|
||||||
new double[] { 2.0, 4.0, 6.0, 8.0, 10.0 },
|
new double[] { 2.0, 4.0, 6.0, 8.0, 10.0 },
|
||||||
new double[] { 4.0, 6.0, 8.0, 10.0, 12.0 }, // FIXME: This one comes from the computation of output of the last minitrans
|
|
||||||
new double[] { 4.0, 6.0, 8.0, 10.0, 12.0 },
|
new double[] { 4.0, 6.0, 8.0, 10.0, 12.0 },
|
||||||
new double[] { 6.0, 8.0, 10.0, 12.0, 14.0 },
|
new double[] { 6.0, 8.0, 10.0, 12.0, 14.0 },
|
||||||
new double[] { 8.0, 10.0, 12.0, 14.0, 16.0 }, // FIXME: This one comes from the computation of output of the last minitrans
|
|
||||||
};
|
};
|
||||||
assertEquals(expectedNNInputs.length, nnMock.outputAllInputs.size());
|
assertEquals(expectedExperienceHandlerInputs.length, experienceHandlerMock.addExperienceArgs.size());
|
||||||
for(int i = 0; i < expectedNNInputs.length; ++i) {
|
for(int i = 0; i < expectedExperienceHandlerInputs.length; ++i) {
|
||||||
double[] expectedRow = expectedNNInputs[i];
|
double[] expectedRow = expectedExperienceHandlerInputs[i];
|
||||||
INDArray input = nnMock.outputAllInputs.get(i);
|
INDArray input = experienceHandlerMock.addExperienceArgs.get(i).getObservation().getData();
|
||||||
assertEquals(expectedRow.length, input.shape()[1]);
|
assertEquals(expectedRow.length, input.shape()[1]);
|
||||||
for(int j = 0; j < expectedRow.length; ++j) {
|
for(int j = 0; j < expectedRow.length; ++j) {
|
||||||
assertEquals(expectedRow[j], 255.0 * input.getDouble(j), 0.00001);
|
assertEquals(expectedRow[j], 255.0 * input.getDouble(j), 0.00001);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
int arrayIdx = 0;
|
|
||||||
double[][][] expectedMinitransObs = new double[][][] {
|
|
||||||
new double[][] {
|
|
||||||
new double[] { 0.0, 2.0, 4.0, 6.0, 8.0 },
|
|
||||||
new double[] { 2.0, 4.0, 6.0, 8.0, 10.0 },
|
|
||||||
new double[] { 4.0, 6.0, 8.0, 10.0, 12.0 }, // FIXME: The last minitrans contains the next observation
|
|
||||||
},
|
|
||||||
new double[][] {
|
|
||||||
new double[] { 4.0, 6.0, 8.0, 10.0, 12.0 },
|
|
||||||
new double[] { 6.0, 8.0, 10.0, 12.0, 14.0 },
|
|
||||||
new double[] { 8.0, 10.0, 12.0, 14.0, 16.0 }, // FIXME: The last minitrans contains the next observation
|
|
||||||
}
|
|
||||||
};
|
|
||||||
double[] expectedOutputs = new double[] { 1.0, 2.0, 3.0, 4.0, 5.0, 6.0 };
|
|
||||||
double[] expectedRewards = new double[] { 0.0, 0.0, 3.0, 0.0, 0.0, 6.0 };
|
|
||||||
|
|
||||||
assertEquals(2, sut.rewards.size());
|
|
||||||
for(int rewardIdx = 0; rewardIdx < 2; ++rewardIdx) {
|
|
||||||
Stack<MiniTrans<Integer>> miniTransStack = sut.rewards.get(rewardIdx);
|
|
||||||
|
|
||||||
for (int i = 0; i < expectedMinitransObs[rewardIdx].length; ++i) {
|
|
||||||
MiniTrans minitrans = miniTransStack.get(i);
|
|
||||||
|
|
||||||
// Observation
|
|
||||||
double[] expectedRow = expectedMinitransObs[rewardIdx][i];
|
|
||||||
INDArray realRewards = minitrans.getObs();
|
|
||||||
assertEquals(expectedRow.length, realRewards.shape()[1]);
|
|
||||||
for (int j = 0; j < expectedRow.length; ++j) {
|
|
||||||
assertEquals("row: "+ i + " col: " + j, expectedRow[j], 255.0 * realRewards.getDouble(j), 0.00001);
|
|
||||||
}
|
|
||||||
|
|
||||||
assertEquals(expectedOutputs[arrayIdx], minitrans.getOutput()[0].getDouble(0), 0.00001);
|
|
||||||
assertEquals(expectedRewards[arrayIdx], minitrans.getReward(), 0.00001);
|
|
||||||
++arrayIdx;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
public static class TestAsyncThreadDiscrete extends AsyncThreadDiscrete<MockEncodable, MockNeuralNet> {
|
public static class TestAsyncThreadDiscrete extends AsyncThreadDiscrete<MockEncodable, MockNeuralNet> {
|
||||||
|
@ -167,22 +131,19 @@ public class AsyncThreadDiscreteTest {
|
||||||
private final MockAsyncConfiguration config;
|
private final MockAsyncConfiguration config;
|
||||||
|
|
||||||
public final List<SubEpochReturn> trainSubEpochResults = new ArrayList<SubEpochReturn>();
|
public final List<SubEpochReturn> trainSubEpochResults = new ArrayList<SubEpochReturn>();
|
||||||
public final List<Stack<MiniTrans<Integer>>> rewards = new ArrayList<Stack<MiniTrans<Integer>>>();
|
|
||||||
|
|
||||||
public TestAsyncThreadDiscrete(MockAsyncGlobal asyncGlobal, MDP<MockEncodable, Integer, DiscreteSpace> mdp,
|
public TestAsyncThreadDiscrete(MockAsyncGlobal asyncGlobal, MDP<MockEncodable, Integer, DiscreteSpace> mdp,
|
||||||
TrainingListenerList listeners, int threadNumber, int deviceNum, MockPolicy policy,
|
TrainingListenerList listeners, int threadNumber, int deviceNum, MockPolicy policy,
|
||||||
MockAsyncConfiguration config, IHistoryProcessor hp) {
|
MockAsyncConfiguration config, IHistoryProcessor hp,
|
||||||
|
ExperienceHandler<Integer, Transition<Integer>> experienceHandler,
|
||||||
|
UpdateAlgorithm<MockNeuralNet> updateAlgorithm) {
|
||||||
super(asyncGlobal, mdp, listeners, threadNumber, deviceNum);
|
super(asyncGlobal, mdp, listeners, threadNumber, deviceNum);
|
||||||
this.asyncGlobal = asyncGlobal;
|
this.asyncGlobal = asyncGlobal;
|
||||||
this.policy = policy;
|
this.policy = policy;
|
||||||
this.config = config;
|
this.config = config;
|
||||||
setHistoryProcessor(hp);
|
setHistoryProcessor(hp);
|
||||||
}
|
setExperienceHandler(experienceHandler);
|
||||||
|
setUpdateAlgorithm(updateAlgorithm);
|
||||||
@Override
|
|
||||||
public Gradient[] calcGradient(MockNeuralNet mockNeuralNet, Stack<MiniTrans<Integer>> rewards) {
|
|
||||||
this.rewards.add(rewards);
|
|
||||||
return new Gradient[0];
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
@ -200,6 +161,11 @@ public class AsyncThreadDiscreteTest {
|
||||||
return policy;
|
return policy;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
protected UpdateAlgorithm<MockNeuralNet> buildUpdateAlgorithm() {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public SubEpochReturn trainSubEpoch(Observation sObs, int nstep) {
|
public SubEpochReturn trainSubEpoch(Observation sObs, int nstep) {
|
||||||
asyncGlobal.increaseCurrentLoop();
|
asyncGlobal.increaseCurrentLoop();
|
||||||
|
|
|
@ -1,197 +0,0 @@
|
||||||
/*******************************************************************************
|
|
||||||
* Copyright (c) 2015-2019 Skymind, Inc.
|
|
||||||
* Copyright (c) 2020 Konduit K.K.
|
|
||||||
*
|
|
||||||
* This program and the accompanying materials are made available under the
|
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* License for the specific language governing permissions and limitations
|
|
||||||
* under the License.
|
|
||||||
*
|
|
||||||
* SPDX-License-Identifier: Apache-2.0
|
|
||||||
******************************************************************************/
|
|
||||||
|
|
||||||
package org.deeplearning4j.rl4j.learning.async.a3c.discrete;
|
|
||||||
|
|
||||||
import org.deeplearning4j.nn.api.NeuralNetwork;
|
|
||||||
import org.deeplearning4j.nn.gradient.Gradient;
|
|
||||||
import org.deeplearning4j.rl4j.learning.IHistoryProcessor;
|
|
||||||
import org.deeplearning4j.rl4j.learning.async.MiniTrans;
|
|
||||||
import org.deeplearning4j.rl4j.learning.configuration.A3CLearningConfiguration;
|
|
||||||
import org.deeplearning4j.rl4j.network.NeuralNet;
|
|
||||||
import org.deeplearning4j.rl4j.network.ac.IActorCritic;
|
|
||||||
import org.deeplearning4j.rl4j.support.*;
|
|
||||||
import org.junit.Test;
|
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
|
||||||
import org.nd4j.linalg.primitives.Pair;
|
|
||||||
|
|
||||||
import java.io.IOException;
|
|
||||||
import java.io.OutputStream;
|
|
||||||
import java.util.ArrayList;
|
|
||||||
import java.util.List;
|
|
||||||
import java.util.Stack;
|
|
||||||
|
|
||||||
import static org.junit.Assert.assertArrayEquals;
|
|
||||||
import static org.junit.Assert.assertEquals;
|
|
||||||
|
|
||||||
public class A3CThreadDiscreteTest {
|
|
||||||
|
|
||||||
@Test
|
|
||||||
public void refac_calcGradient() {
|
|
||||||
// Arrange
|
|
||||||
double gamma = 0.9;
|
|
||||||
MockObservationSpace observationSpace = new MockObservationSpace();
|
|
||||||
MockMDP mdpMock = new MockMDP(observationSpace);
|
|
||||||
A3CLearningConfiguration config = A3CLearningConfiguration.builder().gamma(0.9).build();
|
|
||||||
MockActorCritic actorCriticMock = new MockActorCritic();
|
|
||||||
IHistoryProcessor.Configuration hpConf = new IHistoryProcessor.Configuration(5, 1, 1, 1, 1, 0, 0, 2);
|
|
||||||
MockAsyncGlobal<IActorCritic> asyncGlobalMock = new MockAsyncGlobal<IActorCritic>(actorCriticMock);
|
|
||||||
A3CThreadDiscrete sut = new A3CThreadDiscrete<MockEncodable>(mdpMock, asyncGlobalMock, config, 0, null, 0);
|
|
||||||
MockHistoryProcessor hpMock = new MockHistoryProcessor(hpConf);
|
|
||||||
sut.setHistoryProcessor(hpMock);
|
|
||||||
|
|
||||||
double[][] minitransObs = new double[][] {
|
|
||||||
new double[] { 0.0, 2.0, 4.0, 6.0, 8.0 },
|
|
||||||
new double[] { 2.0, 4.0, 6.0, 8.0, 10.0 },
|
|
||||||
new double[] { 4.0, 6.0, 8.0, 10.0, 12.0 },
|
|
||||||
};
|
|
||||||
double[] outputs = new double[] { 1.0, 2.0, 3.0 };
|
|
||||||
double[] rewards = new double[] { 0.0, 0.0, 3.0 };
|
|
||||||
|
|
||||||
Stack<MiniTrans<Integer>> minitransList = new Stack<MiniTrans<Integer>>();
|
|
||||||
for(int i = 0; i < 3; ++i) {
|
|
||||||
INDArray obs = Nd4j.create(minitransObs[i]).reshape(5, 1, 1);
|
|
||||||
INDArray[] output = new INDArray[] {
|
|
||||||
Nd4j.zeros(5)
|
|
||||||
};
|
|
||||||
output[0].putScalar(i, outputs[i]);
|
|
||||||
minitransList.push(new MiniTrans<>(obs, i, output, rewards[i]));
|
|
||||||
}
|
|
||||||
minitransList.push(new MiniTrans<>(null, 0, null, 4.0)); // The special batch-ending MiniTrans
|
|
||||||
|
|
||||||
// Act
|
|
||||||
sut.calcGradient(actorCriticMock, minitransList);
|
|
||||||
|
|
||||||
// Assert
|
|
||||||
assertEquals(1, actorCriticMock.gradientParams.size());
|
|
||||||
INDArray input = actorCriticMock.gradientParams.get(0).getFirst();
|
|
||||||
INDArray[] labels = actorCriticMock.gradientParams.get(0).getSecond();
|
|
||||||
|
|
||||||
assertEquals(minitransObs.length, input.shape()[0]);
|
|
||||||
for(int i = 0; i < minitransObs.length; ++i) {
|
|
||||||
double[] expectedRow = minitransObs[i];
|
|
||||||
assertEquals(expectedRow.length, input.shape()[1]);
|
|
||||||
for(int j = 0; j < expectedRow.length; ++j) {
|
|
||||||
assertEquals(expectedRow[j], input.getDouble(i, j, 1, 1), 0.00001);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
double latestReward = (gamma * 4.0) + 3.0;
|
|
||||||
double[] expectedLabels0 = new double[] { gamma * gamma * latestReward, gamma * latestReward, latestReward };
|
|
||||||
for(int i = 0; i < expectedLabels0.length; ++i) {
|
|
||||||
assertEquals(expectedLabels0[i], labels[0].getDouble(i), 0.00001);
|
|
||||||
}
|
|
||||||
double[][] expectedLabels1 = new double[][] {
|
|
||||||
new double[] { 4.346, 0.0, 0.0, 0.0, 0.0 },
|
|
||||||
new double[] { 0.0, gamma * latestReward, 0.0, 0.0, 0.0 },
|
|
||||||
new double[] { 0.0, 0.0, latestReward, 0.0, 0.0 },
|
|
||||||
};
|
|
||||||
|
|
||||||
assertArrayEquals(new long[] { expectedLabels0.length, 1 }, labels[0].shape());
|
|
||||||
|
|
||||||
for(int i = 0; i < expectedLabels1.length; ++i) {
|
|
||||||
double[] expectedRow = expectedLabels1[i];
|
|
||||||
assertEquals(expectedRow.length, labels[1].shape()[1]);
|
|
||||||
for(int j = 0; j < expectedRow.length; ++j) {
|
|
||||||
assertEquals(expectedRow[j], labels[1].getDouble(i, j), 0.00001);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
public class MockActorCritic implements IActorCritic {
|
|
||||||
|
|
||||||
public final List<Pair<INDArray, INDArray[]>> gradientParams = new ArrayList<>();
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public NeuralNetwork[] getNeuralNetworks() {
|
|
||||||
return new NeuralNetwork[0];
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public boolean isRecurrent() {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void reset() {
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void fit(INDArray input, INDArray[] labels) {
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public INDArray[] outputAll(INDArray batch) {
|
|
||||||
return new INDArray[0];
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public IActorCritic clone() {
|
|
||||||
return this;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void copy(NeuralNet from) {
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void copy(IActorCritic from) {
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public Gradient[] gradient(INDArray input, INDArray[] labels) {
|
|
||||||
gradientParams.add(new Pair<INDArray, INDArray[]>(input, labels));
|
|
||||||
return new Gradient[0];
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void applyGradient(Gradient[] gradient, int batchSize) {
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void save(OutputStream streamValue, OutputStream streamPolicy) throws IOException {
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void save(String pathValue, String pathPolicy) throws IOException {
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public double getLatestScore() {
|
|
||||||
return 0;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void save(OutputStream os) throws IOException {
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void save(String filename) throws IOException {
|
|
||||||
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -0,0 +1,160 @@
|
||||||
|
package org.deeplearning4j.rl4j.learning.async.a3c.discrete;
|
||||||
|
|
||||||
|
import org.deeplearning4j.nn.api.NeuralNetwork;
|
||||||
|
import org.deeplearning4j.nn.gradient.Gradient;
|
||||||
|
import org.deeplearning4j.rl4j.experience.StateActionPair;
|
||||||
|
import org.deeplearning4j.rl4j.network.NeuralNet;
|
||||||
|
import org.deeplearning4j.rl4j.network.ac.IActorCritic;
|
||||||
|
import org.deeplearning4j.rl4j.observation.Observation;
|
||||||
|
import org.deeplearning4j.rl4j.support.MockAsyncGlobal;
|
||||||
|
import org.deeplearning4j.rl4j.support.MockMDP;
|
||||||
|
import org.deeplearning4j.rl4j.support.MockObservationSpace;
|
||||||
|
import org.junit.Test;
|
||||||
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
|
import org.nd4j.linalg.primitives.Pair;
|
||||||
|
|
||||||
|
import java.io.IOException;
|
||||||
|
import java.io.OutputStream;
|
||||||
|
import java.util.ArrayList;
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
|
import static org.junit.Assert.assertEquals;
|
||||||
|
|
||||||
|
public class A3CUpdateAlgorithmTest {
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void refac_calcGradient_non_terminal() {
|
||||||
|
// Arrange
|
||||||
|
double gamma = 0.9;
|
||||||
|
MockObservationSpace observationSpace = new MockObservationSpace(new int[] { 5 });
|
||||||
|
MockMDP mdpMock = new MockMDP(observationSpace);
|
||||||
|
MockActorCritic actorCriticMock = new MockActorCritic();
|
||||||
|
MockAsyncGlobal<IActorCritic> asyncGlobalMock = new MockAsyncGlobal<IActorCritic>(actorCriticMock);
|
||||||
|
A3CUpdateAlgorithm sut = new A3CUpdateAlgorithm(asyncGlobalMock, observationSpace.getShape(), mdpMock.getActionSpace().getSize(), -1, gamma);
|
||||||
|
|
||||||
|
|
||||||
|
INDArray[] originalObservations = new INDArray[] {
|
||||||
|
Nd4j.create(new double[] { 0.0, 0.1, 0.2, 0.3, 0.4 }),
|
||||||
|
Nd4j.create(new double[] { 1.0, 1.1, 1.2, 1.3, 1.4 }),
|
||||||
|
Nd4j.create(new double[] { 2.0, 2.1, 2.2, 2.3, 2.4 }),
|
||||||
|
Nd4j.create(new double[] { 3.0, 3.1, 3.2, 3.3, 3.4 }),
|
||||||
|
};
|
||||||
|
int[] actions = new int[] { 0, 1, 2, 1 };
|
||||||
|
double[] rewards = new double[] { 0.1, 1.0, 10.0, 100.0 };
|
||||||
|
|
||||||
|
List<StateActionPair<Integer>> experience = new ArrayList<StateActionPair<Integer>>();
|
||||||
|
for(int i = 0; i < originalObservations.length; ++i) {
|
||||||
|
experience.add(new StateActionPair<>(new Observation(originalObservations[i]), actions[i], rewards[i], false));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Act
|
||||||
|
sut.computeGradients(actorCriticMock, experience);
|
||||||
|
|
||||||
|
// Assert
|
||||||
|
assertEquals(1, actorCriticMock.gradientParams.size());
|
||||||
|
|
||||||
|
// Inputs
|
||||||
|
INDArray input = actorCriticMock.gradientParams.get(0).getLeft();
|
||||||
|
for(int i = 0; i < 4; ++i) {
|
||||||
|
for(int j = 0; j < 5; ++j) {
|
||||||
|
assertEquals(i + j / 10.0, input.getDouble(i, j), 0.00001);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
INDArray targets = actorCriticMock.gradientParams.get(0).getRight()[0];
|
||||||
|
INDArray logSoftmax = actorCriticMock.gradientParams.get(0).getRight()[1];
|
||||||
|
|
||||||
|
assertEquals(4, targets.shape()[0]);
|
||||||
|
assertEquals(1, targets.shape()[1]);
|
||||||
|
|
||||||
|
// FIXME: check targets values once fixed
|
||||||
|
|
||||||
|
assertEquals(4, logSoftmax.shape()[0]);
|
||||||
|
assertEquals(5, logSoftmax.shape()[1]);
|
||||||
|
|
||||||
|
// FIXME: check logSoftmax values once fixed
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
public class MockActorCritic implements IActorCritic {
|
||||||
|
|
||||||
|
public final List<Pair<INDArray, INDArray[]>> gradientParams = new ArrayList<>();
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public NeuralNetwork[] getNeuralNetworks() {
|
||||||
|
return new NeuralNetwork[0];
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public boolean isRecurrent() {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void reset() {
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void fit(INDArray input, INDArray[] labels) {
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public INDArray[] outputAll(INDArray batch) {
|
||||||
|
return new INDArray[] { batch.mul(-1.0) };
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public IActorCritic clone() {
|
||||||
|
return this;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void copy(NeuralNet from) {
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void copy(IActorCritic from) {
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public Gradient[] gradient(INDArray input, INDArray[] labels) {
|
||||||
|
gradientParams.add(new Pair<INDArray, INDArray[]>(input, labels));
|
||||||
|
return new Gradient[0];
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void applyGradient(Gradient[] gradient, int batchSize) {
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void save(OutputStream streamValue, OutputStream streamPolicy) throws IOException {
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void save(String pathValue, String pathPolicy) throws IOException {
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public double getLatestScore() {
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void save(OutputStream os) throws IOException {
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void save(String filename) throws IOException {
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -1,98 +0,0 @@
|
||||||
/*******************************************************************************
|
|
||||||
* Copyright (c) 2015-2020 Skymind, Inc.
|
|
||||||
*
|
|
||||||
* This program and the accompanying materials are made available under the
|
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* License for the specific language governing permissions and limitations
|
|
||||||
* under the License.
|
|
||||||
*
|
|
||||||
* SPDX-License-Identifier: Apache-2.0
|
|
||||||
******************************************************************************/
|
|
||||||
|
|
||||||
package org.deeplearning4j.rl4j.learning.async.nstep.discrete;
|
|
||||||
|
|
||||||
import org.deeplearning4j.rl4j.learning.IHistoryProcessor;
|
|
||||||
import org.deeplearning4j.rl4j.learning.async.MiniTrans;
|
|
||||||
import org.deeplearning4j.rl4j.learning.configuration.AsyncQLearningConfiguration;
|
|
||||||
import org.deeplearning4j.rl4j.support.*;
|
|
||||||
import org.junit.Test;
|
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
|
||||||
|
|
||||||
import java.util.Stack;
|
|
||||||
|
|
||||||
import static org.junit.Assert.assertEquals;
|
|
||||||
|
|
||||||
public class AsyncNStepQLearningThreadDiscreteTest {
|
|
||||||
|
|
||||||
@Test
|
|
||||||
public void refac_calcGradient() {
|
|
||||||
// Arrange
|
|
||||||
double gamma = 0.9;
|
|
||||||
MockObservationSpace observationSpace = new MockObservationSpace();
|
|
||||||
MockMDP mdpMock = new MockMDP(observationSpace);
|
|
||||||
AsyncQLearningConfiguration config = AsyncQLearningConfiguration.builder().gamma(gamma).build();
|
|
||||||
MockDQN dqnMock = new MockDQN();
|
|
||||||
IHistoryProcessor.Configuration hpConf = new IHistoryProcessor.Configuration(5, 1, 1, 1, 1, 0, 0, 2);
|
|
||||||
MockAsyncGlobal asyncGlobalMock = new MockAsyncGlobal(dqnMock);
|
|
||||||
AsyncNStepQLearningThreadDiscrete sut = new AsyncNStepQLearningThreadDiscrete<MockEncodable>(mdpMock, asyncGlobalMock, config, null, 0, 0);
|
|
||||||
MockHistoryProcessor hpMock = new MockHistoryProcessor(hpConf);
|
|
||||||
sut.setHistoryProcessor(hpMock);
|
|
||||||
|
|
||||||
double[][] minitransObs = new double[][] {
|
|
||||||
new double[] { 0.0, 2.0, 4.0, 6.0, 8.0 },
|
|
||||||
new double[] { 2.0, 4.0, 6.0, 8.0, 10.0 },
|
|
||||||
new double[] { 4.0, 6.0, 8.0, 10.0, 12.0 },
|
|
||||||
};
|
|
||||||
double[] outputs = new double[] { 1.0, 2.0, 3.0 };
|
|
||||||
double[] rewards = new double[] { 0.0, 0.0, 3.0 };
|
|
||||||
|
|
||||||
Stack<MiniTrans<Integer>> minitransList = new Stack<MiniTrans<Integer>>();
|
|
||||||
for(int i = 0; i < 3; ++i) {
|
|
||||||
INDArray obs = Nd4j.create(minitransObs[i]).reshape(5, 1, 1);
|
|
||||||
INDArray[] output = new INDArray[] {
|
|
||||||
Nd4j.zeros(5)
|
|
||||||
};
|
|
||||||
output[0].putScalar(i, outputs[i]);
|
|
||||||
minitransList.push(new MiniTrans<>(obs, i, output, rewards[i]));
|
|
||||||
}
|
|
||||||
minitransList.push(new MiniTrans<>(null, 0, null, 4.0)); // The special batch-ending MiniTrans
|
|
||||||
|
|
||||||
// Act
|
|
||||||
sut.calcGradient(dqnMock, minitransList);
|
|
||||||
|
|
||||||
// Assert
|
|
||||||
assertEquals(1, dqnMock.gradientParams.size());
|
|
||||||
INDArray input = dqnMock.gradientParams.get(0).getFirst();
|
|
||||||
INDArray labels = dqnMock.gradientParams.get(0).getSecond();
|
|
||||||
|
|
||||||
assertEquals(minitransObs.length, input.shape()[0]);
|
|
||||||
for(int i = 0; i < minitransObs.length; ++i) {
|
|
||||||
double[] expectedRow = minitransObs[i];
|
|
||||||
assertEquals(expectedRow.length, input.shape()[1]);
|
|
||||||
for(int j = 0; j < expectedRow.length; ++j) {
|
|
||||||
assertEquals(expectedRow[j], input.getDouble(i, j, 1, 1), 0.00001);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
double latestReward = (gamma * 4.0) + 3.0;
|
|
||||||
double[][] expectedLabels = new double[][] {
|
|
||||||
new double[] { gamma * gamma * latestReward, 0.0, 0.0, 0.0, 0.0 },
|
|
||||||
new double[] { 0.0, gamma * latestReward, 0.0, 0.0, 0.0 },
|
|
||||||
new double[] { 0.0, 0.0, latestReward, 0.0, 0.0 },
|
|
||||||
};
|
|
||||||
assertEquals(minitransObs.length, labels.shape()[0]);
|
|
||||||
for(int i = 0; i < minitransObs.length; ++i) {
|
|
||||||
double[] expectedRow = expectedLabels[i];
|
|
||||||
assertEquals(expectedRow.length, labels.shape()[1]);
|
|
||||||
for(int j = 0; j < expectedRow.length; ++j) {
|
|
||||||
assertEquals(expectedRow[j], labels.getDouble(i, j), 0.00001);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -0,0 +1,115 @@
|
||||||
|
package org.deeplearning4j.rl4j.learning.async.nstep.discrete;
|
||||||
|
|
||||||
|
import org.deeplearning4j.rl4j.experience.StateActionPair;
|
||||||
|
import org.deeplearning4j.rl4j.learning.async.UpdateAlgorithm;
|
||||||
|
import org.deeplearning4j.rl4j.observation.Observation;
|
||||||
|
import org.deeplearning4j.rl4j.support.MockAsyncGlobal;
|
||||||
|
import org.deeplearning4j.rl4j.support.MockDQN;
|
||||||
|
import org.junit.Test;
|
||||||
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
|
|
||||||
|
import java.util.ArrayList;
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
|
import static org.junit.Assert.assertEquals;
|
||||||
|
|
||||||
|
public class QLearningUpdateAlgorithmTest {
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void when_isTerminal_expect_initRewardIs0() {
|
||||||
|
// Arrange
|
||||||
|
MockDQN dqnMock = new MockDQN();
|
||||||
|
MockAsyncGlobal asyncGlobalMock = new MockAsyncGlobal(dqnMock);
|
||||||
|
UpdateAlgorithm sut = new QLearningUpdateAlgorithm(asyncGlobalMock, new int[] { 1 }, 1, -1, 1.0);
|
||||||
|
List<StateActionPair<Integer>> experience = new ArrayList<StateActionPair<Integer>>() {
|
||||||
|
{
|
||||||
|
add(new StateActionPair<Integer>(new Observation(Nd4j.zeros(1)), 0, 0.0, true));
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
// Act
|
||||||
|
sut.computeGradients(dqnMock, experience);
|
||||||
|
|
||||||
|
// Assert
|
||||||
|
assertEquals(0.0, dqnMock.gradientParams.get(0).getRight().getDouble(0), 0.00001);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void when_terminalAndNoTargetUpdate_expect_initRewardWithMaxQFromCurrent() {
|
||||||
|
// Arrange
|
||||||
|
MockDQN globalDQNMock = new MockDQN();
|
||||||
|
MockAsyncGlobal asyncGlobalMock = new MockAsyncGlobal(globalDQNMock);
|
||||||
|
UpdateAlgorithm sut = new QLearningUpdateAlgorithm(asyncGlobalMock, new int[] { 2 }, 2, -1, 1.0);
|
||||||
|
List<StateActionPair<Integer>> experience = new ArrayList<StateActionPair<Integer>>() {
|
||||||
|
{
|
||||||
|
add(new StateActionPair<Integer>(new Observation(Nd4j.create(new double[] { -123.0, -234.0 })), 0, 0.0, false));
|
||||||
|
}
|
||||||
|
};
|
||||||
|
MockDQN dqnMock = new MockDQN();
|
||||||
|
|
||||||
|
// Act
|
||||||
|
sut.computeGradients(dqnMock, experience);
|
||||||
|
|
||||||
|
// Assert
|
||||||
|
assertEquals(2, dqnMock.outputAllParams.size());
|
||||||
|
assertEquals(-123.0, dqnMock.outputAllParams.get(0).getDouble(0, 0), 0.00001);
|
||||||
|
assertEquals(234.0, dqnMock.gradientParams.get(0).getRight().getDouble(0), 0.00001);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void when_terminalWithTargetUpdate_expect_initRewardWithMaxQFromGlobal() {
|
||||||
|
// Arrange
|
||||||
|
MockDQN globalDQNMock = new MockDQN();
|
||||||
|
MockAsyncGlobal asyncGlobalMock = new MockAsyncGlobal(globalDQNMock);
|
||||||
|
UpdateAlgorithm sut = new QLearningUpdateAlgorithm(asyncGlobalMock, new int[] { 2 }, 2, 1, 1.0);
|
||||||
|
List<StateActionPair<Integer>> experience = new ArrayList<StateActionPair<Integer>>() {
|
||||||
|
{
|
||||||
|
add(new StateActionPair<Integer>(new Observation(Nd4j.create(new double[] { -123.0, -234.0 })), 0, 0.0, false));
|
||||||
|
}
|
||||||
|
};
|
||||||
|
MockDQN dqnMock = new MockDQN();
|
||||||
|
|
||||||
|
// Act
|
||||||
|
sut.computeGradients(dqnMock, experience);
|
||||||
|
|
||||||
|
// Assert
|
||||||
|
assertEquals(1, globalDQNMock.outputAllParams.size());
|
||||||
|
assertEquals(-123.0, globalDQNMock.outputAllParams.get(0).getDouble(0, 0), 0.00001);
|
||||||
|
assertEquals(234.0, dqnMock.gradientParams.get(0).getRight().getDouble(0), 0.00001);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void when_callingWithMultipleExperiences_expect_gradientsAreValid() {
|
||||||
|
// Arrange
|
||||||
|
double gamma = 0.9;
|
||||||
|
MockDQN globalDQNMock = new MockDQN();
|
||||||
|
MockAsyncGlobal asyncGlobalMock = new MockAsyncGlobal(globalDQNMock);
|
||||||
|
UpdateAlgorithm sut = new QLearningUpdateAlgorithm(asyncGlobalMock, new int[] { 2 }, 2, 1, gamma);
|
||||||
|
List<StateActionPair<Integer>> experience = new ArrayList<StateActionPair<Integer>>() {
|
||||||
|
{
|
||||||
|
add(new StateActionPair<Integer>(new Observation(Nd4j.create(new double[] { -1.1, -1.2 })), 0, 1.0, false));
|
||||||
|
add(new StateActionPair<Integer>(new Observation(Nd4j.create(new double[] { -2.1, -2.2 })), 1, 2.0, true));
|
||||||
|
}
|
||||||
|
};
|
||||||
|
MockDQN dqnMock = new MockDQN();
|
||||||
|
|
||||||
|
// Act
|
||||||
|
sut.computeGradients(dqnMock, experience);
|
||||||
|
|
||||||
|
// Assert
|
||||||
|
// input side -- should be a stack of observations
|
||||||
|
INDArray input = dqnMock.gradientParams.get(0).getLeft();
|
||||||
|
assertEquals(-1.1, input.getDouble(0, 0), 0.00001);
|
||||||
|
assertEquals(-1.2, input.getDouble(0, 1), 0.00001);
|
||||||
|
assertEquals(-2.1, input.getDouble(1, 0), 0.00001);
|
||||||
|
assertEquals(-2.2, input.getDouble(1, 1), 0.00001);
|
||||||
|
|
||||||
|
// target side
|
||||||
|
INDArray target = dqnMock.gradientParams.get(0).getRight();
|
||||||
|
assertEquals(1.0 + gamma * 2.0, target.getDouble(0, 0), 0.00001);
|
||||||
|
assertEquals(1.2, target.getDouble(0, 1), 0.00001);
|
||||||
|
assertEquals(2.1, target.getDouble(1, 0), 0.00001);
|
||||||
|
assertEquals(2.0, target.getDouble(1, 1), 0.00001);
|
||||||
|
}
|
||||||
|
}
|
|
@ -17,6 +17,8 @@
|
||||||
|
|
||||||
package org.deeplearning4j.rl4j.learning.sync.qlearning.discrete;
|
package org.deeplearning4j.rl4j.learning.sync.qlearning.discrete;
|
||||||
|
|
||||||
|
import org.deeplearning4j.rl4j.experience.ExperienceHandler;
|
||||||
|
import org.deeplearning4j.rl4j.experience.StateActionPair;
|
||||||
import org.deeplearning4j.rl4j.learning.IHistoryProcessor;
|
import org.deeplearning4j.rl4j.learning.IHistoryProcessor;
|
||||||
import org.deeplearning4j.rl4j.learning.configuration.QLearningConfiguration;
|
import org.deeplearning4j.rl4j.learning.configuration.QLearningConfiguration;
|
||||||
import org.deeplearning4j.rl4j.learning.sync.IExpReplay;
|
import org.deeplearning4j.rl4j.learning.sync.IExpReplay;
|
||||||
|
@ -75,8 +77,8 @@ public class QLearningDiscreteTest {
|
||||||
.build();
|
.build();
|
||||||
|
|
||||||
MockDataManager dataManager = new MockDataManager(false);
|
MockDataManager dataManager = new MockDataManager(false);
|
||||||
MockExpReplay expReplay = new MockExpReplay();
|
MockExperienceHandler experienceHandler = new MockExperienceHandler();
|
||||||
TestQLearningDiscrete sut = new TestQLearningDiscrete(mdp, dqn, conf, dataManager, expReplay, 10, random);
|
TestQLearningDiscrete sut = new TestQLearningDiscrete(mdp, dqn, conf, dataManager, experienceHandler, 10, random);
|
||||||
IHistoryProcessor.Configuration hpConf = new IHistoryProcessor.Configuration(5, 4, 4, 4, 4, 0, 0, 2);
|
IHistoryProcessor.Configuration hpConf = new IHistoryProcessor.Configuration(5, 4, 4, 4, 4, 0, 0, 2);
|
||||||
MockHistoryProcessor hp = new MockHistoryProcessor(hpConf);
|
MockHistoryProcessor hp = new MockHistoryProcessor(hpConf);
|
||||||
sut.setHistoryProcessor(hp);
|
sut.setHistoryProcessor(hp);
|
||||||
|
@ -93,7 +95,6 @@ public class QLearningDiscreteTest {
|
||||||
for (int i = 0; i < expectedRecords.length; ++i) {
|
for (int i = 0; i < expectedRecords.length; ++i) {
|
||||||
assertEquals(expectedRecords[i], hp.recordCalls.get(i).getDouble(0), 0.0001);
|
assertEquals(expectedRecords[i], hp.recordCalls.get(i).getDouble(0), 0.0001);
|
||||||
}
|
}
|
||||||
|
|
||||||
assertEquals(0, hp.startMonitorCallCount);
|
assertEquals(0, hp.startMonitorCallCount);
|
||||||
assertEquals(0, hp.stopMonitorCallCount);
|
assertEquals(0, hp.stopMonitorCallCount);
|
||||||
|
|
||||||
|
@ -133,7 +134,7 @@ public class QLearningDiscreteTest {
|
||||||
// MDP calls
|
// MDP calls
|
||||||
assertArrayEquals(new Integer[]{0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 4, 4, 2, 2, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4}, mdp.actions.toArray());
|
assertArrayEquals(new Integer[]{0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 4, 4, 2, 2, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4}, mdp.actions.toArray());
|
||||||
|
|
||||||
// ExpReplay calls
|
// ExperienceHandler calls
|
||||||
double[] expectedTrRewards = new double[] { 9.0, 21.0, 25.0, 29.0, 33.0, 37.0, 41.0 };
|
double[] expectedTrRewards = new double[] { 9.0, 21.0, 25.0, 29.0, 33.0, 37.0, 41.0 };
|
||||||
int[] expectedTrActions = new int[] { 1, 4, 2, 4, 4, 4, 4, 4 };
|
int[] expectedTrActions = new int[] { 1, 4, 2, 4, 4, 4, 4, 4 };
|
||||||
double[] expectedTrNextObservation = new double[] { 2.0, 4.0, 6.0, 8.0, 10.0, 12.0, 14.0, 16.0 };
|
double[] expectedTrNextObservation = new double[] { 2.0, 4.0, 6.0, 8.0, 10.0, 12.0, 14.0, 16.0 };
|
||||||
|
@ -147,16 +148,17 @@ public class QLearningDiscreteTest {
|
||||||
new double[] { 12.0, 14.0, 16.0, 18.0, 20.0 },
|
new double[] { 12.0, 14.0, 16.0, 18.0, 20.0 },
|
||||||
new double[] { 14.0, 16.0, 18.0, 20.0, 22.0 },
|
new double[] { 14.0, 16.0, 18.0, 20.0, 22.0 },
|
||||||
};
|
};
|
||||||
assertEquals(expectedTrObservations.length, expReplay.transitions.size());
|
|
||||||
|
assertEquals(expectedTrObservations.length, experienceHandler.addExperienceArgs.size());
|
||||||
for(int i = 0; i < expectedTrRewards.length; ++i) {
|
for(int i = 0; i < expectedTrRewards.length; ++i) {
|
||||||
Transition tr = expReplay.transitions.get(i);
|
StateActionPair<Integer> stateActionPair = experienceHandler.addExperienceArgs.get(i);
|
||||||
assertEquals(expectedTrRewards[i], tr.getReward(), 0.0001);
|
assertEquals(expectedTrRewards[i], stateActionPair.getReward(), 0.0001);
|
||||||
assertEquals(expectedTrActions[i], tr.getAction());
|
assertEquals((int)expectedTrActions[i], (int)stateActionPair.getAction());
|
||||||
assertEquals(expectedTrNextObservation[i], 255.0 * tr.getNextObservation().getDouble(0), 0.0001);
|
|
||||||
for(int j = 0; j < expectedTrObservations[i].length; ++j) {
|
for(int j = 0; j < expectedTrObservations[i].length; ++j) {
|
||||||
assertEquals("row: " + i + " col: " + j, expectedTrObservations[i][j], 255.0 * tr.getObservation().getData().getDouble(0, j, 0), 0.0001);
|
assertEquals("row: "+ i + " col: " + j, expectedTrObservations[i][j], 255.0 * stateActionPair.getObservation().getData().getDouble(0, j, 0), 0.0001);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
assertEquals(expectedTrNextObservation[expectedTrNextObservation.length - 1], 255.0 * experienceHandler.finalObservation.getData().getDouble(0), 0.0001);
|
||||||
|
|
||||||
// trainEpoch result
|
// trainEpoch result
|
||||||
assertEquals(initStepCount + 16, result.getStepCounter());
|
assertEquals(initStepCount + 16, result.getStepCounter());
|
||||||
|
@ -167,22 +169,18 @@ public class QLearningDiscreteTest {
|
||||||
|
|
||||||
public static class TestQLearningDiscrete extends QLearningDiscrete<MockEncodable> {
|
public static class TestQLearningDiscrete extends QLearningDiscrete<MockEncodable> {
|
||||||
public TestQLearningDiscrete(MDP<MockEncodable, Integer, DiscreteSpace> mdp, IDQN dqn,
|
public TestQLearningDiscrete(MDP<MockEncodable, Integer, DiscreteSpace> mdp, IDQN dqn,
|
||||||
QLearningConfiguration conf, IDataManager dataManager, MockExpReplay expReplay,
|
QLearningConfiguration conf, IDataManager dataManager, ExperienceHandler<Integer, Transition<Integer>> experienceHandler,
|
||||||
int epsilonNbStep, Random rnd) {
|
int epsilonNbStep, Random rnd) {
|
||||||
super(mdp, dqn, conf, epsilonNbStep, rnd);
|
super(mdp, dqn, conf, epsilonNbStep, rnd);
|
||||||
addListener(new DataManagerTrainingListener(dataManager));
|
addListener(new DataManagerTrainingListener(dataManager));
|
||||||
setExpReplay(expReplay);
|
setExperienceHandler(experienceHandler);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
protected DataSet setTarget(ArrayList<Transition<Integer>> transitions) {
|
protected DataSet setTarget(List<Transition<Integer>> transitions) {
|
||||||
return new org.nd4j.linalg.dataset.DataSet(Nd4j.create(new double[] { 123.0 }), Nd4j.create(new double[] { 234.0 }));
|
return new org.nd4j.linalg.dataset.DataSet(Nd4j.create(new double[] { 123.0 }), Nd4j.create(new double[] { 234.0 }));
|
||||||
}
|
}
|
||||||
|
|
||||||
public void setExpReplay(IExpReplay<Integer> exp) {
|
|
||||||
this.expReplay = exp;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public IDataManager.StatEntry trainEpoch() {
|
public IDataManager.StatEntry trainEpoch() {
|
||||||
return super.trainEpoch();
|
return super.trainEpoch();
|
||||||
|
|
|
@ -19,6 +19,7 @@ public class MockDQN implements IDQN {
|
||||||
public final List<INDArray> outputParams = new ArrayList<>();
|
public final List<INDArray> outputParams = new ArrayList<>();
|
||||||
public final List<Pair<INDArray, INDArray>> fitParams = new ArrayList<>();
|
public final List<Pair<INDArray, INDArray>> fitParams = new ArrayList<>();
|
||||||
public final List<Pair<INDArray, INDArray>> gradientParams = new ArrayList<>();
|
public final List<Pair<INDArray, INDArray>> gradientParams = new ArrayList<>();
|
||||||
|
public final List<INDArray> outputAllParams = new ArrayList<>();
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public NeuralNetwork[] getNeuralNetworks() {
|
public NeuralNetwork[] getNeuralNetworks() {
|
||||||
|
@ -58,7 +59,8 @@ public class MockDQN implements IDQN {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public INDArray[] outputAll(INDArray batch) {
|
public INDArray[] outputAll(INDArray batch) {
|
||||||
return new INDArray[0];
|
outputAllParams.add(batch);
|
||||||
|
return new INDArray[] { batch.mul(-1.0) };
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
|
|
@ -1,22 +0,0 @@
|
||||||
package org.deeplearning4j.rl4j.support;
|
|
||||||
|
|
||||||
import org.deeplearning4j.rl4j.learning.sync.IExpReplay;
|
|
||||||
import org.deeplearning4j.rl4j.learning.sync.Transition;
|
|
||||||
|
|
||||||
import java.util.ArrayList;
|
|
||||||
import java.util.List;
|
|
||||||
|
|
||||||
public class MockExpReplay implements IExpReplay<Integer> {
|
|
||||||
|
|
||||||
public List<Transition<Integer>> transitions = new ArrayList<>();
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public ArrayList<Transition<Integer>> getBatch() {
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void store(Transition<Integer> transition) {
|
|
||||||
transitions.add(transition);
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -0,0 +1,46 @@
|
||||||
|
package org.deeplearning4j.rl4j.support;
|
||||||
|
|
||||||
|
import org.deeplearning4j.rl4j.experience.ExperienceHandler;
|
||||||
|
import org.deeplearning4j.rl4j.experience.StateActionPair;
|
||||||
|
import org.deeplearning4j.rl4j.learning.sync.Transition;
|
||||||
|
import org.deeplearning4j.rl4j.observation.Observation;
|
||||||
|
|
||||||
|
import java.util.ArrayList;
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
|
public class MockExperienceHandler implements ExperienceHandler<Integer, Transition<Integer>> {
|
||||||
|
public List<StateActionPair<Integer>> addExperienceArgs = new ArrayList<StateActionPair<Integer>>();
|
||||||
|
public Observation finalObservation;
|
||||||
|
public boolean isGenerateTrainingBatchCalled;
|
||||||
|
public boolean isResetCalled;
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void addExperience(Observation observation, Integer action, double reward, boolean isTerminal) {
|
||||||
|
addExperienceArgs.add(new StateActionPair<>(observation, action, reward, isTerminal));
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void setFinalObservation(Observation observation) {
|
||||||
|
finalObservation = observation;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public List<Transition<Integer>> generateTrainingBatch() {
|
||||||
|
isGenerateTrainingBatchCalled = true;
|
||||||
|
return new ArrayList<Transition<Integer>>() {
|
||||||
|
{
|
||||||
|
add(new Transition<Integer>(null, 0, 0.0, false));
|
||||||
|
}
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void reset() {
|
||||||
|
isResetCalled = true;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public int getTrainingBatchSize() {
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
}
|
|
@ -5,6 +5,16 @@ import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
|
||||||
public class MockObservationSpace implements ObservationSpace {
|
public class MockObservationSpace implements ObservationSpace {
|
||||||
|
|
||||||
|
private final int[] shape;
|
||||||
|
|
||||||
|
public MockObservationSpace() {
|
||||||
|
this(new int[] { 1 });
|
||||||
|
}
|
||||||
|
|
||||||
|
public MockObservationSpace(int[] shape) {
|
||||||
|
this.shape = shape;
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public String getName() {
|
public String getName() {
|
||||||
return null;
|
return null;
|
||||||
|
@ -12,7 +22,7 @@ public class MockObservationSpace implements ObservationSpace {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public int[] getShape() {
|
public int[] getShape() {
|
||||||
return new int[] { 1 };
|
return shape;
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
|
|
@ -0,0 +1,19 @@
|
||||||
|
package org.deeplearning4j.rl4j.support;
|
||||||
|
|
||||||
|
import org.deeplearning4j.nn.gradient.Gradient;
|
||||||
|
import org.deeplearning4j.rl4j.experience.StateActionPair;
|
||||||
|
import org.deeplearning4j.rl4j.learning.async.UpdateAlgorithm;
|
||||||
|
|
||||||
|
import java.util.ArrayList;
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
|
public class MockUpdateAlgorithm implements UpdateAlgorithm<MockNeuralNet> {
|
||||||
|
|
||||||
|
public final List<List<StateActionPair<Integer>>> experiences = new ArrayList<List<StateActionPair<Integer>>>();
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public Gradient[] computeGradients(MockNeuralNet current, List<StateActionPair<Integer>> experience) {
|
||||||
|
experiences.add(experience);
|
||||||
|
return new Gradient[0];
|
||||||
|
}
|
||||||
|
}
|
Loading…
Reference in New Issue