RL4J: Sanitize async learner (#327)
* refactoring global async to use a much simpler update procedure with a single global lock Signed-off-by: Bam4d <chrisbam4d@gmail.com> * simplification of async learning algorithms, stabilization + better hyperparameters Signed-off-by: Bam4d <chrisbam4d@gmail.com> * started to play with using mockito for tests Signed-off-by: Bam4d <chrisbam4d@gmail.com> * Working on refactoring tests for async classes and trying to make async simpler Signed-off-by: Bam4d <chrisbam4d@gmail.com> * more work on mockito tests and making some tests much less complex and more explicit in what they are testing Signed-off-by: Bam4d <chrisbam4d@gmail.com> * some fixes from merging * do not allow copying of the current network to worker threads, fixing debug line Signed-off-by: Bam4d <chrisbam4d@gmail.com> * adding some more tests around PR review Signed-off-by: Bam4d <chrisbam4d@gmail.com> * Adding more tests after review comments Signed-off-by: Bam4d <chrisbam4d@gmail.com> * few more tests and fixes from PR review * remove rename of maxEpochStep to maxStepsPerEpisode as we agreed to review this in a seperate PR * 2019 instead of 2018 on copyright header * adding konduit copyright to files * some more copyright headers Signed-off-by: Bam4d <chrisbam4d@gmail.com> Co-authored-by: Alexandre Boulanger <aboulang2002@yahoo.com>master
parent
455a5d112d
commit
74420bca31
|
@ -70,3 +70,6 @@ venv2/
|
||||||
# Ignore the nd4j files that are created by javacpp at build to stop merge conflicts
|
# Ignore the nd4j files that are created by javacpp at build to stop merge conflicts
|
||||||
nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/nativeblas/Nd4jCpu.java
|
nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/nativeblas/Nd4jCpu.java
|
||||||
nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/nativeblas/Nd4jCuda.java
|
nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/nativeblas/Nd4jCuda.java
|
||||||
|
|
||||||
|
# Ignore meld temp files
|
||||||
|
*.orig
|
||||||
|
|
|
@ -19,6 +19,7 @@ package org.deeplearning4j.rl4j.mdp;
|
||||||
|
|
||||||
import org.deeplearning4j.gym.StepReply;
|
import org.deeplearning4j.gym.StepReply;
|
||||||
import org.deeplearning4j.rl4j.space.ActionSpace;
|
import org.deeplearning4j.rl4j.space.ActionSpace;
|
||||||
|
import org.deeplearning4j.rl4j.space.Encodable;
|
||||||
import org.deeplearning4j.rl4j.space.ObservationSpace;
|
import org.deeplearning4j.rl4j.space.ObservationSpace;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -31,20 +32,20 @@ import org.deeplearning4j.rl4j.space.ObservationSpace;
|
||||||
* in a "functionnal manner" if step return a mdp
|
* in a "functionnal manner" if step return a mdp
|
||||||
*
|
*
|
||||||
*/
|
*/
|
||||||
public interface MDP<O, A, AS extends ActionSpace<A>> {
|
public interface MDP<OBSERVATION, ACTION, ACTION_SPACE extends ActionSpace<ACTION>> {
|
||||||
|
|
||||||
ObservationSpace<O> getObservationSpace();
|
ObservationSpace<OBSERVATION> getObservationSpace();
|
||||||
|
|
||||||
AS getActionSpace();
|
ACTION_SPACE getActionSpace();
|
||||||
|
|
||||||
O reset();
|
OBSERVATION reset();
|
||||||
|
|
||||||
void close();
|
void close();
|
||||||
|
|
||||||
StepReply<O> step(A action);
|
StepReply<OBSERVATION> step(ACTION action);
|
||||||
|
|
||||||
boolean isDone();
|
boolean isDone();
|
||||||
|
|
||||||
MDP<O, A, AS> newInstance();
|
MDP<OBSERVATION, ACTION, ACTION_SPACE> newInstance();
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -17,24 +17,24 @@
|
||||||
package org.deeplearning4j.rl4j.space;
|
package org.deeplearning4j.rl4j.space;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @param <A> the type of Action
|
* @param <ACTION> the type of Action
|
||||||
* @author rubenfiszel (ruben.fiszel@epfl.ch) on 7/8/16.
|
* @author rubenfiszel (ruben.fiszel@epfl.ch) on 7/8/16.
|
||||||
* <p>
|
* <p>
|
||||||
* Should contain contextual information about the Action space, which is the space of all the actions that could be available.
|
* Should contain contextual information about the Action space, which is the space of all the actions that could be available.
|
||||||
* Also must know how to return a randomly uniformly sampled action.
|
* Also must know how to return a randomly uniformly sampled action.
|
||||||
*/
|
*/
|
||||||
public interface ActionSpace<A> {
|
public interface ActionSpace<ACTION> {
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @return A random action,
|
* @return A random action,
|
||||||
*/
|
*/
|
||||||
A randomAction();
|
ACTION randomAction();
|
||||||
|
|
||||||
Object encode(A action);
|
Object encode(ACTION action);
|
||||||
|
|
||||||
int getSize();
|
int getSize();
|
||||||
|
|
||||||
A noOp();
|
ACTION noOp();
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -121,6 +121,13 @@
|
||||||
<version>${datavec.version}</version>
|
<version>${datavec.version}</version>
|
||||||
</dependency>
|
</dependency>
|
||||||
|
|
||||||
|
<dependency>
|
||||||
|
<groupId>org.mockito</groupId>
|
||||||
|
<artifactId>mockito-core</artifactId>
|
||||||
|
<version>3.3.3</version>
|
||||||
|
<scope>test</scope>
|
||||||
|
</dependency>
|
||||||
|
|
||||||
</dependencies>
|
</dependencies>
|
||||||
|
|
||||||
<profiles>
|
<profiles>
|
||||||
|
|
|
@ -1,54 +1,54 @@
|
||||||
/*******************************************************************************
|
/*******************************************************************************
|
||||||
* Copyright (c) 2020 Konduit K.K.
|
* Copyright (c) 2020 Konduit K.K.
|
||||||
*
|
*
|
||||||
* This program and the accompanying materials are made available under the
|
* This program and the accompanying materials are made available under the
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
*
|
*
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
* License for the specific language governing permissions and limitations
|
* License for the specific language governing permissions and limitations
|
||||||
* under the License.
|
* under the License.
|
||||||
*
|
*
|
||||||
* SPDX-License-Identifier: Apache-2.0
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
******************************************************************************/
|
******************************************************************************/
|
||||||
package org.deeplearning4j.rl4j.experience;
|
package org.deeplearning4j.rl4j.experience;
|
||||||
|
|
||||||
import org.deeplearning4j.rl4j.observation.Observation;
|
import org.deeplearning4j.rl4j.observation.Observation;
|
||||||
|
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* A common interface to all classes capable of handling experience generated by the agents in a learning context.
|
* A common interface to all classes capable of handling experience generated by the agents in a learning context.
|
||||||
*
|
*
|
||||||
* @param <A> Action type
|
* @param <A> Action type
|
||||||
* @param <E> Experience type
|
* @param <E> Experience type
|
||||||
*
|
*
|
||||||
* @author Alexandre Boulanger
|
* @author Alexandre Boulanger
|
||||||
*/
|
*/
|
||||||
public interface ExperienceHandler<A, E> {
|
public interface ExperienceHandler<A, E> {
|
||||||
void addExperience(Observation observation, A action, double reward, boolean isTerminal);
|
void addExperience(Observation observation, A action, double reward, boolean isTerminal);
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Called when the episode is done with the last observation
|
* Called when the episode is done with the last observation
|
||||||
* @param observation
|
* @param observation
|
||||||
*/
|
*/
|
||||||
void setFinalObservation(Observation observation);
|
void setFinalObservation(Observation observation);
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @return The size of the list that will be returned by generateTrainingBatch().
|
* @return The size of the list that will be returned by generateTrainingBatch().
|
||||||
*/
|
*/
|
||||||
int getTrainingBatchSize();
|
int getTrainingBatchSize();
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* The elements are returned in the historical order (i.e. in the order they happened)
|
* The elements are returned in the historical order (i.e. in the order they happened)
|
||||||
* @return The list of experience elements
|
* @return The list of experience elements
|
||||||
*/
|
*/
|
||||||
List<E> generateTrainingBatch();
|
List<E> generateTrainingBatch();
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Signal the experience handler that a new episode is starting
|
* Signal the experience handler that a new episode is starting
|
||||||
*/
|
*/
|
||||||
void reset();
|
void reset();
|
||||||
}
|
}
|
||||||
|
|
|
@ -30,7 +30,7 @@ import java.util.List;
|
||||||
*/
|
*/
|
||||||
public class StateActionExperienceHandler<A> implements ExperienceHandler<A, StateActionPair<A>> {
|
public class StateActionExperienceHandler<A> implements ExperienceHandler<A, StateActionPair<A>> {
|
||||||
|
|
||||||
private List<StateActionPair<A>> stateActionPairs;
|
private List<StateActionPair<A>> stateActionPairs = new ArrayList<>();
|
||||||
|
|
||||||
public void setFinalObservation(Observation observation) {
|
public void setFinalObservation(Observation observation) {
|
||||||
// Do nothing
|
// Do nothing
|
||||||
|
|
|
@ -1,21 +0,0 @@
|
||||||
/*******************************************************************************
|
|
||||||
* Copyright (c) 2020 Konduit K.K.
|
|
||||||
*
|
|
||||||
* This program and the accompanying materials are made available under the
|
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* License for the specific language governing permissions and limitations
|
|
||||||
* under the License.
|
|
||||||
*
|
|
||||||
* SPDX-License-Identifier: Apache-2.0
|
|
||||||
******************************************************************************/
|
|
||||||
|
|
||||||
package org.deeplearning4j.rl4j.learning;
|
|
||||||
|
|
||||||
public interface EpochStepCounter {
|
|
||||||
int getCurrentEpochStep();
|
|
||||||
}
|
|
|
@ -1,5 +1,6 @@
|
||||||
/*******************************************************************************
|
/*******************************************************************************
|
||||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||||
|
* Copyright (c) 2020 Konduit K.K.
|
||||||
*
|
*
|
||||||
* This program and the accompanying materials are made available under the
|
* This program and the accompanying materials are made available under the
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
@ -28,9 +29,11 @@ import org.deeplearning4j.rl4j.mdp.MDP;
|
||||||
* @author Alexandre Boulanger
|
* @author Alexandre Boulanger
|
||||||
* @author rubenfiszel (ruben.fiszel@epfl.ch) on 8/5/16.
|
* @author rubenfiszel (ruben.fiszel@epfl.ch) on 8/5/16.
|
||||||
*/
|
*/
|
||||||
public interface IEpochTrainer extends EpochStepCounter {
|
public interface IEpochTrainer {
|
||||||
int getStepCounter();
|
int getStepCount();
|
||||||
int getEpochCounter();
|
int getEpochCount();
|
||||||
|
int getEpisodeCount();
|
||||||
|
int getCurrentEpisodeStepCount();
|
||||||
IHistoryProcessor getHistoryProcessor();
|
IHistoryProcessor getHistoryProcessor();
|
||||||
MDP getMdp();
|
MDP getMdp();
|
||||||
}
|
}
|
||||||
|
|
|
@ -18,6 +18,7 @@ package org.deeplearning4j.rl4j.learning;
|
||||||
|
|
||||||
import lombok.AllArgsConstructor;
|
import lombok.AllArgsConstructor;
|
||||||
import lombok.Builder;
|
import lombok.Builder;
|
||||||
|
import lombok.Data;
|
||||||
import lombok.Value;
|
import lombok.Value;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
|
||||||
|
@ -51,7 +52,7 @@ public interface IHistoryProcessor {
|
||||||
|
|
||||||
@AllArgsConstructor
|
@AllArgsConstructor
|
||||||
@Builder
|
@Builder
|
||||||
@Value
|
@Data
|
||||||
public static class Configuration {
|
public static class Configuration {
|
||||||
@Builder.Default int historyLength = 4;
|
@Builder.Default int historyLength = 4;
|
||||||
@Builder.Default int rescaledWidth = 84;
|
@Builder.Default int rescaledWidth = 84;
|
||||||
|
|
|
@ -21,19 +21,20 @@ import org.deeplearning4j.rl4j.learning.configuration.ILearningConfiguration;
|
||||||
import org.deeplearning4j.rl4j.mdp.MDP;
|
import org.deeplearning4j.rl4j.mdp.MDP;
|
||||||
import org.deeplearning4j.rl4j.policy.IPolicy;
|
import org.deeplearning4j.rl4j.policy.IPolicy;
|
||||||
import org.deeplearning4j.rl4j.space.ActionSpace;
|
import org.deeplearning4j.rl4j.space.ActionSpace;
|
||||||
|
import org.deeplearning4j.rl4j.space.Encodable;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @author rubenfiszel (ruben.fiszel@epfl.ch) 7/19/16.
|
* @author rubenfiszel (ruben.fiszel@epfl.ch) 7/19/16.
|
||||||
*
|
*
|
||||||
* A common interface that any training method should implement
|
* A common interface that any training method should implement
|
||||||
*/
|
*/
|
||||||
public interface ILearning<O, A, AS extends ActionSpace<A>> {
|
public interface ILearning<O extends Encodable, A, AS extends ActionSpace<A>> {
|
||||||
|
|
||||||
IPolicy<O, A> getPolicy();
|
IPolicy<O, A> getPolicy();
|
||||||
|
|
||||||
void train();
|
void train();
|
||||||
|
|
||||||
int getStepCounter();
|
int getStepCount();
|
||||||
|
|
||||||
ILearningConfiguration getConfiguration();
|
ILearningConfiguration getConfiguration();
|
||||||
|
|
||||||
|
|
|
@ -38,13 +38,13 @@ import org.nd4j.linalg.factory.Nd4j;
|
||||||
*
|
*
|
||||||
*/
|
*/
|
||||||
@Slf4j
|
@Slf4j
|
||||||
public abstract class Learning<O, A, AS extends ActionSpace<A>, NN extends NeuralNet>
|
public abstract class Learning<O extends Encodable, A, AS extends ActionSpace<A>, NN extends NeuralNet>
|
||||||
implements ILearning<O, A, AS>, NeuralNetFetchable<NN> {
|
implements ILearning<O, A, AS>, NeuralNetFetchable<NN> {
|
||||||
|
|
||||||
@Getter @Setter
|
@Getter @Setter
|
||||||
private int stepCounter = 0;
|
protected int stepCount = 0;
|
||||||
@Getter @Setter
|
@Getter @Setter
|
||||||
private int epochCounter = 0;
|
private int epochCount = 0;
|
||||||
@Getter @Setter
|
@Getter @Setter
|
||||||
private IHistoryProcessor historyProcessor = null;
|
private IHistoryProcessor historyProcessor = null;
|
||||||
|
|
||||||
|
@ -73,11 +73,11 @@ public abstract class Learning<O, A, AS extends ActionSpace<A>, NN extends Neura
|
||||||
public abstract NN getNeuralNet();
|
public abstract NN getNeuralNet();
|
||||||
|
|
||||||
public void incrementStep() {
|
public void incrementStep() {
|
||||||
stepCounter++;
|
stepCount++;
|
||||||
}
|
}
|
||||||
|
|
||||||
public void incrementEpoch() {
|
public void incrementEpoch() {
|
||||||
epochCounter++;
|
epochCount++;
|
||||||
}
|
}
|
||||||
|
|
||||||
public void setHistoryProcessor(HistoryProcessor.Configuration conf) {
|
public void setHistoryProcessor(HistoryProcessor.Configuration conf) {
|
||||||
|
|
|
@ -20,13 +20,11 @@ package org.deeplearning4j.rl4j.learning.async;
|
||||||
import lombok.Getter;
|
import lombok.Getter;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import org.deeplearning4j.nn.gradient.Gradient;
|
import org.deeplearning4j.nn.gradient.Gradient;
|
||||||
import org.deeplearning4j.rl4j.learning.configuration.AsyncQLearningConfiguration;
|
|
||||||
import org.deeplearning4j.rl4j.learning.configuration.IAsyncLearningConfiguration;
|
import org.deeplearning4j.rl4j.learning.configuration.IAsyncLearningConfiguration;
|
||||||
import org.deeplearning4j.rl4j.network.NeuralNet;
|
import org.deeplearning4j.rl4j.network.NeuralNet;
|
||||||
import org.nd4j.linalg.primitives.Pair;
|
|
||||||
|
|
||||||
import java.util.concurrent.ConcurrentLinkedQueue;
|
import java.util.concurrent.locks.Lock;
|
||||||
import java.util.concurrent.atomic.AtomicInteger;
|
import java.util.concurrent.locks.ReentrantLock;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @author rubenfiszel (ruben.fiszel@epfl.ch) on 8/5/16.
|
* @author rubenfiszel (ruben.fiszel@epfl.ch) on 8/5/16.
|
||||||
|
@ -52,69 +50,75 @@ import java.util.concurrent.atomic.AtomicInteger;
|
||||||
* structure
|
* structure
|
||||||
*/
|
*/
|
||||||
@Slf4j
|
@Slf4j
|
||||||
public class AsyncGlobal<NN extends NeuralNet> extends Thread implements IAsyncGlobal<NN> {
|
public class AsyncGlobal<NN extends NeuralNet> implements IAsyncGlobal<NN> {
|
||||||
|
|
||||||
@Getter
|
|
||||||
final private NN current;
|
final private NN current;
|
||||||
final private ConcurrentLinkedQueue<Pair<Gradient[], Integer>> queue;
|
|
||||||
final private IAsyncLearningConfiguration configuration;
|
|
||||||
private final IAsyncLearning learning;
|
|
||||||
@Getter
|
|
||||||
private AtomicInteger T = new AtomicInteger(0);
|
|
||||||
@Getter
|
|
||||||
private NN target;
|
|
||||||
@Getter
|
|
||||||
private boolean running = true;
|
|
||||||
|
|
||||||
public AsyncGlobal(NN initial, IAsyncLearningConfiguration configuration, IAsyncLearning learning) {
|
private NN target;
|
||||||
|
|
||||||
|
final private IAsyncLearningConfiguration configuration;
|
||||||
|
|
||||||
|
@Getter
|
||||||
|
private final Lock updateLock;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* The number of times the gradient has been updated by worker threads
|
||||||
|
*/
|
||||||
|
@Getter
|
||||||
|
private int workerUpdateCount;
|
||||||
|
|
||||||
|
@Getter
|
||||||
|
private int stepCount;
|
||||||
|
|
||||||
|
public AsyncGlobal(NN initial, IAsyncLearningConfiguration configuration) {
|
||||||
this.current = initial;
|
this.current = initial;
|
||||||
target = (NN) initial.clone();
|
target = (NN) initial.clone();
|
||||||
this.configuration = configuration;
|
this.configuration = configuration;
|
||||||
this.learning = learning;
|
|
||||||
queue = new ConcurrentLinkedQueue<>();
|
// This is used to sync between
|
||||||
|
updateLock = new ReentrantLock();
|
||||||
}
|
}
|
||||||
|
|
||||||
public boolean isTrainingComplete() {
|
public boolean isTrainingComplete() {
|
||||||
return T.get() >= configuration.getMaxStep();
|
return stepCount >= configuration.getMaxStep();
|
||||||
}
|
}
|
||||||
|
|
||||||
public void enqueue(Gradient[] gradient, Integer nstep) {
|
public void applyGradient(Gradient[] gradient, int nstep) {
|
||||||
if (running && !isTrainingComplete()) {
|
|
||||||
queue.add(new Pair<>(gradient, nstep));
|
if (isTrainingComplete()) {
|
||||||
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
try {
|
||||||
|
updateLock.lock();
|
||||||
|
|
||||||
|
current.applyGradient(gradient, nstep);
|
||||||
|
|
||||||
|
stepCount += nstep;
|
||||||
|
workerUpdateCount++;
|
||||||
|
|
||||||
|
int targetUpdateFrequency = configuration.getLearnerUpdateFrequency();
|
||||||
|
|
||||||
|
// If we have a target update frequency, this means we only want to update the workers after a certain number of async updates
|
||||||
|
// This can lead to more stable training
|
||||||
|
if (targetUpdateFrequency != -1 && workerUpdateCount % targetUpdateFrequency == 0) {
|
||||||
|
log.info("Updating target network at updates={} steps={}", workerUpdateCount, stepCount);
|
||||||
|
} else {
|
||||||
|
target.copy(current);
|
||||||
|
}
|
||||||
|
} finally {
|
||||||
|
updateLock.unlock();
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void run() {
|
public NN getTarget() {
|
||||||
|
try {
|
||||||
while (!isTrainingComplete() && running) {
|
updateLock.lock();
|
||||||
if (!queue.isEmpty()) {
|
return target;
|
||||||
Pair<Gradient[], Integer> pair = queue.poll();
|
} finally {
|
||||||
T.addAndGet(pair.getSecond());
|
updateLock.unlock();
|
||||||
Gradient[] gradient = pair.getFirst();
|
|
||||||
synchronized (this) {
|
|
||||||
current.applyGradient(gradient, pair.getSecond());
|
|
||||||
}
|
|
||||||
if (configuration.getLearnerUpdateFrequency() != -1 && T.get() / configuration.getLearnerUpdateFrequency() > (T.get() - pair.getSecond())
|
|
||||||
/ configuration.getLearnerUpdateFrequency()) {
|
|
||||||
log.info("TARGET UPDATE at T = " + T.get());
|
|
||||||
synchronized (this) {
|
|
||||||
target.copy(current);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Force the immediate termination of the AsyncGlobal instance. Queued work items will be discarded and the AsyncLearning instance will be forced to terminate too.
|
|
||||||
*/
|
|
||||||
public void terminate() {
|
|
||||||
if (running) {
|
|
||||||
running = false;
|
|
||||||
queue.clear();
|
|
||||||
learning.terminate();
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -40,9 +40,9 @@ import org.nd4j.linalg.factory.Nd4j;
|
||||||
* @author Alexandre Boulanger
|
* @author Alexandre Boulanger
|
||||||
*/
|
*/
|
||||||
@Slf4j
|
@Slf4j
|
||||||
public abstract class AsyncLearning<O extends Encodable, A, AS extends ActionSpace<A>, NN extends NeuralNet>
|
public abstract class AsyncLearning<OBSERVATION extends Encodable, ACTION, ACTION_SPACE extends ActionSpace<ACTION>, NN extends NeuralNet>
|
||||||
extends Learning<O, A, AS, NN>
|
extends Learning<OBSERVATION, ACTION, ACTION_SPACE, NN>
|
||||||
implements IAsyncLearning {
|
implements IAsyncLearning {
|
||||||
|
|
||||||
private Thread monitorThread = null;
|
private Thread monitorThread = null;
|
||||||
|
|
||||||
|
@ -69,10 +69,6 @@ public abstract class AsyncLearning<O extends Encodable, A, AS extends ActionSpa
|
||||||
|
|
||||||
protected abstract IAsyncGlobal<NN> getAsyncGlobal();
|
protected abstract IAsyncGlobal<NN> getAsyncGlobal();
|
||||||
|
|
||||||
protected void startGlobalThread() {
|
|
||||||
getAsyncGlobal().start();
|
|
||||||
}
|
|
||||||
|
|
||||||
protected boolean isTrainingComplete() {
|
protected boolean isTrainingComplete() {
|
||||||
return getAsyncGlobal().isTrainingComplete();
|
return getAsyncGlobal().isTrainingComplete();
|
||||||
}
|
}
|
||||||
|
@ -87,7 +83,6 @@ public abstract class AsyncLearning<O extends Encodable, A, AS extends ActionSpa
|
||||||
private int progressMonitorFrequency = 20000;
|
private int progressMonitorFrequency = 20000;
|
||||||
|
|
||||||
private void launchThreads() {
|
private void launchThreads() {
|
||||||
startGlobalThread();
|
|
||||||
for (int i = 0; i < getConfiguration().getNumThreads(); i++) {
|
for (int i = 0; i < getConfiguration().getNumThreads(); i++) {
|
||||||
Thread t = newThread(i, i % Nd4j.getAffinityManager().getNumberOfDevices());
|
Thread t = newThread(i, i % Nd4j.getAffinityManager().getNumberOfDevices());
|
||||||
t.start();
|
t.start();
|
||||||
|
@ -99,8 +94,8 @@ public abstract class AsyncLearning<O extends Encodable, A, AS extends ActionSpa
|
||||||
* @return The current step
|
* @return The current step
|
||||||
*/
|
*/
|
||||||
@Override
|
@Override
|
||||||
public int getStepCounter() {
|
public int getStepCount() {
|
||||||
return getAsyncGlobal().getT().get();
|
return getAsyncGlobal().getStepCount();
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -129,14 +124,13 @@ public abstract class AsyncLearning<O extends Encodable, A, AS extends ActionSpa
|
||||||
monitorTraining();
|
monitorTraining();
|
||||||
}
|
}
|
||||||
|
|
||||||
cleanupPostTraining();
|
|
||||||
listeners.notifyTrainingFinished();
|
listeners.notifyTrainingFinished();
|
||||||
}
|
}
|
||||||
|
|
||||||
protected void monitorTraining() {
|
protected void monitorTraining() {
|
||||||
try {
|
try {
|
||||||
monitorThread = Thread.currentThread();
|
monitorThread = Thread.currentThread();
|
||||||
while (canContinue && !isTrainingComplete() && getAsyncGlobal().isRunning()) {
|
while (canContinue && !isTrainingComplete()) {
|
||||||
canContinue = listeners.notifyTrainingProgress(this);
|
canContinue = listeners.notifyTrainingProgress(this);
|
||||||
if (!canContinue) {
|
if (!canContinue) {
|
||||||
return;
|
return;
|
||||||
|
@ -152,11 +146,6 @@ public abstract class AsyncLearning<O extends Encodable, A, AS extends ActionSpa
|
||||||
monitorThread = null;
|
monitorThread = null;
|
||||||
}
|
}
|
||||||
|
|
||||||
protected void cleanupPostTraining() {
|
|
||||||
// Worker threads stops automatically when the global thread stops
|
|
||||||
getAsyncGlobal().terminate();
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Force the immediate termination of the learning. All learning threads, the AsyncGlobal thread and the monitor thread will be terminated.
|
* Force the immediate termination of the learning. All learning threads, the AsyncGlobal thread and the monitor thread will be terminated.
|
||||||
*/
|
*/
|
||||||
|
|
|
@ -32,6 +32,7 @@ import org.deeplearning4j.rl4j.network.NeuralNet;
|
||||||
import org.deeplearning4j.rl4j.observation.Observation;
|
import org.deeplearning4j.rl4j.observation.Observation;
|
||||||
import org.deeplearning4j.rl4j.policy.IPolicy;
|
import org.deeplearning4j.rl4j.policy.IPolicy;
|
||||||
import org.deeplearning4j.rl4j.space.ActionSpace;
|
import org.deeplearning4j.rl4j.space.ActionSpace;
|
||||||
|
import org.deeplearning4j.rl4j.space.Encodable;
|
||||||
import org.deeplearning4j.rl4j.util.IDataManager;
|
import org.deeplearning4j.rl4j.util.IDataManager;
|
||||||
import org.deeplearning4j.rl4j.util.LegacyMDPWrapper;
|
import org.deeplearning4j.rl4j.util.LegacyMDPWrapper;
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
|
@ -47,39 +48,63 @@ import org.nd4j.linalg.factory.Nd4j;
|
||||||
* @author Alexandre Boulanger
|
* @author Alexandre Boulanger
|
||||||
*/
|
*/
|
||||||
@Slf4j
|
@Slf4j
|
||||||
public abstract class AsyncThread<O, A, AS extends ActionSpace<A>, NN extends NeuralNet>
|
public abstract class AsyncThread<OBSERVATION extends Encodable, ACTION, ACTION_SPACE extends ActionSpace<ACTION>, NN extends NeuralNet>
|
||||||
extends Thread implements IEpochTrainer {
|
extends Thread implements IEpochTrainer {
|
||||||
|
|
||||||
@Getter
|
@Getter
|
||||||
private int threadNumber;
|
private int threadNumber;
|
||||||
|
|
||||||
@Getter
|
@Getter
|
||||||
protected final int deviceNum;
|
protected final int deviceNum;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* The number of steps that this async thread has produced
|
||||||
|
*/
|
||||||
@Getter @Setter
|
@Getter @Setter
|
||||||
private int stepCounter = 0;
|
protected int stepCount = 0;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* The number of epochs (updates) that this thread has sent to the global learner
|
||||||
|
*/
|
||||||
@Getter @Setter
|
@Getter @Setter
|
||||||
private int epochCounter = 0;
|
protected int epochCount = 0;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* The number of environment episodes that have been played out
|
||||||
|
*/
|
||||||
|
@Getter @Setter
|
||||||
|
protected int episodeCount = 0;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* The number of steps in the current episode
|
||||||
|
*/
|
||||||
|
@Getter
|
||||||
|
protected int currentEpisodeStepCount = 0;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* If the current episode needs to be reset
|
||||||
|
*/
|
||||||
|
boolean episodeComplete = true;
|
||||||
|
|
||||||
@Getter @Setter
|
@Getter @Setter
|
||||||
private IHistoryProcessor historyProcessor;
|
private IHistoryProcessor historyProcessor;
|
||||||
|
|
||||||
@Getter
|
private boolean isEpisodeStarted = false;
|
||||||
private int currentEpochStep = 0;
|
private final LegacyMDPWrapper<OBSERVATION, ACTION, ACTION_SPACE> mdp;
|
||||||
|
|
||||||
private boolean isEpochStarted = false;
|
|
||||||
private final LegacyMDPWrapper<O, A, AS> mdp;
|
|
||||||
|
|
||||||
private final TrainingListenerList listeners;
|
private final TrainingListenerList listeners;
|
||||||
|
|
||||||
public AsyncThread(IAsyncGlobal<NN> asyncGlobal, MDP<O, A, AS> mdp, TrainingListenerList listeners, int threadNumber, int deviceNum) {
|
public AsyncThread(MDP<OBSERVATION, ACTION, ACTION_SPACE> mdp, TrainingListenerList listeners, int threadNumber, int deviceNum) {
|
||||||
this.mdp = new LegacyMDPWrapper<O, A, AS>(mdp, null, this);
|
this.mdp = new LegacyMDPWrapper<OBSERVATION, ACTION, ACTION_SPACE>(mdp, null);
|
||||||
this.listeners = listeners;
|
this.listeners = listeners;
|
||||||
this.threadNumber = threadNumber;
|
this.threadNumber = threadNumber;
|
||||||
this.deviceNum = deviceNum;
|
this.deviceNum = deviceNum;
|
||||||
}
|
}
|
||||||
|
|
||||||
public MDP<O, A, AS> getMdp() {
|
public MDP<OBSERVATION, ACTION, ACTION_SPACE> getMdp() {
|
||||||
return mdp.getWrappedMDP();
|
return mdp.getWrappedMDP();
|
||||||
}
|
}
|
||||||
protected LegacyMDPWrapper<O, A, AS> getLegacyMDPWrapper() {
|
protected LegacyMDPWrapper<OBSERVATION, ACTION, ACTION_SPACE> getLegacyMDPWrapper() {
|
||||||
return mdp;
|
return mdp;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -92,13 +117,13 @@ public abstract class AsyncThread<O, A, AS extends ActionSpace<A>, NN extends Ne
|
||||||
mdp.setHistoryProcessor(historyProcessor);
|
mdp.setHistoryProcessor(historyProcessor);
|
||||||
}
|
}
|
||||||
|
|
||||||
protected void postEpoch() {
|
protected void postEpisode() {
|
||||||
if (getHistoryProcessor() != null)
|
if (getHistoryProcessor() != null)
|
||||||
getHistoryProcessor().stopMonitor();
|
getHistoryProcessor().stopMonitor();
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
protected void preEpoch() {
|
protected void preEpisode() {
|
||||||
// Do nothing
|
// Do nothing
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -125,74 +150,69 @@ public abstract class AsyncThread<O, A, AS extends ActionSpace<A>, NN extends Ne
|
||||||
*/
|
*/
|
||||||
@Override
|
@Override
|
||||||
public void run() {
|
public void run() {
|
||||||
try {
|
RunContext context = new RunContext();
|
||||||
RunContext context = new RunContext();
|
Nd4j.getAffinityManager().unsafeSetDevice(deviceNum);
|
||||||
Nd4j.getAffinityManager().unsafeSetDevice(deviceNum);
|
|
||||||
|
|
||||||
log.info("ThreadNum-" + threadNumber + " Started!");
|
log.info("ThreadNum-" + threadNumber + " Started!");
|
||||||
|
|
||||||
while (!getAsyncGlobal().isTrainingComplete() && getAsyncGlobal().isRunning()) {
|
while (!getAsyncGlobal().isTrainingComplete()) {
|
||||||
if (!isEpochStarted) {
|
|
||||||
boolean canContinue = startNewEpoch(context);
|
|
||||||
if (!canContinue) {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
handleTraining(context);
|
if (episodeComplete) {
|
||||||
|
startEpisode(context);
|
||||||
if (currentEpochStep >= getConf().getMaxEpochStep() || getMdp().isDone()) {
|
}
|
||||||
boolean canContinue = finishEpoch(context);
|
|
||||||
if (!canContinue) {
|
if(!startEpoch(context)) {
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|
||||||
++epochCounter;
|
episodeComplete = handleTraining(context);
|
||||||
}
|
|
||||||
|
if(!finishEpoch(context)) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
if(episodeComplete) {
|
||||||
|
finishEpisode(context);
|
||||||
}
|
}
|
||||||
}
|
|
||||||
finally {
|
|
||||||
terminateWork();
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
private void handleTraining(RunContext context) {
|
private boolean finishEpoch(RunContext context) {
|
||||||
int maxSteps = Math.min(getConf().getNStep(), getConf().getMaxEpochStep() - currentEpochStep);
|
epochCount++;
|
||||||
SubEpochReturn subEpochReturn = trainSubEpoch(context.obs, maxSteps);
|
IDataManager.StatEntry statEntry = new AsyncStatEntry(stepCount, epochCount, context.rewards, currentEpisodeStepCount, context.score);
|
||||||
|
return listeners.notifyEpochTrainingResult(this, statEntry);
|
||||||
|
}
|
||||||
|
|
||||||
|
private boolean startEpoch(RunContext context) {
|
||||||
|
return listeners.notifyNewEpoch(this);
|
||||||
|
}
|
||||||
|
|
||||||
|
private boolean handleTraining(RunContext context) {
|
||||||
|
int maxTrainSteps = Math.min(getConf().getNStep(), getConf().getMaxEpochStep() - currentEpisodeStepCount);
|
||||||
|
SubEpochReturn subEpochReturn = trainSubEpoch(context.obs, maxTrainSteps);
|
||||||
|
|
||||||
context.obs = subEpochReturn.getLastObs();
|
context.obs = subEpochReturn.getLastObs();
|
||||||
context.rewards += subEpochReturn.getReward();
|
context.rewards += subEpochReturn.getReward();
|
||||||
context.score = subEpochReturn.getScore();
|
context.score = subEpochReturn.getScore();
|
||||||
|
|
||||||
|
return subEpochReturn.isEpisodeComplete();
|
||||||
}
|
}
|
||||||
|
|
||||||
private boolean startNewEpoch(RunContext context) {
|
private void startEpisode(RunContext context) {
|
||||||
getCurrent().reset();
|
getCurrent().reset();
|
||||||
Learning.InitMdp<Observation> initMdp = refacInitMdp();
|
Learning.InitMdp<Observation> initMdp = refacInitMdp();
|
||||||
|
|
||||||
context.obs = initMdp.getLastObs();
|
context.obs = initMdp.getLastObs();
|
||||||
context.rewards = initMdp.getReward();
|
context.rewards = initMdp.getReward();
|
||||||
|
|
||||||
isEpochStarted = true;
|
preEpisode();
|
||||||
preEpoch();
|
episodeCount++;
|
||||||
|
|
||||||
return listeners.notifyNewEpoch(this);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
private boolean finishEpoch(RunContext context) {
|
private void finishEpisode(RunContext context) {
|
||||||
isEpochStarted = false;
|
postEpisode();
|
||||||
postEpoch();
|
|
||||||
IDataManager.StatEntry statEntry = new AsyncStatEntry(getStepCounter(), epochCounter, context.rewards, currentEpochStep, context.score);
|
|
||||||
|
|
||||||
log.info("ThreadNum-" + threadNumber + " Epoch: " + getCurrentEpochStep() + ", reward: " + context.rewards);
|
log.info("ThreadNum-{} Episode step: {}, Episode: {}, Epoch: {}, reward: {}", threadNumber, currentEpisodeStepCount, episodeCount, epochCount, context.rewards);
|
||||||
|
|
||||||
return listeners.notifyEpochTrainingResult(this, statEntry);
|
|
||||||
}
|
|
||||||
|
|
||||||
private void terminateWork() {
|
|
||||||
getAsyncGlobal().terminate();
|
|
||||||
if(isEpochStarted) {
|
|
||||||
postEpoch();
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
protected abstract NN getCurrent();
|
protected abstract NN getCurrent();
|
||||||
|
@ -201,35 +221,35 @@ public abstract class AsyncThread<O, A, AS extends ActionSpace<A>, NN extends Ne
|
||||||
|
|
||||||
protected abstract IAsyncLearningConfiguration getConf();
|
protected abstract IAsyncLearningConfiguration getConf();
|
||||||
|
|
||||||
protected abstract IPolicy<O, A> getPolicy(NN net);
|
protected abstract IPolicy<OBSERVATION, ACTION> getPolicy(NN net);
|
||||||
|
|
||||||
protected abstract SubEpochReturn trainSubEpoch(Observation obs, int nstep);
|
protected abstract SubEpochReturn trainSubEpoch(Observation obs, int nstep);
|
||||||
|
|
||||||
private Learning.InitMdp<Observation> refacInitMdp() {
|
private Learning.InitMdp<Observation> refacInitMdp() {
|
||||||
currentEpochStep = 0;
|
currentEpisodeStepCount = 0;
|
||||||
|
|
||||||
double reward = 0;
|
double reward = 0;
|
||||||
|
|
||||||
LegacyMDPWrapper<O, A, AS> mdp = getLegacyMDPWrapper();
|
LegacyMDPWrapper<OBSERVATION, ACTION, ACTION_SPACE> mdp = getLegacyMDPWrapper();
|
||||||
Observation observation = mdp.reset();
|
Observation observation = mdp.reset();
|
||||||
|
|
||||||
A action = mdp.getActionSpace().noOp(); //by convention should be the NO_OP
|
ACTION action = mdp.getActionSpace().noOp(); //by convention should be the NO_OP
|
||||||
while (observation.isSkipped() && !mdp.isDone()) {
|
while (observation.isSkipped() && !mdp.isDone()) {
|
||||||
StepReply<Observation> stepReply = mdp.step(action);
|
StepReply<Observation> stepReply = mdp.step(action);
|
||||||
|
|
||||||
reward += stepReply.getReward();
|
reward += stepReply.getReward();
|
||||||
observation = stepReply.getObservation();
|
observation = stepReply.getObservation();
|
||||||
|
|
||||||
incrementStep();
|
incrementSteps();
|
||||||
}
|
}
|
||||||
|
|
||||||
return new Learning.InitMdp(0, observation, reward);
|
return new Learning.InitMdp(0, observation, reward);
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
public void incrementStep() {
|
public void incrementSteps() {
|
||||||
++stepCounter;
|
stepCount++;
|
||||||
++currentEpochStep;
|
currentEpisodeStepCount++;
|
||||||
}
|
}
|
||||||
|
|
||||||
@AllArgsConstructor
|
@AllArgsConstructor
|
||||||
|
@ -239,6 +259,7 @@ public abstract class AsyncThread<O, A, AS extends ActionSpace<A>, NN extends Ne
|
||||||
Observation lastObs;
|
Observation lastObs;
|
||||||
double reward;
|
double reward;
|
||||||
double score;
|
double score;
|
||||||
|
boolean episodeComplete;
|
||||||
}
|
}
|
||||||
|
|
||||||
@AllArgsConstructor
|
@AllArgsConstructor
|
||||||
|
|
|
@ -24,6 +24,9 @@ import lombok.Setter;
|
||||||
import org.deeplearning4j.gym.StepReply;
|
import org.deeplearning4j.gym.StepReply;
|
||||||
import org.deeplearning4j.rl4j.experience.ExperienceHandler;
|
import org.deeplearning4j.rl4j.experience.ExperienceHandler;
|
||||||
import org.deeplearning4j.rl4j.experience.StateActionExperienceHandler;
|
import org.deeplearning4j.rl4j.experience.StateActionExperienceHandler;
|
||||||
|
import org.deeplearning4j.rl4j.experience.ExperienceHandler;
|
||||||
|
import org.deeplearning4j.rl4j.experience.StateActionExperienceHandler;
|
||||||
|
import org.deeplearning4j.rl4j.experience.StateActionPair;
|
||||||
import org.deeplearning4j.rl4j.learning.IHistoryProcessor;
|
import org.deeplearning4j.rl4j.learning.IHistoryProcessor;
|
||||||
import org.deeplearning4j.rl4j.learning.listener.TrainingListenerList;
|
import org.deeplearning4j.rl4j.learning.listener.TrainingListenerList;
|
||||||
import org.deeplearning4j.rl4j.mdp.MDP;
|
import org.deeplearning4j.rl4j.mdp.MDP;
|
||||||
|
@ -31,15 +34,19 @@ import org.deeplearning4j.rl4j.network.NeuralNet;
|
||||||
import org.deeplearning4j.rl4j.observation.Observation;
|
import org.deeplearning4j.rl4j.observation.Observation;
|
||||||
import org.deeplearning4j.rl4j.policy.IPolicy;
|
import org.deeplearning4j.rl4j.policy.IPolicy;
|
||||||
import org.deeplearning4j.rl4j.space.DiscreteSpace;
|
import org.deeplearning4j.rl4j.space.DiscreteSpace;
|
||||||
|
import org.deeplearning4j.rl4j.space.Encodable;
|
||||||
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
|
|
||||||
|
import java.util.Stack;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @author rubenfiszel (ruben.fiszel@epfl.ch) on 8/5/16.
|
* @author rubenfiszel (ruben.fiszel@epfl.ch) on 8/5/16.
|
||||||
*
|
* <p>
|
||||||
* Async Learning specialized for the Discrete Domain
|
* Async Learning specialized for the Discrete Domain
|
||||||
*
|
|
||||||
*/
|
*/
|
||||||
public abstract class AsyncThreadDiscrete<O, NN extends NeuralNet>
|
public abstract class AsyncThreadDiscrete<O extends Encodable, NN extends NeuralNet>
|
||||||
extends AsyncThread<O, Integer, DiscreteSpace, NN> {
|
extends AsyncThread<O, Integer, DiscreteSpace, NN> {
|
||||||
|
|
||||||
@Getter
|
@Getter
|
||||||
private NN current;
|
private NN current;
|
||||||
|
@ -48,7 +55,7 @@ public abstract class AsyncThreadDiscrete<O, NN extends NeuralNet>
|
||||||
private UpdateAlgorithm<NN> updateAlgorithm;
|
private UpdateAlgorithm<NN> updateAlgorithm;
|
||||||
|
|
||||||
// TODO: Make it configurable with a builder
|
// TODO: Make it configurable with a builder
|
||||||
@Setter(AccessLevel.PROTECTED)
|
@Setter(AccessLevel.PROTECTED) @Getter
|
||||||
private ExperienceHandler experienceHandler = new StateActionExperienceHandler();
|
private ExperienceHandler experienceHandler = new StateActionExperienceHandler();
|
||||||
|
|
||||||
public AsyncThreadDiscrete(IAsyncGlobal<NN> asyncGlobal,
|
public AsyncThreadDiscrete(IAsyncGlobal<NN> asyncGlobal,
|
||||||
|
@ -56,9 +63,9 @@ public abstract class AsyncThreadDiscrete<O, NN extends NeuralNet>
|
||||||
TrainingListenerList listeners,
|
TrainingListenerList listeners,
|
||||||
int threadNumber,
|
int threadNumber,
|
||||||
int deviceNum) {
|
int deviceNum) {
|
||||||
super(asyncGlobal, mdp, listeners, threadNumber, deviceNum);
|
super(mdp, listeners, threadNumber, deviceNum);
|
||||||
synchronized (asyncGlobal) {
|
synchronized (asyncGlobal) {
|
||||||
current = (NN)asyncGlobal.getCurrent().clone();
|
current = (NN) asyncGlobal.getTarget().clone();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -72,7 +79,7 @@ public abstract class AsyncThreadDiscrete<O, NN extends NeuralNet>
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
protected void preEpoch() {
|
protected void preEpisode() {
|
||||||
experienceHandler.reset();
|
experienceHandler.reset();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -81,28 +88,23 @@ public abstract class AsyncThreadDiscrete<O, NN extends NeuralNet>
|
||||||
* "Subepoch" correspond to the t_max-step iterations
|
* "Subepoch" correspond to the t_max-step iterations
|
||||||
* that stack rewards with t_max MiniTrans
|
* that stack rewards with t_max MiniTrans
|
||||||
*
|
*
|
||||||
* @param sObs the obs to start from
|
* @param sObs the obs to start from
|
||||||
* @param nstep the number of max nstep (step until t_max or state is terminal)
|
* @param trainingSteps the number of training steps
|
||||||
* @return subepoch training informations
|
* @return subepoch training informations
|
||||||
*/
|
*/
|
||||||
public SubEpochReturn trainSubEpoch(Observation sObs, int nstep) {
|
public SubEpochReturn trainSubEpoch(Observation sObs, int trainingSteps) {
|
||||||
|
|
||||||
synchronized (getAsyncGlobal()) {
|
current.copy(getAsyncGlobal().getTarget());
|
||||||
current.copy(getAsyncGlobal().getCurrent());
|
|
||||||
}
|
|
||||||
|
|
||||||
Observation obs = sObs;
|
Observation obs = sObs;
|
||||||
IPolicy<O, Integer> policy = getPolicy(current);
|
IPolicy<O, Integer> policy = getPolicy(current);
|
||||||
|
|
||||||
Integer action = getMdp().getActionSpace().noOp();
|
Integer action = getMdp().getActionSpace().noOp();
|
||||||
IHistoryProcessor hp = getHistoryProcessor();
|
|
||||||
int skipFrame = hp != null ? hp.getConf().getSkipFrame() : 1;
|
|
||||||
|
|
||||||
double reward = 0;
|
double reward = 0;
|
||||||
double accuReward = 0;
|
double accuReward = 0;
|
||||||
int stepAtStart = getCurrentEpochStep();
|
|
||||||
int lastStep = nstep * skipFrame + stepAtStart;
|
while (!getMdp().isDone() && experienceHandler.getTrainingBatchSize() != trainingSteps) {
|
||||||
while (!getMdp().isDone() && getCurrentEpochStep() < lastStep) {
|
|
||||||
|
|
||||||
//if step of training, just repeat lastAction
|
//if step of training, just repeat lastAction
|
||||||
if (!obs.isSkipped()) {
|
if (!obs.isSkipped()) {
|
||||||
|
@ -115,20 +117,26 @@ public abstract class AsyncThreadDiscrete<O, NN extends NeuralNet>
|
||||||
if (!obs.isSkipped()) {
|
if (!obs.isSkipped()) {
|
||||||
experienceHandler.addExperience(obs, action, accuReward, stepReply.isDone());
|
experienceHandler.addExperience(obs, action, accuReward, stepReply.isDone());
|
||||||
accuReward = 0;
|
accuReward = 0;
|
||||||
|
|
||||||
|
incrementSteps();
|
||||||
}
|
}
|
||||||
|
|
||||||
obs = stepReply.getObservation();
|
obs = stepReply.getObservation();
|
||||||
reward += stepReply.getReward();
|
reward += stepReply.getReward();
|
||||||
|
|
||||||
incrementStep();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if (getMdp().isDone() && getCurrentEpochStep() < lastStep) {
|
boolean episodeComplete = getMdp().isDone() || getConf().getMaxEpochStep() == currentEpisodeStepCount;
|
||||||
|
|
||||||
|
if (episodeComplete && experienceHandler.getTrainingBatchSize() != trainingSteps) {
|
||||||
experienceHandler.setFinalObservation(obs);
|
experienceHandler.setFinalObservation(obs);
|
||||||
}
|
}
|
||||||
|
|
||||||
getAsyncGlobal().enqueue(updateAlgorithm.computeGradients(current, experienceHandler.generateTrainingBatch()), getCurrentEpochStep());
|
int experienceSize = experienceHandler.getTrainingBatchSize();
|
||||||
|
|
||||||
return new SubEpochReturn(getCurrentEpochStep() - stepAtStart, obs, reward, current.getLatestScore());
|
getAsyncGlobal().applyGradient(updateAlgorithm.computeGradients(current, experienceHandler.generateTrainingBatch()), experienceSize);
|
||||||
|
|
||||||
|
return new SubEpochReturn(experienceSize, obs, reward, current.getLatestScore(), episodeComplete);
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
/*******************************************************************************
|
/*******************************************************************************
|
||||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||||
|
* Copyright (c) 2020 Konduit K.K.
|
||||||
*
|
*
|
||||||
* This program and the accompanying materials are made available under the
|
* This program and the accompanying materials are made available under the
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
@ -22,17 +23,29 @@ import org.deeplearning4j.rl4j.network.NeuralNet;
|
||||||
import java.util.concurrent.atomic.AtomicInteger;
|
import java.util.concurrent.atomic.AtomicInteger;
|
||||||
|
|
||||||
public interface IAsyncGlobal<NN extends NeuralNet> {
|
public interface IAsyncGlobal<NN extends NeuralNet> {
|
||||||
boolean isRunning();
|
|
||||||
boolean isTrainingComplete();
|
boolean isTrainingComplete();
|
||||||
void start();
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Force the immediate termination of the AsyncGlobal instance. Queued work items will be discarded.
|
* The number of updates that have been applied by worker threads.
|
||||||
*/
|
*/
|
||||||
void terminate();
|
int getWorkerUpdateCount();
|
||||||
|
|
||||||
AtomicInteger getT();
|
/**
|
||||||
NN getCurrent();
|
* The total number of environment steps that have been processed.
|
||||||
|
*/
|
||||||
|
int getStepCount();
|
||||||
|
|
||||||
|
/**
|
||||||
|
* A copy of the global network that is updated after a certain number of worker episodes.
|
||||||
|
*/
|
||||||
NN getTarget();
|
NN getTarget();
|
||||||
void enqueue(Gradient[] gradient, Integer nstep);
|
|
||||||
|
/**
|
||||||
|
* Apply gradients to the global network
|
||||||
|
* @param gradient
|
||||||
|
* @param batchSize
|
||||||
|
*/
|
||||||
|
void applyGradient(Gradient[] gradient, int batchSize);
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,26 +1,26 @@
|
||||||
/*******************************************************************************
|
/*******************************************************************************
|
||||||
* Copyright (c) 2020 Konduit K.K.
|
* Copyright (c) 2020 Konduit K.K.
|
||||||
*
|
*
|
||||||
* This program and the accompanying materials are made available under the
|
* This program and the accompanying materials are made available under the
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
*
|
*
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
* License for the specific language governing permissions and limitations
|
* License for the specific language governing permissions and limitations
|
||||||
* under the License.
|
* under the License.
|
||||||
*
|
*
|
||||||
* SPDX-License-Identifier: Apache-2.0
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
******************************************************************************/
|
******************************************************************************/
|
||||||
package org.deeplearning4j.rl4j.learning.async;
|
package org.deeplearning4j.rl4j.learning.async;
|
||||||
|
|
||||||
import org.deeplearning4j.nn.gradient.Gradient;
|
import org.deeplearning4j.nn.gradient.Gradient;
|
||||||
import org.deeplearning4j.rl4j.experience.StateActionPair;
|
import org.deeplearning4j.rl4j.experience.StateActionPair;
|
||||||
import org.deeplearning4j.rl4j.network.NeuralNet;
|
import org.deeplearning4j.rl4j.network.NeuralNet;
|
||||||
|
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
|
||||||
public interface UpdateAlgorithm<NN extends NeuralNet> {
|
public interface UpdateAlgorithm<NN extends NeuralNet> {
|
||||||
Gradient[] computeGradients(NN current, List<StateActionPair<Integer>> experience);
|
Gradient[] computeGradients(NN current, List<StateActionPair<Integer>> experience);
|
||||||
}
|
}
|
||||||
|
|
|
@ -57,7 +57,7 @@ public abstract class A3CDiscrete<O extends Encodable> extends AsyncLearning<O,
|
||||||
this.iActorCritic = iActorCritic;
|
this.iActorCritic = iActorCritic;
|
||||||
this.mdp = mdp;
|
this.mdp = mdp;
|
||||||
this.configuration = conf;
|
this.configuration = conf;
|
||||||
asyncGlobal = new AsyncGlobal<>(iActorCritic, conf, this);
|
asyncGlobal = new AsyncGlobal<>(iActorCritic, conf);
|
||||||
|
|
||||||
Long seed = conf.getSeed();
|
Long seed = conf.getSeed();
|
||||||
Random rnd = Nd4j.getRandom();
|
Random rnd = Nd4j.getRandom();
|
||||||
|
|
|
@ -27,7 +27,6 @@ import org.deeplearning4j.rl4j.policy.ACPolicy;
|
||||||
import org.deeplearning4j.rl4j.policy.Policy;
|
import org.deeplearning4j.rl4j.policy.Policy;
|
||||||
import org.deeplearning4j.rl4j.space.DiscreteSpace;
|
import org.deeplearning4j.rl4j.space.DiscreteSpace;
|
||||||
import org.deeplearning4j.rl4j.space.Encodable;
|
import org.deeplearning4j.rl4j.space.Encodable;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
|
||||||
import org.nd4j.linalg.api.rng.Random;
|
import org.nd4j.linalg.api.rng.Random;
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
import org.nd4j.linalg.api.rng.Random;
|
import org.nd4j.linalg.api.rng.Random;
|
||||||
|
@ -73,6 +72,6 @@ public class A3CThreadDiscrete<O extends Encodable> extends AsyncThreadDiscrete<
|
||||||
@Override
|
@Override
|
||||||
protected UpdateAlgorithm<IActorCritic> buildUpdateAlgorithm() {
|
protected UpdateAlgorithm<IActorCritic> buildUpdateAlgorithm() {
|
||||||
int[] shape = getHistoryProcessor() == null ? getMdp().getObservationSpace().getShape() : getHistoryProcessor().getConf().getShape();
|
int[] shape = getHistoryProcessor() == null ? getMdp().getObservationSpace().getShape() : getHistoryProcessor().getConf().getShape();
|
||||||
return new A3CUpdateAlgorithm(asyncGlobal, shape, getMdp().getActionSpace().getSize(), conf.getLearnerUpdateFrequency(), conf.getGamma());
|
return new AdvantageActorCriticUpdateAlgorithm(asyncGlobal.getTarget().isRecurrent(), shape, getMdp().getActionSpace().getSize(), conf.getGamma());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -18,7 +18,6 @@ package org.deeplearning4j.rl4j.learning.async.a3c.discrete;
|
||||||
import org.deeplearning4j.nn.gradient.Gradient;
|
import org.deeplearning4j.nn.gradient.Gradient;
|
||||||
import org.deeplearning4j.rl4j.experience.StateActionPair;
|
import org.deeplearning4j.rl4j.experience.StateActionPair;
|
||||||
import org.deeplearning4j.rl4j.learning.Learning;
|
import org.deeplearning4j.rl4j.learning.Learning;
|
||||||
import org.deeplearning4j.rl4j.learning.async.IAsyncGlobal;
|
|
||||||
import org.deeplearning4j.rl4j.learning.async.UpdateAlgorithm;
|
import org.deeplearning4j.rl4j.learning.async.UpdateAlgorithm;
|
||||||
import org.deeplearning4j.rl4j.network.ac.IActorCritic;
|
import org.deeplearning4j.rl4j.network.ac.IActorCritic;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
@ -27,28 +26,25 @@ import org.nd4j.linalg.indexing.NDArrayIndex;
|
||||||
|
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
|
||||||
public class A3CUpdateAlgorithm implements UpdateAlgorithm<IActorCritic> {
|
/**
|
||||||
|
* The Advantage Actor-Critic update algorithm can be used by A2C and A3C algorithms alike
|
||||||
|
*/
|
||||||
|
public class AdvantageActorCriticUpdateAlgorithm implements UpdateAlgorithm<IActorCritic> {
|
||||||
|
|
||||||
private final IAsyncGlobal asyncGlobal;
|
|
||||||
private final int[] shape;
|
private final int[] shape;
|
||||||
private final int actionSpaceSize;
|
private final int actionSpaceSize;
|
||||||
private final int targetDqnUpdateFreq;
|
|
||||||
private final double gamma;
|
private final double gamma;
|
||||||
private final boolean recurrent;
|
private final boolean recurrent;
|
||||||
|
|
||||||
public A3CUpdateAlgorithm(IAsyncGlobal asyncGlobal,
|
public AdvantageActorCriticUpdateAlgorithm(boolean recurrent,
|
||||||
int[] shape,
|
int[] shape,
|
||||||
int actionSpaceSize,
|
int actionSpaceSize,
|
||||||
int targetDqnUpdateFreq,
|
double gamma) {
|
||||||
double gamma) {
|
|
||||||
|
|
||||||
this.asyncGlobal = asyncGlobal;
|
|
||||||
|
|
||||||
//if recurrent then train as a time serie with a batch size of 1
|
//if recurrent then train as a time serie with a batch size of 1
|
||||||
recurrent = asyncGlobal.getCurrent().isRecurrent();
|
this.recurrent = recurrent;
|
||||||
this.shape = shape;
|
this.shape = shape;
|
||||||
this.actionSpaceSize = actionSpaceSize;
|
this.actionSpaceSize = actionSpaceSize;
|
||||||
this.targetDqnUpdateFreq = targetDqnUpdateFreq;
|
|
||||||
this.gamma = gamma;
|
this.gamma = gamma;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -65,18 +61,12 @@ public class A3CUpdateAlgorithm implements UpdateAlgorithm<IActorCritic> {
|
||||||
: Nd4j.zeros(size, actionSpaceSize);
|
: Nd4j.zeros(size, actionSpaceSize);
|
||||||
|
|
||||||
StateActionPair<Integer> stateActionPair = experience.get(size - 1);
|
StateActionPair<Integer> stateActionPair = experience.get(size - 1);
|
||||||
double r;
|
double value;
|
||||||
if(stateActionPair.isTerminal()) {
|
if (stateActionPair.isTerminal()) {
|
||||||
r = 0;
|
value = 0;
|
||||||
}
|
} else {
|
||||||
else {
|
INDArray[] output = current.outputAll(stateActionPair.getObservation().getData());
|
||||||
INDArray[] output = null;
|
value = output[0].getDouble(0);
|
||||||
if (targetDqnUpdateFreq == -1)
|
|
||||||
output = current.outputAll(stateActionPair.getObservation().getData());
|
|
||||||
else synchronized (asyncGlobal) {
|
|
||||||
output = asyncGlobal.getTarget().outputAll(stateActionPair.getObservation().getData());
|
|
||||||
}
|
|
||||||
r = output[0].getDouble(0);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
for (int i = size - 1; i >= 0; --i) {
|
for (int i = size - 1; i >= 0; --i) {
|
||||||
|
@ -86,7 +76,7 @@ public class A3CUpdateAlgorithm implements UpdateAlgorithm<IActorCritic> {
|
||||||
|
|
||||||
INDArray[] output = current.outputAll(observationData);
|
INDArray[] output = current.outputAll(observationData);
|
||||||
|
|
||||||
r = stateActionPair.getReward() + gamma * r;
|
value = stateActionPair.getReward() + gamma * value;
|
||||||
if (recurrent) {
|
if (recurrent) {
|
||||||
input.get(NDArrayIndex.point(0), NDArrayIndex.all(), NDArrayIndex.point(i)).assign(observationData);
|
input.get(NDArrayIndex.point(0), NDArrayIndex.all(), NDArrayIndex.point(i)).assign(observationData);
|
||||||
} else {
|
} else {
|
||||||
|
@ -94,11 +84,11 @@ public class A3CUpdateAlgorithm implements UpdateAlgorithm<IActorCritic> {
|
||||||
}
|
}
|
||||||
|
|
||||||
//the critic
|
//the critic
|
||||||
targets.putScalar(i, r);
|
targets.putScalar(i, value);
|
||||||
|
|
||||||
//the actor
|
//the actor
|
||||||
double expectedV = output[0].getDouble(0);
|
double expectedV = output[0].getDouble(0);
|
||||||
double advantage = r - expectedV;
|
double advantage = value - expectedV;
|
||||||
if (recurrent) {
|
if (recurrent) {
|
||||||
logSoftmax.putScalar(0, stateActionPair.getAction(), i, advantage);
|
logSoftmax.putScalar(0, stateActionPair.getAction(), i, advantage);
|
||||||
} else {
|
} else {
|
||||||
|
@ -108,6 +98,6 @@ public class A3CUpdateAlgorithm implements UpdateAlgorithm<IActorCritic> {
|
||||||
|
|
||||||
// targets -> value, critic
|
// targets -> value, critic
|
||||||
// logSoftmax -> policy, actor
|
// logSoftmax -> policy, actor
|
||||||
return current.gradient(input, new INDArray[] {targets, logSoftmax});
|
return current.gradient(input, new INDArray[]{targets, logSoftmax});
|
||||||
}
|
}
|
||||||
}
|
}
|
|
@ -50,7 +50,7 @@ public abstract class AsyncNStepQLearningDiscrete<O extends Encodable>
|
||||||
public AsyncNStepQLearningDiscrete(MDP<O, Integer, DiscreteSpace> mdp, IDQN dqn, AsyncQLearningConfiguration conf) {
|
public AsyncNStepQLearningDiscrete(MDP<O, Integer, DiscreteSpace> mdp, IDQN dqn, AsyncQLearningConfiguration conf) {
|
||||||
this.mdp = mdp;
|
this.mdp = mdp;
|
||||||
this.configuration = conf;
|
this.configuration = conf;
|
||||||
this.asyncGlobal = new AsyncGlobal<>(dqn, conf, this);
|
this.asyncGlobal = new AsyncGlobal<>(dqn, conf);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
@ -59,7 +59,7 @@ public abstract class AsyncNStepQLearningDiscrete<O extends Encodable>
|
||||||
}
|
}
|
||||||
|
|
||||||
public IDQN getNeuralNet() {
|
public IDQN getNeuralNet() {
|
||||||
return asyncGlobal.getCurrent();
|
return asyncGlobal.getTarget();
|
||||||
}
|
}
|
||||||
|
|
||||||
public IPolicy<O, Integer> getPolicy() {
|
public IPolicy<O, Integer> getPolicy() {
|
||||||
|
|
|
@ -30,8 +30,8 @@ import org.deeplearning4j.rl4j.policy.EpsGreedy;
|
||||||
import org.deeplearning4j.rl4j.policy.Policy;
|
import org.deeplearning4j.rl4j.policy.Policy;
|
||||||
import org.deeplearning4j.rl4j.space.DiscreteSpace;
|
import org.deeplearning4j.rl4j.space.DiscreteSpace;
|
||||||
import org.deeplearning4j.rl4j.space.Encodable;
|
import org.deeplearning4j.rl4j.space.Encodable;
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
|
||||||
import org.nd4j.linalg.api.rng.Random;
|
import org.nd4j.linalg.api.rng.Random;
|
||||||
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @author rubenfiszel (ruben.fiszel@epfl.ch) on 8/5/16.
|
* @author rubenfiszel (ruben.fiszel@epfl.ch) on 8/5/16.
|
||||||
|
@ -57,7 +57,7 @@ public class AsyncNStepQLearningThreadDiscrete<O extends Encodable> extends Asyn
|
||||||
rnd = Nd4j.getRandom();
|
rnd = Nd4j.getRandom();
|
||||||
|
|
||||||
Long seed = conf.getSeed();
|
Long seed = conf.getSeed();
|
||||||
if(seed != null) {
|
if (seed != null) {
|
||||||
rnd.setSeed(seed + threadNumber);
|
rnd.setSeed(seed + threadNumber);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -72,6 +72,6 @@ public class AsyncNStepQLearningThreadDiscrete<O extends Encodable> extends Asyn
|
||||||
@Override
|
@Override
|
||||||
protected UpdateAlgorithm<IDQN> buildUpdateAlgorithm() {
|
protected UpdateAlgorithm<IDQN> buildUpdateAlgorithm() {
|
||||||
int[] shape = getHistoryProcessor() == null ? getMdp().getObservationSpace().getShape() : getHistoryProcessor().getConf().getShape();
|
int[] shape = getHistoryProcessor() == null ? getMdp().getObservationSpace().getShape() : getHistoryProcessor().getConf().getShape();
|
||||||
return new QLearningUpdateAlgorithm(asyncGlobal, shape, getMdp().getActionSpace().getSize(), conf.getTargetDqnUpdateFreq(), conf.getGamma());
|
return new QLearningUpdateAlgorithm(shape, getMdp().getActionSpace().getSize(), conf.getGamma());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -18,7 +18,6 @@ package org.deeplearning4j.rl4j.learning.async.nstep.discrete;
|
||||||
import org.deeplearning4j.nn.gradient.Gradient;
|
import org.deeplearning4j.nn.gradient.Gradient;
|
||||||
import org.deeplearning4j.rl4j.experience.StateActionPair;
|
import org.deeplearning4j.rl4j.experience.StateActionPair;
|
||||||
import org.deeplearning4j.rl4j.learning.Learning;
|
import org.deeplearning4j.rl4j.learning.Learning;
|
||||||
import org.deeplearning4j.rl4j.learning.async.IAsyncGlobal;
|
|
||||||
import org.deeplearning4j.rl4j.learning.async.UpdateAlgorithm;
|
import org.deeplearning4j.rl4j.learning.async.UpdateAlgorithm;
|
||||||
import org.deeplearning4j.rl4j.network.dqn.IDQN;
|
import org.deeplearning4j.rl4j.network.dqn.IDQN;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
@ -28,22 +27,16 @@ import java.util.List;
|
||||||
|
|
||||||
public class QLearningUpdateAlgorithm implements UpdateAlgorithm<IDQN> {
|
public class QLearningUpdateAlgorithm implements UpdateAlgorithm<IDQN> {
|
||||||
|
|
||||||
private final IAsyncGlobal asyncGlobal;
|
|
||||||
private final int[] shape;
|
private final int[] shape;
|
||||||
private final int actionSpaceSize;
|
private final int actionSpaceSize;
|
||||||
private final int targetDqnUpdateFreq;
|
|
||||||
private final double gamma;
|
private final double gamma;
|
||||||
|
|
||||||
public QLearningUpdateAlgorithm(IAsyncGlobal asyncGlobal,
|
public QLearningUpdateAlgorithm(int[] shape,
|
||||||
int[] shape,
|
|
||||||
int actionSpaceSize,
|
int actionSpaceSize,
|
||||||
int targetDqnUpdateFreq,
|
|
||||||
double gamma) {
|
double gamma) {
|
||||||
|
|
||||||
this.asyncGlobal = asyncGlobal;
|
|
||||||
this.shape = shape;
|
this.shape = shape;
|
||||||
this.actionSpaceSize = actionSpaceSize;
|
this.actionSpaceSize = actionSpaceSize;
|
||||||
this.targetDqnUpdateFreq = targetDqnUpdateFreq;
|
|
||||||
this.gamma = gamma;
|
this.gamma = gamma;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -58,16 +51,11 @@ public class QLearningUpdateAlgorithm implements UpdateAlgorithm<IDQN> {
|
||||||
StateActionPair<Integer> stateActionPair = experience.get(size - 1);
|
StateActionPair<Integer> stateActionPair = experience.get(size - 1);
|
||||||
|
|
||||||
double r;
|
double r;
|
||||||
if(stateActionPair.isTerminal()) {
|
if (stateActionPair.isTerminal()) {
|
||||||
r = 0;
|
r = 0;
|
||||||
}
|
} else {
|
||||||
else {
|
|
||||||
INDArray[] output = null;
|
INDArray[] output = null;
|
||||||
if (targetDqnUpdateFreq == -1)
|
output = current.outputAll(stateActionPair.getObservation().getData());
|
||||||
output = current.outputAll(stateActionPair.getObservation().getData());
|
|
||||||
else synchronized (asyncGlobal) {
|
|
||||||
output = asyncGlobal.getTarget().outputAll(stateActionPair.getObservation().getData());
|
|
||||||
}
|
|
||||||
r = Nd4j.max(output[0]).getDouble(0);
|
r = Nd4j.max(output[0]).getDouble(0);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -20,8 +20,14 @@ public interface IAsyncLearningConfiguration extends ILearningConfiguration {
|
||||||
|
|
||||||
int getNumThreads();
|
int getNumThreads();
|
||||||
|
|
||||||
|
/**
|
||||||
|
* The number of steps to collect for each worker thread between each global update
|
||||||
|
*/
|
||||||
int getNStep();
|
int getNStep();
|
||||||
|
|
||||||
|
/**
|
||||||
|
* The frequency of worker thread gradient updates to perform a copy of the current working network to the target network
|
||||||
|
*/
|
||||||
int getLearnerUpdateFrequency();
|
int getLearnerUpdateFrequency();
|
||||||
|
|
||||||
int getMaxStep();
|
int getMaxStep();
|
||||||
|
|
|
@ -25,6 +25,7 @@ import org.deeplearning4j.rl4j.learning.Learning;
|
||||||
import org.deeplearning4j.rl4j.learning.listener.*;
|
import org.deeplearning4j.rl4j.learning.listener.*;
|
||||||
import org.deeplearning4j.rl4j.network.NeuralNet;
|
import org.deeplearning4j.rl4j.network.NeuralNet;
|
||||||
import org.deeplearning4j.rl4j.space.ActionSpace;
|
import org.deeplearning4j.rl4j.space.ActionSpace;
|
||||||
|
import org.deeplearning4j.rl4j.space.Encodable;
|
||||||
import org.deeplearning4j.rl4j.util.IDataManager;
|
import org.deeplearning4j.rl4j.util.IDataManager;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -35,8 +36,8 @@ import org.deeplearning4j.rl4j.util.IDataManager;
|
||||||
* @author Alexandre Boulanger
|
* @author Alexandre Boulanger
|
||||||
*/
|
*/
|
||||||
@Slf4j
|
@Slf4j
|
||||||
public abstract class SyncLearning<O, A, AS extends ActionSpace<A>, NN extends NeuralNet>
|
public abstract class SyncLearning<OBSERVATION extends Encodable, ACTION, ACTION_SPACE extends ActionSpace<ACTION>, NN extends NeuralNet>
|
||||||
extends Learning<O, A, AS, NN> implements IEpochTrainer {
|
extends Learning<OBSERVATION, ACTION, ACTION_SPACE, NN> implements IEpochTrainer {
|
||||||
|
|
||||||
private final TrainingListenerList listeners = new TrainingListenerList();
|
private final TrainingListenerList listeners = new TrainingListenerList();
|
||||||
|
|
||||||
|
@ -85,7 +86,7 @@ public abstract class SyncLearning<O, A, AS extends ActionSpace<A>, NN extends N
|
||||||
|
|
||||||
boolean canContinue = listeners.notifyTrainingStarted();
|
boolean canContinue = listeners.notifyTrainingStarted();
|
||||||
if (canContinue) {
|
if (canContinue) {
|
||||||
while (getStepCounter() < getConfiguration().getMaxStep()) {
|
while (this.getStepCount() < getConfiguration().getMaxStep()) {
|
||||||
preEpoch();
|
preEpoch();
|
||||||
canContinue = listeners.notifyNewEpoch(this);
|
canContinue = listeners.notifyNewEpoch(this);
|
||||||
if (!canContinue) {
|
if (!canContinue) {
|
||||||
|
@ -100,14 +101,14 @@ public abstract class SyncLearning<O, A, AS extends ActionSpace<A>, NN extends N
|
||||||
|
|
||||||
postEpoch();
|
postEpoch();
|
||||||
|
|
||||||
if(getEpochCounter() % progressMonitorFrequency == 0) {
|
if(getEpochCount() % progressMonitorFrequency == 0) {
|
||||||
canContinue = listeners.notifyTrainingProgress(this);
|
canContinue = listeners.notifyTrainingProgress(this);
|
||||||
if (!canContinue) {
|
if (!canContinue) {
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
log.info("Epoch: " + getEpochCounter() + ", reward: " + statEntry.getReward());
|
log.info("Epoch: " + getEpochCount() + ", reward: " + statEntry.getReward());
|
||||||
incrementEpoch();
|
incrementEpoch();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -19,21 +19,16 @@ package org.deeplearning4j.rl4j.learning.sync.qlearning;
|
||||||
|
|
||||||
import com.fasterxml.jackson.databind.annotation.JsonDeserialize;
|
import com.fasterxml.jackson.databind.annotation.JsonDeserialize;
|
||||||
import com.fasterxml.jackson.databind.annotation.JsonPOJOBuilder;
|
import com.fasterxml.jackson.databind.annotation.JsonPOJOBuilder;
|
||||||
import lombok.AccessLevel;
|
|
||||||
import lombok.AllArgsConstructor;
|
import lombok.AllArgsConstructor;
|
||||||
import lombok.Builder;
|
import lombok.Builder;
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
import lombok.EqualsAndHashCode;
|
import lombok.EqualsAndHashCode;
|
||||||
import lombok.Getter;
|
import lombok.Getter;
|
||||||
import lombok.Setter;
|
|
||||||
import lombok.Value;
|
import lombok.Value;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import org.deeplearning4j.gym.StepReply;
|
import org.deeplearning4j.gym.StepReply;
|
||||||
import org.deeplearning4j.rl4j.learning.EpochStepCounter;
|
import org.deeplearning4j.rl4j.learning.IEpochTrainer;
|
||||||
import org.deeplearning4j.rl4j.learning.configuration.ILearningConfiguration;
|
|
||||||
import org.deeplearning4j.rl4j.learning.configuration.QLearningConfiguration;
|
import org.deeplearning4j.rl4j.learning.configuration.QLearningConfiguration;
|
||||||
import org.deeplearning4j.rl4j.learning.sync.ExpReplay;
|
|
||||||
import org.deeplearning4j.rl4j.learning.sync.IExpReplay;
|
|
||||||
import org.deeplearning4j.rl4j.learning.sync.SyncLearning;
|
import org.deeplearning4j.rl4j.learning.sync.SyncLearning;
|
||||||
import org.deeplearning4j.rl4j.mdp.MDP;
|
import org.deeplearning4j.rl4j.mdp.MDP;
|
||||||
import org.deeplearning4j.rl4j.network.dqn.IDQN;
|
import org.deeplearning4j.rl4j.network.dqn.IDQN;
|
||||||
|
@ -43,8 +38,6 @@ import org.deeplearning4j.rl4j.space.ActionSpace;
|
||||||
import org.deeplearning4j.rl4j.space.Encodable;
|
import org.deeplearning4j.rl4j.space.Encodable;
|
||||||
import org.deeplearning4j.rl4j.util.IDataManager.StatEntry;
|
import org.deeplearning4j.rl4j.util.IDataManager.StatEntry;
|
||||||
import org.deeplearning4j.rl4j.util.LegacyMDPWrapper;
|
import org.deeplearning4j.rl4j.util.LegacyMDPWrapper;
|
||||||
import org.nd4j.linalg.api.rng.Random;
|
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
|
||||||
|
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
@ -58,7 +51,7 @@ import java.util.List;
|
||||||
@Slf4j
|
@Slf4j
|
||||||
public abstract class QLearning<O extends Encodable, A, AS extends ActionSpace<A>>
|
public abstract class QLearning<O extends Encodable, A, AS extends ActionSpace<A>>
|
||||||
extends SyncLearning<O, A, AS, IDQN>
|
extends SyncLearning<O, A, AS, IDQN>
|
||||||
implements TargetQNetworkSource, EpochStepCounter {
|
implements TargetQNetworkSource, IEpochTrainer {
|
||||||
|
|
||||||
protected abstract LegacyMDPWrapper<O, A, AS> getLegacyMDPWrapper();
|
protected abstract LegacyMDPWrapper<O, A, AS> getLegacyMDPWrapper();
|
||||||
|
|
||||||
|
@ -90,7 +83,10 @@ public abstract class QLearning<O extends Encodable, A, AS extends ActionSpace<A
|
||||||
protected abstract QLStepReturn<Observation> trainStep(Observation obs);
|
protected abstract QLStepReturn<Observation> trainStep(Observation obs);
|
||||||
|
|
||||||
@Getter
|
@Getter
|
||||||
private int currentEpochStep = 0;
|
private int episodeCount;
|
||||||
|
|
||||||
|
@Getter
|
||||||
|
private int currentEpisodeStepCount = 0;
|
||||||
|
|
||||||
protected StatEntry trainEpoch() {
|
protected StatEntry trainEpoch() {
|
||||||
resetNetworks();
|
resetNetworks();
|
||||||
|
@ -104,9 +100,9 @@ public abstract class QLearning<O extends Encodable, A, AS extends ActionSpace<A
|
||||||
double meanQ = 0;
|
double meanQ = 0;
|
||||||
int numQ = 0;
|
int numQ = 0;
|
||||||
List<Double> scores = new ArrayList<>();
|
List<Double> scores = new ArrayList<>();
|
||||||
while (currentEpochStep < getConfiguration().getMaxEpochStep() && !getMdp().isDone()) {
|
while (currentEpisodeStepCount < getConfiguration().getMaxEpochStep() && !getMdp().isDone()) {
|
||||||
|
|
||||||
if (getStepCounter() % getConfiguration().getTargetDqnUpdateFreq() == 0) {
|
if (this.getStepCount() % getConfiguration().getTargetDqnUpdateFreq() == 0) {
|
||||||
updateTargetNetwork();
|
updateTargetNetwork();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -132,20 +128,20 @@ public abstract class QLearning<O extends Encodable, A, AS extends ActionSpace<A
|
||||||
meanQ /= (numQ + 0.001); //avoid div zero
|
meanQ /= (numQ + 0.001); //avoid div zero
|
||||||
|
|
||||||
|
|
||||||
StatEntry statEntry = new QLStatEntry(getStepCounter(), getEpochCounter(), reward, currentEpochStep, scores,
|
StatEntry statEntry = new QLStatEntry(this.getStepCount(), getEpochCount(), reward, currentEpisodeStepCount, scores,
|
||||||
getEgPolicy().getEpsilon(), startQ, meanQ);
|
getEgPolicy().getEpsilon(), startQ, meanQ);
|
||||||
|
|
||||||
return statEntry;
|
return statEntry;
|
||||||
}
|
}
|
||||||
|
|
||||||
protected void finishEpoch(Observation observation) {
|
protected void finishEpoch(Observation observation) {
|
||||||
// Do Nothing
|
episodeCount++;
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void incrementStep() {
|
public void incrementStep() {
|
||||||
super.incrementStep();
|
super.incrementStep();
|
||||||
++currentEpochStep;
|
++currentEpisodeStepCount;
|
||||||
}
|
}
|
||||||
|
|
||||||
protected void resetNetworks() {
|
protected void resetNetworks() {
|
||||||
|
@ -154,7 +150,7 @@ public abstract class QLearning<O extends Encodable, A, AS extends ActionSpace<A
|
||||||
}
|
}
|
||||||
|
|
||||||
private InitMdp<Observation> refacInitMdp() {
|
private InitMdp<Observation> refacInitMdp() {
|
||||||
currentEpochStep = 0;
|
currentEpisodeStepCount = 0;
|
||||||
|
|
||||||
double reward = 0;
|
double reward = 0;
|
||||||
|
|
||||||
|
|
|
@ -47,6 +47,7 @@ import org.nd4j.linalg.factory.Nd4j;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @author rubenfiszel (ruben.fiszel@epfl.ch) 7/18/16.
|
* @author rubenfiszel (ruben.fiszel@epfl.ch) 7/18/16.
|
||||||
* <p>
|
* <p>
|
||||||
|
@ -90,7 +91,7 @@ public abstract class QLearningDiscrete<O extends Encodable> extends QLearning<O
|
||||||
public QLearningDiscrete(MDP<O, Integer, DiscreteSpace> mdp, IDQN dqn, QLearningConfiguration conf,
|
public QLearningDiscrete(MDP<O, Integer, DiscreteSpace> mdp, IDQN dqn, QLearningConfiguration conf,
|
||||||
int epsilonNbStep, Random random) {
|
int epsilonNbStep, Random random) {
|
||||||
this.configuration = conf;
|
this.configuration = conf;
|
||||||
this.mdp = new LegacyMDPWrapper<>(mdp, null, this);
|
this.mdp = new LegacyMDPWrapper<>(mdp, null);
|
||||||
qNetwork = dqn;
|
qNetwork = dqn;
|
||||||
targetQNetwork = dqn.clone();
|
targetQNetwork = dqn.clone();
|
||||||
policy = new DQNPolicy(getQNetwork());
|
policy = new DQNPolicy(getQNetwork());
|
||||||
|
@ -164,13 +165,13 @@ public abstract class QLearningDiscrete<O extends Encodable> extends QLearning<O
|
||||||
|
|
||||||
// Update NN
|
// Update NN
|
||||||
// FIXME: maybe start updating when experience replay has reached a certain size instead of using "updateStart"?
|
// FIXME: maybe start updating when experience replay has reached a certain size instead of using "updateStart"?
|
||||||
if (getStepCounter() > updateStart) {
|
if (this.getStepCount() > updateStart) {
|
||||||
DataSet targets = setTarget(experienceHandler.generateTrainingBatch());
|
DataSet targets = setTarget(experienceHandler.generateTrainingBatch());
|
||||||
getQNetwork().fit(targets.getFeatures(), targets.getLabels());
|
getQNetwork().fit(targets.getFeatures(), targets.getLabels());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return new QLStepReturn<Observation>(maxQ, getQNetwork().getLatestScore(), stepReply);
|
return new QLStepReturn<>(maxQ, getQNetwork().getLatestScore(), stepReply);
|
||||||
}
|
}
|
||||||
|
|
||||||
protected DataSet setTarget(List<Transition<Integer>> transitions) {
|
protected DataSet setTarget(List<Transition<Integer>> transitions) {
|
||||||
|
|
|
@ -17,6 +17,7 @@
|
||||||
package org.deeplearning4j.rl4j.observation;
|
package org.deeplearning4j.rl4j.observation;
|
||||||
|
|
||||||
import lombok.Getter;
|
import lombok.Getter;
|
||||||
|
import org.deeplearning4j.rl4j.space.Encodable;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|
|
@ -40,7 +40,7 @@ public class EncodableToImageWritableTransform implements Operation<Encodable, I
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public ImageWritable transform(Encodable encodable) {
|
public ImageWritable transform(Encodable encodable) {
|
||||||
INDArray indArray = Nd4j.create((encodable).toArray()).reshape(height, width, colorChannels);
|
INDArray indArray = Nd4j.create(encodable.toArray()).reshape(height, width, colorChannels);
|
||||||
Mat mat = new Mat(height, width, CV_32FC(3), indArray.data().pointer());
|
Mat mat = new Mat(height, width, CV_32FC(3), indArray.data().pointer());
|
||||||
return new ImageWritable(converter.convert(mat));
|
return new ImageWritable(converter.convert(mat));
|
||||||
}
|
}
|
||||||
|
|
|
@ -25,6 +25,7 @@ import org.deeplearning4j.rl4j.mdp.MDP;
|
||||||
import org.deeplearning4j.rl4j.network.NeuralNet;
|
import org.deeplearning4j.rl4j.network.NeuralNet;
|
||||||
import org.deeplearning4j.rl4j.observation.Observation;
|
import org.deeplearning4j.rl4j.observation.Observation;
|
||||||
import org.deeplearning4j.rl4j.space.ActionSpace;
|
import org.deeplearning4j.rl4j.space.ActionSpace;
|
||||||
|
import org.deeplearning4j.rl4j.space.Encodable;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.api.rng.Random;
|
import org.nd4j.linalg.api.rng.Random;
|
||||||
|
|
||||||
|
@ -40,7 +41,7 @@ import org.nd4j.linalg.api.rng.Random;
|
||||||
*/
|
*/
|
||||||
@AllArgsConstructor
|
@AllArgsConstructor
|
||||||
@Slf4j
|
@Slf4j
|
||||||
public class EpsGreedy<O, A, AS extends ActionSpace<A>> extends Policy<O, A> {
|
public class EpsGreedy<O extends Encodable, A, AS extends ActionSpace<A>> extends Policy<O, A> {
|
||||||
|
|
||||||
final private Policy<O, A> policy;
|
final private Policy<O, A> policy;
|
||||||
final private MDP<O, A, AS> mdp;
|
final private MDP<O, A, AS> mdp;
|
||||||
|
@ -57,8 +58,8 @@ public class EpsGreedy<O, A, AS extends ActionSpace<A>> extends Policy<O, A> {
|
||||||
public A nextAction(INDArray input) {
|
public A nextAction(INDArray input) {
|
||||||
|
|
||||||
double ep = getEpsilon();
|
double ep = getEpsilon();
|
||||||
if (learning.getStepCounter() % 500 == 1)
|
if (learning.getStepCount() % 500 == 1)
|
||||||
log.info("EP: " + ep + " " + learning.getStepCounter());
|
log.info("EP: " + ep + " " + learning.getStepCount());
|
||||||
if (rnd.nextDouble() > ep)
|
if (rnd.nextDouble() > ep)
|
||||||
return policy.nextAction(input);
|
return policy.nextAction(input);
|
||||||
else
|
else
|
||||||
|
@ -70,6 +71,6 @@ public class EpsGreedy<O, A, AS extends ActionSpace<A>> extends Policy<O, A> {
|
||||||
}
|
}
|
||||||
|
|
||||||
public double getEpsilon() {
|
public double getEpsilon() {
|
||||||
return Math.min(1.0, Math.max(minEpsilon, 1.0 - (learning.getStepCounter() - updateStart) * 1.0 / epsilonNbStep));
|
return Math.min(1.0, Math.max(minEpsilon, 1.0 - (learning.getStepCount() - updateStart) * 1.0 / epsilonNbStep));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -7,7 +7,7 @@ import org.deeplearning4j.rl4j.space.ActionSpace;
|
||||||
import org.deeplearning4j.rl4j.space.Encodable;
|
import org.deeplearning4j.rl4j.space.Encodable;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
|
||||||
public interface IPolicy<O, A> {
|
public interface IPolicy<O extends Encodable, A> {
|
||||||
<AS extends ActionSpace<A>> double play(MDP<O, A, AS> mdp, IHistoryProcessor hp);
|
<AS extends ActionSpace<A>> double play(MDP<O, A, AS> mdp, IHistoryProcessor hp);
|
||||||
A nextAction(INDArray input);
|
A nextAction(INDArray input);
|
||||||
A nextAction(Observation observation);
|
A nextAction(Observation observation);
|
||||||
|
|
|
@ -16,10 +16,7 @@
|
||||||
|
|
||||||
package org.deeplearning4j.rl4j.policy;
|
package org.deeplearning4j.rl4j.policy;
|
||||||
|
|
||||||
import lombok.Getter;
|
|
||||||
import lombok.Setter;
|
|
||||||
import org.deeplearning4j.gym.StepReply;
|
import org.deeplearning4j.gym.StepReply;
|
||||||
import org.deeplearning4j.rl4j.learning.EpochStepCounter;
|
|
||||||
import org.deeplearning4j.rl4j.learning.HistoryProcessor;
|
import org.deeplearning4j.rl4j.learning.HistoryProcessor;
|
||||||
import org.deeplearning4j.rl4j.learning.IHistoryProcessor;
|
import org.deeplearning4j.rl4j.learning.IHistoryProcessor;
|
||||||
import org.deeplearning4j.rl4j.learning.Learning;
|
import org.deeplearning4j.rl4j.learning.Learning;
|
||||||
|
@ -27,6 +24,7 @@ import org.deeplearning4j.rl4j.mdp.MDP;
|
||||||
import org.deeplearning4j.rl4j.network.NeuralNet;
|
import org.deeplearning4j.rl4j.network.NeuralNet;
|
||||||
import org.deeplearning4j.rl4j.observation.Observation;
|
import org.deeplearning4j.rl4j.observation.Observation;
|
||||||
import org.deeplearning4j.rl4j.space.ActionSpace;
|
import org.deeplearning4j.rl4j.space.ActionSpace;
|
||||||
|
import org.deeplearning4j.rl4j.space.Encodable;
|
||||||
import org.deeplearning4j.rl4j.util.LegacyMDPWrapper;
|
import org.deeplearning4j.rl4j.util.LegacyMDPWrapper;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -36,7 +34,7 @@ import org.deeplearning4j.rl4j.util.LegacyMDPWrapper;
|
||||||
*
|
*
|
||||||
* A Policy responsability is to choose the next action given a state
|
* A Policy responsability is to choose the next action given a state
|
||||||
*/
|
*/
|
||||||
public abstract class Policy<O, A> implements IPolicy<O, A> {
|
public abstract class Policy<O extends Encodable, A> implements IPolicy<O, A> {
|
||||||
|
|
||||||
public abstract NeuralNet getNeuralNet();
|
public abstract NeuralNet getNeuralNet();
|
||||||
|
|
||||||
|
@ -54,10 +52,9 @@ public abstract class Policy<O, A> implements IPolicy<O, A> {
|
||||||
public <AS extends ActionSpace<A>> double play(MDP<O, A, AS> mdp, IHistoryProcessor hp) {
|
public <AS extends ActionSpace<A>> double play(MDP<O, A, AS> mdp, IHistoryProcessor hp) {
|
||||||
resetNetworks();
|
resetNetworks();
|
||||||
|
|
||||||
RefacEpochStepCounter epochStepCounter = new RefacEpochStepCounter();
|
LegacyMDPWrapper<O, A, AS> mdpWrapper = new LegacyMDPWrapper<O, A, AS>(mdp, hp);
|
||||||
LegacyMDPWrapper<O, A, AS> mdpWrapper = new LegacyMDPWrapper<O, A, AS>(mdp, hp, epochStepCounter);
|
|
||||||
|
|
||||||
Learning.InitMdp<Observation> initMdp = refacInitMdp(mdpWrapper, hp, epochStepCounter);
|
Learning.InitMdp<Observation> initMdp = refacInitMdp(mdpWrapper, hp);
|
||||||
Observation obs = initMdp.getLastObs();
|
Observation obs = initMdp.getLastObs();
|
||||||
|
|
||||||
double reward = initMdp.getReward();
|
double reward = initMdp.getReward();
|
||||||
|
@ -79,7 +76,6 @@ public abstract class Policy<O, A> implements IPolicy<O, A> {
|
||||||
reward += stepReply.getReward();
|
reward += stepReply.getReward();
|
||||||
|
|
||||||
obs = stepReply.getObservation();
|
obs = stepReply.getObservation();
|
||||||
epochStepCounter.incrementEpochStep();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return reward;
|
return reward;
|
||||||
|
@ -89,8 +85,7 @@ public abstract class Policy<O, A> implements IPolicy<O, A> {
|
||||||
getNeuralNet().reset();
|
getNeuralNet().reset();
|
||||||
}
|
}
|
||||||
|
|
||||||
protected <AS extends ActionSpace<A>> Learning.InitMdp<Observation> refacInitMdp(LegacyMDPWrapper<O, A, AS> mdpWrapper, IHistoryProcessor hp, RefacEpochStepCounter epochStepCounter) {
|
protected <AS extends ActionSpace<A>> Learning.InitMdp<Observation> refacInitMdp(LegacyMDPWrapper<O, A, AS> mdpWrapper, IHistoryProcessor hp) {
|
||||||
epochStepCounter.setCurrentEpochStep(0);
|
|
||||||
|
|
||||||
double reward = 0;
|
double reward = 0;
|
||||||
|
|
||||||
|
@ -104,21 +99,9 @@ public abstract class Policy<O, A> implements IPolicy<O, A> {
|
||||||
reward += stepReply.getReward();
|
reward += stepReply.getReward();
|
||||||
observation = stepReply.getObservation();
|
observation = stepReply.getObservation();
|
||||||
|
|
||||||
epochStepCounter.incrementEpochStep();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return new Learning.InitMdp(0, observation, reward);
|
return new Learning.InitMdp(0, observation, reward);
|
||||||
}
|
}
|
||||||
|
|
||||||
public class RefacEpochStepCounter implements EpochStepCounter {
|
|
||||||
|
|
||||||
@Getter
|
|
||||||
@Setter
|
|
||||||
private int currentEpochStep = 0;
|
|
||||||
|
|
||||||
public void incrementEpochStep() {
|
|
||||||
++currentEpochStep;
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -278,7 +278,7 @@ public class DataManager implements IDataManager {
|
||||||
Path infoPath = Paths.get(getInfo());
|
Path infoPath = Paths.get(getInfo());
|
||||||
|
|
||||||
Info info = new Info(iLearning.getClass().getSimpleName(), iLearning.getMdp().getClass().getSimpleName(),
|
Info info = new Info(iLearning.getClass().getSimpleName(), iLearning.getMdp().getClass().getSimpleName(),
|
||||||
iLearning.getConfiguration(), iLearning.getStepCounter(), System.currentTimeMillis());
|
iLearning.getConfiguration(), iLearning.getStepCount(), System.currentTimeMillis());
|
||||||
String toWrite = toJson(info);
|
String toWrite = toJson(info);
|
||||||
|
|
||||||
Files.write(infoPath, toWrite.getBytes(), StandardOpenOption.TRUNCATE_EXISTING);
|
Files.write(infoPath, toWrite.getBytes(), StandardOpenOption.TRUNCATE_EXISTING);
|
||||||
|
@ -300,12 +300,12 @@ public class DataManager implements IDataManager {
|
||||||
if (!saveData)
|
if (!saveData)
|
||||||
return;
|
return;
|
||||||
|
|
||||||
save(getModelDir() + "/" + learning.getStepCounter() + ".training", learning);
|
save(getModelDir() + "/" + learning.getStepCount() + ".training", learning);
|
||||||
if(learning instanceof NeuralNetFetchable) {
|
if(learning instanceof NeuralNetFetchable) {
|
||||||
try {
|
try {
|
||||||
((NeuralNetFetchable)learning).getNeuralNet().save(getModelDir() + "/" + learning.getStepCounter() + ".model");
|
((NeuralNetFetchable)learning).getNeuralNet().save(getModelDir() + "/" + learning.getStepCount() + ".model");
|
||||||
} catch (UnsupportedOperationException e) {
|
} catch (UnsupportedOperationException e) {
|
||||||
String path = getModelDir() + "/" + learning.getStepCounter();
|
String path = getModelDir() + "/" + learning.getStepCount();
|
||||||
((IActorCritic)((NeuralNetFetchable)learning).getNeuralNet()).save(path + "_value.model", path + "_policy.model");
|
((IActorCritic)((NeuralNetFetchable)learning).getNeuralNet()).save(path + "_value.model", path + "_policy.model");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -40,7 +40,7 @@ public class DataManagerTrainingListener implements TrainingListener {
|
||||||
if (trainer instanceof AsyncThread) {
|
if (trainer instanceof AsyncThread) {
|
||||||
filename += ((AsyncThread) trainer).getThreadNumber() + "-";
|
filename += ((AsyncThread) trainer).getThreadNumber() + "-";
|
||||||
}
|
}
|
||||||
filename += trainer.getEpochCounter() + "-" + trainer.getStepCounter() + ".mp4";
|
filename += trainer.getEpochCount() + "-" + trainer.getStepCount() + ".mp4";
|
||||||
hp.startMonitor(filename, shape);
|
hp.startMonitor(filename, shape);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -66,7 +66,7 @@ public class DataManagerTrainingListener implements TrainingListener {
|
||||||
@Override
|
@Override
|
||||||
public ListenerResponse onTrainingProgress(ILearning learning) {
|
public ListenerResponse onTrainingProgress(ILearning learning) {
|
||||||
try {
|
try {
|
||||||
int stepCounter = learning.getStepCounter();
|
int stepCounter = learning.getStepCount();
|
||||||
if (stepCounter - lastSave >= Constants.MODEL_SAVE_FREQ) {
|
if (stepCounter - lastSave >= Constants.MODEL_SAVE_FREQ) {
|
||||||
dataManager.save(learning);
|
dataManager.save(learning);
|
||||||
lastSave = stepCounter;
|
lastSave = stepCounter;
|
||||||
|
|
|
@ -8,7 +8,6 @@ import org.datavec.image.transform.CropImageTransform;
|
||||||
import org.datavec.image.transform.MultiImageTransform;
|
import org.datavec.image.transform.MultiImageTransform;
|
||||||
import org.datavec.image.transform.ResizeImageTransform;
|
import org.datavec.image.transform.ResizeImageTransform;
|
||||||
import org.deeplearning4j.gym.StepReply;
|
import org.deeplearning4j.gym.StepReply;
|
||||||
import org.deeplearning4j.rl4j.learning.EpochStepCounter;
|
|
||||||
import org.deeplearning4j.rl4j.learning.IHistoryProcessor;
|
import org.deeplearning4j.rl4j.learning.IHistoryProcessor;
|
||||||
import org.deeplearning4j.rl4j.mdp.MDP;
|
import org.deeplearning4j.rl4j.mdp.MDP;
|
||||||
import org.deeplearning4j.rl4j.observation.Observation;
|
import org.deeplearning4j.rl4j.observation.Observation;
|
||||||
|
@ -30,10 +29,10 @@ import java.util.Map;
|
||||||
|
|
||||||
import static org.bytedeco.opencv.global.opencv_imgproc.COLOR_BGR2GRAY;
|
import static org.bytedeco.opencv.global.opencv_imgproc.COLOR_BGR2GRAY;
|
||||||
|
|
||||||
public class LegacyMDPWrapper<O, A, AS extends ActionSpace<A>> implements MDP<Observation, A, AS> {
|
public class LegacyMDPWrapper<OBSERVATION extends Encodable, A, AS extends ActionSpace<A>> implements MDP<Observation, A, AS> {
|
||||||
|
|
||||||
@Getter
|
@Getter
|
||||||
private final MDP<O, A, AS> wrappedMDP;
|
private final MDP<OBSERVATION, A, AS> wrappedMDP;
|
||||||
@Getter
|
@Getter
|
||||||
private final WrapperObservationSpace observationSpace;
|
private final WrapperObservationSpace observationSpace;
|
||||||
private final int[] shape;
|
private final int[] shape;
|
||||||
|
@ -44,16 +43,14 @@ public class LegacyMDPWrapper<O, A, AS extends ActionSpace<A>> implements MDP<Ob
|
||||||
@Getter(AccessLevel.PRIVATE)
|
@Getter(AccessLevel.PRIVATE)
|
||||||
private IHistoryProcessor historyProcessor;
|
private IHistoryProcessor historyProcessor;
|
||||||
|
|
||||||
private final EpochStepCounter epochStepCounter;
|
|
||||||
|
|
||||||
private int skipFrame = 1;
|
private int skipFrame = 1;
|
||||||
|
private int steps = 0;
|
||||||
|
|
||||||
public LegacyMDPWrapper(MDP<O, A, AS> wrappedMDP, IHistoryProcessor historyProcessor, EpochStepCounter epochStepCounter) {
|
public LegacyMDPWrapper(MDP<OBSERVATION, A, AS> wrappedMDP, IHistoryProcessor historyProcessor) {
|
||||||
this.wrappedMDP = wrappedMDP;
|
this.wrappedMDP = wrappedMDP;
|
||||||
this.shape = wrappedMDP.getObservationSpace().getShape();
|
this.shape = wrappedMDP.getObservationSpace().getShape();
|
||||||
this.observationSpace = new WrapperObservationSpace(shape);
|
this.observationSpace = new WrapperObservationSpace(shape);
|
||||||
this.historyProcessor = historyProcessor;
|
this.historyProcessor = historyProcessor;
|
||||||
this.epochStepCounter = epochStepCounter;
|
|
||||||
|
|
||||||
setHistoryProcessor(historyProcessor);
|
setHistoryProcessor(historyProcessor);
|
||||||
}
|
}
|
||||||
|
@ -63,6 +60,7 @@ public class LegacyMDPWrapper<O, A, AS extends ActionSpace<A>> implements MDP<Ob
|
||||||
createTransformProcess();
|
createTransformProcess();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//TODO: this transform process should be decoupled from history processor and configured seperately by the end-user
|
||||||
private void createTransformProcess() {
|
private void createTransformProcess() {
|
||||||
IHistoryProcessor historyProcessor = getHistoryProcessor();
|
IHistoryProcessor historyProcessor = getHistoryProcessor();
|
||||||
|
|
||||||
|
@ -103,7 +101,7 @@ public class LegacyMDPWrapper<O, A, AS extends ActionSpace<A>> implements MDP<Ob
|
||||||
public Observation reset() {
|
public Observation reset() {
|
||||||
transformProcess.reset();
|
transformProcess.reset();
|
||||||
|
|
||||||
O rawResetResponse = wrappedMDP.reset();
|
OBSERVATION rawResetResponse = wrappedMDP.reset();
|
||||||
record(rawResetResponse);
|
record(rawResetResponse);
|
||||||
|
|
||||||
if(historyProcessor != null) {
|
if(historyProcessor != null) {
|
||||||
|
@ -118,21 +116,21 @@ public class LegacyMDPWrapper<O, A, AS extends ActionSpace<A>> implements MDP<Ob
|
||||||
public StepReply<Observation> step(A a) {
|
public StepReply<Observation> step(A a) {
|
||||||
IHistoryProcessor historyProcessor = getHistoryProcessor();
|
IHistoryProcessor historyProcessor = getHistoryProcessor();
|
||||||
|
|
||||||
StepReply<O> rawStepReply = wrappedMDP.step(a);
|
StepReply<OBSERVATION> rawStepReply = wrappedMDP.step(a);
|
||||||
INDArray rawObservation = getInput(rawStepReply.getObservation());
|
INDArray rawObservation = getInput(rawStepReply.getObservation());
|
||||||
|
|
||||||
if(historyProcessor != null) {
|
if(historyProcessor != null) {
|
||||||
historyProcessor.record(rawObservation);
|
historyProcessor.record(rawObservation);
|
||||||
}
|
}
|
||||||
|
|
||||||
int stepOfObservation = epochStepCounter.getCurrentEpochStep() + 1;
|
int stepOfObservation = steps++;
|
||||||
|
|
||||||
Map<String, Object> channelsData = buildChannelsData(rawStepReply.getObservation());
|
Map<String, Object> channelsData = buildChannelsData(rawStepReply.getObservation());
|
||||||
Observation observation = transformProcess.transform(channelsData, stepOfObservation, rawStepReply.isDone());
|
Observation observation = transformProcess.transform(channelsData, stepOfObservation, rawStepReply.isDone());
|
||||||
return new StepReply<Observation>(observation, rawStepReply.getReward(), rawStepReply.isDone(), rawStepReply.getInfo());
|
return new StepReply<Observation>(observation, rawStepReply.getReward(), rawStepReply.isDone(), rawStepReply.getInfo());
|
||||||
}
|
}
|
||||||
|
|
||||||
private void record(O obs) {
|
private void record(OBSERVATION obs) {
|
||||||
INDArray rawObservation = getInput(obs);
|
INDArray rawObservation = getInput(obs);
|
||||||
|
|
||||||
IHistoryProcessor historyProcessor = getHistoryProcessor();
|
IHistoryProcessor historyProcessor = getHistoryProcessor();
|
||||||
|
@ -141,7 +139,7 @@ public class LegacyMDPWrapper<O, A, AS extends ActionSpace<A>> implements MDP<Ob
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
private Map<String, Object> buildChannelsData(final O obs) {
|
private Map<String, Object> buildChannelsData(final OBSERVATION obs) {
|
||||||
return new HashMap<String, Object>() {{
|
return new HashMap<String, Object>() {{
|
||||||
put("data", obs);
|
put("data", obs);
|
||||||
}};
|
}};
|
||||||
|
@ -159,11 +157,11 @@ public class LegacyMDPWrapper<O, A, AS extends ActionSpace<A>> implements MDP<Ob
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public MDP<Observation, A, AS> newInstance() {
|
public MDP<Observation, A, AS> newInstance() {
|
||||||
return new LegacyMDPWrapper<O, A, AS>(wrappedMDP.newInstance(), historyProcessor, epochStepCounter);
|
return new LegacyMDPWrapper<>(wrappedMDP.newInstance(), historyProcessor);
|
||||||
}
|
}
|
||||||
|
|
||||||
private INDArray getInput(O obs) {
|
private INDArray getInput(OBSERVATION obs) {
|
||||||
INDArray arr = Nd4j.create(((Encodable)obs).toArray());
|
INDArray arr = Nd4j.create(obs.toArray());
|
||||||
int[] shape = observationSpace.getShape();
|
int[] shape = observationSpace.getShape();
|
||||||
if (shape.length == 1)
|
if (shape.length == 1)
|
||||||
return arr.reshape(new long[] {1, arr.length()});
|
return arr.reshape(new long[] {1, arr.length()});
|
||||||
|
|
|
@ -1,107 +1,107 @@
|
||||||
package org.deeplearning4j.rl4j.experience;
|
package org.deeplearning4j.rl4j.experience;
|
||||||
|
|
||||||
import org.deeplearning4j.rl4j.learning.sync.IExpReplay;
|
import org.deeplearning4j.rl4j.learning.sync.IExpReplay;
|
||||||
import org.deeplearning4j.rl4j.learning.sync.Transition;
|
import org.deeplearning4j.rl4j.learning.sync.Transition;
|
||||||
import org.deeplearning4j.rl4j.observation.Observation;
|
import org.deeplearning4j.rl4j.observation.Observation;
|
||||||
import org.junit.Test;
|
import org.junit.Test;
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
|
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
|
||||||
import static org.junit.Assert.assertEquals;
|
import static org.junit.Assert.assertEquals;
|
||||||
|
|
||||||
public class ReplayMemoryExperienceHandlerTest {
|
public class ReplayMemoryExperienceHandlerTest {
|
||||||
@Test
|
@Test
|
||||||
public void when_addingFirstExperience_expect_notAddedToStoreBeforeNextObservationIsAdded() {
|
public void when_addingFirstExperience_expect_notAddedToStoreBeforeNextObservationIsAdded() {
|
||||||
// Arrange
|
// Arrange
|
||||||
TestExpReplay expReplayMock = new TestExpReplay();
|
TestExpReplay expReplayMock = new TestExpReplay();
|
||||||
ReplayMemoryExperienceHandler sut = new ReplayMemoryExperienceHandler(expReplayMock);
|
ReplayMemoryExperienceHandler sut = new ReplayMemoryExperienceHandler(expReplayMock);
|
||||||
|
|
||||||
// Act
|
// Act
|
||||||
sut.addExperience(new Observation(Nd4j.create(new double[] { 1.0 })), 1, 1.0, false);
|
sut.addExperience(new Observation(Nd4j.create(new double[] { 1.0 })), 1, 1.0, false);
|
||||||
int numStoredTransitions = expReplayMock.addedTransitions.size();
|
int numStoredTransitions = expReplayMock.addedTransitions.size();
|
||||||
sut.addExperience(new Observation(Nd4j.create(new double[] { 2.0 })), 2, 2.0, false);
|
sut.addExperience(new Observation(Nd4j.create(new double[] { 2.0 })), 2, 2.0, false);
|
||||||
|
|
||||||
// Assert
|
// Assert
|
||||||
assertEquals(0, numStoredTransitions);
|
assertEquals(0, numStoredTransitions);
|
||||||
assertEquals(1, expReplayMock.addedTransitions.size());
|
assertEquals(1, expReplayMock.addedTransitions.size());
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void when_addingExperience_expect_transitionsAreCorrect() {
|
public void when_addingExperience_expect_transitionsAreCorrect() {
|
||||||
// Arrange
|
// Arrange
|
||||||
TestExpReplay expReplayMock = new TestExpReplay();
|
TestExpReplay expReplayMock = new TestExpReplay();
|
||||||
ReplayMemoryExperienceHandler sut = new ReplayMemoryExperienceHandler(expReplayMock);
|
ReplayMemoryExperienceHandler sut = new ReplayMemoryExperienceHandler(expReplayMock);
|
||||||
|
|
||||||
// Act
|
// Act
|
||||||
sut.addExperience(new Observation(Nd4j.create(new double[] { 1.0 })), 1, 1.0, false);
|
sut.addExperience(new Observation(Nd4j.create(new double[] { 1.0 })), 1, 1.0, false);
|
||||||
sut.addExperience(new Observation(Nd4j.create(new double[] { 2.0 })), 2, 2.0, false);
|
sut.addExperience(new Observation(Nd4j.create(new double[] { 2.0 })), 2, 2.0, false);
|
||||||
sut.setFinalObservation(new Observation(Nd4j.create(new double[] { 3.0 })));
|
sut.setFinalObservation(new Observation(Nd4j.create(new double[] { 3.0 })));
|
||||||
|
|
||||||
// Assert
|
// Assert
|
||||||
assertEquals(2, expReplayMock.addedTransitions.size());
|
assertEquals(2, expReplayMock.addedTransitions.size());
|
||||||
|
|
||||||
assertEquals(1.0, expReplayMock.addedTransitions.get(0).getObservation().getData().getDouble(0), 0.00001);
|
assertEquals(1.0, expReplayMock.addedTransitions.get(0).getObservation().getData().getDouble(0), 0.00001);
|
||||||
assertEquals(1, (int)expReplayMock.addedTransitions.get(0).getAction());
|
assertEquals(1, (int)expReplayMock.addedTransitions.get(0).getAction());
|
||||||
assertEquals(1.0, expReplayMock.addedTransitions.get(0).getReward(), 0.00001);
|
assertEquals(1.0, expReplayMock.addedTransitions.get(0).getReward(), 0.00001);
|
||||||
assertEquals(2.0, expReplayMock.addedTransitions.get(0).getNextObservation().getDouble(0), 0.00001);
|
assertEquals(2.0, expReplayMock.addedTransitions.get(0).getNextObservation().getDouble(0), 0.00001);
|
||||||
|
|
||||||
assertEquals(2.0, expReplayMock.addedTransitions.get(1).getObservation().getData().getDouble(0), 0.00001);
|
assertEquals(2.0, expReplayMock.addedTransitions.get(1).getObservation().getData().getDouble(0), 0.00001);
|
||||||
assertEquals(2, (int)expReplayMock.addedTransitions.get(1).getAction());
|
assertEquals(2, (int)expReplayMock.addedTransitions.get(1).getAction());
|
||||||
assertEquals(2.0, expReplayMock.addedTransitions.get(1).getReward(), 0.00001);
|
assertEquals(2.0, expReplayMock.addedTransitions.get(1).getReward(), 0.00001);
|
||||||
assertEquals(3.0, expReplayMock.addedTransitions.get(1).getNextObservation().getDouble(0), 0.00001);
|
assertEquals(3.0, expReplayMock.addedTransitions.get(1).getNextObservation().getDouble(0), 0.00001);
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void when_settingFinalObservation_expect_nextAddedExperienceDoNotUsePreviousObservation() {
|
public void when_settingFinalObservation_expect_nextAddedExperienceDoNotUsePreviousObservation() {
|
||||||
// Arrange
|
// Arrange
|
||||||
TestExpReplay expReplayMock = new TestExpReplay();
|
TestExpReplay expReplayMock = new TestExpReplay();
|
||||||
ReplayMemoryExperienceHandler sut = new ReplayMemoryExperienceHandler(expReplayMock);
|
ReplayMemoryExperienceHandler sut = new ReplayMemoryExperienceHandler(expReplayMock);
|
||||||
|
|
||||||
// Act
|
// Act
|
||||||
sut.addExperience(new Observation(Nd4j.create(new double[] { 1.0 })), 1, 1.0, false);
|
sut.addExperience(new Observation(Nd4j.create(new double[] { 1.0 })), 1, 1.0, false);
|
||||||
sut.setFinalObservation(new Observation(Nd4j.create(new double[] { 2.0 })));
|
sut.setFinalObservation(new Observation(Nd4j.create(new double[] { 2.0 })));
|
||||||
sut.addExperience(new Observation(Nd4j.create(new double[] { 3.0 })), 3, 3.0, false);
|
sut.addExperience(new Observation(Nd4j.create(new double[] { 3.0 })), 3, 3.0, false);
|
||||||
|
|
||||||
// Assert
|
// Assert
|
||||||
assertEquals(1, expReplayMock.addedTransitions.size());
|
assertEquals(1, expReplayMock.addedTransitions.size());
|
||||||
assertEquals(1, (int)expReplayMock.addedTransitions.get(0).getAction());
|
assertEquals(1, (int)expReplayMock.addedTransitions.get(0).getAction());
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void when_addingExperience_expect_getTrainingBatchSizeReturnSize() {
|
public void when_addingExperience_expect_getTrainingBatchSizeReturnSize() {
|
||||||
// Arrange
|
// Arrange
|
||||||
TestExpReplay expReplayMock = new TestExpReplay();
|
TestExpReplay expReplayMock = new TestExpReplay();
|
||||||
ReplayMemoryExperienceHandler sut = new ReplayMemoryExperienceHandler(expReplayMock);
|
ReplayMemoryExperienceHandler sut = new ReplayMemoryExperienceHandler(expReplayMock);
|
||||||
sut.addExperience(new Observation(Nd4j.create(new double[] { 1.0 })), 1, 1.0, false);
|
sut.addExperience(new Observation(Nd4j.create(new double[] { 1.0 })), 1, 1.0, false);
|
||||||
sut.addExperience(new Observation(Nd4j.create(new double[] { 2.0 })), 2, 2.0, false);
|
sut.addExperience(new Observation(Nd4j.create(new double[] { 2.0 })), 2, 2.0, false);
|
||||||
sut.setFinalObservation(new Observation(Nd4j.create(new double[] { 3.0 })));
|
sut.setFinalObservation(new Observation(Nd4j.create(new double[] { 3.0 })));
|
||||||
|
|
||||||
// Act
|
// Act
|
||||||
int size = sut.getTrainingBatchSize();
|
int size = sut.getTrainingBatchSize();
|
||||||
// Assert
|
// Assert
|
||||||
assertEquals(2, size);
|
assertEquals(2, size);
|
||||||
}
|
}
|
||||||
|
|
||||||
private static class TestExpReplay implements IExpReplay<Integer> {
|
private static class TestExpReplay implements IExpReplay<Integer> {
|
||||||
|
|
||||||
public final List<Transition<Integer>> addedTransitions = new ArrayList<>();
|
public final List<Transition<Integer>> addedTransitions = new ArrayList<>();
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public ArrayList<Transition<Integer>> getBatch() {
|
public ArrayList<Transition<Integer>> getBatch() {
|
||||||
return null;
|
return null;
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void store(Transition<Integer> transition) {
|
public void store(Transition<Integer> transition) {
|
||||||
addedTransitions.add(transition);
|
addedTransitions.add(transition);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public int getBatchSize() {
|
public int getBatchSize() {
|
||||||
return addedTransitions.size();
|
return addedTransitions.size();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -18,132 +18,93 @@
|
||||||
package org.deeplearning4j.rl4j.learning.async;
|
package org.deeplearning4j.rl4j.learning.async;
|
||||||
|
|
||||||
import org.deeplearning4j.rl4j.learning.configuration.IAsyncLearningConfiguration;
|
import org.deeplearning4j.rl4j.learning.configuration.IAsyncLearningConfiguration;
|
||||||
import org.deeplearning4j.rl4j.mdp.MDP;
|
import org.deeplearning4j.rl4j.learning.listener.TrainingListener;
|
||||||
import org.deeplearning4j.rl4j.policy.IPolicy;
|
import org.deeplearning4j.rl4j.network.NeuralNet;
|
||||||
import org.deeplearning4j.rl4j.space.DiscreteSpace;
|
import org.deeplearning4j.rl4j.space.ActionSpace;
|
||||||
import org.deeplearning4j.rl4j.support.MockAsyncConfiguration;
|
import org.deeplearning4j.rl4j.space.Box;
|
||||||
import org.deeplearning4j.rl4j.support.MockAsyncGlobal;
|
import org.junit.Before;
|
||||||
import org.deeplearning4j.rl4j.support.MockEncodable;
|
|
||||||
import org.deeplearning4j.rl4j.support.MockNeuralNet;
|
|
||||||
import org.deeplearning4j.rl4j.support.MockPolicy;
|
|
||||||
import org.deeplearning4j.rl4j.support.MockTrainingListener;
|
|
||||||
import org.junit.Test;
|
import org.junit.Test;
|
||||||
|
import org.junit.runner.RunWith;
|
||||||
|
import org.mockito.Mock;
|
||||||
|
import org.mockito.Mockito;
|
||||||
|
import org.mockito.junit.MockitoJUnitRunner;
|
||||||
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
|
|
||||||
import static org.junit.Assert.assertEquals;
|
import static org.mockito.ArgumentMatchers.eq;
|
||||||
import static org.junit.Assert.assertTrue;
|
import static org.mockito.Mockito.mock;
|
||||||
|
import static org.mockito.Mockito.times;
|
||||||
|
import static org.mockito.Mockito.verify;
|
||||||
|
import static org.mockito.Mockito.when;
|
||||||
|
|
||||||
|
|
||||||
|
@RunWith(MockitoJUnitRunner.class)
|
||||||
public class AsyncLearningTest {
|
public class AsyncLearningTest {
|
||||||
|
|
||||||
@Test
|
AsyncLearning<Box, INDArray, ActionSpace<INDArray>, NeuralNet> asyncLearning;
|
||||||
public void when_training_expect_AsyncGlobalStarted() {
|
|
||||||
// Arrange
|
|
||||||
TestContext context = new TestContext();
|
|
||||||
context.asyncGlobal.setMaxLoops(1);
|
|
||||||
|
|
||||||
// Act
|
@Mock
|
||||||
context.sut.train();
|
TrainingListener mockTrainingListener;
|
||||||
|
|
||||||
// Assert
|
@Mock
|
||||||
assertTrue(context.asyncGlobal.hasBeenStarted);
|
AsyncGlobal<NeuralNet> mockAsyncGlobal;
|
||||||
assertTrue(context.asyncGlobal.hasBeenTerminated);
|
|
||||||
|
@Mock
|
||||||
|
IAsyncLearningConfiguration mockConfiguration;
|
||||||
|
|
||||||
|
@Before
|
||||||
|
public void setup() {
|
||||||
|
asyncLearning = mock(AsyncLearning.class, Mockito.withSettings()
|
||||||
|
.useConstructor()
|
||||||
|
.defaultAnswer(Mockito.CALLS_REAL_METHODS));
|
||||||
|
|
||||||
|
asyncLearning.addListener(mockTrainingListener);
|
||||||
|
|
||||||
|
when(asyncLearning.getAsyncGlobal()).thenReturn(mockAsyncGlobal);
|
||||||
|
when(asyncLearning.getConfiguration()).thenReturn(mockConfiguration);
|
||||||
|
|
||||||
|
// Don't actually start any threads in any of these tests
|
||||||
|
when(mockConfiguration.getNumThreads()).thenReturn(0);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void when_trainStartReturnsStop_expect_noTraining() {
|
public void when_trainStartReturnsStop_expect_noTraining() {
|
||||||
// Arrange
|
// Arrange
|
||||||
TestContext context = new TestContext();
|
when(mockTrainingListener.onTrainingStart()).thenReturn(TrainingListener.ListenerResponse.STOP);
|
||||||
context.listener.setRemainingTrainingStartCallCount(0);
|
|
||||||
// Act
|
// Act
|
||||||
context.sut.train();
|
asyncLearning.train();
|
||||||
|
|
||||||
// Assert
|
// Assert
|
||||||
assertEquals(1, context.listener.onTrainingStartCallCount);
|
verify(mockTrainingListener, times(1)).onTrainingStart();
|
||||||
assertEquals(1, context.listener.onTrainingEndCallCount);
|
verify(mockTrainingListener, times(1)).onTrainingEnd();
|
||||||
assertEquals(0, context.policy.playCallCount);
|
|
||||||
assertTrue(context.asyncGlobal.hasBeenTerminated);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void when_trainingIsComplete_expect_trainingStop() {
|
public void when_trainingIsComplete_expect_trainingStop() {
|
||||||
// Arrange
|
// Arrange
|
||||||
TestContext context = new TestContext();
|
when(mockAsyncGlobal.isTrainingComplete()).thenReturn(true);
|
||||||
|
|
||||||
// Act
|
// Act
|
||||||
context.sut.train();
|
asyncLearning.train();
|
||||||
|
|
||||||
// Assert
|
// Assert
|
||||||
assertEquals(1, context.listener.onTrainingStartCallCount);
|
verify(mockTrainingListener, times(1)).onTrainingStart();
|
||||||
assertEquals(1, context.listener.onTrainingEndCallCount);
|
verify(mockTrainingListener, times(1)).onTrainingEnd();
|
||||||
assertTrue(context.asyncGlobal.hasBeenTerminated);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void when_training_expect_onTrainingProgressCalled() {
|
public void when_training_expect_onTrainingProgressCalled() {
|
||||||
// Arrange
|
// Arrange
|
||||||
TestContext context = new TestContext();
|
asyncLearning.setProgressMonitorFrequency(100);
|
||||||
|
when(mockTrainingListener.onTrainingProgress(eq(asyncLearning))).thenReturn(TrainingListener.ListenerResponse.STOP);
|
||||||
|
|
||||||
// Act
|
// Act
|
||||||
context.sut.train();
|
asyncLearning.train();
|
||||||
|
|
||||||
// Assert
|
// Assert
|
||||||
assertEquals(1, context.listener.onTrainingProgressCallCount);
|
verify(mockTrainingListener, times(1)).onTrainingStart();
|
||||||
|
verify(mockTrainingListener, times(1)).onTrainingEnd();
|
||||||
|
verify(mockTrainingListener, times(1)).onTrainingProgress(eq(asyncLearning));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
public static class TestContext {
|
|
||||||
MockAsyncConfiguration config = new MockAsyncConfiguration(1L, 11, 0, 0, 0, 0,0, 0, 0, 0);
|
|
||||||
public final MockAsyncGlobal asyncGlobal = new MockAsyncGlobal();
|
|
||||||
public final MockPolicy policy = new MockPolicy();
|
|
||||||
public final TestAsyncLearning sut = new TestAsyncLearning(config, asyncGlobal, policy);
|
|
||||||
public final MockTrainingListener listener = new MockTrainingListener(asyncGlobal);
|
|
||||||
|
|
||||||
public TestContext() {
|
|
||||||
sut.addListener(listener);
|
|
||||||
asyncGlobal.setMaxLoops(1);
|
|
||||||
sut.setProgressMonitorFrequency(1);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
public static class TestAsyncLearning extends AsyncLearning<MockEncodable, Integer, DiscreteSpace, MockNeuralNet> {
|
|
||||||
private final IAsyncLearningConfiguration conf;
|
|
||||||
private final IAsyncGlobal asyncGlobal;
|
|
||||||
private final IPolicy<MockEncodable, Integer> policy;
|
|
||||||
|
|
||||||
public TestAsyncLearning(IAsyncLearningConfiguration conf, IAsyncGlobal asyncGlobal, IPolicy<MockEncodable, Integer> policy) {
|
|
||||||
this.conf = conf;
|
|
||||||
this.asyncGlobal = asyncGlobal;
|
|
||||||
this.policy = policy;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public IPolicy getPolicy() {
|
|
||||||
return policy;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public IAsyncLearningConfiguration getConfiguration() {
|
|
||||||
return conf;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
protected AsyncThread newThread(int i, int deviceAffinity) {
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public MDP getMdp() {
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
protected IAsyncGlobal getAsyncGlobal() {
|
|
||||||
return asyncGlobal;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public MockNeuralNet getNeuralNet() {
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,6 +1,5 @@
|
||||||
/*******************************************************************************
|
/*******************************************************************************
|
||||||
* Copyright (c) 2015-2019 Skymind, Inc.
|
* Copyright (c) 2020 Konduit K. K.
|
||||||
* Copyright (c) 2020 Konduit K.K.
|
|
||||||
*
|
*
|
||||||
* This program and the accompanying materials are made available under the
|
* This program and the accompanying materials are made available under the
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
@ -17,161 +16,230 @@
|
||||||
|
|
||||||
package org.deeplearning4j.rl4j.learning.async;
|
package org.deeplearning4j.rl4j.learning.async;
|
||||||
|
|
||||||
import org.deeplearning4j.nn.gradient.Gradient;
|
import org.deeplearning4j.gym.StepReply;
|
||||||
import org.deeplearning4j.rl4j.experience.ExperienceHandler;
|
|
||||||
import org.deeplearning4j.rl4j.experience.StateActionPair;
|
|
||||||
import org.deeplearning4j.rl4j.learning.IHistoryProcessor;
|
|
||||||
import org.deeplearning4j.rl4j.learning.configuration.IAsyncLearningConfiguration;
|
import org.deeplearning4j.rl4j.learning.configuration.IAsyncLearningConfiguration;
|
||||||
import org.deeplearning4j.rl4j.learning.listener.TrainingListenerList;
|
import org.deeplearning4j.rl4j.learning.listener.TrainingListenerList;
|
||||||
import org.deeplearning4j.rl4j.learning.sync.Transition;
|
|
||||||
import org.deeplearning4j.rl4j.mdp.MDP;
|
import org.deeplearning4j.rl4j.mdp.MDP;
|
||||||
|
import org.deeplearning4j.rl4j.network.NeuralNet;
|
||||||
import org.deeplearning4j.rl4j.observation.Observation;
|
import org.deeplearning4j.rl4j.observation.Observation;
|
||||||
import org.deeplearning4j.rl4j.policy.IPolicy;
|
import org.deeplearning4j.rl4j.policy.Policy;
|
||||||
import org.deeplearning4j.rl4j.space.DiscreteSpace;
|
import org.deeplearning4j.rl4j.space.DiscreteSpace;
|
||||||
import org.deeplearning4j.rl4j.support.*;
|
import org.deeplearning4j.rl4j.space.Encodable;
|
||||||
|
import org.deeplearning4j.rl4j.space.ObservationSpace;
|
||||||
|
import org.deeplearning4j.rl4j.util.LegacyMDPWrapper;
|
||||||
|
import org.junit.Before;
|
||||||
import org.junit.Test;
|
import org.junit.Test;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.junit.runner.RunWith;
|
||||||
|
import org.mockito.Mock;
|
||||||
|
import org.mockito.Mockito;
|
||||||
|
import org.mockito.junit.MockitoJUnitRunner;
|
||||||
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
|
|
||||||
import java.util.ArrayList;
|
import java.util.concurrent.atomic.AtomicInteger;
|
||||||
import java.util.List;
|
|
||||||
|
|
||||||
import static org.junit.Assert.assertEquals;
|
import static org.junit.Assert.assertEquals;
|
||||||
|
import static org.junit.Assert.assertFalse;
|
||||||
|
import static org.junit.Assert.assertTrue;
|
||||||
|
import static org.mockito.ArgumentMatchers.any;
|
||||||
|
import static org.mockito.ArgumentMatchers.anyInt;
|
||||||
|
import static org.mockito.ArgumentMatchers.eq;
|
||||||
|
import static org.mockito.Mockito.mock;
|
||||||
|
import static org.mockito.Mockito.when;
|
||||||
|
|
||||||
|
@RunWith(MockitoJUnitRunner.class)
|
||||||
public class AsyncThreadDiscreteTest {
|
public class AsyncThreadDiscreteTest {
|
||||||
|
|
||||||
|
|
||||||
|
AsyncThreadDiscrete<Encodable, NeuralNet> asyncThreadDiscrete;
|
||||||
|
|
||||||
|
@Mock
|
||||||
|
IAsyncLearningConfiguration mockAsyncConfiguration;
|
||||||
|
|
||||||
|
@Mock
|
||||||
|
UpdateAlgorithm<NeuralNet> mockUpdateAlgorithm;
|
||||||
|
|
||||||
|
@Mock
|
||||||
|
IAsyncGlobal<NeuralNet> mockAsyncGlobal;
|
||||||
|
|
||||||
|
@Mock
|
||||||
|
Policy<Encodable, Integer> mockGlobalCurrentPolicy;
|
||||||
|
|
||||||
|
@Mock
|
||||||
|
NeuralNet mockGlobalTargetNetwork;
|
||||||
|
|
||||||
|
@Mock
|
||||||
|
MDP<Encodable, Integer, DiscreteSpace> mockMDP;
|
||||||
|
|
||||||
|
@Mock
|
||||||
|
LegacyMDPWrapper<Encodable, Integer, DiscreteSpace> mockLegacyMDPWrapper;
|
||||||
|
|
||||||
|
@Mock
|
||||||
|
DiscreteSpace mockActionSpace;
|
||||||
|
|
||||||
|
@Mock
|
||||||
|
ObservationSpace<Encodable> mockObservationSpace;
|
||||||
|
|
||||||
|
@Mock
|
||||||
|
TrainingListenerList mockTrainingListenerList;
|
||||||
|
|
||||||
|
@Mock
|
||||||
|
Observation mockObservation;
|
||||||
|
|
||||||
|
int[] observationShape = new int[]{3, 10, 10};
|
||||||
|
int actionSize = 4;
|
||||||
|
|
||||||
|
private void setupMDPMocks() {
|
||||||
|
|
||||||
|
when(mockActionSpace.noOp()).thenReturn(0);
|
||||||
|
when(mockMDP.getActionSpace()).thenReturn(mockActionSpace);
|
||||||
|
|
||||||
|
when(mockObservationSpace.getShape()).thenReturn(observationShape);
|
||||||
|
when(mockMDP.getObservationSpace()).thenReturn(mockObservationSpace);
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
private void setupNNMocks() {
|
||||||
|
when(mockAsyncGlobal.getTarget()).thenReturn(mockGlobalTargetNetwork);
|
||||||
|
when(mockGlobalTargetNetwork.clone()).thenReturn(mockGlobalTargetNetwork);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Before
|
||||||
|
public void setup() {
|
||||||
|
|
||||||
|
setupMDPMocks();
|
||||||
|
setupNNMocks();
|
||||||
|
|
||||||
|
asyncThreadDiscrete = mock(AsyncThreadDiscrete.class, Mockito.withSettings()
|
||||||
|
.useConstructor(mockAsyncGlobal, mockMDP, mockTrainingListenerList, 0, 0)
|
||||||
|
.defaultAnswer(Mockito.CALLS_REAL_METHODS));
|
||||||
|
|
||||||
|
asyncThreadDiscrete.setUpdateAlgorithm(mockUpdateAlgorithm);
|
||||||
|
|
||||||
|
when(asyncThreadDiscrete.getConf()).thenReturn(mockAsyncConfiguration);
|
||||||
|
when(mockAsyncConfiguration.getRewardFactor()).thenReturn(1.0);
|
||||||
|
when(asyncThreadDiscrete.getAsyncGlobal()).thenReturn(mockAsyncGlobal);
|
||||||
|
when(asyncThreadDiscrete.getPolicy(eq(mockGlobalTargetNetwork))).thenReturn(mockGlobalCurrentPolicy);
|
||||||
|
|
||||||
|
when(mockGlobalCurrentPolicy.nextAction(any(Observation.class))).thenReturn(0);
|
||||||
|
|
||||||
|
when(asyncThreadDiscrete.getLegacyMDPWrapper()).thenReturn(mockLegacyMDPWrapper);
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void refac_AsyncThreadDiscrete_trainSubEpoch() {
|
public void when_episodeCompletes_expect_stepsToBeInLineWithEpisodeLenth() {
|
||||||
|
|
||||||
// Arrange
|
// Arrange
|
||||||
int numEpochs = 1;
|
int episodeRemaining = 5;
|
||||||
MockNeuralNet nnMock = new MockNeuralNet();
|
int remainingTrainingSteps = 10;
|
||||||
IHistoryProcessor.Configuration hpConf = new IHistoryProcessor.Configuration(5, 4, 4, 4, 4, 0, 0, 2);
|
|
||||||
MockHistoryProcessor hpMock = new MockHistoryProcessor(hpConf);
|
// return done after 4 steps (the episode finishes before nsteps)
|
||||||
MockAsyncGlobal asyncGlobalMock = new MockAsyncGlobal(nnMock);
|
when(mockMDP.isDone()).thenAnswer(invocation ->
|
||||||
asyncGlobalMock.setMaxLoops(hpConf.getSkipFrame() * numEpochs);
|
asyncThreadDiscrete.getStepCount() == episodeRemaining
|
||||||
MockObservationSpace observationSpace = new MockObservationSpace();
|
);
|
||||||
MockMDP mdpMock = new MockMDP(observationSpace);
|
|
||||||
TrainingListenerList listeners = new TrainingListenerList();
|
when(mockLegacyMDPWrapper.step(0)).thenReturn(new StepReply<>(mockObservation, 0.0, false, null));
|
||||||
MockPolicy policyMock = new MockPolicy();
|
|
||||||
MockAsyncConfiguration config = new MockAsyncConfiguration(5L, 100, 0,0, 0, 0, 0, 0, 2, 5);
|
|
||||||
MockExperienceHandler experienceHandlerMock = new MockExperienceHandler();
|
|
||||||
MockUpdateAlgorithm updateAlgorithmMock = new MockUpdateAlgorithm();
|
|
||||||
TestAsyncThreadDiscrete sut = new TestAsyncThreadDiscrete(asyncGlobalMock, mdpMock, listeners, 0, 0, policyMock, config, hpMock, experienceHandlerMock, updateAlgorithmMock);
|
|
||||||
sut.getLegacyMDPWrapper().setTransformProcess(MockMDP.buildTransformProcess(observationSpace.getShape(), hpConf.getSkipFrame(), hpConf.getHistoryLength()));
|
|
||||||
|
|
||||||
// Act
|
// Act
|
||||||
sut.run();
|
AsyncThread.SubEpochReturn subEpochReturn = asyncThreadDiscrete.trainSubEpoch(mockObservation, remainingTrainingSteps);
|
||||||
|
|
||||||
// Assert
|
// Assert
|
||||||
assertEquals(2, sut.trainSubEpochResults.size());
|
assertTrue(subEpochReturn.isEpisodeComplete());
|
||||||
double[][] expectedLastObservations = new double[][] {
|
assertEquals(5, subEpochReturn.getSteps());
|
||||||
new double[] { 4.0, 6.0, 8.0, 10.0, 12.0 },
|
|
||||||
new double[] { 8.0, 10.0, 12.0, 14.0, 16.0 },
|
|
||||||
};
|
|
||||||
double[] expectedSubEpochReturnRewards = new double[] { 42.0, 58.0 };
|
|
||||||
for(int i = 0; i < 2; ++i) {
|
|
||||||
AsyncThread.SubEpochReturn result = sut.trainSubEpochResults.get(i);
|
|
||||||
assertEquals(4, result.getSteps());
|
|
||||||
assertEquals(expectedSubEpochReturnRewards[i], result.getReward(), 0.00001);
|
|
||||||
assertEquals(0.0, result.getScore(), 0.00001);
|
|
||||||
|
|
||||||
double[] expectedLastObservation = expectedLastObservations[i];
|
|
||||||
assertEquals(expectedLastObservation.length, result.getLastObs().getData().shape()[1]);
|
|
||||||
for(int j = 0; j < expectedLastObservation.length; ++j) {
|
|
||||||
assertEquals(expectedLastObservation[j], 255.0 * result.getLastObs().getData().getDouble(j), 0.00001);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
assertEquals(2, asyncGlobalMock.enqueueCallCount);
|
|
||||||
|
|
||||||
// HistoryProcessor
|
|
||||||
double[] expectedRecordValues = new double[] { 0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0, };
|
|
||||||
assertEquals(expectedRecordValues.length, hpMock.recordCalls.size());
|
|
||||||
for(int i = 0; i < expectedRecordValues.length; ++i) {
|
|
||||||
assertEquals(expectedRecordValues[i], hpMock.recordCalls.get(i).getDouble(0), 0.00001);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Policy
|
|
||||||
double[][] expectedPolicyInputs = new double[][] {
|
|
||||||
new double[] { 0.0, 2.0, 4.0, 6.0, 8.0 },
|
|
||||||
new double[] { 2.0, 4.0, 6.0, 8.0, 10.0 },
|
|
||||||
new double[] { 4.0, 6.0, 8.0, 10.0, 12.0 },
|
|
||||||
new double[] { 6.0, 8.0, 10.0, 12.0, 14.0 },
|
|
||||||
};
|
|
||||||
assertEquals(expectedPolicyInputs.length, policyMock.actionInputs.size());
|
|
||||||
for(int i = 0; i < expectedPolicyInputs.length; ++i) {
|
|
||||||
double[] expectedRow = expectedPolicyInputs[i];
|
|
||||||
INDArray input = policyMock.actionInputs.get(i);
|
|
||||||
assertEquals(expectedRow.length, input.shape()[1]);
|
|
||||||
for(int j = 0; j < expectedRow.length; ++j) {
|
|
||||||
assertEquals(expectedRow[j], 255.0 * input.getDouble(j), 0.00001);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// ExperienceHandler
|
|
||||||
double[][] expectedExperienceHandlerInputs = new double[][] {
|
|
||||||
new double[] { 0.0, 2.0, 4.0, 6.0, 8.0 },
|
|
||||||
new double[] { 2.0, 4.0, 6.0, 8.0, 10.0 },
|
|
||||||
new double[] { 4.0, 6.0, 8.0, 10.0, 12.0 },
|
|
||||||
new double[] { 6.0, 8.0, 10.0, 12.0, 14.0 },
|
|
||||||
};
|
|
||||||
assertEquals(expectedExperienceHandlerInputs.length, experienceHandlerMock.addExperienceArgs.size());
|
|
||||||
for(int i = 0; i < expectedExperienceHandlerInputs.length; ++i) {
|
|
||||||
double[] expectedRow = expectedExperienceHandlerInputs[i];
|
|
||||||
INDArray input = experienceHandlerMock.addExperienceArgs.get(i).getObservation().getData();
|
|
||||||
assertEquals(expectedRow.length, input.shape()[1]);
|
|
||||||
for(int j = 0; j < expectedRow.length; ++j) {
|
|
||||||
assertEquals(expectedRow[j], 255.0 * input.getDouble(j), 0.00001);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
public static class TestAsyncThreadDiscrete extends AsyncThreadDiscrete<MockEncodable, MockNeuralNet> {
|
@Test
|
||||||
|
public void when_episodeCompletesDueToMaxStepsReached_expect_isEpisodeComplete() {
|
||||||
|
|
||||||
private final MockAsyncGlobal asyncGlobal;
|
// Arrange
|
||||||
private final MockPolicy policy;
|
int remainingTrainingSteps = 50;
|
||||||
private final MockAsyncConfiguration config;
|
|
||||||
|
|
||||||
public final List<SubEpochReturn> trainSubEpochResults = new ArrayList<SubEpochReturn>();
|
// Episode does not complete due to MDP
|
||||||
|
when(mockMDP.isDone()).thenReturn(false);
|
||||||
|
|
||||||
public TestAsyncThreadDiscrete(MockAsyncGlobal asyncGlobal, MDP<MockEncodable, Integer, DiscreteSpace> mdp,
|
when(mockLegacyMDPWrapper.step(0)).thenReturn(new StepReply<>(mockObservation, 0.0, false, null));
|
||||||
TrainingListenerList listeners, int threadNumber, int deviceNum, MockPolicy policy,
|
|
||||||
MockAsyncConfiguration config, IHistoryProcessor hp,
|
|
||||||
ExperienceHandler<Integer, Transition<Integer>> experienceHandler,
|
|
||||||
UpdateAlgorithm<MockNeuralNet> updateAlgorithm) {
|
|
||||||
super(asyncGlobal, mdp, listeners, threadNumber, deviceNum);
|
|
||||||
this.asyncGlobal = asyncGlobal;
|
|
||||||
this.policy = policy;
|
|
||||||
this.config = config;
|
|
||||||
setHistoryProcessor(hp);
|
|
||||||
setExperienceHandler(experienceHandler);
|
|
||||||
setUpdateAlgorithm(updateAlgorithm);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
when(mockAsyncConfiguration.getMaxEpochStep()).thenReturn(50);
|
||||||
protected IAsyncGlobal<MockNeuralNet> getAsyncGlobal() {
|
|
||||||
return asyncGlobal;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
// Act
|
||||||
protected IAsyncLearningConfiguration getConf() {
|
AsyncThread.SubEpochReturn subEpochReturn = asyncThreadDiscrete.trainSubEpoch(mockObservation, remainingTrainingSteps);
|
||||||
return config;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
// Assert
|
||||||
protected IPolicy<MockEncodable, Integer> getPolicy(MockNeuralNet net) {
|
assertTrue(subEpochReturn.isEpisodeComplete());
|
||||||
return policy;
|
assertEquals(50, subEpochReturn.getSteps());
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
protected UpdateAlgorithm<MockNeuralNet> buildUpdateAlgorithm() {
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public SubEpochReturn trainSubEpoch(Observation sObs, int nstep) {
|
|
||||||
asyncGlobal.increaseCurrentLoop();
|
|
||||||
SubEpochReturn result = super.trainSubEpoch(sObs, nstep);
|
|
||||||
trainSubEpochResults.add(result);
|
|
||||||
return result;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void when_episodeLongerThanNsteps_expect_returnNStepLength() {
|
||||||
|
|
||||||
|
// Arrange
|
||||||
|
int episodeRemaining = 5;
|
||||||
|
int remainingTrainingSteps = 4;
|
||||||
|
|
||||||
|
// return done after 4 steps (the episode finishes before nsteps)
|
||||||
|
when(mockMDP.isDone()).thenAnswer(invocation ->
|
||||||
|
asyncThreadDiscrete.getStepCount() == episodeRemaining
|
||||||
|
);
|
||||||
|
|
||||||
|
when(mockLegacyMDPWrapper.step(0)).thenReturn(new StepReply<>(mockObservation, 0.0, false, null));
|
||||||
|
|
||||||
|
// Act
|
||||||
|
AsyncThread.SubEpochReturn subEpochReturn = asyncThreadDiscrete.trainSubEpoch(mockObservation, remainingTrainingSteps);
|
||||||
|
|
||||||
|
// Assert
|
||||||
|
assertFalse(subEpochReturn.isEpisodeComplete());
|
||||||
|
assertEquals(remainingTrainingSteps, subEpochReturn.getSteps());
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void when_framesAreSkipped_expect_proportionateStepCounterUpdates() {
|
||||||
|
int skipFrames = 2;
|
||||||
|
int remainingTrainingSteps = 10;
|
||||||
|
|
||||||
|
// Episode does not complete due to MDP
|
||||||
|
when(mockMDP.isDone()).thenReturn(false);
|
||||||
|
|
||||||
|
AtomicInteger stepCount = new AtomicInteger();
|
||||||
|
|
||||||
|
// Use skipFrames to return if observations are skipped or not
|
||||||
|
when(mockLegacyMDPWrapper.step(anyInt())).thenAnswer(invocationOnMock -> {
|
||||||
|
|
||||||
|
boolean isSkipped = stepCount.incrementAndGet() % skipFrames != 0;
|
||||||
|
|
||||||
|
Observation mockObs = new Observation(isSkipped ? null : Nd4j.create(observationShape));
|
||||||
|
return new StepReply<>(mockObs, 0.0, false, null);
|
||||||
|
});
|
||||||
|
|
||||||
|
|
||||||
|
// Act
|
||||||
|
AsyncThread.SubEpochReturn subEpochReturn = asyncThreadDiscrete.trainSubEpoch(mockObservation, remainingTrainingSteps);
|
||||||
|
|
||||||
|
// Assert
|
||||||
|
assertFalse(subEpochReturn.isEpisodeComplete());
|
||||||
|
assertEquals(remainingTrainingSteps, subEpochReturn.getSteps());
|
||||||
|
assertEquals((remainingTrainingSteps - 1) * skipFrames + 1, stepCount.get());
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void when_preEpisodeCalled_expect_experienceHandlerReset() {
|
||||||
|
|
||||||
|
// Arrange
|
||||||
|
int trainingSteps = 100;
|
||||||
|
for (int i = 0; i < trainingSteps; i++) {
|
||||||
|
asyncThreadDiscrete.getExperienceHandler().addExperience(mockObservation, 0, 0.0, false);
|
||||||
|
}
|
||||||
|
|
||||||
|
int experienceHandlerSizeBeforeReset = asyncThreadDiscrete.getExperienceHandler().getTrainingBatchSize();
|
||||||
|
|
||||||
|
// Act
|
||||||
|
asyncThreadDiscrete.preEpisode();
|
||||||
|
|
||||||
|
// Assert
|
||||||
|
assertEquals(100, experienceHandlerSizeBeforeReset);
|
||||||
|
assertEquals(0, asyncThreadDiscrete.getExperienceHandler().getTrainingBatchSize());
|
||||||
|
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,220 +1,277 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2020 Konduit K.K.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
package org.deeplearning4j.rl4j.learning.async;
|
package org.deeplearning4j.rl4j.learning.async;
|
||||||
|
|
||||||
import lombok.AllArgsConstructor;
|
|
||||||
import lombok.Getter;
|
|
||||||
import org.deeplearning4j.rl4j.learning.IHistoryProcessor;
|
|
||||||
import org.deeplearning4j.rl4j.learning.configuration.IAsyncLearningConfiguration;
|
import org.deeplearning4j.rl4j.learning.configuration.IAsyncLearningConfiguration;
|
||||||
import org.deeplearning4j.rl4j.learning.listener.TrainingListenerList;
|
import org.deeplearning4j.rl4j.learning.listener.TrainingListenerList;
|
||||||
import org.deeplearning4j.rl4j.mdp.MDP;
|
import org.deeplearning4j.rl4j.mdp.MDP;
|
||||||
|
import org.deeplearning4j.rl4j.network.NeuralNet;
|
||||||
import org.deeplearning4j.rl4j.observation.Observation;
|
import org.deeplearning4j.rl4j.observation.Observation;
|
||||||
import org.deeplearning4j.rl4j.policy.Policy;
|
import org.deeplearning4j.rl4j.space.ActionSpace;
|
||||||
import org.deeplearning4j.rl4j.space.DiscreteSpace;
|
import org.deeplearning4j.rl4j.space.Box;
|
||||||
import org.deeplearning4j.rl4j.support.MockAsyncConfiguration;
|
import org.deeplearning4j.rl4j.space.ObservationSpace;
|
||||||
import org.deeplearning4j.rl4j.support.MockAsyncGlobal;
|
|
||||||
import org.deeplearning4j.rl4j.support.MockEncodable;
|
|
||||||
import org.deeplearning4j.rl4j.support.MockHistoryProcessor;
|
|
||||||
import org.deeplearning4j.rl4j.support.MockMDP;
|
|
||||||
import org.deeplearning4j.rl4j.support.MockNeuralNet;
|
|
||||||
import org.deeplearning4j.rl4j.support.MockObservationSpace;
|
|
||||||
import org.deeplearning4j.rl4j.support.MockTrainingListener;
|
|
||||||
import org.deeplearning4j.rl4j.util.IDataManager;
|
import org.deeplearning4j.rl4j.util.IDataManager;
|
||||||
|
import org.junit.Before;
|
||||||
import org.junit.Test;
|
import org.junit.Test;
|
||||||
|
import org.junit.runner.RunWith;
|
||||||
import java.util.ArrayList;
|
import org.mockito.Mock;
|
||||||
import java.util.List;
|
import org.mockito.Mockito;
|
||||||
|
import org.mockito.junit.MockitoJUnitRunner;
|
||||||
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
|
import org.nd4j.shade.guava.base.Preconditions;
|
||||||
|
|
||||||
import static org.junit.Assert.assertEquals;
|
import static org.junit.Assert.assertEquals;
|
||||||
|
import static org.mockito.ArgumentMatchers.any;
|
||||||
|
import static org.mockito.ArgumentMatchers.anyInt;
|
||||||
|
import static org.mockito.ArgumentMatchers.eq;
|
||||||
|
import static org.mockito.Mockito.clearInvocations;
|
||||||
|
import static org.mockito.Mockito.doAnswer;
|
||||||
|
import static org.mockito.Mockito.mock;
|
||||||
|
import static org.mockito.Mockito.times;
|
||||||
|
import static org.mockito.Mockito.verify;
|
||||||
|
import static org.mockito.Mockito.when;
|
||||||
|
|
||||||
|
@RunWith(MockitoJUnitRunner.class)
|
||||||
public class AsyncThreadTest {
|
public class AsyncThreadTest {
|
||||||
|
|
||||||
@Test
|
@Mock
|
||||||
public void when_newEpochStarted_expect_neuralNetworkReset() {
|
ActionSpace<INDArray> mockActionSpace;
|
||||||
// Arrange
|
|
||||||
int numberOfEpochs = 5;
|
|
||||||
TestContext context = new TestContext(numberOfEpochs);
|
|
||||||
|
|
||||||
// Act
|
@Mock
|
||||||
context.sut.run();
|
ObservationSpace<Box> mockObservationSpace;
|
||||||
|
|
||||||
// Assert
|
@Mock
|
||||||
assertEquals(numberOfEpochs, context.neuralNet.resetCallCount);
|
IAsyncLearningConfiguration mockAsyncConfiguration;
|
||||||
|
|
||||||
|
@Mock
|
||||||
|
NeuralNet mockNeuralNet;
|
||||||
|
|
||||||
|
@Mock
|
||||||
|
IAsyncGlobal<NeuralNet> mockAsyncGlobal;
|
||||||
|
|
||||||
|
@Mock
|
||||||
|
MDP<Box, INDArray, ActionSpace<INDArray>> mockMDP;
|
||||||
|
|
||||||
|
@Mock
|
||||||
|
TrainingListenerList mockTrainingListeners;
|
||||||
|
|
||||||
|
int[] observationShape = new int[]{3, 10, 10};
|
||||||
|
int actionSize = 4;
|
||||||
|
|
||||||
|
AsyncThread<Box, INDArray, ActionSpace<INDArray>, NeuralNet> thread;
|
||||||
|
|
||||||
|
@Before
|
||||||
|
public void setup() {
|
||||||
|
setupMDPMocks();
|
||||||
|
setupThreadMocks();
|
||||||
|
}
|
||||||
|
|
||||||
|
private void setupThreadMocks() {
|
||||||
|
|
||||||
|
thread = mock(AsyncThread.class, Mockito.withSettings()
|
||||||
|
.useConstructor(mockMDP, mockTrainingListeners, 0, 0)
|
||||||
|
.defaultAnswer(Mockito.CALLS_REAL_METHODS));
|
||||||
|
|
||||||
|
when(thread.getAsyncGlobal()).thenReturn(mockAsyncGlobal);
|
||||||
|
when(thread.getCurrent()).thenReturn(mockNeuralNet);
|
||||||
|
}
|
||||||
|
|
||||||
|
private void setupMDPMocks() {
|
||||||
|
|
||||||
|
when(mockObservationSpace.getShape()).thenReturn(observationShape);
|
||||||
|
when(mockActionSpace.noOp()).thenReturn(Nd4j.zeros(actionSize));
|
||||||
|
|
||||||
|
when(mockMDP.getObservationSpace()).thenReturn(mockObservationSpace);
|
||||||
|
when(mockMDP.getActionSpace()).thenReturn(mockActionSpace);
|
||||||
|
|
||||||
|
int dataLength = 1;
|
||||||
|
for (int d : observationShape) {
|
||||||
|
dataLength *= d;
|
||||||
|
}
|
||||||
|
|
||||||
|
when(mockMDP.reset()).thenReturn(new Box(new double[dataLength]));
|
||||||
|
}
|
||||||
|
|
||||||
|
private void mockTrainingListeners() {
|
||||||
|
mockTrainingListeners(false, false);
|
||||||
|
}
|
||||||
|
|
||||||
|
private void mockTrainingListeners(boolean stopOnNotifyNewEpoch, boolean stopOnNotifyEpochTrainingResult) {
|
||||||
|
when(mockTrainingListeners.notifyNewEpoch(eq(thread))).thenReturn(!stopOnNotifyNewEpoch);
|
||||||
|
when(mockTrainingListeners.notifyEpochTrainingResult(eq(thread), any(IDataManager.StatEntry.class))).thenReturn(!stopOnNotifyEpochTrainingResult);
|
||||||
|
}
|
||||||
|
|
||||||
|
private void mockTrainingContext() {
|
||||||
|
mockTrainingContext(1000, 100, 10);
|
||||||
|
}
|
||||||
|
|
||||||
|
private void mockTrainingContext(int maxSteps, int maxStepsPerEpisode, int nstep) {
|
||||||
|
|
||||||
|
// Some conditions of this test harness
|
||||||
|
Preconditions.checkArgument(maxStepsPerEpisode >= nstep, "episodeLength must be greater than or equal to nstep");
|
||||||
|
Preconditions.checkArgument(maxStepsPerEpisode % nstep == 0, "episodeLength must be a multiple of nstep");
|
||||||
|
|
||||||
|
Observation mockObs = new Observation(Nd4j.zeros(observationShape));
|
||||||
|
|
||||||
|
when(mockAsyncConfiguration.getMaxEpochStep()).thenReturn(maxStepsPerEpisode);
|
||||||
|
when(mockAsyncConfiguration.getNStep()).thenReturn(nstep);
|
||||||
|
when(thread.getConf()).thenReturn(mockAsyncConfiguration);
|
||||||
|
|
||||||
|
// if we hit the max step count
|
||||||
|
when(mockAsyncGlobal.isTrainingComplete()).thenAnswer(invocation -> thread.getStepCount() >= maxSteps);
|
||||||
|
|
||||||
|
when(thread.trainSubEpoch(any(Observation.class), anyInt())).thenAnswer(invocationOnMock -> {
|
||||||
|
int steps = invocationOnMock.getArgument(1);
|
||||||
|
thread.stepCount += steps;
|
||||||
|
thread.currentEpisodeStepCount += steps;
|
||||||
|
boolean isEpisodeComplete = thread.getCurrentEpisodeStepCount() % maxStepsPerEpisode == 0;
|
||||||
|
return new AsyncThread.SubEpochReturn(steps, mockObs, 0.0, 0.0, isEpisodeComplete);
|
||||||
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void when_onNewEpochReturnsStop_expect_threadStopped() {
|
public void when_episodeComplete_expect_neuralNetworkReset() {
|
||||||
|
|
||||||
// Arrange
|
// Arrange
|
||||||
int stopAfterNumCalls = 1;
|
mockTrainingContext(100, 10, 10);
|
||||||
TestContext context = new TestContext(100000);
|
mockTrainingListeners();
|
||||||
context.listener.setRemainingOnNewEpochCallCount(stopAfterNumCalls);
|
|
||||||
|
|
||||||
// Act
|
// Act
|
||||||
context.sut.run();
|
thread.run();
|
||||||
|
|
||||||
// Assert
|
// Assert
|
||||||
assertEquals(stopAfterNumCalls + 1, context.listener.onNewEpochCallCount); // +1: The call that returns stop is counted
|
verify(mockNeuralNet, times(10)).reset(); // there are 10 episodes so the network should be reset between each
|
||||||
assertEquals(stopAfterNumCalls, context.listener.onEpochTrainingResultCallCount);
|
assertEquals(10, thread.getEpochCount()); // We are performing a training iteration every 10 steps, so there should be 10 epochs
|
||||||
|
assertEquals(10, thread.getEpisodeCount()); // There should be 10 completed episodes
|
||||||
|
assertEquals(100, thread.getStepCount()); // 100 steps overall
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void when_epochTrainingResultReturnsStop_expect_threadStopped() {
|
public void when_notifyNewEpochReturnsStop_expect_threadStopped() {
|
||||||
// Arrange
|
// Arrange
|
||||||
int stopAfterNumCalls = 1;
|
mockTrainingContext();
|
||||||
TestContext context = new TestContext(100000);
|
mockTrainingListeners(true, false);
|
||||||
context.listener.setRemainingOnEpochTrainingResult(stopAfterNumCalls);
|
|
||||||
|
|
||||||
// Act
|
// Act
|
||||||
context.sut.run();
|
thread.run();
|
||||||
|
|
||||||
// Assert
|
// Assert
|
||||||
assertEquals(stopAfterNumCalls + 1, context.listener.onEpochTrainingResultCallCount); // +1: The call that returns stop is counted
|
assertEquals(0, thread.getEpochCount());
|
||||||
assertEquals(stopAfterNumCalls + 1, context.listener.onNewEpochCallCount); // +1: onNewEpoch is called on the epoch that onEpochTrainingResult() will stop
|
assertEquals(1, thread.getEpisodeCount());
|
||||||
|
assertEquals(0, thread.getStepCount());
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void when_run_expect_preAndPostEpochCalled() {
|
public void when_notifyEpochTrainingResultReturnsStop_expect_threadStopped() {
|
||||||
// Arrange
|
// Arrange
|
||||||
int numberOfEpochs = 5;
|
mockTrainingContext();
|
||||||
TestContext context = new TestContext(numberOfEpochs);
|
mockTrainingListeners(false, true);
|
||||||
|
|
||||||
// Act
|
// Act
|
||||||
context.sut.run();
|
thread.run();
|
||||||
|
|
||||||
// Assert
|
// Assert
|
||||||
assertEquals(numberOfEpochs, context.sut.preEpochCallCount);
|
assertEquals(1, thread.getEpochCount());
|
||||||
assertEquals(numberOfEpochs, context.sut.postEpochCallCount);
|
assertEquals(1, thread.getEpisodeCount());
|
||||||
|
assertEquals(10, thread.getStepCount()); // one epoch is by default 10 steps
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void when_run_expect_preAndPostEpisodeCalled() {
|
||||||
|
// Arrange
|
||||||
|
mockTrainingContext(100, 10, 5);
|
||||||
|
mockTrainingListeners(false, false);
|
||||||
|
|
||||||
|
// Act
|
||||||
|
thread.run();
|
||||||
|
|
||||||
|
// Assert
|
||||||
|
assertEquals(20, thread.getEpochCount());
|
||||||
|
assertEquals(10, thread.getEpisodeCount());
|
||||||
|
assertEquals(100, thread.getStepCount());
|
||||||
|
|
||||||
|
verify(thread, times(10)).preEpisode(); // over 100 steps there will be 10 episodes
|
||||||
|
verify(thread, times(10)).postEpisode();
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void when_run_expect_trainSubEpochCalledAndResultPassedToListeners() {
|
public void when_run_expect_trainSubEpochCalledAndResultPassedToListeners() {
|
||||||
// Arrange
|
// Arrange
|
||||||
int numberOfEpochs = 5;
|
mockTrainingContext(100, 10, 5);
|
||||||
TestContext context = new TestContext(numberOfEpochs);
|
mockTrainingListeners(false, false);
|
||||||
|
|
||||||
// Act
|
// Act
|
||||||
context.sut.run();
|
thread.run();
|
||||||
|
|
||||||
// Assert
|
// Assert
|
||||||
assertEquals(numberOfEpochs, context.listener.statEntries.size());
|
assertEquals(20, thread.getEpochCount());
|
||||||
int[] expectedStepCounter = new int[] { 10, 20, 30, 40, 50 };
|
assertEquals(10, thread.getEpisodeCount());
|
||||||
double expectedReward = (1.0 + 2.0 + 3.0 + 4.0 + 5.0 + 6.0 + 7.0 + 8.0) // reward from init
|
assertEquals(100, thread.getStepCount());
|
||||||
+ 1.0; // Reward from trainSubEpoch()
|
|
||||||
for(int i = 0; i < numberOfEpochs; ++i) {
|
// Over 100 steps there will be 20 training iterations, so there will be 20 calls to notifyEpochTrainingResult
|
||||||
IDataManager.StatEntry statEntry = context.listener.statEntries.get(i);
|
verify(mockTrainingListeners, times(20)).notifyEpochTrainingResult(eq(thread), any(IDataManager.StatEntry.class));
|
||||||
assertEquals(expectedStepCounter[i], statEntry.getStepCounter());
|
|
||||||
assertEquals(i, statEntry.getEpochCounter());
|
|
||||||
assertEquals(expectedReward, statEntry.getReward(), 0.0001);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void when_run_expect_trainSubEpochCalled() {
|
public void when_run_expect_trainSubEpochCalled() {
|
||||||
// Arrange
|
// Arrange
|
||||||
int numberOfEpochs = 5;
|
mockTrainingContext(100, 10, 5);
|
||||||
TestContext context = new TestContext(numberOfEpochs);
|
mockTrainingListeners(false, false);
|
||||||
|
|
||||||
// Act
|
// Act
|
||||||
context.sut.run();
|
thread.run();
|
||||||
|
|
||||||
// Assert
|
// Assert
|
||||||
assertEquals(numberOfEpochs, context.sut.trainSubEpochParams.size());
|
assertEquals(20, thread.getEpochCount());
|
||||||
double[] expectedObservation = new double[] { 0.0, 2.0, 4.0, 6.0, 8.0 };
|
assertEquals(10, thread.getEpisodeCount());
|
||||||
for(int i = 0; i < context.sut.trainSubEpochParams.size(); ++i) {
|
assertEquals(100, thread.getStepCount());
|
||||||
MockAsyncThread.TrainSubEpochParams params = context.sut.trainSubEpochParams.get(i);
|
|
||||||
assertEquals(2, params.nstep);
|
// There should be 20 calls to trainsubepoch with 5 steps per epoch
|
||||||
assertEquals(expectedObservation.length, params.obs.getData().shape()[1]);
|
verify(thread, times(20)).trainSubEpoch(any(Observation.class), eq(5));
|
||||||
for(int j = 0; j < expectedObservation.length; ++j){
|
}
|
||||||
assertEquals(expectedObservation[j], 255.0 * params.obs.getData().getDouble(j), 0.00001);
|
|
||||||
|
@Test
|
||||||
|
public void when_remainingEpisodeLengthSmallerThanNSteps_expect_trainSubEpochCalledWithMinimumValue() {
|
||||||
|
|
||||||
|
int currentEpisodeSteps = 95;
|
||||||
|
mockTrainingContext(1000, 100, 10);
|
||||||
|
mockTrainingListeners(false, true);
|
||||||
|
|
||||||
|
// want to mock that we are 95 steps into the episode
|
||||||
|
doAnswer(invocationOnMock -> {
|
||||||
|
for (int i = 0; i < currentEpisodeSteps; i++) {
|
||||||
|
thread.incrementSteps();
|
||||||
}
|
}
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
private static class TestContext {
|
|
||||||
public final MockAsyncGlobal asyncGlobal = new MockAsyncGlobal();
|
|
||||||
public final MockNeuralNet neuralNet = new MockNeuralNet();
|
|
||||||
public final MockObservationSpace observationSpace = new MockObservationSpace();
|
|
||||||
public final MockMDP mdp = new MockMDP(observationSpace);
|
|
||||||
public final MockAsyncConfiguration config = new MockAsyncConfiguration(5L, 10, 0, 0, 0, 0, 0, 0, 10, 0);
|
|
||||||
public final TrainingListenerList listeners = new TrainingListenerList();
|
|
||||||
public final MockTrainingListener listener = new MockTrainingListener();
|
|
||||||
public final IHistoryProcessor.Configuration hpConf = new IHistoryProcessor.Configuration(5, 4, 4, 4, 4, 0, 0, 2);
|
|
||||||
public final MockHistoryProcessor historyProcessor = new MockHistoryProcessor(hpConf);
|
|
||||||
|
|
||||||
public final MockAsyncThread sut = new MockAsyncThread(asyncGlobal, 0, neuralNet, mdp, config, listeners);
|
|
||||||
|
|
||||||
public TestContext(int numEpochs) {
|
|
||||||
asyncGlobal.setMaxLoops(numEpochs);
|
|
||||||
listeners.add(listener);
|
|
||||||
sut.setHistoryProcessor(historyProcessor);
|
|
||||||
sut.getLegacyMDPWrapper().setTransformProcess(MockMDP.buildTransformProcess(observationSpace.getShape(), hpConf.getSkipFrame(), hpConf.getHistoryLength()));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
public static class MockAsyncThread extends AsyncThread<MockEncodable, Integer, DiscreteSpace, MockNeuralNet> {
|
|
||||||
|
|
||||||
public int preEpochCallCount = 0;
|
|
||||||
public int postEpochCallCount = 0;
|
|
||||||
|
|
||||||
private final MockAsyncGlobal asyncGlobal;
|
|
||||||
private final MockNeuralNet neuralNet;
|
|
||||||
private final IAsyncLearningConfiguration conf;
|
|
||||||
|
|
||||||
private final List<TrainSubEpochParams> trainSubEpochParams = new ArrayList<TrainSubEpochParams>();
|
|
||||||
|
|
||||||
public MockAsyncThread(MockAsyncGlobal asyncGlobal, int threadNumber, MockNeuralNet neuralNet, MDP mdp, IAsyncLearningConfiguration conf, TrainingListenerList listeners) {
|
|
||||||
super(asyncGlobal, mdp, listeners, threadNumber, 0);
|
|
||||||
|
|
||||||
this.asyncGlobal = asyncGlobal;
|
|
||||||
this.neuralNet = neuralNet;
|
|
||||||
this.conf = conf;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
protected void preEpoch() {
|
|
||||||
++preEpochCallCount;
|
|
||||||
super.preEpoch();
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
protected void postEpoch() {
|
|
||||||
++postEpochCallCount;
|
|
||||||
super.postEpoch();
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
protected MockNeuralNet getCurrent() {
|
|
||||||
return neuralNet;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
protected IAsyncGlobal getAsyncGlobal() {
|
|
||||||
return asyncGlobal;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
protected IAsyncLearningConfiguration getConf() {
|
|
||||||
return conf;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
protected Policy getPolicy(MockNeuralNet net) {
|
|
||||||
return null;
|
return null;
|
||||||
}
|
}).when(thread).preEpisode();
|
||||||
|
|
||||||
@Override
|
mockTrainingListeners(false, true);
|
||||||
protected SubEpochReturn trainSubEpoch(Observation obs, int nstep) {
|
|
||||||
asyncGlobal.increaseCurrentLoop();
|
|
||||||
trainSubEpochParams.add(new TrainSubEpochParams(obs, nstep));
|
|
||||||
for(int i = 0; i < nstep; ++i) {
|
|
||||||
incrementStep();
|
|
||||||
}
|
|
||||||
return new SubEpochReturn(nstep, null, 1.0, 1.0);
|
|
||||||
}
|
|
||||||
|
|
||||||
@AllArgsConstructor
|
// Act
|
||||||
@Getter
|
thread.run();
|
||||||
public static class TrainSubEpochParams {
|
|
||||||
Observation obs;
|
// Assert
|
||||||
int nstep;
|
assertEquals(1, thread.getEpochCount());
|
||||||
}
|
assertEquals(1, thread.getEpisodeCount());
|
||||||
|
assertEquals(100, thread.getStepCount());
|
||||||
|
|
||||||
|
// There should be 1 call to trainsubepoch with 5 steps as this is the remaining episode steps
|
||||||
|
verify(thread, times(1)).trainSubEpoch(any(Observation.class), eq(5));
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,160 +0,0 @@
|
||||||
package org.deeplearning4j.rl4j.learning.async.a3c.discrete;
|
|
||||||
|
|
||||||
import org.deeplearning4j.nn.api.NeuralNetwork;
|
|
||||||
import org.deeplearning4j.nn.gradient.Gradient;
|
|
||||||
import org.deeplearning4j.rl4j.experience.StateActionPair;
|
|
||||||
import org.deeplearning4j.rl4j.network.NeuralNet;
|
|
||||||
import org.deeplearning4j.rl4j.network.ac.IActorCritic;
|
|
||||||
import org.deeplearning4j.rl4j.observation.Observation;
|
|
||||||
import org.deeplearning4j.rl4j.support.MockAsyncGlobal;
|
|
||||||
import org.deeplearning4j.rl4j.support.MockMDP;
|
|
||||||
import org.deeplearning4j.rl4j.support.MockObservationSpace;
|
|
||||||
import org.junit.Test;
|
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
|
||||||
import org.nd4j.linalg.primitives.Pair;
|
|
||||||
|
|
||||||
import java.io.IOException;
|
|
||||||
import java.io.OutputStream;
|
|
||||||
import java.util.ArrayList;
|
|
||||||
import java.util.List;
|
|
||||||
|
|
||||||
import static org.junit.Assert.assertEquals;
|
|
||||||
|
|
||||||
public class A3CUpdateAlgorithmTest {
|
|
||||||
|
|
||||||
@Test
|
|
||||||
public void refac_calcGradient_non_terminal() {
|
|
||||||
// Arrange
|
|
||||||
double gamma = 0.9;
|
|
||||||
MockObservationSpace observationSpace = new MockObservationSpace(new int[] { 5 });
|
|
||||||
MockMDP mdpMock = new MockMDP(observationSpace);
|
|
||||||
MockActorCritic actorCriticMock = new MockActorCritic();
|
|
||||||
MockAsyncGlobal<IActorCritic> asyncGlobalMock = new MockAsyncGlobal<IActorCritic>(actorCriticMock);
|
|
||||||
A3CUpdateAlgorithm sut = new A3CUpdateAlgorithm(asyncGlobalMock, observationSpace.getShape(), mdpMock.getActionSpace().getSize(), -1, gamma);
|
|
||||||
|
|
||||||
|
|
||||||
INDArray[] originalObservations = new INDArray[] {
|
|
||||||
Nd4j.create(new double[] { 0.0, 0.1, 0.2, 0.3, 0.4 }),
|
|
||||||
Nd4j.create(new double[] { 1.0, 1.1, 1.2, 1.3, 1.4 }),
|
|
||||||
Nd4j.create(new double[] { 2.0, 2.1, 2.2, 2.3, 2.4 }),
|
|
||||||
Nd4j.create(new double[] { 3.0, 3.1, 3.2, 3.3, 3.4 }),
|
|
||||||
};
|
|
||||||
int[] actions = new int[] { 0, 1, 2, 1 };
|
|
||||||
double[] rewards = new double[] { 0.1, 1.0, 10.0, 100.0 };
|
|
||||||
|
|
||||||
List<StateActionPair<Integer>> experience = new ArrayList<StateActionPair<Integer>>();
|
|
||||||
for(int i = 0; i < originalObservations.length; ++i) {
|
|
||||||
experience.add(new StateActionPair<>(new Observation(originalObservations[i]), actions[i], rewards[i], false));
|
|
||||||
}
|
|
||||||
|
|
||||||
// Act
|
|
||||||
sut.computeGradients(actorCriticMock, experience);
|
|
||||||
|
|
||||||
// Assert
|
|
||||||
assertEquals(1, actorCriticMock.gradientParams.size());
|
|
||||||
|
|
||||||
// Inputs
|
|
||||||
INDArray input = actorCriticMock.gradientParams.get(0).getLeft();
|
|
||||||
for(int i = 0; i < 4; ++i) {
|
|
||||||
for(int j = 0; j < 5; ++j) {
|
|
||||||
assertEquals(i + j / 10.0, input.getDouble(i, j), 0.00001);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
INDArray targets = actorCriticMock.gradientParams.get(0).getRight()[0];
|
|
||||||
INDArray logSoftmax = actorCriticMock.gradientParams.get(0).getRight()[1];
|
|
||||||
|
|
||||||
assertEquals(4, targets.shape()[0]);
|
|
||||||
assertEquals(1, targets.shape()[1]);
|
|
||||||
|
|
||||||
// FIXME: check targets values once fixed
|
|
||||||
|
|
||||||
assertEquals(4, logSoftmax.shape()[0]);
|
|
||||||
assertEquals(5, logSoftmax.shape()[1]);
|
|
||||||
|
|
||||||
// FIXME: check logSoftmax values once fixed
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
public class MockActorCritic implements IActorCritic {
|
|
||||||
|
|
||||||
public final List<Pair<INDArray, INDArray[]>> gradientParams = new ArrayList<>();
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public NeuralNetwork[] getNeuralNetworks() {
|
|
||||||
return new NeuralNetwork[0];
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public boolean isRecurrent() {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void reset() {
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void fit(INDArray input, INDArray[] labels) {
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public INDArray[] outputAll(INDArray batch) {
|
|
||||||
return new INDArray[] { batch.mul(-1.0) };
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public IActorCritic clone() {
|
|
||||||
return this;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void copy(NeuralNet from) {
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void copy(IActorCritic from) {
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public Gradient[] gradient(INDArray input, INDArray[] labels) {
|
|
||||||
gradientParams.add(new Pair<INDArray, INDArray[]>(input, labels));
|
|
||||||
return new Gradient[0];
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void applyGradient(Gradient[] gradient, int batchSize) {
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void save(OutputStream streamValue, OutputStream streamPolicy) throws IOException {
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void save(String pathValue, String pathPolicy) throws IOException {
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public double getLatestScore() {
|
|
||||||
return 0;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void save(OutputStream os) throws IOException {
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void save(String filename) throws IOException {
|
|
||||||
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -0,0 +1,93 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2020 Konduit K.K.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
package org.deeplearning4j.rl4j.learning.async.a3c.discrete;
|
||||||
|
|
||||||
|
import org.deeplearning4j.rl4j.experience.StateActionPair;
|
||||||
|
import org.deeplearning4j.rl4j.learning.async.AsyncGlobal;
|
||||||
|
import org.deeplearning4j.rl4j.network.NeuralNet;
|
||||||
|
import org.deeplearning4j.rl4j.network.ac.IActorCritic;
|
||||||
|
import org.deeplearning4j.rl4j.observation.Observation;
|
||||||
|
import org.junit.Test;
|
||||||
|
import org.junit.runner.RunWith;
|
||||||
|
import org.mockito.ArgumentCaptor;
|
||||||
|
import org.mockito.Mock;
|
||||||
|
import org.mockito.junit.MockitoJUnitRunner;
|
||||||
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
|
|
||||||
|
import java.util.ArrayList;
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
|
import static org.junit.Assert.assertEquals;
|
||||||
|
import static org.mockito.ArgumentMatchers.any;
|
||||||
|
import static org.mockito.Mockito.times;
|
||||||
|
import static org.mockito.Mockito.verify;
|
||||||
|
import static org.mockito.Mockito.when;
|
||||||
|
|
||||||
|
@RunWith(MockitoJUnitRunner.class)
|
||||||
|
public class AdvantageActorCriticUpdateAlgorithmTest {
|
||||||
|
|
||||||
|
@Mock
|
||||||
|
AsyncGlobal<NeuralNet> mockAsyncGlobal;
|
||||||
|
|
||||||
|
@Mock
|
||||||
|
IActorCritic mockActorCritic;
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void refac_calcGradient_non_terminal() {
|
||||||
|
// Arrange
|
||||||
|
int[] observationShape = new int[]{5};
|
||||||
|
double gamma = 0.9;
|
||||||
|
AdvantageActorCriticUpdateAlgorithm algorithm = new AdvantageActorCriticUpdateAlgorithm(false, observationShape, 1, gamma);
|
||||||
|
|
||||||
|
INDArray[] originalObservations = new INDArray[]{
|
||||||
|
Nd4j.create(new double[]{0.0, 0.1, 0.2, 0.3, 0.4}),
|
||||||
|
Nd4j.create(new double[]{1.0, 1.1, 1.2, 1.3, 1.4}),
|
||||||
|
Nd4j.create(new double[]{2.0, 2.1, 2.2, 2.3, 2.4}),
|
||||||
|
Nd4j.create(new double[]{3.0, 3.1, 3.2, 3.3, 3.4}),
|
||||||
|
};
|
||||||
|
|
||||||
|
int[] actions = new int[]{0, 1, 2, 1};
|
||||||
|
double[] rewards = new double[]{0.1, 1.0, 10.0, 100.0};
|
||||||
|
|
||||||
|
List<StateActionPair<Integer>> experience = new ArrayList<StateActionPair<Integer>>();
|
||||||
|
for (int i = 0; i < originalObservations.length; ++i) {
|
||||||
|
experience.add(new StateActionPair<>(new Observation(originalObservations[i]), actions[i], rewards[i], false));
|
||||||
|
}
|
||||||
|
|
||||||
|
when(mockActorCritic.outputAll(any(INDArray.class))).thenAnswer(invocation -> {
|
||||||
|
INDArray batch = invocation.getArgument(0);
|
||||||
|
return new INDArray[]{batch.mul(-1.0)};
|
||||||
|
});
|
||||||
|
|
||||||
|
ArgumentCaptor<INDArray> inputArgumentCaptor = ArgumentCaptor.forClass(INDArray.class);
|
||||||
|
ArgumentCaptor<INDArray[]> criticActorArgumentCaptor = ArgumentCaptor.forClass(INDArray[].class);
|
||||||
|
|
||||||
|
// Act
|
||||||
|
algorithm.computeGradients(mockActorCritic, experience);
|
||||||
|
|
||||||
|
verify(mockActorCritic, times(1)).gradient(inputArgumentCaptor.capture(), criticActorArgumentCaptor.capture());
|
||||||
|
|
||||||
|
assertEquals(Nd4j.stack(0, originalObservations), inputArgumentCaptor.getValue());
|
||||||
|
|
||||||
|
//TODO: the actual AdvantageActorCritic Algo is not implemented correctly, so needs to be fixed, then we can test these
|
||||||
|
// assertEquals(Nd4j.zeros(1), criticActorArgumentCaptor.getValue()[0]);
|
||||||
|
// assertEquals(Nd4j.zeros(1), criticActorArgumentCaptor.getValue()[1]);
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
|
@ -1,3 +1,19 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2020 Konduit K.K.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
package org.deeplearning4j.rl4j.learning.async.listener;
|
package org.deeplearning4j.rl4j.learning.async.listener;
|
||||||
|
|
||||||
import org.deeplearning4j.rl4j.learning.IEpochTrainer;
|
import org.deeplearning4j.rl4j.learning.IEpochTrainer;
|
||||||
|
|
|
@ -1,11 +1,30 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2020 Konduit K.K.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
package org.deeplearning4j.rl4j.learning.async.nstep.discrete;
|
package org.deeplearning4j.rl4j.learning.async.nstep.discrete;
|
||||||
|
|
||||||
import org.deeplearning4j.rl4j.experience.StateActionPair;
|
import org.deeplearning4j.rl4j.experience.StateActionPair;
|
||||||
|
import org.deeplearning4j.rl4j.learning.async.AsyncGlobal;
|
||||||
import org.deeplearning4j.rl4j.learning.async.UpdateAlgorithm;
|
import org.deeplearning4j.rl4j.learning.async.UpdateAlgorithm;
|
||||||
import org.deeplearning4j.rl4j.observation.Observation;
|
import org.deeplearning4j.rl4j.observation.Observation;
|
||||||
import org.deeplearning4j.rl4j.support.MockAsyncGlobal;
|
|
||||||
import org.deeplearning4j.rl4j.support.MockDQN;
|
import org.deeplearning4j.rl4j.support.MockDQN;
|
||||||
import org.junit.Test;
|
import org.junit.Test;
|
||||||
|
import org.junit.runner.RunWith;
|
||||||
|
import org.mockito.Mock;
|
||||||
|
import org.mockito.junit.MockitoJUnitRunner;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
|
|
||||||
|
@ -14,17 +33,21 @@ import java.util.List;
|
||||||
|
|
||||||
import static org.junit.Assert.assertEquals;
|
import static org.junit.Assert.assertEquals;
|
||||||
|
|
||||||
|
@RunWith(MockitoJUnitRunner.class)
|
||||||
public class QLearningUpdateAlgorithmTest {
|
public class QLearningUpdateAlgorithmTest {
|
||||||
|
|
||||||
|
@Mock
|
||||||
|
AsyncGlobal mockAsyncGlobal;
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void when_isTerminal_expect_initRewardIs0() {
|
public void when_isTerminal_expect_initRewardIs0() {
|
||||||
// Arrange
|
// Arrange
|
||||||
MockDQN dqnMock = new MockDQN();
|
MockDQN dqnMock = new MockDQN();
|
||||||
MockAsyncGlobal asyncGlobalMock = new MockAsyncGlobal(dqnMock);
|
UpdateAlgorithm sut = new QLearningUpdateAlgorithm(new int[] { 1 }, 1, 1.0);
|
||||||
UpdateAlgorithm sut = new QLearningUpdateAlgorithm(asyncGlobalMock, new int[] { 1 }, 1, -1, 1.0);
|
final Observation observation = new Observation(Nd4j.zeros(1));
|
||||||
List<StateActionPair<Integer>> experience = new ArrayList<StateActionPair<Integer>>() {
|
List<StateActionPair<Integer>> experience = new ArrayList<StateActionPair<Integer>>() {
|
||||||
{
|
{
|
||||||
add(new StateActionPair<Integer>(new Observation(Nd4j.zeros(1)), 0, 0.0, true));
|
add(new StateActionPair<Integer>(observation, 0, 0.0, true));
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -38,12 +61,11 @@ public class QLearningUpdateAlgorithmTest {
|
||||||
@Test
|
@Test
|
||||||
public void when_terminalAndNoTargetUpdate_expect_initRewardWithMaxQFromCurrent() {
|
public void when_terminalAndNoTargetUpdate_expect_initRewardWithMaxQFromCurrent() {
|
||||||
// Arrange
|
// Arrange
|
||||||
MockDQN globalDQNMock = new MockDQN();
|
UpdateAlgorithm sut = new QLearningUpdateAlgorithm(new int[] { 2 }, 2, 1.0);
|
||||||
MockAsyncGlobal asyncGlobalMock = new MockAsyncGlobal(globalDQNMock);
|
final Observation observation = new Observation(Nd4j.create(new double[] { -123.0, -234.0 }));
|
||||||
UpdateAlgorithm sut = new QLearningUpdateAlgorithm(asyncGlobalMock, new int[] { 2 }, 2, -1, 1.0);
|
|
||||||
List<StateActionPair<Integer>> experience = new ArrayList<StateActionPair<Integer>>() {
|
List<StateActionPair<Integer>> experience = new ArrayList<StateActionPair<Integer>>() {
|
||||||
{
|
{
|
||||||
add(new StateActionPair<Integer>(new Observation(Nd4j.create(new double[] { -123.0, -234.0 })), 0, 0.0, false));
|
add(new StateActionPair<Integer>(observation, 0, 0.0, false));
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
MockDQN dqnMock = new MockDQN();
|
MockDQN dqnMock = new MockDQN();
|
||||||
|
@ -57,35 +79,11 @@ public class QLearningUpdateAlgorithmTest {
|
||||||
assertEquals(234.0, dqnMock.gradientParams.get(0).getRight().getDouble(0), 0.00001);
|
assertEquals(234.0, dqnMock.gradientParams.get(0).getRight().getDouble(0), 0.00001);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
public void when_terminalWithTargetUpdate_expect_initRewardWithMaxQFromGlobal() {
|
|
||||||
// Arrange
|
|
||||||
MockDQN globalDQNMock = new MockDQN();
|
|
||||||
MockAsyncGlobal asyncGlobalMock = new MockAsyncGlobal(globalDQNMock);
|
|
||||||
UpdateAlgorithm sut = new QLearningUpdateAlgorithm(asyncGlobalMock, new int[] { 2 }, 2, 1, 1.0);
|
|
||||||
List<StateActionPair<Integer>> experience = new ArrayList<StateActionPair<Integer>>() {
|
|
||||||
{
|
|
||||||
add(new StateActionPair<Integer>(new Observation(Nd4j.create(new double[] { -123.0, -234.0 })), 0, 0.0, false));
|
|
||||||
}
|
|
||||||
};
|
|
||||||
MockDQN dqnMock = new MockDQN();
|
|
||||||
|
|
||||||
// Act
|
|
||||||
sut.computeGradients(dqnMock, experience);
|
|
||||||
|
|
||||||
// Assert
|
|
||||||
assertEquals(1, globalDQNMock.outputAllParams.size());
|
|
||||||
assertEquals(-123.0, globalDQNMock.outputAllParams.get(0).getDouble(0, 0), 0.00001);
|
|
||||||
assertEquals(234.0, dqnMock.gradientParams.get(0).getRight().getDouble(0), 0.00001);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void when_callingWithMultipleExperiences_expect_gradientsAreValid() {
|
public void when_callingWithMultipleExperiences_expect_gradientsAreValid() {
|
||||||
// Arrange
|
// Arrange
|
||||||
double gamma = 0.9;
|
double gamma = 0.9;
|
||||||
MockDQN globalDQNMock = new MockDQN();
|
UpdateAlgorithm sut = new QLearningUpdateAlgorithm(new int[] { 2 }, 2, gamma);
|
||||||
MockAsyncGlobal asyncGlobalMock = new MockAsyncGlobal(globalDQNMock);
|
|
||||||
UpdateAlgorithm sut = new QLearningUpdateAlgorithm(asyncGlobalMock, new int[] { 2 }, 2, 1, gamma);
|
|
||||||
List<StateActionPair<Integer>> experience = new ArrayList<StateActionPair<Integer>>() {
|
List<StateActionPair<Integer>> experience = new ArrayList<StateActionPair<Integer>>() {
|
||||||
{
|
{
|
||||||
add(new StateActionPair<Integer>(new Observation(Nd4j.create(new double[] { -1.1, -1.2 })), 0, 1.0, false));
|
add(new StateActionPair<Integer>(new Observation(Nd4j.create(new double[] { -1.1, -1.2 })), 0, 1.0, false));
|
||||||
|
|
|
@ -1,20 +1,56 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2020 Konduit K.K.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
package org.deeplearning4j.rl4j.learning.listener;
|
package org.deeplearning4j.rl4j.learning.listener;
|
||||||
|
|
||||||
import org.deeplearning4j.rl4j.support.MockTrainingListener;
|
import org.deeplearning4j.rl4j.learning.IEpochTrainer;
|
||||||
|
import org.deeplearning4j.rl4j.learning.ILearning;
|
||||||
|
import org.deeplearning4j.rl4j.util.IDataManager;
|
||||||
import org.junit.Test;
|
import org.junit.Test;
|
||||||
|
import org.mockito.Mock;
|
||||||
|
|
||||||
import static org.junit.Assert.*;
|
import static org.junit.Assert.*;
|
||||||
|
import static org.mockito.ArgumentMatchers.any;
|
||||||
|
import static org.mockito.ArgumentMatchers.eq;
|
||||||
|
import static org.mockito.Mockito.mock;
|
||||||
|
import static org.mockito.Mockito.never;
|
||||||
|
import static org.mockito.Mockito.times;
|
||||||
|
import static org.mockito.Mockito.verify;
|
||||||
|
import static org.mockito.Mockito.when;
|
||||||
|
|
||||||
public class TrainingListenerListTest {
|
public class TrainingListenerListTest {
|
||||||
|
|
||||||
|
@Mock
|
||||||
|
IEpochTrainer mockTrainer;
|
||||||
|
|
||||||
|
@Mock
|
||||||
|
ILearning mockLearning;
|
||||||
|
|
||||||
|
@Mock
|
||||||
|
IDataManager.StatEntry mockStatEntry;
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void when_listIsEmpty_expect_notifyReturnTrue() {
|
public void when_listIsEmpty_expect_notifyReturnTrue() {
|
||||||
// Arrange
|
// Arrange
|
||||||
TrainingListenerList sut = new TrainingListenerList();
|
TrainingListenerList trainingListenerList = new TrainingListenerList();
|
||||||
|
|
||||||
// Act
|
// Act
|
||||||
boolean resultTrainingStarted = sut.notifyTrainingStarted();
|
boolean resultTrainingStarted = trainingListenerList.notifyTrainingStarted();
|
||||||
boolean resultNewEpoch = sut.notifyNewEpoch(null);
|
boolean resultNewEpoch = trainingListenerList.notifyNewEpoch(null);
|
||||||
boolean resultEpochFinished = sut.notifyEpochTrainingResult(null, null);
|
boolean resultEpochFinished = trainingListenerList.notifyEpochTrainingResult(null, null);
|
||||||
|
|
||||||
// Assert
|
// Assert
|
||||||
assertTrue(resultTrainingStarted);
|
assertTrue(resultTrainingStarted);
|
||||||
|
@ -25,54 +61,56 @@ public class TrainingListenerListTest {
|
||||||
@Test
|
@Test
|
||||||
public void when_firstListerStops_expect_othersListnersNotCalled() {
|
public void when_firstListerStops_expect_othersListnersNotCalled() {
|
||||||
// Arrange
|
// Arrange
|
||||||
MockTrainingListener listener1 = new MockTrainingListener();
|
TrainingListener listener1 = mock(TrainingListener.class);
|
||||||
listener1.setRemainingTrainingStartCallCount(0);
|
TrainingListener listener2 = mock(TrainingListener.class);
|
||||||
listener1.setRemainingOnNewEpochCallCount(0);
|
TrainingListenerList trainingListenerList = new TrainingListenerList();
|
||||||
listener1.setRemainingonTrainingProgressCallCount(0);
|
trainingListenerList.add(listener1);
|
||||||
listener1.setRemainingOnEpochTrainingResult(0);
|
trainingListenerList.add(listener2);
|
||||||
MockTrainingListener listener2 = new MockTrainingListener();
|
|
||||||
TrainingListenerList sut = new TrainingListenerList();
|
when(listener1.onTrainingStart()).thenReturn(TrainingListener.ListenerResponse.STOP);
|
||||||
sut.add(listener1);
|
when(listener1.onNewEpoch(eq(mockTrainer))).thenReturn(TrainingListener.ListenerResponse.STOP);
|
||||||
sut.add(listener2);
|
when(listener1.onEpochTrainingResult(eq(mockTrainer), eq(mockStatEntry))).thenReturn(TrainingListener.ListenerResponse.STOP);
|
||||||
|
when(listener1.onTrainingProgress(eq(mockLearning))).thenReturn(TrainingListener.ListenerResponse.STOP);
|
||||||
|
|
||||||
// Act
|
// Act
|
||||||
sut.notifyTrainingStarted();
|
trainingListenerList.notifyTrainingStarted();
|
||||||
sut.notifyNewEpoch(null);
|
trainingListenerList.notifyNewEpoch(mockTrainer);
|
||||||
sut.notifyEpochTrainingResult(null, null);
|
trainingListenerList.notifyEpochTrainingResult(mockTrainer, null);
|
||||||
sut.notifyTrainingProgress(null);
|
trainingListenerList.notifyTrainingProgress(mockLearning);
|
||||||
sut.notifyTrainingFinished();
|
trainingListenerList.notifyTrainingFinished();
|
||||||
|
|
||||||
// Assert
|
// Assert
|
||||||
assertEquals(1, listener1.onTrainingStartCallCount);
|
|
||||||
assertEquals(0, listener2.onTrainingStartCallCount);
|
|
||||||
|
|
||||||
assertEquals(1, listener1.onNewEpochCallCount);
|
verify(listener1, times(1)).onTrainingStart();
|
||||||
assertEquals(0, listener2.onNewEpochCallCount);
|
verify(listener2, never()).onTrainingStart();
|
||||||
|
|
||||||
assertEquals(1, listener1.onEpochTrainingResultCallCount);
|
verify(listener1, times(1)).onNewEpoch(eq(mockTrainer));
|
||||||
assertEquals(0, listener2.onEpochTrainingResultCallCount);
|
verify(listener2, never()).onNewEpoch(eq(mockTrainer));
|
||||||
|
|
||||||
assertEquals(1, listener1.onTrainingProgressCallCount);
|
verify(listener1, times(1)).onEpochTrainingResult(eq(mockTrainer), eq(mockStatEntry));
|
||||||
assertEquals(0, listener2.onTrainingProgressCallCount);
|
verify(listener2, never()).onEpochTrainingResult(eq(mockTrainer), eq(mockStatEntry));
|
||||||
|
|
||||||
assertEquals(1, listener1.onTrainingEndCallCount);
|
verify(listener1, times(1)).onTrainingProgress(eq(mockLearning));
|
||||||
assertEquals(1, listener2.onTrainingEndCallCount);
|
verify(listener2, never()).onTrainingProgress(eq(mockLearning));
|
||||||
|
|
||||||
|
verify(listener1, times(1)).onTrainingEnd();
|
||||||
|
verify(listener2, times(1)).onTrainingEnd();
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void when_allListenersContinue_expect_listReturnsTrue() {
|
public void when_allListenersContinue_expect_listReturnsTrue() {
|
||||||
// Arrange
|
// Arrange
|
||||||
MockTrainingListener listener1 = new MockTrainingListener();
|
TrainingListener listener1 = mock(TrainingListener.class);
|
||||||
MockTrainingListener listener2 = new MockTrainingListener();
|
TrainingListener listener2 = mock(TrainingListener.class);
|
||||||
TrainingListenerList sut = new TrainingListenerList();
|
TrainingListenerList trainingListenerList = new TrainingListenerList();
|
||||||
sut.add(listener1);
|
trainingListenerList.add(listener1);
|
||||||
sut.add(listener2);
|
trainingListenerList.add(listener2);
|
||||||
|
|
||||||
// Act
|
// Act
|
||||||
boolean resultTrainingStarted = sut.notifyTrainingStarted();
|
boolean resultTrainingStarted = trainingListenerList.notifyTrainingStarted();
|
||||||
boolean resultNewEpoch = sut.notifyNewEpoch(null);
|
boolean resultNewEpoch = trainingListenerList.notifyNewEpoch(null);
|
||||||
boolean resultEpochTrainingResult = sut.notifyEpochTrainingResult(null, null);
|
boolean resultEpochTrainingResult = trainingListenerList.notifyEpochTrainingResult(null, null);
|
||||||
boolean resultProgress = sut.notifyTrainingProgress(null);
|
boolean resultProgress = trainingListenerList.notifyTrainingProgress(null);
|
||||||
|
|
||||||
// Assert
|
// Assert
|
||||||
assertTrue(resultTrainingStarted);
|
assertTrue(resultTrainingStarted);
|
||||||
|
|
|
@ -1,151 +1,117 @@
|
||||||
/*******************************************************************************
|
|
||||||
* Copyright (c) 2015-2019 Skymind, Inc.
|
|
||||||
* Copyright (c) 2020 Konduit K.K.
|
|
||||||
*
|
|
||||||
* This program and the accompanying materials are made available under the
|
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* License for the specific language governing permissions and limitations
|
|
||||||
* under the License.
|
|
||||||
*
|
|
||||||
* SPDX-License-Identifier: Apache-2.0
|
|
||||||
******************************************************************************/
|
|
||||||
|
|
||||||
package org.deeplearning4j.rl4j.learning.sync;
|
package org.deeplearning4j.rl4j.learning.sync;
|
||||||
|
|
||||||
import lombok.Getter;
|
import org.deeplearning4j.rl4j.learning.ILearning;
|
||||||
import org.deeplearning4j.rl4j.learning.configuration.ILearningConfiguration;
|
import org.deeplearning4j.rl4j.learning.configuration.ILearningConfiguration;
|
||||||
import org.deeplearning4j.rl4j.learning.configuration.LearningConfiguration;
|
import org.deeplearning4j.rl4j.learning.listener.TrainingListener;
|
||||||
import org.deeplearning4j.rl4j.learning.configuration.QLearningConfiguration;
|
|
||||||
import org.deeplearning4j.rl4j.learning.sync.qlearning.QLearning;
|
|
||||||
import org.deeplearning4j.rl4j.learning.sync.support.MockStatEntry;
|
import org.deeplearning4j.rl4j.learning.sync.support.MockStatEntry;
|
||||||
import org.deeplearning4j.rl4j.mdp.MDP;
|
|
||||||
import org.deeplearning4j.rl4j.network.NeuralNet;
|
import org.deeplearning4j.rl4j.network.NeuralNet;
|
||||||
import org.deeplearning4j.rl4j.policy.IPolicy;
|
import org.deeplearning4j.rl4j.space.ActionSpace;
|
||||||
import org.deeplearning4j.rl4j.support.MockTrainingListener;
|
import org.deeplearning4j.rl4j.space.Box;
|
||||||
import org.deeplearning4j.rl4j.util.IDataManager;
|
import org.deeplearning4j.rl4j.util.IDataManager;
|
||||||
|
import org.junit.Before;
|
||||||
import org.junit.Test;
|
import org.junit.Test;
|
||||||
|
import org.junit.runner.RunWith;
|
||||||
|
import org.mockito.Mock;
|
||||||
|
import org.mockito.Mockito;
|
||||||
|
import org.mockito.junit.MockitoJUnitRunner;
|
||||||
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
|
||||||
import static org.junit.Assert.assertEquals;
|
import static org.mockito.ArgumentMatchers.any;
|
||||||
|
import static org.mockito.ArgumentMatchers.eq;
|
||||||
|
import static org.mockito.Mockito.mock;
|
||||||
|
import static org.mockito.Mockito.times;
|
||||||
|
import static org.mockito.Mockito.verify;
|
||||||
|
import static org.mockito.Mockito.when;
|
||||||
|
|
||||||
|
@RunWith(MockitoJUnitRunner.class)
|
||||||
public class SyncLearningTest {
|
public class SyncLearningTest {
|
||||||
|
|
||||||
|
@Mock
|
||||||
|
TrainingListener mockTrainingListener;
|
||||||
|
|
||||||
|
SyncLearning<Box, INDArray, ActionSpace<INDArray>, NeuralNet> syncLearning;
|
||||||
|
|
||||||
|
@Mock
|
||||||
|
ILearningConfiguration mockLearningConfiguration;
|
||||||
|
|
||||||
|
@Before
|
||||||
|
public void setup() {
|
||||||
|
|
||||||
|
syncLearning = mock(SyncLearning.class, Mockito.withSettings()
|
||||||
|
.useConstructor()
|
||||||
|
.defaultAnswer(Mockito.CALLS_REAL_METHODS));
|
||||||
|
|
||||||
|
syncLearning.addListener(mockTrainingListener);
|
||||||
|
|
||||||
|
when(syncLearning.trainEpoch()).thenAnswer(invocation -> {
|
||||||
|
//syncLearning.incrementEpoch();
|
||||||
|
syncLearning.incrementStep();
|
||||||
|
return new MockStatEntry(syncLearning.getEpochCount(), syncLearning.getStepCount(), 1.0);
|
||||||
|
});
|
||||||
|
|
||||||
|
when(syncLearning.getConfiguration()).thenReturn(mockLearningConfiguration);
|
||||||
|
when(mockLearningConfiguration.getMaxStep()).thenReturn(100);
|
||||||
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void when_training_expect_listenersToBeCalled() {
|
public void when_training_expect_listenersToBeCalled() {
|
||||||
// Arrange
|
|
||||||
QLearningConfiguration lconfig = QLearningConfiguration.builder().maxStep(10).build();
|
|
||||||
MockTrainingListener listener = new MockTrainingListener();
|
|
||||||
MockSyncLearning sut = new MockSyncLearning(lconfig);
|
|
||||||
sut.addListener(listener);
|
|
||||||
|
|
||||||
// Act
|
// Act
|
||||||
sut.train();
|
syncLearning.train();
|
||||||
|
|
||||||
|
verify(mockTrainingListener, times(1)).onTrainingStart();
|
||||||
|
verify(mockTrainingListener, times(100)).onNewEpoch(eq(syncLearning));
|
||||||
|
verify(mockTrainingListener, times(100)).onEpochTrainingResult(eq(syncLearning), any(IDataManager.StatEntry.class));
|
||||||
|
verify(mockTrainingListener, times(1)).onTrainingEnd();
|
||||||
|
|
||||||
assertEquals(1, listener.onTrainingStartCallCount);
|
|
||||||
assertEquals(10, listener.onNewEpochCallCount);
|
|
||||||
assertEquals(10, listener.onEpochTrainingResultCallCount);
|
|
||||||
assertEquals(1, listener.onTrainingEndCallCount);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void when_trainingStartCanContinueFalse_expect_trainingStopped() {
|
public void when_trainingStartCanContinueFalse_expect_trainingStopped() {
|
||||||
// Arrange
|
// Arrange
|
||||||
QLearningConfiguration lconfig = QLearningConfiguration.builder().maxStep(10).build();
|
when(mockTrainingListener.onTrainingStart()).thenReturn(TrainingListener.ListenerResponse.STOP);
|
||||||
MockTrainingListener listener = new MockTrainingListener();
|
|
||||||
MockSyncLearning sut = new MockSyncLearning(lconfig);
|
|
||||||
sut.addListener(listener);
|
|
||||||
listener.setRemainingTrainingStartCallCount(0);
|
|
||||||
|
|
||||||
// Act
|
// Act
|
||||||
sut.train();
|
syncLearning.train();
|
||||||
|
|
||||||
assertEquals(1, listener.onTrainingStartCallCount);
|
verify(mockTrainingListener, times(1)).onTrainingStart();
|
||||||
assertEquals(0, listener.onNewEpochCallCount);
|
verify(mockTrainingListener, times(0)).onNewEpoch(eq(syncLearning));
|
||||||
assertEquals(0, listener.onEpochTrainingResultCallCount);
|
verify(mockTrainingListener, times(0)).onEpochTrainingResult(eq(syncLearning), any(IDataManager.StatEntry.class));
|
||||||
assertEquals(1, listener.onTrainingEndCallCount);
|
verify(mockTrainingListener, times(1)).onTrainingEnd();
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void when_newEpochCanContinueFalse_expect_trainingStopped() {
|
public void when_newEpochCanContinueFalse_expect_trainingStopped() {
|
||||||
// Arrange
|
// Arrange
|
||||||
QLearningConfiguration lconfig = QLearningConfiguration.builder().maxStep(10).build();
|
when(mockTrainingListener.onNewEpoch(eq(syncLearning)))
|
||||||
MockTrainingListener listener = new MockTrainingListener();
|
.thenReturn(TrainingListener.ListenerResponse.CONTINUE)
|
||||||
MockSyncLearning sut = new MockSyncLearning(lconfig);
|
.thenReturn(TrainingListener.ListenerResponse.CONTINUE)
|
||||||
sut.addListener(listener);
|
.thenReturn(TrainingListener.ListenerResponse.STOP);
|
||||||
listener.setRemainingOnNewEpochCallCount(2);
|
|
||||||
|
|
||||||
// Act
|
// Act
|
||||||
sut.train();
|
syncLearning.train();
|
||||||
|
|
||||||
|
verify(mockTrainingListener, times(1)).onTrainingStart();
|
||||||
|
verify(mockTrainingListener, times(3)).onNewEpoch(eq(syncLearning));
|
||||||
|
verify(mockTrainingListener, times(2)).onEpochTrainingResult(eq(syncLearning), any(IDataManager.StatEntry.class));
|
||||||
|
verify(mockTrainingListener, times(1)).onTrainingEnd();
|
||||||
|
|
||||||
assertEquals(1, listener.onTrainingStartCallCount);
|
|
||||||
assertEquals(3, listener.onNewEpochCallCount);
|
|
||||||
assertEquals(2, listener.onEpochTrainingResultCallCount);
|
|
||||||
assertEquals(1, listener.onTrainingEndCallCount);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void when_epochTrainingResultCanContinueFalse_expect_trainingStopped() {
|
public void when_epochTrainingResultCanContinueFalse_expect_trainingStopped() {
|
||||||
// Arrange
|
// Arrange
|
||||||
LearningConfiguration lconfig = QLearningConfiguration.builder().maxStep(10).build();
|
when(mockTrainingListener.onEpochTrainingResult(eq(syncLearning), any(IDataManager.StatEntry.class)))
|
||||||
MockTrainingListener listener = new MockTrainingListener();
|
.thenReturn(TrainingListener.ListenerResponse.CONTINUE)
|
||||||
MockSyncLearning sut = new MockSyncLearning(lconfig);
|
.thenReturn(TrainingListener.ListenerResponse.CONTINUE)
|
||||||
sut.addListener(listener);
|
.thenReturn(TrainingListener.ListenerResponse.STOP);
|
||||||
listener.setRemainingOnEpochTrainingResult(2);
|
|
||||||
|
|
||||||
// Act
|
// Act
|
||||||
sut.train();
|
syncLearning.train();
|
||||||
|
|
||||||
assertEquals(1, listener.onTrainingStartCallCount);
|
verify(mockTrainingListener, times(1)).onTrainingStart();
|
||||||
assertEquals(3, listener.onNewEpochCallCount);
|
verify(mockTrainingListener, times(3)).onNewEpoch(eq(syncLearning));
|
||||||
assertEquals(3, listener.onEpochTrainingResultCallCount);
|
verify(mockTrainingListener, times(3)).onEpochTrainingResult(eq(syncLearning), any(IDataManager.StatEntry.class));
|
||||||
assertEquals(1, listener.onTrainingEndCallCount);
|
verify(mockTrainingListener, times(1)).onTrainingEnd();
|
||||||
}
|
|
||||||
|
|
||||||
public static class MockSyncLearning extends SyncLearning {
|
|
||||||
|
|
||||||
private final ILearningConfiguration conf;
|
|
||||||
|
|
||||||
@Getter
|
|
||||||
private int currentEpochStep = 0;
|
|
||||||
|
|
||||||
public MockSyncLearning(ILearningConfiguration conf) {
|
|
||||||
this.conf = conf;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
protected void preEpoch() { currentEpochStep = 0; }
|
|
||||||
|
|
||||||
@Override
|
|
||||||
protected void postEpoch() { }
|
|
||||||
|
|
||||||
@Override
|
|
||||||
protected IDataManager.StatEntry trainEpoch() {
|
|
||||||
setStepCounter(getStepCounter() + 1);
|
|
||||||
return new MockStatEntry(getCurrentEpochStep(), getStepCounter(), 1.0);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public NeuralNet getNeuralNet() {
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public IPolicy getPolicy() {
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public ILearningConfiguration getConfiguration() {
|
|
||||||
return conf;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public MDP getMdp() {
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -17,6 +17,7 @@
|
||||||
|
|
||||||
package org.deeplearning4j.rl4j.learning.sync.qlearning.discrete;
|
package org.deeplearning4j.rl4j.learning.sync.qlearning.discrete;
|
||||||
|
|
||||||
|
import org.deeplearning4j.gym.StepReply;
|
||||||
import org.deeplearning4j.rl4j.experience.ExperienceHandler;
|
import org.deeplearning4j.rl4j.experience.ExperienceHandler;
|
||||||
import org.deeplearning4j.rl4j.experience.StateActionPair;
|
import org.deeplearning4j.rl4j.experience.StateActionPair;
|
||||||
import org.deeplearning4j.rl4j.learning.IHistoryProcessor;
|
import org.deeplearning4j.rl4j.learning.IHistoryProcessor;
|
||||||
|
@ -26,11 +27,21 @@ import org.deeplearning4j.rl4j.learning.sync.Transition;
|
||||||
import org.deeplearning4j.rl4j.learning.sync.qlearning.QLearning;
|
import org.deeplearning4j.rl4j.learning.sync.qlearning.QLearning;
|
||||||
import org.deeplearning4j.rl4j.mdp.MDP;
|
import org.deeplearning4j.rl4j.mdp.MDP;
|
||||||
import org.deeplearning4j.rl4j.network.dqn.IDQN;
|
import org.deeplearning4j.rl4j.network.dqn.IDQN;
|
||||||
|
import org.deeplearning4j.rl4j.observation.Observation;
|
||||||
|
import org.deeplearning4j.rl4j.space.Box;
|
||||||
import org.deeplearning4j.rl4j.space.DiscreteSpace;
|
import org.deeplearning4j.rl4j.space.DiscreteSpace;
|
||||||
|
import org.deeplearning4j.rl4j.space.Encodable;
|
||||||
|
import org.deeplearning4j.rl4j.space.ObservationSpace;
|
||||||
import org.deeplearning4j.rl4j.support.*;
|
import org.deeplearning4j.rl4j.support.*;
|
||||||
import org.deeplearning4j.rl4j.util.DataManagerTrainingListener;
|
import org.deeplearning4j.rl4j.util.DataManagerTrainingListener;
|
||||||
import org.deeplearning4j.rl4j.util.IDataManager;
|
import org.deeplearning4j.rl4j.util.IDataManager;
|
||||||
|
import org.deeplearning4j.rl4j.util.LegacyMDPWrapper;
|
||||||
|
import org.junit.Before;
|
||||||
import org.junit.Test;
|
import org.junit.Test;
|
||||||
|
import org.junit.runner.RunWith;
|
||||||
|
import org.mockito.Mock;
|
||||||
|
import org.mockito.Mockito;
|
||||||
|
import org.mockito.junit.MockitoJUnitRunner;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.dataset.api.DataSet;
|
import org.nd4j.linalg.dataset.api.DataSet;
|
||||||
import org.nd4j.linalg.api.rng.Random;
|
import org.nd4j.linalg.api.rng.Random;
|
||||||
|
@ -40,150 +51,146 @@ import java.util.ArrayList;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
|
||||||
import static org.junit.Assert.*;
|
import static org.junit.Assert.*;
|
||||||
|
import static org.mockito.ArgumentMatchers.anyInt;
|
||||||
|
import static org.mockito.ArgumentMatchers.eq;
|
||||||
|
import static org.mockito.Mockito.mock;
|
||||||
|
import static org.mockito.Mockito.when;
|
||||||
|
|
||||||
|
@RunWith(MockitoJUnitRunner.class)
|
||||||
public class QLearningDiscreteTest {
|
public class QLearningDiscreteTest {
|
||||||
|
|
||||||
|
QLearningDiscrete<Encodable> qLearningDiscrete;
|
||||||
|
|
||||||
|
@Mock
|
||||||
|
IHistoryProcessor mockHistoryProcessor;
|
||||||
|
|
||||||
|
@Mock
|
||||||
|
IHistoryProcessor.Configuration mockHistoryConfiguration;
|
||||||
|
|
||||||
|
@Mock
|
||||||
|
MDP<Encodable, Integer, DiscreteSpace> mockMDP;
|
||||||
|
|
||||||
|
@Mock
|
||||||
|
DiscreteSpace mockActionSpace;
|
||||||
|
|
||||||
|
@Mock
|
||||||
|
ObservationSpace<Encodable> mockObservationSpace;
|
||||||
|
|
||||||
|
@Mock
|
||||||
|
IDQN mockDQN;
|
||||||
|
|
||||||
|
@Mock
|
||||||
|
QLearningConfiguration mockQlearningConfiguration;
|
||||||
|
|
||||||
|
int[] observationShape = new int[]{3, 10, 10};
|
||||||
|
int totalObservationSize = 1;
|
||||||
|
|
||||||
|
private void setupMDPMocks() {
|
||||||
|
|
||||||
|
when(mockObservationSpace.getShape()).thenReturn(observationShape);
|
||||||
|
|
||||||
|
when(mockMDP.getObservationSpace()).thenReturn(mockObservationSpace);
|
||||||
|
when(mockMDP.getActionSpace()).thenReturn(mockActionSpace);
|
||||||
|
|
||||||
|
int dataLength = 1;
|
||||||
|
for (int d : observationShape) {
|
||||||
|
dataLength *= d;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
private void mockTestContext(int maxSteps, int updateStart, int batchSize, double rewardFactor, int maxExperienceReplay) {
|
||||||
|
when(mockQlearningConfiguration.getBatchSize()).thenReturn(batchSize);
|
||||||
|
when(mockQlearningConfiguration.getRewardFactor()).thenReturn(rewardFactor);
|
||||||
|
when(mockQlearningConfiguration.getExpRepMaxSize()).thenReturn(maxExperienceReplay);
|
||||||
|
when(mockQlearningConfiguration.getSeed()).thenReturn(123L);
|
||||||
|
|
||||||
|
qLearningDiscrete = mock(
|
||||||
|
QLearningDiscrete.class,
|
||||||
|
Mockito.withSettings()
|
||||||
|
.useConstructor(mockMDP, mockDQN, mockQlearningConfiguration, 0)
|
||||||
|
.defaultAnswer(Mockito.CALLS_REAL_METHODS)
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
private void mockHistoryProcessor(int skipFrames) {
|
||||||
|
when(mockHistoryConfiguration.getRescaledHeight()).thenReturn(observationShape[1]);
|
||||||
|
when(mockHistoryConfiguration.getRescaledWidth()).thenReturn(observationShape[2]);
|
||||||
|
|
||||||
|
when(mockHistoryConfiguration.getOffsetX()).thenReturn(0);
|
||||||
|
when(mockHistoryConfiguration.getOffsetY()).thenReturn(0);
|
||||||
|
|
||||||
|
when(mockHistoryConfiguration.getCroppingHeight()).thenReturn(observationShape[1]);
|
||||||
|
when(mockHistoryConfiguration.getCroppingWidth()).thenReturn(observationShape[2]);
|
||||||
|
when(mockHistoryConfiguration.getSkipFrame()).thenReturn(skipFrames);
|
||||||
|
when(mockHistoryProcessor.getConf()).thenReturn(mockHistoryConfiguration);
|
||||||
|
|
||||||
|
qLearningDiscrete.setHistoryProcessor(mockHistoryProcessor);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Before
|
||||||
|
public void setup() {
|
||||||
|
setupMDPMocks();
|
||||||
|
|
||||||
|
for (int i : observationShape) {
|
||||||
|
totalObservationSize *= i;
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void refac_QLearningDiscrete_trainStep() {
|
public void when_singleTrainStep_expect_correctValues() {
|
||||||
|
|
||||||
// Arrange
|
// Arrange
|
||||||
MockObservationSpace observationSpace = new MockObservationSpace();
|
mockTestContext(100,0,2,1.0, 10);
|
||||||
MockDQN dqn = new MockDQN();
|
|
||||||
MockRandom random = new MockRandom(new double[]{
|
|
||||||
0.7309677600860596,
|
|
||||||
0.8314409852027893,
|
|
||||||
0.2405363917350769,
|
|
||||||
0.6063451766967773,
|
|
||||||
0.6374173760414124,
|
|
||||||
0.3090505599975586,
|
|
||||||
0.5504369735717773,
|
|
||||||
0.11700659990310669
|
|
||||||
},
|
|
||||||
new int[]{1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4});
|
|
||||||
MockMDP mdp = new MockMDP(observationSpace, random);
|
|
||||||
|
|
||||||
int initStepCount = 8;
|
// An example observation and 2 Q values output (2 actions)
|
||||||
|
Observation observation = new Observation(Nd4j.zeros(observationShape));
|
||||||
|
when(mockDQN.output(eq(observation))).thenReturn(Nd4j.create(new float[] {1.0f, 0.5f}));
|
||||||
|
|
||||||
QLearningConfiguration conf = QLearningConfiguration.builder()
|
when(mockMDP.step(anyInt())).thenReturn(new StepReply<>(new Box(new double[totalObservationSize]), 0, false, null));
|
||||||
.seed(0L)
|
|
||||||
.maxEpochStep(24)
|
|
||||||
.maxStep(0)
|
|
||||||
.expRepMaxSize(5).batchSize(1).targetDqnUpdateFreq(1000)
|
|
||||||
.updateStart(initStepCount)
|
|
||||||
.rewardFactor(1.0)
|
|
||||||
.gamma(0)
|
|
||||||
.errorClamp(0)
|
|
||||||
.minEpsilon(0)
|
|
||||||
.epsilonNbStep(0)
|
|
||||||
.doubleDQN(true)
|
|
||||||
.build();
|
|
||||||
|
|
||||||
MockDataManager dataManager = new MockDataManager(false);
|
|
||||||
MockExperienceHandler experienceHandler = new MockExperienceHandler();
|
|
||||||
TestQLearningDiscrete sut = new TestQLearningDiscrete(mdp, dqn, conf, dataManager, experienceHandler, 10, random);
|
|
||||||
IHistoryProcessor.Configuration hpConf = new IHistoryProcessor.Configuration(5, 4, 4, 4, 4, 0, 0, 2);
|
|
||||||
MockHistoryProcessor hp = new MockHistoryProcessor(hpConf);
|
|
||||||
sut.setHistoryProcessor(hp);
|
|
||||||
sut.getLegacyMDPWrapper().setTransformProcess(MockMDP.buildTransformProcess(observationSpace.getShape(), hpConf.getSkipFrame(), hpConf.getHistoryLength()));
|
|
||||||
List<QLearning.QLStepReturn<MockEncodable>> results = new ArrayList<>();
|
|
||||||
|
|
||||||
// Act
|
// Act
|
||||||
IDataManager.StatEntry result = sut.trainEpoch();
|
QLearning.QLStepReturn<Observation> stepReturn = qLearningDiscrete.trainStep(observation);
|
||||||
|
|
||||||
// Assert
|
// Assert
|
||||||
// HistoryProcessor calls
|
assertEquals(1.0, stepReturn.getMaxQ(), 1e-5);
|
||||||
double[] expectedRecords = new double[]{0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0, 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0};
|
|
||||||
assertEquals(expectedRecords.length, hp.recordCalls.size());
|
|
||||||
for (int i = 0; i < expectedRecords.length; ++i) {
|
|
||||||
assertEquals(expectedRecords[i], hp.recordCalls.get(i).getDouble(0), 0.0001);
|
|
||||||
}
|
|
||||||
assertEquals(0, hp.startMonitorCallCount);
|
|
||||||
assertEquals(0, hp.stopMonitorCallCount);
|
|
||||||
|
|
||||||
// DQN calls
|
StepReply<Observation> stepReply = stepReturn.getStepReply();
|
||||||
assertEquals(1, dqn.fitParams.size());
|
|
||||||
assertEquals(123.0, dqn.fitParams.get(0).getFirst().getDouble(0), 0.001);
|
|
||||||
assertEquals(234.0, dqn.fitParams.get(0).getSecond().getDouble(0), 0.001);
|
|
||||||
assertEquals(14, dqn.outputParams.size());
|
|
||||||
double[][] expectedDQNOutput = new double[][]{
|
|
||||||
new double[]{0.0, 2.0, 4.0, 6.0, 8.0},
|
|
||||||
new double[]{2.0, 4.0, 6.0, 8.0, 10.0},
|
|
||||||
new double[]{2.0, 4.0, 6.0, 8.0, 10.0},
|
|
||||||
new double[]{4.0, 6.0, 8.0, 10.0, 12.0},
|
|
||||||
new double[]{6.0, 8.0, 10.0, 12.0, 14.0},
|
|
||||||
new double[]{6.0, 8.0, 10.0, 12.0, 14.0},
|
|
||||||
new double[]{8.0, 10.0, 12.0, 14.0, 16.0},
|
|
||||||
new double[]{8.0, 10.0, 12.0, 14.0, 16.0},
|
|
||||||
new double[]{10.0, 12.0, 14.0, 16.0, 18.0},
|
|
||||||
new double[]{10.0, 12.0, 14.0, 16.0, 18.0},
|
|
||||||
new double[]{12.0, 14.0, 16.0, 18.0, 20.0},
|
|
||||||
new double[]{12.0, 14.0, 16.0, 18.0, 20.0},
|
|
||||||
new double[]{14.0, 16.0, 18.0, 20.0, 22.0},
|
|
||||||
new double[]{14.0, 16.0, 18.0, 20.0, 22.0},
|
|
||||||
};
|
|
||||||
for (int i = 0; i < expectedDQNOutput.length; ++i) {
|
|
||||||
INDArray outputParam = dqn.outputParams.get(i);
|
|
||||||
|
|
||||||
assertEquals(5, outputParam.shape()[1]);
|
assertEquals(0, stepReply.getReward(), 1e-5);
|
||||||
assertEquals(1, outputParam.shape()[2]);
|
assertFalse(stepReply.isDone());
|
||||||
|
assertFalse(stepReply.getObservation().isSkipped());
|
||||||
|
assertEquals(observation.getData().reshape(observationShape), stepReply.getObservation().getData().reshape(observationShape));
|
||||||
|
|
||||||
double[] expectedRow = expectedDQNOutput[i];
|
|
||||||
for (int j = 0; j < expectedRow.length; ++j) {
|
|
||||||
assertEquals("row: " + i + " col: " + j, expectedRow[j], 255.0 * outputParam.getDouble(j), 0.00001);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// MDP calls
|
|
||||||
assertArrayEquals(new Integer[]{0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 4, 4, 2, 2, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4}, mdp.actions.toArray());
|
|
||||||
|
|
||||||
// ExperienceHandler calls
|
|
||||||
double[] expectedTrRewards = new double[] { 9.0, 21.0, 25.0, 29.0, 33.0, 37.0, 41.0 };
|
|
||||||
int[] expectedTrActions = new int[] { 1, 4, 2, 4, 4, 4, 4, 4 };
|
|
||||||
double[] expectedTrNextObservation = new double[] { 2.0, 4.0, 6.0, 8.0, 10.0, 12.0, 14.0, 16.0 };
|
|
||||||
double[][] expectedTrObservations = new double[][] {
|
|
||||||
new double[] { 0.0, 2.0, 4.0, 6.0, 8.0 },
|
|
||||||
new double[] { 2.0, 4.0, 6.0, 8.0, 10.0 },
|
|
||||||
new double[] { 4.0, 6.0, 8.0, 10.0, 12.0 },
|
|
||||||
new double[] { 6.0, 8.0, 10.0, 12.0, 14.0 },
|
|
||||||
new double[] { 8.0, 10.0, 12.0, 14.0, 16.0 },
|
|
||||||
new double[] { 10.0, 12.0, 14.0, 16.0, 18.0 },
|
|
||||||
new double[] { 12.0, 14.0, 16.0, 18.0, 20.0 },
|
|
||||||
new double[] { 14.0, 16.0, 18.0, 20.0, 22.0 },
|
|
||||||
};
|
|
||||||
|
|
||||||
assertEquals(expectedTrObservations.length, experienceHandler.addExperienceArgs.size());
|
|
||||||
for(int i = 0; i < expectedTrRewards.length; ++i) {
|
|
||||||
StateActionPair<Integer> stateActionPair = experienceHandler.addExperienceArgs.get(i);
|
|
||||||
assertEquals(expectedTrRewards[i], stateActionPair.getReward(), 0.0001);
|
|
||||||
assertEquals((int)expectedTrActions[i], (int)stateActionPair.getAction());
|
|
||||||
for(int j = 0; j < expectedTrObservations[i].length; ++j) {
|
|
||||||
assertEquals("row: "+ i + " col: " + j, expectedTrObservations[i][j], 255.0 * stateActionPair.getObservation().getData().getDouble(0, j, 0), 0.0001);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
assertEquals(expectedTrNextObservation[expectedTrNextObservation.length - 1], 255.0 * experienceHandler.finalObservation.getData().getDouble(0), 0.0001);
|
|
||||||
|
|
||||||
// trainEpoch result
|
|
||||||
assertEquals(initStepCount + 16, result.getStepCounter());
|
|
||||||
assertEquals(300.0, result.getReward(), 0.00001);
|
|
||||||
assertTrue(dqn.hasBeenReset);
|
|
||||||
assertTrue(((MockDQN) sut.getTargetQNetwork()).hasBeenReset);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
public static class TestQLearningDiscrete extends QLearningDiscrete<MockEncodable> {
|
@Test
|
||||||
public TestQLearningDiscrete(MDP<MockEncodable, Integer, DiscreteSpace> mdp, IDQN dqn,
|
public void when_singleTrainStepSkippedFrames_expect_correctValues() {
|
||||||
QLearningConfiguration conf, IDataManager dataManager, ExperienceHandler<Integer, Transition<Integer>> experienceHandler,
|
// Arrange
|
||||||
int epsilonNbStep, Random rnd) {
|
mockTestContext(100,0,2,1.0, 10);
|
||||||
super(mdp, dqn, conf, epsilonNbStep, rnd);
|
|
||||||
addListener(new DataManagerTrainingListener(dataManager));
|
|
||||||
setExperienceHandler(experienceHandler);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
mockHistoryProcessor(2);
|
||||||
protected DataSet setTarget(List<Transition<Integer>> transitions) {
|
|
||||||
return new org.nd4j.linalg.dataset.DataSet(Nd4j.create(new double[] { 123.0 }), Nd4j.create(new double[] { 234.0 }));
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
// An example observation and 2 Q values output (2 actions)
|
||||||
public IDataManager.StatEntry trainEpoch() {
|
Observation observation = new Observation(Nd4j.zeros(observationShape));
|
||||||
return super.trainEpoch();
|
when(mockDQN.output(eq(observation))).thenReturn(Nd4j.create(new float[] {1.0f, 0.5f}));
|
||||||
}
|
|
||||||
|
when(mockMDP.step(anyInt())).thenReturn(new StepReply<>(new Box(new double[totalObservationSize]), 0, false, null));
|
||||||
|
|
||||||
|
// Act
|
||||||
|
QLearning.QLStepReturn<Observation> stepReturn = qLearningDiscrete.trainStep(observation);
|
||||||
|
|
||||||
|
// Assert
|
||||||
|
assertEquals(1.0, stepReturn.getMaxQ(), 1e-5);
|
||||||
|
|
||||||
|
StepReply<Observation> stepReply = stepReturn.getStepReply();
|
||||||
|
|
||||||
|
assertEquals(0, stepReply.getReward(), 1e-5);
|
||||||
|
assertFalse(stepReply.isDone());
|
||||||
|
assertTrue(stepReply.getObservation().isSkipped());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//TODO: there are much more test cases here that can be improved upon
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,79 +0,0 @@
|
||||||
package org.deeplearning4j.rl4j.learning.sync.support;
|
|
||||||
|
|
||||||
import org.deeplearning4j.gym.StepReply;
|
|
||||||
import org.deeplearning4j.rl4j.mdp.MDP;
|
|
||||||
import org.deeplearning4j.rl4j.space.DiscreteSpace;
|
|
||||||
import org.deeplearning4j.rl4j.space.ObservationSpace;
|
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
|
||||||
|
|
||||||
public class MockMDP implements MDP<Object, Integer, DiscreteSpace> {
|
|
||||||
|
|
||||||
private final int maxSteps;
|
|
||||||
private final DiscreteSpace actionSpace = new DiscreteSpace(1);
|
|
||||||
private final MockObservationSpace observationSpace = new MockObservationSpace();
|
|
||||||
|
|
||||||
private int currentStep = 0;
|
|
||||||
|
|
||||||
public MockMDP(int maxSteps) {
|
|
||||||
|
|
||||||
this.maxSteps = maxSteps;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public ObservationSpace<Object> getObservationSpace() {
|
|
||||||
return observationSpace;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public DiscreteSpace getActionSpace() {
|
|
||||||
return actionSpace;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public Object reset() {
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void close() {
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public StepReply<Object> step(Integer integer) {
|
|
||||||
return new StepReply<Object>(null, 1.0, isDone(), null);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public boolean isDone() {
|
|
||||||
return currentStep >= maxSteps;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public MDP<Object, Integer, DiscreteSpace> newInstance() {
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
|
|
||||||
private static class MockObservationSpace implements ObservationSpace {
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public String getName() {
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public int[] getShape() {
|
|
||||||
return new int[0];
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public INDArray getLow() {
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public INDArray getHigh() {
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -257,9 +257,9 @@ public class PolicyTest {
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
protected <AS extends ActionSpace<Integer>> Learning.InitMdp<Observation> refacInitMdp(LegacyMDPWrapper<MockEncodable, Integer, AS> mdpWrapper, IHistoryProcessor hp, RefacEpochStepCounter epochStepCounter) {
|
protected <AS extends ActionSpace<Integer>> Learning.InitMdp<Observation> refacInitMdp(LegacyMDPWrapper<MockEncodable, Integer, AS> mdpWrapper, IHistoryProcessor hp) {
|
||||||
mdpWrapper.setTransformProcess(MockMDP.buildTransformProcess(shape, skipFrame, historyLength));
|
mdpWrapper.setTransformProcess(MockMDP.buildTransformProcess(shape, skipFrame, historyLength));
|
||||||
return super.refacInitMdp(mdpWrapper, hp, epochStepCounter);
|
return super.refacInitMdp(mdpWrapper, hp);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,37 +0,0 @@
|
||||||
/*******************************************************************************
|
|
||||||
* Copyright (c) 2020 Konduit K.K.
|
|
||||||
*
|
|
||||||
* This program and the accompanying materials are made available under the
|
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* License for the specific language governing permissions and limitations
|
|
||||||
* under the License.
|
|
||||||
*
|
|
||||||
* SPDX-License-Identifier: Apache-2.0
|
|
||||||
******************************************************************************/
|
|
||||||
|
|
||||||
package org.deeplearning4j.rl4j.support;
|
|
||||||
|
|
||||||
import lombok.AllArgsConstructor;
|
|
||||||
import lombok.Value;
|
|
||||||
import org.deeplearning4j.rl4j.learning.configuration.IAsyncLearningConfiguration;
|
|
||||||
|
|
||||||
@Value
|
|
||||||
@AllArgsConstructor
|
|
||||||
public class MockAsyncConfiguration implements IAsyncLearningConfiguration {
|
|
||||||
|
|
||||||
private Long seed;
|
|
||||||
private int maxEpochStep;
|
|
||||||
private int maxStep;
|
|
||||||
private int updateStart;
|
|
||||||
private double rewardFactor;
|
|
||||||
private double gamma;
|
|
||||||
private double errorClamp;
|
|
||||||
private int numThreads;
|
|
||||||
private int nStep;
|
|
||||||
private int learnerUpdateFrequency;
|
|
||||||
}
|
|
|
@ -1,75 +0,0 @@
|
||||||
package org.deeplearning4j.rl4j.support;
|
|
||||||
|
|
||||||
import lombok.Getter;
|
|
||||||
import lombok.Setter;
|
|
||||||
import org.deeplearning4j.nn.gradient.Gradient;
|
|
||||||
import org.deeplearning4j.rl4j.learning.async.IAsyncGlobal;
|
|
||||||
import org.deeplearning4j.rl4j.network.NeuralNet;
|
|
||||||
|
|
||||||
import java.util.concurrent.atomic.AtomicInteger;
|
|
||||||
|
|
||||||
public class MockAsyncGlobal<NN extends NeuralNet> implements IAsyncGlobal<NN> {
|
|
||||||
|
|
||||||
@Getter
|
|
||||||
private final NN current;
|
|
||||||
|
|
||||||
public boolean hasBeenStarted = false;
|
|
||||||
public boolean hasBeenTerminated = false;
|
|
||||||
|
|
||||||
public int enqueueCallCount = 0;
|
|
||||||
|
|
||||||
@Setter
|
|
||||||
private int maxLoops;
|
|
||||||
@Setter
|
|
||||||
private int numLoopsStopRunning;
|
|
||||||
private int currentLoop = 0;
|
|
||||||
|
|
||||||
public MockAsyncGlobal() {
|
|
||||||
this(null);
|
|
||||||
}
|
|
||||||
|
|
||||||
public MockAsyncGlobal(NN current) {
|
|
||||||
maxLoops = Integer.MAX_VALUE;
|
|
||||||
numLoopsStopRunning = Integer.MAX_VALUE;
|
|
||||||
this.current = current;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public boolean isRunning() {
|
|
||||||
return currentLoop < numLoopsStopRunning;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void terminate() {
|
|
||||||
hasBeenTerminated = true;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public boolean isTrainingComplete() {
|
|
||||||
return currentLoop >= maxLoops;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void start() {
|
|
||||||
hasBeenStarted = true;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public AtomicInteger getT() {
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public NN getTarget() {
|
|
||||||
return current;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void enqueue(Gradient[] gradient, Integer nstep) {
|
|
||||||
++enqueueCallCount;
|
|
||||||
}
|
|
||||||
|
|
||||||
public void increaseCurrentLoop() {
|
|
||||||
++currentLoop;
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,46 +0,0 @@
|
||||||
package org.deeplearning4j.rl4j.support;
|
|
||||||
|
|
||||||
import org.deeplearning4j.rl4j.experience.ExperienceHandler;
|
|
||||||
import org.deeplearning4j.rl4j.experience.StateActionPair;
|
|
||||||
import org.deeplearning4j.rl4j.learning.sync.Transition;
|
|
||||||
import org.deeplearning4j.rl4j.observation.Observation;
|
|
||||||
|
|
||||||
import java.util.ArrayList;
|
|
||||||
import java.util.List;
|
|
||||||
|
|
||||||
public class MockExperienceHandler implements ExperienceHandler<Integer, Transition<Integer>> {
|
|
||||||
public List<StateActionPair<Integer>> addExperienceArgs = new ArrayList<StateActionPair<Integer>>();
|
|
||||||
public Observation finalObservation;
|
|
||||||
public boolean isGenerateTrainingBatchCalled;
|
|
||||||
public boolean isResetCalled;
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void addExperience(Observation observation, Integer action, double reward, boolean isTerminal) {
|
|
||||||
addExperienceArgs.add(new StateActionPair<>(observation, action, reward, isTerminal));
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void setFinalObservation(Observation observation) {
|
|
||||||
finalObservation = observation;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public List<Transition<Integer>> generateTrainingBatch() {
|
|
||||||
isGenerateTrainingBatchCalled = true;
|
|
||||||
return new ArrayList<Transition<Integer>>() {
|
|
||||||
{
|
|
||||||
add(new Transition<Integer>(null, 0, 0.0, false));
|
|
||||||
}
|
|
||||||
};
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void reset() {
|
|
||||||
isResetCalled = true;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public int getTrainingBatchSize() {
|
|
||||||
return 1;
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,77 +0,0 @@
|
||||||
package org.deeplearning4j.rl4j.support;
|
|
||||||
|
|
||||||
import lombok.Setter;
|
|
||||||
import org.deeplearning4j.rl4j.learning.IEpochTrainer;
|
|
||||||
import org.deeplearning4j.rl4j.learning.ILearning;
|
|
||||||
import org.deeplearning4j.rl4j.learning.listener.*;
|
|
||||||
import org.deeplearning4j.rl4j.util.IDataManager;
|
|
||||||
|
|
||||||
import java.util.ArrayList;
|
|
||||||
import java.util.List;
|
|
||||||
|
|
||||||
public class MockTrainingListener implements TrainingListener {
|
|
||||||
|
|
||||||
private final MockAsyncGlobal asyncGlobal;
|
|
||||||
public int onTrainingStartCallCount = 0;
|
|
||||||
public int onTrainingEndCallCount = 0;
|
|
||||||
public int onNewEpochCallCount = 0;
|
|
||||||
public int onEpochTrainingResultCallCount = 0;
|
|
||||||
public int onTrainingProgressCallCount = 0;
|
|
||||||
|
|
||||||
@Setter
|
|
||||||
private int remainingTrainingStartCallCount = Integer.MAX_VALUE;
|
|
||||||
@Setter
|
|
||||||
private int remainingOnNewEpochCallCount = Integer.MAX_VALUE;
|
|
||||||
@Setter
|
|
||||||
private int remainingOnEpochTrainingResult = Integer.MAX_VALUE;
|
|
||||||
@Setter
|
|
||||||
private int remainingonTrainingProgressCallCount = Integer.MAX_VALUE;
|
|
||||||
|
|
||||||
public final List<IDataManager.StatEntry> statEntries = new ArrayList<>();
|
|
||||||
|
|
||||||
public MockTrainingListener() {
|
|
||||||
this(null);
|
|
||||||
}
|
|
||||||
|
|
||||||
public MockTrainingListener(MockAsyncGlobal asyncGlobal) {
|
|
||||||
this.asyncGlobal = asyncGlobal;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public ListenerResponse onTrainingStart() {
|
|
||||||
++onTrainingStartCallCount;
|
|
||||||
--remainingTrainingStartCallCount;
|
|
||||||
return remainingTrainingStartCallCount < 0 ? ListenerResponse.STOP : ListenerResponse.CONTINUE;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public ListenerResponse onNewEpoch(IEpochTrainer trainer) {
|
|
||||||
++onNewEpochCallCount;
|
|
||||||
--remainingOnNewEpochCallCount;
|
|
||||||
return remainingOnNewEpochCallCount < 0 ? ListenerResponse.STOP : ListenerResponse.CONTINUE;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public ListenerResponse onEpochTrainingResult(IEpochTrainer trainer, IDataManager.StatEntry statEntry) {
|
|
||||||
++onEpochTrainingResultCallCount;
|
|
||||||
--remainingOnEpochTrainingResult;
|
|
||||||
statEntries.add(statEntry);
|
|
||||||
return remainingOnEpochTrainingResult < 0 ? ListenerResponse.STOP : ListenerResponse.CONTINUE;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public ListenerResponse onTrainingProgress(ILearning learning) {
|
|
||||||
++onTrainingProgressCallCount;
|
|
||||||
--remainingonTrainingProgressCallCount;
|
|
||||||
if(asyncGlobal != null) {
|
|
||||||
asyncGlobal.increaseCurrentLoop();
|
|
||||||
}
|
|
||||||
return remainingonTrainingProgressCallCount < 0 ? ListenerResponse.STOP : ListenerResponse.CONTINUE;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void onTrainingEnd() {
|
|
||||||
++onTrainingEndCallCount;
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,19 +0,0 @@
|
||||||
package org.deeplearning4j.rl4j.support;
|
|
||||||
|
|
||||||
import org.deeplearning4j.nn.gradient.Gradient;
|
|
||||||
import org.deeplearning4j.rl4j.experience.StateActionPair;
|
|
||||||
import org.deeplearning4j.rl4j.learning.async.UpdateAlgorithm;
|
|
||||||
|
|
||||||
import java.util.ArrayList;
|
|
||||||
import java.util.List;
|
|
||||||
|
|
||||||
public class MockUpdateAlgorithm implements UpdateAlgorithm<MockNeuralNet> {
|
|
||||||
|
|
||||||
public final List<List<StateActionPair<Integer>>> experiences = new ArrayList<List<StateActionPair<Integer>>>();
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public Gradient[] computeGradients(MockNeuralNet current, List<StateActionPair<Integer>> experience) {
|
|
||||||
experiences.add(experience);
|
|
||||||
return new Gradient[0];
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -151,20 +151,26 @@ public class DataManagerTrainingListenerTest {
|
||||||
private static class TestTrainer implements IEpochTrainer, ILearning
|
private static class TestTrainer implements IEpochTrainer, ILearning
|
||||||
{
|
{
|
||||||
@Override
|
@Override
|
||||||
public int getStepCounter() {
|
public int getStepCount() {
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public int getEpochCounter() {
|
public int getEpochCount() {
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public int getCurrentEpochStep() {
|
public int getEpisodeCount() {
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public int getCurrentEpisodeStepCount() {
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
@Getter
|
@Getter
|
||||||
@Setter
|
@Setter
|
||||||
private IHistoryProcessor historyProcessor;
|
private IHistoryProcessor historyProcessor;
|
||||||
|
|
Loading…
Reference in New Issue