RL4J: Extract TD Target calculations (StandardDQN and DoubleDQN) (#8267)

Signed-off-by: Alexandre Boulanger <aboulang2002@yahoo.com>
master
Alexandre Boulanger 2019-10-08 20:14:47 -04:00 committed by Samuel Audet
parent 50b13fadc8
commit 3aa51e210a
16 changed files with 732 additions and 107 deletions

View File

@ -16,6 +16,7 @@
package org.deeplearning4j.rl4j.learning.sync;
import lombok.AllArgsConstructor;
import lombok.Value;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
@ -27,6 +28,7 @@ import org.nd4j.linalg.factory.Nd4j;
* State, Action, Reward, (isTerminal), State
*/
@Value
@AllArgsConstructor
public class Transition<A> {
INDArray[] observation;

View File

@ -43,7 +43,7 @@ import java.util.List;
*/
@Slf4j
public abstract class QLearning<O extends Encodable, A, AS extends ActionSpace<A>>
extends SyncLearning<O, A, AS, IDQN> {
extends SyncLearning<O, A, AS, IDQN> implements TargetQNetworkSource {
// FIXME Changed for refac
// @Getter
@ -61,28 +61,19 @@ public abstract class QLearning<O extends Encodable, A, AS extends ActionSpace<A
public abstract MDP<O, A, AS> getMdp();
protected abstract IDQN getCurrentDQN();
public abstract IDQN getQNetwork();
protected abstract IDQN getTargetDQN();
public abstract IDQN getTargetQNetwork();
protected abstract void setTargetDQN(IDQN dqn);
protected INDArray dqnOutput(INDArray input) {
return getCurrentDQN().output(input);
}
protected INDArray targetDqnOutput(INDArray input) {
return getTargetDQN().output(input);
}
protected abstract void setTargetQNetwork(IDQN dqn);
protected void updateTargetNetwork() {
log.info("Update target network");
setTargetDQN(getCurrentDQN().clone());
setTargetQNetwork(getQNetwork().clone());
}
public IDQN getNeuralNet() {
return getCurrentDQN();
return getQNetwork();
}
public abstract QLConfiguration getConfiguration();

View File

@ -0,0 +1,28 @@
/*******************************************************************************
* Copyright (c) 2015-2019 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.rl4j.learning.sync.qlearning;
import org.deeplearning4j.rl4j.network.dqn.IDQN;
/**
* An interface for all implementations capable of supplying a Q-Network
*
* @author Alexandre Boulanger
*/
public interface QNetworkSource {
IDQN getQNetwork();
}

View File

@ -0,0 +1,28 @@
/*******************************************************************************
* Copyright (c) 2015-2019 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.rl4j.learning.sync.qlearning;
import org.deeplearning4j.rl4j.network.dqn.IDQN;
/**
* An interface that is an extension of {@link QNetworkSource} for all implementations capable of supplying a target Q-Network
*
* @author Alexandre Boulanger
*/
public interface TargetQNetworkSource extends QNetworkSource {
IDQN getTargetQNetwork();
}

View File

@ -16,12 +16,14 @@
package org.deeplearning4j.rl4j.learning.sync.qlearning.discrete;
import lombok.AccessLevel;
import lombok.Getter;
import lombok.Setter;
import org.deeplearning4j.gym.StepReply;
import org.deeplearning4j.rl4j.learning.Learning;
import org.deeplearning4j.rl4j.learning.sync.Transition;
import org.deeplearning4j.rl4j.learning.sync.qlearning.QLearning;
import org.deeplearning4j.rl4j.learning.sync.qlearning.discrete.TDTargetAlgorithm.*;
import org.deeplearning4j.rl4j.mdp.MDP;
import org.deeplearning4j.rl4j.network.dqn.IDQN;
import org.deeplearning4j.rl4j.policy.DQNPolicy;
@ -29,10 +31,7 @@ import org.deeplearning4j.rl4j.policy.EpsGreedy;
import org.deeplearning4j.rl4j.space.DiscreteSpace;
import org.deeplearning4j.rl4j.space.Encodable;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.INDArrayIndex;
import org.nd4j.linalg.indexing.NDArrayIndex;
import org.nd4j.linalg.primitives.Pair;
import org.nd4j.linalg.dataset.api.DataSet;
import org.nd4j.linalg.util.ArrayUtil;
import java.util.ArrayList;
@ -53,29 +52,38 @@ public abstract class QLearningDiscrete<O extends Encodable> extends QLearning<O
@Getter
final private MDP<O, Integer, DiscreteSpace> mdp;
@Getter
final private IDQN currentDQN;
@Getter
private DQNPolicy<O> policy;
@Getter
private EpsGreedy<O, Integer, DiscreteSpace> egPolicy;
@Getter
@Setter
private IDQN targetDQN;
final private IDQN qNetwork;
@Getter
@Setter(AccessLevel.PROTECTED)
private IDQN targetQNetwork;
private int lastAction;
private INDArray[] history = null;
private double accuReward = 0;
ITDTargetAlgorithm tdTargetAlgorithm;
public QLearningDiscrete(MDP<O, Integer, DiscreteSpace> mdp, IDQN dqn, QLConfiguration conf,
int epsilonNbStep) {
super(conf);
this.configuration = conf;
this.mdp = mdp;
currentDQN = dqn;
targetDQN = dqn.clone();
policy = new DQNPolicy(getCurrentDQN());
qNetwork = dqn;
targetQNetwork = dqn.clone();
policy = new DQNPolicy(getQNetwork());
egPolicy = new EpsGreedy(policy, mdp, conf.getUpdateStart(), epsilonNbStep, getRandom(), conf.getMinEpsilon(),
this);
mdp.getActionSpace().setSeed(conf.getSeed());
tdTargetAlgorithm = conf.isDoubleDQN()
? new DoubleDQN(this, conf.getGamma(), conf.getErrorClamp())
: new StandardDQN(this, conf.getGamma(), conf.getErrorClamp());
}
public void postEpoch() {
@ -134,7 +142,7 @@ public abstract class QLearningDiscrete<O extends Encodable> extends QLearning<O
if (hstack.shape().length > 2)
hstack = hstack.reshape(Learning.makeShape(1, ArrayUtil.toInts(hstack.shape())));
INDArray qs = getCurrentDQN().output(hstack);
INDArray qs = getQNetwork().output(hstack);
int maxAction = Learning.getMaxAction(qs);
maxQ = qs.getDouble(maxAction);
@ -160,96 +168,31 @@ public abstract class QLearningDiscrete<O extends Encodable> extends QLearning<O
getExpReplay().store(trans);
if (getStepCounter() > updateStart) {
Pair<INDArray, INDArray> targets = setTarget(getExpReplay().getBatch());
getCurrentDQN().fit(targets.getFirst(), targets.getSecond());
DataSet targets = setTarget(getExpReplay().getBatch());
getQNetwork().fit(targets.getFeatures(), targets.getLabels());
}
history = nhistory;
accuReward = 0;
}
return new QLStepReturn<O>(maxQ, getCurrentDQN().getLatestScore(), stepReply);
return new QLStepReturn<O>(maxQ, getQNetwork().getLatestScore(), stepReply);
}
protected Pair<INDArray, INDArray> setTarget(ArrayList<Transition<Integer>> transitions) {
protected DataSet setTarget(ArrayList<Transition<Integer>> transitions) {
if (transitions.size() == 0)
throw new IllegalArgumentException("too few transitions");
int size = transitions.size();
// TODO: Remove once we use DataSets in observations
int[] shape = getHistoryProcessor() == null ? getMdp().getObservationSpace().getShape()
: getHistoryProcessor().getConf().getShape();
int[] nshape = makeShape(size, shape);
INDArray obs = Nd4j.create(nshape);
INDArray nextObs = Nd4j.create(nshape);
int[] actions = new int[size];
boolean[] areTerminal = new boolean[size];
((BaseTDTargetAlgorithm) tdTargetAlgorithm).setNShape(makeShape(transitions.size(), shape));
for (int i = 0; i < size; i++) {
Transition<Integer> trans = transitions.get(i);
areTerminal[i] = trans.isTerminal();
actions[i] = trans.getAction();
INDArray[] obsArray = trans.getObservation();
if (obs.rank() == 2) {
obs.putRow(i, obsArray[0]);
} else {
for (int j = 0; j < obsArray.length; j++) {
obs.put(new INDArrayIndex[] {NDArrayIndex.point(i), NDArrayIndex.point(j)}, obsArray[j]);
}
}
INDArray[] nextObsArray = Transition.append(trans.getObservation(), trans.getNextObservation());
if (nextObs.rank() == 2) {
nextObs.putRow(i, nextObsArray[0]);
} else {
for (int j = 0; j < nextObsArray.length; j++) {
nextObs.put(new INDArrayIndex[] {NDArrayIndex.point(i), NDArrayIndex.point(j)}, nextObsArray[j]);
}
}
}
if (getHistoryProcessor() != null) {
obs.muli(1.0 / getHistoryProcessor().getScale());
nextObs.muli(1.0 / getHistoryProcessor().getScale());
// TODO: Remove once we use DataSets in observations
if(getHistoryProcessor() != null) {
((BaseTDTargetAlgorithm) tdTargetAlgorithm).setScale(getHistoryProcessor().getScale());
}
INDArray dqnOutputAr = dqnOutput(obs);
INDArray dqnOutputNext = dqnOutput(nextObs);
INDArray targetDqnOutputNext = targetDqnOutput(nextObs);
INDArray tempQ = null;
INDArray getMaxAction = null;
if (getConfiguration().isDoubleDQN()) {
getMaxAction = Nd4j.argMax(dqnOutputNext, 1);
} else {
tempQ = Nd4j.max(targetDqnOutputNext, 1);
}
for (int i = 0; i < size; i++) {
double yTar = transitions.get(i).getReward();
if (!areTerminal[i]) {
double q = 0;
if (getConfiguration().isDoubleDQN()) {
q += targetDqnOutputNext.getDouble(i, getMaxAction.getInt(i));
} else
q += tempQ.getDouble(i);
yTar += getConfiguration().getGamma() * q;
}
double previousV = dqnOutputAr.getDouble(i, actions[i]);
double lowB = previousV - getConfiguration().getErrorClamp();
double highB = previousV + getConfiguration().getErrorClamp();
double clamped = Math.min(highB, Math.max(yTar, lowB));
dqnOutputAr.putScalar(i, actions[i], clamped);
}
return new Pair(obs, dqnOutputAr);
return tdTargetAlgorithm.computeTDTargets(transitions);
}
}

View File

@ -0,0 +1,62 @@
/*******************************************************************************
* Copyright (c) 2015-2019 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.rl4j.learning.sync.qlearning.discrete.TDTargetAlgorithm;
import org.deeplearning4j.rl4j.learning.sync.qlearning.TargetQNetworkSource;
import org.deeplearning4j.rl4j.network.dqn.IDQN;
import org.nd4j.linalg.api.ndarray.INDArray;
/**
* The base of all DQN based algorithms
*
* @author Alexandre Boulanger
*
*/
public abstract class BaseDQNAlgorithm extends BaseTDTargetAlgorithm {
private final TargetQNetworkSource qTargetNetworkSource;
/**
* In litterature, this corresponds to Q{net}(s(t+1), a)
*/
protected INDArray qNetworkNextObservation;
/**
* In litterature, this corresponds to Q{tnet}(s(t+1), a)
*/
protected INDArray targetQNetworkNextObservation;
protected BaseDQNAlgorithm(TargetQNetworkSource qTargetNetworkSource, double gamma) {
super(qTargetNetworkSource, gamma);
this.qTargetNetworkSource = qTargetNetworkSource;
}
protected BaseDQNAlgorithm(TargetQNetworkSource qTargetNetworkSource, double gamma, double errorClamp) {
super(qTargetNetworkSource, gamma, errorClamp);
this.qTargetNetworkSource = qTargetNetworkSource;
}
@Override
protected void initComputation(INDArray observations, INDArray nextObservations) {
super.initComputation(observations, nextObservations);
qNetworkNextObservation = qNetworkSource.getQNetwork().output(nextObservations);
IDQN targetQNetwork = qTargetNetworkSource.getTargetQNetwork();
targetQNetworkNextObservation = targetQNetwork.output(nextObservations);
}
}

View File

@ -0,0 +1,147 @@
/*******************************************************************************
* Copyright (c) 2015-2019 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.rl4j.learning.sync.qlearning.discrete.TDTargetAlgorithm;
import lombok.Setter;
import org.deeplearning4j.rl4j.learning.sync.Transition;
import org.deeplearning4j.rl4j.learning.sync.qlearning.QNetworkSource;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.api.DataSet;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.INDArrayIndex;
import org.nd4j.linalg.indexing.NDArrayIndex;
import java.util.List;
/**
* The base of all TD target calculation algorithms that use deep learning.
*
* @author Alexandre Boulanger
*/
public abstract class BaseTDTargetAlgorithm implements ITDTargetAlgorithm<Integer> {
protected final QNetworkSource qNetworkSource;
protected final double gamma;
private final double errorClamp;
private final boolean isClamped;
@Setter
private int[] nShape; // TODO: Remove once we use DataSets in observations
@Setter
private double scale = 1.0; // TODO: Remove once we use DataSets in observations
/**
*
* @param qNetworkSource The source of the Q-Network
* @param gamma The discount factor
* @param errorClamp Will prevent the new Q-Value from being farther than <i>errorClamp</i> away from the previous value. Double.NaN will disable the clamping.
*/
protected BaseTDTargetAlgorithm(QNetworkSource qNetworkSource, double gamma, double errorClamp) {
this.qNetworkSource = qNetworkSource;
this.gamma = gamma;
this.errorClamp = errorClamp;
isClamped = !Double.isNaN(errorClamp);
}
/**
*
* @param qNetworkSource The source of the Q-Network
* @param gamma The discount factor
* Note: Error clamping is disabled with this ctor
*/
protected BaseTDTargetAlgorithm(QNetworkSource qNetworkSource, double gamma) {
this(qNetworkSource, gamma, Double.NaN);
}
/**
* Called just before the calculation starts
* @param observations A INDArray of all observations stacked on dimension 0
* @param nextObservations A INDArray of all next observations stacked on dimension 0
*/
protected void initComputation(INDArray observations, INDArray nextObservations) {
// Do nothing
}
/**
* Compute the new estimated Q-Value for every transition in the batch
*
* @param batchIdx The index in the batch of the current transition
* @param reward The reward of the current transition
* @param isTerminal True if it's the last transition of the "game"
* @return The estimated Q-Value
*/
protected abstract double computeTarget(int batchIdx, double reward, boolean isTerminal);
@Override
public DataSet computeTDTargets(List<Transition<Integer>> transitions) {
int size = transitions.size();
INDArray observations = Nd4j.create(nShape);
INDArray nextObservations = Nd4j.create(nShape);
// TODO: Remove once we use DataSets in observations
for (int i = 0; i < size; i++) {
Transition<Integer> trans = transitions.get(i);
INDArray[] obsArray = trans.getObservation();
if (observations.rank() == 2) {
observations.putRow(i, obsArray[0]);
} else {
for (int j = 0; j < obsArray.length; j++) {
observations.put(new INDArrayIndex[] {NDArrayIndex.point(i), NDArrayIndex.point(j)}, obsArray[j]);
}
}
INDArray[] nextObsArray = Transition.append(trans.getObservation(), trans.getNextObservation());
if (nextObservations.rank() == 2) {
nextObservations.putRow(i, nextObsArray[0]);
} else {
for (int j = 0; j < nextObsArray.length; j++) {
nextObservations.put(new INDArrayIndex[] {NDArrayIndex.point(i), NDArrayIndex.point(j)}, nextObsArray[j]);
}
}
}
// TODO: Remove once we use DataSets in observations
if(scale != 1.0) {
observations.muli(1.0 / scale);
nextObservations.muli(1.0 / scale);
}
initComputation(observations, nextObservations);
INDArray updatedQValues = qNetworkSource.getQNetwork().output(observations);
for (int i = 0; i < size; ++i) {
Transition<Integer> transition = transitions.get(i);
double yTarget = computeTarget(i, transition.getReward(), transition.isTerminal());
if(isClamped) {
double previousQValue = updatedQValues.getDouble(i, transition.getAction());
double lowBound = previousQValue - errorClamp;
double highBound = previousQValue + errorClamp;
yTarget = Math.min(highBound, Math.max(yTarget, lowBound));
}
updatedQValues.putScalar(i, transition.getAction(), yTarget);
}
return new org.nd4j.linalg.dataset.DataSet(observations, updatedQValues);
}
}

View File

@ -0,0 +1,67 @@
/*******************************************************************************
* Copyright (c) 2015-2019 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.rl4j.learning.sync.qlearning.discrete.TDTargetAlgorithm;
import org.deeplearning4j.rl4j.learning.sync.qlearning.TargetQNetworkSource;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
/**
* The Double-DQN algorithm based on "Deep Reinforcement Learning with Double Q-learning" (https://arxiv.org/abs/1509.06461)
*
* @author Alexandre Boulanger
*/
public class DoubleDQN extends BaseDQNAlgorithm {
private static final int ACTION_DIMENSION_IDX = 1;
// In litterature, this corresponds to: max_{a}Q(s_{t+1}, a)
private INDArray maxActionsFromQNetworkNextObservation;
public DoubleDQN(TargetQNetworkSource qTargetNetworkSource, double gamma) {
super(qTargetNetworkSource, gamma);
}
public DoubleDQN(TargetQNetworkSource qTargetNetworkSource, double gamma, double errorClamp) {
super(qTargetNetworkSource, gamma, errorClamp);
}
@Override
protected void initComputation(INDArray observations, INDArray nextObservations) {
super.initComputation(observations, nextObservations);
maxActionsFromQNetworkNextObservation = Nd4j.argMax(qNetworkNextObservation, ACTION_DIMENSION_IDX);
}
/**
* In litterature, this corresponds to:<br />
* Q(s_t, a_t) = R_{t+1} + \gamma * Q_{tar}(s_{t+1}, max_{a}Q(s_{t+1}, a))
* @param batchIdx The index in the batch of the current transition
* @param reward The reward of the current transition
* @param isTerminal True if it's the last transition of the "game"
* @return The estimated Q-Value
*/
@Override
protected double computeTarget(int batchIdx, double reward, boolean isTerminal) {
double yTarget = reward;
if (!isTerminal) {
yTarget += gamma * targetQNetworkNextObservation.getDouble(batchIdx, maxActionsFromQNetworkNextObservation.getInt(batchIdx));
}
return yTarget;
}
}

View File

@ -0,0 +1,38 @@
/*******************************************************************************
* Copyright (c) 2015-2019 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.rl4j.learning.sync.qlearning.discrete.TDTargetAlgorithm;
import org.deeplearning4j.rl4j.learning.sync.Transition;
import org.nd4j.linalg.dataset.api.DataSet;
import java.util.List;
/**
* The interface of all TD target calculation algorithms.
*
* @param <A> The type of actions
*
* @author Alexandre Boulanger
*/
public interface ITDTargetAlgorithm<A> {
/**
* Compute the updated estimated Q-Values for every transition
* @param transitions The transitions from the experience replay
* @return A DataSet where every element is the observation and the estimated Q-Values for all actions
*/
DataSet computeTDTargets(List<Transition<A>> transitions);
}

View File

@ -0,0 +1,67 @@
/*******************************************************************************
* Copyright (c) 2015-2019 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.rl4j.learning.sync.qlearning.discrete.TDTargetAlgorithm;
import org.deeplearning4j.rl4j.learning.sync.qlearning.TargetQNetworkSource;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
/**
* The Standard DQN algorithm based on "Playing Atari with Deep Reinforcement Learning" (https://arxiv.org/abs/1312.5602)
*
* @author Alexandre Boulanger
*/
public class StandardDQN extends BaseDQNAlgorithm {
private static final int ACTION_DIMENSION_IDX = 1;
// In litterature, this corresponds to: max_{a}Q_{tar}(s_{t+1}, a)
private INDArray maxActionsFromQTargetNextObservation;
public StandardDQN(TargetQNetworkSource qTargetNetworkSource, double gamma) {
super(qTargetNetworkSource, gamma);
}
public StandardDQN(TargetQNetworkSource qTargetNetworkSource, double gamma, double errorClamp) {
super(qTargetNetworkSource, gamma, errorClamp);
}
@Override
protected void initComputation(INDArray observations, INDArray nextObservations) {
super.initComputation(observations, nextObservations);
maxActionsFromQTargetNextObservation = Nd4j.max(targetQNetworkNextObservation, ACTION_DIMENSION_IDX);
}
/**
* In litterature, this corresponds to:<br />
* Q(s_t, a_t) = R_{t+1} + \gamma * max_{a}Q_{tar}(s_{t+1}, a)
* @param batchIdx The index in the batch of the current transition
* @param reward The reward of the current transition
* @param isTerminal True if it's the last transition of the "game"
* @return The estimated Q-Value
*/
@Override
protected double computeTarget(int batchIdx, double reward, boolean isTerminal) {
double yTarget = reward;
if (!isTerminal) {
yTarget += gamma * maxActionsFromQTargetNextObservation.getDouble(batchIdx);
}
return yTarget;
}
}

View File

@ -35,7 +35,7 @@ public interface IDQN<NN extends IDQN> extends NeuralNet<NN> {
void fit(INDArray input, INDArray labels);
void fit(INDArray input, INDArray[] labels);
INDArray output(INDArray batch);
INDArray[] outputAll(INDArray batch);

View File

@ -12,8 +12,8 @@ import org.deeplearning4j.rl4j.util.DataManagerTrainingListener;
import org.deeplearning4j.rl4j.util.IDataManager;
import org.junit.Test;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.api.DataSet;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.primitives.Pair;
import java.util.ArrayList;
import java.util.List;
@ -139,8 +139,8 @@ public class QLearningDiscreteTest {
}
@Override
protected Pair<INDArray, INDArray> setTarget(ArrayList<Transition<Integer>> transitions) {
return new Pair<>(Nd4j.create(new double[] { 123.0 }), Nd4j.create(new double[] { 234.0 }));
protected DataSet setTarget(ArrayList<Transition<Integer>> transitions) {
return new org.nd4j.linalg.dataset.DataSet(Nd4j.create(new double[] { 123.0 }), Nd4j.create(new double[] { 234.0 }));
}
public void setExpReplay(IExpReplay<Integer> exp){

View File

@ -0,0 +1,105 @@
package org.deeplearning4j.rl4j.learning.sync.qlearning.discrete.TDTargetAlgorithm;
import org.deeplearning4j.rl4j.learning.sync.Transition;
import org.deeplearning4j.rl4j.learning.sync.support.MockDQN;
import org.deeplearning4j.rl4j.learning.sync.support.MockTargetQNetworkSource;
import org.junit.Test;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.api.DataSet;
import org.nd4j.linalg.factory.Nd4j;
import java.util.ArrayList;
import java.util.List;
import static org.junit.Assert.assertEquals;
public class DoubleDQNTest {
@Test
public void when_isTerminal_expect_rewardValueAtIdx0() {
// Assemble
MockDQN qNetwork = new MockDQN();
MockDQN targetQNetwork = new MockDQN();
MockTargetQNetworkSource targetQNetworkSource = new MockTargetQNetworkSource(qNetwork, targetQNetwork);
List<Transition<Integer>> transitions = new ArrayList<Transition<Integer>>() {
{
add(new Transition<Integer>(new INDArray[]{Nd4j.create(new double[]{1.1, 2.2})}, 0, 1.0, true, Nd4j.create(new double[]{11.0, 22.0})));
}
};
DoubleDQN sut = new DoubleDQN(targetQNetworkSource, 0.5);
sut.setNShape(new int[] { 1, 2 });
// Act
DataSet result = sut.computeTDTargets(transitions);
// Assert
INDArray evaluatedQValues = result.getLabels();
assertEquals(1.0, evaluatedQValues.getDouble(0, 0), 0.0001);
assertEquals(2.2, evaluatedQValues.getDouble(0, 1), 0.0001);
}
@Test
public void when_isNotTerminal_expect_rewardPlusEstimatedQValue() {
// Assemble
MockDQN qNetwork = new MockDQN();
MockDQN targetQNetwork = new MockDQN(-1.0);
MockTargetQNetworkSource targetQNetworkSource = new MockTargetQNetworkSource(qNetwork, targetQNetwork);
List<Transition<Integer>> transitions = new ArrayList<Transition<Integer>>() {
{
add(new Transition<Integer>(new INDArray[]{Nd4j.create(new double[]{1.1, 2.2})}, 0, 1.0, false, Nd4j.create(new double[]{11.0, 22.0})));
}
};
DoubleDQN sut = new DoubleDQN(targetQNetworkSource, 0.5);
sut.setNShape(new int[] { 1, 2 });
// Act
DataSet result = sut.computeTDTargets(transitions);
// Assert
INDArray evaluatedQValues = result.getLabels();
assertEquals(1.0 + 0.5 * -22.0, evaluatedQValues.getDouble(0, 0), 0.0001);
assertEquals(2.2, evaluatedQValues.getDouble(0, 1), 0.0001);
}
@Test
public void when_batchHasMoreThanOne_expect_everySampleEvaluated() {
// Assemble
MockDQN qNetwork = new MockDQN();
MockDQN targetQNetwork = new MockDQN(-1.0);
MockTargetQNetworkSource targetQNetworkSource = new MockTargetQNetworkSource(qNetwork, targetQNetwork);
List<Transition<Integer>> transitions = new ArrayList<Transition<Integer>>() {
{
add(new Transition<Integer>(new INDArray[]{Nd4j.create(new double[]{1.1, 2.2})}, 0, 1.0, false, Nd4j.create(new double[]{11.0, 22.0})));
add(new Transition<Integer>(new INDArray[]{Nd4j.create(new double[]{3.3, 4.4})}, 1, 2.0, false, Nd4j.create(new double[]{33.0, 44.0})));
add(new Transition<Integer>(new INDArray[]{Nd4j.create(new double[]{5.5, 6.6})}, 0, 3.0, true, Nd4j.create(new double[]{55.0, 66.0})));
}
};
DoubleDQN sut = new DoubleDQN(targetQNetworkSource, 0.5);
sut.setNShape(new int[] { 3, 2 });
// Act
DataSet result = sut.computeTDTargets(transitions);
// Assert
INDArray evaluatedQValues = result.getLabels();
assertEquals(1.0 + 0.5 * -22.0, evaluatedQValues.getDouble(0, 0), 0.0001);
assertEquals(2.2, evaluatedQValues.getDouble(0, 1), 0.0001);
assertEquals(3.3, evaluatedQValues.getDouble(1, 0), 0.0001);
assertEquals(2.0 + 0.5 * -44.0, evaluatedQValues.getDouble(1, 1), 0.0001);
assertEquals(3.0, evaluatedQValues.getDouble(2, 0), 0.0001); // terminal: reward only
assertEquals(6.6, evaluatedQValues.getDouble(2, 1), 0.0001);
}
}

View File

@ -0,0 +1,104 @@
package org.deeplearning4j.rl4j.learning.sync.qlearning.discrete.TDTargetAlgorithm;
import org.deeplearning4j.rl4j.learning.sync.Transition;
import org.deeplearning4j.rl4j.learning.sync.support.MockDQN;
import org.deeplearning4j.rl4j.learning.sync.support.MockTargetQNetworkSource;
import org.junit.Test;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.api.DataSet;
import org.nd4j.linalg.factory.Nd4j;
import java.util.ArrayList;
import java.util.List;
import static org.junit.Assert.*;
public class StandardDQNTest {
@Test
public void when_isTerminal_expect_rewardValueAtIdx0() {
// Assemble
MockDQN qNetwork = new MockDQN();
MockDQN targetQNetwork = new MockDQN();
MockTargetQNetworkSource targetQNetworkSource = new MockTargetQNetworkSource(qNetwork, targetQNetwork);
List<Transition<Integer>> transitions = new ArrayList<Transition<Integer>>() {
{
add(new Transition<Integer>(new INDArray[]{Nd4j.create(new double[]{1.1, 2.2})}, 0, 1.0, true, Nd4j.create(new double[]{11.0, 22.0})));
}
};
StandardDQN sut = new StandardDQN(targetQNetworkSource, 0.5);
sut.setNShape(new int[] { 1, 2 });
// Act
DataSet result = sut.computeTDTargets(transitions);
// Assert
INDArray evaluatedQValues = result.getLabels();
assertEquals(1.0, evaluatedQValues.getDouble(0, 0), 0.0001);
assertEquals(2.2, evaluatedQValues.getDouble(0, 1), 0.0001);
}
@Test
public void when_isNotTerminal_expect_rewardPlusEstimatedQValue() {
// Assemble
MockDQN qNetwork = new MockDQN();
MockDQN targetQNetwork = new MockDQN();
MockTargetQNetworkSource targetQNetworkSource = new MockTargetQNetworkSource(qNetwork, targetQNetwork);
List<Transition<Integer>> transitions = new ArrayList<Transition<Integer>>() {
{
add(new Transition<Integer>(new INDArray[]{Nd4j.create(new double[]{1.1, 2.2})}, 0, 1.0, false, Nd4j.create(new double[]{11.0, 22.0})));
}
};
StandardDQN sut = new StandardDQN(targetQNetworkSource, 0.5);
sut.setNShape(new int[] { 1, 2 });
// Act
DataSet result = sut.computeTDTargets(transitions);
// Assert
INDArray evaluatedQValues = result.getLabels();
assertEquals(1.0 + 0.5 * 22.0, evaluatedQValues.getDouble(0, 0), 0.0001);
assertEquals(2.2, evaluatedQValues.getDouble(0, 1), 0.0001);
}
@Test
public void when_batchHasMoreThanOne_expect_everySampleEvaluated() {
// Assemble
MockDQN qNetwork = new MockDQN();
MockDQN targetQNetwork = new MockDQN();
MockTargetQNetworkSource targetQNetworkSource = new MockTargetQNetworkSource(qNetwork, targetQNetwork);
List<Transition<Integer>> transitions = new ArrayList<Transition<Integer>>() {
{
add(new Transition<Integer>(new INDArray[]{Nd4j.create(new double[]{1.1, 2.2})}, 0, 1.0, false, Nd4j.create(new double[]{11.0, 22.0})));
add(new Transition<Integer>(new INDArray[]{Nd4j.create(new double[]{3.3, 4.4})}, 1, 2.0, false, Nd4j.create(new double[]{33.0, 44.0})));
add(new Transition<Integer>(new INDArray[]{Nd4j.create(new double[]{5.5, 6.6})}, 0, 3.0, true, Nd4j.create(new double[]{55.0, 66.0})));
}
};
StandardDQN sut = new StandardDQN(targetQNetworkSource, 0.5);
sut.setNShape(new int[] { 3, 2 });
// Act
DataSet result = sut.computeTDTargets(transitions);
// Assert
INDArray evaluatedQValues = result.getLabels();
assertEquals((1.0 + 0.5 * 22.0), evaluatedQValues.getDouble(0, 0), 0.0001);
assertEquals(2.2, evaluatedQValues.getDouble(0, 1), 0.0001);
assertEquals(3.3, evaluatedQValues.getDouble(1, 0), 0.0001);
assertEquals((2.0 + 0.5 * 44.0), evaluatedQValues.getDouble(1, 1), 0.0001);
assertEquals(3.0, evaluatedQValues.getDouble(2, 0), 0.0001); // terminal: reward only
assertEquals(6.6, evaluatedQValues.getDouble(2, 1), 0.0001);
}
}

View File

@ -1,15 +1,28 @@
package org.deeplearning4j.rl4j.learning.sync.support;
import lombok.Setter;
import org.deeplearning4j.nn.api.NeuralNetwork;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.rl4j.network.NeuralNet;
import org.deeplearning4j.rl4j.network.dqn.IDQN;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import java.io.IOException;
import java.io.OutputStream;
public class MockDQN implements IDQN {
private final double mult;
public MockDQN() {
this(1.0);
}
public MockDQN(double mult) {
this.mult = mult;
}
@Override
public NeuralNetwork[] getNeuralNetworks() {
return new NeuralNetwork[0];
@ -37,7 +50,11 @@ public class MockDQN implements IDQN {
@Override
public INDArray output(INDArray batch) {
return null;
if(mult != 1.0) {
return batch.dup().muli(mult);
}
return batch;
}
@Override

View File

@ -0,0 +1,26 @@
package org.deeplearning4j.rl4j.learning.sync.support;
import org.deeplearning4j.rl4j.learning.sync.qlearning.TargetQNetworkSource;
import org.deeplearning4j.rl4j.network.dqn.IDQN;
public class MockTargetQNetworkSource implements TargetQNetworkSource {
private final IDQN qNetwork;
private final IDQN targetQNetwork;
public MockTargetQNetworkSource(IDQN qNetwork, IDQN targetQNetwork) {
this.qNetwork = qNetwork;
this.targetQNetwork = targetQNetwork;
}
@Override
public IDQN getTargetQNetwork() {
return targetQNetwork;
}
@Override
public IDQN getQNetwork() {
return qNetwork;
}
}