RL4J: Add Backwardly Compatible Builder patterns (#326)

* Starting to switch configs of RL algorithms to use more fluent builder patterns. Many parameter choices in different algorithms default to SOTA and only be changed in specific cases

Signed-off-by: Bam4d <chris.bam4d@gmail.com>

* remove personal gpu-build file

Signed-off-by: Bam4d <chrisbam4d@gmail.com>

* refactored out configurations so they are heirarchical and re-usable, this is a step towards having a plug-and-play framework for different algorithms

* backwardly compatible configurations

* adding documentation to new configuration classes

Signed-off-by: Bam4d <chrisbam4d@gmail.com>

* private access modifiers are better suited here

Signed-off-by: Bam4d <chrisbam4d@gmail.com>

* RL4j does not compile without java 8 due to previous updates

fixing null pointers when listener arrays are empty

Signed-off-by: Bam4d <chrisbam4d@gmail.com>

* fixing copyright headers

Signed-off-by: Bam4d <chrisbam4d@gmail.com>

* uncomment logging line

Signed-off-by: Bam4d <chrisbam4d@gmail.com>

* fixing default value for learningUpdateFrequency

fixing test failure due to #352

Signed-off-by: Bam4d <chrisbam4d@gmail.com>

Co-authored-by: Bam4d <chris.bam4d@gmail.com>
master
Chris Bamford 2020-04-06 04:36:12 +01:00 committed by GitHub
parent fb1c41c512
commit 1a35ebec2e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
55 changed files with 1388 additions and 495 deletions

View File

@ -16,6 +16,7 @@
package org.datavec.api.transform.split;
import lombok.AllArgsConstructor;
import lombok.Data;

View File

@ -20,6 +20,18 @@
xmlns="http://maven.apache.org/POM/4.0.0"
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
<modelVersion>4.0.0</modelVersion>
<build>
<plugins>
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-compiler-plugin</artifactId>
<configuration>
<source>8</source>
<target>8</target>
</configuration>
</plugin>
</plugins>
</build>
<parent>
<artifactId>rl4j</artifactId>

View File

@ -1,3 +1,19 @@
/*******************************************************************************
* Copyright (c) 2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.rl4j.learning;
public interface EpochStepCounter {

View File

@ -1,5 +1,6 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
* Copyright (c) 2015-2019 Skymind, Inc.
* Copyright (c) 2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
@ -16,10 +17,10 @@
package org.deeplearning4j.rl4j.learning;
import org.deeplearning4j.rl4j.learning.configuration.ILearningConfiguration;
import org.deeplearning4j.rl4j.mdp.MDP;
import org.deeplearning4j.rl4j.policy.IPolicy;
import org.deeplearning4j.rl4j.space.ActionSpace;
import org.deeplearning4j.rl4j.space.Encodable;
/**
* @author rubenfiszel (ruben.fiszel@epfl.ch) 7/19/16.
@ -34,21 +35,12 @@ public interface ILearning<O, A, AS extends ActionSpace<A>> {
int getStepCounter();
LConfiguration getConfiguration();
ILearningConfiguration getConfiguration();
MDP<O, A, AS> getMdp();
IHistoryProcessor getHistoryProcessor();
interface LConfiguration {
Integer getSeed();
int getMaxEpochStep();
int getMaxStep();
double getGamma();
}
}

View File

@ -1,5 +1,6 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
* Copyright (c) 2015-2019 Skymind, Inc.
* Copyright (c) 2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
@ -19,6 +20,8 @@ package org.deeplearning4j.rl4j.learning.async;
import lombok.Getter;
import lombok.extern.slf4j.Slf4j;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.rl4j.learning.configuration.AsyncQLearningConfiguration;
import org.deeplearning4j.rl4j.learning.configuration.IAsyncLearningConfiguration;
import org.deeplearning4j.rl4j.network.NeuralNet;
import org.nd4j.linalg.primitives.Pair;
@ -27,28 +30,26 @@ import java.util.concurrent.atomic.AtomicInteger;
/**
* @author rubenfiszel (ruben.fiszel@epfl.ch) on 8/5/16.
*
* <p>
* In the original paper, the authors uses Asynchronous
* Gradient Descent: Hogwild! It is a way to apply gradients
* and modify a model in a lock-free manner.
*
* <p>
* As a way to implement this with dl4j, it is unfortunately
* necessary at the time of writing to apply the gradient
* (update the parameters) on a single separate global thread.
*
* <p>
* This Central thread for Asynchronous Method of reinforcement learning
* enqueue the gradients coming from the different threads and update its
* model and target. Those neurals nets are then synced by the other threads.
*
* <p>
* The benefits of this thread is that the updater is "shared" between all thread
* we have a single updater which is the single updater of the model contained here
*
* <p>
* This is similar to RMSProp with shared g and momentum
*
* <p>
* When Hogwild! is implemented, this could be replaced by a simple data
* structure
*
*
*/
@Slf4j
public class AsyncGlobal<NN extends NeuralNet> extends Thread implements IAsyncGlobal<NN> {
@ -56,7 +57,7 @@ public class AsyncGlobal<NN extends NeuralNet> extends Thread implements IAsyncG
@Getter
final private NN current;
final private ConcurrentLinkedQueue<Pair<Gradient[], Integer>> queue;
final private AsyncConfiguration a3cc;
final private IAsyncLearningConfiguration configuration;
private final IAsyncLearning learning;
@Getter
private AtomicInteger T = new AtomicInteger(0);
@ -65,20 +66,20 @@ public class AsyncGlobal<NN extends NeuralNet> extends Thread implements IAsyncG
@Getter
private boolean running = true;
public AsyncGlobal(NN initial, AsyncConfiguration a3cc, IAsyncLearning learning) {
public AsyncGlobal(NN initial, IAsyncLearningConfiguration configuration, IAsyncLearning learning) {
this.current = initial;
target = (NN) initial.clone();
this.a3cc = a3cc;
this.configuration = configuration;
this.learning = learning;
queue = new ConcurrentLinkedQueue<>();
}
public boolean isTrainingComplete() {
return T.get() >= a3cc.getMaxStep();
return T.get() >= configuration.getMaxStep();
}
public void enqueue(Gradient[] gradient, Integer nstep) {
if(running && !isTrainingComplete()) {
if (running && !isTrainingComplete()) {
queue.add(new Pair<>(gradient, nstep));
}
}
@ -94,9 +95,8 @@ public class AsyncGlobal<NN extends NeuralNet> extends Thread implements IAsyncG
synchronized (this) {
current.applyGradient(gradient, pair.getSecond());
}
if (a3cc.getTargetDqnUpdateFreq() != -1
&& T.get() / a3cc.getTargetDqnUpdateFreq() > (T.get() - pair.getSecond())
/ a3cc.getTargetDqnUpdateFreq()) {
if (configuration.getLearnerUpdateFrequency() != -1 && T.get() / configuration.getLearnerUpdateFrequency() > (T.get() - pair.getSecond())
/ configuration.getLearnerUpdateFrequency()) {
log.info("TARGET UPDATE at T = " + T.get());
synchronized (this) {
target.copy(current);
@ -111,7 +111,7 @@ public class AsyncGlobal<NN extends NeuralNet> extends Thread implements IAsyncG
* Force the immediate termination of the AsyncGlobal instance. Queued work items will be discarded and the AsyncLearning instance will be forced to terminate too.
*/
public void terminate() {
if(running) {
if (running) {
running = false;
queue.clear();
learning.terminate();

View File

@ -1,5 +1,6 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
* Copyright (c) 2015-2019 Skymind, Inc.
* Copyright (c) 2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
@ -21,14 +22,17 @@ import lombok.Getter;
import lombok.Setter;
import lombok.extern.slf4j.Slf4j;
import org.deeplearning4j.rl4j.learning.Learning;
import org.deeplearning4j.rl4j.learning.listener.*;
import org.deeplearning4j.rl4j.learning.configuration.AsyncQLearningConfiguration;
import org.deeplearning4j.rl4j.learning.configuration.IAsyncLearningConfiguration;
import org.deeplearning4j.rl4j.learning.listener.TrainingListener;
import org.deeplearning4j.rl4j.learning.listener.TrainingListenerList;
import org.deeplearning4j.rl4j.network.NeuralNet;
import org.deeplearning4j.rl4j.space.ActionSpace;
import org.deeplearning4j.rl4j.space.Encodable;
import org.nd4j.linalg.factory.Nd4j;
/**
* The entry point for async training. This class will start a number ({@link AsyncConfiguration#getNumThread()
* The entry point for async training. This class will start a number ({@link AsyncQLearningConfiguration#getNumThreads()
* configuration.getNumThread()}) of worker threads. Then, it will monitor their progress at regular intervals
* (see setProgressEventInterval(int))
*
@ -37,8 +41,8 @@ import org.nd4j.linalg.factory.Nd4j;
*/
@Slf4j
public abstract class AsyncLearning<O extends Encodable, A, AS extends ActionSpace<A>, NN extends NeuralNet>
extends Learning<O, A, AS, NN>
implements IAsyncLearning {
extends Learning<O, A, AS, NN>
implements IAsyncLearning {
private Thread monitorThread = null;
@ -56,9 +60,10 @@ public abstract class AsyncLearning<O extends Encodable, A, AS extends ActionSpa
/**
* Returns the configuration
* @return the configuration (see {@link AsyncConfiguration})
*
* @return the configuration (see {@link AsyncQLearningConfiguration})
*/
public abstract AsyncConfiguration getConfiguration();
public abstract IAsyncLearningConfiguration getConfiguration();
protected abstract AsyncThread newThread(int i, int deviceAffinity);
@ -77,12 +82,13 @@ public abstract class AsyncLearning<O extends Encodable, A, AS extends ActionSpa
/**
* Number of milliseconds between calls to onTrainingProgress
*/
@Getter @Setter
@Getter
@Setter
private int progressMonitorFrequency = 20000;
private void launchThreads() {
startGlobalThread();
for (int i = 0; i < getConfiguration().getNumThread(); i++) {
for (int i = 0; i < getConfiguration().getNumThreads(); i++) {
Thread t = newThread(i, i % Nd4j.getAffinityManager().getNumberOfDevices());
t.start();
}
@ -132,7 +138,7 @@ public abstract class AsyncLearning<O extends Encodable, A, AS extends ActionSpa
monitorThread = Thread.currentThread();
while (canContinue && !isTrainingComplete() && getAsyncGlobal().isRunning()) {
canContinue = listeners.notifyTrainingProgress(this);
if(!canContinue) {
if (!canContinue) {
return;
}
@ -155,11 +161,11 @@ public abstract class AsyncLearning<O extends Encodable, A, AS extends ActionSpa
* Force the immediate termination of the learning. All learning threads, the AsyncGlobal thread and the monitor thread will be terminated.
*/
public void terminate() {
if(canContinue) {
if (canContinue) {
canContinue = false;
Thread safeMonitorThread = monitorThread;
if(safeMonitorThread != null) {
if (safeMonitorThread != null) {
safeMonitorThread.interrupt();
}
}

View File

@ -1,5 +1,6 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
* Copyright (c) 2015-2019 Skymind, Inc.
* Copyright (c) 2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
@ -23,6 +24,7 @@ import lombok.Value;
import lombok.extern.slf4j.Slf4j;
import org.deeplearning4j.gym.StepReply;
import org.deeplearning4j.rl4j.learning.*;
import org.deeplearning4j.rl4j.learning.configuration.IAsyncLearningConfiguration;
import org.deeplearning4j.rl4j.learning.listener.TrainingListener;
import org.deeplearning4j.rl4j.learning.listener.TrainingListenerList;
import org.deeplearning4j.rl4j.mdp.MDP;
@ -155,7 +157,7 @@ public abstract class AsyncThread<O, A, AS extends ActionSpace<A>, NN extends Ne
}
private void handleTraining(RunContext context) {
int maxSteps = Math.min(getConf().getNstep(), getConf().getMaxEpochStep() - currentEpochStep);
int maxSteps = Math.min(getConf().getNStep(), getConf().getMaxEpochStep() - currentEpochStep);
SubEpochReturn subEpochReturn = trainSubEpoch(context.obs, maxSteps);
context.obs = subEpochReturn.getLastObs();
@ -197,7 +199,7 @@ public abstract class AsyncThread<O, A, AS extends ActionSpace<A>, NN extends Ne
protected abstract IAsyncGlobal<NN> getAsyncGlobal();
protected abstract AsyncConfiguration getConf();
protected abstract IAsyncLearningConfiguration getConf();
protected abstract IPolicy<O, A> getPolicy(NN net);

View File

@ -1,5 +1,7 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
* Copyright (c) 2015-2019 Skymind, Inc.
* Copyright (c) 2020 Konduit K.K.
*
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
@ -112,7 +114,7 @@ public abstract class AsyncThreadDiscrete<O, NN extends NeuralNet>
rewards.add(new MiniTrans(obs.getData(), null, null, 0));
else {
INDArray[] output = null;
if (getConf().getTargetDqnUpdateFreq() == -1)
if (getConf().getLearnerUpdateFrequency() == -1)
output = current.outputAll(obs.getData());
else synchronized (getAsyncGlobal()) {
output = getAsyncGlobal().getTarget().outputAll(obs.getData());

View File

@ -1,5 +1,6 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
* Copyright (c) 2015-2019 Skymind, Inc.
* Copyright (c) 2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
@ -16,11 +17,15 @@
package org.deeplearning4j.rl4j.learning.async.a3c.discrete;
import lombok.*;
import org.deeplearning4j.rl4j.learning.async.AsyncConfiguration;
import lombok.AllArgsConstructor;
import lombok.Builder;
import lombok.Data;
import lombok.EqualsAndHashCode;
import lombok.Getter;
import org.deeplearning4j.rl4j.learning.async.AsyncGlobal;
import org.deeplearning4j.rl4j.learning.async.AsyncLearning;
import org.deeplearning4j.rl4j.learning.async.AsyncThread;
import org.deeplearning4j.rl4j.learning.configuration.A3CLearningConfiguration;
import org.deeplearning4j.rl4j.mdp.MDP;
import org.deeplearning4j.rl4j.network.ac.IActorCritic;
import org.deeplearning4j.rl4j.policy.ACPolicy;
@ -32,15 +37,14 @@ import org.nd4j.linalg.factory.Nd4j;
/**
* @author rubenfiszel (ruben.fiszel@epfl.ch) 7/23/16.
* Training for A3C in the Discrete Domain
*
* <p>
* All methods are fully implemented as described in the
* https://arxiv.org/abs/1602.01783 paper.
*
*/
public abstract class A3CDiscrete<O extends Encodable> extends AsyncLearning<O, Integer, DiscreteSpace, IActorCritic> {
@Getter
final public A3CConfiguration configuration;
final public A3CLearningConfiguration configuration;
@Getter
final protected MDP<O, Integer, DiscreteSpace> mdp;
final private IActorCritic iActorCritic;
@ -49,15 +53,15 @@ public abstract class A3CDiscrete<O extends Encodable> extends AsyncLearning<O,
@Getter
final private ACPolicy<O> policy;
public A3CDiscrete(MDP<O, Integer, DiscreteSpace> mdp, IActorCritic iActorCritic, A3CConfiguration conf) {
public A3CDiscrete(MDP<O, Integer, DiscreteSpace> mdp, IActorCritic iActorCritic, A3CLearningConfiguration conf) {
this.iActorCritic = iActorCritic;
this.mdp = mdp;
this.configuration = conf;
asyncGlobal = new AsyncGlobal<>(iActorCritic, conf, this);
Integer seed = conf.getSeed();
Long seed = conf.getSeed();
Random rnd = Nd4j.getRandom();
if(seed != null) {
if (seed != null) {
rnd.setSeed(seed);
}
@ -65,7 +69,7 @@ public abstract class A3CDiscrete<O extends Encodable> extends AsyncLearning<O,
}
protected AsyncThread newThread(int i, int deviceNum) {
return new A3CThreadDiscrete(mdp.newInstance(), asyncGlobal, getConfiguration(), deviceNum, getListeners(), i);
return new A3CThreadDiscrete(mdp.newInstance(), asyncGlobal, this.getConfiguration(), deviceNum, getListeners(), i);
}
public IActorCritic getNeuralNet() {
@ -76,9 +80,9 @@ public abstract class A3CDiscrete<O extends Encodable> extends AsyncLearning<O,
@AllArgsConstructor
@Builder
@EqualsAndHashCode(callSuper = false)
public static class A3CConfiguration implements AsyncConfiguration {
public static class A3CConfiguration {
Integer seed;
int seed;
int maxEpochStep;
int maxStep;
int numThread;
@ -88,8 +92,20 @@ public abstract class A3CDiscrete<O extends Encodable> extends AsyncLearning<O,
double gamma;
double errorClamp;
public int getTargetDqnUpdateFreq() {
return -1;
/**
* Converts the deprecated A3CConfiguration to the new LearningConfiguration format
*/
public A3CLearningConfiguration toLearningConfiguration() {
return A3CLearningConfiguration.builder()
.seed(new Long(seed))
.maxEpochStep(maxEpochStep)
.maxStep(maxStep)
.numThreads(numThread)
.nStep(nstep)
.rewardFactor(rewardFactor)
.gamma(gamma)
.build();
}
}

View File

@ -1,5 +1,6 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
* Copyright (c) 2015-2019 Skymind, Inc.
* Copyright (c) 2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
@ -18,10 +19,12 @@ package org.deeplearning4j.rl4j.learning.async.a3c.discrete;
import org.deeplearning4j.rl4j.learning.HistoryProcessor;
import org.deeplearning4j.rl4j.learning.async.AsyncThread;
import org.deeplearning4j.rl4j.learning.configuration.A3CLearningConfiguration;
import org.deeplearning4j.rl4j.mdp.MDP;
import org.deeplearning4j.rl4j.network.ac.ActorCriticFactoryCompGraph;
import org.deeplearning4j.rl4j.network.ac.ActorCriticFactoryCompGraphStdConv;
import org.deeplearning4j.rl4j.network.ac.IActorCritic;
import org.deeplearning4j.rl4j.network.configuration.ActorCriticNetworkConfiguration;
import org.deeplearning4j.rl4j.space.DiscreteSpace;
import org.deeplearning4j.rl4j.space.Encodable;
import org.deeplearning4j.rl4j.util.DataManagerTrainingListener;
@ -29,16 +32,15 @@ import org.deeplearning4j.rl4j.util.IDataManager;
/**
* @author rubenfiszel (ruben.fiszel@epfl.ch) on 8/8/16.
*
* <p>
* Training for A3C in the Discrete Domain
*
* <p>
* Specialized constructors for the Conv (pixels input) case
* Specialized conf + provide additional type safety
*
* <p>
* It uses CompGraph because there is benefit to combine the
* first layers since they're essentially doing the same dimension
* reduction task
*
**/
public class A3CDiscreteConv<O extends Encodable> extends A3CDiscrete<O> {
@ -46,12 +48,22 @@ public class A3CDiscreteConv<O extends Encodable> extends A3CDiscrete<O> {
@Deprecated
public A3CDiscreteConv(MDP<O, Integer, DiscreteSpace> mdp, IActorCritic actorCritic,
HistoryProcessor.Configuration hpconf, A3CConfiguration conf, IDataManager dataManager) {
HistoryProcessor.Configuration hpconf, A3CConfiguration conf, IDataManager dataManager) {
this(mdp, actorCritic, hpconf, conf);
addListener(new DataManagerTrainingListener(dataManager));
}
@Deprecated
public A3CDiscreteConv(MDP<O, Integer, DiscreteSpace> mdp, IActorCritic IActorCritic,
HistoryProcessor.Configuration hpconf, A3CConfiguration conf) {
super(mdp, IActorCritic, conf.toLearningConfiguration());
this.hpconf = hpconf;
setHistoryProcessor(hpconf);
}
public A3CDiscreteConv(MDP<O, Integer, DiscreteSpace> mdp, IActorCritic IActorCritic,
HistoryProcessor.Configuration hpconf, A3CLearningConfiguration conf) {
super(mdp, IActorCritic, conf);
this.hpconf = hpconf;
setHistoryProcessor(hpconf);
@ -59,21 +71,35 @@ public class A3CDiscreteConv<O extends Encodable> extends A3CDiscrete<O> {
@Deprecated
public A3CDiscreteConv(MDP<O, Integer, DiscreteSpace> mdp, ActorCriticFactoryCompGraph factory,
HistoryProcessor.Configuration hpconf, A3CConfiguration conf, IDataManager dataManager) {
HistoryProcessor.Configuration hpconf, A3CConfiguration conf, IDataManager dataManager) {
this(mdp, factory.buildActorCritic(hpconf.getShape(), mdp.getActionSpace().getSize()), hpconf, conf, dataManager);
}
@Deprecated
public A3CDiscreteConv(MDP<O, Integer, DiscreteSpace> mdp, ActorCriticFactoryCompGraph factory,
HistoryProcessor.Configuration hpconf, A3CConfiguration conf) {
this(mdp, factory.buildActorCritic(hpconf.getShape(), mdp.getActionSpace().getSize()), hpconf, conf);
}
public A3CDiscreteConv(MDP<O, Integer, DiscreteSpace> mdp, ActorCriticFactoryCompGraph factory,
HistoryProcessor.Configuration hpconf, A3CLearningConfiguration conf) {
this(mdp, factory.buildActorCritic(hpconf.getShape(), mdp.getActionSpace().getSize()), hpconf, conf);
}
@Deprecated
public A3CDiscreteConv(MDP<O, Integer, DiscreteSpace> mdp, ActorCriticFactoryCompGraphStdConv.Configuration netConf,
HistoryProcessor.Configuration hpconf, A3CConfiguration conf, IDataManager dataManager) {
this(mdp, new ActorCriticFactoryCompGraphStdConv(netConf), hpconf, conf, dataManager);
HistoryProcessor.Configuration hpconf, A3CConfiguration conf, IDataManager dataManager) {
this(mdp, new ActorCriticFactoryCompGraphStdConv(netConf.toNetworkConfiguration()), hpconf, conf, dataManager);
}
@Deprecated
public A3CDiscreteConv(MDP<O, Integer, DiscreteSpace> mdp, ActorCriticFactoryCompGraphStdConv.Configuration netConf,
HistoryProcessor.Configuration hpconf, A3CConfiguration conf) {
this(mdp, new ActorCriticFactoryCompGraphStdConv(netConf.toNetworkConfiguration()), hpconf, conf);
}
public A3CDiscreteConv(MDP<O, Integer, DiscreteSpace> mdp, ActorCriticNetworkConfiguration netConf,
HistoryProcessor.Configuration hpconf, A3CLearningConfiguration conf) {
this(mdp, new ActorCriticFactoryCompGraphStdConv(netConf), hpconf, conf);
}

View File

@ -1,5 +1,6 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
* Copyright (c) 2015-2019 Skymind, Inc.
* Copyright (c) 2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
@ -16,8 +17,10 @@
package org.deeplearning4j.rl4j.learning.async.a3c.discrete;
import org.deeplearning4j.rl4j.learning.configuration.A3CLearningConfiguration;
import org.deeplearning4j.rl4j.mdp.MDP;
import org.deeplearning4j.rl4j.network.ac.*;
import org.deeplearning4j.rl4j.network.configuration.ActorCriticDenseNetworkConfiguration;
import org.deeplearning4j.rl4j.space.DiscreteSpace;
import org.deeplearning4j.rl4j.space.Encodable;
import org.deeplearning4j.rl4j.util.DataManagerTrainingListener;
@ -25,67 +28,81 @@ import org.deeplearning4j.rl4j.util.IDataManager;
/**
* @author rubenfiszel (ruben.fiszel@epfl.ch) on 8/8/16.
*
* <p>
* Training for A3C in the Discrete Domain
*
* <p>
* We use specifically the Separate version because
* the model is too small to have enough benefit by sharing layers
*
*/
public class A3CDiscreteDense<O extends Encodable> extends A3CDiscrete<O> {
@Deprecated
public A3CDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp, IActorCritic IActorCritic, A3CConfiguration conf,
IDataManager dataManager) {
IDataManager dataManager) {
this(mdp, IActorCritic, conf);
addListener(new DataManagerTrainingListener(dataManager));
}
@Deprecated
public A3CDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp, IActorCritic actorCritic, A3CConfiguration conf) {
super(mdp, actorCritic, conf.toLearningConfiguration());
}
public A3CDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp, IActorCritic actorCritic, A3CLearningConfiguration conf) {
super(mdp, actorCritic, conf);
}
@Deprecated
public A3CDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp, ActorCriticFactorySeparate factory,
A3CConfiguration conf, IDataManager dataManager) {
A3CConfiguration conf, IDataManager dataManager) {
this(mdp, factory.buildActorCritic(mdp.getObservationSpace().getShape(), mdp.getActionSpace().getSize()), conf,
dataManager);
dataManager);
}
@Deprecated
public A3CDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp, ActorCriticFactorySeparate factory,
A3CConfiguration conf) {
this(mdp, factory.buildActorCritic(mdp.getObservationSpace().getShape(), mdp.getActionSpace().getSize()), conf);
}
@Deprecated
public A3CDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp,
ActorCriticFactorySeparateStdDense.Configuration netConf, A3CConfiguration conf,
IDataManager dataManager) {
this(mdp, new ActorCriticFactorySeparateStdDense(netConf), conf, dataManager);
}
public A3CDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp,
ActorCriticFactorySeparateStdDense.Configuration netConf, A3CConfiguration conf) {
this(mdp, new ActorCriticFactorySeparateStdDense(netConf), conf);
}
@Deprecated
public A3CDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp, ActorCriticFactoryCompGraph factory,
A3CConfiguration conf, IDataManager dataManager) {
this(mdp, factory.buildActorCritic(mdp.getObservationSpace().getShape(), mdp.getActionSpace().getSize()), conf,
dataManager);
}
public A3CDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp, ActorCriticFactoryCompGraph factory,
A3CConfiguration conf) {
public A3CDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp, ActorCriticFactorySeparate factory,
A3CLearningConfiguration conf) {
this(mdp, factory.buildActorCritic(mdp.getObservationSpace().getShape(), mdp.getActionSpace().getSize()), conf);
}
@Deprecated
public A3CDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp,
ActorCriticFactoryCompGraphStdDense.Configuration netConf, A3CConfiguration conf,
IDataManager dataManager) {
this(mdp, new ActorCriticFactoryCompGraphStdDense(netConf), conf, dataManager);
}
public A3CDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp,
ActorCriticFactoryCompGraphStdDense.Configuration netConf, A3CConfiguration conf) {
this(mdp, new ActorCriticFactoryCompGraphStdDense(netConf), conf);
ActorCriticFactorySeparateStdDense.Configuration netConf, A3CConfiguration conf,
IDataManager dataManager) {
this(mdp, new ActorCriticFactorySeparateStdDense(netConf.toNetworkConfiguration()), conf, dataManager);
}
@Deprecated
public A3CDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp,
ActorCriticFactorySeparateStdDense.Configuration netConf, A3CConfiguration conf) {
this(mdp, new ActorCriticFactorySeparateStdDense(netConf.toNetworkConfiguration()), conf);
}
public A3CDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp,
ActorCriticDenseNetworkConfiguration netConf, A3CLearningConfiguration conf) {
this(mdp, new ActorCriticFactorySeparateStdDense(netConf), conf);
}
@Deprecated
public A3CDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp, ActorCriticFactoryCompGraph factory,
A3CConfiguration conf, IDataManager dataManager) {
this(mdp, factory.buildActorCritic(mdp.getObservationSpace().getShape(), mdp.getActionSpace().getSize()), conf,
dataManager);
}
@Deprecated
public A3CDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp, ActorCriticFactoryCompGraph factory,
A3CConfiguration conf) {
this(mdp, factory.buildActorCritic(mdp.getObservationSpace().getShape(), mdp.getActionSpace().getSize()), conf);
}
public A3CDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp, ActorCriticFactoryCompGraph factory,
A3CLearningConfiguration conf) {
this(mdp, factory.buildActorCritic(mdp.getObservationSpace().getShape(), mdp.getActionSpace().getSize()), conf);
}
}

View File

@ -1,5 +1,6 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
* Copyright (c) 2015-2019 Skymind, Inc.
* Copyright (c) 2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
@ -19,10 +20,10 @@ package org.deeplearning4j.rl4j.learning.async.a3c.discrete;
import lombok.Getter;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.rl4j.learning.Learning;
import org.deeplearning4j.rl4j.learning.async.AsyncGlobal;
import org.deeplearning4j.rl4j.learning.async.AsyncThreadDiscrete;
import org.deeplearning4j.rl4j.learning.async.IAsyncGlobal;
import org.deeplearning4j.rl4j.learning.async.MiniTrans;
import org.deeplearning4j.rl4j.learning.configuration.A3CLearningConfiguration;
import org.deeplearning4j.rl4j.learning.listener.TrainingListenerList;
import org.deeplearning4j.rl4j.mdp.MDP;
import org.deeplearning4j.rl4j.network.ac.IActorCritic;
@ -31,9 +32,9 @@ import org.deeplearning4j.rl4j.policy.Policy;
import org.deeplearning4j.rl4j.space.DiscreteSpace;
import org.deeplearning4j.rl4j.space.Encodable;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.rng.Random;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.NDArrayIndex;
import org.nd4j.linalg.api.rng.Random;
import java.util.Stack;
@ -45,7 +46,7 @@ import java.util.Stack;
public class A3CThreadDiscrete<O extends Encodable> extends AsyncThreadDiscrete<O, IActorCritic> {
@Getter
final protected A3CDiscrete.A3CConfiguration conf;
final protected A3CLearningConfiguration conf;
@Getter
final protected IAsyncGlobal<IActorCritic> asyncGlobal;
@Getter
@ -54,14 +55,14 @@ public class A3CThreadDiscrete<O extends Encodable> extends AsyncThreadDiscrete<
final private Random rnd;
public A3CThreadDiscrete(MDP<O, Integer, DiscreteSpace> mdp, IAsyncGlobal<IActorCritic> asyncGlobal,
A3CDiscrete.A3CConfiguration a3cc, int deviceNum, TrainingListenerList listeners,
A3CLearningConfiguration a3cc, int deviceNum, TrainingListenerList listeners,
int threadNumber) {
super(asyncGlobal, mdp, listeners, threadNumber, deviceNum);
this.conf = a3cc;
this.asyncGlobal = asyncGlobal;
this.threadNumber = threadNumber;
Integer seed = conf.getSeed();
Long seed = conf.getSeed();
rnd = Nd4j.getRandom();
if(seed != null) {
rnd.setSeed(seed + threadNumber);

View File

@ -1,49 +1,53 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
* Copyright (c) 2015-2019 Skymind, Inc.
* Copyright (c) 2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.rl4j.learning.async.nstep.discrete;
import lombok.*;
import org.deeplearning4j.rl4j.learning.async.AsyncConfiguration;
import lombok.AllArgsConstructor;
import lombok.Builder;
import lombok.Data;
import lombok.EqualsAndHashCode;
import lombok.Getter;
import org.deeplearning4j.rl4j.learning.async.AsyncGlobal;
import org.deeplearning4j.rl4j.learning.async.AsyncLearning;
import org.deeplearning4j.rl4j.learning.async.AsyncThread;
import org.deeplearning4j.rl4j.learning.configuration.AsyncQLearningConfiguration;
import org.deeplearning4j.rl4j.mdp.MDP;
import org.deeplearning4j.rl4j.network.dqn.IDQN;
import org.deeplearning4j.rl4j.policy.DQNPolicy;
import org.deeplearning4j.rl4j.policy.IPolicy;
import org.deeplearning4j.rl4j.space.DiscreteSpace;
import org.deeplearning4j.rl4j.space.Encodable;
import org.nd4j.linalg.factory.Nd4j;
/**
* @author rubenfiszel (ruben.fiszel@epfl.ch) on 8/5/16.
*/
public abstract class AsyncNStepQLearningDiscrete<O extends Encodable>
extends AsyncLearning<O, Integer, DiscreteSpace, IDQN> {
extends AsyncLearning<O, Integer, DiscreteSpace, IDQN> {
@Getter
final public AsyncNStepQLConfiguration configuration;
final public AsyncQLearningConfiguration configuration;
@Getter
final private MDP<O, Integer, DiscreteSpace> mdp;
@Getter
final private AsyncGlobal<IDQN> asyncGlobal;
public AsyncNStepQLearningDiscrete(MDP<O, Integer, DiscreteSpace> mdp, IDQN dqn, AsyncNStepQLConfiguration conf) {
public AsyncNStepQLearningDiscrete(MDP<O, Integer, DiscreteSpace> mdp, IDQN dqn, AsyncQLearningConfiguration conf) {
this.mdp = mdp;
this.configuration = conf;
this.asyncGlobal = new AsyncGlobal<>(dqn, conf, this);
@ -62,12 +66,11 @@ public abstract class AsyncNStepQLearningDiscrete<O extends Encodable>
return new DQNPolicy<O>(getNeuralNet());
}
@Data
@AllArgsConstructor
@Builder
@EqualsAndHashCode(callSuper = false)
public static class AsyncNStepQLConfiguration implements AsyncConfiguration {
public static class AsyncNStepQLConfiguration {
Integer seed;
int maxEpochStep;
@ -82,5 +85,22 @@ public abstract class AsyncNStepQLearningDiscrete<O extends Encodable>
float minEpsilon;
int epsilonNbStep;
public AsyncQLearningConfiguration toLearningConfiguration() {
return AsyncQLearningConfiguration.builder()
.seed(new Long(seed))
.maxEpochStep(maxEpochStep)
.maxStep(maxStep)
.numThreads(numThread)
.nStep(nstep)
.targetDqnUpdateFreq(targetDqnUpdateFreq)
.updateStart(updateStart)
.rewardFactor(rewardFactor)
.gamma(gamma)
.errorClamp(errorClamp)
.minEpsilon(minEpsilon)
.build();
}
}
}

View File

@ -1,24 +1,27 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
* Copyright (c) 2015-2019 Skymind, Inc.
* Copyright (c) 2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.rl4j.learning.async.nstep.discrete;
import org.deeplearning4j.rl4j.learning.HistoryProcessor;
import org.deeplearning4j.rl4j.learning.async.AsyncThread;
import org.deeplearning4j.rl4j.learning.configuration.AsyncQLearningConfiguration;
import org.deeplearning4j.rl4j.mdp.MDP;
import org.deeplearning4j.rl4j.network.configuration.NetworkConfiguration;
import org.deeplearning4j.rl4j.network.dqn.DQNFactory;
import org.deeplearning4j.rl4j.network.dqn.DQNFactoryStdConv;
import org.deeplearning4j.rl4j.network.dqn.IDQN;
@ -38,12 +41,12 @@ public class AsyncNStepQLearningDiscreteConv<O extends Encodable> extends AsyncN
@Deprecated
public AsyncNStepQLearningDiscreteConv(MDP<O, Integer, DiscreteSpace> mdp, IDQN dqn,
HistoryProcessor.Configuration hpconf, AsyncNStepQLConfiguration conf, IDataManager dataManager) {
HistoryProcessor.Configuration hpconf, AsyncQLearningConfiguration conf, IDataManager dataManager) {
this(mdp, dqn, hpconf, conf);
addListener(new DataManagerTrainingListener(dataManager));
}
public AsyncNStepQLearningDiscreteConv(MDP<O, Integer, DiscreteSpace> mdp, IDQN dqn,
HistoryProcessor.Configuration hpconf, AsyncNStepQLConfiguration conf) {
HistoryProcessor.Configuration hpconf, AsyncQLearningConfiguration conf) {
super(mdp, dqn, conf);
this.hpconf = hpconf;
setHistoryProcessor(hpconf);
@ -51,21 +54,21 @@ public class AsyncNStepQLearningDiscreteConv<O extends Encodable> extends AsyncN
@Deprecated
public AsyncNStepQLearningDiscreteConv(MDP<O, Integer, DiscreteSpace> mdp, DQNFactory factory,
HistoryProcessor.Configuration hpconf, AsyncNStepQLConfiguration conf, IDataManager dataManager) {
HistoryProcessor.Configuration hpconf, AsyncQLearningConfiguration conf, IDataManager dataManager) {
this(mdp, factory.buildDQN(hpconf.getShape(), mdp.getActionSpace().getSize()), hpconf, conf, dataManager);
}
public AsyncNStepQLearningDiscreteConv(MDP<O, Integer, DiscreteSpace> mdp, DQNFactory factory,
HistoryProcessor.Configuration hpconf, AsyncNStepQLConfiguration conf) {
HistoryProcessor.Configuration hpconf, AsyncQLearningConfiguration conf) {
this(mdp, factory.buildDQN(hpconf.getShape(), mdp.getActionSpace().getSize()), hpconf, conf);
}
@Deprecated
public AsyncNStepQLearningDiscreteConv(MDP<O, Integer, DiscreteSpace> mdp, DQNFactoryStdConv.Configuration netConf,
HistoryProcessor.Configuration hpconf, AsyncNStepQLConfiguration conf, IDataManager dataManager) {
public AsyncNStepQLearningDiscreteConv(MDP<O, Integer, DiscreteSpace> mdp, NetworkConfiguration netConf,
HistoryProcessor.Configuration hpconf, AsyncQLearningConfiguration conf, IDataManager dataManager) {
this(mdp, new DQNFactoryStdConv(netConf), hpconf, conf, dataManager);
}
public AsyncNStepQLearningDiscreteConv(MDP<O, Integer, DiscreteSpace> mdp, DQNFactoryStdConv.Configuration netConf,
HistoryProcessor.Configuration hpconf, AsyncNStepQLConfiguration conf) {
public AsyncNStepQLearningDiscreteConv(MDP<O, Integer, DiscreteSpace> mdp, NetworkConfiguration netConf,
HistoryProcessor.Configuration hpconf, AsyncQLearningConfiguration conf) {
this(mdp, new DQNFactoryStdConv(netConf), hpconf, conf);
}

View File

@ -1,22 +1,26 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
* Copyright (c) 2015-2019 Skymind, Inc.
* Copyright (c) 2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* SPDX-License-Identifier: Apache-2.0
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.rl4j.learning.async.nstep.discrete;
import org.deeplearning4j.rl4j.learning.configuration.AsyncQLearningConfiguration;
import org.deeplearning4j.rl4j.mdp.MDP;
import org.deeplearning4j.rl4j.network.configuration.DQNDenseNetworkConfiguration;
import org.deeplearning4j.rl4j.network.dqn.DQNFactory;
import org.deeplearning4j.rl4j.network.dqn.DQNFactoryStdDense;
import org.deeplearning4j.rl4j.network.dqn.IDQN;
@ -32,35 +36,56 @@ public class AsyncNStepQLearningDiscreteDense<O extends Encodable> extends Async
@Deprecated
public AsyncNStepQLearningDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp, IDQN dqn,
AsyncNStepQLConfiguration conf, IDataManager dataManager) {
super(mdp, dqn, conf);
AsyncNStepQLConfiguration conf, IDataManager dataManager) {
super(mdp, dqn, conf.toLearningConfiguration());
addListener(new DataManagerTrainingListener(dataManager));
}
@Deprecated
public AsyncNStepQLearningDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp, IDQN dqn,
AsyncNStepQLConfiguration conf) {
super(mdp, dqn, conf.toLearningConfiguration());
}
public AsyncNStepQLearningDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp, IDQN dqn,
AsyncQLearningConfiguration conf) {
super(mdp, dqn, conf);
}
@Deprecated
public AsyncNStepQLearningDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp, DQNFactory factory,
AsyncNStepQLConfiguration conf, IDataManager dataManager) {
AsyncNStepQLConfiguration conf, IDataManager dataManager) {
this(mdp, factory.buildDQN(mdp.getObservationSpace().getShape(), mdp.getActionSpace().getSize()), conf,
dataManager);
dataManager);
}
@Deprecated
public AsyncNStepQLearningDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp, DQNFactory factory,
AsyncNStepQLConfiguration conf) {
this(mdp, factory.buildDQN(mdp.getObservationSpace().getShape(), mdp.getActionSpace().getSize()), conf);
}
public AsyncNStepQLearningDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp, DQNFactory factory,
AsyncQLearningConfiguration conf) {
this(mdp, factory.buildDQN(mdp.getObservationSpace().getShape(), mdp.getActionSpace().getSize()), conf);
}
@Deprecated
public AsyncNStepQLearningDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp,
DQNFactoryStdDense.Configuration netConf, AsyncNStepQLConfiguration conf, IDataManager dataManager) {
this(mdp, new DQNFactoryStdDense(netConf), conf, dataManager);
DQNFactoryStdDense.Configuration netConf, AsyncNStepQLConfiguration conf, IDataManager dataManager) {
this(mdp, new DQNFactoryStdDense(netConf.toNetworkConfiguration()), conf, dataManager);
}
@Deprecated
public AsyncNStepQLearningDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp,
DQNFactoryStdDense.Configuration netConf, AsyncNStepQLConfiguration conf) {
this(mdp, new DQNFactoryStdDense(netConf.toNetworkConfiguration()), conf);
}
public AsyncNStepQLearningDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp,
DQNDenseNetworkConfiguration netConf, AsyncQLearningConfiguration conf) {
this(mdp, new DQNFactoryStdDense(netConf), conf);
}
}

View File

@ -1,17 +1,18 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
* Copyright (c) 2015-2019 Skymind, Inc.
* Copyright (c) 2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.rl4j.learning.async.nstep.discrete;
@ -22,6 +23,7 @@ import org.deeplearning4j.rl4j.learning.Learning;
import org.deeplearning4j.rl4j.learning.async.AsyncThreadDiscrete;
import org.deeplearning4j.rl4j.learning.async.IAsyncGlobal;
import org.deeplearning4j.rl4j.learning.async.MiniTrans;
import org.deeplearning4j.rl4j.learning.configuration.AsyncQLearningConfiguration;
import org.deeplearning4j.rl4j.learning.listener.TrainingListenerList;
import org.deeplearning4j.rl4j.mdp.MDP;
import org.deeplearning4j.rl4j.network.dqn.IDQN;
@ -42,7 +44,7 @@ import java.util.Stack;
public class AsyncNStepQLearningThreadDiscrete<O extends Encodable> extends AsyncThreadDiscrete<O, IDQN> {
@Getter
final protected AsyncNStepQLearningDiscrete.AsyncNStepQLConfiguration conf;
final protected AsyncQLearningConfiguration conf;
@Getter
final protected IAsyncGlobal<IDQN> asyncGlobal;
@Getter
@ -51,7 +53,7 @@ public class AsyncNStepQLearningThreadDiscrete<O extends Encodable> extends Asyn
final private Random rnd;
public AsyncNStepQLearningThreadDiscrete(MDP<O, Integer, DiscreteSpace> mdp, IAsyncGlobal<IDQN> asyncGlobal,
AsyncNStepQLearningDiscrete.AsyncNStepQLConfiguration conf,
AsyncQLearningConfiguration conf,
TrainingListenerList listeners, int threadNumber, int deviceNum) {
super(asyncGlobal, mdp, listeners, threadNumber, deviceNum);
this.conf = conf;
@ -59,7 +61,7 @@ public class AsyncNStepQLearningThreadDiscrete<O extends Encodable> extends Asyn
this.threadNumber = threadNumber;
rnd = Nd4j.getRandom();
Integer seed = conf.getSeed();
Long seed = conf.getSeed();
if(seed != null) {
rnd.setSeed(seed + threadNumber);
}

View File

@ -0,0 +1,46 @@
/*******************************************************************************
* Copyright (c) 2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.rl4j.learning.configuration;
import lombok.Builder;
import lombok.Data;
import lombok.EqualsAndHashCode;
import lombok.experimental.SuperBuilder;
@Data
@SuperBuilder
@EqualsAndHashCode(callSuper = true)
public class A3CLearningConfiguration extends LearningConfiguration implements IAsyncLearningConfiguration {
/**
* The number of asynchronous threads to use to generate gradients
*/
private final int numThreads;
/**
* The number of steps to calculate gradients over
*/
private final int nStep;
/**
* The frequency of async training iterations to update the target network.
*
* If this is set to -1 then the target network is updated after every training iteration
*/
@Builder.Default
private int learnerUpdateFrequency = -1;
}

View File

@ -0,0 +1,42 @@
/*******************************************************************************
* Copyright (c) 2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.rl4j.learning.configuration;
import lombok.Data;
import lombok.EqualsAndHashCode;
import lombok.experimental.SuperBuilder;
@Data
@SuperBuilder
@EqualsAndHashCode(callSuper = true)
public class AsyncQLearningConfiguration extends QLearningConfiguration implements IAsyncLearningConfiguration {
/**
* The number of asynchronous threads to use to generate experience data
*/
private final int numThreads;
/**
* The number of steps in each training interations
*/
private final int nStep;
public int getLearnerUpdateFrequency() {
return getTargetDqnUpdateFreq();
}
}

View File

@ -0,0 +1,28 @@
/*******************************************************************************
* Copyright (c) 2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.rl4j.learning.configuration;
public interface IAsyncLearningConfiguration extends ILearningConfiguration {
int getNumThreads();
int getNStep();
int getLearnerUpdateFrequency();
int getMaxStep();
}

View File

@ -1,5 +1,5 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
* Copyright (c) 2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
@ -14,36 +14,16 @@
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.rl4j.learning.async;
package org.deeplearning4j.rl4j.learning.configuration;
import org.deeplearning4j.rl4j.learning.ILearning;
/**
* @author rubenfiszel (ruben.fiszel@epfl.ch) on 8/23/16.
*
* Interface configuration for all training method that inherit
* from AsyncLearning
*/
public interface AsyncConfiguration extends ILearning.LConfiguration {
Integer getSeed();
public interface ILearningConfiguration {
Long getSeed();
int getMaxEpochStep();
int getMaxStep();
int getNumThread();
int getNstep();
int getTargetDqnUpdateFreq();
int getUpdateStart();
double getRewardFactor();
double getGamma();
double getErrorClamp();
double getRewardFactor();
}

View File

@ -0,0 +1,59 @@
/*******************************************************************************
* Copyright (c) 2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.rl4j.learning.configuration;
import lombok.Builder;
import lombok.Data;
import lombok.NoArgsConstructor;
import lombok.experimental.SuperBuilder;
@Data
@SuperBuilder
@NoArgsConstructor
public class LearningConfiguration implements ILearningConfiguration {
/**
* Seed value used for training
*/
@Builder.Default
private Long seed = System.currentTimeMillis();
/**
* The maximum number of steps in each episode
*/
@Builder.Default
private int maxEpochStep = 200;
/**
* The maximum number of steps to train for
*/
@Builder.Default
private int maxStep = 150000;
/**
* Gamma parameter used for discounted rewards
*/
@Builder.Default
private double gamma = 0.99;
/**
* Scaling parameter for rewards
*/
@Builder.Default
private double rewardFactor = 1.0;
}

View File

@ -0,0 +1,79 @@
/*******************************************************************************
* Copyright (c) 2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.rl4j.learning.configuration;
import lombok.Builder;
import lombok.Data;
import lombok.EqualsAndHashCode;
import lombok.NoArgsConstructor;
import lombok.experimental.SuperBuilder;
@Data
@SuperBuilder
@NoArgsConstructor
@EqualsAndHashCode(callSuper = false)
public class QLearningConfiguration extends LearningConfiguration {
/**
* The maximum size of the experience replay buffer
*/
@Builder.Default
private int expRepMaxSize = 150000;
/**
* The batch size of experience for each training iteration
*/
@Builder.Default
private int batchSize = 32;
/**
* How many steps between target network updates
*/
@Builder.Default
private int targetDqnUpdateFreq = 100;
/**
* The number of steps to initially wait for until samplling batches from experience replay buffer
*/
@Builder.Default
private int updateStart = 10;
/**
* Prevent the new Q-Value from being farther than <i>errorClamp</i> away from the previous value. Double.NaN will result in no clamping
*/
@Builder.Default
private double errorClamp = 1.0;
/**
* The minimum probability for random exploration action during episilon-greedy annealing
*/
@Builder.Default
private double minEpsilon = 0.1f;
/**
* The number of steps to anneal epsilon to its minimum value.
*/
@Builder.Default
private int epsilonNbStep = 10000;
/**
* Whether to use the double DQN algorithm
*/
@Builder.Default
private boolean doubleDQN = false;
}

View File

@ -1,5 +1,6 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
* Copyright (c) 2015-2019 Skymind, Inc.
* Copyright (c) 2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
@ -63,7 +64,7 @@ public abstract class SyncLearning<O, A, AS extends ActionSpace<A>, NN extends N
/**
* This method will train the model<p>
* The training stop when:<br>
* - the number of steps reaches the maximum defined in the configuration (see {@link LConfiguration#getMaxStep() LConfiguration.getMaxStep()})<br>
* - the number of steps reaches the maximum defined in the configuration (see {@link ILearningConfiguration#getMaxStep() LConfiguration.getMaxStep()})<br>
* OR<br>
* - a listener explicitly stops it<br>
* <p>

View File

@ -1,5 +1,6 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
* Copyright (c) 2015-2019 Skymind, Inc.
* Copyright (c) 2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
@ -18,10 +19,19 @@ package org.deeplearning4j.rl4j.learning.sync.qlearning;
import com.fasterxml.jackson.databind.annotation.JsonDeserialize;
import com.fasterxml.jackson.databind.annotation.JsonPOJOBuilder;
import lombok.*;
import lombok.AccessLevel;
import lombok.AllArgsConstructor;
import lombok.Builder;
import lombok.Data;
import lombok.EqualsAndHashCode;
import lombok.Getter;
import lombok.Setter;
import lombok.Value;
import lombok.extern.slf4j.Slf4j;
import org.deeplearning4j.gym.StepReply;
import org.deeplearning4j.rl4j.learning.EpochStepCounter;
import org.deeplearning4j.rl4j.learning.configuration.ILearningConfiguration;
import org.deeplearning4j.rl4j.learning.configuration.QLearningConfiguration;
import org.deeplearning4j.rl4j.learning.sync.ExpReplay;
import org.deeplearning4j.rl4j.learning.sync.IExpReplay;
import org.deeplearning4j.rl4j.learning.sync.SyncLearning;
@ -59,15 +69,15 @@ public abstract class QLearning<O extends Encodable, A, AS extends ActionSpace<A
protected abstract LegacyMDPWrapper<O, A, AS> getLegacyMDPWrapper();
public QLearning(QLConfiguration conf) {
public QLearning(QLearningConfiguration conf) {
this(conf, getSeededRandom(conf.getSeed()));
}
public QLearning(QLConfiguration conf, Random random) {
public QLearning(QLearningConfiguration conf, Random random) {
expReplay = new ExpReplay<>(conf.getExpRepMaxSize(), conf.getBatchSize(), random);
}
private static Random getSeededRandom(Integer seed) {
private static Random getSeededRandom(Long seed) {
Random rnd = Nd4j.getRandom();
if(seed != null) {
rnd.setSeed(seed);
@ -95,7 +105,7 @@ public abstract class QLearning<O extends Encodable, A, AS extends ActionSpace<A
return getQNetwork();
}
public abstract QLConfiguration getConfiguration();
public abstract QLearningConfiguration getConfiguration();
protected abstract void preEpoch();
@ -198,7 +208,7 @@ public abstract class QLearning<O extends Encodable, A, AS extends ActionSpace<A
double reward;
int episodeLength;
List<Double> scores;
float epsilon;
double epsilon;
double startQ;
double meanQ;
}
@ -213,12 +223,14 @@ public abstract class QLearning<O extends Encodable, A, AS extends ActionSpace<A
}
@Data
@AllArgsConstructor
@Builder
@Deprecated
@EqualsAndHashCode(callSuper = false)
@JsonDeserialize(builder = QLConfiguration.QLConfigurationBuilder.class)
public static class QLConfiguration implements LConfiguration {
public static class QLConfiguration {
Integer seed;
int maxEpochStep;
@ -237,7 +249,25 @@ public abstract class QLearning<O extends Encodable, A, AS extends ActionSpace<A
@JsonPOJOBuilder(withPrefix = "")
public static final class QLConfigurationBuilder {
}
public QLearningConfiguration toLearningConfiguration() {
return QLearningConfiguration.builder()
.seed(seed.longValue())
.maxEpochStep(maxEpochStep)
.maxStep(maxStep)
.expRepMaxSize(expRepMaxSize)
.batchSize(batchSize)
.targetDqnUpdateFreq(targetDqnUpdateFreq)
.updateStart(updateStart)
.rewardFactor(rewardFactor)
.gamma(gamma)
.errorClamp(errorClamp)
.minEpsilon(minEpsilon)
.epsilonNbStep(epsilonNbStep)
.doubleDQN(doubleDQN)
.build();
}
}
}

View File

@ -1,5 +1,6 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
* Copyright (c) 2015-2019 Skymind, Inc.
* Copyright (c) 2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
@ -22,6 +23,7 @@ import lombok.Setter;
import org.deeplearning4j.gym.StepReply;
import org.deeplearning4j.rl4j.learning.IHistoryProcessor;
import org.deeplearning4j.rl4j.learning.Learning;
import org.deeplearning4j.rl4j.learning.configuration.QLearningConfiguration;
import org.deeplearning4j.rl4j.learning.sync.Transition;
import org.deeplearning4j.rl4j.learning.sync.qlearning.QLearning;
import org.deeplearning4j.rl4j.learning.sync.qlearning.discrete.TDTargetAlgorithm.DoubleDQN;
@ -45,16 +47,15 @@ import java.util.ArrayList;
/**
* @author rubenfiszel (ruben.fiszel@epfl.ch) 7/18/16.
*
* <p>
* DQN or Deep Q-Learning in the Discrete domain
*
* <p>
* http://arxiv.org/abs/1312.5602
*
*/
public abstract class QLearningDiscrete<O extends Encodable> extends QLearning<O, Integer, DiscreteSpace> {
@Getter
final private QLConfiguration configuration;
final private QLearningConfiguration configuration;
private final LegacyMDPWrapper<O, Integer, DiscreteSpace> mdp;
@Getter
private DQNPolicy<O> policy;
@ -78,16 +79,15 @@ public abstract class QLearningDiscrete<O extends Encodable> extends QLearning<O
return mdp;
}
public QLearningDiscrete(MDP<O, Integer, DiscreteSpace> mdp, IDQN dqn, QLConfiguration conf,
int epsilonNbStep) {
public QLearningDiscrete(MDP<O, Integer, DiscreteSpace> mdp, IDQN dqn, QLearningConfiguration conf, int epsilonNbStep) {
this(mdp, dqn, conf, epsilonNbStep, Nd4j.getRandomFactory().getNewRandomInstance(conf.getSeed()));
}
public QLearningDiscrete(MDP<O, Integer, DiscreteSpace> mdp, IDQN dqn, QLConfiguration conf,
public QLearningDiscrete(MDP<O, Integer, DiscreteSpace> mdp, IDQN dqn, QLearningConfiguration conf,
int epsilonNbStep, Random random) {
super(conf);
this.configuration = conf;
this.mdp = new LegacyMDPWrapper<O, Integer, DiscreteSpace>(mdp, null, this);
this.mdp = new LegacyMDPWrapper<>(mdp, null, this);
qNetwork = dqn;
targetQNetwork = dqn.clone();
policy = new DQNPolicy(getQNetwork());
@ -125,6 +125,7 @@ public abstract class QLearningDiscrete<O extends Encodable> extends QLearning<O
/**
* Single step of training
*
* @param obs last obs
* @return relevant info for next step
*/
@ -135,8 +136,8 @@ public abstract class QLearningDiscrete<O extends Encodable> extends QLearning<O
boolean isHistoryProcessor = getHistoryProcessor() != null;
int skipFrame = isHistoryProcessor ? getHistoryProcessor().getConf().getSkipFrame() : 1;
int historyLength = isHistoryProcessor ? getHistoryProcessor().getConf().getHistoryLength() : 1;
int updateStart = getConfiguration().getUpdateStart()
+ ((getConfiguration().getBatchSize() + historyLength) * skipFrame);
int updateStart = this.getConfiguration().getUpdateStart()
+ ((this.getConfiguration().getBatchSize() + historyLength) * skipFrame);
Double maxQ = Double.NaN; //ignore if Nan for stats
@ -161,7 +162,7 @@ public abstract class QLearningDiscrete<O extends Encodable> extends QLearning<O
if (!obs.isSkipped()) {
// Add experience
if(pendingTransition != null) {
if (pendingTransition != null) {
pendingTransition.setNextObservation(obs);
getExpReplay().store(pendingTransition);
}
@ -188,7 +189,7 @@ public abstract class QLearningDiscrete<O extends Encodable> extends QLearning<O
@Override
protected void finishEpoch(Observation observation) {
if(pendingTransition != null) {
if (pendingTransition != null) {
pendingTransition.setNextObservation(observation);
getExpReplay().store(pendingTransition);
}

View File

@ -1,5 +1,6 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
* Copyright (c) 2015-2019 Skymind, Inc.
* Copyright (c) 2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
@ -17,7 +18,9 @@
package org.deeplearning4j.rl4j.learning.sync.qlearning.discrete;
import org.deeplearning4j.rl4j.learning.HistoryProcessor;
import org.deeplearning4j.rl4j.learning.configuration.QLearningConfiguration;
import org.deeplearning4j.rl4j.mdp.MDP;
import org.deeplearning4j.rl4j.network.configuration.NetworkConfiguration;
import org.deeplearning4j.rl4j.network.dqn.DQNFactory;
import org.deeplearning4j.rl4j.network.dqn.DQNFactoryStdConv;
import org.deeplearning4j.rl4j.network.dqn.IDQN;
@ -36,33 +39,55 @@ public class QLearningDiscreteConv<O extends Encodable> extends QLearningDiscret
@Deprecated
public QLearningDiscreteConv(MDP<O, Integer, DiscreteSpace> mdp, IDQN dqn, HistoryProcessor.Configuration hpconf,
QLConfiguration conf, IDataManager dataManager) {
QLConfiguration conf, IDataManager dataManager) {
this(mdp, dqn, hpconf, conf);
addListener(new DataManagerTrainingListener(dataManager));
}
@Deprecated
public QLearningDiscreteConv(MDP<O, Integer, DiscreteSpace> mdp, IDQN dqn, HistoryProcessor.Configuration hpconf,
QLConfiguration conf) {
super(mdp, dqn, conf.toLearningConfiguration(), conf.getEpsilonNbStep() * hpconf.getSkipFrame());
setHistoryProcessor(hpconf);
}
public QLearningDiscreteConv(MDP<O, Integer, DiscreteSpace> mdp, IDQN dqn, HistoryProcessor.Configuration hpconf,
QLearningConfiguration conf) {
super(mdp, dqn, conf, conf.getEpsilonNbStep() * hpconf.getSkipFrame());
setHistoryProcessor(hpconf);
}
@Deprecated
public QLearningDiscreteConv(MDP<O, Integer, DiscreteSpace> mdp, DQNFactory factory,
HistoryProcessor.Configuration hpconf, QLConfiguration conf, IDataManager dataManager) {
HistoryProcessor.Configuration hpconf, QLConfiguration conf, IDataManager dataManager) {
this(mdp, factory.buildDQN(hpconf.getShape(), mdp.getActionSpace().getSize()), hpconf, conf, dataManager);
}
@Deprecated
public QLearningDiscreteConv(MDP<O, Integer, DiscreteSpace> mdp, DQNFactory factory,
HistoryProcessor.Configuration hpconf, QLConfiguration conf) {
this(mdp, factory.buildDQN(hpconf.getShape(), mdp.getActionSpace().getSize()), hpconf, conf);
}
public QLearningDiscreteConv(MDP<O, Integer, DiscreteSpace> mdp, DQNFactory factory,
HistoryProcessor.Configuration hpconf, QLearningConfiguration conf) {
this(mdp, factory.buildDQN(hpconf.getShape(), mdp.getActionSpace().getSize()), hpconf, conf);
}
@Deprecated
public QLearningDiscreteConv(MDP<O, Integer, DiscreteSpace> mdp, DQNFactoryStdConv.Configuration netConf,
HistoryProcessor.Configuration hpconf, QLConfiguration conf, IDataManager dataManager) {
this(mdp, new DQNFactoryStdConv(netConf), hpconf, conf, dataManager);
HistoryProcessor.Configuration hpconf, QLConfiguration conf, IDataManager dataManager) {
this(mdp, new DQNFactoryStdConv(netConf.toNetworkConfiguration()), hpconf, conf, dataManager);
}
@Deprecated
public QLearningDiscreteConv(MDP<O, Integer, DiscreteSpace> mdp, DQNFactoryStdConv.Configuration netConf,
HistoryProcessor.Configuration hpconf, QLConfiguration conf) {
this(mdp, new DQNFactoryStdConv(netConf.toNetworkConfiguration()), hpconf, conf);
}
public QLearningDiscreteConv(MDP<O, Integer, DiscreteSpace> mdp, NetworkConfiguration netConf,
HistoryProcessor.Configuration hpconf, QLearningConfiguration conf) {
this(mdp, new DQNFactoryStdConv(netConf), hpconf, conf);
}

View File

@ -1,5 +1,6 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
* Copyright (c) 2015-2019 Skymind, Inc.
* Copyright (c) 2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
@ -16,8 +17,10 @@
package org.deeplearning4j.rl4j.learning.sync.qlearning.discrete;
import org.deeplearning4j.rl4j.learning.configuration.QLearningConfiguration;
import org.deeplearning4j.rl4j.learning.sync.qlearning.QLearning;
import org.deeplearning4j.rl4j.mdp.MDP;
import org.deeplearning4j.rl4j.network.configuration.DQNDenseNetworkConfiguration;
import org.deeplearning4j.rl4j.network.dqn.DQNFactory;
import org.deeplearning4j.rl4j.network.dqn.DQNFactoryStdDense;
import org.deeplearning4j.rl4j.network.dqn.IDQN;
@ -38,7 +41,13 @@ public class QLearningDiscreteDense<O extends Encodable> extends QLearningDiscre
this(mdp, dqn, conf);
addListener(new DataManagerTrainingListener(dataManager));
}
@Deprecated
public QLearningDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp, IDQN dqn, QLearning.QLConfiguration conf) {
super(mdp, dqn, conf.toLearningConfiguration(), conf.getEpsilonNbStep());
}
public QLearningDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp, IDQN dqn, QLearningConfiguration conf) {
super(mdp, dqn, conf, conf.getEpsilonNbStep());
}
@ -48,18 +57,33 @@ public class QLearningDiscreteDense<O extends Encodable> extends QLearningDiscre
this(mdp, factory.buildDQN(mdp.getObservationSpace().getShape(), mdp.getActionSpace().getSize()), conf,
dataManager);
}
@Deprecated
public QLearningDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp, DQNFactory factory,
QLearning.QLConfiguration conf) {
this(mdp, factory.buildDQN(mdp.getObservationSpace().getShape(), mdp.getActionSpace().getSize()), conf);
}
public QLearningDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp, DQNFactory factory,
QLearningConfiguration conf) {
this(mdp, factory.buildDQN(mdp.getObservationSpace().getShape(), mdp.getActionSpace().getSize()), conf);
}
@Deprecated
public QLearningDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp, DQNFactoryStdDense.Configuration netConf,
QLearning.QLConfiguration conf, IDataManager dataManager) {
this(mdp, new DQNFactoryStdDense(netConf), conf, dataManager);
this(mdp, new DQNFactoryStdDense(netConf.toNetworkConfiguration()), conf, dataManager);
}
@Deprecated
public QLearningDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp, DQNFactoryStdDense.Configuration netConf,
QLearning.QLConfiguration conf) {
this(mdp, new DQNFactoryStdDense(netConf.toNetworkConfiguration()), conf);
}
public QLearningDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp, DQNDenseNetworkConfiguration netConf,
QLearningConfiguration conf) {
this(mdp, new DQNFactoryStdDense(netConf), conf);
}

View File

@ -1,5 +1,6 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
* Copyright (c) 2015-2019 Skymind, Inc.
* Copyright (c) 2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
@ -36,7 +37,7 @@ import java.util.Collection;
*
* Standard implementation of ActorCriticCompGraph
*/
public class ActorCriticCompGraph<NN extends ActorCriticCompGraph> implements IActorCritic<NN> {
public class ActorCriticCompGraph implements IActorCritic<ActorCriticCompGraph> {
final protected ComputationGraph cg;
@Getter
@ -73,13 +74,13 @@ public class ActorCriticCompGraph<NN extends ActorCriticCompGraph> implements IA
}
}
public NN clone() {
NN nn = (NN)new ActorCriticCompGraph(cg.clone());
public ActorCriticCompGraph clone() {
ActorCriticCompGraph nn = new ActorCriticCompGraph(cg.clone());
nn.cg.setListeners(cg.getListeners());
return nn;
}
public void copy(NN from) {
public void copy(ActorCriticCompGraph from) {
cg.setParams(from.cg.params());
}

View File

@ -1,5 +1,6 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
* Copyright (c) 2015-2019 Skymind, Inc.
* Copyright (c) 2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
@ -31,12 +32,16 @@ import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.optimize.api.TrainingListener;
import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
import org.deeplearning4j.rl4j.network.configuration.ActorCriticNetworkConfiguration;
import org.deeplearning4j.rl4j.network.configuration.ActorCriticNetworkConfiguration.ActorCriticNetworkConfigurationBuilder;
import org.deeplearning4j.rl4j.util.Constants;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.learning.config.Adam;
import org.nd4j.linalg.learning.config.IUpdater;
import org.nd4j.linalg.lossfunctions.LossFunctions;
import java.util.Arrays;
/**
* @author rubenfiszel (ruben.fiszel@epfl.ch) on 8/9/16.
*
@ -45,8 +50,7 @@ import org.nd4j.linalg.lossfunctions.LossFunctions;
@Value
public class ActorCriticFactoryCompGraphStdConv implements ActorCriticFactoryCompGraph {
Configuration conf;
ActorCriticNetworkConfiguration conf;
public ActorCriticCompGraph buildActorCritic(int shapeInputs[], int numOutputs) {
@ -109,16 +113,33 @@ public class ActorCriticFactoryCompGraphStdConv implements ActorCriticFactoryCom
return new ActorCriticCompGraph(model);
}
@AllArgsConstructor
@Builder
@Value
@Deprecated
public static class Configuration {
double l2;
IUpdater updater;
TrainingListener[] listeners;
boolean useLSTM;
/**
* Converts the deprecated Configuration to the new NetworkConfiguration format
*/
public ActorCriticNetworkConfiguration toNetworkConfiguration() {
ActorCriticNetworkConfigurationBuilder builder = ActorCriticNetworkConfiguration.builder()
.l2(l2)
.updater(updater)
.useLSTM(useLSTM);
if (listeners != null) {
builder.listeners(Arrays.asList(listeners));
}
return builder.build();
}
}
}

View File

@ -1,5 +1,6 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
* Copyright (c) 2015-2019 Skymind, Inc.
* Copyright (c) 2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
@ -16,8 +17,6 @@
package org.deeplearning4j.rl4j.network.ac;
import lombok.AllArgsConstructor;
import lombok.Builder;
import lombok.Value;
import org.deeplearning4j.nn.api.OptimizationAlgorithm;
import org.deeplearning4j.nn.conf.ComputationGraphConfiguration;
@ -29,12 +28,11 @@ import org.deeplearning4j.nn.conf.layers.OutputLayer;
import org.deeplearning4j.nn.conf.layers.RnnOutputLayer;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.optimize.api.TrainingListener;
import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
import org.deeplearning4j.rl4j.network.configuration.ActorCriticDenseNetworkConfiguration;
import org.deeplearning4j.rl4j.util.Constants;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.learning.config.Adam;
import org.nd4j.linalg.learning.config.IUpdater;
import org.nd4j.linalg.lossfunctions.LossFunctions;
/**
@ -45,7 +43,7 @@ import org.nd4j.linalg.lossfunctions.LossFunctions;
@Value
public class ActorCriticFactoryCompGraphStdDense implements ActorCriticFactoryCompGraph {
Configuration conf;
ActorCriticDenseNetworkConfiguration conf;
public ActorCriticCompGraph buildActorCritic(int[] numInputs, int numOutputs) {
int nIn = 1;
@ -65,27 +63,27 @@ public class ActorCriticFactoryCompGraphStdDense implements ActorCriticFactoryCo
"input");
for (int i = 1; i < conf.getNumLayer(); i++) {
for (int i = 1; i < conf.getNumLayers(); i++) {
confB.addLayer(i + "", new DenseLayer.Builder().nIn(conf.getNumHiddenNodes()).nOut(conf.getNumHiddenNodes())
.activation(Activation.RELU).build(), (i - 1) + "");
}
if (conf.isUseLSTM()) {
confB.addLayer(getConf().getNumLayer() + "", new LSTM.Builder().activation(Activation.TANH)
.nOut(conf.getNumHiddenNodes()).build(), (getConf().getNumLayer() - 1) + "");
confB.addLayer(getConf().getNumLayers() + "", new LSTM.Builder().activation(Activation.TANH)
.nOut(conf.getNumHiddenNodes()).build(), (getConf().getNumLayers() - 1) + "");
confB.addLayer("value", new RnnOutputLayer.Builder(LossFunctions.LossFunction.MSE).activation(Activation.IDENTITY)
.nOut(1).build(), getConf().getNumLayer() + "");
.nOut(1).build(), getConf().getNumLayers() + "");
confB.addLayer("softmax", new RnnOutputLayer.Builder(new ActorCriticLoss()).activation(Activation.SOFTMAX)
.nOut(numOutputs).build(), getConf().getNumLayer() + "");
.nOut(numOutputs).build(), getConf().getNumLayers() + "");
} else {
confB.addLayer("value", new OutputLayer.Builder(LossFunctions.LossFunction.MSE).activation(Activation.IDENTITY)
.nOut(1).build(), (getConf().getNumLayer() - 1) + "");
.nOut(1).build(), (getConf().getNumLayers() - 1) + "");
confB.addLayer("softmax", new OutputLayer.Builder(new ActorCriticLoss()).activation(Activation.SOFTMAX)
.nOut(numOutputs).build(), (getConf().getNumLayer() - 1) + "");
.nOut(numOutputs).build(), (getConf().getNumLayers() - 1) + "");
}
confB.setOutputs("value", "softmax");
@ -103,18 +101,4 @@ public class ActorCriticFactoryCompGraphStdDense implements ActorCriticFactoryCo
return new ActorCriticCompGraph(model);
}
@AllArgsConstructor
@Builder
@Value
public static class Configuration {
int numLayer;
int numHiddenNodes;
double l2;
IUpdater updater;
TrainingListener[] listeners;
boolean useLSTM;
}
}

View File

@ -1,5 +1,6 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
* Copyright (c) 2015-2019 Skymind, Inc.
* Copyright (c) 2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
@ -31,21 +32,24 @@ import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.optimize.api.TrainingListener;
import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
import org.deeplearning4j.rl4j.network.configuration.ActorCriticDenseNetworkConfiguration;
import org.deeplearning4j.rl4j.network.configuration.ActorCriticDenseNetworkConfiguration.ActorCriticDenseNetworkConfigurationBuilder;
import org.deeplearning4j.rl4j.util.Constants;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.learning.config.Adam;
import org.nd4j.linalg.learning.config.IUpdater;
import org.nd4j.linalg.lossfunctions.LossFunctions;
import java.util.Arrays;
/**
* @author rubenfiszel (ruben.fiszel@epfl.ch) on 8/9/16.
*
*
*/
@Value
public class ActorCriticFactorySeparateStdDense implements ActorCriticFactorySeparate {
Configuration conf;
ActorCriticDenseNetworkConfiguration conf;
public ActorCriticSeparate buildActorCritic(int[] numInputs, int numOutputs) {
int nIn = 1;
@ -53,27 +57,27 @@ public class ActorCriticFactorySeparateStdDense implements ActorCriticFactorySep
nIn *= i;
}
NeuralNetConfiguration.ListBuilder confB = new NeuralNetConfiguration.Builder().seed(Constants.NEURAL_NET_SEED)
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
.updater(conf.getUpdater() != null ? conf.getUpdater() : new Adam())
.weightInit(WeightInit.XAVIER)
.l2(conf.getL2())
.list().layer(0, new DenseLayer.Builder().nIn(nIn).nOut(conf.getNumHiddenNodes())
.activation(Activation.RELU).build());
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
.updater(conf.getUpdater() != null ? conf.getUpdater() : new Adam())
.weightInit(WeightInit.XAVIER)
.l2(conf.getL2())
.list().layer(0, new DenseLayer.Builder().nIn(nIn).nOut(conf.getNumHiddenNodes())
.activation(Activation.RELU).build());
for (int i = 1; i < conf.getNumLayer(); i++) {
for (int i = 1; i < conf.getNumLayers(); i++) {
confB.layer(i, new DenseLayer.Builder().nIn(conf.getNumHiddenNodes()).nOut(conf.getNumHiddenNodes())
.activation(Activation.RELU).build());
.activation(Activation.RELU).build());
}
if (conf.isUseLSTM()) {
confB.layer(conf.getNumLayer(), new LSTM.Builder().nOut(conf.getNumHiddenNodes()).activation(Activation.TANH).build());
confB.layer(conf.getNumLayers(), new LSTM.Builder().nOut(conf.getNumHiddenNodes()).activation(Activation.TANH).build());
confB.layer(conf.getNumLayer() + 1, new RnnOutputLayer.Builder(LossFunctions.LossFunction.MSE).activation(Activation.IDENTITY)
.nIn(conf.getNumHiddenNodes()).nOut(1).build());
confB.layer(conf.getNumLayers() + 1, new RnnOutputLayer.Builder(LossFunctions.LossFunction.MSE).activation(Activation.IDENTITY)
.nIn(conf.getNumHiddenNodes()).nOut(1).build());
} else {
confB.layer(conf.getNumLayer(), new OutputLayer.Builder(LossFunctions.LossFunction.MSE).activation(Activation.IDENTITY)
.nIn(conf.getNumHiddenNodes()).nOut(1).build());
confB.layer(conf.getNumLayers(), new OutputLayer.Builder(LossFunctions.LossFunction.MSE).activation(Activation.IDENTITY)
.nIn(conf.getNumHiddenNodes()).nOut(1).build());
}
confB.setInputType(conf.isUseLSTM() ? InputType.recurrent(nIn) : InputType.feedForward(nIn));
@ -87,28 +91,28 @@ public class ActorCriticFactorySeparateStdDense implements ActorCriticFactorySep
}
NeuralNetConfiguration.ListBuilder confB2 = new NeuralNetConfiguration.Builder().seed(Constants.NEURAL_NET_SEED)
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
.updater(conf.getUpdater() != null ? conf.getUpdater() : new Adam())
.weightInit(WeightInit.XAVIER)
//.regularization(true)
//.l2(conf.getL2())
.list().layer(0, new DenseLayer.Builder().nIn(nIn).nOut(conf.getNumHiddenNodes())
.activation(Activation.RELU).build());
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
.updater(conf.getUpdater() != null ? conf.getUpdater() : new Adam())
.weightInit(WeightInit.XAVIER)
//.regularization(true)
//.l2(conf.getL2())
.list().layer(0, new DenseLayer.Builder().nIn(nIn).nOut(conf.getNumHiddenNodes())
.activation(Activation.RELU).build());
for (int i = 1; i < conf.getNumLayer(); i++) {
for (int i = 1; i < conf.getNumLayers(); i++) {
confB2.layer(i, new DenseLayer.Builder().nIn(conf.getNumHiddenNodes()).nOut(conf.getNumHiddenNodes())
.activation(Activation.RELU).build());
.activation(Activation.RELU).build());
}
if (conf.isUseLSTM()) {
confB2.layer(conf.getNumLayer(), new LSTM.Builder().nOut(conf.getNumHiddenNodes()).activation(Activation.TANH).build());
confB2.layer(conf.getNumLayers(), new LSTM.Builder().nOut(conf.getNumHiddenNodes()).activation(Activation.TANH).build());
confB2.layer(conf.getNumLayer() + 1, new RnnOutputLayer.Builder(new ActorCriticLoss())
.activation(Activation.SOFTMAX).nIn(conf.getNumHiddenNodes()).nOut(numOutputs).build());
confB2.layer(conf.getNumLayers() + 1, new RnnOutputLayer.Builder(new ActorCriticLoss())
.activation(Activation.SOFTMAX).nIn(conf.getNumHiddenNodes()).nOut(numOutputs).build());
} else {
confB2.layer(conf.getNumLayer(), new OutputLayer.Builder(new ActorCriticLoss())
.activation(Activation.SOFTMAX).nIn(conf.getNumHiddenNodes()).nOut(numOutputs).build());
confB2.layer(conf.getNumLayers(), new OutputLayer.Builder(new ActorCriticLoss())
.activation(Activation.SOFTMAX).nIn(conf.getNumHiddenNodes()).nOut(numOutputs).build());
}
confB2.setInputType(conf.isUseLSTM() ? InputType.recurrent(nIn) : InputType.feedForward(nIn));
@ -128,6 +132,7 @@ public class ActorCriticFactorySeparateStdDense implements ActorCriticFactorySep
@AllArgsConstructor
@Value
@Builder
@Deprecated
public static class Configuration {
int numLayer;
@ -136,6 +141,22 @@ public class ActorCriticFactorySeparateStdDense implements ActorCriticFactorySep
IUpdater updater;
TrainingListener[] listeners;
boolean useLSTM;
public ActorCriticDenseNetworkConfiguration toNetworkConfiguration() {
ActorCriticDenseNetworkConfigurationBuilder builder = ActorCriticDenseNetworkConfiguration.builder()
.numHiddenNodes(numHiddenNodes)
.numLayers(numLayer)
.l2(l2)
.updater(updater)
.useLSTM(useLSTM);
if (listeners != null) {
builder.listeners(Arrays.asList(listeners));
}
return builder.build();
}
}

View File

@ -0,0 +1,42 @@
/*******************************************************************************
* Copyright (c) 2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.rl4j.network.configuration;
import lombok.Builder;
import lombok.Data;
import lombok.EqualsAndHashCode;
import lombok.experimental.SuperBuilder;
@Data
@SuperBuilder
@EqualsAndHashCode(callSuper = true)
public class ActorCriticDenseNetworkConfiguration extends ActorCriticNetworkConfiguration {
/**
* The number of layers in the dense network
*/
@Builder.Default
private int numLayers = 3;
/**
* The number of hidden neurons in each layer
*/
@Builder.Default
private int numHiddenNodes = 100;
}

View File

@ -0,0 +1,37 @@
/*******************************************************************************
* Copyright (c) 2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.rl4j.network.configuration;
import lombok.Builder;
import lombok.Data;
import lombok.EqualsAndHashCode;
import lombok.NoArgsConstructor;
import lombok.experimental.SuperBuilder;
@Data
@SuperBuilder
@NoArgsConstructor
@EqualsAndHashCode(callSuper = true)
public class ActorCriticNetworkConfiguration extends NetworkConfiguration {
/**
* Whether or not to add an LSTM layer to the network.
*/
@Builder.Default
private boolean useLSTM = false;
}

View File

@ -0,0 +1,40 @@
/*******************************************************************************
* Copyright (c) 2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.rl4j.network.configuration;
import lombok.Builder;
import lombok.Data;
import lombok.EqualsAndHashCode;
import lombok.experimental.SuperBuilder;
@Data
@SuperBuilder
@EqualsAndHashCode(callSuper = true)
public class DQNDenseNetworkConfiguration extends NetworkConfiguration {
/**
* The number of layers in the dense network
*/
@Builder.Default
private int numLayers = 3;
/**
* The number of hidden neurons in each layer
*/
@Builder.Default
private int numHiddenNodes = 100;
}

View File

@ -0,0 +1,58 @@
/*******************************************************************************
* Copyright (c) 2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.rl4j.network.configuration;
import lombok.Builder;
import lombok.Data;
import lombok.NoArgsConstructor;
import lombok.Singular;
import lombok.experimental.SuperBuilder;
import org.deeplearning4j.optimize.api.TrainingListener;
import org.nd4j.linalg.learning.config.IUpdater;
import java.util.List;
@Data
@SuperBuilder
@NoArgsConstructor
public class NetworkConfiguration {
/**
* The learning rate of the network
*/
@Builder.Default
private double learningRate = 0.01;
/**
* L2 regularization on the network
*/
@Builder.Default
private double l2 = 0.0;
/**
* The network's gradient update algorithm
*/
private IUpdater updater;
/**
* Training listeners attached to the network
*/
@Singular
private List<TrainingListener> listeners;
}

View File

@ -1,5 +1,6 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
* Copyright (c) 2015-2019 Skymind, Inc.
* Copyright (c) 2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
@ -30,12 +31,15 @@ import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.optimize.api.TrainingListener;
import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
import org.deeplearning4j.rl4j.network.configuration.NetworkConfiguration;
import org.deeplearning4j.rl4j.util.Constants;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.learning.config.Adam;
import org.nd4j.linalg.learning.config.IUpdater;
import org.nd4j.linalg.lossfunctions.LossFunctions;
import java.util.Arrays;
/**
* @author rubenfiszel (ruben.fiszel@epfl.ch) 7/13/16.
*/
@ -43,7 +47,7 @@ import org.nd4j.linalg.lossfunctions.LossFunctions;
public class DQNFactoryStdConv implements DQNFactory {
Configuration conf;
NetworkConfiguration conf;
public DQN buildDQN(int shapeInputs[], int numOutputs) {
@ -80,7 +84,6 @@ public class DQNFactoryStdConv implements DQNFactory {
return new DQN(model);
}
@AllArgsConstructor
@Builder
@Value
@ -90,6 +93,23 @@ public class DQNFactoryStdConv implements DQNFactory {
double l2;
IUpdater updater;
TrainingListener[] listeners;
/**
* Converts the deprecated Configuration to the new NetworkConfiguration format
*/
public NetworkConfiguration toNetworkConfiguration() {
NetworkConfiguration.NetworkConfigurationBuilder builder = NetworkConfiguration.builder()
.learningRate(learningRate)
.l2(l2)
.updater(updater);
if (listeners != null) {
builder.listeners(Arrays.asList(listeners));
}
return builder.build();
}
}
}

View File

@ -1,5 +1,6 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
* Copyright (c) 2015-2019 Skymind, Inc.
* Copyright (c) 2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
@ -28,12 +29,16 @@ import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.optimize.api.TrainingListener;
import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
import org.deeplearning4j.rl4j.network.configuration.DQNDenseNetworkConfiguration;
import org.deeplearning4j.rl4j.network.configuration.DQNDenseNetworkConfiguration.DQNDenseNetworkConfigurationBuilder;
import org.deeplearning4j.rl4j.util.Constants;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.learning.config.Adam;
import org.nd4j.linalg.learning.config.IUpdater;
import org.nd4j.linalg.lossfunctions.LossFunctions;
import java.util.Arrays;
/**
* @author rubenfiszel (ruben.fiszel@epfl.ch) 7/13/16.
*/
@ -41,32 +46,41 @@ import org.nd4j.linalg.lossfunctions.LossFunctions;
@Value
public class DQNFactoryStdDense implements DQNFactory {
Configuration conf;
DQNDenseNetworkConfiguration conf;
public DQN buildDQN(int[] numInputs, int numOutputs) {
int nIn = 1;
for (int i : numInputs) {
nIn *= i;
}
NeuralNetConfiguration.ListBuilder confB = new NeuralNetConfiguration.Builder().seed(Constants.NEURAL_NET_SEED)
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
//.updater(Updater.NESTEROVS).momentum(0.9)
//.updater(Updater.RMSPROP).rho(conf.getRmsDecay())//.rmsDecay(conf.getRmsDecay())
.updater(conf.getUpdater() != null ? conf.getUpdater() : new Adam())
.weightInit(WeightInit.XAVIER)
.l2(conf.getL2())
.list().layer(0, new DenseLayer.Builder().nIn(nIn).nOut(conf.getNumHiddenNodes())
.activation(Activation.RELU).build());
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
.updater(conf.getUpdater() != null ? conf.getUpdater() : new Adam())
.weightInit(WeightInit.XAVIER)
.l2(conf.getL2())
.list()
.layer(0,
new DenseLayer.Builder()
.nIn(nIn)
.nOut(conf.getNumHiddenNodes())
.activation(Activation.RELU).build()
);
for (int i = 1; i < conf.getNumLayer(); i++) {
for (int i = 1; i < conf.getNumLayers(); i++) {
confB.layer(i, new DenseLayer.Builder().nIn(conf.getNumHiddenNodes()).nOut(conf.getNumHiddenNodes())
.activation(Activation.RELU).build());
.activation(Activation.RELU).build());
}
confB.layer(conf.getNumLayer(), new OutputLayer.Builder(LossFunctions.LossFunction.MSE).activation(Activation.IDENTITY)
.nIn(conf.getNumHiddenNodes()).nOut(numOutputs).build());
confB.layer(conf.getNumLayers(),
new OutputLayer.Builder(LossFunctions.LossFunction.MSE)
.activation(Activation.IDENTITY)
.nIn(conf.getNumHiddenNodes())
.nOut(numOutputs)
.build()
);
MultiLayerConfiguration mlnconf = confB.build();
@ -83,6 +97,7 @@ public class DQNFactoryStdDense implements DQNFactory {
@AllArgsConstructor
@Value
@Builder
@Deprecated
public static class Configuration {
int numLayer;
@ -90,7 +105,23 @@ public class DQNFactoryStdDense implements DQNFactory {
double l2;
IUpdater updater;
TrainingListener[] listeners;
/**
* Converts the deprecated Configuration to the new NetworkConfiguration format
*/
public DQNDenseNetworkConfiguration toNetworkConfiguration() {
DQNDenseNetworkConfigurationBuilder builder = DQNDenseNetworkConfiguration.builder()
.numHiddenNodes(numHiddenNodes)
.numLayers(numLayer)
.l2(l2)
.updater(updater);
if (listeners != null) {
builder.listeners(Arrays.asList(listeners));
}
return builder.build();
}
}
}

View File

@ -1,5 +1,6 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
* Copyright (c) 2015-2019 Skymind, Inc.
* Copyright (c) 2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
@ -46,7 +47,7 @@ public class EpsGreedy<O, A, AS extends ActionSpace<A>> extends Policy<O, A> {
final private int updateStart;
final private int epsilonNbStep;
final private Random rnd;
final private float minEpsilon;
final private double minEpsilon;
final private IEpochTrainer learning;
public NeuralNet getNeuralNet() {
@ -55,10 +56,10 @@ public class EpsGreedy<O, A, AS extends ActionSpace<A>> extends Policy<O, A> {
public A nextAction(INDArray input) {
float ep = getEpsilon();
double ep = getEpsilon();
if (learning.getStepCounter() % 500 == 1)
log.info("EP: " + ep + " " + learning.getStepCounter());
if (rnd.nextFloat() > ep)
if (rnd.nextDouble() > ep)
return policy.nextAction(input);
else
return mdp.getActionSpace().randomAction();
@ -68,7 +69,7 @@ public class EpsGreedy<O, A, AS extends ActionSpace<A>> extends Policy<O, A> {
return this.nextAction(observation.getData());
}
public float getEpsilon() {
return Math.min(1f, Math.max(minEpsilon, 1f - (learning.getStepCounter() - updateStart) * 1f / epsilonNbStep));
public double getEpsilon() {
return Math.min(1.0, Math.max(minEpsilon, 1.0 - (learning.getStepCounter() - updateStart) * 1.0 / epsilonNbStep));
}
}

View File

@ -1,5 +1,6 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
* Copyright (c) 2015-2019 Skymind, Inc.
* Copyright (c) 2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
@ -22,17 +23,30 @@ import lombok.Builder;
import lombok.Getter;
import lombok.Value;
import lombok.extern.slf4j.Slf4j;
import org.deeplearning4j.rl4j.learning.NeuralNetFetchable;
import org.nd4j.linalg.primitives.Pair;
import org.deeplearning4j.rl4j.learning.ILearning;
import org.deeplearning4j.rl4j.learning.Learning;
import org.deeplearning4j.rl4j.learning.NeuralNetFetchable;
import org.deeplearning4j.rl4j.learning.configuration.ILearningConfiguration;
import org.deeplearning4j.rl4j.network.ac.IActorCritic;
import org.deeplearning4j.rl4j.network.dqn.DQN;
import org.deeplearning4j.rl4j.network.dqn.IDQN;
import org.deeplearning4j.util.ModelSerializer;
import org.nd4j.linalg.primitives.Pair;
import java.io.*;
import java.nio.file.*;
import java.io.BufferedOutputStream;
import java.io.BufferedReader;
import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.File;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.InputStreamReader;
import java.io.OutputStream;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.nio.file.StandardCopyOption;
import java.nio.file.StandardOpenOption;
import java.util.zip.ZipEntry;
import java.util.zip.ZipFile;
import java.util.zip.ZipOutputStream;
@ -304,7 +318,7 @@ public class DataManager implements IDataManager {
public static class Info {
String trainingName;
String mdpName;
ILearning.LConfiguration conf;
ILearningConfiguration conf;
int stepCounter;
long millisTime;
}

View File

@ -1,5 +1,6 @@
/*******************************************************************************
* Copyright (c) 2015-2019 Skymind, Inc.
* Copyright (c) 2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
@ -16,14 +17,11 @@
package org.deeplearning4j.rl4j.learning;
import java.util.Arrays;
import org.junit.Test;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;
/**
*
@ -32,7 +30,7 @@ import static org.junit.Assert.assertTrue;
public class HistoryProcessorTest {
@Test
public void testHistoryProcessor() throws Exception {
public void testHistoryProcessor() {
HistoryProcessor.Configuration conf = HistoryProcessor.Configuration.builder()
.croppingHeight(2).croppingWidth(2).rescaledHeight(2).rescaledWidth(2).build();
IHistoryProcessor hp = new HistoryProcessor(conf);
@ -43,8 +41,6 @@ public class HistoryProcessorTest {
hp.add(a);
INDArray[] h = hp.getHistory();
assertEquals(4, h.length);
// System.out.println(Arrays.toString(a.shape()));
// System.out.println(Arrays.toString(h[0].shape()));
assertEquals( 1, h[0].shape()[0]);
assertEquals(a.shape()[0], h[0].shape()[1]);
assertEquals(a.shape()[1], h[0].shape()[2]);

View File

@ -1,9 +1,32 @@
/*******************************************************************************
* Copyright (c) 2015-2019 Skymind, Inc.
* Copyright (c) 2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.rl4j.learning.async;
import org.deeplearning4j.rl4j.learning.configuration.IAsyncLearningConfiguration;
import org.deeplearning4j.rl4j.mdp.MDP;
import org.deeplearning4j.rl4j.policy.IPolicy;
import org.deeplearning4j.rl4j.space.DiscreteSpace;
import org.deeplearning4j.rl4j.support.*;
import org.deeplearning4j.rl4j.support.MockAsyncConfiguration;
import org.deeplearning4j.rl4j.support.MockAsyncGlobal;
import org.deeplearning4j.rl4j.support.MockEncodable;
import org.deeplearning4j.rl4j.support.MockNeuralNet;
import org.deeplearning4j.rl4j.support.MockPolicy;
import org.deeplearning4j.rl4j.support.MockTrainingListener;
import org.junit.Test;
import static org.junit.Assert.assertEquals;
@ -68,7 +91,7 @@ public class AsyncLearningTest {
public static class TestContext {
MockAsyncConfiguration config = new MockAsyncConfiguration(1, 11, 0, 0, 0, 0,0, 0, 0, 0);
MockAsyncConfiguration config = new MockAsyncConfiguration(1L, 11, 0, 0, 0, 0,0, 0, 0, 0);
public final MockAsyncGlobal asyncGlobal = new MockAsyncGlobal();
public final MockPolicy policy = new MockPolicy();
public final TestAsyncLearning sut = new TestAsyncLearning(config, asyncGlobal, policy);
@ -82,11 +105,11 @@ public class AsyncLearningTest {
}
public static class TestAsyncLearning extends AsyncLearning<MockEncodable, Integer, DiscreteSpace, MockNeuralNet> {
private final AsyncConfiguration conf;
private final IAsyncLearningConfiguration conf;
private final IAsyncGlobal asyncGlobal;
private final IPolicy<MockEncodable, Integer> policy;
public TestAsyncLearning(AsyncConfiguration conf, IAsyncGlobal asyncGlobal, IPolicy<MockEncodable, Integer> policy) {
public TestAsyncLearning(IAsyncLearningConfiguration conf, IAsyncGlobal asyncGlobal, IPolicy<MockEncodable, Integer> policy) {
this.conf = conf;
this.asyncGlobal = asyncGlobal;
this.policy = policy;
@ -98,7 +121,7 @@ public class AsyncLearningTest {
}
@Override
public AsyncConfiguration getConfiguration() {
public IAsyncLearningConfiguration getConfiguration() {
return conf;
}

View File

@ -1,7 +1,25 @@
/*******************************************************************************
* Copyright (c) 2015-2019 Skymind, Inc.
* Copyright (c) 2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.rl4j.learning.async;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.rl4j.learning.IHistoryProcessor;
import org.deeplearning4j.rl4j.learning.configuration.IAsyncLearningConfiguration;
import org.deeplearning4j.rl4j.learning.listener.TrainingListenerList;
import org.deeplearning4j.rl4j.mdp.MDP;
import org.deeplearning4j.rl4j.observation.Observation;
@ -32,7 +50,7 @@ public class AsyncThreadDiscreteTest {
MockMDP mdpMock = new MockMDP(observationSpace);
TrainingListenerList listeners = new TrainingListenerList();
MockPolicy policyMock = new MockPolicy();
MockAsyncConfiguration config = new MockAsyncConfiguration(5, 100, 0, 0, 2, 5,0, 0, 0, 0);
MockAsyncConfiguration config = new MockAsyncConfiguration(5L, 100, 0,0, 0, 0, 0, 0, 2, 5);
TestAsyncThreadDiscrete sut = new TestAsyncThreadDiscrete(asyncGlobalMock, mdpMock, listeners, 0, 0, policyMock, config, hpMock);
sut.getLegacyMDPWrapper().setTransformProcess(MockMDP.buildTransformProcess(observationSpace.getShape(), hpConf.getSkipFrame(), hpConf.getHistoryLength()));
@ -173,7 +191,7 @@ public class AsyncThreadDiscreteTest {
}
@Override
protected AsyncConfiguration getConf() {
protected IAsyncLearningConfiguration getConf() {
return config;
}

View File

@ -3,12 +3,20 @@ package org.deeplearning4j.rl4j.learning.async;
import lombok.AllArgsConstructor;
import lombok.Getter;
import org.deeplearning4j.rl4j.learning.IHistoryProcessor;
import org.deeplearning4j.rl4j.learning.configuration.IAsyncLearningConfiguration;
import org.deeplearning4j.rl4j.learning.listener.TrainingListenerList;
import org.deeplearning4j.rl4j.mdp.MDP;
import org.deeplearning4j.rl4j.observation.Observation;
import org.deeplearning4j.rl4j.policy.Policy;
import org.deeplearning4j.rl4j.space.DiscreteSpace;
import org.deeplearning4j.rl4j.support.*;
import org.deeplearning4j.rl4j.support.MockAsyncConfiguration;
import org.deeplearning4j.rl4j.support.MockAsyncGlobal;
import org.deeplearning4j.rl4j.support.MockEncodable;
import org.deeplearning4j.rl4j.support.MockHistoryProcessor;
import org.deeplearning4j.rl4j.support.MockMDP;
import org.deeplearning4j.rl4j.support.MockNeuralNet;
import org.deeplearning4j.rl4j.support.MockObservationSpace;
import org.deeplearning4j.rl4j.support.MockTrainingListener;
import org.deeplearning4j.rl4j.util.IDataManager;
import org.junit.Test;
@ -16,7 +24,6 @@ import java.util.ArrayList;
import java.util.List;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNull;
public class AsyncThreadTest {
@ -126,7 +133,7 @@ public class AsyncThreadTest {
public final MockNeuralNet neuralNet = new MockNeuralNet();
public final MockObservationSpace observationSpace = new MockObservationSpace();
public final MockMDP mdp = new MockMDP(observationSpace);
public final MockAsyncConfiguration config = new MockAsyncConfiguration(5, 10, 0, 0, 10, 0, 0, 0, 0, 0);
public final MockAsyncConfiguration config = new MockAsyncConfiguration(5L, 10, 0, 0, 0, 0, 0, 0, 10, 0);
public final TrainingListenerList listeners = new TrainingListenerList();
public final MockTrainingListener listener = new MockTrainingListener();
public final IHistoryProcessor.Configuration hpConf = new IHistoryProcessor.Configuration(5, 4, 4, 4, 4, 0, 0, 2);
@ -149,11 +156,11 @@ public class AsyncThreadTest {
private final MockAsyncGlobal asyncGlobal;
private final MockNeuralNet neuralNet;
private final AsyncConfiguration conf;
private final IAsyncLearningConfiguration conf;
private final List<TrainSubEpochParams> trainSubEpochParams = new ArrayList<TrainSubEpochParams>();
public MockAsyncThread(MockAsyncGlobal asyncGlobal, int threadNumber, MockNeuralNet neuralNet, MDP mdp, AsyncConfiguration conf, TrainingListenerList listeners) {
public MockAsyncThread(MockAsyncGlobal asyncGlobal, int threadNumber, MockNeuralNet neuralNet, MDP mdp, IAsyncLearningConfiguration conf, TrainingListenerList listeners) {
super(asyncGlobal, mdp, listeners, threadNumber, 0);
this.asyncGlobal = asyncGlobal;
@ -184,7 +191,7 @@ public class AsyncThreadTest {
}
@Override
protected AsyncConfiguration getConf() {
protected IAsyncLearningConfiguration getConf() {
return conf;
}

View File

@ -1,11 +1,27 @@
/*******************************************************************************
* Copyright (c) 2015-2019 Skymind, Inc.
* Copyright (c) 2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.rl4j.learning.async.a3c.discrete;
import org.deeplearning4j.nn.api.NeuralNetwork;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.rl4j.learning.IHistoryProcessor;
import org.deeplearning4j.rl4j.learning.async.MiniTrans;
import org.deeplearning4j.rl4j.learning.async.nstep.discrete.AsyncNStepQLearningDiscrete;
import org.deeplearning4j.rl4j.learning.async.nstep.discrete.AsyncNStepQLearningThreadDiscrete;
import org.deeplearning4j.rl4j.learning.configuration.A3CLearningConfiguration;
import org.deeplearning4j.rl4j.network.NeuralNet;
import org.deeplearning4j.rl4j.network.ac.IActorCritic;
import org.deeplearning4j.rl4j.support.*;
@ -31,7 +47,7 @@ public class A3CThreadDiscreteTest {
double gamma = 0.9;
MockObservationSpace observationSpace = new MockObservationSpace();
MockMDP mdpMock = new MockMDP(observationSpace);
A3CDiscrete.A3CConfiguration config = new A3CDiscrete.A3CConfiguration(0, 0, 0, 0, 0, 0, 0, gamma, 0);
A3CLearningConfiguration config = A3CLearningConfiguration.builder().gamma(0.9).build();
MockActorCritic actorCriticMock = new MockActorCritic();
IHistoryProcessor.Configuration hpConf = new IHistoryProcessor.Configuration(5, 1, 1, 1, 1, 0, 0, 2);
MockAsyncGlobal<IActorCritic> asyncGlobalMock = new MockAsyncGlobal<IActorCritic>(actorCriticMock);
@ -54,9 +70,9 @@ public class A3CThreadDiscreteTest {
Nd4j.zeros(5)
};
output[0].putScalar(i, outputs[i]);
minitransList.push(new MiniTrans<Integer>(obs, i, output, rewards[i]));
minitransList.push(new MiniTrans<>(obs, i, output, rewards[i]));
}
minitransList.push(new MiniTrans<Integer>(null, 0, null, 4.0)); // The special batch-ending MiniTrans
minitransList.push(new MiniTrans<>(null, 0, null, 4.0)); // The special batch-ending MiniTrans
// Act
sut.calcGradient(actorCriticMock, minitransList);

View File

@ -1,7 +1,24 @@
/*******************************************************************************
* Copyright (c) 2015-2020 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.rl4j.learning.async.nstep.discrete;
import org.deeplearning4j.rl4j.learning.IHistoryProcessor;
import org.deeplearning4j.rl4j.learning.async.MiniTrans;
import org.deeplearning4j.rl4j.learning.configuration.AsyncQLearningConfiguration;
import org.deeplearning4j.rl4j.support.*;
import org.junit.Test;
import org.nd4j.linalg.api.ndarray.INDArray;
@ -19,7 +36,7 @@ public class AsyncNStepQLearningThreadDiscreteTest {
double gamma = 0.9;
MockObservationSpace observationSpace = new MockObservationSpace();
MockMDP mdpMock = new MockMDP(observationSpace);
AsyncNStepQLearningDiscrete.AsyncNStepQLConfiguration config = new AsyncNStepQLearningDiscrete.AsyncNStepQLConfiguration(0, 0, 0, 0, 0, 0, 0, 0, gamma, 0, 0, 0);
AsyncQLearningConfiguration config = AsyncQLearningConfiguration.builder().gamma(gamma).build();
MockDQN dqnMock = new MockDQN();
IHistoryProcessor.Configuration hpConf = new IHistoryProcessor.Configuration(5, 1, 1, 1, 1, 0, 0, 2);
MockAsyncGlobal asyncGlobalMock = new MockAsyncGlobal(dqnMock);
@ -42,9 +59,9 @@ public class AsyncNStepQLearningThreadDiscreteTest {
Nd4j.zeros(5)
};
output[0].putScalar(i, outputs[i]);
minitransList.push(new MiniTrans<Integer>(obs, i, output, rewards[i]));
minitransList.push(new MiniTrans<>(obs, i, output, rewards[i]));
}
minitransList.push(new MiniTrans<Integer>(null, 0, null, 4.0)); // The special batch-ending MiniTrans
minitransList.push(new MiniTrans<>(null, 0, null, 4.0)); // The special batch-ending MiniTrans
// Act
sut.calcGradient(dqnMock, minitransList);

View File

@ -1,6 +1,26 @@
/*******************************************************************************
* Copyright (c) 2015-2019 Skymind, Inc.
* Copyright (c) 2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.rl4j.learning.sync;
import lombok.Getter;
import org.deeplearning4j.rl4j.learning.configuration.ILearningConfiguration;
import org.deeplearning4j.rl4j.learning.configuration.LearningConfiguration;
import org.deeplearning4j.rl4j.learning.configuration.QLearningConfiguration;
import org.deeplearning4j.rl4j.learning.sync.qlearning.QLearning;
import org.deeplearning4j.rl4j.learning.sync.support.MockStatEntry;
import org.deeplearning4j.rl4j.mdp.MDP;
@ -17,7 +37,7 @@ public class SyncLearningTest {
@Test
public void when_training_expect_listenersToBeCalled() {
// Arrange
QLearning.QLConfiguration lconfig = QLearning.QLConfiguration.builder().maxStep(10).build();
QLearningConfiguration lconfig = QLearningConfiguration.builder().maxStep(10).build();
MockTrainingListener listener = new MockTrainingListener();
MockSyncLearning sut = new MockSyncLearning(lconfig);
sut.addListener(listener);
@ -34,7 +54,7 @@ public class SyncLearningTest {
@Test
public void when_trainingStartCanContinueFalse_expect_trainingStopped() {
// Arrange
QLearning.QLConfiguration lconfig = QLearning.QLConfiguration.builder().maxStep(10).build();
QLearningConfiguration lconfig = QLearningConfiguration.builder().maxStep(10).build();
MockTrainingListener listener = new MockTrainingListener();
MockSyncLearning sut = new MockSyncLearning(lconfig);
sut.addListener(listener);
@ -52,7 +72,7 @@ public class SyncLearningTest {
@Test
public void when_newEpochCanContinueFalse_expect_trainingStopped() {
// Arrange
QLearning.QLConfiguration lconfig = QLearning.QLConfiguration.builder().maxStep(10).build();
QLearningConfiguration lconfig = QLearningConfiguration.builder().maxStep(10).build();
MockTrainingListener listener = new MockTrainingListener();
MockSyncLearning sut = new MockSyncLearning(lconfig);
sut.addListener(listener);
@ -70,7 +90,7 @@ public class SyncLearningTest {
@Test
public void when_epochTrainingResultCanContinueFalse_expect_trainingStopped() {
// Arrange
QLearning.QLConfiguration lconfig = QLearning.QLConfiguration.builder().maxStep(10).build();
LearningConfiguration lconfig = QLearningConfiguration.builder().maxStep(10).build();
MockTrainingListener listener = new MockTrainingListener();
MockSyncLearning sut = new MockSyncLearning(lconfig);
sut.addListener(listener);
@ -87,12 +107,12 @@ public class SyncLearningTest {
public static class MockSyncLearning extends SyncLearning {
private final LConfiguration conf;
private final ILearningConfiguration conf;
@Getter
private int currentEpochStep = 0;
public MockSyncLearning(LConfiguration conf) {
public MockSyncLearning(ILearningConfiguration conf) {
this.conf = conf;
}
@ -119,7 +139,7 @@ public class SyncLearningTest {
}
@Override
public LConfiguration getConfiguration() {
public ILearningConfiguration getConfiguration() {
return conf;
}

View File

@ -1,5 +1,6 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
* Copyright (c) 2015-2019 Skymind, Inc.
* Copyright (c) 2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
@ -17,36 +18,24 @@
package org.deeplearning4j.rl4j.learning.sync.qlearning;
import com.fasterxml.jackson.databind.ObjectMapper;
import org.deeplearning4j.rl4j.learning.configuration.QLearningConfiguration;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.ExpectedException;
public class QLConfigurationTest {
public class QLearningConfigurationTest {
@Rule
public ExpectedException thrown = ExpectedException.none();
@Test
public void serialize() throws Exception {
ObjectMapper mapper = new ObjectMapper();
QLearning.QLConfiguration qlConfiguration =
new QLearning.QLConfiguration(
123, //Random seed
200, //Max step By epoch
8000, //Max step
150000, //Max size of experience replay
32, //size of batches
500, //target update (hard)
10, //num step noop warmup
0.01, //reward scaling
0.99, //gamma
1.0, //td error clipping
0.1f, //min epsilon
10000, //num step for eps greedy anneal
true //double DQN
);
QLearningConfiguration qLearningConfiguration = QLearningConfiguration.builder()
.build();
// Should not throw..
String json = mapper.writeValueAsString(qlConfiguration);
QLearning.QLConfiguration cnf = mapper.readValue(json, QLearning.QLConfiguration.class);
String json = mapper.writeValueAsString(qLearningConfiguration);
QLearningConfiguration cnf = mapper.readValue(json, QLearningConfiguration.class);
}
}

View File

@ -1,6 +1,24 @@
/*******************************************************************************
* Copyright (c) 2015-2019 Skymind, Inc.
* Copyright (c) 2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.rl4j.learning.sync.qlearning.discrete;
import org.deeplearning4j.rl4j.learning.IHistoryProcessor;
import org.deeplearning4j.rl4j.learning.configuration.QLearningConfiguration;
import org.deeplearning4j.rl4j.learning.sync.IExpReplay;
import org.deeplearning4j.rl4j.learning.sync.Transition;
import org.deeplearning4j.rl4j.learning.sync.qlearning.QLearning;
@ -27,7 +45,7 @@ public class QLearningDiscreteTest {
// Arrange
MockObservationSpace observationSpace = new MockObservationSpace();
MockDQN dqn = new MockDQN();
MockRandom random = new MockRandom(new double[] {
MockRandom random = new MockRandom(new double[]{
0.7309677600860596,
0.8314409852027893,
0.2405363917350769,
@ -36,14 +54,26 @@ public class QLearningDiscreteTest {
0.3090505599975586,
0.5504369735717773,
0.11700659990310669
},
new int[] { 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4 });
},
new int[]{1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4});
MockMDP mdp = new MockMDP(observationSpace, random);
int initStepCount = 8;
QLearning.QLConfiguration conf = new QLearning.QLConfiguration(0, 24, 0, 5, 1, 1000,
initStepCount, 1.0, 0, 0, 0, 0, true);
QLearningConfiguration conf = QLearningConfiguration.builder()
.seed(0L)
.maxEpochStep(24)
.maxStep(0)
.expRepMaxSize(5).batchSize(1).targetDqnUpdateFreq(1000)
.updateStart(initStepCount)
.rewardFactor(1.0)
.gamma(0)
.errorClamp(0)
.minEpsilon(0)
.epsilonNbStep(0)
.doubleDQN(true)
.build();
MockDataManager dataManager = new MockDataManager(false);
MockExpReplay expReplay = new MockExpReplay();
TestQLearningDiscrete sut = new TestQLearningDiscrete(mdp, dqn, conf, dataManager, expReplay, 10, random);
@ -58,9 +88,9 @@ public class QLearningDiscreteTest {
// Assert
// HistoryProcessor calls
double[] expectedRecords = new double[] { 0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0, 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0 };
double[] expectedRecords = new double[]{0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0, 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0};
assertEquals(expectedRecords.length, hp.recordCalls.size());
for(int i = 0; i < expectedRecords.length; ++i) {
for (int i = 0; i < expectedRecords.length; ++i) {
assertEquals(expectedRecords[i], hp.recordCalls.get(i).getDouble(0), 0.0001);
}
@ -72,59 +102,59 @@ public class QLearningDiscreteTest {
assertEquals(123.0, dqn.fitParams.get(0).getFirst().getDouble(0), 0.001);
assertEquals(234.0, dqn.fitParams.get(0).getSecond().getDouble(0), 0.001);
assertEquals(14, dqn.outputParams.size());
double[][] expectedDQNOutput = new double[][] {
new double[] { 0.0, 2.0, 4.0, 6.0, 8.0 },
new double[] { 2.0, 4.0, 6.0, 8.0, 10.0 },
new double[] { 2.0, 4.0, 6.0, 8.0, 10.0 },
new double[] { 4.0, 6.0, 8.0, 10.0, 12.0 },
new double[] { 6.0, 8.0, 10.0, 12.0, 14.0 },
new double[] { 6.0, 8.0, 10.0, 12.0, 14.0 },
new double[] { 8.0, 10.0, 12.0, 14.0, 16.0 },
new double[] { 8.0, 10.0, 12.0, 14.0, 16.0 },
new double[] { 10.0, 12.0, 14.0, 16.0, 18.0 },
new double[] { 10.0, 12.0, 14.0, 16.0, 18.0 },
new double[] { 12.0, 14.0, 16.0, 18.0, 20.0 },
new double[] { 12.0, 14.0, 16.0, 18.0, 20.0 },
new double[] { 14.0, 16.0, 18.0, 20.0, 22.0 },
new double[] { 14.0, 16.0, 18.0, 20.0, 22.0 },
double[][] expectedDQNOutput = new double[][]{
new double[]{0.0, 2.0, 4.0, 6.0, 8.0},
new double[]{2.0, 4.0, 6.0, 8.0, 10.0},
new double[]{2.0, 4.0, 6.0, 8.0, 10.0},
new double[]{4.0, 6.0, 8.0, 10.0, 12.0},
new double[]{6.0, 8.0, 10.0, 12.0, 14.0},
new double[]{6.0, 8.0, 10.0, 12.0, 14.0},
new double[]{8.0, 10.0, 12.0, 14.0, 16.0},
new double[]{8.0, 10.0, 12.0, 14.0, 16.0},
new double[]{10.0, 12.0, 14.0, 16.0, 18.0},
new double[]{10.0, 12.0, 14.0, 16.0, 18.0},
new double[]{12.0, 14.0, 16.0, 18.0, 20.0},
new double[]{12.0, 14.0, 16.0, 18.0, 20.0},
new double[]{14.0, 16.0, 18.0, 20.0, 22.0},
new double[]{14.0, 16.0, 18.0, 20.0, 22.0},
};
for(int i = 0; i < expectedDQNOutput.length; ++i) {
for (int i = 0; i < expectedDQNOutput.length; ++i) {
INDArray outputParam = dqn.outputParams.get(i);
assertEquals(5, outputParam.shape()[1]);
assertEquals(1, outputParam.shape()[2]);
double[] expectedRow = expectedDQNOutput[i];
for(int j = 0; j < expectedRow.length; ++j) {
assertEquals("row: "+ i + " col: " + j, expectedRow[j], 255.0 * outputParam.getDouble(j), 0.00001);
for (int j = 0; j < expectedRow.length; ++j) {
assertEquals("row: " + i + " col: " + j, expectedRow[j], 255.0 * outputParam.getDouble(j), 0.00001);
}
}
// MDP calls
assertArrayEquals(new Integer[] {0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 4, 4, 2, 2, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4 }, mdp.actions.toArray());
assertArrayEquals(new Integer[]{0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 4, 4, 2, 2, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4}, mdp.actions.toArray());
// ExpReplay calls
double[] expectedTrRewards = new double[] { 9.0, 21.0, 25.0, 29.0, 33.0, 37.0, 41.0 };
int[] expectedTrActions = new int[] { 1, 4, 2, 4, 4, 4, 4, 4 };
double[] expectedTrNextObservation = new double[] { 2.0, 4.0, 6.0, 8.0, 10.0, 12.0, 14.0, 16.0 };
double[][] expectedTrObservations = new double[][] {
new double[] { 0.0, 2.0, 4.0, 6.0, 8.0 },
new double[] { 2.0, 4.0, 6.0, 8.0, 10.0 },
new double[] { 4.0, 6.0, 8.0, 10.0, 12.0 },
new double[] { 6.0, 8.0, 10.0, 12.0, 14.0 },
new double[] { 8.0, 10.0, 12.0, 14.0, 16.0 },
new double[] { 10.0, 12.0, 14.0, 16.0, 18.0 },
new double[] { 12.0, 14.0, 16.0, 18.0, 20.0 },
new double[] { 14.0, 16.0, 18.0, 20.0, 22.0 },
double[] expectedTrRewards = new double[]{9.0, 21.0, 25.0, 29.0, 33.0, 37.0, 41.0};
int[] expectedTrActions = new int[]{1, 4, 2, 4, 4, 4, 4, 4};
double[] expectedTrNextObservation = new double[]{2.0, 4.0, 6.0, 8.0, 10.0, 12.0, 14.0, 16.0};
double[][] expectedTrObservations = new double[][]{
new double[]{0.0, 2.0, 4.0, 6.0, 8.0},
new double[]{2.0, 4.0, 6.0, 8.0, 10.0},
new double[]{4.0, 6.0, 8.0, 10.0, 12.0},
new double[]{6.0, 8.0, 10.0, 12.0, 14.0},
new double[]{8.0, 10.0, 12.0, 14.0, 16.0},
new double[]{10.0, 12.0, 14.0, 16.0, 18.0},
new double[]{12.0, 14.0, 16.0, 18.0, 20.0},
new double[]{14.0, 16.0, 18.0, 20.0, 22.0},
};
assertEquals(expectedTrObservations.length, expReplay.transitions.size());
for(int i = 0; i < expectedTrRewards.length; ++i) {
for (int i = 0; i < expectedTrRewards.length; ++i) {
Transition tr = expReplay.transitions.get(i);
assertEquals(expectedTrRewards[i], tr.getReward(), 0.0001);
assertEquals(expectedTrActions[i], tr.getAction());
assertEquals(expectedTrNextObservation[i], 255.0 * tr.getNextObservation().getDouble(0), 0.0001);
for(int j = 0; j < expectedTrObservations[i].length; ++j) {
assertEquals("row: "+ i + " col: " + j, expectedTrObservations[i][j], 255.0 * tr.getObservation().getData().getDouble(0, j, 0), 0.0001);
for (int j = 0; j < expectedTrObservations[i].length; ++j) {
assertEquals("row: " + i + " col: " + j, expectedTrObservations[i][j], 255.0 * tr.getObservation().getData().getDouble(0, j, 0), 0.0001);
}
}
@ -132,12 +162,12 @@ public class QLearningDiscreteTest {
assertEquals(initStepCount + 16, result.getStepCounter());
assertEquals(300.0, result.getReward(), 0.00001);
assertTrue(dqn.hasBeenReset);
assertTrue(((MockDQN)sut.getTargetQNetwork()).hasBeenReset);
assertTrue(((MockDQN) sut.getTargetQNetwork()).hasBeenReset);
}
public static class TestQLearningDiscrete extends QLearningDiscrete<MockEncodable> {
public TestQLearningDiscrete(MDP<MockEncodable, Integer, DiscreteSpace> mdp, IDQN dqn,
QLConfiguration conf, IDataManager dataManager, MockExpReplay expReplay,
QLearningConfiguration conf, IDataManager dataManager, MockExpReplay expReplay,
int epsilonNbStep, Random rnd) {
super(mdp, dqn, conf, epsilonNbStep, rnd);
addListener(new DataManagerTrainingListener(dataManager));
@ -146,10 +176,10 @@ public class QLearningDiscreteTest {
@Override
protected DataSet setTarget(ArrayList<Transition<Integer>> transitions) {
return new org.nd4j.linalg.dataset.DataSet(Nd4j.create(new double[] { 123.0 }), Nd4j.create(new double[] { 234.0 }));
return new org.nd4j.linalg.dataset.DataSet(Nd4j.create(new double[]{123.0}), Nd4j.create(new double[]{234.0}));
}
public void setExpReplay(IExpReplay<Integer> exp){
public void setExpReplay(IExpReplay<Integer> exp) {
this.expReplay = exp;
}

View File

@ -1,5 +1,6 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
* Copyright (c) 2015-2019 Skymind, Inc.
* Copyright (c) 2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
@ -16,6 +17,7 @@
package org.deeplearning4j.rl4j.network.ac;
import org.deeplearning4j.rl4j.network.configuration.ActorCriticDenseNetworkConfiguration;
import org.junit.Test;
import org.nd4j.linalg.activations.impl.ActivationSoftmax;
import org.nd4j.linalg.api.ndarray.INDArray;
@ -29,30 +31,31 @@ import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;
/**
*
* @author saudet
*/
public class ActorCriticTest {
public static ActorCriticFactorySeparateStdDense.Configuration NET_CONF =
new ActorCriticFactorySeparateStdDense.Configuration(
4, //number of layers
32, //number of hidden nodes
0.001, //l2 regularization
new RmsProp(0.0005), null, false
);
public static ActorCriticDenseNetworkConfiguration NET_CONF =
ActorCriticDenseNetworkConfiguration.builder()
.numLayers(4)
.numHiddenNodes(32)
.l2(0.001)
.updater(new RmsProp(0.0005))
.useLSTM(false)
.build();
public static ActorCriticFactoryCompGraphStdDense.Configuration NET_CONF_CG =
new ActorCriticFactoryCompGraphStdDense.Configuration(
2, //number of layers
128, //number of hidden nodes
0.00001, //l2 regularization
new RmsProp(0.005), null, true
);
public static ActorCriticDenseNetworkConfiguration NET_CONF_CG =
ActorCriticDenseNetworkConfiguration.builder()
.numLayers(2)
.numHiddenNodes(128)
.l2(0.00001)
.updater(new RmsProp(0.005))
.useLSTM(true)
.build();
@Test
public void testModelLoadSave() throws IOException {
ActorCriticSeparate acs = new ActorCriticFactorySeparateStdDense(NET_CONF).buildActorCritic(new int[] {7}, 5);
ActorCriticSeparate acs = new ActorCriticFactorySeparateStdDense(NET_CONF).buildActorCritic(new int[]{7}, 5);
File fileValue = File.createTempFile("rl4j-value-", ".model");
File filePolicy = File.createTempFile("rl4j-policy-", ".model");
@ -63,7 +66,7 @@ public class ActorCriticTest {
assertEquals(acs.valueNet, acs2.valueNet);
assertEquals(acs.policyNet, acs2.policyNet);
ActorCriticCompGraph accg = new ActorCriticFactoryCompGraphStdDense(NET_CONF_CG).buildActorCritic(new int[] {37}, 43);
ActorCriticCompGraph accg = new ActorCriticFactoryCompGraphStdDense(NET_CONF_CG).buildActorCritic(new int[]{37}, 43);
File file = File.createTempFile("rl4j-cg-", ".model");
accg.save(file.getAbsolutePath());
@ -83,15 +86,15 @@ public class ActorCriticTest {
for (double i = eps; i < n; i++) {
for (double j = eps; j < n; j++) {
INDArray labels = Nd4j.create(new double[] {i / n, 1 - i / n}, new long[]{1,2});
INDArray output = Nd4j.create(new double[] {j / n, 1 - j / n}, new long[]{1,2});
INDArray labels = Nd4j.create(new double[]{i / n, 1 - i / n}, new long[]{1, 2});
INDArray output = Nd4j.create(new double[]{j / n, 1 - j / n}, new long[]{1, 2});
INDArray gradient = loss.computeGradient(labels, output, activation, null);
output = Nd4j.create(new double[] {j / n, 1 - j / n}, new long[]{1,2});
output = Nd4j.create(new double[]{j / n, 1 - j / n}, new long[]{1, 2});
double score = loss.computeScore(labels, output, activation, null, false);
INDArray output1 = Nd4j.create(new double[] {j / n + eps, 1 - j / n}, new long[]{1,2});
INDArray output1 = Nd4j.create(new double[]{j / n + eps, 1 - j / n}, new long[]{1, 2});
double score1 = loss.computeScore(labels, output1, activation, null, false);
INDArray output2 = Nd4j.create(new double[] {j / n, 1 - j / n + eps}, new long[]{1,2});
INDArray output2 = Nd4j.create(new double[]{j / n, 1 - j / n + eps}, new long[]{1, 2});
double score2 = loss.computeScore(labels, output2, activation, null, false);
double gradient1 = (score1 - score) / eps;

View File

@ -1,5 +1,6 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
* Copyright (c) 2015-2019 Skymind, Inc.
* Copyright (c) 2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
@ -16,6 +17,7 @@
package org.deeplearning4j.rl4j.network.dqn;
import org.deeplearning4j.rl4j.network.configuration.DQNDenseNetworkConfiguration;
import org.junit.Test;
import org.nd4j.linalg.learning.config.RmsProp;
@ -25,22 +27,20 @@ import java.io.IOException;
import static org.junit.Assert.assertEquals;
/**
*
* @author saudet
*/
public class DQNTest {
public static DQNFactoryStdDense.Configuration NET_CONF =
new DQNFactoryStdDense.Configuration(
3, //number of layers
16, //number of hidden nodes
0.001, //l2 regularization
new RmsProp(0.0005), null
);
private static DQNDenseNetworkConfiguration NET_CONF =
DQNDenseNetworkConfiguration.builder().numLayers(3)
.numHiddenNodes(16)
.l2(0.001)
.updater(new RmsProp(0.0005))
.build();
@Test
public void testModelLoadSave() throws IOException {
DQN dqn = new DQNFactoryStdDense(NET_CONF).buildDQN(new int[] {42}, 13);
DQN dqn = new DQNFactoryStdDense(NET_CONF).buildDQN(new int[]{42}, 13);
File file = File.createTempFile("rl4j-dqn-", ".model");
dqn.save(file.getAbsolutePath());

View File

@ -128,7 +128,7 @@ public class TransformProcessTest {
// Assert
assertFalse(result.isSkipped());
assertEquals(1, result.getData().shape().length);
assertEquals(2, result.getData().shape().length);
assertEquals(1, result.getData().shape()[0]);
assertEquals(-10.0, result.getData().getDouble(0), 0.00001);
}

View File

@ -1,5 +1,6 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
* Copyright (c) 2015-2019 Skymind, Inc.
* Copyright (c) 2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
@ -24,16 +25,18 @@ import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.rl4j.learning.IHistoryProcessor;
import org.deeplearning4j.rl4j.learning.Learning;
import org.deeplearning4j.rl4j.learning.sync.qlearning.QLearning;
import org.deeplearning4j.rl4j.learning.sync.qlearning.discrete.QLearningDiscreteTest;
import org.deeplearning4j.rl4j.mdp.MDP;
import org.deeplearning4j.rl4j.learning.configuration.QLearningConfiguration;
import org.deeplearning4j.rl4j.network.NeuralNet;
import org.deeplearning4j.rl4j.network.ac.IActorCritic;
import org.deeplearning4j.rl4j.observation.Observation;
import org.deeplearning4j.rl4j.space.ActionSpace;
import org.deeplearning4j.rl4j.space.DiscreteSpace;
import org.deeplearning4j.rl4j.space.Encodable;
import org.deeplearning4j.rl4j.support.*;
import org.deeplearning4j.rl4j.support.MockDQN;
import org.deeplearning4j.rl4j.support.MockEncodable;
import org.deeplearning4j.rl4j.support.MockHistoryProcessor;
import org.deeplearning4j.rl4j.support.MockMDP;
import org.deeplearning4j.rl4j.support.MockNeuralNet;
import org.deeplearning4j.rl4j.support.MockObservationSpace;
import org.deeplearning4j.rl4j.support.MockRandom;
import org.deeplearning4j.rl4j.util.LegacyMDPWrapper;
import org.junit.Test;
import org.nd4j.linalg.activations.Activation;
@ -43,8 +46,6 @@ import org.nd4j.linalg.lossfunctions.LossFunctions;
import java.io.IOException;
import java.io.OutputStream;
import java.util.ArrayList;
import java.util.List;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotNull;
@ -186,8 +187,22 @@ public class PolicyTest {
new int[] { 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4 });
MockMDP mdp = new MockMDP(observationSpace, 30, random);
QLearning.QLConfiguration conf = new QLearning.QLConfiguration(0, 0, 0, 5, 1, 0,
0, 1.0, 0, 0, 0, 0, true);
QLearningConfiguration conf = QLearningConfiguration.builder()
.seed(0L)
.maxEpochStep(0)
.maxStep(0)
.expRepMaxSize(5)
.batchSize(1)
.targetDqnUpdateFreq(0)
.updateStart(0)
.rewardFactor(1.0)
.gamma(0)
.errorClamp(0)
.minEpsilon(0)
.epsilonNbStep(0)
.doubleDQN(true)
.build();
MockNeuralNet nnMock = new MockNeuralNet();
IHistoryProcessor.Configuration hpConf = new IHistoryProcessor.Configuration(5, 4, 4, 4, 4, 0, 0, 2);
MockRefacPolicy sut = new MockRefacPolicy(nnMock, observationSpace.getShape(), hpConf.getSkipFrame(), hpConf.getHistoryLength());

View File

@ -1,22 +1,37 @@
/*******************************************************************************
* Copyright (c) 2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.rl4j.support;
import lombok.AllArgsConstructor;
import lombok.Getter;
import lombok.Value;
import org.deeplearning4j.rl4j.learning.async.AsyncConfiguration;
import org.deeplearning4j.rl4j.learning.configuration.IAsyncLearningConfiguration;
@AllArgsConstructor
@Value
public class MockAsyncConfiguration implements AsyncConfiguration {
@AllArgsConstructor
public class MockAsyncConfiguration implements IAsyncLearningConfiguration {
private Integer seed;
private Long seed;
private int maxEpochStep;
private int maxStep;
private int numThread;
private int nstep;
private int targetDqnUpdateFreq;
private int updateStart;
private double rewardFactor;
private double gamma;
private double errorClamp;
private int numThreads;
private int nStep;
private int learnerUpdateFrequency;
}

View File

@ -1,3 +1,20 @@
/*******************************************************************************
* Copyright (c) 2015-2019 Skymind, Inc.
* Copyright (c) 2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.rl4j.util;
import lombok.Getter;
@ -5,6 +22,7 @@ import lombok.Setter;
import org.deeplearning4j.rl4j.learning.IEpochTrainer;
import org.deeplearning4j.rl4j.learning.IHistoryProcessor;
import org.deeplearning4j.rl4j.learning.ILearning;
import org.deeplearning4j.rl4j.learning.configuration.ILearningConfiguration;
import org.deeplearning4j.rl4j.learning.listener.TrainingListener;
import org.deeplearning4j.rl4j.learning.sync.support.MockStatEntry;
import org.deeplearning4j.rl4j.mdp.MDP;
@ -162,7 +180,7 @@ public class DataManagerTrainingListenerTest {
}
@Override
public LConfiguration getConfiguration() {
public ILearningConfiguration getConfiguration() {
return null;
}

View File

@ -1,5 +1,6 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
* Copyright (c) 2015-2019 Skymind, Inc.
* Copyright (c) 2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
@ -16,19 +17,18 @@
package org.deeplearning4j.malmo;
import java.util.HashMap;
import com.microsoft.msr.malmo.TimestampedStringVector;
import com.microsoft.msr.malmo.WorldState;
import org.json.JSONArray;
import org.json.JSONObject;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import com.microsoft.msr.malmo.TimestampedStringVector;
import com.microsoft.msr.malmo.WorldState;
import java.util.HashMap;
/**
* Observation space that contains a grid of Minecraft blocks
*
* @author howard-abrams (howard.abrams@ca.com) on 1/12/17.
*/
public class MalmoObservationSpaceGrid extends MalmoObservationSpace {
@ -44,10 +44,10 @@ public class MalmoObservationSpaceGrid extends MalmoObservationSpace {
/**
* Construct observation space from a array of blocks policy should distinguish between.
*
* @param name Name given to Grid element in mission specification
* @param xSize total x size of grid
* @param ySize total y size of grid
* @param zSize total z size of grid
* @param name Name given to Grid element in mission specification
* @param xSize total x size of grid
* @param ySize total y size of grid
* @param zSize total z size of grid
* @param blocks Array of block names to distinguish between. Supports combination of individual strings and/or arrays of strings to map multiple block types to a single observation value. If not specified, it will dynamically map block names to integers - however, because these will be mapped as they are seen, different missions may have different mappings!
*/
public MalmoObservationSpaceGrid(String name, int xSize, int ySize, int zSize, Object... blocks) {
@ -78,7 +78,7 @@ public class MalmoObservationSpaceGrid extends MalmoObservationSpace {
@Override
public int[] getShape() {
return new int[] {totalSize};
return new int[]{totalSize};
}
@Override