Merge remote-tracking branch 'eclipse/master'
commit
7ea07de76b
4
pom.xml
4
pom.xml
|
@ -303,8 +303,8 @@
|
|||
<leptonica.version>1.79.0</leptonica.version>
|
||||
<hdf5.version>1.10.6</hdf5.version>
|
||||
<ale.version>0.6.1</ale.version>
|
||||
<gym.version>0.15.4</gym.version>
|
||||
<tensorflow.version>1.15.0</tensorflow.version>
|
||||
<gym.version>0.15.5</gym.version>
|
||||
<tensorflow.version>1.15.2</tensorflow.version>
|
||||
<tensorflow.javacpp.version>${tensorflow.version}-${javacpp-presets.version}</tensorflow.javacpp.version>
|
||||
|
||||
<commons-compress.version>1.18</commons-compress.version>
|
||||
|
|
|
@ -0,0 +1,5 @@
|
|||
package org.deeplearning4j.rl4j.learning;
|
||||
|
||||
public interface EpochStepCounter {
|
||||
int getCurrentEpochStep();
|
||||
}
|
|
@ -21,9 +21,14 @@ import org.deeplearning4j.rl4j.mdp.MDP;
|
|||
/**
|
||||
* The common API between Learning and AsyncThread.
|
||||
*
|
||||
* Express the ability to count the number of step of the current training.
|
||||
* Factorisation of a feature between threads in async and learning process
|
||||
* for the web monitoring
|
||||
*
|
||||
* @author Alexandre Boulanger
|
||||
* @author rubenfiszel (ruben.fiszel@epfl.ch) on 8/5/16.
|
||||
*/
|
||||
public interface IEpochTrainer {
|
||||
public interface IEpochTrainer extends EpochStepCounter {
|
||||
int getStepCounter();
|
||||
int getEpochCounter();
|
||||
IHistoryProcessor getHistoryProcessor();
|
||||
|
|
|
@ -26,7 +26,7 @@ import org.deeplearning4j.rl4j.space.Encodable;
|
|||
*
|
||||
* A common interface that any training method should implement
|
||||
*/
|
||||
public interface ILearning<O, A, AS extends ActionSpace<A>> extends StepCountable {
|
||||
public interface ILearning<O, A, AS extends ActionSpace<A>> {
|
||||
|
||||
IPolicy<O, A> getPolicy();
|
||||
|
||||
|
|
|
@ -21,7 +21,6 @@ import lombok.Getter;
|
|||
import lombok.Setter;
|
||||
import lombok.Value;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.deeplearning4j.gym.StepReply;
|
||||
import org.deeplearning4j.rl4j.mdp.MDP;
|
||||
import org.deeplearning4j.rl4j.network.NeuralNet;
|
||||
import org.deeplearning4j.rl4j.space.ActionSpace;
|
||||
|
@ -53,55 +52,6 @@ public abstract class Learning<O, A, AS extends ActionSpace<A>, NN extends Neura
|
|||
return Nd4j.argMax(vector, Integer.MAX_VALUE).getInt(0);
|
||||
}
|
||||
|
||||
public static <O, A, AS extends ActionSpace<A>> INDArray getInput(MDP<O, A, AS> mdp, O obs) {
|
||||
INDArray arr = Nd4j.create(((Encodable)obs).toArray());
|
||||
int[] shape = mdp.getObservationSpace().getShape();
|
||||
if (shape.length == 1)
|
||||
return arr.reshape(new long[] {1, arr.length()});
|
||||
else
|
||||
return arr.reshape(shape);
|
||||
}
|
||||
|
||||
public static <O, A, AS extends ActionSpace<A>> InitMdp<O> initMdp(MDP<O, A, AS> mdp,
|
||||
IHistoryProcessor hp) {
|
||||
|
||||
O obs = mdp.reset();
|
||||
|
||||
int step = 0;
|
||||
double reward = 0;
|
||||
|
||||
boolean isHistoryProcessor = hp != null;
|
||||
|
||||
int skipFrame = isHistoryProcessor ? hp.getConf().getSkipFrame() : 1;
|
||||
int requiredFrame = isHistoryProcessor ? skipFrame * (hp.getConf().getHistoryLength() - 1) : 0;
|
||||
|
||||
INDArray input = Learning.getInput(mdp, obs);
|
||||
if (isHistoryProcessor)
|
||||
hp.record(input);
|
||||
|
||||
|
||||
while (step < requiredFrame && !mdp.isDone()) {
|
||||
|
||||
A action = mdp.getActionSpace().noOp(); //by convention should be the NO_OP
|
||||
if (step % skipFrame == 0 && isHistoryProcessor)
|
||||
hp.add(input);
|
||||
|
||||
StepReply<O> stepReply = mdp.step(action);
|
||||
reward += stepReply.getReward();
|
||||
obs = stepReply.getObservation();
|
||||
|
||||
input = Learning.getInput(mdp, obs);
|
||||
if (isHistoryProcessor)
|
||||
hp.record(input);
|
||||
|
||||
step++;
|
||||
|
||||
}
|
||||
|
||||
return new InitMdp(step, obs, reward);
|
||||
|
||||
}
|
||||
|
||||
public static int[] makeShape(int size, int[] shape) {
|
||||
int[] nshape = new int[shape.length + 1];
|
||||
nshape[0] = size;
|
||||
|
@ -122,16 +72,16 @@ public abstract class Learning<O, A, AS extends ActionSpace<A>, NN extends Neura
|
|||
|
||||
public abstract NN getNeuralNet();
|
||||
|
||||
public int incrementStep() {
|
||||
return stepCounter++;
|
||||
public void incrementStep() {
|
||||
stepCounter++;
|
||||
}
|
||||
|
||||
public int incrementEpoch() {
|
||||
return epochCounter++;
|
||||
public void incrementEpoch() {
|
||||
epochCounter++;
|
||||
}
|
||||
|
||||
public void setHistoryProcessor(HistoryProcessor.Configuration conf) {
|
||||
historyProcessor = new HistoryProcessor(conf);
|
||||
setHistoryProcessor(new HistoryProcessor(conf));
|
||||
}
|
||||
|
||||
public void setHistoryProcessor(IHistoryProcessor historyProcessor) {
|
||||
|
|
|
@ -1,30 +0,0 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
package org.deeplearning4j.rl4j.learning;
|
||||
|
||||
/**
|
||||
* @author rubenfiszel (ruben.fiszel@epfl.ch) on 8/5/16.
|
||||
*
|
||||
* Express the ability to count the number of step of the current training.
|
||||
* Factorisation of a feature between threads in async and learning process
|
||||
* for the web monitoring
|
||||
*/
|
||||
public interface StepCountable {
|
||||
|
||||
int getStepCounter();
|
||||
|
||||
}
|
|
@ -30,8 +30,6 @@ import org.deeplearning4j.rl4j.network.NeuralNet;
|
|||
import org.deeplearning4j.rl4j.observation.Observation;
|
||||
import org.deeplearning4j.rl4j.policy.IPolicy;
|
||||
import org.deeplearning4j.rl4j.space.ActionSpace;
|
||||
import org.deeplearning4j.rl4j.space.DiscreteSpace;
|
||||
import org.deeplearning4j.rl4j.space.Encodable;
|
||||
import org.deeplearning4j.rl4j.util.IDataManager;
|
||||
import org.deeplearning4j.rl4j.util.LegacyMDPWrapper;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
|
@ -48,7 +46,7 @@ import org.nd4j.linalg.factory.Nd4j;
|
|||
*/
|
||||
@Slf4j
|
||||
public abstract class AsyncThread<O, A, AS extends ActionSpace<A>, NN extends NeuralNet>
|
||||
extends Thread implements StepCountable, IEpochTrainer {
|
||||
extends Thread implements IEpochTrainer {
|
||||
|
||||
@Getter
|
||||
private int threadNumber;
|
||||
|
@ -61,6 +59,9 @@ public abstract class AsyncThread<O, A, AS extends ActionSpace<A>, NN extends Ne
|
|||
@Getter @Setter
|
||||
private IHistoryProcessor historyProcessor;
|
||||
|
||||
@Getter
|
||||
private int currentEpochStep = 0;
|
||||
|
||||
private boolean isEpochStarted = false;
|
||||
private final LegacyMDPWrapper<O, A, AS> mdp;
|
||||
|
||||
|
@ -138,7 +139,7 @@ public abstract class AsyncThread<O, A, AS extends ActionSpace<A>, NN extends Ne
|
|||
|
||||
handleTraining(context);
|
||||
|
||||
if (context.epochElapsedSteps >= getConf().getMaxEpochStep() || getMdp().isDone()) {
|
||||
if (currentEpochStep >= getConf().getMaxEpochStep() || getMdp().isDone()) {
|
||||
boolean canContinue = finishEpoch(context);
|
||||
if (!canContinue) {
|
||||
break;
|
||||
|
@ -154,11 +155,10 @@ public abstract class AsyncThread<O, A, AS extends ActionSpace<A>, NN extends Ne
|
|||
}
|
||||
|
||||
private void handleTraining(RunContext context) {
|
||||
int maxSteps = Math.min(getConf().getNstep(), getConf().getMaxEpochStep() - context.epochElapsedSteps);
|
||||
int maxSteps = Math.min(getConf().getNstep(), getConf().getMaxEpochStep() - currentEpochStep);
|
||||
SubEpochReturn subEpochReturn = trainSubEpoch(context.obs, maxSteps);
|
||||
|
||||
context.obs = subEpochReturn.getLastObs();
|
||||
context.epochElapsedSteps += subEpochReturn.getSteps();
|
||||
context.rewards += subEpochReturn.getReward();
|
||||
context.score = subEpochReturn.getScore();
|
||||
}
|
||||
|
@ -169,7 +169,6 @@ public abstract class AsyncThread<O, A, AS extends ActionSpace<A>, NN extends Ne
|
|||
|
||||
context.obs = initMdp.getLastObs();
|
||||
context.rewards = initMdp.getReward();
|
||||
context.epochElapsedSteps = initMdp.getSteps();
|
||||
|
||||
isEpochStarted = true;
|
||||
preEpoch();
|
||||
|
@ -180,9 +179,9 @@ public abstract class AsyncThread<O, A, AS extends ActionSpace<A>, NN extends Ne
|
|||
private boolean finishEpoch(RunContext context) {
|
||||
isEpochStarted = false;
|
||||
postEpoch();
|
||||
IDataManager.StatEntry statEntry = new AsyncStatEntry(getStepCounter(), epochCounter, context.rewards, context.epochElapsedSteps, context.score);
|
||||
IDataManager.StatEntry statEntry = new AsyncStatEntry(getStepCounter(), epochCounter, context.rewards, currentEpochStep, context.score);
|
||||
|
||||
log.info("ThreadNum-" + threadNumber + " Epoch: " + getEpochCounter() + ", reward: " + context.rewards);
|
||||
log.info("ThreadNum-" + threadNumber + " Epoch: " + getCurrentEpochStep() + ", reward: " + context.rewards);
|
||||
|
||||
return listeners.notifyEpochTrainingResult(this, statEntry);
|
||||
}
|
||||
|
@ -205,37 +204,30 @@ public abstract class AsyncThread<O, A, AS extends ActionSpace<A>, NN extends Ne
|
|||
protected abstract SubEpochReturn trainSubEpoch(Observation obs, int nstep);
|
||||
|
||||
private Learning.InitMdp<Observation> refacInitMdp() {
|
||||
LegacyMDPWrapper<O, A, AS> mdp = getLegacyMDPWrapper();
|
||||
IHistoryProcessor hp = getHistoryProcessor();
|
||||
currentEpochStep = 0;
|
||||
|
||||
Observation observation = mdp.reset();
|
||||
|
||||
int step = 0;
|
||||
double reward = 0;
|
||||
|
||||
boolean isHistoryProcessor = hp != null;
|
||||
|
||||
int skipFrame = isHistoryProcessor ? hp.getConf().getSkipFrame() : 1;
|
||||
int requiredFrame = isHistoryProcessor ? skipFrame * (hp.getConf().getHistoryLength() - 1) : 0;
|
||||
|
||||
while (step < requiredFrame && !mdp.isDone()) {
|
||||
|
||||
A action = mdp.getActionSpace().noOp(); //by convention should be the NO_OP
|
||||
LegacyMDPWrapper<O, A, AS> mdp = getLegacyMDPWrapper();
|
||||
Observation observation = mdp.reset();
|
||||
|
||||
A action = mdp.getActionSpace().noOp(); //by convention should be the NO_OP
|
||||
while (observation.isSkipped() && !mdp.isDone()) {
|
||||
StepReply<Observation> stepReply = mdp.step(action);
|
||||
|
||||
reward += stepReply.getReward();
|
||||
observation = stepReply.getObservation();
|
||||
|
||||
step++;
|
||||
|
||||
incrementStep();
|
||||
}
|
||||
|
||||
return new Learning.InitMdp(step, observation, reward);
|
||||
return new Learning.InitMdp(0, observation, reward);
|
||||
|
||||
}
|
||||
|
||||
public void incrementStep() {
|
||||
++stepCounter;
|
||||
++currentEpochStep;
|
||||
}
|
||||
|
||||
@AllArgsConstructor
|
||||
|
@ -260,7 +252,6 @@ public abstract class AsyncThread<O, A, AS extends ActionSpace<A>, NN extends Ne
|
|||
private static class RunContext {
|
||||
private Observation obs;
|
||||
private double rewards;
|
||||
private int epochElapsedSteps;
|
||||
private double score;
|
||||
}
|
||||
|
||||
|
|
|
@ -20,19 +20,14 @@ import lombok.Getter;
|
|||
import org.deeplearning4j.gym.StepReply;
|
||||
import org.deeplearning4j.nn.gradient.Gradient;
|
||||
import org.deeplearning4j.rl4j.learning.IHistoryProcessor;
|
||||
import org.deeplearning4j.rl4j.learning.Learning;
|
||||
import org.deeplearning4j.rl4j.learning.listener.TrainingListenerList;
|
||||
import org.deeplearning4j.rl4j.learning.sync.Transition;
|
||||
import org.deeplearning4j.rl4j.mdp.MDP;
|
||||
import org.deeplearning4j.rl4j.network.NeuralNet;
|
||||
import org.deeplearning4j.rl4j.observation.Observation;
|
||||
import org.deeplearning4j.rl4j.policy.IPolicy;
|
||||
import org.deeplearning4j.rl4j.space.DiscreteSpace;
|
||||
import org.deeplearning4j.rl4j.space.Encodable;
|
||||
import org.deeplearning4j.rl4j.util.LegacyMDPWrapper;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
import org.nd4j.linalg.util.ArrayUtil;
|
||||
|
||||
import java.util.Stack;
|
||||
|
||||
|
@ -74,17 +69,18 @@ public abstract class AsyncThreadDiscrete<O, NN extends NeuralNet>
|
|||
IPolicy<O, Integer> policy = getPolicy(current);
|
||||
|
||||
Integer action;
|
||||
Integer lastAction = null;
|
||||
Integer lastAction = getMdp().getActionSpace().noOp();
|
||||
IHistoryProcessor hp = getHistoryProcessor();
|
||||
int skipFrame = hp != null ? hp.getConf().getSkipFrame() : 1;
|
||||
|
||||
double reward = 0;
|
||||
double accuReward = 0;
|
||||
int i = 0;
|
||||
while (!getMdp().isDone() && i < nstep * skipFrame) {
|
||||
int stepAtStart = getCurrentEpochStep();
|
||||
int lastStep = nstep * skipFrame + stepAtStart;
|
||||
while (!getMdp().isDone() && getCurrentEpochStep() < lastStep) {
|
||||
|
||||
//if step of training, just repeat lastAction
|
||||
if (i % skipFrame != 0 && lastAction != null) {
|
||||
if (obs.isSkipped()) {
|
||||
action = lastAction;
|
||||
} else {
|
||||
action = policy.nextAction(obs);
|
||||
|
@ -94,7 +90,7 @@ public abstract class AsyncThreadDiscrete<O, NN extends NeuralNet>
|
|||
accuReward += stepReply.getReward() * getConf().getRewardFactor();
|
||||
|
||||
//if it's not a skipped frame, you can do a step of training
|
||||
if (i % skipFrame == 0 || lastAction == null || stepReply.isDone()) {
|
||||
if (!obs.isSkipped() || stepReply.isDone()) {
|
||||
|
||||
INDArray[] output = current.outputAll(obs.getData());
|
||||
rewards.add(new MiniTrans(obs.getData(), action, output, accuReward));
|
||||
|
@ -106,7 +102,6 @@ public abstract class AsyncThreadDiscrete<O, NN extends NeuralNet>
|
|||
|
||||
reward += stepReply.getReward();
|
||||
|
||||
i++;
|
||||
incrementStep();
|
||||
lastAction = action;
|
||||
}
|
||||
|
@ -114,7 +109,7 @@ public abstract class AsyncThreadDiscrete<O, NN extends NeuralNet>
|
|||
//a bit of a trick usable because of how the stack is treated to init R
|
||||
// FIXME: The last element of minitrans is only used to seed the reward in calcGradient; observation, action and output are ignored.
|
||||
|
||||
if (getMdp().isDone() && i < nstep * skipFrame)
|
||||
if (getMdp().isDone() && getCurrentEpochStep() < lastStep)
|
||||
rewards.add(new MiniTrans(obs.getData(), null, null, 0));
|
||||
else {
|
||||
INDArray[] output = null;
|
||||
|
@ -127,9 +122,9 @@ public abstract class AsyncThreadDiscrete<O, NN extends NeuralNet>
|
|||
rewards.add(new MiniTrans(obs.getData(), null, output, maxQ));
|
||||
}
|
||||
|
||||
getAsyncGlobal().enqueue(calcGradient(current, rewards), i);
|
||||
getAsyncGlobal().enqueue(calcGradient(current, rewards), getCurrentEpochStep());
|
||||
|
||||
return new SubEpochReturn(i, obs, reward, current.getLatestScore());
|
||||
return new SubEpochReturn(getCurrentEpochStep() - stepAtStart, obs, reward, current.getLatestScore());
|
||||
}
|
||||
|
||||
public abstract Gradient[] calcGradient(NN nn, Stack<MiniTrans<Integer>> rewards);
|
||||
|
|
|
@ -17,7 +17,6 @@
|
|||
package org.deeplearning4j.rl4j.learning.sync;
|
||||
|
||||
import lombok.Getter;
|
||||
import lombok.Setter;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.deeplearning4j.rl4j.learning.IEpochTrainer;
|
||||
import org.deeplearning4j.rl4j.learning.ILearning;
|
||||
|
@ -25,7 +24,6 @@ import org.deeplearning4j.rl4j.learning.Learning;
|
|||
import org.deeplearning4j.rl4j.learning.listener.*;
|
||||
import org.deeplearning4j.rl4j.network.NeuralNet;
|
||||
import org.deeplearning4j.rl4j.space.ActionSpace;
|
||||
import org.deeplearning4j.rl4j.space.Encodable;
|
||||
import org.deeplearning4j.rl4j.util.IDataManager;
|
||||
|
||||
/**
|
||||
|
|
|
@ -16,6 +16,7 @@
|
|||
|
||||
package org.deeplearning4j.rl4j.learning.sync;
|
||||
|
||||
import lombok.Data;
|
||||
import lombok.Value;
|
||||
import org.deeplearning4j.rl4j.observation.Observation;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
|
@ -34,7 +35,7 @@ import java.util.List;
|
|||
* @author Alexandre Boulanger
|
||||
*
|
||||
*/
|
||||
@Value
|
||||
@Data
|
||||
public class Transition<A> {
|
||||
|
||||
Observation observation;
|
||||
|
@ -43,12 +44,15 @@ public class Transition<A> {
|
|||
boolean isTerminal;
|
||||
INDArray nextObservation;
|
||||
|
||||
public Transition(Observation observation, A action, double reward, boolean isTerminal, Observation nextObservation) {
|
||||
public Transition(Observation observation, A action, double reward, boolean isTerminal) {
|
||||
this.observation = observation;
|
||||
this.action = action;
|
||||
this.reward = reward;
|
||||
this.isTerminal = isTerminal;
|
||||
this.nextObservation = null;
|
||||
}
|
||||
|
||||
public void setNextObservation(Observation nextObservation) {
|
||||
// To conserve memory, only the most recent frame of the next observation is kept (if history is used).
|
||||
// The full nextObservation will be re-build from observation when needed.
|
||||
long[] nextObservationShape = nextObservation.getData().shape().clone();
|
||||
|
|
|
@ -21,8 +21,7 @@ import com.fasterxml.jackson.databind.annotation.JsonPOJOBuilder;
|
|||
import lombok.*;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.deeplearning4j.gym.StepReply;
|
||||
import org.deeplearning4j.rl4j.learning.IHistoryProcessor;
|
||||
import org.deeplearning4j.rl4j.learning.Learning;
|
||||
import org.deeplearning4j.rl4j.learning.EpochStepCounter;
|
||||
import org.deeplearning4j.rl4j.learning.sync.ExpReplay;
|
||||
import org.deeplearning4j.rl4j.learning.sync.IExpReplay;
|
||||
import org.deeplearning4j.rl4j.learning.sync.SyncLearning;
|
||||
|
@ -34,9 +33,8 @@ import org.deeplearning4j.rl4j.space.ActionSpace;
|
|||
import org.deeplearning4j.rl4j.space.Encodable;
|
||||
import org.deeplearning4j.rl4j.util.IDataManager.StatEntry;
|
||||
import org.deeplearning4j.rl4j.util.LegacyMDPWrapper;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
import org.nd4j.linalg.api.rng.Random;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
|
@ -49,7 +47,8 @@ import java.util.List;
|
|||
*/
|
||||
@Slf4j
|
||||
public abstract class QLearning<O extends Encodable, A, AS extends ActionSpace<A>>
|
||||
extends SyncLearning<O, A, AS, IDQN> implements TargetQNetworkSource {
|
||||
extends SyncLearning<O, A, AS, IDQN>
|
||||
implements TargetQNetworkSource, EpochStepCounter {
|
||||
|
||||
// FIXME Changed for refac
|
||||
// @Getter
|
||||
|
@ -104,18 +103,22 @@ public abstract class QLearning<O extends Encodable, A, AS extends ActionSpace<A
|
|||
|
||||
protected abstract QLStepReturn<Observation> trainStep(Observation obs);
|
||||
|
||||
@Getter
|
||||
private int currentEpochStep = 0;
|
||||
|
||||
protected StatEntry trainEpoch() {
|
||||
resetNetworks();
|
||||
|
||||
InitMdp<Observation> initMdp = refacInitMdp();
|
||||
Observation obs = initMdp.getLastObs();
|
||||
|
||||
double reward = initMdp.getReward();
|
||||
int step = initMdp.getSteps();
|
||||
|
||||
Double startQ = Double.NaN;
|
||||
double meanQ = 0;
|
||||
int numQ = 0;
|
||||
List<Double> scores = new ArrayList<>();
|
||||
while (step < getConfiguration().getMaxEpochStep() && !getMdp().isDone()) {
|
||||
while (currentEpochStep < getConfiguration().getMaxEpochStep() && !getMdp().isDone()) {
|
||||
|
||||
if (getStepCounter() % getConfiguration().getTargetDqnUpdateFreq() == 0) {
|
||||
updateTargetNetwork();
|
||||
|
@ -136,49 +139,53 @@ public abstract class QLearning<O extends Encodable, A, AS extends ActionSpace<A
|
|||
reward += stepR.getStepReply().getReward();
|
||||
obs = stepR.getStepReply().getObservation();
|
||||
incrementStep();
|
||||
step++;
|
||||
}
|
||||
|
||||
finishEpoch(obs);
|
||||
|
||||
meanQ /= (numQ + 0.001); //avoid div zero
|
||||
|
||||
|
||||
StatEntry statEntry = new QLStatEntry(getStepCounter(), getEpochCounter(), reward, step, scores,
|
||||
StatEntry statEntry = new QLStatEntry(getStepCounter(), getEpochCounter(), reward, currentEpochStep, scores,
|
||||
getEgPolicy().getEpsilon(), startQ, meanQ);
|
||||
|
||||
return statEntry;
|
||||
}
|
||||
|
||||
protected void finishEpoch(Observation observation) {
|
||||
// Do Nothing
|
||||
}
|
||||
|
||||
@Override
|
||||
public void incrementStep() {
|
||||
super.incrementStep();
|
||||
++currentEpochStep;
|
||||
}
|
||||
|
||||
protected void resetNetworks() {
|
||||
getQNetwork().reset();
|
||||
getTargetQNetwork().reset();
|
||||
}
|
||||
|
||||
private InitMdp<Observation> refacInitMdp() {
|
||||
getQNetwork().reset();
|
||||
getTargetQNetwork().reset();
|
||||
currentEpochStep = 0;
|
||||
|
||||
LegacyMDPWrapper<O, A, AS> mdp = getLegacyMDPWrapper();
|
||||
IHistoryProcessor hp = getHistoryProcessor();
|
||||
|
||||
Observation observation = mdp.reset();
|
||||
|
||||
int step = 0;
|
||||
double reward = 0;
|
||||
|
||||
boolean isHistoryProcessor = hp != null;
|
||||
|
||||
int skipFrame = isHistoryProcessor ? hp.getConf().getSkipFrame() : 1;
|
||||
int requiredFrame = isHistoryProcessor ? skipFrame * (hp.getConf().getHistoryLength() - 1) : 0;
|
||||
|
||||
while (step < requiredFrame && !mdp.isDone()) {
|
||||
|
||||
A action = mdp.getActionSpace().noOp(); //by convention should be the NO_OP
|
||||
LegacyMDPWrapper<O, A, AS> mdp = getLegacyMDPWrapper();
|
||||
Observation observation = mdp.reset();
|
||||
|
||||
A action = mdp.getActionSpace().noOp(); //by convention should be the NO_OP
|
||||
while (observation.isSkipped() && !mdp.isDone()) {
|
||||
StepReply<Observation> stepReply = mdp.step(action);
|
||||
|
||||
reward += stepReply.getReward();
|
||||
observation = stepReply.getObservation();
|
||||
|
||||
step++;
|
||||
|
||||
incrementStep();
|
||||
}
|
||||
|
||||
return new InitMdp(step, observation, reward);
|
||||
return new InitMdp(0, observation, reward);
|
||||
|
||||
}
|
||||
|
||||
|
|
|
@ -20,10 +20,13 @@ import lombok.AccessLevel;
|
|||
import lombok.Getter;
|
||||
import lombok.Setter;
|
||||
import org.deeplearning4j.gym.StepReply;
|
||||
import org.deeplearning4j.rl4j.learning.IHistoryProcessor;
|
||||
import org.deeplearning4j.rl4j.learning.Learning;
|
||||
import org.deeplearning4j.rl4j.learning.sync.Transition;
|
||||
import org.deeplearning4j.rl4j.learning.sync.qlearning.QLearning;
|
||||
import org.deeplearning4j.rl4j.learning.sync.qlearning.discrete.TDTargetAlgorithm.*;
|
||||
import org.deeplearning4j.rl4j.learning.sync.qlearning.discrete.TDTargetAlgorithm.DoubleDQN;
|
||||
import org.deeplearning4j.rl4j.learning.sync.qlearning.discrete.TDTargetAlgorithm.ITDTargetAlgorithm;
|
||||
import org.deeplearning4j.rl4j.learning.sync.qlearning.discrete.TDTargetAlgorithm.StandardDQN;
|
||||
import org.deeplearning4j.rl4j.mdp.MDP;
|
||||
import org.deeplearning4j.rl4j.network.dqn.IDQN;
|
||||
import org.deeplearning4j.rl4j.observation.Observation;
|
||||
|
@ -36,7 +39,6 @@ import org.nd4j.linalg.api.ndarray.INDArray;
|
|||
import org.nd4j.linalg.api.rng.Random;
|
||||
import org.nd4j.linalg.dataset.api.DataSet;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
import org.nd4j.linalg.util.ArrayUtil;
|
||||
|
||||
import java.util.ArrayList;
|
||||
|
||||
|
@ -68,6 +70,8 @@ public abstract class QLearningDiscrete<O extends Encodable> extends QLearning<O
|
|||
private int lastAction;
|
||||
private double accuReward = 0;
|
||||
|
||||
private Transition pendingTransition;
|
||||
|
||||
ITDTargetAlgorithm tdTargetAlgorithm;
|
||||
|
||||
protected LegacyMDPWrapper<O, Integer, DiscreteSpace> getLegacyMDPWrapper() {
|
||||
|
@ -83,7 +87,7 @@ public abstract class QLearningDiscrete<O extends Encodable> extends QLearning<O
|
|||
int epsilonNbStep, Random random) {
|
||||
super(conf);
|
||||
this.configuration = conf;
|
||||
this.mdp = new LegacyMDPWrapper<O, Integer, DiscreteSpace>(mdp, this);
|
||||
this.mdp = new LegacyMDPWrapper<O, Integer, DiscreteSpace>(mdp, null, this);
|
||||
qNetwork = dqn;
|
||||
targetQNetwork = dqn.clone();
|
||||
policy = new DQNPolicy(getQNetwork());
|
||||
|
@ -108,8 +112,15 @@ public abstract class QLearningDiscrete<O extends Encodable> extends QLearning<O
|
|||
}
|
||||
|
||||
public void preEpoch() {
|
||||
lastAction = 0;
|
||||
lastAction = mdp.getActionSpace().noOp();
|
||||
accuReward = 0;
|
||||
pendingTransition = null;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void setHistoryProcessor(IHistoryProcessor historyProcessor) {
|
||||
super.setHistoryProcessor(historyProcessor);
|
||||
mdp.setHistoryProcessor(historyProcessor);
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -120,9 +131,8 @@ public abstract class QLearningDiscrete<O extends Encodable> extends QLearning<O
|
|||
protected QLStepReturn<Observation> trainStep(Observation obs) {
|
||||
|
||||
Integer action;
|
||||
|
||||
boolean isHistoryProcessor = getHistoryProcessor() != null;
|
||||
|
||||
|
||||
int skipFrame = isHistoryProcessor ? getHistoryProcessor().getConf().getSkipFrame() : 1;
|
||||
int historyLength = isHistoryProcessor ? getHistoryProcessor().getConf().getHistoryLength() : 1;
|
||||
int updateStart = getConfiguration().getUpdateStart()
|
||||
|
@ -131,7 +141,7 @@ public abstract class QLearningDiscrete<O extends Encodable> extends QLearning<O
|
|||
Double maxQ = Double.NaN; //ignore if Nan for stats
|
||||
|
||||
//if step of training, just repeat lastAction
|
||||
if (getStepCounter() % skipFrame != 0) {
|
||||
if (obs.isSkipped()) {
|
||||
action = lastAction;
|
||||
} else {
|
||||
INDArray qs = getQNetwork().output(obs);
|
||||
|
@ -145,22 +155,25 @@ public abstract class QLearningDiscrete<O extends Encodable> extends QLearning<O
|
|||
|
||||
StepReply<Observation> stepReply = mdp.step(action);
|
||||
|
||||
Observation nextObservation = stepReply.getObservation();
|
||||
|
||||
accuReward += stepReply.getReward() * configuration.getRewardFactor();
|
||||
|
||||
//if it's not a skipped frame, you can do a step of training
|
||||
if (getStepCounter() % skipFrame == 0 || stepReply.isDone()) {
|
||||
if (!obs.isSkipped() || stepReply.isDone()) {
|
||||
|
||||
Transition<Integer> trans = new Transition(obs, action, accuReward, stepReply.isDone(), nextObservation);
|
||||
getExpReplay().store(trans);
|
||||
// Add experience
|
||||
if(pendingTransition != null) {
|
||||
pendingTransition.setNextObservation(obs);
|
||||
getExpReplay().store(pendingTransition);
|
||||
}
|
||||
pendingTransition = new Transition(obs, action, accuReward, stepReply.isDone());
|
||||
accuReward = 0;
|
||||
|
||||
// Update NN
|
||||
// FIXME: maybe start updating when experience replay has reached a certain size instead of using "updateStart"?
|
||||
if (getStepCounter() > updateStart) {
|
||||
DataSet targets = setTarget(getExpReplay().getBatch());
|
||||
getQNetwork().fit(targets.getFeatures(), targets.getLabels());
|
||||
}
|
||||
|
||||
accuReward = 0;
|
||||
}
|
||||
|
||||
return new QLStepReturn<Observation>(maxQ, getQNetwork().getLatestScore(), stepReply);
|
||||
|
@ -172,4 +185,12 @@ public abstract class QLearningDiscrete<O extends Encodable> extends QLearning<O
|
|||
|
||||
return tdTargetAlgorithm.computeTDTargets(transitions);
|
||||
}
|
||||
|
||||
@Override
|
||||
protected void finishEpoch(Observation observation) {
|
||||
if(pendingTransition != null) {
|
||||
pendingTransition.setNextObservation(observation);
|
||||
getExpReplay().store(pendingTransition);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -92,6 +92,7 @@ public class CartpoleNative implements MDP<CartpoleNative.State, Integer, Discre
|
|||
theta = 0.1 * rnd.nextDouble() - 0.05;
|
||||
thetaDot = 0.1 * rnd.nextDouble() - 0.05;
|
||||
stepsBeyondDone = null;
|
||||
done = false;
|
||||
|
||||
return new State(new double[] { x, xDot, theta, thetaDot });
|
||||
}
|
||||
|
@ -126,7 +127,7 @@ public class CartpoleNative implements MDP<CartpoleNative.State, Integer, Discre
|
|||
break;
|
||||
}
|
||||
|
||||
boolean done = x < -xThreshold || x > xThreshold
|
||||
done |= x < -xThreshold || x > xThreshold
|
||||
|| theta < -thetaThresholdRadians || theta > thetaThresholdRadians;
|
||||
|
||||
double reward;
|
||||
|
|
|
@ -16,6 +16,8 @@
|
|||
|
||||
package org.deeplearning4j.rl4j.observation;
|
||||
|
||||
import lombok.Getter;
|
||||
import lombok.Setter;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.dataset.api.DataSet;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
|
@ -28,6 +30,9 @@ public class Observation {
|
|||
|
||||
private final DataSet data;
|
||||
|
||||
@Getter @Setter
|
||||
private boolean skipped;
|
||||
|
||||
public Observation(INDArray[] data) {
|
||||
this(data, false);
|
||||
}
|
||||
|
|
|
@ -18,7 +18,8 @@ package org.deeplearning4j.rl4j.policy;
|
|||
|
||||
import lombok.AllArgsConstructor;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.deeplearning4j.rl4j.learning.StepCountable;
|
||||
import org.deeplearning4j.rl4j.learning.IEpochTrainer;
|
||||
import org.deeplearning4j.rl4j.learning.ILearning;
|
||||
import org.deeplearning4j.rl4j.mdp.MDP;
|
||||
import org.deeplearning4j.rl4j.network.NeuralNet;
|
||||
import org.deeplearning4j.rl4j.observation.Observation;
|
||||
|
@ -46,7 +47,7 @@ public class EpsGreedy<O, A, AS extends ActionSpace<A>> extends Policy<O, A> {
|
|||
final private int epsilonNbStep;
|
||||
final private Random rnd;
|
||||
final private float minEpsilon;
|
||||
final private StepCountable learning;
|
||||
final private IEpochTrainer learning;
|
||||
|
||||
public NeuralNet getNeuralNet() {
|
||||
return policy.getNeuralNet();
|
||||
|
|
|
@ -19,20 +19,15 @@ package org.deeplearning4j.rl4j.policy;
|
|||
import lombok.Getter;
|
||||
import lombok.Setter;
|
||||
import org.deeplearning4j.gym.StepReply;
|
||||
import org.deeplearning4j.rl4j.learning.EpochStepCounter;
|
||||
import org.deeplearning4j.rl4j.learning.HistoryProcessor;
|
||||
import org.deeplearning4j.rl4j.learning.IHistoryProcessor;
|
||||
import org.deeplearning4j.rl4j.learning.Learning;
|
||||
import org.deeplearning4j.rl4j.learning.StepCountable;
|
||||
import org.deeplearning4j.rl4j.learning.sync.Transition;
|
||||
import org.deeplearning4j.rl4j.mdp.MDP;
|
||||
import org.deeplearning4j.rl4j.network.NeuralNet;
|
||||
import org.deeplearning4j.rl4j.observation.Observation;
|
||||
import org.deeplearning4j.rl4j.space.ActionSpace;
|
||||
import org.deeplearning4j.rl4j.space.DiscreteSpace;
|
||||
import org.deeplearning4j.rl4j.space.Encodable;
|
||||
import org.deeplearning4j.rl4j.util.LegacyMDPWrapper;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.util.ArrayUtil;
|
||||
|
||||
/**
|
||||
* @author rubenfiszel (ruben.fiszel@epfl.ch) 7/18/16.
|
||||
|
@ -57,24 +52,22 @@ public abstract class Policy<O, A> implements IPolicy<O, A> {
|
|||
|
||||
@Override
|
||||
public <AS extends ActionSpace<A>> double play(MDP<O, A, AS> mdp, IHistoryProcessor hp) {
|
||||
RefacStepCountable stepCountable = new RefacStepCountable();
|
||||
LegacyMDPWrapper<O, A, AS> mdpWrapper = new LegacyMDPWrapper<O, A, AS>(mdp, hp, stepCountable);
|
||||
resetNetworks();
|
||||
|
||||
boolean isHistoryProcessor = hp != null;
|
||||
int skipFrame = isHistoryProcessor ? hp.getConf().getSkipFrame() : 1;
|
||||
RefacEpochStepCounter epochStepCounter = new RefacEpochStepCounter();
|
||||
LegacyMDPWrapper<O, A, AS> mdpWrapper = new LegacyMDPWrapper<O, A, AS>(mdp, hp, epochStepCounter);
|
||||
|
||||
Learning.InitMdp<Observation> initMdp = refacInitMdp(mdpWrapper, hp);
|
||||
Learning.InitMdp<Observation> initMdp = refacInitMdp(mdpWrapper, hp, epochStepCounter);
|
||||
Observation obs = initMdp.getLastObs();
|
||||
|
||||
double reward = initMdp.getReward();
|
||||
|
||||
A lastAction = mdpWrapper.getActionSpace().noOp();
|
||||
A action;
|
||||
stepCountable.setStepCounter(initMdp.getSteps());
|
||||
|
||||
while (!mdpWrapper.isDone()) {
|
||||
|
||||
if (stepCountable.getStepCounter() % skipFrame != 0) {
|
||||
if (obs.isSkipped()) {
|
||||
action = lastAction;
|
||||
} else {
|
||||
action = nextAction(obs);
|
||||
|
@ -86,52 +79,46 @@ public abstract class Policy<O, A> implements IPolicy<O, A> {
|
|||
reward += stepReply.getReward();
|
||||
|
||||
obs = stepReply.getObservation();
|
||||
stepCountable.increment();
|
||||
epochStepCounter.incrementEpochStep();
|
||||
}
|
||||
|
||||
return reward;
|
||||
}
|
||||
|
||||
private <AS extends ActionSpace<A>> Learning.InitMdp<Observation> refacInitMdp(LegacyMDPWrapper<O, A, AS> mdpWrapper, IHistoryProcessor hp) {
|
||||
protected void resetNetworks() {
|
||||
getNeuralNet().reset();
|
||||
Observation observation = mdpWrapper.reset();
|
||||
}
|
||||
|
||||
private <AS extends ActionSpace<A>> Learning.InitMdp<Observation> refacInitMdp(LegacyMDPWrapper<O, A, AS> mdpWrapper, IHistoryProcessor hp, RefacEpochStepCounter epochStepCounter) {
|
||||
epochStepCounter.setCurrentEpochStep(0);
|
||||
|
||||
int step = 0;
|
||||
double reward = 0;
|
||||
|
||||
boolean isHistoryProcessor = hp != null;
|
||||
Observation observation = mdpWrapper.reset();
|
||||
|
||||
int skipFrame = isHistoryProcessor ? hp.getConf().getSkipFrame() : 1;
|
||||
int requiredFrame = isHistoryProcessor ? skipFrame * (hp.getConf().getHistoryLength() - 1) : 0;
|
||||
|
||||
while (step < requiredFrame && !mdpWrapper.isDone()) {
|
||||
|
||||
A action = mdpWrapper.getActionSpace().noOp(); //by convention should be the NO_OP
|
||||
A action = mdpWrapper.getActionSpace().noOp(); //by convention should be the NO_OP
|
||||
while (observation.isSkipped() && !mdpWrapper.isDone()) {
|
||||
|
||||
StepReply<Observation> stepReply = mdpWrapper.step(action);
|
||||
|
||||
reward += stepReply.getReward();
|
||||
observation = stepReply.getObservation();
|
||||
|
||||
step++;
|
||||
|
||||
epochStepCounter.incrementEpochStep();
|
||||
}
|
||||
|
||||
return new Learning.InitMdp(step, observation, reward);
|
||||
return new Learning.InitMdp(0, observation, reward);
|
||||
}
|
||||
|
||||
private class RefacStepCountable implements StepCountable {
|
||||
public class RefacEpochStepCounter implements EpochStepCounter {
|
||||
|
||||
@Getter
|
||||
@Setter
|
||||
private int stepCounter = 0;
|
||||
private int currentEpochStep = 0;
|
||||
|
||||
public void increment() {
|
||||
++stepCounter;
|
||||
public void incrementEpochStep() {
|
||||
++currentEpochStep;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int getStepCounter() {
|
||||
return 0;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,10 +1,11 @@
|
|||
package org.deeplearning4j.rl4j.util;
|
||||
|
||||
import lombok.AccessLevel;
|
||||
import lombok.Getter;
|
||||
import lombok.Setter;
|
||||
import org.deeplearning4j.gym.StepReply;
|
||||
import org.deeplearning4j.rl4j.learning.EpochStepCounter;
|
||||
import org.deeplearning4j.rl4j.learning.IHistoryProcessor;
|
||||
import org.deeplearning4j.rl4j.learning.ILearning;
|
||||
import org.deeplearning4j.rl4j.learning.StepCountable;
|
||||
import org.deeplearning4j.rl4j.mdp.MDP;
|
||||
import org.deeplearning4j.rl4j.observation.Observation;
|
||||
import org.deeplearning4j.rl4j.space.ActionSpace;
|
||||
|
@ -19,50 +20,20 @@ public class LegacyMDPWrapper<O, A, AS extends ActionSpace<A>> implements MDP<Ob
|
|||
private final MDP<O, A, AS> wrappedMDP;
|
||||
@Getter
|
||||
private final WrapperObservationSpace observationSpace;
|
||||
private final ILearning learning;
|
||||
|
||||
@Getter(AccessLevel.PRIVATE) @Setter(AccessLevel.PUBLIC)
|
||||
private IHistoryProcessor historyProcessor;
|
||||
private final StepCountable stepCountable;
|
||||
private int skipFrame;
|
||||
|
||||
private int step = 0;
|
||||
private final EpochStepCounter epochStepCounter;
|
||||
|
||||
public LegacyMDPWrapper(MDP<O, A, AS> wrappedMDP, ILearning learning) {
|
||||
this(wrappedMDP, learning, null, null);
|
||||
}
|
||||
private int skipFrame = 1;
|
||||
private int requiredFrame = 0;
|
||||
|
||||
public LegacyMDPWrapper(MDP<O, A, AS> wrappedMDP, IHistoryProcessor historyProcessor, StepCountable stepCountable) {
|
||||
this(wrappedMDP, null, historyProcessor, stepCountable);
|
||||
}
|
||||
|
||||
private LegacyMDPWrapper(MDP<O, A, AS> wrappedMDP, ILearning learning, IHistoryProcessor historyProcessor, StepCountable stepCountable) {
|
||||
public LegacyMDPWrapper(MDP<O, A, AS> wrappedMDP, IHistoryProcessor historyProcessor, EpochStepCounter epochStepCounter) {
|
||||
this.wrappedMDP = wrappedMDP;
|
||||
this.observationSpace = new WrapperObservationSpace(wrappedMDP.getObservationSpace().getShape());
|
||||
this.learning = learning;
|
||||
this.historyProcessor = historyProcessor;
|
||||
this.stepCountable = stepCountable;
|
||||
}
|
||||
|
||||
private IHistoryProcessor getHistoryProcessor() {
|
||||
if(historyProcessor != null) {
|
||||
return historyProcessor;
|
||||
}
|
||||
|
||||
if (learning != null) {
|
||||
return learning.getHistoryProcessor();
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
public void setHistoryProcessor(IHistoryProcessor historyProcessor) {
|
||||
this.historyProcessor = historyProcessor;
|
||||
}
|
||||
|
||||
private int getStep() {
|
||||
if(stepCountable != null) {
|
||||
return stepCountable.getStepCounter();
|
||||
}
|
||||
|
||||
return learning.getStepCounter();
|
||||
this.epochStepCounter = epochStepCounter;
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -83,9 +54,12 @@ public class LegacyMDPWrapper<O, A, AS extends ActionSpace<A>> implements MDP<Ob
|
|||
|
||||
if(historyProcessor != null) {
|
||||
skipFrame = historyProcessor.getConf().getSkipFrame();
|
||||
requiredFrame = skipFrame * (historyProcessor.getConf().getHistoryLength() - 1);
|
||||
|
||||
historyProcessor.add(rawObservation);
|
||||
}
|
||||
step = 0;
|
||||
|
||||
observation.setSkipped(skipFrame != 0);
|
||||
|
||||
return observation;
|
||||
}
|
||||
|
@ -97,21 +71,18 @@ public class LegacyMDPWrapper<O, A, AS extends ActionSpace<A>> implements MDP<Ob
|
|||
StepReply<O> rawStepReply = wrappedMDP.step(a);
|
||||
INDArray rawObservation = getInput(rawStepReply.getObservation());
|
||||
|
||||
++step;
|
||||
int stepOfObservation = epochStepCounter.getCurrentEpochStep() + 1;
|
||||
|
||||
int requiredFrame = 0;
|
||||
if(historyProcessor != null) {
|
||||
historyProcessor.record(rawObservation);
|
||||
|
||||
requiredFrame = skipFrame * (historyProcessor.getConf().getHistoryLength() - 1);
|
||||
if ((getStep() % skipFrame == 0 && step >= requiredFrame)
|
||||
|| (step % skipFrame == 0 && step < requiredFrame )){
|
||||
if (stepOfObservation % skipFrame == 0) {
|
||||
historyProcessor.add(rawObservation);
|
||||
}
|
||||
}
|
||||
|
||||
Observation observation;
|
||||
if(historyProcessor != null && step >= requiredFrame) {
|
||||
if(historyProcessor != null && stepOfObservation >= requiredFrame) {
|
||||
observation = new Observation(historyProcessor.getHistory(), true);
|
||||
observation.getData().muli(1.0 / historyProcessor.getScale());
|
||||
}
|
||||
|
@ -119,6 +90,10 @@ public class LegacyMDPWrapper<O, A, AS extends ActionSpace<A>> implements MDP<Ob
|
|||
observation = new Observation(new INDArray[] { rawObservation }, false);
|
||||
}
|
||||
|
||||
if(stepOfObservation % skipFrame != 0 || stepOfObservation < requiredFrame) {
|
||||
observation.setSkipped(true);
|
||||
}
|
||||
|
||||
return new StepReply<Observation>(observation, rawStepReply.getReward(), rawStepReply.isDone(), rawStepReply.getInfo());
|
||||
}
|
||||
|
||||
|
@ -134,7 +109,7 @@ public class LegacyMDPWrapper<O, A, AS extends ActionSpace<A>> implements MDP<Ob
|
|||
|
||||
@Override
|
||||
public MDP<Observation, A, AS> newInstance() {
|
||||
return new LegacyMDPWrapper<O, A, AS>(wrappedMDP.newInstance(), learning);
|
||||
return new LegacyMDPWrapper<O, A, AS>(wrappedMDP.newInstance(), historyProcessor, epochStepCounter);
|
||||
}
|
||||
|
||||
private INDArray getInput(O obs) {
|
||||
|
|
|
@ -32,7 +32,7 @@ public class AsyncThreadDiscreteTest {
|
|||
MockMDP mdpMock = new MockMDP(observationSpace);
|
||||
TrainingListenerList listeners = new TrainingListenerList();
|
||||
MockPolicy policyMock = new MockPolicy();
|
||||
MockAsyncConfiguration config = new MockAsyncConfiguration(5, 16, 0, 0, 2, 5,0, 0, 0, 0);
|
||||
MockAsyncConfiguration config = new MockAsyncConfiguration(5, 100, 0, 0, 2, 5,0, 0, 0, 0);
|
||||
TestAsyncThreadDiscrete sut = new TestAsyncThreadDiscrete(asyncGlobalMock, mdpMock, listeners, 0, 0, policyMock, config, hpMock);
|
||||
|
||||
// Act
|
||||
|
@ -41,8 +41,8 @@ public class AsyncThreadDiscreteTest {
|
|||
// Assert
|
||||
assertEquals(2, sut.trainSubEpochResults.size());
|
||||
double[][] expectedLastObservations = new double[][] {
|
||||
new double[] { 4.0, 6.0, 8.0, 9.0, 11.0 },
|
||||
new double[] { 8.0, 9.0, 11.0, 13.0, 15.0 },
|
||||
new double[] { 4.0, 6.0, 8.0, 10.0, 12.0 },
|
||||
new double[] { 8.0, 10.0, 12.0, 14.0, 16.0 },
|
||||
};
|
||||
double[] expectedSubEpochReturnRewards = new double[] { 42.0, 58.0 };
|
||||
for(int i = 0; i < 2; ++i) {
|
||||
|
@ -60,7 +60,7 @@ public class AsyncThreadDiscreteTest {
|
|||
assertEquals(2, asyncGlobalMock.enqueueCallCount);
|
||||
|
||||
// HistoryProcessor
|
||||
double[] expectedAddValues = new double[] { 0.0, 2.0, 4.0, 6.0, 8.0, 9.0, 11.0, 13.0, 15.0 };
|
||||
double[] expectedAddValues = new double[] { 0.0, 2.0, 4.0, 6.0, 8.0, 10.0, 12.0, 14.0, 16.0 };
|
||||
assertEquals(expectedAddValues.length, hpMock.addCalls.size());
|
||||
for(int i = 0; i < expectedAddValues.length; ++i) {
|
||||
assertEquals(expectedAddValues[i], hpMock.addCalls.get(i).getDouble(0), 0.00001);
|
||||
|
@ -75,9 +75,9 @@ public class AsyncThreadDiscreteTest {
|
|||
// Policy
|
||||
double[][] expectedPolicyInputs = new double[][] {
|
||||
new double[] { 0.0, 2.0, 4.0, 6.0, 8.0 },
|
||||
new double[] { 2.0, 4.0, 6.0, 8.0, 9.0 },
|
||||
new double[] { 4.0, 6.0, 8.0, 9.0, 11.0 },
|
||||
new double[] { 6.0, 8.0, 9.0, 11.0, 13.0 },
|
||||
new double[] { 2.0, 4.0, 6.0, 8.0, 10.0 },
|
||||
new double[] { 4.0, 6.0, 8.0, 10.0, 12.0 },
|
||||
new double[] { 6.0, 8.0, 10.0, 12.0, 14.0 },
|
||||
};
|
||||
assertEquals(expectedPolicyInputs.length, policyMock.actionInputs.size());
|
||||
for(int i = 0; i < expectedPolicyInputs.length; ++i) {
|
||||
|
@ -93,11 +93,11 @@ public class AsyncThreadDiscreteTest {
|
|||
assertEquals(2, nnMock.copyCallCount);
|
||||
double[][] expectedNNInputs = new double[][] {
|
||||
new double[] { 0.0, 2.0, 4.0, 6.0, 8.0 },
|
||||
new double[] { 2.0, 4.0, 6.0, 8.0, 9.0 },
|
||||
new double[] { 4.0, 6.0, 8.0, 9.0, 11.0 }, // FIXME: This one comes from the computation of output of the last minitrans
|
||||
new double[] { 4.0, 6.0, 8.0, 9.0, 11.0 },
|
||||
new double[] { 6.0, 8.0, 9.0, 11.0, 13.0 },
|
||||
new double[] { 8.0, 9.0, 11.0, 13.0, 15.0 }, // FIXME: This one comes from the computation of output of the last minitrans
|
||||
new double[] { 2.0, 4.0, 6.0, 8.0, 10.0 },
|
||||
new double[] { 4.0, 6.0, 8.0, 10.0, 12.0 }, // FIXME: This one comes from the computation of output of the last minitrans
|
||||
new double[] { 4.0, 6.0, 8.0, 10.0, 12.0 },
|
||||
new double[] { 6.0, 8.0, 10.0, 12.0, 14.0 },
|
||||
new double[] { 8.0, 10.0, 12.0, 14.0, 16.0 }, // FIXME: This one comes from the computation of output of the last minitrans
|
||||
};
|
||||
assertEquals(expectedNNInputs.length, nnMock.outputAllInputs.size());
|
||||
for(int i = 0; i < expectedNNInputs.length; ++i) {
|
||||
|
@ -113,13 +113,13 @@ public class AsyncThreadDiscreteTest {
|
|||
double[][][] expectedMinitransObs = new double[][][] {
|
||||
new double[][] {
|
||||
new double[] { 0.0, 2.0, 4.0, 6.0, 8.0 },
|
||||
new double[] { 2.0, 4.0, 6.0, 8.0, 9.0 },
|
||||
new double[] { 4.0, 6.0, 8.0, 9.0, 11.0 }, // FIXME: The last minitrans contains the next observation
|
||||
new double[] { 2.0, 4.0, 6.0, 8.0, 10.0 },
|
||||
new double[] { 4.0, 6.0, 8.0, 10.0, 12.0 }, // FIXME: The last minitrans contains the next observation
|
||||
},
|
||||
new double[][] {
|
||||
new double[] { 4.0, 6.0, 8.0, 9.0, 11.0 },
|
||||
new double[] { 6.0, 8.0, 9.0, 11.0, 13.0 },
|
||||
new double[] { 8.0, 9.0, 11.0, 13.0, 15 }, // FIXME: The last minitrans contains the next observation
|
||||
new double[] { 4.0, 6.0, 8.0, 10.0, 12.0 },
|
||||
new double[] { 6.0, 8.0, 10.0, 12.0, 14.0 },
|
||||
new double[] { 8.0, 10.0, 12.0, 14.0, 16.0 }, // FIXME: The last minitrans contains the next observation
|
||||
}
|
||||
};
|
||||
double[] expectedOutputs = new double[] { 1.0, 2.0, 3.0, 4.0, 5.0, 6.0 };
|
||||
|
|
|
@ -5,15 +5,12 @@ import lombok.Getter;
|
|||
import org.deeplearning4j.rl4j.learning.IHistoryProcessor;
|
||||
import org.deeplearning4j.rl4j.learning.listener.TrainingListenerList;
|
||||
import org.deeplearning4j.rl4j.mdp.MDP;
|
||||
import org.deeplearning4j.rl4j.network.NeuralNet;
|
||||
import org.deeplearning4j.rl4j.observation.Observation;
|
||||
import org.deeplearning4j.rl4j.policy.Policy;
|
||||
import org.deeplearning4j.rl4j.space.DiscreteSpace;
|
||||
import org.deeplearning4j.rl4j.space.Encodable;
|
||||
import org.deeplearning4j.rl4j.support.*;
|
||||
import org.deeplearning4j.rl4j.util.IDataManager;
|
||||
import org.junit.Test;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
|
@ -91,7 +88,7 @@ public class AsyncThreadTest {
|
|||
|
||||
// Assert
|
||||
assertEquals(numberOfEpochs, context.listener.statEntries.size());
|
||||
int[] expectedStepCounter = new int[] { 2, 4, 6, 8, 10 };
|
||||
int[] expectedStepCounter = new int[] { 10, 20, 30, 40, 50 };
|
||||
double expectedReward = (1.0 + 2.0 + 3.0 + 4.0 + 5.0 + 6.0 + 7.0 + 8.0) // reward from init
|
||||
+ 1.0; // Reward from trainSubEpoch()
|
||||
for(int i = 0; i < numberOfEpochs; ++i) {
|
||||
|
@ -114,7 +111,7 @@ public class AsyncThreadTest {
|
|||
// Assert
|
||||
assertEquals(numberOfEpochs, context.sut.trainSubEpochParams.size());
|
||||
double[] expectedObservation = new double[] { 0.0, 2.0, 4.0, 6.0, 8.0 };
|
||||
for(int i = 0; i < context.sut.getEpochCounter(); ++i) {
|
||||
for(int i = 0; i < context.sut.trainSubEpochParams.size(); ++i) {
|
||||
MockAsyncThread.TrainSubEpochParams params = context.sut.trainSubEpochParams.get(i);
|
||||
assertEquals(2, params.nstep);
|
||||
assertEquals(expectedObservation.length, params.obs.getData().shape()[1]);
|
||||
|
@ -199,7 +196,9 @@ public class AsyncThreadTest {
|
|||
protected SubEpochReturn trainSubEpoch(Observation obs, int nstep) {
|
||||
asyncGlobal.increaseCurrentLoop();
|
||||
trainSubEpochParams.add(new TrainSubEpochParams(obs, nstep));
|
||||
setStepCounter(getStepCounter() + nstep);
|
||||
for(int i = 0; i < nstep; ++i) {
|
||||
incrementStep();
|
||||
}
|
||||
return new SubEpochReturn(nstep, null, 1.0, 1.0);
|
||||
}
|
||||
|
||||
|
|
|
@ -18,8 +18,8 @@ public class ExpReplayTest {
|
|||
ExpReplay<Integer> sut = new ExpReplay<Integer>(2, 1, randomMock);
|
||||
|
||||
// Act
|
||||
Transition<Integer> transition = new Transition<Integer>(new Observation(new INDArray[] { Nd4j.create(1) }),
|
||||
123, 234, false, new Observation(Nd4j.create(1)));
|
||||
Transition<Integer> transition = buildTransition(new Observation(new INDArray[] { Nd4j.create(1) }),
|
||||
123, 234, new Observation(Nd4j.create(1)));
|
||||
sut.store(transition);
|
||||
List<Transition<Integer>> results = sut.getBatch(1);
|
||||
|
||||
|
@ -36,12 +36,12 @@ public class ExpReplayTest {
|
|||
ExpReplay<Integer> sut = new ExpReplay<Integer>(2, 1, randomMock);
|
||||
|
||||
// Act
|
||||
Transition<Integer> transition1 = new Transition<Integer>(new Observation(new INDArray[] { Nd4j.create(1) }),
|
||||
1, 2, false, new Observation(Nd4j.create(1)));
|
||||
Transition<Integer> transition2 = new Transition<Integer>(new Observation(new INDArray[] { Nd4j.create(1) }),
|
||||
3, 4, false, new Observation(Nd4j.create(1)));
|
||||
Transition<Integer> transition3 = new Transition<Integer>(new Observation(new INDArray[] { Nd4j.create(1) }),
|
||||
5, 6, false, new Observation(Nd4j.create(1)));
|
||||
Transition<Integer> transition1 = buildTransition(new Observation(new INDArray[] { Nd4j.create(1) }),
|
||||
1, 2, new Observation(Nd4j.create(1)));
|
||||
Transition<Integer> transition2 = buildTransition(new Observation(new INDArray[] { Nd4j.create(1) }),
|
||||
3, 4, new Observation(Nd4j.create(1)));
|
||||
Transition<Integer> transition3 = buildTransition(new Observation(new INDArray[] { Nd4j.create(1) }),
|
||||
5, 6, new Observation(Nd4j.create(1)));
|
||||
sut.store(transition1);
|
||||
sut.store(transition2);
|
||||
sut.store(transition3);
|
||||
|
@ -78,12 +78,12 @@ public class ExpReplayTest {
|
|||
ExpReplay<Integer> sut = new ExpReplay<Integer>(5, 1, randomMock);
|
||||
|
||||
// Act
|
||||
Transition<Integer> transition1 = new Transition<Integer>(new Observation(new INDArray[] { Nd4j.create(1) }),
|
||||
1, 2, false, new Observation(Nd4j.create(1)));
|
||||
Transition<Integer> transition2 = new Transition<Integer>(new Observation(new INDArray[] { Nd4j.create(1) }),
|
||||
3, 4, false, new Observation(Nd4j.create(1)));
|
||||
Transition<Integer> transition3 = new Transition<Integer>(new Observation(new INDArray[] { Nd4j.create(1) }),
|
||||
5, 6, false, new Observation(Nd4j.create(1)));
|
||||
Transition<Integer> transition1 = buildTransition(new Observation(new INDArray[] { Nd4j.create(1) }),
|
||||
1, 2, new Observation(Nd4j.create(1)));
|
||||
Transition<Integer> transition2 = buildTransition(new Observation(new INDArray[] { Nd4j.create(1) }),
|
||||
3, 4, new Observation(Nd4j.create(1)));
|
||||
Transition<Integer> transition3 = buildTransition(new Observation(new INDArray[] { Nd4j.create(1) }),
|
||||
5, 6, new Observation(Nd4j.create(1)));
|
||||
sut.store(transition1);
|
||||
sut.store(transition2);
|
||||
sut.store(transition3);
|
||||
|
@ -100,12 +100,12 @@ public class ExpReplayTest {
|
|||
ExpReplay<Integer> sut = new ExpReplay<Integer>(5, 1, randomMock);
|
||||
|
||||
// Act
|
||||
Transition<Integer> transition1 = new Transition<Integer>(new Observation(new INDArray[] { Nd4j.create(1) }),
|
||||
1, 2, false, new Observation(Nd4j.create(1)));
|
||||
Transition<Integer> transition2 = new Transition<Integer>(new Observation(new INDArray[] { Nd4j.create(1) }),
|
||||
3, 4, false, new Observation(Nd4j.create(1)));
|
||||
Transition<Integer> transition3 = new Transition<Integer>(new Observation(new INDArray[] { Nd4j.create(1) }),
|
||||
5, 6, false, new Observation(Nd4j.create(1)));
|
||||
Transition<Integer> transition1 = buildTransition(new Observation(new INDArray[] { Nd4j.create(1) }),
|
||||
1, 2, new Observation(Nd4j.create(1)));
|
||||
Transition<Integer> transition2 = buildTransition(new Observation(new INDArray[] { Nd4j.create(1) }),
|
||||
3, 4, new Observation(Nd4j.create(1)));
|
||||
Transition<Integer> transition3 = buildTransition(new Observation(new INDArray[] { Nd4j.create(1) }),
|
||||
5, 6, new Observation(Nd4j.create(1)));
|
||||
sut.store(transition1);
|
||||
sut.store(transition2);
|
||||
sut.store(transition3);
|
||||
|
@ -131,16 +131,16 @@ public class ExpReplayTest {
|
|||
ExpReplay<Integer> sut = new ExpReplay<Integer>(5, 1, randomMock);
|
||||
|
||||
// Act
|
||||
Transition<Integer> transition1 = new Transition<Integer>(new Observation(new INDArray[] { Nd4j.create(1) }),
|
||||
1, 2, false, new Observation(Nd4j.create(1)));
|
||||
Transition<Integer> transition2 = new Transition<Integer>(new Observation(new INDArray[] { Nd4j.create(1) }),
|
||||
3, 4, false, new Observation(Nd4j.create(1)));
|
||||
Transition<Integer> transition3 = new Transition<Integer>(new Observation(new INDArray[] { Nd4j.create(1) }),
|
||||
5, 6, false, new Observation(Nd4j.create(1)));
|
||||
Transition<Integer> transition4 = new Transition<Integer>(new Observation(new INDArray[] { Nd4j.create(1) }),
|
||||
7, 8, false, new Observation(Nd4j.create(1)));
|
||||
Transition<Integer> transition5 = new Transition<Integer>(new Observation(new INDArray[] { Nd4j.create(1) }),
|
||||
9, 10, false, new Observation(Nd4j.create(1)));
|
||||
Transition<Integer> transition1 = buildTransition(new Observation(new INDArray[] { Nd4j.create(1) }),
|
||||
1, 2, new Observation(Nd4j.create(1)));
|
||||
Transition<Integer> transition2 = buildTransition(new Observation(new INDArray[] { Nd4j.create(1) }),
|
||||
3, 4, new Observation(Nd4j.create(1)));
|
||||
Transition<Integer> transition3 = buildTransition(new Observation(new INDArray[] { Nd4j.create(1) }),
|
||||
5, 6, new Observation(Nd4j.create(1)));
|
||||
Transition<Integer> transition4 = buildTransition(new Observation(new INDArray[] { Nd4j.create(1) }),
|
||||
7, 8, new Observation(Nd4j.create(1)));
|
||||
Transition<Integer> transition5 = buildTransition(new Observation(new INDArray[] { Nd4j.create(1) }),
|
||||
9, 10, new Observation(Nd4j.create(1)));
|
||||
sut.store(transition1);
|
||||
sut.store(transition2);
|
||||
sut.store(transition3);
|
||||
|
@ -168,16 +168,16 @@ public class ExpReplayTest {
|
|||
ExpReplay<Integer> sut = new ExpReplay<Integer>(5, 1, randomMock);
|
||||
|
||||
// Act
|
||||
Transition<Integer> transition1 = new Transition<Integer>(new Observation(new INDArray[] { Nd4j.create(1) }),
|
||||
1, 2, false, new Observation(Nd4j.create(1)));
|
||||
Transition<Integer> transition2 = new Transition<Integer>(new Observation(new INDArray[] { Nd4j.create(1) }),
|
||||
3, 4, false, new Observation(Nd4j.create(1)));
|
||||
Transition<Integer> transition3 = new Transition<Integer>(new Observation(new INDArray[] { Nd4j.create(1) }),
|
||||
5, 6, false, new Observation(Nd4j.create(1)));
|
||||
Transition<Integer> transition4 = new Transition<Integer>(new Observation(new INDArray[] { Nd4j.create(1) }),
|
||||
7, 8, false, new Observation(Nd4j.create(1)));
|
||||
Transition<Integer> transition5 = new Transition<Integer>(new Observation(new INDArray[] { Nd4j.create(1) }),
|
||||
9, 10, false, new Observation(Nd4j.create(1)));
|
||||
Transition<Integer> transition1 = buildTransition(new Observation(new INDArray[] { Nd4j.create(1) }),
|
||||
1, 2, new Observation(Nd4j.create(1)));
|
||||
Transition<Integer> transition2 = buildTransition(new Observation(new INDArray[] { Nd4j.create(1) }),
|
||||
3, 4, new Observation(Nd4j.create(1)));
|
||||
Transition<Integer> transition3 = buildTransition(new Observation(new INDArray[] { Nd4j.create(1) }),
|
||||
5, 6, new Observation(Nd4j.create(1)));
|
||||
Transition<Integer> transition4 = buildTransition(new Observation(new INDArray[] { Nd4j.create(1) }),
|
||||
7, 8, new Observation(Nd4j.create(1)));
|
||||
Transition<Integer> transition5 = buildTransition(new Observation(new INDArray[] { Nd4j.create(1) }),
|
||||
9, 10, new Observation(Nd4j.create(1)));
|
||||
sut.store(transition1);
|
||||
sut.store(transition2);
|
||||
sut.store(transition3);
|
||||
|
@ -196,6 +196,12 @@ public class ExpReplayTest {
|
|||
|
||||
assertEquals(5, (int)results.get(2).getAction());
|
||||
assertEquals(6, (int)results.get(2).getReward());
|
||||
}
|
||||
|
||||
private Transition<Integer> buildTransition(Observation observation, Integer action, double reward, Observation nextObservation) {
|
||||
Transition<Integer> result = new Transition<Integer>(observation, action, reward, false);
|
||||
result.setNextObservation(nextObservation);
|
||||
|
||||
return result;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
package org.deeplearning4j.rl4j.learning.sync;
|
||||
|
||||
import lombok.Getter;
|
||||
import org.deeplearning4j.rl4j.learning.sync.qlearning.QLearning;
|
||||
import org.deeplearning4j.rl4j.learning.sync.support.MockStatEntry;
|
||||
import org.deeplearning4j.rl4j.mdp.MDP;
|
||||
|
@ -88,12 +89,15 @@ public class SyncLearningTest {
|
|||
|
||||
private final LConfiguration conf;
|
||||
|
||||
@Getter
|
||||
private int currentEpochStep = 0;
|
||||
|
||||
public MockSyncLearning(LConfiguration conf) {
|
||||
this.conf = conf;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected void preEpoch() { }
|
||||
protected void preEpoch() { currentEpochStep = 0; }
|
||||
|
||||
@Override
|
||||
protected void postEpoch() { }
|
||||
|
@ -101,7 +105,7 @@ public class SyncLearningTest {
|
|||
@Override
|
||||
protected IDataManager.StatEntry trainEpoch() {
|
||||
setStepCounter(getStepCounter() + 1);
|
||||
return new MockStatEntry(getEpochCounter(), getStepCounter(), 1.0);
|
||||
return new MockStatEntry(getCurrentEpochStep(), getStepCounter(), 1.0);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
|
|
@ -21,7 +21,7 @@ public class TransitionTest {
|
|||
Observation nextObservation = buildObservation(nextObs);
|
||||
|
||||
// Act
|
||||
Transition transition = new Transition(observation, 123, 234.0, false, nextObservation);
|
||||
Transition transition = buildTransition(observation, 123, 234.0, nextObservation);
|
||||
|
||||
// Assert
|
||||
double[][] expectedObservation = new double[][] { obs };
|
||||
|
@ -52,7 +52,7 @@ public class TransitionTest {
|
|||
Observation nextObservation = buildObservation(nextObs);
|
||||
|
||||
// Act
|
||||
Transition transition = new Transition(observation, 123, 234.0, false, nextObservation);
|
||||
Transition transition = buildTransition(observation, 123, 234.0, nextObservation);
|
||||
|
||||
// Assert
|
||||
assertExpected(obs, transition.getObservation().getData());
|
||||
|
@ -71,12 +71,12 @@ public class TransitionTest {
|
|||
double[] obs1 = new double[] { 0.0, 1.0, 2.0 };
|
||||
Observation observation1 = buildObservation(obs1);
|
||||
Observation nextObservation1 = buildObservation(new double[] { 100.0, 101.0, 102.0 });
|
||||
transitions.add(new Transition(observation1,0, 0.0, false, nextObservation1));
|
||||
transitions.add(buildTransition(observation1,0, 0.0, nextObservation1));
|
||||
|
||||
double[] obs2 = new double[] { 10.0, 11.0, 12.0 };
|
||||
Observation observation2 = buildObservation(obs2);
|
||||
Observation nextObservation2 = buildObservation(new double[] { 110.0, 111.0, 112.0 });
|
||||
transitions.add(new Transition(observation2, 0, 0.0, false, nextObservation2));
|
||||
transitions.add(buildTransition(observation2, 0, 0.0, nextObservation2));
|
||||
|
||||
// Act
|
||||
INDArray result = Transition.buildStackedObservations(transitions);
|
||||
|
@ -101,7 +101,7 @@ public class TransitionTest {
|
|||
double[] nextObs1 = new double[] { 100.0, 101.0, 102.0 };
|
||||
Observation nextObservation1 = buildNextObservation(obs1, nextObs1);
|
||||
|
||||
transitions.add(new Transition(observation1, 0, 0.0, false, nextObservation1));
|
||||
transitions.add(buildTransition(observation1, 0, 0.0, nextObservation1));
|
||||
|
||||
double[][] obs2 = new double[][] {
|
||||
{ 10.0, 11.0, 12.0 },
|
||||
|
@ -112,7 +112,7 @@ public class TransitionTest {
|
|||
|
||||
double[] nextObs2 = new double[] { 110.0, 111.0, 112.0 };
|
||||
Observation nextObservation2 = buildNextObservation(obs2, nextObs2);
|
||||
transitions.add(new Transition(observation2, 0, 0.0, false, nextObservation2));
|
||||
transitions.add(buildTransition(observation2, 0, 0.0, nextObservation2));
|
||||
|
||||
// Act
|
||||
INDArray result = Transition.buildStackedObservations(transitions);
|
||||
|
@ -131,13 +131,13 @@ public class TransitionTest {
|
|||
double[] nextObs1 = new double[] { 100.0, 101.0, 102.0 };
|
||||
Observation observation1 = buildObservation(obs1);
|
||||
Observation nextObservation1 = buildObservation(nextObs1);
|
||||
transitions.add(new Transition(observation1, 0, 0.0, false, nextObservation1));
|
||||
transitions.add(buildTransition(observation1, 0, 0.0, nextObservation1));
|
||||
|
||||
double[] obs2 = new double[] { 10.0, 11.0, 12.0 };
|
||||
double[] nextObs2 = new double[] { 110.0, 111.0, 112.0 };
|
||||
Observation observation2 = buildObservation(obs2);
|
||||
Observation nextObservation2 = buildObservation(nextObs2);
|
||||
transitions.add(new Transition(observation2, 0, 0.0, false, nextObservation2));
|
||||
transitions.add(buildTransition(observation2, 0, 0.0, nextObservation2));
|
||||
|
||||
// Act
|
||||
INDArray result = Transition.buildStackedNextObservations(transitions);
|
||||
|
@ -162,7 +162,7 @@ public class TransitionTest {
|
|||
double[] nextObs1 = new double[] { 100.0, 101.0, 102.0 };
|
||||
Observation nextObservation1 = buildNextObservation(obs1, nextObs1);
|
||||
|
||||
transitions.add(new Transition(observation1, 0, 0.0, false, nextObservation1));
|
||||
transitions.add(buildTransition(observation1, 0, 0.0, nextObservation1));
|
||||
|
||||
double[][] obs2 = new double[][] {
|
||||
{ 10.0, 11.0, 12.0 },
|
||||
|
@ -174,7 +174,7 @@ public class TransitionTest {
|
|||
double[] nextObs2 = new double[] { 110.0, 111.0, 112.0 };
|
||||
Observation nextObservation2 = buildNextObservation(obs2, nextObs2);
|
||||
|
||||
transitions.add(new Transition(observation2, 0, 0.0, false, nextObservation2));
|
||||
transitions.add(buildTransition(observation2, 0, 0.0, nextObservation2));
|
||||
|
||||
// Act
|
||||
INDArray result = Transition.buildStackedNextObservations(transitions);
|
||||
|
@ -207,7 +207,13 @@ public class TransitionTest {
|
|||
Nd4j.create(obs[1]).reshape(1, 3),
|
||||
};
|
||||
return new Observation(nextHistory);
|
||||
}
|
||||
|
||||
private Transition buildTransition(Observation observation, int action, double reward, Observation nextObservation) {
|
||||
Transition result = new Transition(observation, action, reward, false);
|
||||
result.setNextObservation(nextObservation);
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
private void assertExpected(double[] expected, INDArray actual) {
|
||||
|
|
|
@ -40,8 +40,10 @@ public class QLearningDiscreteTest {
|
|||
new int[] { 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4 });
|
||||
MockMDP mdp = new MockMDP(observationSpace, random);
|
||||
|
||||
int initStepCount = 8;
|
||||
|
||||
QLearning.QLConfiguration conf = new QLearning.QLConfiguration(0, 24, 0, 5, 1, 1000,
|
||||
0, 1.0, 0, 0, 0, 0, true);
|
||||
initStepCount, 1.0, 0, 0, 0, 0, true);
|
||||
MockDataManager dataManager = new MockDataManager(false);
|
||||
MockExpReplay expReplay = new MockExpReplay();
|
||||
TestQLearningDiscrete sut = new TestQLearningDiscrete(mdp, dqn, conf, dataManager, expReplay, 10, random);
|
||||
|
@ -60,7 +62,7 @@ public class QLearningDiscreteTest {
|
|||
for(int i = 0; i < expectedRecords.length; ++i) {
|
||||
assertEquals(expectedRecords[i], hp.recordCalls.get(i).getDouble(0), 0.0001);
|
||||
}
|
||||
double[] expectedAdds = new double[] { 0.0, 2.0, 4.0, 6.0, 8.0, 9.0, 11.0, 13.0, 15.0, 17.0, 19.0, 21.0, 23.0 };
|
||||
double[] expectedAdds = new double[] { 0.0, 2.0, 4.0, 6.0, 8.0, 10.0, 12.0, 14.0, 16.0, 18.0, 20.0, 22.0, 24.0 };
|
||||
assertEquals(expectedAdds.length, hp.addCalls.size());
|
||||
for(int i = 0; i < expectedAdds.length; ++i) {
|
||||
assertEquals(expectedAdds[i], hp.addCalls.get(i).getDouble(0), 0.0001);
|
||||
|
@ -75,19 +77,19 @@ public class QLearningDiscreteTest {
|
|||
assertEquals(14, dqn.outputParams.size());
|
||||
double[][] expectedDQNOutput = new double[][] {
|
||||
new double[] { 0.0, 2.0, 4.0, 6.0, 8.0 },
|
||||
new double[] { 2.0, 4.0, 6.0, 8.0, 9.0 },
|
||||
new double[] { 2.0, 4.0, 6.0, 8.0, 9.0 },
|
||||
new double[] { 4.0, 6.0, 8.0, 9.0, 11.0 },
|
||||
new double[] { 6.0, 8.0, 9.0, 11.0, 13.0 },
|
||||
new double[] { 6.0, 8.0, 9.0, 11.0, 13.0 },
|
||||
new double[] { 8.0, 9.0, 11.0, 13.0, 15.0 },
|
||||
new double[] { 8.0, 9.0, 11.0, 13.0, 15.0 },
|
||||
new double[] { 9.0, 11.0, 13.0, 15.0, 17.0 },
|
||||
new double[] { 9.0, 11.0, 13.0, 15.0, 17.0 },
|
||||
new double[] { 11.0, 13.0, 15.0, 17.0, 19.0 },
|
||||
new double[] { 11.0, 13.0, 15.0, 17.0, 19.0 },
|
||||
new double[] { 13.0, 15.0, 17.0, 19.0, 21.0 },
|
||||
new double[] { 13.0, 15.0, 17.0, 19.0, 21.0 },
|
||||
new double[] { 2.0, 4.0, 6.0, 8.0, 10.0 },
|
||||
new double[] { 2.0, 4.0, 6.0, 8.0, 10.0 },
|
||||
new double[] { 4.0, 6.0, 8.0, 10.0, 12.0 },
|
||||
new double[] { 6.0, 8.0, 10.0, 12.0, 14.0 },
|
||||
new double[] { 6.0, 8.0, 10.0, 12.0, 14.0 },
|
||||
new double[] { 8.0, 10.0, 12.0, 14.0, 16.0 },
|
||||
new double[] { 8.0, 10.0, 12.0, 14.0, 16.0 },
|
||||
new double[] { 10.0, 12.0, 14.0, 16.0, 18.0 },
|
||||
new double[] { 10.0, 12.0, 14.0, 16.0, 18.0 },
|
||||
new double[] { 12.0, 14.0, 16.0, 18.0, 20.0 },
|
||||
new double[] { 12.0, 14.0, 16.0, 18.0, 20.0 },
|
||||
new double[] { 14.0, 16.0, 18.0, 20.0, 22.0 },
|
||||
new double[] { 14.0, 16.0, 18.0, 20.0, 22.0 },
|
||||
};
|
||||
for(int i = 0; i < expectedDQNOutput.length; ++i) {
|
||||
INDArray outputParam = dqn.outputParams.get(i);
|
||||
|
@ -105,19 +107,20 @@ public class QLearningDiscreteTest {
|
|||
assertArrayEquals(new Integer[] {0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 4, 4, 2, 2, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4 }, mdp.actions.toArray());
|
||||
|
||||
// ExpReplay calls
|
||||
double[] expectedTrRewards = new double[] { 9.0, 21.0, 25.0, 29.0, 33.0, 37.0, 41.0, 45.0 };
|
||||
double[] expectedTrRewards = new double[] { 9.0, 21.0, 25.0, 29.0, 33.0, 37.0, 41.0 };
|
||||
int[] expectedTrActions = new int[] { 1, 4, 2, 4, 4, 4, 4, 4 };
|
||||
double[] expectedTrNextObservation = new double[] { 2.0, 4.0, 6.0, 8.0, 9.0, 11.0, 13.0, 15.0 };
|
||||
double[] expectedTrNextObservation = new double[] { 2.0, 4.0, 6.0, 8.0, 10.0, 12.0, 14.0, 16.0 };
|
||||
double[][] expectedTrObservations = new double[][] {
|
||||
new double[] { 0.0, 2.0, 4.0, 6.0, 8.0 },
|
||||
new double[] { 2.0, 4.0, 6.0, 8.0, 9.0 },
|
||||
new double[] { 4.0, 6.0, 8.0, 9.0, 11.0 },
|
||||
new double[] { 6.0, 8.0, 9.0, 11.0, 13.0 },
|
||||
new double[] { 8.0, 9.0, 11.0, 13.0, 15.0 },
|
||||
new double[] { 9.0, 11.0, 13.0, 15.0, 17.0 },
|
||||
new double[] { 11.0, 13.0, 15.0, 17.0, 19.0 },
|
||||
new double[] { 13.0, 15.0, 17.0, 19.0, 21.0 },
|
||||
new double[] { 2.0, 4.0, 6.0, 8.0, 10.0 },
|
||||
new double[] { 4.0, 6.0, 8.0, 10.0, 12.0 },
|
||||
new double[] { 6.0, 8.0, 10.0, 12.0, 14.0 },
|
||||
new double[] { 8.0, 10.0, 12.0, 14.0, 16.0 },
|
||||
new double[] { 10.0, 12.0, 14.0, 16.0, 18.0 },
|
||||
new double[] { 12.0, 14.0, 16.0, 18.0, 20.0 },
|
||||
new double[] { 14.0, 16.0, 18.0, 20.0, 22.0 },
|
||||
};
|
||||
assertEquals(expectedTrObservations.length, expReplay.transitions.size());
|
||||
for(int i = 0; i < expectedTrRewards.length; ++i) {
|
||||
Transition tr = expReplay.transitions.get(i);
|
||||
assertEquals(expectedTrRewards[i], tr.getReward(), 0.0001);
|
||||
|
@ -129,7 +132,7 @@ public class QLearningDiscreteTest {
|
|||
}
|
||||
|
||||
// trainEpoch result
|
||||
assertEquals(16, result.getStepCounter());
|
||||
assertEquals(initStepCount + 16, result.getStepCounter());
|
||||
assertEquals(300.0, result.getReward(), 0.00001);
|
||||
assertTrue(dqn.hasBeenReset);
|
||||
assertTrue(((MockDQN)sut.getTargetQNetwork()).hasBeenReset);
|
||||
|
|
|
@ -26,7 +26,7 @@ public class DoubleDQNTest {
|
|||
|
||||
List<Transition<Integer>> transitions = new ArrayList<Transition<Integer>>() {
|
||||
{
|
||||
add(new Transition<Integer>(buildObservation(new double[]{1.1, 2.2}),
|
||||
add(builtTransition(buildObservation(new double[]{1.1, 2.2}),
|
||||
0, 1.0, true, buildObservation(new double[]{11.0, 22.0})));
|
||||
}
|
||||
};
|
||||
|
@ -52,7 +52,7 @@ public class DoubleDQNTest {
|
|||
|
||||
List<Transition<Integer>> transitions = new ArrayList<Transition<Integer>>() {
|
||||
{
|
||||
add(new Transition<Integer>(buildObservation(new double[]{1.1, 2.2}),
|
||||
add(builtTransition(buildObservation(new double[]{1.1, 2.2}),
|
||||
0, 1.0, false, buildObservation(new double[]{11.0, 22.0})));
|
||||
}
|
||||
};
|
||||
|
@ -78,11 +78,11 @@ public class DoubleDQNTest {
|
|||
|
||||
List<Transition<Integer>> transitions = new ArrayList<Transition<Integer>>() {
|
||||
{
|
||||
add(new Transition<Integer>(buildObservation(new double[]{1.1, 2.2}),
|
||||
add(builtTransition(buildObservation(new double[]{1.1, 2.2}),
|
||||
0, 1.0, false, buildObservation(new double[]{11.0, 22.0})));
|
||||
add(new Transition<Integer>(buildObservation(new double[]{3.3, 4.4}),
|
||||
add(builtTransition(buildObservation(new double[]{3.3, 4.4}),
|
||||
1, 2.0, false, buildObservation(new double[]{33.0, 44.0})));
|
||||
add(new Transition<Integer>(buildObservation(new double[]{5.5, 6.6}),
|
||||
add(builtTransition(buildObservation(new double[]{5.5, 6.6}),
|
||||
0, 3.0, true, buildObservation(new double[]{55.0, 66.0})));
|
||||
}
|
||||
};
|
||||
|
@ -108,4 +108,11 @@ public class DoubleDQNTest {
|
|||
private Observation buildObservation(double[] data) {
|
||||
return new Observation(new INDArray[]{Nd4j.create(data).reshape(1, 2)});
|
||||
}
|
||||
|
||||
private Transition<Integer> builtTransition(Observation observation, Integer action, double reward, boolean isTerminal, Observation nextObservation) {
|
||||
Transition<Integer> result = new Transition<Integer>(observation, action, reward, isTerminal);
|
||||
result.setNextObservation(nextObservation);
|
||||
|
||||
return result;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -25,7 +25,7 @@ public class StandardDQNTest {
|
|||
|
||||
List<Transition<Integer>> transitions = new ArrayList<Transition<Integer>>() {
|
||||
{
|
||||
add(new Transition<Integer>(buildObservation(new double[]{1.1, 2.2}),
|
||||
add(buildTransition(buildObservation(new double[]{1.1, 2.2}),
|
||||
0, 1.0, true, buildObservation(new double[]{11.0, 22.0})));
|
||||
}
|
||||
};
|
||||
|
@ -51,7 +51,7 @@ public class StandardDQNTest {
|
|||
|
||||
List<Transition<Integer>> transitions = new ArrayList<Transition<Integer>>() {
|
||||
{
|
||||
add(new Transition<Integer>(buildObservation(new double[]{1.1, 2.2}),
|
||||
add(buildTransition(buildObservation(new double[]{1.1, 2.2}),
|
||||
0, 1.0, false, buildObservation(new double[]{11.0, 22.0})));
|
||||
}
|
||||
};
|
||||
|
@ -77,11 +77,11 @@ public class StandardDQNTest {
|
|||
|
||||
List<Transition<Integer>> transitions = new ArrayList<Transition<Integer>>() {
|
||||
{
|
||||
add(new Transition<Integer>(buildObservation(new double[]{1.1, 2.2}),
|
||||
add(buildTransition(buildObservation(new double[]{1.1, 2.2}),
|
||||
0, 1.0, false, buildObservation(new double[]{11.0, 22.0})));
|
||||
add(new Transition<Integer>(buildObservation(new double[]{3.3, 4.4}),
|
||||
add(buildTransition(buildObservation(new double[]{3.3, 4.4}),
|
||||
1, 2.0, false, buildObservation(new double[]{33.0, 44.0})));
|
||||
add(new Transition<Integer>(buildObservation(new double[]{5.5, 6.6}),
|
||||
add(buildTransition(buildObservation(new double[]{5.5, 6.6}),
|
||||
0, 3.0, true, buildObservation(new double[]{55.0, 66.0})));
|
||||
}
|
||||
};
|
||||
|
@ -108,4 +108,10 @@ public class StandardDQNTest {
|
|||
return new Observation(new INDArray[]{Nd4j.create(data).reshape(1, 2)});
|
||||
}
|
||||
|
||||
private Transition<Integer> buildTransition(Observation observation, Integer action, double reward, boolean isTerminal, Observation nextObservation) {
|
||||
Transition<Integer> result = new Transition<Integer>(observation, action, reward, isTerminal);
|
||||
result.setNextObservation(nextObservation);
|
||||
|
||||
return result;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -198,7 +198,7 @@ public class PolicyTest {
|
|||
assertEquals(465.0, totalReward, 0.0001);
|
||||
|
||||
// HistoryProcessor
|
||||
assertEquals(27, hp.addCalls.size());
|
||||
assertEquals(16, hp.addCalls.size());
|
||||
assertEquals(31, hp.recordCalls.size());
|
||||
for(int i=0; i <= 30; ++i) {
|
||||
assertEquals((double)i, hp.recordCalls.get(i).getDouble(0), 0.0001);
|
||||
|
|
|
@ -142,6 +142,11 @@ public class DataManagerTrainingListenerTest {
|
|||
return 0;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int getCurrentEpochStep() {
|
||||
return 0;
|
||||
}
|
||||
|
||||
@Getter
|
||||
@Setter
|
||||
private IHistoryProcessor historyProcessor;
|
||||
|
|
|
@ -93,6 +93,9 @@ public class GymEnv<O, A, AS extends ActionSpace<A>> implements MDP<O, A, AS> {
|
|||
private boolean done = false;
|
||||
|
||||
public GymEnv(String envId, boolean render, boolean monitor) {
|
||||
this(envId, render, monitor, (Integer)null);
|
||||
}
|
||||
public GymEnv(String envId, boolean render, boolean monitor, Integer seed) {
|
||||
this.envId = envId;
|
||||
this.render = render;
|
||||
this.monitor = monitor;
|
||||
|
@ -107,6 +110,10 @@ public class GymEnv<O, A, AS extends ActionSpace<A>> implements MDP<O, A, AS> {
|
|||
Py_DecRef(PyRun_StringFlags("env = gym.wrappers.Monitor(env, '" + GYM_MONITOR_DIR + "')", Py_single_input, globals, locals, null));
|
||||
checkPythonError();
|
||||
}
|
||||
if (seed != null) {
|
||||
Py_DecRef(PyRun_StringFlags("env.seed(" + seed + ")", Py_single_input, globals, locals, null));
|
||||
checkPythonError();
|
||||
}
|
||||
PyObject shapeTuple = PyRun_StringFlags("env.observation_space.shape", Py_eval_input, globals, locals, null);
|
||||
int[] shape = new int[(int)PyTuple_Size(shapeTuple)];
|
||||
for (int i = 0; i < shape.length; i++) {
|
||||
|
@ -125,7 +132,10 @@ public class GymEnv<O, A, AS extends ActionSpace<A>> implements MDP<O, A, AS> {
|
|||
}
|
||||
|
||||
public GymEnv(String envId, boolean render, boolean monitor, int[] actions) {
|
||||
this(envId, render, monitor);
|
||||
this(envId, render, monitor, null, actions);
|
||||
}
|
||||
public GymEnv(String envId, boolean render, boolean monitor, Integer seed, int[] actions) {
|
||||
this(envId, render, monitor, seed);
|
||||
actionTransformer = new ActionTransformer((HighLowDiscrete) getActionSpace(), actions);
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue