RL4J: Extract TD Target calculations (StandardDQN and DoubleDQN) (#8267)
Signed-off-by: Alexandre Boulanger <aboulang2002@yahoo.com>master
parent
50b13fadc8
commit
3aa51e210a
|
@ -16,6 +16,7 @@
|
||||||
|
|
||||||
package org.deeplearning4j.rl4j.learning.sync;
|
package org.deeplearning4j.rl4j.learning.sync;
|
||||||
|
|
||||||
|
import lombok.AllArgsConstructor;
|
||||||
import lombok.Value;
|
import lombok.Value;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
|
@ -27,6 +28,7 @@ import org.nd4j.linalg.factory.Nd4j;
|
||||||
* State, Action, Reward, (isTerminal), State
|
* State, Action, Reward, (isTerminal), State
|
||||||
*/
|
*/
|
||||||
@Value
|
@Value
|
||||||
|
@AllArgsConstructor
|
||||||
public class Transition<A> {
|
public class Transition<A> {
|
||||||
|
|
||||||
INDArray[] observation;
|
INDArray[] observation;
|
||||||
|
|
|
@ -43,7 +43,7 @@ import java.util.List;
|
||||||
*/
|
*/
|
||||||
@Slf4j
|
@Slf4j
|
||||||
public abstract class QLearning<O extends Encodable, A, AS extends ActionSpace<A>>
|
public abstract class QLearning<O extends Encodable, A, AS extends ActionSpace<A>>
|
||||||
extends SyncLearning<O, A, AS, IDQN> {
|
extends SyncLearning<O, A, AS, IDQN> implements TargetQNetworkSource {
|
||||||
|
|
||||||
// FIXME Changed for refac
|
// FIXME Changed for refac
|
||||||
// @Getter
|
// @Getter
|
||||||
|
@ -61,28 +61,19 @@ public abstract class QLearning<O extends Encodable, A, AS extends ActionSpace<A
|
||||||
|
|
||||||
public abstract MDP<O, A, AS> getMdp();
|
public abstract MDP<O, A, AS> getMdp();
|
||||||
|
|
||||||
protected abstract IDQN getCurrentDQN();
|
public abstract IDQN getQNetwork();
|
||||||
|
|
||||||
protected abstract IDQN getTargetDQN();
|
public abstract IDQN getTargetQNetwork();
|
||||||
|
|
||||||
protected abstract void setTargetDQN(IDQN dqn);
|
protected abstract void setTargetQNetwork(IDQN dqn);
|
||||||
|
|
||||||
protected INDArray dqnOutput(INDArray input) {
|
|
||||||
return getCurrentDQN().output(input);
|
|
||||||
}
|
|
||||||
|
|
||||||
protected INDArray targetDqnOutput(INDArray input) {
|
|
||||||
return getTargetDQN().output(input);
|
|
||||||
}
|
|
||||||
|
|
||||||
protected void updateTargetNetwork() {
|
protected void updateTargetNetwork() {
|
||||||
log.info("Update target network");
|
log.info("Update target network");
|
||||||
setTargetDQN(getCurrentDQN().clone());
|
setTargetQNetwork(getQNetwork().clone());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
public IDQN getNeuralNet() {
|
public IDQN getNeuralNet() {
|
||||||
return getCurrentDQN();
|
return getQNetwork();
|
||||||
}
|
}
|
||||||
|
|
||||||
public abstract QLConfiguration getConfiguration();
|
public abstract QLConfiguration getConfiguration();
|
||||||
|
|
|
@ -0,0 +1,28 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2015-2019 Skymind, Inc.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
package org.deeplearning4j.rl4j.learning.sync.qlearning;
|
||||||
|
|
||||||
|
import org.deeplearning4j.rl4j.network.dqn.IDQN;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* An interface for all implementations capable of supplying a Q-Network
|
||||||
|
*
|
||||||
|
* @author Alexandre Boulanger
|
||||||
|
*/
|
||||||
|
public interface QNetworkSource {
|
||||||
|
IDQN getQNetwork();
|
||||||
|
}
|
|
@ -0,0 +1,28 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2015-2019 Skymind, Inc.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
package org.deeplearning4j.rl4j.learning.sync.qlearning;
|
||||||
|
|
||||||
|
import org.deeplearning4j.rl4j.network.dqn.IDQN;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* An interface that is an extension of {@link QNetworkSource} for all implementations capable of supplying a target Q-Network
|
||||||
|
*
|
||||||
|
* @author Alexandre Boulanger
|
||||||
|
*/
|
||||||
|
public interface TargetQNetworkSource extends QNetworkSource {
|
||||||
|
IDQN getTargetQNetwork();
|
||||||
|
}
|
|
@ -16,12 +16,14 @@
|
||||||
|
|
||||||
package org.deeplearning4j.rl4j.learning.sync.qlearning.discrete;
|
package org.deeplearning4j.rl4j.learning.sync.qlearning.discrete;
|
||||||
|
|
||||||
|
import lombok.AccessLevel;
|
||||||
import lombok.Getter;
|
import lombok.Getter;
|
||||||
import lombok.Setter;
|
import lombok.Setter;
|
||||||
import org.deeplearning4j.gym.StepReply;
|
import org.deeplearning4j.gym.StepReply;
|
||||||
import org.deeplearning4j.rl4j.learning.Learning;
|
import org.deeplearning4j.rl4j.learning.Learning;
|
||||||
import org.deeplearning4j.rl4j.learning.sync.Transition;
|
import org.deeplearning4j.rl4j.learning.sync.Transition;
|
||||||
import org.deeplearning4j.rl4j.learning.sync.qlearning.QLearning;
|
import org.deeplearning4j.rl4j.learning.sync.qlearning.QLearning;
|
||||||
|
import org.deeplearning4j.rl4j.learning.sync.qlearning.discrete.TDTargetAlgorithm.*;
|
||||||
import org.deeplearning4j.rl4j.mdp.MDP;
|
import org.deeplearning4j.rl4j.mdp.MDP;
|
||||||
import org.deeplearning4j.rl4j.network.dqn.IDQN;
|
import org.deeplearning4j.rl4j.network.dqn.IDQN;
|
||||||
import org.deeplearning4j.rl4j.policy.DQNPolicy;
|
import org.deeplearning4j.rl4j.policy.DQNPolicy;
|
||||||
|
@ -29,10 +31,7 @@ import org.deeplearning4j.rl4j.policy.EpsGreedy;
|
||||||
import org.deeplearning4j.rl4j.space.DiscreteSpace;
|
import org.deeplearning4j.rl4j.space.DiscreteSpace;
|
||||||
import org.deeplearning4j.rl4j.space.Encodable;
|
import org.deeplearning4j.rl4j.space.Encodable;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
import org.nd4j.linalg.dataset.api.DataSet;
|
||||||
import org.nd4j.linalg.indexing.INDArrayIndex;
|
|
||||||
import org.nd4j.linalg.indexing.NDArrayIndex;
|
|
||||||
import org.nd4j.linalg.primitives.Pair;
|
|
||||||
import org.nd4j.linalg.util.ArrayUtil;
|
import org.nd4j.linalg.util.ArrayUtil;
|
||||||
|
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
|
@ -53,29 +52,38 @@ public abstract class QLearningDiscrete<O extends Encodable> extends QLearning<O
|
||||||
@Getter
|
@Getter
|
||||||
final private MDP<O, Integer, DiscreteSpace> mdp;
|
final private MDP<O, Integer, DiscreteSpace> mdp;
|
||||||
@Getter
|
@Getter
|
||||||
final private IDQN currentDQN;
|
|
||||||
@Getter
|
|
||||||
private DQNPolicy<O> policy;
|
private DQNPolicy<O> policy;
|
||||||
@Getter
|
@Getter
|
||||||
private EpsGreedy<O, Integer, DiscreteSpace> egPolicy;
|
private EpsGreedy<O, Integer, DiscreteSpace> egPolicy;
|
||||||
|
|
||||||
@Getter
|
@Getter
|
||||||
@Setter
|
final private IDQN qNetwork;
|
||||||
private IDQN targetDQN;
|
@Getter
|
||||||
|
@Setter(AccessLevel.PROTECTED)
|
||||||
|
private IDQN targetQNetwork;
|
||||||
|
|
||||||
private int lastAction;
|
private int lastAction;
|
||||||
private INDArray[] history = null;
|
private INDArray[] history = null;
|
||||||
private double accuReward = 0;
|
private double accuReward = 0;
|
||||||
|
|
||||||
|
ITDTargetAlgorithm tdTargetAlgorithm;
|
||||||
|
|
||||||
public QLearningDiscrete(MDP<O, Integer, DiscreteSpace> mdp, IDQN dqn, QLConfiguration conf,
|
public QLearningDiscrete(MDP<O, Integer, DiscreteSpace> mdp, IDQN dqn, QLConfiguration conf,
|
||||||
int epsilonNbStep) {
|
int epsilonNbStep) {
|
||||||
super(conf);
|
super(conf);
|
||||||
this.configuration = conf;
|
this.configuration = conf;
|
||||||
this.mdp = mdp;
|
this.mdp = mdp;
|
||||||
currentDQN = dqn;
|
qNetwork = dqn;
|
||||||
targetDQN = dqn.clone();
|
targetQNetwork = dqn.clone();
|
||||||
policy = new DQNPolicy(getCurrentDQN());
|
policy = new DQNPolicy(getQNetwork());
|
||||||
egPolicy = new EpsGreedy(policy, mdp, conf.getUpdateStart(), epsilonNbStep, getRandom(), conf.getMinEpsilon(),
|
egPolicy = new EpsGreedy(policy, mdp, conf.getUpdateStart(), epsilonNbStep, getRandom(), conf.getMinEpsilon(),
|
||||||
this);
|
this);
|
||||||
mdp.getActionSpace().setSeed(conf.getSeed());
|
mdp.getActionSpace().setSeed(conf.getSeed());
|
||||||
|
|
||||||
|
tdTargetAlgorithm = conf.isDoubleDQN()
|
||||||
|
? new DoubleDQN(this, conf.getGamma(), conf.getErrorClamp())
|
||||||
|
: new StandardDQN(this, conf.getGamma(), conf.getErrorClamp());
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
public void postEpoch() {
|
public void postEpoch() {
|
||||||
|
@ -134,7 +142,7 @@ public abstract class QLearningDiscrete<O extends Encodable> extends QLearning<O
|
||||||
if (hstack.shape().length > 2)
|
if (hstack.shape().length > 2)
|
||||||
hstack = hstack.reshape(Learning.makeShape(1, ArrayUtil.toInts(hstack.shape())));
|
hstack = hstack.reshape(Learning.makeShape(1, ArrayUtil.toInts(hstack.shape())));
|
||||||
|
|
||||||
INDArray qs = getCurrentDQN().output(hstack);
|
INDArray qs = getQNetwork().output(hstack);
|
||||||
int maxAction = Learning.getMaxAction(qs);
|
int maxAction = Learning.getMaxAction(qs);
|
||||||
|
|
||||||
maxQ = qs.getDouble(maxAction);
|
maxQ = qs.getDouble(maxAction);
|
||||||
|
@ -160,96 +168,31 @@ public abstract class QLearningDiscrete<O extends Encodable> extends QLearning<O
|
||||||
getExpReplay().store(trans);
|
getExpReplay().store(trans);
|
||||||
|
|
||||||
if (getStepCounter() > updateStart) {
|
if (getStepCounter() > updateStart) {
|
||||||
Pair<INDArray, INDArray> targets = setTarget(getExpReplay().getBatch());
|
DataSet targets = setTarget(getExpReplay().getBatch());
|
||||||
getCurrentDQN().fit(targets.getFirst(), targets.getSecond());
|
getQNetwork().fit(targets.getFeatures(), targets.getLabels());
|
||||||
}
|
}
|
||||||
|
|
||||||
history = nhistory;
|
history = nhistory;
|
||||||
accuReward = 0;
|
accuReward = 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
return new QLStepReturn<O>(maxQ, getQNetwork().getLatestScore(), stepReply);
|
||||||
return new QLStepReturn<O>(maxQ, getCurrentDQN().getLatestScore(), stepReply);
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
protected Pair<INDArray, INDArray> setTarget(ArrayList<Transition<Integer>> transitions) {
|
protected DataSet setTarget(ArrayList<Transition<Integer>> transitions) {
|
||||||
if (transitions.size() == 0)
|
if (transitions.size() == 0)
|
||||||
throw new IllegalArgumentException("too few transitions");
|
throw new IllegalArgumentException("too few transitions");
|
||||||
|
|
||||||
int size = transitions.size();
|
// TODO: Remove once we use DataSets in observations
|
||||||
|
|
||||||
int[] shape = getHistoryProcessor() == null ? getMdp().getObservationSpace().getShape()
|
int[] shape = getHistoryProcessor() == null ? getMdp().getObservationSpace().getShape()
|
||||||
: getHistoryProcessor().getConf().getShape();
|
: getHistoryProcessor().getConf().getShape();
|
||||||
int[] nshape = makeShape(size, shape);
|
((BaseTDTargetAlgorithm) tdTargetAlgorithm).setNShape(makeShape(transitions.size(), shape));
|
||||||
INDArray obs = Nd4j.create(nshape);
|
|
||||||
INDArray nextObs = Nd4j.create(nshape);
|
|
||||||
int[] actions = new int[size];
|
|
||||||
boolean[] areTerminal = new boolean[size];
|
|
||||||
|
|
||||||
for (int i = 0; i < size; i++) {
|
// TODO: Remove once we use DataSets in observations
|
||||||
Transition<Integer> trans = transitions.get(i);
|
|
||||||
areTerminal[i] = trans.isTerminal();
|
|
||||||
actions[i] = trans.getAction();
|
|
||||||
|
|
||||||
INDArray[] obsArray = trans.getObservation();
|
|
||||||
if (obs.rank() == 2) {
|
|
||||||
obs.putRow(i, obsArray[0]);
|
|
||||||
} else {
|
|
||||||
for (int j = 0; j < obsArray.length; j++) {
|
|
||||||
obs.put(new INDArrayIndex[] {NDArrayIndex.point(i), NDArrayIndex.point(j)}, obsArray[j]);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
INDArray[] nextObsArray = Transition.append(trans.getObservation(), trans.getNextObservation());
|
|
||||||
if (nextObs.rank() == 2) {
|
|
||||||
nextObs.putRow(i, nextObsArray[0]);
|
|
||||||
} else {
|
|
||||||
for (int j = 0; j < nextObsArray.length; j++) {
|
|
||||||
nextObs.put(new INDArrayIndex[] {NDArrayIndex.point(i), NDArrayIndex.point(j)}, nextObsArray[j]);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if(getHistoryProcessor() != null) {
|
if(getHistoryProcessor() != null) {
|
||||||
obs.muli(1.0 / getHistoryProcessor().getScale());
|
((BaseTDTargetAlgorithm) tdTargetAlgorithm).setScale(getHistoryProcessor().getScale());
|
||||||
nextObs.muli(1.0 / getHistoryProcessor().getScale());
|
|
||||||
}
|
}
|
||||||
|
|
||||||
INDArray dqnOutputAr = dqnOutput(obs);
|
return tdTargetAlgorithm.computeTDTargets(transitions);
|
||||||
|
|
||||||
INDArray dqnOutputNext = dqnOutput(nextObs);
|
|
||||||
INDArray targetDqnOutputNext = targetDqnOutput(nextObs);
|
|
||||||
|
|
||||||
INDArray tempQ = null;
|
|
||||||
INDArray getMaxAction = null;
|
|
||||||
if (getConfiguration().isDoubleDQN()) {
|
|
||||||
getMaxAction = Nd4j.argMax(dqnOutputNext, 1);
|
|
||||||
} else {
|
|
||||||
tempQ = Nd4j.max(targetDqnOutputNext, 1);
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
for (int i = 0; i < size; i++) {
|
|
||||||
double yTar = transitions.get(i).getReward();
|
|
||||||
if (!areTerminal[i]) {
|
|
||||||
double q = 0;
|
|
||||||
if (getConfiguration().isDoubleDQN()) {
|
|
||||||
q += targetDqnOutputNext.getDouble(i, getMaxAction.getInt(i));
|
|
||||||
} else
|
|
||||||
q += tempQ.getDouble(i);
|
|
||||||
|
|
||||||
yTar += getConfiguration().getGamma() * q;
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
double previousV = dqnOutputAr.getDouble(i, actions[i]);
|
|
||||||
double lowB = previousV - getConfiguration().getErrorClamp();
|
|
||||||
double highB = previousV + getConfiguration().getErrorClamp();
|
|
||||||
double clamped = Math.min(highB, Math.max(yTar, lowB));
|
|
||||||
|
|
||||||
dqnOutputAr.putScalar(i, actions[i], clamped);
|
|
||||||
}
|
|
||||||
|
|
||||||
return new Pair(obs, dqnOutputAr);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -0,0 +1,62 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2015-2019 Skymind, Inc.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
package org.deeplearning4j.rl4j.learning.sync.qlearning.discrete.TDTargetAlgorithm;
|
||||||
|
|
||||||
|
import org.deeplearning4j.rl4j.learning.sync.qlearning.TargetQNetworkSource;
|
||||||
|
import org.deeplearning4j.rl4j.network.dqn.IDQN;
|
||||||
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* The base of all DQN based algorithms
|
||||||
|
*
|
||||||
|
* @author Alexandre Boulanger
|
||||||
|
*
|
||||||
|
*/
|
||||||
|
public abstract class BaseDQNAlgorithm extends BaseTDTargetAlgorithm {
|
||||||
|
|
||||||
|
private final TargetQNetworkSource qTargetNetworkSource;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* In litterature, this corresponds to Q{net}(s(t+1), a)
|
||||||
|
*/
|
||||||
|
protected INDArray qNetworkNextObservation;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* In litterature, this corresponds to Q{tnet}(s(t+1), a)
|
||||||
|
*/
|
||||||
|
protected INDArray targetQNetworkNextObservation;
|
||||||
|
|
||||||
|
protected BaseDQNAlgorithm(TargetQNetworkSource qTargetNetworkSource, double gamma) {
|
||||||
|
super(qTargetNetworkSource, gamma);
|
||||||
|
this.qTargetNetworkSource = qTargetNetworkSource;
|
||||||
|
}
|
||||||
|
|
||||||
|
protected BaseDQNAlgorithm(TargetQNetworkSource qTargetNetworkSource, double gamma, double errorClamp) {
|
||||||
|
super(qTargetNetworkSource, gamma, errorClamp);
|
||||||
|
this.qTargetNetworkSource = qTargetNetworkSource;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
protected void initComputation(INDArray observations, INDArray nextObservations) {
|
||||||
|
super.initComputation(observations, nextObservations);
|
||||||
|
|
||||||
|
qNetworkNextObservation = qNetworkSource.getQNetwork().output(nextObservations);
|
||||||
|
|
||||||
|
IDQN targetQNetwork = qTargetNetworkSource.getTargetQNetwork();
|
||||||
|
targetQNetworkNextObservation = targetQNetwork.output(nextObservations);
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,147 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2015-2019 Skymind, Inc.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
package org.deeplearning4j.rl4j.learning.sync.qlearning.discrete.TDTargetAlgorithm;
|
||||||
|
|
||||||
|
import lombok.Setter;
|
||||||
|
import org.deeplearning4j.rl4j.learning.sync.Transition;
|
||||||
|
import org.deeplearning4j.rl4j.learning.sync.qlearning.QNetworkSource;
|
||||||
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
import org.nd4j.linalg.dataset.api.DataSet;
|
||||||
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
|
import org.nd4j.linalg.indexing.INDArrayIndex;
|
||||||
|
import org.nd4j.linalg.indexing.NDArrayIndex;
|
||||||
|
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* The base of all TD target calculation algorithms that use deep learning.
|
||||||
|
*
|
||||||
|
* @author Alexandre Boulanger
|
||||||
|
*/
|
||||||
|
public abstract class BaseTDTargetAlgorithm implements ITDTargetAlgorithm<Integer> {
|
||||||
|
|
||||||
|
protected final QNetworkSource qNetworkSource;
|
||||||
|
protected final double gamma;
|
||||||
|
|
||||||
|
private final double errorClamp;
|
||||||
|
private final boolean isClamped;
|
||||||
|
|
||||||
|
@Setter
|
||||||
|
private int[] nShape; // TODO: Remove once we use DataSets in observations
|
||||||
|
@Setter
|
||||||
|
private double scale = 1.0; // TODO: Remove once we use DataSets in observations
|
||||||
|
|
||||||
|
/**
|
||||||
|
*
|
||||||
|
* @param qNetworkSource The source of the Q-Network
|
||||||
|
* @param gamma The discount factor
|
||||||
|
* @param errorClamp Will prevent the new Q-Value from being farther than <i>errorClamp</i> away from the previous value. Double.NaN will disable the clamping.
|
||||||
|
*/
|
||||||
|
protected BaseTDTargetAlgorithm(QNetworkSource qNetworkSource, double gamma, double errorClamp) {
|
||||||
|
this.qNetworkSource = qNetworkSource;
|
||||||
|
this.gamma = gamma;
|
||||||
|
|
||||||
|
this.errorClamp = errorClamp;
|
||||||
|
isClamped = !Double.isNaN(errorClamp);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
*
|
||||||
|
* @param qNetworkSource The source of the Q-Network
|
||||||
|
* @param gamma The discount factor
|
||||||
|
* Note: Error clamping is disabled with this ctor
|
||||||
|
*/
|
||||||
|
protected BaseTDTargetAlgorithm(QNetworkSource qNetworkSource, double gamma) {
|
||||||
|
this(qNetworkSource, gamma, Double.NaN);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Called just before the calculation starts
|
||||||
|
* @param observations A INDArray of all observations stacked on dimension 0
|
||||||
|
* @param nextObservations A INDArray of all next observations stacked on dimension 0
|
||||||
|
*/
|
||||||
|
protected void initComputation(INDArray observations, INDArray nextObservations) {
|
||||||
|
// Do nothing
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Compute the new estimated Q-Value for every transition in the batch
|
||||||
|
*
|
||||||
|
* @param batchIdx The index in the batch of the current transition
|
||||||
|
* @param reward The reward of the current transition
|
||||||
|
* @param isTerminal True if it's the last transition of the "game"
|
||||||
|
* @return The estimated Q-Value
|
||||||
|
*/
|
||||||
|
protected abstract double computeTarget(int batchIdx, double reward, boolean isTerminal);
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public DataSet computeTDTargets(List<Transition<Integer>> transitions) {
|
||||||
|
|
||||||
|
int size = transitions.size();
|
||||||
|
|
||||||
|
INDArray observations = Nd4j.create(nShape);
|
||||||
|
INDArray nextObservations = Nd4j.create(nShape);
|
||||||
|
|
||||||
|
// TODO: Remove once we use DataSets in observations
|
||||||
|
for (int i = 0; i < size; i++) {
|
||||||
|
Transition<Integer> trans = transitions.get(i);
|
||||||
|
|
||||||
|
INDArray[] obsArray = trans.getObservation();
|
||||||
|
if (observations.rank() == 2) {
|
||||||
|
observations.putRow(i, obsArray[0]);
|
||||||
|
} else {
|
||||||
|
for (int j = 0; j < obsArray.length; j++) {
|
||||||
|
observations.put(new INDArrayIndex[] {NDArrayIndex.point(i), NDArrayIndex.point(j)}, obsArray[j]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
INDArray[] nextObsArray = Transition.append(trans.getObservation(), trans.getNextObservation());
|
||||||
|
if (nextObservations.rank() == 2) {
|
||||||
|
nextObservations.putRow(i, nextObsArray[0]);
|
||||||
|
} else {
|
||||||
|
for (int j = 0; j < nextObsArray.length; j++) {
|
||||||
|
nextObservations.put(new INDArrayIndex[] {NDArrayIndex.point(i), NDArrayIndex.point(j)}, nextObsArray[j]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: Remove once we use DataSets in observations
|
||||||
|
if(scale != 1.0) {
|
||||||
|
observations.muli(1.0 / scale);
|
||||||
|
nextObservations.muli(1.0 / scale);
|
||||||
|
}
|
||||||
|
|
||||||
|
initComputation(observations, nextObservations);
|
||||||
|
|
||||||
|
INDArray updatedQValues = qNetworkSource.getQNetwork().output(observations);
|
||||||
|
|
||||||
|
for (int i = 0; i < size; ++i) {
|
||||||
|
Transition<Integer> transition = transitions.get(i);
|
||||||
|
double yTarget = computeTarget(i, transition.getReward(), transition.isTerminal());
|
||||||
|
|
||||||
|
if(isClamped) {
|
||||||
|
double previousQValue = updatedQValues.getDouble(i, transition.getAction());
|
||||||
|
double lowBound = previousQValue - errorClamp;
|
||||||
|
double highBound = previousQValue + errorClamp;
|
||||||
|
yTarget = Math.min(highBound, Math.max(yTarget, lowBound));
|
||||||
|
}
|
||||||
|
updatedQValues.putScalar(i, transition.getAction(), yTarget);
|
||||||
|
}
|
||||||
|
|
||||||
|
return new org.nd4j.linalg.dataset.DataSet(observations, updatedQValues);
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,67 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2015-2019 Skymind, Inc.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
package org.deeplearning4j.rl4j.learning.sync.qlearning.discrete.TDTargetAlgorithm;
|
||||||
|
|
||||||
|
import org.deeplearning4j.rl4j.learning.sync.qlearning.TargetQNetworkSource;
|
||||||
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* The Double-DQN algorithm based on "Deep Reinforcement Learning with Double Q-learning" (https://arxiv.org/abs/1509.06461)
|
||||||
|
*
|
||||||
|
* @author Alexandre Boulanger
|
||||||
|
*/
|
||||||
|
public class DoubleDQN extends BaseDQNAlgorithm {
|
||||||
|
|
||||||
|
private static final int ACTION_DIMENSION_IDX = 1;
|
||||||
|
|
||||||
|
// In litterature, this corresponds to: max_{a}Q(s_{t+1}, a)
|
||||||
|
private INDArray maxActionsFromQNetworkNextObservation;
|
||||||
|
|
||||||
|
public DoubleDQN(TargetQNetworkSource qTargetNetworkSource, double gamma) {
|
||||||
|
super(qTargetNetworkSource, gamma);
|
||||||
|
}
|
||||||
|
|
||||||
|
public DoubleDQN(TargetQNetworkSource qTargetNetworkSource, double gamma, double errorClamp) {
|
||||||
|
super(qTargetNetworkSource, gamma, errorClamp);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
protected void initComputation(INDArray observations, INDArray nextObservations) {
|
||||||
|
super.initComputation(observations, nextObservations);
|
||||||
|
|
||||||
|
maxActionsFromQNetworkNextObservation = Nd4j.argMax(qNetworkNextObservation, ACTION_DIMENSION_IDX);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* In litterature, this corresponds to:<br />
|
||||||
|
* Q(s_t, a_t) = R_{t+1} + \gamma * Q_{tar}(s_{t+1}, max_{a}Q(s_{t+1}, a))
|
||||||
|
* @param batchIdx The index in the batch of the current transition
|
||||||
|
* @param reward The reward of the current transition
|
||||||
|
* @param isTerminal True if it's the last transition of the "game"
|
||||||
|
* @return The estimated Q-Value
|
||||||
|
*/
|
||||||
|
@Override
|
||||||
|
protected double computeTarget(int batchIdx, double reward, boolean isTerminal) {
|
||||||
|
double yTarget = reward;
|
||||||
|
if (!isTerminal) {
|
||||||
|
yTarget += gamma * targetQNetworkNextObservation.getDouble(batchIdx, maxActionsFromQNetworkNextObservation.getInt(batchIdx));
|
||||||
|
}
|
||||||
|
|
||||||
|
return yTarget;
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,38 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2015-2019 Skymind, Inc.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
package org.deeplearning4j.rl4j.learning.sync.qlearning.discrete.TDTargetAlgorithm;
|
||||||
|
|
||||||
|
import org.deeplearning4j.rl4j.learning.sync.Transition;
|
||||||
|
import org.nd4j.linalg.dataset.api.DataSet;
|
||||||
|
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* The interface of all TD target calculation algorithms.
|
||||||
|
*
|
||||||
|
* @param <A> The type of actions
|
||||||
|
*
|
||||||
|
* @author Alexandre Boulanger
|
||||||
|
*/
|
||||||
|
public interface ITDTargetAlgorithm<A> {
|
||||||
|
/**
|
||||||
|
* Compute the updated estimated Q-Values for every transition
|
||||||
|
* @param transitions The transitions from the experience replay
|
||||||
|
* @return A DataSet where every element is the observation and the estimated Q-Values for all actions
|
||||||
|
*/
|
||||||
|
DataSet computeTDTargets(List<Transition<A>> transitions);
|
||||||
|
}
|
|
@ -0,0 +1,67 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2015-2019 Skymind, Inc.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
package org.deeplearning4j.rl4j.learning.sync.qlearning.discrete.TDTargetAlgorithm;
|
||||||
|
|
||||||
|
import org.deeplearning4j.rl4j.learning.sync.qlearning.TargetQNetworkSource;
|
||||||
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* The Standard DQN algorithm based on "Playing Atari with Deep Reinforcement Learning" (https://arxiv.org/abs/1312.5602)
|
||||||
|
*
|
||||||
|
* @author Alexandre Boulanger
|
||||||
|
*/
|
||||||
|
public class StandardDQN extends BaseDQNAlgorithm {
|
||||||
|
|
||||||
|
private static final int ACTION_DIMENSION_IDX = 1;
|
||||||
|
|
||||||
|
// In litterature, this corresponds to: max_{a}Q_{tar}(s_{t+1}, a)
|
||||||
|
private INDArray maxActionsFromQTargetNextObservation;
|
||||||
|
|
||||||
|
public StandardDQN(TargetQNetworkSource qTargetNetworkSource, double gamma) {
|
||||||
|
super(qTargetNetworkSource, gamma);
|
||||||
|
}
|
||||||
|
|
||||||
|
public StandardDQN(TargetQNetworkSource qTargetNetworkSource, double gamma, double errorClamp) {
|
||||||
|
super(qTargetNetworkSource, gamma, errorClamp);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
protected void initComputation(INDArray observations, INDArray nextObservations) {
|
||||||
|
super.initComputation(observations, nextObservations);
|
||||||
|
|
||||||
|
maxActionsFromQTargetNextObservation = Nd4j.max(targetQNetworkNextObservation, ACTION_DIMENSION_IDX);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* In litterature, this corresponds to:<br />
|
||||||
|
* Q(s_t, a_t) = R_{t+1} + \gamma * max_{a}Q_{tar}(s_{t+1}, a)
|
||||||
|
* @param batchIdx The index in the batch of the current transition
|
||||||
|
* @param reward The reward of the current transition
|
||||||
|
* @param isTerminal True if it's the last transition of the "game"
|
||||||
|
* @return The estimated Q-Value
|
||||||
|
*/
|
||||||
|
@Override
|
||||||
|
protected double computeTarget(int batchIdx, double reward, boolean isTerminal) {
|
||||||
|
double yTarget = reward;
|
||||||
|
if (!isTerminal) {
|
||||||
|
yTarget += gamma * maxActionsFromQTargetNextObservation.getDouble(batchIdx);
|
||||||
|
}
|
||||||
|
|
||||||
|
return yTarget;
|
||||||
|
}
|
||||||
|
}
|
|
@ -12,8 +12,8 @@ import org.deeplearning4j.rl4j.util.DataManagerTrainingListener;
|
||||||
import org.deeplearning4j.rl4j.util.IDataManager;
|
import org.deeplearning4j.rl4j.util.IDataManager;
|
||||||
import org.junit.Test;
|
import org.junit.Test;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
import org.nd4j.linalg.dataset.api.DataSet;
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
import org.nd4j.linalg.primitives.Pair;
|
|
||||||
|
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
@ -139,8 +139,8 @@ public class QLearningDiscreteTest {
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
protected Pair<INDArray, INDArray> setTarget(ArrayList<Transition<Integer>> transitions) {
|
protected DataSet setTarget(ArrayList<Transition<Integer>> transitions) {
|
||||||
return new Pair<>(Nd4j.create(new double[] { 123.0 }), Nd4j.create(new double[] { 234.0 }));
|
return new org.nd4j.linalg.dataset.DataSet(Nd4j.create(new double[] { 123.0 }), Nd4j.create(new double[] { 234.0 }));
|
||||||
}
|
}
|
||||||
|
|
||||||
public void setExpReplay(IExpReplay<Integer> exp){
|
public void setExpReplay(IExpReplay<Integer> exp){
|
||||||
|
|
|
@ -0,0 +1,105 @@
|
||||||
|
package org.deeplearning4j.rl4j.learning.sync.qlearning.discrete.TDTargetAlgorithm;
|
||||||
|
|
||||||
|
import org.deeplearning4j.rl4j.learning.sync.Transition;
|
||||||
|
import org.deeplearning4j.rl4j.learning.sync.support.MockDQN;
|
||||||
|
import org.deeplearning4j.rl4j.learning.sync.support.MockTargetQNetworkSource;
|
||||||
|
import org.junit.Test;
|
||||||
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
import org.nd4j.linalg.dataset.api.DataSet;
|
||||||
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
|
|
||||||
|
import java.util.ArrayList;
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
|
import static org.junit.Assert.assertEquals;
|
||||||
|
|
||||||
|
public class DoubleDQNTest {
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void when_isTerminal_expect_rewardValueAtIdx0() {
|
||||||
|
|
||||||
|
// Assemble
|
||||||
|
MockDQN qNetwork = new MockDQN();
|
||||||
|
MockDQN targetQNetwork = new MockDQN();
|
||||||
|
MockTargetQNetworkSource targetQNetworkSource = new MockTargetQNetworkSource(qNetwork, targetQNetwork);
|
||||||
|
|
||||||
|
List<Transition<Integer>> transitions = new ArrayList<Transition<Integer>>() {
|
||||||
|
{
|
||||||
|
add(new Transition<Integer>(new INDArray[]{Nd4j.create(new double[]{1.1, 2.2})}, 0, 1.0, true, Nd4j.create(new double[]{11.0, 22.0})));
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
DoubleDQN sut = new DoubleDQN(targetQNetworkSource, 0.5);
|
||||||
|
sut.setNShape(new int[] { 1, 2 });
|
||||||
|
|
||||||
|
// Act
|
||||||
|
DataSet result = sut.computeTDTargets(transitions);
|
||||||
|
|
||||||
|
// Assert
|
||||||
|
INDArray evaluatedQValues = result.getLabels();
|
||||||
|
assertEquals(1.0, evaluatedQValues.getDouble(0, 0), 0.0001);
|
||||||
|
assertEquals(2.2, evaluatedQValues.getDouble(0, 1), 0.0001);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void when_isNotTerminal_expect_rewardPlusEstimatedQValue() {
|
||||||
|
|
||||||
|
// Assemble
|
||||||
|
MockDQN qNetwork = new MockDQN();
|
||||||
|
MockDQN targetQNetwork = new MockDQN(-1.0);
|
||||||
|
MockTargetQNetworkSource targetQNetworkSource = new MockTargetQNetworkSource(qNetwork, targetQNetwork);
|
||||||
|
|
||||||
|
List<Transition<Integer>> transitions = new ArrayList<Transition<Integer>>() {
|
||||||
|
{
|
||||||
|
add(new Transition<Integer>(new INDArray[]{Nd4j.create(new double[]{1.1, 2.2})}, 0, 1.0, false, Nd4j.create(new double[]{11.0, 22.0})));
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
DoubleDQN sut = new DoubleDQN(targetQNetworkSource, 0.5);
|
||||||
|
sut.setNShape(new int[] { 1, 2 });
|
||||||
|
|
||||||
|
// Act
|
||||||
|
DataSet result = sut.computeTDTargets(transitions);
|
||||||
|
|
||||||
|
// Assert
|
||||||
|
INDArray evaluatedQValues = result.getLabels();
|
||||||
|
assertEquals(1.0 + 0.5 * -22.0, evaluatedQValues.getDouble(0, 0), 0.0001);
|
||||||
|
assertEquals(2.2, evaluatedQValues.getDouble(0, 1), 0.0001);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void when_batchHasMoreThanOne_expect_everySampleEvaluated() {
|
||||||
|
|
||||||
|
// Assemble
|
||||||
|
MockDQN qNetwork = new MockDQN();
|
||||||
|
MockDQN targetQNetwork = new MockDQN(-1.0);
|
||||||
|
MockTargetQNetworkSource targetQNetworkSource = new MockTargetQNetworkSource(qNetwork, targetQNetwork);
|
||||||
|
|
||||||
|
List<Transition<Integer>> transitions = new ArrayList<Transition<Integer>>() {
|
||||||
|
{
|
||||||
|
add(new Transition<Integer>(new INDArray[]{Nd4j.create(new double[]{1.1, 2.2})}, 0, 1.0, false, Nd4j.create(new double[]{11.0, 22.0})));
|
||||||
|
add(new Transition<Integer>(new INDArray[]{Nd4j.create(new double[]{3.3, 4.4})}, 1, 2.0, false, Nd4j.create(new double[]{33.0, 44.0})));
|
||||||
|
add(new Transition<Integer>(new INDArray[]{Nd4j.create(new double[]{5.5, 6.6})}, 0, 3.0, true, Nd4j.create(new double[]{55.0, 66.0})));
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
DoubleDQN sut = new DoubleDQN(targetQNetworkSource, 0.5);
|
||||||
|
sut.setNShape(new int[] { 3, 2 });
|
||||||
|
|
||||||
|
// Act
|
||||||
|
DataSet result = sut.computeTDTargets(transitions);
|
||||||
|
|
||||||
|
// Assert
|
||||||
|
INDArray evaluatedQValues = result.getLabels();
|
||||||
|
assertEquals(1.0 + 0.5 * -22.0, evaluatedQValues.getDouble(0, 0), 0.0001);
|
||||||
|
assertEquals(2.2, evaluatedQValues.getDouble(0, 1), 0.0001);
|
||||||
|
|
||||||
|
assertEquals(3.3, evaluatedQValues.getDouble(1, 0), 0.0001);
|
||||||
|
assertEquals(2.0 + 0.5 * -44.0, evaluatedQValues.getDouble(1, 1), 0.0001);
|
||||||
|
|
||||||
|
assertEquals(3.0, evaluatedQValues.getDouble(2, 0), 0.0001); // terminal: reward only
|
||||||
|
assertEquals(6.6, evaluatedQValues.getDouble(2, 1), 0.0001);
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
|
@ -0,0 +1,104 @@
|
||||||
|
package org.deeplearning4j.rl4j.learning.sync.qlearning.discrete.TDTargetAlgorithm;
|
||||||
|
|
||||||
|
import org.deeplearning4j.rl4j.learning.sync.Transition;
|
||||||
|
import org.deeplearning4j.rl4j.learning.sync.support.MockDQN;
|
||||||
|
import org.deeplearning4j.rl4j.learning.sync.support.MockTargetQNetworkSource;
|
||||||
|
import org.junit.Test;
|
||||||
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
import org.nd4j.linalg.dataset.api.DataSet;
|
||||||
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
|
|
||||||
|
import java.util.ArrayList;
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
|
import static org.junit.Assert.*;
|
||||||
|
|
||||||
|
public class StandardDQNTest {
|
||||||
|
@Test
|
||||||
|
public void when_isTerminal_expect_rewardValueAtIdx0() {
|
||||||
|
|
||||||
|
// Assemble
|
||||||
|
MockDQN qNetwork = new MockDQN();
|
||||||
|
MockDQN targetQNetwork = new MockDQN();
|
||||||
|
MockTargetQNetworkSource targetQNetworkSource = new MockTargetQNetworkSource(qNetwork, targetQNetwork);
|
||||||
|
|
||||||
|
List<Transition<Integer>> transitions = new ArrayList<Transition<Integer>>() {
|
||||||
|
{
|
||||||
|
add(new Transition<Integer>(new INDArray[]{Nd4j.create(new double[]{1.1, 2.2})}, 0, 1.0, true, Nd4j.create(new double[]{11.0, 22.0})));
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
StandardDQN sut = new StandardDQN(targetQNetworkSource, 0.5);
|
||||||
|
sut.setNShape(new int[] { 1, 2 });
|
||||||
|
|
||||||
|
// Act
|
||||||
|
DataSet result = sut.computeTDTargets(transitions);
|
||||||
|
|
||||||
|
// Assert
|
||||||
|
INDArray evaluatedQValues = result.getLabels();
|
||||||
|
assertEquals(1.0, evaluatedQValues.getDouble(0, 0), 0.0001);
|
||||||
|
assertEquals(2.2, evaluatedQValues.getDouble(0, 1), 0.0001);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void when_isNotTerminal_expect_rewardPlusEstimatedQValue() {
|
||||||
|
|
||||||
|
// Assemble
|
||||||
|
MockDQN qNetwork = new MockDQN();
|
||||||
|
MockDQN targetQNetwork = new MockDQN();
|
||||||
|
MockTargetQNetworkSource targetQNetworkSource = new MockTargetQNetworkSource(qNetwork, targetQNetwork);
|
||||||
|
|
||||||
|
List<Transition<Integer>> transitions = new ArrayList<Transition<Integer>>() {
|
||||||
|
{
|
||||||
|
add(new Transition<Integer>(new INDArray[]{Nd4j.create(new double[]{1.1, 2.2})}, 0, 1.0, false, Nd4j.create(new double[]{11.0, 22.0})));
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
StandardDQN sut = new StandardDQN(targetQNetworkSource, 0.5);
|
||||||
|
sut.setNShape(new int[] { 1, 2 });
|
||||||
|
|
||||||
|
// Act
|
||||||
|
DataSet result = sut.computeTDTargets(transitions);
|
||||||
|
|
||||||
|
// Assert
|
||||||
|
INDArray evaluatedQValues = result.getLabels();
|
||||||
|
assertEquals(1.0 + 0.5 * 22.0, evaluatedQValues.getDouble(0, 0), 0.0001);
|
||||||
|
assertEquals(2.2, evaluatedQValues.getDouble(0, 1), 0.0001);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void when_batchHasMoreThanOne_expect_everySampleEvaluated() {
|
||||||
|
|
||||||
|
// Assemble
|
||||||
|
MockDQN qNetwork = new MockDQN();
|
||||||
|
MockDQN targetQNetwork = new MockDQN();
|
||||||
|
MockTargetQNetworkSource targetQNetworkSource = new MockTargetQNetworkSource(qNetwork, targetQNetwork);
|
||||||
|
|
||||||
|
List<Transition<Integer>> transitions = new ArrayList<Transition<Integer>>() {
|
||||||
|
{
|
||||||
|
add(new Transition<Integer>(new INDArray[]{Nd4j.create(new double[]{1.1, 2.2})}, 0, 1.0, false, Nd4j.create(new double[]{11.0, 22.0})));
|
||||||
|
add(new Transition<Integer>(new INDArray[]{Nd4j.create(new double[]{3.3, 4.4})}, 1, 2.0, false, Nd4j.create(new double[]{33.0, 44.0})));
|
||||||
|
add(new Transition<Integer>(new INDArray[]{Nd4j.create(new double[]{5.5, 6.6})}, 0, 3.0, true, Nd4j.create(new double[]{55.0, 66.0})));
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
StandardDQN sut = new StandardDQN(targetQNetworkSource, 0.5);
|
||||||
|
sut.setNShape(new int[] { 3, 2 });
|
||||||
|
|
||||||
|
// Act
|
||||||
|
DataSet result = sut.computeTDTargets(transitions);
|
||||||
|
|
||||||
|
// Assert
|
||||||
|
INDArray evaluatedQValues = result.getLabels();
|
||||||
|
assertEquals((1.0 + 0.5 * 22.0), evaluatedQValues.getDouble(0, 0), 0.0001);
|
||||||
|
assertEquals(2.2, evaluatedQValues.getDouble(0, 1), 0.0001);
|
||||||
|
|
||||||
|
assertEquals(3.3, evaluatedQValues.getDouble(1, 0), 0.0001);
|
||||||
|
assertEquals((2.0 + 0.5 * 44.0), evaluatedQValues.getDouble(1, 1), 0.0001);
|
||||||
|
|
||||||
|
assertEquals(3.0, evaluatedQValues.getDouble(2, 0), 0.0001); // terminal: reward only
|
||||||
|
assertEquals(6.6, evaluatedQValues.getDouble(2, 1), 0.0001);
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
|
@ -1,15 +1,28 @@
|
||||||
package org.deeplearning4j.rl4j.learning.sync.support;
|
package org.deeplearning4j.rl4j.learning.sync.support;
|
||||||
|
|
||||||
|
import lombok.Setter;
|
||||||
import org.deeplearning4j.nn.api.NeuralNetwork;
|
import org.deeplearning4j.nn.api.NeuralNetwork;
|
||||||
import org.deeplearning4j.nn.gradient.Gradient;
|
import org.deeplearning4j.nn.gradient.Gradient;
|
||||||
import org.deeplearning4j.rl4j.network.NeuralNet;
|
import org.deeplearning4j.rl4j.network.NeuralNet;
|
||||||
import org.deeplearning4j.rl4j.network.dqn.IDQN;
|
import org.deeplearning4j.rl4j.network.dqn.IDQN;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
|
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
import java.io.OutputStream;
|
import java.io.OutputStream;
|
||||||
|
|
||||||
public class MockDQN implements IDQN {
|
public class MockDQN implements IDQN {
|
||||||
|
|
||||||
|
private final double mult;
|
||||||
|
|
||||||
|
public MockDQN() {
|
||||||
|
this(1.0);
|
||||||
|
}
|
||||||
|
|
||||||
|
public MockDQN(double mult) {
|
||||||
|
this.mult = mult;
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public NeuralNetwork[] getNeuralNetworks() {
|
public NeuralNetwork[] getNeuralNetworks() {
|
||||||
return new NeuralNetwork[0];
|
return new NeuralNetwork[0];
|
||||||
|
@ -37,7 +50,11 @@ public class MockDQN implements IDQN {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public INDArray output(INDArray batch) {
|
public INDArray output(INDArray batch) {
|
||||||
return null;
|
if(mult != 1.0) {
|
||||||
|
return batch.dup().muli(mult);
|
||||||
|
}
|
||||||
|
|
||||||
|
return batch;
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
|
|
@ -0,0 +1,26 @@
|
||||||
|
package org.deeplearning4j.rl4j.learning.sync.support;
|
||||||
|
|
||||||
|
import org.deeplearning4j.rl4j.learning.sync.qlearning.TargetQNetworkSource;
|
||||||
|
import org.deeplearning4j.rl4j.network.dqn.IDQN;
|
||||||
|
|
||||||
|
public class MockTargetQNetworkSource implements TargetQNetworkSource {
|
||||||
|
|
||||||
|
|
||||||
|
private final IDQN qNetwork;
|
||||||
|
private final IDQN targetQNetwork;
|
||||||
|
|
||||||
|
public MockTargetQNetworkSource(IDQN qNetwork, IDQN targetQNetwork) {
|
||||||
|
this.qNetwork = qNetwork;
|
||||||
|
this.targetQNetwork = targetQNetwork;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public IDQN getTargetQNetwork() {
|
||||||
|
return targetQNetwork;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public IDQN getQNetwork() {
|
||||||
|
return qNetwork;
|
||||||
|
}
|
||||||
|
}
|
Loading…
Reference in New Issue