Merge remote-tracking branch 'eclipse/master'
This commit is contained in:
		
						commit
						7ea07de76b
					
				
							
								
								
									
										4
									
								
								pom.xml
									
									
									
									
									
								
							
							
						
						
									
										4
									
								
								pom.xml
									
									
									
									
									
								
							| @ -303,8 +303,8 @@ | ||||
|         <leptonica.version>1.79.0</leptonica.version> | ||||
|         <hdf5.version>1.10.6</hdf5.version> | ||||
|         <ale.version>0.6.1</ale.version> | ||||
|         <gym.version>0.15.4</gym.version> | ||||
|         <tensorflow.version>1.15.0</tensorflow.version> | ||||
|         <gym.version>0.15.5</gym.version> | ||||
|         <tensorflow.version>1.15.2</tensorflow.version> | ||||
|         <tensorflow.javacpp.version>${tensorflow.version}-${javacpp-presets.version}</tensorflow.javacpp.version> | ||||
| 
 | ||||
|         <commons-compress.version>1.18</commons-compress.version> | ||||
|  | ||||
| @ -0,0 +1,5 @@ | ||||
| package org.deeplearning4j.rl4j.learning; | ||||
| 
 | ||||
| public interface EpochStepCounter { | ||||
|     int getCurrentEpochStep(); | ||||
| } | ||||
| @ -21,9 +21,14 @@ import org.deeplearning4j.rl4j.mdp.MDP; | ||||
| /** | ||||
|  * The common API between Learning and AsyncThread. | ||||
|  * | ||||
|  * Express the ability to count the number of step of the current training. | ||||
|  * Factorisation of a feature between threads in async and learning process | ||||
|  * for the web monitoring | ||||
|  * | ||||
|  * @author Alexandre Boulanger | ||||
|  * @author rubenfiszel (ruben.fiszel@epfl.ch) on 8/5/16. | ||||
|  */ | ||||
| public interface IEpochTrainer { | ||||
| public interface IEpochTrainer extends EpochStepCounter { | ||||
|     int getStepCounter(); | ||||
|     int getEpochCounter(); | ||||
|     IHistoryProcessor getHistoryProcessor(); | ||||
|  | ||||
| @ -26,7 +26,7 @@ import org.deeplearning4j.rl4j.space.Encodable; | ||||
|  * | ||||
|  * A common interface that any training method should implement | ||||
|  */ | ||||
| public interface ILearning<O, A, AS extends ActionSpace<A>> extends StepCountable { | ||||
| public interface ILearning<O, A, AS extends ActionSpace<A>> { | ||||
| 
 | ||||
|     IPolicy<O, A> getPolicy(); | ||||
| 
 | ||||
|  | ||||
| @ -21,7 +21,6 @@ import lombok.Getter; | ||||
| import lombok.Setter; | ||||
| import lombok.Value; | ||||
| import lombok.extern.slf4j.Slf4j; | ||||
| import org.deeplearning4j.gym.StepReply; | ||||
| import org.deeplearning4j.rl4j.mdp.MDP; | ||||
| import org.deeplearning4j.rl4j.network.NeuralNet; | ||||
| import org.deeplearning4j.rl4j.space.ActionSpace; | ||||
| @ -53,55 +52,6 @@ public abstract class Learning<O, A, AS extends ActionSpace<A>, NN extends Neura | ||||
|         return Nd4j.argMax(vector, Integer.MAX_VALUE).getInt(0); | ||||
|     } | ||||
| 
 | ||||
|     public static <O, A, AS extends ActionSpace<A>> INDArray getInput(MDP<O, A, AS> mdp, O obs) { | ||||
|         INDArray arr = Nd4j.create(((Encodable)obs).toArray()); | ||||
|         int[] shape = mdp.getObservationSpace().getShape(); | ||||
|         if (shape.length == 1) | ||||
|             return arr.reshape(new long[] {1, arr.length()}); | ||||
|         else | ||||
|             return arr.reshape(shape); | ||||
|     } | ||||
| 
 | ||||
|     public static <O, A, AS extends ActionSpace<A>> InitMdp<O> initMdp(MDP<O, A, AS> mdp, | ||||
|                     IHistoryProcessor hp) { | ||||
| 
 | ||||
|         O obs = mdp.reset(); | ||||
| 
 | ||||
|         int step = 0; | ||||
|         double reward = 0; | ||||
| 
 | ||||
|         boolean isHistoryProcessor = hp != null; | ||||
| 
 | ||||
|         int skipFrame = isHistoryProcessor ? hp.getConf().getSkipFrame() : 1; | ||||
|         int requiredFrame = isHistoryProcessor ? skipFrame * (hp.getConf().getHistoryLength() - 1) : 0; | ||||
| 
 | ||||
|         INDArray input = Learning.getInput(mdp, obs); | ||||
|         if (isHistoryProcessor) | ||||
|             hp.record(input); | ||||
| 
 | ||||
| 
 | ||||
|         while (step < requiredFrame && !mdp.isDone()) { | ||||
| 
 | ||||
|             A action = mdp.getActionSpace().noOp(); //by convention should be the NO_OP | ||||
|             if (step % skipFrame == 0 && isHistoryProcessor) | ||||
|                 hp.add(input); | ||||
| 
 | ||||
|             StepReply<O> stepReply = mdp.step(action); | ||||
|             reward += stepReply.getReward(); | ||||
|             obs = stepReply.getObservation(); | ||||
| 
 | ||||
|             input = Learning.getInput(mdp, obs); | ||||
|             if (isHistoryProcessor) | ||||
|                 hp.record(input); | ||||
| 
 | ||||
|             step++; | ||||
| 
 | ||||
|         } | ||||
| 
 | ||||
|         return new InitMdp(step, obs, reward); | ||||
| 
 | ||||
|     } | ||||
| 
 | ||||
|     public static int[] makeShape(int size, int[] shape) { | ||||
|         int[] nshape = new int[shape.length + 1]; | ||||
|         nshape[0] = size; | ||||
| @ -122,16 +72,16 @@ public abstract class Learning<O, A, AS extends ActionSpace<A>, NN extends Neura | ||||
| 
 | ||||
|     public abstract NN getNeuralNet(); | ||||
| 
 | ||||
|     public int incrementStep() { | ||||
|         return stepCounter++; | ||||
|     public void incrementStep() { | ||||
|         stepCounter++; | ||||
|     } | ||||
| 
 | ||||
|     public int incrementEpoch() { | ||||
|         return epochCounter++; | ||||
|     public void incrementEpoch() { | ||||
|         epochCounter++; | ||||
|     } | ||||
| 
 | ||||
|     public void setHistoryProcessor(HistoryProcessor.Configuration conf) { | ||||
|         historyProcessor = new HistoryProcessor(conf); | ||||
|         setHistoryProcessor(new HistoryProcessor(conf)); | ||||
|     } | ||||
| 
 | ||||
|     public void setHistoryProcessor(IHistoryProcessor historyProcessor) { | ||||
|  | ||||
| @ -1,30 +0,0 @@ | ||||
| /******************************************************************************* | ||||
|  * Copyright (c) 2015-2018 Skymind, Inc. | ||||
|  * | ||||
|  * This program and the accompanying materials are made available under the | ||||
|  * terms of the Apache License, Version 2.0 which is available at | ||||
|  * https://www.apache.org/licenses/LICENSE-2.0. | ||||
|  * | ||||
|  * Unless required by applicable law or agreed to in writing, software | ||||
|  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | ||||
|  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the | ||||
|  * License for the specific language governing permissions and limitations | ||||
|  * under the License. | ||||
|  * | ||||
|  * SPDX-License-Identifier: Apache-2.0 | ||||
|  ******************************************************************************/ | ||||
| 
 | ||||
| package org.deeplearning4j.rl4j.learning; | ||||
| 
 | ||||
| /** | ||||
|  * @author rubenfiszel (ruben.fiszel@epfl.ch) on 8/5/16. | ||||
|  * | ||||
|  * Express the ability to count the number of step of the current training. | ||||
|  * Factorisation of a feature between threads in async and learning process | ||||
|  * for the web monitoring | ||||
|  */ | ||||
| public interface StepCountable { | ||||
| 
 | ||||
|     int getStepCounter(); | ||||
| 
 | ||||
| } | ||||
| @ -30,8 +30,6 @@ import org.deeplearning4j.rl4j.network.NeuralNet; | ||||
| import org.deeplearning4j.rl4j.observation.Observation; | ||||
| import org.deeplearning4j.rl4j.policy.IPolicy; | ||||
| import org.deeplearning4j.rl4j.space.ActionSpace; | ||||
| import org.deeplearning4j.rl4j.space.DiscreteSpace; | ||||
| import org.deeplearning4j.rl4j.space.Encodable; | ||||
| import org.deeplearning4j.rl4j.util.IDataManager; | ||||
| import org.deeplearning4j.rl4j.util.LegacyMDPWrapper; | ||||
| import org.nd4j.linalg.factory.Nd4j; | ||||
| @ -48,7 +46,7 @@ import org.nd4j.linalg.factory.Nd4j; | ||||
|  */ | ||||
| @Slf4j | ||||
| public abstract class AsyncThread<O, A, AS extends ActionSpace<A>, NN extends NeuralNet> | ||||
|                 extends Thread implements StepCountable, IEpochTrainer { | ||||
|                 extends Thread implements IEpochTrainer { | ||||
| 
 | ||||
|     @Getter | ||||
|     private int threadNumber; | ||||
| @ -61,6 +59,9 @@ public abstract class AsyncThread<O, A, AS extends ActionSpace<A>, NN extends Ne | ||||
|     @Getter @Setter | ||||
|     private IHistoryProcessor historyProcessor; | ||||
| 
 | ||||
|     @Getter | ||||
|     private int currentEpochStep = 0; | ||||
| 
 | ||||
|     private boolean isEpochStarted = false; | ||||
|     private final LegacyMDPWrapper<O, A, AS> mdp; | ||||
| 
 | ||||
| @ -138,7 +139,7 @@ public abstract class AsyncThread<O, A, AS extends ActionSpace<A>, NN extends Ne | ||||
| 
 | ||||
|                 handleTraining(context); | ||||
| 
 | ||||
|                 if (context.epochElapsedSteps >= getConf().getMaxEpochStep() || getMdp().isDone()) { | ||||
|                 if (currentEpochStep >= getConf().getMaxEpochStep() || getMdp().isDone()) { | ||||
|                     boolean canContinue = finishEpoch(context); | ||||
|                     if (!canContinue) { | ||||
|                         break; | ||||
| @ -154,11 +155,10 @@ public abstract class AsyncThread<O, A, AS extends ActionSpace<A>, NN extends Ne | ||||
|     } | ||||
| 
 | ||||
|     private void handleTraining(RunContext context) { | ||||
|         int maxSteps = Math.min(getConf().getNstep(), getConf().getMaxEpochStep() - context.epochElapsedSteps); | ||||
|         int maxSteps = Math.min(getConf().getNstep(), getConf().getMaxEpochStep() - currentEpochStep); | ||||
|         SubEpochReturn subEpochReturn = trainSubEpoch(context.obs, maxSteps); | ||||
| 
 | ||||
|         context.obs = subEpochReturn.getLastObs(); | ||||
|         context.epochElapsedSteps += subEpochReturn.getSteps(); | ||||
|         context.rewards += subEpochReturn.getReward(); | ||||
|         context.score = subEpochReturn.getScore(); | ||||
|     } | ||||
| @ -169,7 +169,6 @@ public abstract class AsyncThread<O, A, AS extends ActionSpace<A>, NN extends Ne | ||||
| 
 | ||||
|         context.obs = initMdp.getLastObs(); | ||||
|         context.rewards = initMdp.getReward(); | ||||
|         context.epochElapsedSteps = initMdp.getSteps(); | ||||
| 
 | ||||
|         isEpochStarted = true; | ||||
|         preEpoch(); | ||||
| @ -180,9 +179,9 @@ public abstract class AsyncThread<O, A, AS extends ActionSpace<A>, NN extends Ne | ||||
|     private boolean finishEpoch(RunContext context) { | ||||
|         isEpochStarted = false; | ||||
|         postEpoch(); | ||||
|         IDataManager.StatEntry statEntry = new AsyncStatEntry(getStepCounter(), epochCounter, context.rewards, context.epochElapsedSteps, context.score); | ||||
|         IDataManager.StatEntry statEntry = new AsyncStatEntry(getStepCounter(), epochCounter, context.rewards, currentEpochStep, context.score); | ||||
| 
 | ||||
|         log.info("ThreadNum-" + threadNumber + " Epoch: " + getEpochCounter() + ", reward: " + context.rewards); | ||||
|         log.info("ThreadNum-" + threadNumber + " Epoch: " + getCurrentEpochStep() + ", reward: " + context.rewards); | ||||
| 
 | ||||
|         return listeners.notifyEpochTrainingResult(this, statEntry); | ||||
|     } | ||||
| @ -205,37 +204,30 @@ public abstract class AsyncThread<O, A, AS extends ActionSpace<A>, NN extends Ne | ||||
|     protected abstract SubEpochReturn trainSubEpoch(Observation obs, int nstep); | ||||
| 
 | ||||
|     private Learning.InitMdp<Observation> refacInitMdp() { | ||||
|         LegacyMDPWrapper<O, A, AS> mdp = getLegacyMDPWrapper(); | ||||
|         IHistoryProcessor hp = getHistoryProcessor(); | ||||
|         currentEpochStep = 0; | ||||
| 
 | ||||
|         Observation observation = mdp.reset(); | ||||
| 
 | ||||
|         int step = 0; | ||||
|         double reward = 0; | ||||
| 
 | ||||
|         boolean isHistoryProcessor = hp != null; | ||||
| 
 | ||||
|         int skipFrame = isHistoryProcessor ? hp.getConf().getSkipFrame() : 1; | ||||
|         int requiredFrame = isHistoryProcessor ? skipFrame * (hp.getConf().getHistoryLength() - 1) : 0; | ||||
| 
 | ||||
|         while (step < requiredFrame && !mdp.isDone()) { | ||||
| 
 | ||||
|             A action = mdp.getActionSpace().noOp(); //by convention should be the NO_OP | ||||
|         LegacyMDPWrapper<O, A, AS> mdp = getLegacyMDPWrapper(); | ||||
|         Observation observation = mdp.reset(); | ||||
| 
 | ||||
|         A action = mdp.getActionSpace().noOp(); //by convention should be the NO_OP | ||||
|         while (observation.isSkipped() && !mdp.isDone()) { | ||||
|             StepReply<Observation> stepReply = mdp.step(action); | ||||
| 
 | ||||
|             reward += stepReply.getReward(); | ||||
|             observation = stepReply.getObservation(); | ||||
| 
 | ||||
|             step++; | ||||
| 
 | ||||
|             incrementStep(); | ||||
|         } | ||||
| 
 | ||||
|         return new Learning.InitMdp(step, observation, reward); | ||||
|         return new Learning.InitMdp(0, observation, reward); | ||||
| 
 | ||||
|     } | ||||
| 
 | ||||
|     public void incrementStep() { | ||||
|         ++stepCounter; | ||||
|         ++currentEpochStep; | ||||
|     } | ||||
| 
 | ||||
|     @AllArgsConstructor | ||||
| @ -260,7 +252,6 @@ public abstract class AsyncThread<O, A, AS extends ActionSpace<A>, NN extends Ne | ||||
|     private static class RunContext { | ||||
|         private Observation obs; | ||||
|         private double rewards; | ||||
|         private int epochElapsedSteps; | ||||
|         private double score; | ||||
|     } | ||||
| 
 | ||||
|  | ||||
| @ -20,19 +20,14 @@ import lombok.Getter; | ||||
| import org.deeplearning4j.gym.StepReply; | ||||
| import org.deeplearning4j.nn.gradient.Gradient; | ||||
| import org.deeplearning4j.rl4j.learning.IHistoryProcessor; | ||||
| import org.deeplearning4j.rl4j.learning.Learning; | ||||
| import org.deeplearning4j.rl4j.learning.listener.TrainingListenerList; | ||||
| import org.deeplearning4j.rl4j.learning.sync.Transition; | ||||
| import org.deeplearning4j.rl4j.mdp.MDP; | ||||
| import org.deeplearning4j.rl4j.network.NeuralNet; | ||||
| import org.deeplearning4j.rl4j.observation.Observation; | ||||
| import org.deeplearning4j.rl4j.policy.IPolicy; | ||||
| import org.deeplearning4j.rl4j.space.DiscreteSpace; | ||||
| import org.deeplearning4j.rl4j.space.Encodable; | ||||
| import org.deeplearning4j.rl4j.util.LegacyMDPWrapper; | ||||
| import org.nd4j.linalg.api.ndarray.INDArray; | ||||
| import org.nd4j.linalg.factory.Nd4j; | ||||
| import org.nd4j.linalg.util.ArrayUtil; | ||||
| 
 | ||||
| import java.util.Stack; | ||||
| 
 | ||||
| @ -74,17 +69,18 @@ public abstract class AsyncThreadDiscrete<O, NN extends NeuralNet> | ||||
|         IPolicy<O, Integer> policy = getPolicy(current); | ||||
| 
 | ||||
|         Integer action; | ||||
|         Integer lastAction = null; | ||||
|         Integer lastAction = getMdp().getActionSpace().noOp(); | ||||
|         IHistoryProcessor hp = getHistoryProcessor(); | ||||
|         int skipFrame = hp != null ? hp.getConf().getSkipFrame() : 1; | ||||
| 
 | ||||
|         double reward = 0; | ||||
|         double accuReward = 0; | ||||
|         int i = 0; | ||||
|         while (!getMdp().isDone() && i < nstep * skipFrame) { | ||||
|         int stepAtStart = getCurrentEpochStep(); | ||||
|         int lastStep = nstep * skipFrame + stepAtStart; | ||||
|         while (!getMdp().isDone() && getCurrentEpochStep() < lastStep) { | ||||
| 
 | ||||
|             //if step of training, just repeat lastAction | ||||
|             if (i % skipFrame != 0 && lastAction != null) { | ||||
|             if (obs.isSkipped()) { | ||||
|                 action = lastAction; | ||||
|             } else { | ||||
|                 action = policy.nextAction(obs); | ||||
| @ -94,7 +90,7 @@ public abstract class AsyncThreadDiscrete<O, NN extends NeuralNet> | ||||
|             accuReward += stepReply.getReward() * getConf().getRewardFactor(); | ||||
| 
 | ||||
|             //if it's not a skipped frame, you can do a step of training | ||||
|             if (i % skipFrame == 0 || lastAction == null || stepReply.isDone()) { | ||||
|             if (!obs.isSkipped() || stepReply.isDone()) { | ||||
| 
 | ||||
|                 INDArray[] output = current.outputAll(obs.getData()); | ||||
|                 rewards.add(new MiniTrans(obs.getData(), action, output, accuReward)); | ||||
| @ -106,7 +102,6 @@ public abstract class AsyncThreadDiscrete<O, NN extends NeuralNet> | ||||
| 
 | ||||
|             reward += stepReply.getReward(); | ||||
| 
 | ||||
|             i++; | ||||
|             incrementStep(); | ||||
|             lastAction = action; | ||||
|         } | ||||
| @ -114,7 +109,7 @@ public abstract class AsyncThreadDiscrete<O, NN extends NeuralNet> | ||||
|         //a bit of a trick usable because of how the stack is treated to init R | ||||
|         // FIXME: The last element of minitrans is only used to seed the reward in calcGradient; observation, action and output are ignored. | ||||
| 
 | ||||
|         if (getMdp().isDone() && i < nstep * skipFrame) | ||||
|         if (getMdp().isDone() && getCurrentEpochStep() < lastStep) | ||||
|             rewards.add(new MiniTrans(obs.getData(), null, null, 0)); | ||||
|         else { | ||||
|             INDArray[] output = null; | ||||
| @ -127,9 +122,9 @@ public abstract class AsyncThreadDiscrete<O, NN extends NeuralNet> | ||||
|             rewards.add(new MiniTrans(obs.getData(), null, output, maxQ)); | ||||
|         } | ||||
| 
 | ||||
|         getAsyncGlobal().enqueue(calcGradient(current, rewards), i); | ||||
|         getAsyncGlobal().enqueue(calcGradient(current, rewards), getCurrentEpochStep()); | ||||
| 
 | ||||
|         return new SubEpochReturn(i, obs, reward, current.getLatestScore()); | ||||
|         return new SubEpochReturn(getCurrentEpochStep() - stepAtStart, obs, reward, current.getLatestScore()); | ||||
|     } | ||||
| 
 | ||||
|     public abstract Gradient[] calcGradient(NN nn, Stack<MiniTrans<Integer>> rewards); | ||||
|  | ||||
| @ -17,7 +17,6 @@ | ||||
| package org.deeplearning4j.rl4j.learning.sync; | ||||
| 
 | ||||
| import lombok.Getter; | ||||
| import lombok.Setter; | ||||
| import lombok.extern.slf4j.Slf4j; | ||||
| import org.deeplearning4j.rl4j.learning.IEpochTrainer; | ||||
| import org.deeplearning4j.rl4j.learning.ILearning; | ||||
| @ -25,7 +24,6 @@ import org.deeplearning4j.rl4j.learning.Learning; | ||||
| import org.deeplearning4j.rl4j.learning.listener.*; | ||||
| import org.deeplearning4j.rl4j.network.NeuralNet; | ||||
| import org.deeplearning4j.rl4j.space.ActionSpace; | ||||
| import org.deeplearning4j.rl4j.space.Encodable; | ||||
| import org.deeplearning4j.rl4j.util.IDataManager; | ||||
| 
 | ||||
| /** | ||||
|  | ||||
| @ -16,6 +16,7 @@ | ||||
| 
 | ||||
| package org.deeplearning4j.rl4j.learning.sync; | ||||
| 
 | ||||
| import lombok.Data; | ||||
| import lombok.Value; | ||||
| import org.deeplearning4j.rl4j.observation.Observation; | ||||
| import org.nd4j.linalg.api.ndarray.INDArray; | ||||
| @ -34,7 +35,7 @@ import java.util.List; | ||||
|  * @author Alexandre Boulanger | ||||
|  * | ||||
|  */ | ||||
| @Value | ||||
| @Data | ||||
| public class Transition<A> { | ||||
| 
 | ||||
|     Observation observation; | ||||
| @ -43,12 +44,15 @@ public class Transition<A> { | ||||
|     boolean isTerminal; | ||||
|     INDArray nextObservation; | ||||
| 
 | ||||
|     public Transition(Observation observation, A action, double reward, boolean isTerminal, Observation nextObservation) { | ||||
|     public Transition(Observation observation, A action, double reward, boolean isTerminal) { | ||||
|         this.observation = observation; | ||||
|         this.action = action; | ||||
|         this.reward = reward; | ||||
|         this.isTerminal = isTerminal; | ||||
|         this.nextObservation = null; | ||||
|     } | ||||
| 
 | ||||
|     public void setNextObservation(Observation nextObservation) { | ||||
|         // To conserve memory, only the most recent frame of the next observation is kept (if history is used). | ||||
|         // The full nextObservation will be re-build from observation when needed. | ||||
|         long[] nextObservationShape = nextObservation.getData().shape().clone(); | ||||
|  | ||||
| @ -21,8 +21,7 @@ import com.fasterxml.jackson.databind.annotation.JsonPOJOBuilder; | ||||
| import lombok.*; | ||||
| import lombok.extern.slf4j.Slf4j; | ||||
| import org.deeplearning4j.gym.StepReply; | ||||
| import org.deeplearning4j.rl4j.learning.IHistoryProcessor; | ||||
| import org.deeplearning4j.rl4j.learning.Learning; | ||||
| import org.deeplearning4j.rl4j.learning.EpochStepCounter; | ||||
| import org.deeplearning4j.rl4j.learning.sync.ExpReplay; | ||||
| import org.deeplearning4j.rl4j.learning.sync.IExpReplay; | ||||
| import org.deeplearning4j.rl4j.learning.sync.SyncLearning; | ||||
| @ -34,9 +33,8 @@ import org.deeplearning4j.rl4j.space.ActionSpace; | ||||
| import org.deeplearning4j.rl4j.space.Encodable; | ||||
| import org.deeplearning4j.rl4j.util.IDataManager.StatEntry; | ||||
| import org.deeplearning4j.rl4j.util.LegacyMDPWrapper; | ||||
| import org.nd4j.linalg.api.ndarray.INDArray; | ||||
| import org.nd4j.linalg.factory.Nd4j; | ||||
| import org.nd4j.linalg.api.rng.Random; | ||||
| import org.nd4j.linalg.factory.Nd4j; | ||||
| 
 | ||||
| import java.util.ArrayList; | ||||
| import java.util.List; | ||||
| @ -49,7 +47,8 @@ import java.util.List; | ||||
|  */ | ||||
| @Slf4j | ||||
| public abstract class QLearning<O extends Encodable, A, AS extends ActionSpace<A>> | ||||
|                 extends SyncLearning<O, A, AS, IDQN> implements TargetQNetworkSource { | ||||
|                 extends SyncLearning<O, A, AS, IDQN> | ||||
|                 implements TargetQNetworkSource, EpochStepCounter { | ||||
| 
 | ||||
|     // FIXME Changed for refac | ||||
|     // @Getter | ||||
| @ -104,18 +103,22 @@ public abstract class QLearning<O extends Encodable, A, AS extends ActionSpace<A | ||||
| 
 | ||||
|     protected abstract QLStepReturn<Observation> trainStep(Observation obs); | ||||
| 
 | ||||
|     @Getter | ||||
|     private int currentEpochStep = 0; | ||||
| 
 | ||||
|     protected StatEntry trainEpoch() { | ||||
|         resetNetworks(); | ||||
| 
 | ||||
|         InitMdp<Observation> initMdp = refacInitMdp(); | ||||
|         Observation obs = initMdp.getLastObs(); | ||||
| 
 | ||||
|         double reward = initMdp.getReward(); | ||||
|         int step = initMdp.getSteps(); | ||||
| 
 | ||||
|         Double startQ = Double.NaN; | ||||
|         double meanQ = 0; | ||||
|         int numQ = 0; | ||||
|         List<Double> scores = new ArrayList<>(); | ||||
|         while (step < getConfiguration().getMaxEpochStep() && !getMdp().isDone()) { | ||||
|         while (currentEpochStep < getConfiguration().getMaxEpochStep() && !getMdp().isDone()) { | ||||
| 
 | ||||
|             if (getStepCounter() % getConfiguration().getTargetDqnUpdateFreq() == 0) { | ||||
|                 updateTargetNetwork(); | ||||
| @ -136,49 +139,53 @@ public abstract class QLearning<O extends Encodable, A, AS extends ActionSpace<A | ||||
|             reward += stepR.getStepReply().getReward(); | ||||
|             obs = stepR.getStepReply().getObservation(); | ||||
|             incrementStep(); | ||||
|             step++; | ||||
|         } | ||||
| 
 | ||||
|         finishEpoch(obs); | ||||
| 
 | ||||
|         meanQ /= (numQ + 0.001); //avoid div zero | ||||
| 
 | ||||
| 
 | ||||
|         StatEntry statEntry = new QLStatEntry(getStepCounter(), getEpochCounter(), reward, step, scores, | ||||
|         StatEntry statEntry = new QLStatEntry(getStepCounter(), getEpochCounter(), reward, currentEpochStep, scores, | ||||
|                         getEgPolicy().getEpsilon(), startQ, meanQ); | ||||
| 
 | ||||
|         return statEntry; | ||||
|     } | ||||
| 
 | ||||
|     protected void finishEpoch(Observation observation) { | ||||
|         // Do Nothing | ||||
|     } | ||||
| 
 | ||||
|     @Override | ||||
|     public void incrementStep() { | ||||
|         super.incrementStep(); | ||||
|         ++currentEpochStep; | ||||
|     } | ||||
| 
 | ||||
|     protected void resetNetworks() { | ||||
|         getQNetwork().reset(); | ||||
|         getTargetQNetwork().reset(); | ||||
|     } | ||||
| 
 | ||||
|     private InitMdp<Observation> refacInitMdp() { | ||||
|         getQNetwork().reset(); | ||||
|         getTargetQNetwork().reset(); | ||||
|         currentEpochStep = 0; | ||||
| 
 | ||||
|         LegacyMDPWrapper<O, A, AS> mdp = getLegacyMDPWrapper(); | ||||
|         IHistoryProcessor hp = getHistoryProcessor(); | ||||
| 
 | ||||
|         Observation observation = mdp.reset(); | ||||
| 
 | ||||
|         int step = 0; | ||||
|         double reward = 0; | ||||
| 
 | ||||
|         boolean isHistoryProcessor = hp != null; | ||||
| 
 | ||||
|         int skipFrame = isHistoryProcessor ? hp.getConf().getSkipFrame() : 1; | ||||
|         int requiredFrame = isHistoryProcessor ? skipFrame * (hp.getConf().getHistoryLength() - 1) : 0; | ||||
| 
 | ||||
|         while (step < requiredFrame && !mdp.isDone()) { | ||||
| 
 | ||||
|             A action = mdp.getActionSpace().noOp(); //by convention should be the NO_OP | ||||
|         LegacyMDPWrapper<O, A, AS> mdp = getLegacyMDPWrapper(); | ||||
|         Observation observation = mdp.reset(); | ||||
| 
 | ||||
|         A action = mdp.getActionSpace().noOp(); //by convention should be the NO_OP | ||||
|         while (observation.isSkipped() && !mdp.isDone()) { | ||||
|             StepReply<Observation> stepReply = mdp.step(action); | ||||
| 
 | ||||
|             reward += stepReply.getReward(); | ||||
|             observation = stepReply.getObservation(); | ||||
| 
 | ||||
|             step++; | ||||
| 
 | ||||
|             incrementStep(); | ||||
|         } | ||||
| 
 | ||||
|         return new InitMdp(step, observation, reward); | ||||
|         return new InitMdp(0, observation, reward); | ||||
| 
 | ||||
|     } | ||||
| 
 | ||||
|  | ||||
| @ -20,10 +20,13 @@ import lombok.AccessLevel; | ||||
| import lombok.Getter; | ||||
| import lombok.Setter; | ||||
| import org.deeplearning4j.gym.StepReply; | ||||
| import org.deeplearning4j.rl4j.learning.IHistoryProcessor; | ||||
| import org.deeplearning4j.rl4j.learning.Learning; | ||||
| import org.deeplearning4j.rl4j.learning.sync.Transition; | ||||
| import org.deeplearning4j.rl4j.learning.sync.qlearning.QLearning; | ||||
| import org.deeplearning4j.rl4j.learning.sync.qlearning.discrete.TDTargetAlgorithm.*; | ||||
| import org.deeplearning4j.rl4j.learning.sync.qlearning.discrete.TDTargetAlgorithm.DoubleDQN; | ||||
| import org.deeplearning4j.rl4j.learning.sync.qlearning.discrete.TDTargetAlgorithm.ITDTargetAlgorithm; | ||||
| import org.deeplearning4j.rl4j.learning.sync.qlearning.discrete.TDTargetAlgorithm.StandardDQN; | ||||
| import org.deeplearning4j.rl4j.mdp.MDP; | ||||
| import org.deeplearning4j.rl4j.network.dqn.IDQN; | ||||
| import org.deeplearning4j.rl4j.observation.Observation; | ||||
| @ -36,7 +39,6 @@ import org.nd4j.linalg.api.ndarray.INDArray; | ||||
| import org.nd4j.linalg.api.rng.Random; | ||||
| import org.nd4j.linalg.dataset.api.DataSet; | ||||
| import org.nd4j.linalg.factory.Nd4j; | ||||
| import org.nd4j.linalg.util.ArrayUtil; | ||||
| 
 | ||||
| import java.util.ArrayList; | ||||
| 
 | ||||
| @ -68,6 +70,8 @@ public abstract class QLearningDiscrete<O extends Encodable> extends QLearning<O | ||||
|     private int lastAction; | ||||
|     private double accuReward = 0; | ||||
| 
 | ||||
|     private Transition pendingTransition; | ||||
| 
 | ||||
|     ITDTargetAlgorithm tdTargetAlgorithm; | ||||
| 
 | ||||
|     protected LegacyMDPWrapper<O, Integer, DiscreteSpace> getLegacyMDPWrapper() { | ||||
| @ -83,7 +87,7 @@ public abstract class QLearningDiscrete<O extends Encodable> extends QLearning<O | ||||
|                              int epsilonNbStep, Random random) { | ||||
|         super(conf); | ||||
|         this.configuration = conf; | ||||
|         this.mdp = new LegacyMDPWrapper<O, Integer, DiscreteSpace>(mdp, this); | ||||
|         this.mdp = new LegacyMDPWrapper<O, Integer, DiscreteSpace>(mdp, null, this); | ||||
|         qNetwork = dqn; | ||||
|         targetQNetwork = dqn.clone(); | ||||
|         policy = new DQNPolicy(getQNetwork()); | ||||
| @ -108,8 +112,15 @@ public abstract class QLearningDiscrete<O extends Encodable> extends QLearning<O | ||||
|     } | ||||
| 
 | ||||
|     public void preEpoch() { | ||||
|         lastAction = 0; | ||||
|         lastAction = mdp.getActionSpace().noOp(); | ||||
|         accuReward = 0; | ||||
|         pendingTransition = null; | ||||
|     } | ||||
| 
 | ||||
|     @Override | ||||
|     public void setHistoryProcessor(IHistoryProcessor historyProcessor) { | ||||
|         super.setHistoryProcessor(historyProcessor); | ||||
|         mdp.setHistoryProcessor(historyProcessor); | ||||
|     } | ||||
| 
 | ||||
|     /** | ||||
| @ -120,9 +131,8 @@ public abstract class QLearningDiscrete<O extends Encodable> extends QLearning<O | ||||
|     protected QLStepReturn<Observation> trainStep(Observation obs) { | ||||
| 
 | ||||
|         Integer action; | ||||
| 
 | ||||
|         boolean isHistoryProcessor = getHistoryProcessor() != null; | ||||
| 
 | ||||
| 
 | ||||
|         int skipFrame = isHistoryProcessor ? getHistoryProcessor().getConf().getSkipFrame() : 1; | ||||
|         int historyLength = isHistoryProcessor ? getHistoryProcessor().getConf().getHistoryLength() : 1; | ||||
|         int updateStart = getConfiguration().getUpdateStart() | ||||
| @ -131,7 +141,7 @@ public abstract class QLearningDiscrete<O extends Encodable> extends QLearning<O | ||||
|         Double maxQ = Double.NaN; //ignore if Nan for stats | ||||
| 
 | ||||
|         //if step of training, just repeat lastAction | ||||
|         if (getStepCounter() % skipFrame != 0) { | ||||
|         if (obs.isSkipped()) { | ||||
|             action = lastAction; | ||||
|         } else { | ||||
|             INDArray qs = getQNetwork().output(obs); | ||||
| @ -145,22 +155,25 @@ public abstract class QLearningDiscrete<O extends Encodable> extends QLearning<O | ||||
| 
 | ||||
|         StepReply<Observation> stepReply = mdp.step(action); | ||||
| 
 | ||||
|         Observation nextObservation = stepReply.getObservation(); | ||||
| 
 | ||||
|         accuReward += stepReply.getReward() * configuration.getRewardFactor(); | ||||
| 
 | ||||
|         //if it's not a skipped frame, you can do a step of training | ||||
|         if (getStepCounter() % skipFrame == 0 || stepReply.isDone()) { | ||||
|         if (!obs.isSkipped() || stepReply.isDone()) { | ||||
| 
 | ||||
|             Transition<Integer> trans = new Transition(obs, action, accuReward, stepReply.isDone(), nextObservation); | ||||
|             getExpReplay().store(trans); | ||||
|             // Add experience | ||||
|             if(pendingTransition != null) { | ||||
|                 pendingTransition.setNextObservation(obs); | ||||
|                 getExpReplay().store(pendingTransition); | ||||
|             } | ||||
|             pendingTransition = new Transition(obs, action, accuReward, stepReply.isDone()); | ||||
|             accuReward = 0; | ||||
| 
 | ||||
|             // Update NN | ||||
|             // FIXME: maybe start updating when experience replay has reached a certain size instead of using "updateStart"? | ||||
|             if (getStepCounter() > updateStart) { | ||||
|                 DataSet targets = setTarget(getExpReplay().getBatch()); | ||||
|                 getQNetwork().fit(targets.getFeatures(), targets.getLabels()); | ||||
|             } | ||||
| 
 | ||||
|             accuReward = 0; | ||||
|         } | ||||
| 
 | ||||
|         return new QLStepReturn<Observation>(maxQ, getQNetwork().getLatestScore(), stepReply); | ||||
| @ -172,4 +185,12 @@ public abstract class QLearningDiscrete<O extends Encodable> extends QLearning<O | ||||
| 
 | ||||
|         return tdTargetAlgorithm.computeTDTargets(transitions); | ||||
|     } | ||||
| 
 | ||||
|     @Override | ||||
|     protected void finishEpoch(Observation observation) { | ||||
|         if(pendingTransition != null) { | ||||
|             pendingTransition.setNextObservation(observation); | ||||
|             getExpReplay().store(pendingTransition); | ||||
|         } | ||||
|     } | ||||
| } | ||||
|  | ||||
| @ -92,6 +92,7 @@ public class CartpoleNative implements MDP<CartpoleNative.State, Integer, Discre | ||||
|         theta = 0.1 * rnd.nextDouble() - 0.05; | ||||
|         thetaDot = 0.1 * rnd.nextDouble() - 0.05; | ||||
|         stepsBeyondDone = null; | ||||
|         done = false; | ||||
| 
 | ||||
|         return new State(new double[] { x, xDot, theta, thetaDot }); | ||||
|     } | ||||
| @ -126,7 +127,7 @@ public class CartpoleNative implements MDP<CartpoleNative.State, Integer, Discre | ||||
|                 break; | ||||
|         } | ||||
| 
 | ||||
|         boolean done =  x < -xThreshold || x > xThreshold | ||||
|         done |=  x < -xThreshold || x > xThreshold | ||||
|                 || theta < -thetaThresholdRadians || theta > thetaThresholdRadians; | ||||
| 
 | ||||
|         double reward; | ||||
|  | ||||
| @ -16,6 +16,8 @@ | ||||
| 
 | ||||
| package org.deeplearning4j.rl4j.observation; | ||||
| 
 | ||||
| import lombok.Getter; | ||||
| import lombok.Setter; | ||||
| import org.nd4j.linalg.api.ndarray.INDArray; | ||||
| import org.nd4j.linalg.dataset.api.DataSet; | ||||
| import org.nd4j.linalg.factory.Nd4j; | ||||
| @ -28,6 +30,9 @@ public class Observation { | ||||
| 
 | ||||
|     private final DataSet data; | ||||
| 
 | ||||
|     @Getter @Setter | ||||
|     private boolean skipped; | ||||
| 
 | ||||
|     public Observation(INDArray[] data) { | ||||
|         this(data, false); | ||||
|     } | ||||
|  | ||||
| @ -18,7 +18,8 @@ package org.deeplearning4j.rl4j.policy; | ||||
| 
 | ||||
| import lombok.AllArgsConstructor; | ||||
| import lombok.extern.slf4j.Slf4j; | ||||
| import org.deeplearning4j.rl4j.learning.StepCountable; | ||||
| import org.deeplearning4j.rl4j.learning.IEpochTrainer; | ||||
| import org.deeplearning4j.rl4j.learning.ILearning; | ||||
| import org.deeplearning4j.rl4j.mdp.MDP; | ||||
| import org.deeplearning4j.rl4j.network.NeuralNet; | ||||
| import org.deeplearning4j.rl4j.observation.Observation; | ||||
| @ -46,7 +47,7 @@ public class EpsGreedy<O, A, AS extends ActionSpace<A>> extends Policy<O, A> { | ||||
|     final private int epsilonNbStep; | ||||
|     final private Random rnd; | ||||
|     final private float minEpsilon; | ||||
|     final private StepCountable learning; | ||||
|     final private IEpochTrainer learning; | ||||
| 
 | ||||
|     public NeuralNet getNeuralNet() { | ||||
|         return policy.getNeuralNet(); | ||||
|  | ||||
| @ -19,20 +19,15 @@ package org.deeplearning4j.rl4j.policy; | ||||
| import lombok.Getter; | ||||
| import lombok.Setter; | ||||
| import org.deeplearning4j.gym.StepReply; | ||||
| import org.deeplearning4j.rl4j.learning.EpochStepCounter; | ||||
| import org.deeplearning4j.rl4j.learning.HistoryProcessor; | ||||
| import org.deeplearning4j.rl4j.learning.IHistoryProcessor; | ||||
| import org.deeplearning4j.rl4j.learning.Learning; | ||||
| import org.deeplearning4j.rl4j.learning.StepCountable; | ||||
| import org.deeplearning4j.rl4j.learning.sync.Transition; | ||||
| import org.deeplearning4j.rl4j.mdp.MDP; | ||||
| import org.deeplearning4j.rl4j.network.NeuralNet; | ||||
| import org.deeplearning4j.rl4j.observation.Observation; | ||||
| import org.deeplearning4j.rl4j.space.ActionSpace; | ||||
| import org.deeplearning4j.rl4j.space.DiscreteSpace; | ||||
| import org.deeplearning4j.rl4j.space.Encodable; | ||||
| import org.deeplearning4j.rl4j.util.LegacyMDPWrapper; | ||||
| import org.nd4j.linalg.api.ndarray.INDArray; | ||||
| import org.nd4j.linalg.util.ArrayUtil; | ||||
| 
 | ||||
| /** | ||||
|  * @author rubenfiszel (ruben.fiszel@epfl.ch) 7/18/16. | ||||
| @ -57,24 +52,22 @@ public abstract class Policy<O, A> implements IPolicy<O, A> { | ||||
| 
 | ||||
|     @Override | ||||
|     public <AS extends ActionSpace<A>> double play(MDP<O, A, AS> mdp, IHistoryProcessor hp) { | ||||
|         RefacStepCountable stepCountable = new RefacStepCountable(); | ||||
|         LegacyMDPWrapper<O, A, AS> mdpWrapper = new LegacyMDPWrapper<O, A, AS>(mdp, hp, stepCountable); | ||||
|         resetNetworks(); | ||||
| 
 | ||||
|         boolean isHistoryProcessor = hp != null; | ||||
|         int skipFrame = isHistoryProcessor ? hp.getConf().getSkipFrame() : 1; | ||||
|         RefacEpochStepCounter epochStepCounter = new RefacEpochStepCounter(); | ||||
|         LegacyMDPWrapper<O, A, AS> mdpWrapper = new LegacyMDPWrapper<O, A, AS>(mdp, hp, epochStepCounter); | ||||
| 
 | ||||
|         Learning.InitMdp<Observation> initMdp = refacInitMdp(mdpWrapper, hp); | ||||
|         Learning.InitMdp<Observation> initMdp = refacInitMdp(mdpWrapper, hp, epochStepCounter); | ||||
|         Observation obs = initMdp.getLastObs(); | ||||
| 
 | ||||
|         double reward = initMdp.getReward(); | ||||
| 
 | ||||
|         A lastAction = mdpWrapper.getActionSpace().noOp(); | ||||
|         A action; | ||||
|         stepCountable.setStepCounter(initMdp.getSteps()); | ||||
| 
 | ||||
|         while (!mdpWrapper.isDone()) { | ||||
| 
 | ||||
|             if (stepCountable.getStepCounter() % skipFrame != 0) { | ||||
|             if (obs.isSkipped()) { | ||||
|                 action = lastAction; | ||||
|             } else { | ||||
|                 action = nextAction(obs); | ||||
| @ -86,52 +79,46 @@ public abstract class Policy<O, A> implements IPolicy<O, A> { | ||||
|             reward += stepReply.getReward(); | ||||
| 
 | ||||
|             obs = stepReply.getObservation(); | ||||
|             stepCountable.increment(); | ||||
|             epochStepCounter.incrementEpochStep(); | ||||
|         } | ||||
| 
 | ||||
|         return reward; | ||||
|     } | ||||
| 
 | ||||
|     private <AS extends ActionSpace<A>> Learning.InitMdp<Observation> refacInitMdp(LegacyMDPWrapper<O, A, AS> mdpWrapper, IHistoryProcessor hp) { | ||||
|     protected void resetNetworks() { | ||||
|         getNeuralNet().reset(); | ||||
|         Observation observation = mdpWrapper.reset(); | ||||
|     } | ||||
| 
 | ||||
|     private <AS extends ActionSpace<A>> Learning.InitMdp<Observation> refacInitMdp(LegacyMDPWrapper<O, A, AS> mdpWrapper, IHistoryProcessor hp, RefacEpochStepCounter epochStepCounter) { | ||||
|         epochStepCounter.setCurrentEpochStep(0); | ||||
| 
 | ||||
|         int step = 0; | ||||
|         double reward = 0; | ||||
| 
 | ||||
|         boolean isHistoryProcessor = hp != null; | ||||
|         Observation observation = mdpWrapper.reset(); | ||||
| 
 | ||||
|         int skipFrame = isHistoryProcessor ? hp.getConf().getSkipFrame() : 1; | ||||
|         int requiredFrame = isHistoryProcessor ? skipFrame * (hp.getConf().getHistoryLength() - 1) : 0; | ||||
| 
 | ||||
|         while (step < requiredFrame && !mdpWrapper.isDone()) { | ||||
| 
 | ||||
|             A action = mdpWrapper.getActionSpace().noOp(); //by convention should be the NO_OP | ||||
|         A action = mdpWrapper.getActionSpace().noOp(); //by convention should be the NO_OP | ||||
|         while (observation.isSkipped() && !mdpWrapper.isDone()) { | ||||
| 
 | ||||
|             StepReply<Observation> stepReply = mdpWrapper.step(action); | ||||
| 
 | ||||
|             reward += stepReply.getReward(); | ||||
|             observation = stepReply.getObservation(); | ||||
| 
 | ||||
|             step++; | ||||
| 
 | ||||
|             epochStepCounter.incrementEpochStep(); | ||||
|         } | ||||
| 
 | ||||
|         return new Learning.InitMdp(step, observation, reward); | ||||
|         return new Learning.InitMdp(0, observation, reward); | ||||
|     } | ||||
| 
 | ||||
|     private class RefacStepCountable implements StepCountable { | ||||
|     public class RefacEpochStepCounter implements EpochStepCounter { | ||||
| 
 | ||||
|         @Getter | ||||
|         @Setter | ||||
|         private int stepCounter = 0; | ||||
|         private int currentEpochStep = 0; | ||||
| 
 | ||||
|         public void increment() { | ||||
|             ++stepCounter; | ||||
|         public void incrementEpochStep() { | ||||
|             ++currentEpochStep; | ||||
|         } | ||||
| 
 | ||||
|         @Override | ||||
|         public int getStepCounter() { | ||||
|             return 0; | ||||
|         } | ||||
|     } | ||||
| } | ||||
|  | ||||
| @ -1,10 +1,11 @@ | ||||
| package org.deeplearning4j.rl4j.util; | ||||
| 
 | ||||
| import lombok.AccessLevel; | ||||
| import lombok.Getter; | ||||
| import lombok.Setter; | ||||
| import org.deeplearning4j.gym.StepReply; | ||||
| import org.deeplearning4j.rl4j.learning.EpochStepCounter; | ||||
| import org.deeplearning4j.rl4j.learning.IHistoryProcessor; | ||||
| import org.deeplearning4j.rl4j.learning.ILearning; | ||||
| import org.deeplearning4j.rl4j.learning.StepCountable; | ||||
| import org.deeplearning4j.rl4j.mdp.MDP; | ||||
| import org.deeplearning4j.rl4j.observation.Observation; | ||||
| import org.deeplearning4j.rl4j.space.ActionSpace; | ||||
| @ -19,50 +20,20 @@ public class LegacyMDPWrapper<O, A, AS extends ActionSpace<A>> implements MDP<Ob | ||||
|     private final MDP<O, A, AS> wrappedMDP; | ||||
|     @Getter | ||||
|     private final WrapperObservationSpace observationSpace; | ||||
|     private final ILearning learning; | ||||
| 
 | ||||
|     @Getter(AccessLevel.PRIVATE) @Setter(AccessLevel.PUBLIC) | ||||
|     private IHistoryProcessor historyProcessor; | ||||
|     private final StepCountable stepCountable; | ||||
|     private int skipFrame; | ||||
| 
 | ||||
|     private int step = 0; | ||||
|     private final EpochStepCounter epochStepCounter; | ||||
| 
 | ||||
|     public LegacyMDPWrapper(MDP<O, A, AS> wrappedMDP, ILearning learning) { | ||||
|         this(wrappedMDP, learning, null, null); | ||||
|     } | ||||
|     private int skipFrame = 1; | ||||
|     private int requiredFrame = 0; | ||||
| 
 | ||||
|     public LegacyMDPWrapper(MDP<O, A, AS> wrappedMDP, IHistoryProcessor historyProcessor, StepCountable stepCountable) { | ||||
|         this(wrappedMDP, null, historyProcessor, stepCountable); | ||||
|     } | ||||
| 
 | ||||
|     private LegacyMDPWrapper(MDP<O, A, AS> wrappedMDP, ILearning learning, IHistoryProcessor historyProcessor, StepCountable stepCountable) { | ||||
|     public LegacyMDPWrapper(MDP<O, A, AS> wrappedMDP, IHistoryProcessor historyProcessor, EpochStepCounter epochStepCounter) { | ||||
|         this.wrappedMDP = wrappedMDP; | ||||
|         this.observationSpace = new WrapperObservationSpace(wrappedMDP.getObservationSpace().getShape()); | ||||
|         this.learning = learning; | ||||
|         this.historyProcessor = historyProcessor; | ||||
|         this.stepCountable = stepCountable; | ||||
|     } | ||||
| 
 | ||||
|     private IHistoryProcessor getHistoryProcessor() { | ||||
|         if(historyProcessor != null) { | ||||
|             return historyProcessor; | ||||
|         } | ||||
| 
 | ||||
|         if (learning != null) { | ||||
|             return learning.getHistoryProcessor(); | ||||
|         } | ||||
|         return null; | ||||
|     } | ||||
| 
 | ||||
|     public void setHistoryProcessor(IHistoryProcessor historyProcessor) { | ||||
|         this.historyProcessor = historyProcessor; | ||||
|     } | ||||
| 
 | ||||
|     private int getStep() { | ||||
|         if(stepCountable != null) { | ||||
|             return stepCountable.getStepCounter(); | ||||
|         } | ||||
| 
 | ||||
|         return learning.getStepCounter(); | ||||
|         this.epochStepCounter = epochStepCounter; | ||||
|     } | ||||
| 
 | ||||
|     @Override | ||||
| @ -83,9 +54,12 @@ public class LegacyMDPWrapper<O, A, AS extends ActionSpace<A>> implements MDP<Ob | ||||
| 
 | ||||
|         if(historyProcessor != null) { | ||||
|             skipFrame = historyProcessor.getConf().getSkipFrame(); | ||||
|             requiredFrame = skipFrame * (historyProcessor.getConf().getHistoryLength() - 1); | ||||
| 
 | ||||
|             historyProcessor.add(rawObservation); | ||||
|         } | ||||
|         step = 0; | ||||
| 
 | ||||
|         observation.setSkipped(skipFrame != 0); | ||||
| 
 | ||||
|         return observation; | ||||
|     } | ||||
| @ -97,21 +71,18 @@ public class LegacyMDPWrapper<O, A, AS extends ActionSpace<A>> implements MDP<Ob | ||||
|         StepReply<O> rawStepReply = wrappedMDP.step(a); | ||||
|         INDArray rawObservation = getInput(rawStepReply.getObservation()); | ||||
| 
 | ||||
|         ++step; | ||||
|         int stepOfObservation = epochStepCounter.getCurrentEpochStep() + 1; | ||||
| 
 | ||||
|         int requiredFrame = 0; | ||||
|         if(historyProcessor != null) { | ||||
|             historyProcessor.record(rawObservation); | ||||
| 
 | ||||
|             requiredFrame = skipFrame * (historyProcessor.getConf().getHistoryLength() - 1); | ||||
|             if ((getStep() % skipFrame == 0 && step >= requiredFrame) | ||||
|             || (step % skipFrame == 0 && step < requiredFrame )){ | ||||
|             if (stepOfObservation % skipFrame == 0) { | ||||
|                 historyProcessor.add(rawObservation); | ||||
|             } | ||||
|         } | ||||
| 
 | ||||
|         Observation observation; | ||||
|         if(historyProcessor != null && step >= requiredFrame) { | ||||
|         if(historyProcessor != null && stepOfObservation >= requiredFrame) { | ||||
|             observation = new Observation(historyProcessor.getHistory(), true); | ||||
|             observation.getData().muli(1.0 / historyProcessor.getScale()); | ||||
|         } | ||||
| @ -119,6 +90,10 @@ public class LegacyMDPWrapper<O, A, AS extends ActionSpace<A>> implements MDP<Ob | ||||
|             observation = new Observation(new INDArray[] { rawObservation }, false); | ||||
|         } | ||||
| 
 | ||||
|         if(stepOfObservation % skipFrame != 0 || stepOfObservation < requiredFrame) { | ||||
|             observation.setSkipped(true); | ||||
|         } | ||||
| 
 | ||||
|         return new StepReply<Observation>(observation, rawStepReply.getReward(), rawStepReply.isDone(), rawStepReply.getInfo()); | ||||
|     } | ||||
| 
 | ||||
| @ -134,7 +109,7 @@ public class LegacyMDPWrapper<O, A, AS extends ActionSpace<A>> implements MDP<Ob | ||||
| 
 | ||||
|     @Override | ||||
|     public MDP<Observation, A, AS> newInstance() { | ||||
|         return new LegacyMDPWrapper<O, A, AS>(wrappedMDP.newInstance(), learning); | ||||
|         return new LegacyMDPWrapper<O, A, AS>(wrappedMDP.newInstance(), historyProcessor, epochStepCounter); | ||||
|     } | ||||
| 
 | ||||
|     private INDArray getInput(O obs) { | ||||
|  | ||||
| @ -32,7 +32,7 @@ public class AsyncThreadDiscreteTest { | ||||
|         MockMDP mdpMock = new MockMDP(observationSpace); | ||||
|         TrainingListenerList listeners = new TrainingListenerList(); | ||||
|         MockPolicy policyMock = new MockPolicy(); | ||||
|         MockAsyncConfiguration config = new MockAsyncConfiguration(5, 16, 0, 0, 2, 5,0, 0, 0, 0); | ||||
|         MockAsyncConfiguration config = new MockAsyncConfiguration(5, 100, 0, 0, 2, 5,0, 0, 0, 0); | ||||
|         TestAsyncThreadDiscrete sut = new TestAsyncThreadDiscrete(asyncGlobalMock, mdpMock, listeners, 0, 0, policyMock, config, hpMock); | ||||
| 
 | ||||
|         // Act | ||||
| @ -41,8 +41,8 @@ public class AsyncThreadDiscreteTest { | ||||
|         // Assert | ||||
|         assertEquals(2, sut.trainSubEpochResults.size()); | ||||
|         double[][] expectedLastObservations = new double[][] { | ||||
|             new double[] { 4.0, 6.0, 8.0, 9.0, 11.0 }, | ||||
|             new double[] { 8.0, 9.0, 11.0, 13.0, 15.0 }, | ||||
|             new double[] { 4.0, 6.0, 8.0, 10.0, 12.0 }, | ||||
|             new double[] { 8.0, 10.0, 12.0, 14.0, 16.0 }, | ||||
|         }; | ||||
|         double[] expectedSubEpochReturnRewards = new double[] { 42.0, 58.0 }; | ||||
|         for(int i = 0; i < 2; ++i) { | ||||
| @ -60,7 +60,7 @@ public class AsyncThreadDiscreteTest { | ||||
|         assertEquals(2, asyncGlobalMock.enqueueCallCount); | ||||
| 
 | ||||
|         // HistoryProcessor | ||||
|         double[] expectedAddValues = new double[] { 0.0, 2.0, 4.0, 6.0, 8.0, 9.0, 11.0, 13.0, 15.0 }; | ||||
|         double[] expectedAddValues = new double[] { 0.0, 2.0, 4.0, 6.0, 8.0, 10.0, 12.0, 14.0, 16.0 }; | ||||
|         assertEquals(expectedAddValues.length, hpMock.addCalls.size()); | ||||
|         for(int i = 0; i < expectedAddValues.length; ++i) { | ||||
|             assertEquals(expectedAddValues[i], hpMock.addCalls.get(i).getDouble(0), 0.00001); | ||||
| @ -75,9 +75,9 @@ public class AsyncThreadDiscreteTest { | ||||
|         // Policy | ||||
|         double[][] expectedPolicyInputs = new double[][] { | ||||
|                 new double[] { 0.0, 2.0, 4.0, 6.0, 8.0 }, | ||||
|                 new double[] { 2.0, 4.0, 6.0, 8.0, 9.0 }, | ||||
|                 new double[] { 4.0, 6.0, 8.0, 9.0, 11.0 }, | ||||
|                 new double[] { 6.0, 8.0, 9.0, 11.0, 13.0 }, | ||||
|                 new double[] { 2.0, 4.0, 6.0, 8.0, 10.0 }, | ||||
|                 new double[] { 4.0, 6.0, 8.0, 10.0, 12.0 }, | ||||
|                 new double[] { 6.0, 8.0, 10.0, 12.0, 14.0 }, | ||||
|         }; | ||||
|         assertEquals(expectedPolicyInputs.length, policyMock.actionInputs.size()); | ||||
|         for(int i = 0; i < expectedPolicyInputs.length; ++i) { | ||||
| @ -93,11 +93,11 @@ public class AsyncThreadDiscreteTest { | ||||
|         assertEquals(2, nnMock.copyCallCount); | ||||
|         double[][] expectedNNInputs = new double[][] { | ||||
|                 new double[] { 0.0, 2.0, 4.0, 6.0, 8.0 }, | ||||
|                 new double[] { 2.0, 4.0, 6.0, 8.0, 9.0 }, | ||||
|                 new double[] { 4.0, 6.0, 8.0, 9.0, 11.0 }, // FIXME: This one comes from the computation of output of the last minitrans | ||||
|                 new double[] { 4.0, 6.0, 8.0, 9.0, 11.0 }, | ||||
|                 new double[] { 6.0, 8.0, 9.0, 11.0, 13.0 }, | ||||
|                 new double[] { 8.0, 9.0, 11.0, 13.0, 15.0 }, // FIXME: This one comes from the computation of output of the last minitrans | ||||
|                 new double[] { 2.0, 4.0, 6.0, 8.0, 10.0 }, | ||||
|                 new double[] { 4.0, 6.0, 8.0, 10.0, 12.0 }, // FIXME: This one comes from the computation of output of the last minitrans | ||||
|                 new double[] { 4.0, 6.0, 8.0, 10.0, 12.0 }, | ||||
|                 new double[] { 6.0, 8.0, 10.0, 12.0, 14.0 }, | ||||
|                 new double[] { 8.0, 10.0, 12.0, 14.0, 16.0 }, // FIXME: This one comes from the computation of output of the last minitrans | ||||
|         }; | ||||
|         assertEquals(expectedNNInputs.length, nnMock.outputAllInputs.size()); | ||||
|         for(int i = 0; i < expectedNNInputs.length; ++i) { | ||||
| @ -113,13 +113,13 @@ public class AsyncThreadDiscreteTest { | ||||
|         double[][][] expectedMinitransObs = new double[][][] { | ||||
|             new double[][] { | ||||
|                     new double[] { 0.0, 2.0, 4.0, 6.0, 8.0 }, | ||||
|                     new double[] { 2.0, 4.0, 6.0, 8.0, 9.0 }, | ||||
|                     new double[] { 4.0, 6.0, 8.0, 9.0, 11.0 }, // FIXME: The last minitrans contains the next observation | ||||
|                     new double[] { 2.0, 4.0, 6.0, 8.0, 10.0 }, | ||||
|                     new double[] { 4.0, 6.0, 8.0, 10.0, 12.0 }, // FIXME: The last minitrans contains the next observation | ||||
|             }, | ||||
|             new double[][] { | ||||
|                     new double[] { 4.0, 6.0, 8.0, 9.0, 11.0 }, | ||||
|                     new double[] { 6.0, 8.0, 9.0, 11.0, 13.0 }, | ||||
|                     new double[] { 8.0, 9.0, 11.0, 13.0, 15 }, // FIXME: The last minitrans contains the next observation | ||||
|                     new double[] { 4.0, 6.0, 8.0, 10.0, 12.0 }, | ||||
|                     new double[] { 6.0, 8.0, 10.0, 12.0, 14.0 }, | ||||
|                     new double[] { 8.0, 10.0, 12.0, 14.0, 16.0 }, // FIXME: The last minitrans contains the next observation | ||||
|             } | ||||
|         }; | ||||
|         double[] expectedOutputs = new double[] { 1.0, 2.0, 3.0, 4.0, 5.0, 6.0 }; | ||||
|  | ||||
| @ -5,15 +5,12 @@ import lombok.Getter; | ||||
| import org.deeplearning4j.rl4j.learning.IHistoryProcessor; | ||||
| import org.deeplearning4j.rl4j.learning.listener.TrainingListenerList; | ||||
| import org.deeplearning4j.rl4j.mdp.MDP; | ||||
| import org.deeplearning4j.rl4j.network.NeuralNet; | ||||
| import org.deeplearning4j.rl4j.observation.Observation; | ||||
| import org.deeplearning4j.rl4j.policy.Policy; | ||||
| import org.deeplearning4j.rl4j.space.DiscreteSpace; | ||||
| import org.deeplearning4j.rl4j.space.Encodable; | ||||
| import org.deeplearning4j.rl4j.support.*; | ||||
| import org.deeplearning4j.rl4j.util.IDataManager; | ||||
| import org.junit.Test; | ||||
| import org.nd4j.linalg.api.ndarray.INDArray; | ||||
| 
 | ||||
| import java.util.ArrayList; | ||||
| import java.util.List; | ||||
| @ -91,7 +88,7 @@ public class AsyncThreadTest { | ||||
| 
 | ||||
|         // Assert | ||||
|         assertEquals(numberOfEpochs, context.listener.statEntries.size()); | ||||
|         int[] expectedStepCounter = new int[] { 2, 4, 6, 8, 10 }; | ||||
|         int[] expectedStepCounter = new int[] { 10, 20, 30, 40, 50 }; | ||||
|         double expectedReward = (1.0 + 2.0 + 3.0 + 4.0 + 5.0 + 6.0 + 7.0 + 8.0) // reward from init | ||||
|             + 1.0; // Reward from trainSubEpoch() | ||||
|         for(int i = 0; i < numberOfEpochs; ++i) { | ||||
| @ -114,7 +111,7 @@ public class AsyncThreadTest { | ||||
|         // Assert | ||||
|         assertEquals(numberOfEpochs, context.sut.trainSubEpochParams.size()); | ||||
|         double[] expectedObservation = new double[] { 0.0, 2.0, 4.0, 6.0, 8.0 }; | ||||
|         for(int i = 0; i < context.sut.getEpochCounter(); ++i) { | ||||
|         for(int i = 0; i < context.sut.trainSubEpochParams.size(); ++i) { | ||||
|             MockAsyncThread.TrainSubEpochParams params = context.sut.trainSubEpochParams.get(i); | ||||
|             assertEquals(2, params.nstep); | ||||
|             assertEquals(expectedObservation.length, params.obs.getData().shape()[1]); | ||||
| @ -199,7 +196,9 @@ public class AsyncThreadTest { | ||||
|         protected SubEpochReturn trainSubEpoch(Observation obs, int nstep) { | ||||
|             asyncGlobal.increaseCurrentLoop(); | ||||
|             trainSubEpochParams.add(new TrainSubEpochParams(obs, nstep)); | ||||
|             setStepCounter(getStepCounter() + nstep); | ||||
|             for(int i = 0; i < nstep; ++i) { | ||||
|                 incrementStep(); | ||||
|             } | ||||
|             return new SubEpochReturn(nstep, null, 1.0, 1.0); | ||||
|         } | ||||
| 
 | ||||
|  | ||||
| @ -18,8 +18,8 @@ public class ExpReplayTest { | ||||
|         ExpReplay<Integer> sut = new ExpReplay<Integer>(2, 1, randomMock); | ||||
| 
 | ||||
|         // Act | ||||
|         Transition<Integer> transition = new Transition<Integer>(new Observation(new INDArray[] { Nd4j.create(1) }), | ||||
|                 123, 234, false, new Observation(Nd4j.create(1))); | ||||
|         Transition<Integer> transition = buildTransition(new Observation(new INDArray[] { Nd4j.create(1) }), | ||||
|                 123, 234, new Observation(Nd4j.create(1))); | ||||
|         sut.store(transition); | ||||
|         List<Transition<Integer>> results = sut.getBatch(1); | ||||
| 
 | ||||
| @ -36,12 +36,12 @@ public class ExpReplayTest { | ||||
|         ExpReplay<Integer> sut = new ExpReplay<Integer>(2, 1, randomMock); | ||||
| 
 | ||||
|         // Act | ||||
|         Transition<Integer> transition1 = new Transition<Integer>(new Observation(new INDArray[] { Nd4j.create(1) }), | ||||
|                 1, 2, false, new Observation(Nd4j.create(1))); | ||||
|         Transition<Integer> transition2 = new Transition<Integer>(new Observation(new INDArray[] { Nd4j.create(1) }), | ||||
|                 3, 4, false, new Observation(Nd4j.create(1))); | ||||
|         Transition<Integer> transition3 = new Transition<Integer>(new Observation(new INDArray[] { Nd4j.create(1) }), | ||||
|                 5, 6, false, new Observation(Nd4j.create(1))); | ||||
|         Transition<Integer> transition1 = buildTransition(new Observation(new INDArray[] { Nd4j.create(1) }), | ||||
|                 1, 2, new Observation(Nd4j.create(1))); | ||||
|         Transition<Integer> transition2 = buildTransition(new Observation(new INDArray[] { Nd4j.create(1) }), | ||||
|                 3, 4, new Observation(Nd4j.create(1))); | ||||
|         Transition<Integer> transition3 = buildTransition(new Observation(new INDArray[] { Nd4j.create(1) }), | ||||
|                 5, 6, new Observation(Nd4j.create(1))); | ||||
|         sut.store(transition1); | ||||
|         sut.store(transition2); | ||||
|         sut.store(transition3); | ||||
| @ -78,12 +78,12 @@ public class ExpReplayTest { | ||||
|         ExpReplay<Integer> sut = new ExpReplay<Integer>(5, 1, randomMock); | ||||
| 
 | ||||
|         // Act | ||||
|         Transition<Integer> transition1 = new Transition<Integer>(new Observation(new INDArray[] { Nd4j.create(1) }), | ||||
|                 1, 2, false, new Observation(Nd4j.create(1))); | ||||
|         Transition<Integer> transition2 = new Transition<Integer>(new Observation(new INDArray[] { Nd4j.create(1) }), | ||||
|                 3, 4, false, new Observation(Nd4j.create(1))); | ||||
|         Transition<Integer> transition3 = new Transition<Integer>(new Observation(new INDArray[] { Nd4j.create(1) }), | ||||
|                 5, 6, false, new Observation(Nd4j.create(1))); | ||||
|         Transition<Integer> transition1 = buildTransition(new Observation(new INDArray[] { Nd4j.create(1) }), | ||||
|                 1, 2, new Observation(Nd4j.create(1))); | ||||
|         Transition<Integer> transition2 = buildTransition(new Observation(new INDArray[] { Nd4j.create(1) }), | ||||
|                 3, 4, new Observation(Nd4j.create(1))); | ||||
|         Transition<Integer> transition3 = buildTransition(new Observation(new INDArray[] { Nd4j.create(1) }), | ||||
|                 5, 6, new Observation(Nd4j.create(1))); | ||||
|         sut.store(transition1); | ||||
|         sut.store(transition2); | ||||
|         sut.store(transition3); | ||||
| @ -100,12 +100,12 @@ public class ExpReplayTest { | ||||
|         ExpReplay<Integer> sut = new ExpReplay<Integer>(5, 1, randomMock); | ||||
| 
 | ||||
|         // Act | ||||
|         Transition<Integer> transition1 = new Transition<Integer>(new Observation(new INDArray[] { Nd4j.create(1) }), | ||||
|                 1, 2, false, new Observation(Nd4j.create(1))); | ||||
|         Transition<Integer> transition2 = new Transition<Integer>(new Observation(new INDArray[] { Nd4j.create(1) }), | ||||
|                 3, 4, false, new Observation(Nd4j.create(1))); | ||||
|         Transition<Integer> transition3 = new Transition<Integer>(new Observation(new INDArray[] { Nd4j.create(1) }), | ||||
|                 5, 6, false, new Observation(Nd4j.create(1))); | ||||
|         Transition<Integer> transition1 = buildTransition(new Observation(new INDArray[] { Nd4j.create(1) }), | ||||
|                 1, 2, new Observation(Nd4j.create(1))); | ||||
|         Transition<Integer> transition2 = buildTransition(new Observation(new INDArray[] { Nd4j.create(1) }), | ||||
|                 3, 4, new Observation(Nd4j.create(1))); | ||||
|         Transition<Integer> transition3 = buildTransition(new Observation(new INDArray[] { Nd4j.create(1) }), | ||||
|                 5, 6, new Observation(Nd4j.create(1))); | ||||
|         sut.store(transition1); | ||||
|         sut.store(transition2); | ||||
|         sut.store(transition3); | ||||
| @ -131,16 +131,16 @@ public class ExpReplayTest { | ||||
|         ExpReplay<Integer> sut = new ExpReplay<Integer>(5, 1, randomMock); | ||||
| 
 | ||||
|         // Act | ||||
|         Transition<Integer> transition1 = new Transition<Integer>(new Observation(new INDArray[] { Nd4j.create(1) }), | ||||
|                 1, 2, false, new Observation(Nd4j.create(1))); | ||||
|         Transition<Integer> transition2 = new Transition<Integer>(new Observation(new INDArray[] { Nd4j.create(1) }), | ||||
|                 3, 4, false, new Observation(Nd4j.create(1))); | ||||
|         Transition<Integer> transition3 = new Transition<Integer>(new Observation(new INDArray[] { Nd4j.create(1) }), | ||||
|                 5, 6, false, new Observation(Nd4j.create(1))); | ||||
|         Transition<Integer> transition4 = new Transition<Integer>(new Observation(new INDArray[] { Nd4j.create(1) }), | ||||
|                 7, 8, false, new Observation(Nd4j.create(1))); | ||||
|         Transition<Integer> transition5 = new Transition<Integer>(new Observation(new INDArray[] { Nd4j.create(1) }), | ||||
|                 9, 10, false, new Observation(Nd4j.create(1))); | ||||
|         Transition<Integer> transition1 = buildTransition(new Observation(new INDArray[] { Nd4j.create(1) }), | ||||
|                 1, 2, new Observation(Nd4j.create(1))); | ||||
|         Transition<Integer> transition2 = buildTransition(new Observation(new INDArray[] { Nd4j.create(1) }), | ||||
|                 3, 4, new Observation(Nd4j.create(1))); | ||||
|         Transition<Integer> transition3 = buildTransition(new Observation(new INDArray[] { Nd4j.create(1) }), | ||||
|                 5, 6, new Observation(Nd4j.create(1))); | ||||
|         Transition<Integer> transition4 = buildTransition(new Observation(new INDArray[] { Nd4j.create(1) }), | ||||
|                 7, 8, new Observation(Nd4j.create(1))); | ||||
|         Transition<Integer> transition5 = buildTransition(new Observation(new INDArray[] { Nd4j.create(1) }), | ||||
|                 9, 10, new Observation(Nd4j.create(1))); | ||||
|         sut.store(transition1); | ||||
|         sut.store(transition2); | ||||
|         sut.store(transition3); | ||||
| @ -168,16 +168,16 @@ public class ExpReplayTest { | ||||
|         ExpReplay<Integer> sut = new ExpReplay<Integer>(5, 1, randomMock); | ||||
| 
 | ||||
|         // Act | ||||
|         Transition<Integer> transition1 = new Transition<Integer>(new Observation(new INDArray[] { Nd4j.create(1) }), | ||||
|                 1, 2, false, new Observation(Nd4j.create(1))); | ||||
|         Transition<Integer> transition2 = new Transition<Integer>(new Observation(new INDArray[] { Nd4j.create(1) }), | ||||
|                 3, 4, false, new Observation(Nd4j.create(1))); | ||||
|         Transition<Integer> transition3 = new Transition<Integer>(new Observation(new INDArray[] { Nd4j.create(1) }), | ||||
|                 5, 6, false, new Observation(Nd4j.create(1))); | ||||
|         Transition<Integer> transition4 = new Transition<Integer>(new Observation(new INDArray[] { Nd4j.create(1) }), | ||||
|                 7, 8, false, new Observation(Nd4j.create(1))); | ||||
|         Transition<Integer> transition5 = new Transition<Integer>(new Observation(new INDArray[] { Nd4j.create(1) }), | ||||
|                 9, 10, false, new Observation(Nd4j.create(1))); | ||||
|         Transition<Integer> transition1 = buildTransition(new Observation(new INDArray[] { Nd4j.create(1) }), | ||||
|                 1, 2, new Observation(Nd4j.create(1))); | ||||
|         Transition<Integer> transition2 = buildTransition(new Observation(new INDArray[] { Nd4j.create(1) }), | ||||
|                 3, 4, new Observation(Nd4j.create(1))); | ||||
|         Transition<Integer> transition3 = buildTransition(new Observation(new INDArray[] { Nd4j.create(1) }), | ||||
|                 5, 6, new Observation(Nd4j.create(1))); | ||||
|         Transition<Integer> transition4 = buildTransition(new Observation(new INDArray[] { Nd4j.create(1) }), | ||||
|                 7, 8, new Observation(Nd4j.create(1))); | ||||
|         Transition<Integer> transition5 = buildTransition(new Observation(new INDArray[] { Nd4j.create(1) }), | ||||
|                 9, 10, new Observation(Nd4j.create(1))); | ||||
|         sut.store(transition1); | ||||
|         sut.store(transition2); | ||||
|         sut.store(transition3); | ||||
| @ -196,6 +196,12 @@ public class ExpReplayTest { | ||||
| 
 | ||||
|         assertEquals(5, (int)results.get(2).getAction()); | ||||
|         assertEquals(6, (int)results.get(2).getReward()); | ||||
|     } | ||||
| 
 | ||||
|     private Transition<Integer> buildTransition(Observation observation, Integer action, double reward, Observation nextObservation) { | ||||
|         Transition<Integer> result = new Transition<Integer>(observation, action, reward, false); | ||||
|         result.setNextObservation(nextObservation); | ||||
| 
 | ||||
|         return result; | ||||
|     } | ||||
| } | ||||
|  | ||||
| @ -1,5 +1,6 @@ | ||||
| package org.deeplearning4j.rl4j.learning.sync; | ||||
| 
 | ||||
| import lombok.Getter; | ||||
| import org.deeplearning4j.rl4j.learning.sync.qlearning.QLearning; | ||||
| import org.deeplearning4j.rl4j.learning.sync.support.MockStatEntry; | ||||
| import org.deeplearning4j.rl4j.mdp.MDP; | ||||
| @ -88,12 +89,15 @@ public class SyncLearningTest { | ||||
| 
 | ||||
|         private final LConfiguration conf; | ||||
| 
 | ||||
|         @Getter | ||||
|         private int currentEpochStep = 0; | ||||
| 
 | ||||
|         public MockSyncLearning(LConfiguration conf) { | ||||
|             this.conf = conf; | ||||
|         } | ||||
| 
 | ||||
|         @Override | ||||
|         protected void preEpoch() { } | ||||
|         protected void preEpoch() { currentEpochStep = 0;  } | ||||
| 
 | ||||
|         @Override | ||||
|         protected void postEpoch() { } | ||||
| @ -101,7 +105,7 @@ public class SyncLearningTest { | ||||
|         @Override | ||||
|         protected IDataManager.StatEntry trainEpoch() { | ||||
|             setStepCounter(getStepCounter() + 1); | ||||
|             return new MockStatEntry(getEpochCounter(), getStepCounter(), 1.0); | ||||
|             return new MockStatEntry(getCurrentEpochStep(), getStepCounter(), 1.0); | ||||
|         } | ||||
| 
 | ||||
|         @Override | ||||
|  | ||||
| @ -21,7 +21,7 @@ public class TransitionTest { | ||||
|         Observation nextObservation = buildObservation(nextObs); | ||||
| 
 | ||||
|         // Act | ||||
|         Transition transition = new Transition(observation, 123, 234.0, false, nextObservation); | ||||
|         Transition transition = buildTransition(observation, 123, 234.0, nextObservation); | ||||
| 
 | ||||
|         // Assert | ||||
|         double[][] expectedObservation = new double[][] { obs }; | ||||
| @ -52,7 +52,7 @@ public class TransitionTest { | ||||
|         Observation nextObservation = buildObservation(nextObs); | ||||
| 
 | ||||
|         // Act | ||||
|         Transition transition = new Transition(observation, 123, 234.0, false, nextObservation); | ||||
|         Transition transition = buildTransition(observation, 123, 234.0, nextObservation); | ||||
| 
 | ||||
|         // Assert | ||||
|         assertExpected(obs, transition.getObservation().getData()); | ||||
| @ -71,12 +71,12 @@ public class TransitionTest { | ||||
|         double[] obs1 = new double[] { 0.0, 1.0, 2.0 }; | ||||
|         Observation observation1 = buildObservation(obs1); | ||||
|         Observation nextObservation1 = buildObservation(new double[] { 100.0, 101.0, 102.0 }); | ||||
|         transitions.add(new Transition(observation1,0, 0.0, false, nextObservation1)); | ||||
|         transitions.add(buildTransition(observation1,0, 0.0, nextObservation1)); | ||||
| 
 | ||||
|         double[] obs2 = new double[] { 10.0, 11.0, 12.0 }; | ||||
|         Observation observation2 = buildObservation(obs2); | ||||
|         Observation nextObservation2 = buildObservation(new double[] { 110.0, 111.0, 112.0 }); | ||||
|         transitions.add(new Transition(observation2, 0, 0.0, false, nextObservation2)); | ||||
|         transitions.add(buildTransition(observation2, 0, 0.0, nextObservation2)); | ||||
| 
 | ||||
|         // Act | ||||
|         INDArray result = Transition.buildStackedObservations(transitions); | ||||
| @ -101,7 +101,7 @@ public class TransitionTest { | ||||
|         double[] nextObs1 = new double[] { 100.0, 101.0, 102.0 }; | ||||
|         Observation nextObservation1 = buildNextObservation(obs1, nextObs1); | ||||
| 
 | ||||
|         transitions.add(new Transition(observation1, 0, 0.0, false, nextObservation1)); | ||||
|         transitions.add(buildTransition(observation1, 0, 0.0, nextObservation1)); | ||||
| 
 | ||||
|         double[][] obs2 = new double[][] { | ||||
|                 { 10.0, 11.0, 12.0 }, | ||||
| @ -112,7 +112,7 @@ public class TransitionTest { | ||||
| 
 | ||||
|         double[] nextObs2 = new double[] { 110.0, 111.0, 112.0 }; | ||||
|         Observation nextObservation2 = buildNextObservation(obs2, nextObs2); | ||||
|         transitions.add(new Transition(observation2, 0, 0.0, false, nextObservation2)); | ||||
|         transitions.add(buildTransition(observation2, 0, 0.0, nextObservation2)); | ||||
| 
 | ||||
|         // Act | ||||
|         INDArray result = Transition.buildStackedObservations(transitions); | ||||
| @ -131,13 +131,13 @@ public class TransitionTest { | ||||
|         double[] nextObs1 = new double[] { 100.0, 101.0, 102.0 }; | ||||
|         Observation observation1 = buildObservation(obs1); | ||||
|         Observation nextObservation1 = buildObservation(nextObs1); | ||||
|         transitions.add(new Transition(observation1, 0, 0.0, false, nextObservation1)); | ||||
|         transitions.add(buildTransition(observation1, 0, 0.0, nextObservation1)); | ||||
| 
 | ||||
|         double[] obs2 = new double[] { 10.0, 11.0, 12.0 }; | ||||
|         double[] nextObs2 = new double[] { 110.0, 111.0, 112.0 }; | ||||
|         Observation observation2 = buildObservation(obs2); | ||||
|         Observation nextObservation2 = buildObservation(nextObs2); | ||||
|         transitions.add(new Transition(observation2, 0, 0.0, false, nextObservation2)); | ||||
|         transitions.add(buildTransition(observation2, 0, 0.0, nextObservation2)); | ||||
| 
 | ||||
|         // Act | ||||
|         INDArray result = Transition.buildStackedNextObservations(transitions); | ||||
| @ -162,7 +162,7 @@ public class TransitionTest { | ||||
|         double[] nextObs1 = new double[] { 100.0, 101.0, 102.0 }; | ||||
|         Observation nextObservation1 = buildNextObservation(obs1, nextObs1); | ||||
| 
 | ||||
|         transitions.add(new Transition(observation1, 0, 0.0, false, nextObservation1)); | ||||
|         transitions.add(buildTransition(observation1, 0, 0.0, nextObservation1)); | ||||
| 
 | ||||
|         double[][] obs2 = new double[][] { | ||||
|                 { 10.0, 11.0, 12.0 }, | ||||
| @ -174,7 +174,7 @@ public class TransitionTest { | ||||
|         double[] nextObs2 = new double[] { 110.0, 111.0, 112.0 }; | ||||
|         Observation nextObservation2 = buildNextObservation(obs2, nextObs2); | ||||
| 
 | ||||
|         transitions.add(new Transition(observation2, 0, 0.0, false, nextObservation2)); | ||||
|         transitions.add(buildTransition(observation2, 0, 0.0, nextObservation2)); | ||||
| 
 | ||||
|         // Act | ||||
|         INDArray result = Transition.buildStackedNextObservations(transitions); | ||||
| @ -207,7 +207,13 @@ public class TransitionTest { | ||||
|                 Nd4j.create(obs[1]).reshape(1, 3), | ||||
|         }; | ||||
|         return new Observation(nextHistory); | ||||
|     } | ||||
| 
 | ||||
|     private Transition buildTransition(Observation observation, int action, double reward, Observation nextObservation) { | ||||
|         Transition result = new Transition(observation, action, reward, false); | ||||
|         result.setNextObservation(nextObservation); | ||||
| 
 | ||||
|         return result; | ||||
|     } | ||||
| 
 | ||||
|     private void assertExpected(double[] expected, INDArray actual) { | ||||
|  | ||||
| @ -40,8 +40,10 @@ public class QLearningDiscreteTest { | ||||
|             new int[] { 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4 }); | ||||
|         MockMDP mdp = new MockMDP(observationSpace, random); | ||||
| 
 | ||||
|         int initStepCount = 8; | ||||
| 
 | ||||
|         QLearning.QLConfiguration conf = new QLearning.QLConfiguration(0, 24, 0, 5, 1, 1000, | ||||
|                 0, 1.0, 0, 0, 0, 0, true); | ||||
|                 initStepCount, 1.0, 0, 0, 0, 0, true); | ||||
|         MockDataManager dataManager = new MockDataManager(false); | ||||
|         MockExpReplay expReplay = new MockExpReplay(); | ||||
|         TestQLearningDiscrete sut = new TestQLearningDiscrete(mdp, dqn, conf, dataManager, expReplay, 10, random); | ||||
| @ -60,7 +62,7 @@ public class QLearningDiscreteTest { | ||||
|         for(int i = 0; i < expectedRecords.length; ++i) { | ||||
|             assertEquals(expectedRecords[i], hp.recordCalls.get(i).getDouble(0), 0.0001); | ||||
|         } | ||||
|         double[] expectedAdds = new double[] { 0.0, 2.0, 4.0, 6.0, 8.0, 9.0, 11.0, 13.0, 15.0, 17.0, 19.0, 21.0, 23.0 }; | ||||
|         double[] expectedAdds = new double[] { 0.0, 2.0, 4.0, 6.0, 8.0, 10.0, 12.0, 14.0, 16.0, 18.0, 20.0, 22.0, 24.0 }; | ||||
|         assertEquals(expectedAdds.length, hp.addCalls.size()); | ||||
|         for(int i = 0; i < expectedAdds.length; ++i) { | ||||
|             assertEquals(expectedAdds[i], hp.addCalls.get(i).getDouble(0), 0.0001); | ||||
| @ -75,19 +77,19 @@ public class QLearningDiscreteTest { | ||||
|         assertEquals(14, dqn.outputParams.size()); | ||||
|         double[][] expectedDQNOutput = new double[][] { | ||||
|                 new double[] { 0.0, 2.0, 4.0, 6.0, 8.0 }, | ||||
|                 new double[] { 2.0, 4.0, 6.0, 8.0, 9.0 }, | ||||
|                 new double[] { 2.0, 4.0, 6.0, 8.0, 9.0 }, | ||||
|                 new double[] { 4.0, 6.0, 8.0, 9.0, 11.0 }, | ||||
|                 new double[] { 6.0, 8.0, 9.0, 11.0, 13.0 }, | ||||
|                 new double[] { 6.0, 8.0, 9.0, 11.0, 13.0 }, | ||||
|                 new double[] { 8.0, 9.0, 11.0, 13.0, 15.0 }, | ||||
|                 new double[] { 8.0, 9.0, 11.0, 13.0, 15.0 }, | ||||
|                 new double[] { 9.0, 11.0, 13.0, 15.0, 17.0 }, | ||||
|                 new double[] { 9.0, 11.0, 13.0, 15.0, 17.0 }, | ||||
|                 new double[] { 11.0, 13.0, 15.0, 17.0, 19.0 }, | ||||
|                 new double[] { 11.0, 13.0, 15.0, 17.0, 19.0 }, | ||||
|                 new double[] { 13.0, 15.0, 17.0, 19.0, 21.0 }, | ||||
|                 new double[] { 13.0, 15.0, 17.0, 19.0, 21.0 }, | ||||
|                 new double[] { 2.0, 4.0, 6.0, 8.0, 10.0 }, | ||||
|                 new double[] { 2.0, 4.0, 6.0, 8.0, 10.0 }, | ||||
|                 new double[] { 4.0, 6.0, 8.0, 10.0, 12.0 }, | ||||
|                 new double[] { 6.0, 8.0, 10.0, 12.0, 14.0 }, | ||||
|                 new double[] { 6.0, 8.0, 10.0, 12.0, 14.0 }, | ||||
|                 new double[] { 8.0, 10.0, 12.0, 14.0, 16.0 }, | ||||
|                 new double[] { 8.0, 10.0, 12.0, 14.0, 16.0 }, | ||||
|                 new double[] { 10.0, 12.0, 14.0, 16.0, 18.0 }, | ||||
|                 new double[] { 10.0, 12.0, 14.0, 16.0, 18.0 }, | ||||
|                 new double[] { 12.0, 14.0, 16.0, 18.0, 20.0 }, | ||||
|                 new double[] { 12.0, 14.0, 16.0, 18.0, 20.0 }, | ||||
|                 new double[] { 14.0, 16.0, 18.0, 20.0, 22.0 }, | ||||
|                 new double[] { 14.0, 16.0, 18.0, 20.0, 22.0 }, | ||||
|         }; | ||||
|         for(int i = 0; i < expectedDQNOutput.length; ++i) { | ||||
|             INDArray outputParam = dqn.outputParams.get(i); | ||||
| @ -105,19 +107,20 @@ public class QLearningDiscreteTest { | ||||
|         assertArrayEquals(new Integer[] {0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 4, 4, 2, 2, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4 }, mdp.actions.toArray()); | ||||
| 
 | ||||
|         // ExpReplay calls | ||||
|         double[] expectedTrRewards = new double[] { 9.0, 21.0, 25.0, 29.0, 33.0, 37.0, 41.0, 45.0 }; | ||||
|         double[] expectedTrRewards = new double[] { 9.0, 21.0, 25.0, 29.0, 33.0, 37.0, 41.0 }; | ||||
|         int[] expectedTrActions = new int[] { 1, 4, 2, 4, 4, 4, 4, 4 }; | ||||
|         double[] expectedTrNextObservation = new double[] { 2.0, 4.0, 6.0, 8.0, 9.0, 11.0, 13.0, 15.0 }; | ||||
|         double[] expectedTrNextObservation = new double[] { 2.0, 4.0, 6.0, 8.0, 10.0, 12.0, 14.0, 16.0 }; | ||||
|         double[][] expectedTrObservations = new double[][] { | ||||
|                 new double[] { 0.0, 2.0, 4.0, 6.0, 8.0 }, | ||||
|                 new double[] { 2.0, 4.0, 6.0, 8.0, 9.0 }, | ||||
|                 new double[] { 4.0, 6.0, 8.0, 9.0, 11.0 }, | ||||
|                 new double[] { 6.0, 8.0, 9.0, 11.0, 13.0 }, | ||||
|                 new double[] { 8.0, 9.0, 11.0, 13.0, 15.0 }, | ||||
|                 new double[] { 9.0, 11.0, 13.0, 15.0, 17.0 }, | ||||
|                 new double[] { 11.0, 13.0, 15.0, 17.0, 19.0 }, | ||||
|                 new double[] { 13.0, 15.0, 17.0, 19.0, 21.0 }, | ||||
|                 new double[] { 2.0, 4.0, 6.0, 8.0, 10.0 }, | ||||
|                 new double[] { 4.0, 6.0, 8.0, 10.0, 12.0 }, | ||||
|                 new double[] { 6.0, 8.0, 10.0, 12.0, 14.0 }, | ||||
|                 new double[] { 8.0, 10.0, 12.0, 14.0, 16.0 }, | ||||
|                 new double[] { 10.0, 12.0, 14.0, 16.0, 18.0 }, | ||||
|                 new double[] { 12.0, 14.0, 16.0, 18.0, 20.0 }, | ||||
|                 new double[] { 14.0, 16.0, 18.0, 20.0, 22.0 }, | ||||
|         }; | ||||
|         assertEquals(expectedTrObservations.length, expReplay.transitions.size()); | ||||
|         for(int i = 0; i < expectedTrRewards.length; ++i) { | ||||
|             Transition tr = expReplay.transitions.get(i); | ||||
|             assertEquals(expectedTrRewards[i], tr.getReward(), 0.0001); | ||||
| @ -129,7 +132,7 @@ public class QLearningDiscreteTest { | ||||
|         } | ||||
| 
 | ||||
|         // trainEpoch result | ||||
|         assertEquals(16, result.getStepCounter()); | ||||
|         assertEquals(initStepCount + 16, result.getStepCounter()); | ||||
|         assertEquals(300.0, result.getReward(), 0.00001); | ||||
|         assertTrue(dqn.hasBeenReset); | ||||
|         assertTrue(((MockDQN)sut.getTargetQNetwork()).hasBeenReset); | ||||
|  | ||||
| @ -26,7 +26,7 @@ public class DoubleDQNTest { | ||||
| 
 | ||||
|         List<Transition<Integer>> transitions = new ArrayList<Transition<Integer>>() { | ||||
|             { | ||||
|                 add(new Transition<Integer>(buildObservation(new double[]{1.1, 2.2}), | ||||
|                 add(builtTransition(buildObservation(new double[]{1.1, 2.2}), | ||||
|                         0, 1.0, true, buildObservation(new double[]{11.0, 22.0}))); | ||||
|             } | ||||
|         }; | ||||
| @ -52,7 +52,7 @@ public class DoubleDQNTest { | ||||
| 
 | ||||
|         List<Transition<Integer>> transitions = new ArrayList<Transition<Integer>>() { | ||||
|             { | ||||
|                 add(new Transition<Integer>(buildObservation(new double[]{1.1, 2.2}), | ||||
|                 add(builtTransition(buildObservation(new double[]{1.1, 2.2}), | ||||
|                         0, 1.0, false, buildObservation(new double[]{11.0, 22.0}))); | ||||
|             } | ||||
|         }; | ||||
| @ -78,11 +78,11 @@ public class DoubleDQNTest { | ||||
| 
 | ||||
|         List<Transition<Integer>> transitions = new ArrayList<Transition<Integer>>() { | ||||
|             { | ||||
|                 add(new Transition<Integer>(buildObservation(new double[]{1.1, 2.2}), | ||||
|                 add(builtTransition(buildObservation(new double[]{1.1, 2.2}), | ||||
|                         0, 1.0, false, buildObservation(new double[]{11.0, 22.0}))); | ||||
|                 add(new Transition<Integer>(buildObservation(new double[]{3.3, 4.4}), | ||||
|                 add(builtTransition(buildObservation(new double[]{3.3, 4.4}), | ||||
|                         1, 2.0, false, buildObservation(new double[]{33.0, 44.0}))); | ||||
|                 add(new Transition<Integer>(buildObservation(new double[]{5.5, 6.6}), | ||||
|                 add(builtTransition(buildObservation(new double[]{5.5, 6.6}), | ||||
|                         0, 3.0, true, buildObservation(new double[]{55.0, 66.0}))); | ||||
|             } | ||||
|         }; | ||||
| @ -108,4 +108,11 @@ public class DoubleDQNTest { | ||||
|     private Observation buildObservation(double[] data) { | ||||
|         return new Observation(new INDArray[]{Nd4j.create(data).reshape(1, 2)}); | ||||
|     } | ||||
| 
 | ||||
|     private Transition<Integer> builtTransition(Observation observation, Integer action, double reward, boolean isTerminal, Observation nextObservation) { | ||||
|         Transition<Integer> result = new Transition<Integer>(observation, action, reward, isTerminal); | ||||
|         result.setNextObservation(nextObservation); | ||||
| 
 | ||||
|         return result; | ||||
|     } | ||||
| } | ||||
|  | ||||
| @ -25,7 +25,7 @@ public class StandardDQNTest { | ||||
| 
 | ||||
|         List<Transition<Integer>> transitions = new ArrayList<Transition<Integer>>() { | ||||
|             { | ||||
|                 add(new Transition<Integer>(buildObservation(new double[]{1.1, 2.2}), | ||||
|                 add(buildTransition(buildObservation(new double[]{1.1, 2.2}), | ||||
|                         0, 1.0, true, buildObservation(new double[]{11.0, 22.0}))); | ||||
|             } | ||||
|         }; | ||||
| @ -51,7 +51,7 @@ public class StandardDQNTest { | ||||
| 
 | ||||
|         List<Transition<Integer>> transitions = new ArrayList<Transition<Integer>>() { | ||||
|             { | ||||
|                 add(new Transition<Integer>(buildObservation(new double[]{1.1, 2.2}), | ||||
|                 add(buildTransition(buildObservation(new double[]{1.1, 2.2}), | ||||
|                         0, 1.0, false, buildObservation(new double[]{11.0, 22.0}))); | ||||
|             } | ||||
|         }; | ||||
| @ -77,11 +77,11 @@ public class StandardDQNTest { | ||||
| 
 | ||||
|         List<Transition<Integer>> transitions = new ArrayList<Transition<Integer>>() { | ||||
|             { | ||||
|                 add(new Transition<Integer>(buildObservation(new double[]{1.1, 2.2}), | ||||
|                 add(buildTransition(buildObservation(new double[]{1.1, 2.2}), | ||||
|                         0, 1.0, false, buildObservation(new double[]{11.0, 22.0}))); | ||||
|                 add(new Transition<Integer>(buildObservation(new double[]{3.3, 4.4}), | ||||
|                 add(buildTransition(buildObservation(new double[]{3.3, 4.4}), | ||||
|                         1, 2.0, false, buildObservation(new double[]{33.0, 44.0}))); | ||||
|                 add(new Transition<Integer>(buildObservation(new double[]{5.5, 6.6}), | ||||
|                 add(buildTransition(buildObservation(new double[]{5.5, 6.6}), | ||||
|                         0, 3.0, true, buildObservation(new double[]{55.0, 66.0}))); | ||||
|             } | ||||
|         }; | ||||
| @ -108,4 +108,10 @@ public class StandardDQNTest { | ||||
|         return new Observation(new INDArray[]{Nd4j.create(data).reshape(1, 2)}); | ||||
|     } | ||||
| 
 | ||||
|     private Transition<Integer> buildTransition(Observation observation, Integer action, double reward, boolean isTerminal, Observation nextObservation) { | ||||
|         Transition<Integer> result = new Transition<Integer>(observation, action, reward, isTerminal); | ||||
|         result.setNextObservation(nextObservation); | ||||
| 
 | ||||
|         return result; | ||||
|     } | ||||
| } | ||||
|  | ||||
| @ -198,7 +198,7 @@ public class PolicyTest { | ||||
|         assertEquals(465.0, totalReward, 0.0001); | ||||
| 
 | ||||
|         // HistoryProcessor | ||||
|         assertEquals(27, hp.addCalls.size()); | ||||
|         assertEquals(16, hp.addCalls.size()); | ||||
|         assertEquals(31, hp.recordCalls.size()); | ||||
|         for(int i=0; i <= 30; ++i) { | ||||
|             assertEquals((double)i, hp.recordCalls.get(i).getDouble(0), 0.0001); | ||||
|  | ||||
| @ -142,6 +142,11 @@ public class DataManagerTrainingListenerTest { | ||||
|             return 0; | ||||
|         } | ||||
| 
 | ||||
|         @Override | ||||
|         public int getCurrentEpochStep() { | ||||
|             return 0; | ||||
|         } | ||||
| 
 | ||||
|         @Getter | ||||
|         @Setter | ||||
|         private IHistoryProcessor historyProcessor; | ||||
|  | ||||
| @ -93,6 +93,9 @@ public class GymEnv<O, A, AS extends ActionSpace<A>> implements MDP<O, A, AS> { | ||||
|     private boolean done = false; | ||||
| 
 | ||||
|     public GymEnv(String envId, boolean render, boolean monitor) { | ||||
|         this(envId, render, monitor, (Integer)null); | ||||
|     } | ||||
|     public GymEnv(String envId, boolean render, boolean monitor, Integer seed) { | ||||
|         this.envId = envId; | ||||
|         this.render = render; | ||||
|         this.monitor = monitor; | ||||
| @ -107,6 +110,10 @@ public class GymEnv<O, A, AS extends ActionSpace<A>> implements MDP<O, A, AS> { | ||||
|                 Py_DecRef(PyRun_StringFlags("env = gym.wrappers.Monitor(env, '" + GYM_MONITOR_DIR + "')", Py_single_input, globals, locals, null)); | ||||
|                 checkPythonError(); | ||||
|             } | ||||
|             if (seed != null) { | ||||
|                 Py_DecRef(PyRun_StringFlags("env.seed(" + seed + ")", Py_single_input, globals, locals, null)); | ||||
|                 checkPythonError(); | ||||
|             } | ||||
|             PyObject shapeTuple = PyRun_StringFlags("env.observation_space.shape", Py_eval_input, globals, locals, null); | ||||
|             int[] shape = new int[(int)PyTuple_Size(shapeTuple)]; | ||||
|             for (int i = 0; i < shape.length; i++) { | ||||
| @ -125,7 +132,10 @@ public class GymEnv<O, A, AS extends ActionSpace<A>> implements MDP<O, A, AS> { | ||||
|     } | ||||
| 
 | ||||
|     public GymEnv(String envId, boolean render, boolean monitor, int[] actions) { | ||||
|         this(envId, render, monitor); | ||||
|         this(envId, render, monitor, null, actions); | ||||
|     } | ||||
|     public GymEnv(String envId, boolean render, boolean monitor, Integer seed, int[] actions) { | ||||
|         this(envId, render, monitor, seed); | ||||
|         actionTransformer = new ActionTransformer((HighLowDiscrete) getActionSpace(), actions); | ||||
|     } | ||||
| 
 | ||||
|  | ||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user