RL4J: Add Backwardly Compatible Builder patterns (#326)

* Starting to switch configs of RL algorithms to use more fluent builder patterns. Many parameter choices in different algorithms default to SOTA and only be changed in specific cases

Signed-off-by: Bam4d <chris.bam4d@gmail.com>

* remove personal gpu-build file

Signed-off-by: Bam4d <chrisbam4d@gmail.com>

* refactored out configurations so they are heirarchical and re-usable, this is a step towards having a plug-and-play framework for different algorithms

* backwardly compatible configurations

* adding documentation to new configuration classes

Signed-off-by: Bam4d <chrisbam4d@gmail.com>

* private access modifiers are better suited here

Signed-off-by: Bam4d <chrisbam4d@gmail.com>

* RL4j does not compile without java 8 due to previous updates

fixing null pointers when listener arrays are empty

Signed-off-by: Bam4d <chrisbam4d@gmail.com>

* fixing copyright headers

Signed-off-by: Bam4d <chrisbam4d@gmail.com>

* uncomment logging line

Signed-off-by: Bam4d <chrisbam4d@gmail.com>

* fixing default value for learningUpdateFrequency

fixing test failure due to #352

Signed-off-by: Bam4d <chrisbam4d@gmail.com>

Co-authored-by: Bam4d <chris.bam4d@gmail.com>
master
Chris Bamford 2020-04-06 04:36:12 +01:00 committed by GitHub
parent fb1c41c512
commit 1a35ebec2e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
55 changed files with 1388 additions and 495 deletions

View File

@ -16,6 +16,7 @@
package org.datavec.api.transform.split; package org.datavec.api.transform.split;
import lombok.AllArgsConstructor; import lombok.AllArgsConstructor;
import lombok.Data; import lombok.Data;

View File

@ -20,6 +20,18 @@
xmlns="http://maven.apache.org/POM/4.0.0" xmlns="http://maven.apache.org/POM/4.0.0"
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd"> xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
<modelVersion>4.0.0</modelVersion> <modelVersion>4.0.0</modelVersion>
<build>
<plugins>
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-compiler-plugin</artifactId>
<configuration>
<source>8</source>
<target>8</target>
</configuration>
</plugin>
</plugins>
</build>
<parent> <parent>
<artifactId>rl4j</artifactId> <artifactId>rl4j</artifactId>

View File

@ -1,3 +1,19 @@
/*******************************************************************************
* Copyright (c) 2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.rl4j.learning; package org.deeplearning4j.rl4j.learning;
public interface EpochStepCounter { public interface EpochStepCounter {

View File

@ -1,5 +1,6 @@
/******************************************************************************* /*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc. * Copyright (c) 2015-2019 Skymind, Inc.
* Copyright (c) 2020 Konduit K.K.
* *
* This program and the accompanying materials are made available under the * This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at * terms of the Apache License, Version 2.0 which is available at
@ -16,10 +17,10 @@
package org.deeplearning4j.rl4j.learning; package org.deeplearning4j.rl4j.learning;
import org.deeplearning4j.rl4j.learning.configuration.ILearningConfiguration;
import org.deeplearning4j.rl4j.mdp.MDP; import org.deeplearning4j.rl4j.mdp.MDP;
import org.deeplearning4j.rl4j.policy.IPolicy; import org.deeplearning4j.rl4j.policy.IPolicy;
import org.deeplearning4j.rl4j.space.ActionSpace; import org.deeplearning4j.rl4j.space.ActionSpace;
import org.deeplearning4j.rl4j.space.Encodable;
/** /**
* @author rubenfiszel (ruben.fiszel@epfl.ch) 7/19/16. * @author rubenfiszel (ruben.fiszel@epfl.ch) 7/19/16.
@ -34,21 +35,12 @@ public interface ILearning<O, A, AS extends ActionSpace<A>> {
int getStepCounter(); int getStepCounter();
LConfiguration getConfiguration(); ILearningConfiguration getConfiguration();
MDP<O, A, AS> getMdp(); MDP<O, A, AS> getMdp();
IHistoryProcessor getHistoryProcessor(); IHistoryProcessor getHistoryProcessor();
interface LConfiguration {
Integer getSeed();
int getMaxEpochStep();
int getMaxStep();
double getGamma();
}
} }

View File

@ -1,5 +1,6 @@
/******************************************************************************* /*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc. * Copyright (c) 2015-2019 Skymind, Inc.
* Copyright (c) 2020 Konduit K.K.
* *
* This program and the accompanying materials are made available under the * This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at * terms of the Apache License, Version 2.0 which is available at
@ -19,6 +20,8 @@ package org.deeplearning4j.rl4j.learning.async;
import lombok.Getter; import lombok.Getter;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.deeplearning4j.nn.gradient.Gradient; import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.rl4j.learning.configuration.AsyncQLearningConfiguration;
import org.deeplearning4j.rl4j.learning.configuration.IAsyncLearningConfiguration;
import org.deeplearning4j.rl4j.network.NeuralNet; import org.deeplearning4j.rl4j.network.NeuralNet;
import org.nd4j.linalg.primitives.Pair; import org.nd4j.linalg.primitives.Pair;
@ -27,28 +30,26 @@ import java.util.concurrent.atomic.AtomicInteger;
/** /**
* @author rubenfiszel (ruben.fiszel@epfl.ch) on 8/5/16. * @author rubenfiszel (ruben.fiszel@epfl.ch) on 8/5/16.
* * <p>
* In the original paper, the authors uses Asynchronous * In the original paper, the authors uses Asynchronous
* Gradient Descent: Hogwild! It is a way to apply gradients * Gradient Descent: Hogwild! It is a way to apply gradients
* and modify a model in a lock-free manner. * and modify a model in a lock-free manner.
* * <p>
* As a way to implement this with dl4j, it is unfortunately * As a way to implement this with dl4j, it is unfortunately
* necessary at the time of writing to apply the gradient * necessary at the time of writing to apply the gradient
* (update the parameters) on a single separate global thread. * (update the parameters) on a single separate global thread.
* * <p>
* This Central thread for Asynchronous Method of reinforcement learning * This Central thread for Asynchronous Method of reinforcement learning
* enqueue the gradients coming from the different threads and update its * enqueue the gradients coming from the different threads and update its
* model and target. Those neurals nets are then synced by the other threads. * model and target. Those neurals nets are then synced by the other threads.
* * <p>
* The benefits of this thread is that the updater is "shared" between all thread * The benefits of this thread is that the updater is "shared" between all thread
* we have a single updater which is the single updater of the model contained here * we have a single updater which is the single updater of the model contained here
* * <p>
* This is similar to RMSProp with shared g and momentum * This is similar to RMSProp with shared g and momentum
* * <p>
* When Hogwild! is implemented, this could be replaced by a simple data * When Hogwild! is implemented, this could be replaced by a simple data
* structure * structure
*
*
*/ */
@Slf4j @Slf4j
public class AsyncGlobal<NN extends NeuralNet> extends Thread implements IAsyncGlobal<NN> { public class AsyncGlobal<NN extends NeuralNet> extends Thread implements IAsyncGlobal<NN> {
@ -56,7 +57,7 @@ public class AsyncGlobal<NN extends NeuralNet> extends Thread implements IAsyncG
@Getter @Getter
final private NN current; final private NN current;
final private ConcurrentLinkedQueue<Pair<Gradient[], Integer>> queue; final private ConcurrentLinkedQueue<Pair<Gradient[], Integer>> queue;
final private AsyncConfiguration a3cc; final private IAsyncLearningConfiguration configuration;
private final IAsyncLearning learning; private final IAsyncLearning learning;
@Getter @Getter
private AtomicInteger T = new AtomicInteger(0); private AtomicInteger T = new AtomicInteger(0);
@ -65,20 +66,20 @@ public class AsyncGlobal<NN extends NeuralNet> extends Thread implements IAsyncG
@Getter @Getter
private boolean running = true; private boolean running = true;
public AsyncGlobal(NN initial, AsyncConfiguration a3cc, IAsyncLearning learning) { public AsyncGlobal(NN initial, IAsyncLearningConfiguration configuration, IAsyncLearning learning) {
this.current = initial; this.current = initial;
target = (NN) initial.clone(); target = (NN) initial.clone();
this.a3cc = a3cc; this.configuration = configuration;
this.learning = learning; this.learning = learning;
queue = new ConcurrentLinkedQueue<>(); queue = new ConcurrentLinkedQueue<>();
} }
public boolean isTrainingComplete() { public boolean isTrainingComplete() {
return T.get() >= a3cc.getMaxStep(); return T.get() >= configuration.getMaxStep();
} }
public void enqueue(Gradient[] gradient, Integer nstep) { public void enqueue(Gradient[] gradient, Integer nstep) {
if(running && !isTrainingComplete()) { if (running && !isTrainingComplete()) {
queue.add(new Pair<>(gradient, nstep)); queue.add(new Pair<>(gradient, nstep));
} }
} }
@ -94,9 +95,8 @@ public class AsyncGlobal<NN extends NeuralNet> extends Thread implements IAsyncG
synchronized (this) { synchronized (this) {
current.applyGradient(gradient, pair.getSecond()); current.applyGradient(gradient, pair.getSecond());
} }
if (a3cc.getTargetDqnUpdateFreq() != -1 if (configuration.getLearnerUpdateFrequency() != -1 && T.get() / configuration.getLearnerUpdateFrequency() > (T.get() - pair.getSecond())
&& T.get() / a3cc.getTargetDqnUpdateFreq() > (T.get() - pair.getSecond()) / configuration.getLearnerUpdateFrequency()) {
/ a3cc.getTargetDqnUpdateFreq()) {
log.info("TARGET UPDATE at T = " + T.get()); log.info("TARGET UPDATE at T = " + T.get());
synchronized (this) { synchronized (this) {
target.copy(current); target.copy(current);
@ -111,7 +111,7 @@ public class AsyncGlobal<NN extends NeuralNet> extends Thread implements IAsyncG
* Force the immediate termination of the AsyncGlobal instance. Queued work items will be discarded and the AsyncLearning instance will be forced to terminate too. * Force the immediate termination of the AsyncGlobal instance. Queued work items will be discarded and the AsyncLearning instance will be forced to terminate too.
*/ */
public void terminate() { public void terminate() {
if(running) { if (running) {
running = false; running = false;
queue.clear(); queue.clear();
learning.terminate(); learning.terminate();

View File

@ -1,5 +1,6 @@
/******************************************************************************* /*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc. * Copyright (c) 2015-2019 Skymind, Inc.
* Copyright (c) 2020 Konduit K.K.
* *
* This program and the accompanying materials are made available under the * This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at * terms of the Apache License, Version 2.0 which is available at
@ -21,14 +22,17 @@ import lombok.Getter;
import lombok.Setter; import lombok.Setter;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.deeplearning4j.rl4j.learning.Learning; import org.deeplearning4j.rl4j.learning.Learning;
import org.deeplearning4j.rl4j.learning.listener.*; import org.deeplearning4j.rl4j.learning.configuration.AsyncQLearningConfiguration;
import org.deeplearning4j.rl4j.learning.configuration.IAsyncLearningConfiguration;
import org.deeplearning4j.rl4j.learning.listener.TrainingListener;
import org.deeplearning4j.rl4j.learning.listener.TrainingListenerList;
import org.deeplearning4j.rl4j.network.NeuralNet; import org.deeplearning4j.rl4j.network.NeuralNet;
import org.deeplearning4j.rl4j.space.ActionSpace; import org.deeplearning4j.rl4j.space.ActionSpace;
import org.deeplearning4j.rl4j.space.Encodable; import org.deeplearning4j.rl4j.space.Encodable;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
/** /**
* The entry point for async training. This class will start a number ({@link AsyncConfiguration#getNumThread() * The entry point for async training. This class will start a number ({@link AsyncQLearningConfiguration#getNumThreads()
* configuration.getNumThread()}) of worker threads. Then, it will monitor their progress at regular intervals * configuration.getNumThread()}) of worker threads. Then, it will monitor their progress at regular intervals
* (see setProgressEventInterval(int)) * (see setProgressEventInterval(int))
* *
@ -56,9 +60,10 @@ public abstract class AsyncLearning<O extends Encodable, A, AS extends ActionSpa
/** /**
* Returns the configuration * Returns the configuration
* @return the configuration (see {@link AsyncConfiguration}) *
* @return the configuration (see {@link AsyncQLearningConfiguration})
*/ */
public abstract AsyncConfiguration getConfiguration(); public abstract IAsyncLearningConfiguration getConfiguration();
protected abstract AsyncThread newThread(int i, int deviceAffinity); protected abstract AsyncThread newThread(int i, int deviceAffinity);
@ -77,12 +82,13 @@ public abstract class AsyncLearning<O extends Encodable, A, AS extends ActionSpa
/** /**
* Number of milliseconds between calls to onTrainingProgress * Number of milliseconds between calls to onTrainingProgress
*/ */
@Getter @Setter @Getter
@Setter
private int progressMonitorFrequency = 20000; private int progressMonitorFrequency = 20000;
private void launchThreads() { private void launchThreads() {
startGlobalThread(); startGlobalThread();
for (int i = 0; i < getConfiguration().getNumThread(); i++) { for (int i = 0; i < getConfiguration().getNumThreads(); i++) {
Thread t = newThread(i, i % Nd4j.getAffinityManager().getNumberOfDevices()); Thread t = newThread(i, i % Nd4j.getAffinityManager().getNumberOfDevices());
t.start(); t.start();
} }
@ -132,7 +138,7 @@ public abstract class AsyncLearning<O extends Encodable, A, AS extends ActionSpa
monitorThread = Thread.currentThread(); monitorThread = Thread.currentThread();
while (canContinue && !isTrainingComplete() && getAsyncGlobal().isRunning()) { while (canContinue && !isTrainingComplete() && getAsyncGlobal().isRunning()) {
canContinue = listeners.notifyTrainingProgress(this); canContinue = listeners.notifyTrainingProgress(this);
if(!canContinue) { if (!canContinue) {
return; return;
} }
@ -155,11 +161,11 @@ public abstract class AsyncLearning<O extends Encodable, A, AS extends ActionSpa
* Force the immediate termination of the learning. All learning threads, the AsyncGlobal thread and the monitor thread will be terminated. * Force the immediate termination of the learning. All learning threads, the AsyncGlobal thread and the monitor thread will be terminated.
*/ */
public void terminate() { public void terminate() {
if(canContinue) { if (canContinue) {
canContinue = false; canContinue = false;
Thread safeMonitorThread = monitorThread; Thread safeMonitorThread = monitorThread;
if(safeMonitorThread != null) { if (safeMonitorThread != null) {
safeMonitorThread.interrupt(); safeMonitorThread.interrupt();
} }
} }

View File

@ -1,5 +1,6 @@
/******************************************************************************* /*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc. * Copyright (c) 2015-2019 Skymind, Inc.
* Copyright (c) 2020 Konduit K.K.
* *
* This program and the accompanying materials are made available under the * This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at * terms of the Apache License, Version 2.0 which is available at
@ -23,6 +24,7 @@ import lombok.Value;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.deeplearning4j.gym.StepReply; import org.deeplearning4j.gym.StepReply;
import org.deeplearning4j.rl4j.learning.*; import org.deeplearning4j.rl4j.learning.*;
import org.deeplearning4j.rl4j.learning.configuration.IAsyncLearningConfiguration;
import org.deeplearning4j.rl4j.learning.listener.TrainingListener; import org.deeplearning4j.rl4j.learning.listener.TrainingListener;
import org.deeplearning4j.rl4j.learning.listener.TrainingListenerList; import org.deeplearning4j.rl4j.learning.listener.TrainingListenerList;
import org.deeplearning4j.rl4j.mdp.MDP; import org.deeplearning4j.rl4j.mdp.MDP;
@ -155,7 +157,7 @@ public abstract class AsyncThread<O, A, AS extends ActionSpace<A>, NN extends Ne
} }
private void handleTraining(RunContext context) { private void handleTraining(RunContext context) {
int maxSteps = Math.min(getConf().getNstep(), getConf().getMaxEpochStep() - currentEpochStep); int maxSteps = Math.min(getConf().getNStep(), getConf().getMaxEpochStep() - currentEpochStep);
SubEpochReturn subEpochReturn = trainSubEpoch(context.obs, maxSteps); SubEpochReturn subEpochReturn = trainSubEpoch(context.obs, maxSteps);
context.obs = subEpochReturn.getLastObs(); context.obs = subEpochReturn.getLastObs();
@ -197,7 +199,7 @@ public abstract class AsyncThread<O, A, AS extends ActionSpace<A>, NN extends Ne
protected abstract IAsyncGlobal<NN> getAsyncGlobal(); protected abstract IAsyncGlobal<NN> getAsyncGlobal();
protected abstract AsyncConfiguration getConf(); protected abstract IAsyncLearningConfiguration getConf();
protected abstract IPolicy<O, A> getPolicy(NN net); protected abstract IPolicy<O, A> getPolicy(NN net);

View File

@ -1,5 +1,7 @@
/******************************************************************************* /*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc. * Copyright (c) 2015-2019 Skymind, Inc.
* Copyright (c) 2020 Konduit K.K.
*
* *
* This program and the accompanying materials are made available under the * This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at * terms of the Apache License, Version 2.0 which is available at
@ -112,7 +114,7 @@ public abstract class AsyncThreadDiscrete<O, NN extends NeuralNet>
rewards.add(new MiniTrans(obs.getData(), null, null, 0)); rewards.add(new MiniTrans(obs.getData(), null, null, 0));
else { else {
INDArray[] output = null; INDArray[] output = null;
if (getConf().getTargetDqnUpdateFreq() == -1) if (getConf().getLearnerUpdateFrequency() == -1)
output = current.outputAll(obs.getData()); output = current.outputAll(obs.getData());
else synchronized (getAsyncGlobal()) { else synchronized (getAsyncGlobal()) {
output = getAsyncGlobal().getTarget().outputAll(obs.getData()); output = getAsyncGlobal().getTarget().outputAll(obs.getData());

View File

@ -1,5 +1,6 @@
/******************************************************************************* /*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc. * Copyright (c) 2015-2019 Skymind, Inc.
* Copyright (c) 2020 Konduit K.K.
* *
* This program and the accompanying materials are made available under the * This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at * terms of the Apache License, Version 2.0 which is available at
@ -16,11 +17,15 @@
package org.deeplearning4j.rl4j.learning.async.a3c.discrete; package org.deeplearning4j.rl4j.learning.async.a3c.discrete;
import lombok.*; import lombok.AllArgsConstructor;
import org.deeplearning4j.rl4j.learning.async.AsyncConfiguration; import lombok.Builder;
import lombok.Data;
import lombok.EqualsAndHashCode;
import lombok.Getter;
import org.deeplearning4j.rl4j.learning.async.AsyncGlobal; import org.deeplearning4j.rl4j.learning.async.AsyncGlobal;
import org.deeplearning4j.rl4j.learning.async.AsyncLearning; import org.deeplearning4j.rl4j.learning.async.AsyncLearning;
import org.deeplearning4j.rl4j.learning.async.AsyncThread; import org.deeplearning4j.rl4j.learning.async.AsyncThread;
import org.deeplearning4j.rl4j.learning.configuration.A3CLearningConfiguration;
import org.deeplearning4j.rl4j.mdp.MDP; import org.deeplearning4j.rl4j.mdp.MDP;
import org.deeplearning4j.rl4j.network.ac.IActorCritic; import org.deeplearning4j.rl4j.network.ac.IActorCritic;
import org.deeplearning4j.rl4j.policy.ACPolicy; import org.deeplearning4j.rl4j.policy.ACPolicy;
@ -32,15 +37,14 @@ import org.nd4j.linalg.factory.Nd4j;
/** /**
* @author rubenfiszel (ruben.fiszel@epfl.ch) 7/23/16. * @author rubenfiszel (ruben.fiszel@epfl.ch) 7/23/16.
* Training for A3C in the Discrete Domain * Training for A3C in the Discrete Domain
* * <p>
* All methods are fully implemented as described in the * All methods are fully implemented as described in the
* https://arxiv.org/abs/1602.01783 paper. * https://arxiv.org/abs/1602.01783 paper.
*
*/ */
public abstract class A3CDiscrete<O extends Encodable> extends AsyncLearning<O, Integer, DiscreteSpace, IActorCritic> { public abstract class A3CDiscrete<O extends Encodable> extends AsyncLearning<O, Integer, DiscreteSpace, IActorCritic> {
@Getter @Getter
final public A3CConfiguration configuration; final public A3CLearningConfiguration configuration;
@Getter @Getter
final protected MDP<O, Integer, DiscreteSpace> mdp; final protected MDP<O, Integer, DiscreteSpace> mdp;
final private IActorCritic iActorCritic; final private IActorCritic iActorCritic;
@ -49,15 +53,15 @@ public abstract class A3CDiscrete<O extends Encodable> extends AsyncLearning<O,
@Getter @Getter
final private ACPolicy<O> policy; final private ACPolicy<O> policy;
public A3CDiscrete(MDP<O, Integer, DiscreteSpace> mdp, IActorCritic iActorCritic, A3CConfiguration conf) { public A3CDiscrete(MDP<O, Integer, DiscreteSpace> mdp, IActorCritic iActorCritic, A3CLearningConfiguration conf) {
this.iActorCritic = iActorCritic; this.iActorCritic = iActorCritic;
this.mdp = mdp; this.mdp = mdp;
this.configuration = conf; this.configuration = conf;
asyncGlobal = new AsyncGlobal<>(iActorCritic, conf, this); asyncGlobal = new AsyncGlobal<>(iActorCritic, conf, this);
Integer seed = conf.getSeed(); Long seed = conf.getSeed();
Random rnd = Nd4j.getRandom(); Random rnd = Nd4j.getRandom();
if(seed != null) { if (seed != null) {
rnd.setSeed(seed); rnd.setSeed(seed);
} }
@ -65,7 +69,7 @@ public abstract class A3CDiscrete<O extends Encodable> extends AsyncLearning<O,
} }
protected AsyncThread newThread(int i, int deviceNum) { protected AsyncThread newThread(int i, int deviceNum) {
return new A3CThreadDiscrete(mdp.newInstance(), asyncGlobal, getConfiguration(), deviceNum, getListeners(), i); return new A3CThreadDiscrete(mdp.newInstance(), asyncGlobal, this.getConfiguration(), deviceNum, getListeners(), i);
} }
public IActorCritic getNeuralNet() { public IActorCritic getNeuralNet() {
@ -76,9 +80,9 @@ public abstract class A3CDiscrete<O extends Encodable> extends AsyncLearning<O,
@AllArgsConstructor @AllArgsConstructor
@Builder @Builder
@EqualsAndHashCode(callSuper = false) @EqualsAndHashCode(callSuper = false)
public static class A3CConfiguration implements AsyncConfiguration { public static class A3CConfiguration {
Integer seed; int seed;
int maxEpochStep; int maxEpochStep;
int maxStep; int maxStep;
int numThread; int numThread;
@ -88,8 +92,20 @@ public abstract class A3CDiscrete<O extends Encodable> extends AsyncLearning<O,
double gamma; double gamma;
double errorClamp; double errorClamp;
public int getTargetDqnUpdateFreq() { /**
return -1; * Converts the deprecated A3CConfiguration to the new LearningConfiguration format
*/
public A3CLearningConfiguration toLearningConfiguration() {
return A3CLearningConfiguration.builder()
.seed(new Long(seed))
.maxEpochStep(maxEpochStep)
.maxStep(maxStep)
.numThreads(numThread)
.nStep(nstep)
.rewardFactor(rewardFactor)
.gamma(gamma)
.build();
} }
} }

View File

@ -1,5 +1,6 @@
/******************************************************************************* /*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc. * Copyright (c) 2015-2019 Skymind, Inc.
* Copyright (c) 2020 Konduit K.K.
* *
* This program and the accompanying materials are made available under the * This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at * terms of the Apache License, Version 2.0 which is available at
@ -18,10 +19,12 @@ package org.deeplearning4j.rl4j.learning.async.a3c.discrete;
import org.deeplearning4j.rl4j.learning.HistoryProcessor; import org.deeplearning4j.rl4j.learning.HistoryProcessor;
import org.deeplearning4j.rl4j.learning.async.AsyncThread; import org.deeplearning4j.rl4j.learning.async.AsyncThread;
import org.deeplearning4j.rl4j.learning.configuration.A3CLearningConfiguration;
import org.deeplearning4j.rl4j.mdp.MDP; import org.deeplearning4j.rl4j.mdp.MDP;
import org.deeplearning4j.rl4j.network.ac.ActorCriticFactoryCompGraph; import org.deeplearning4j.rl4j.network.ac.ActorCriticFactoryCompGraph;
import org.deeplearning4j.rl4j.network.ac.ActorCriticFactoryCompGraphStdConv; import org.deeplearning4j.rl4j.network.ac.ActorCriticFactoryCompGraphStdConv;
import org.deeplearning4j.rl4j.network.ac.IActorCritic; import org.deeplearning4j.rl4j.network.ac.IActorCritic;
import org.deeplearning4j.rl4j.network.configuration.ActorCriticNetworkConfiguration;
import org.deeplearning4j.rl4j.space.DiscreteSpace; import org.deeplearning4j.rl4j.space.DiscreteSpace;
import org.deeplearning4j.rl4j.space.Encodable; import org.deeplearning4j.rl4j.space.Encodable;
import org.deeplearning4j.rl4j.util.DataManagerTrainingListener; import org.deeplearning4j.rl4j.util.DataManagerTrainingListener;
@ -29,16 +32,15 @@ import org.deeplearning4j.rl4j.util.IDataManager;
/** /**
* @author rubenfiszel (ruben.fiszel@epfl.ch) on 8/8/16. * @author rubenfiszel (ruben.fiszel@epfl.ch) on 8/8/16.
* * <p>
* Training for A3C in the Discrete Domain * Training for A3C in the Discrete Domain
* * <p>
* Specialized constructors for the Conv (pixels input) case * Specialized constructors for the Conv (pixels input) case
* Specialized conf + provide additional type safety * Specialized conf + provide additional type safety
* * <p>
* It uses CompGraph because there is benefit to combine the * It uses CompGraph because there is benefit to combine the
* first layers since they're essentially doing the same dimension * first layers since they're essentially doing the same dimension
* reduction task * reduction task
*
**/ **/
public class A3CDiscreteConv<O extends Encodable> extends A3CDiscrete<O> { public class A3CDiscreteConv<O extends Encodable> extends A3CDiscrete<O> {
@ -50,8 +52,18 @@ public class A3CDiscreteConv<O extends Encodable> extends A3CDiscrete<O> {
this(mdp, actorCritic, hpconf, conf); this(mdp, actorCritic, hpconf, conf);
addListener(new DataManagerTrainingListener(dataManager)); addListener(new DataManagerTrainingListener(dataManager));
} }
@Deprecated
public A3CDiscreteConv(MDP<O, Integer, DiscreteSpace> mdp, IActorCritic IActorCritic, public A3CDiscreteConv(MDP<O, Integer, DiscreteSpace> mdp, IActorCritic IActorCritic,
HistoryProcessor.Configuration hpconf, A3CConfiguration conf) { HistoryProcessor.Configuration hpconf, A3CConfiguration conf) {
super(mdp, IActorCritic, conf.toLearningConfiguration());
this.hpconf = hpconf;
setHistoryProcessor(hpconf);
}
public A3CDiscreteConv(MDP<O, Integer, DiscreteSpace> mdp, IActorCritic IActorCritic,
HistoryProcessor.Configuration hpconf, A3CLearningConfiguration conf) {
super(mdp, IActorCritic, conf); super(mdp, IActorCritic, conf);
this.hpconf = hpconf; this.hpconf = hpconf;
setHistoryProcessor(hpconf); setHistoryProcessor(hpconf);
@ -62,18 +74,32 @@ public class A3CDiscreteConv<O extends Encodable> extends A3CDiscrete<O> {
HistoryProcessor.Configuration hpconf, A3CConfiguration conf, IDataManager dataManager) { HistoryProcessor.Configuration hpconf, A3CConfiguration conf, IDataManager dataManager) {
this(mdp, factory.buildActorCritic(hpconf.getShape(), mdp.getActionSpace().getSize()), hpconf, conf, dataManager); this(mdp, factory.buildActorCritic(hpconf.getShape(), mdp.getActionSpace().getSize()), hpconf, conf, dataManager);
} }
@Deprecated
public A3CDiscreteConv(MDP<O, Integer, DiscreteSpace> mdp, ActorCriticFactoryCompGraph factory, public A3CDiscreteConv(MDP<O, Integer, DiscreteSpace> mdp, ActorCriticFactoryCompGraph factory,
HistoryProcessor.Configuration hpconf, A3CConfiguration conf) { HistoryProcessor.Configuration hpconf, A3CConfiguration conf) {
this(mdp, factory.buildActorCritic(hpconf.getShape(), mdp.getActionSpace().getSize()), hpconf, conf); this(mdp, factory.buildActorCritic(hpconf.getShape(), mdp.getActionSpace().getSize()), hpconf, conf);
} }
public A3CDiscreteConv(MDP<O, Integer, DiscreteSpace> mdp, ActorCriticFactoryCompGraph factory,
HistoryProcessor.Configuration hpconf, A3CLearningConfiguration conf) {
this(mdp, factory.buildActorCritic(hpconf.getShape(), mdp.getActionSpace().getSize()), hpconf, conf);
}
@Deprecated @Deprecated
public A3CDiscreteConv(MDP<O, Integer, DiscreteSpace> mdp, ActorCriticFactoryCompGraphStdConv.Configuration netConf, public A3CDiscreteConv(MDP<O, Integer, DiscreteSpace> mdp, ActorCriticFactoryCompGraphStdConv.Configuration netConf,
HistoryProcessor.Configuration hpconf, A3CConfiguration conf, IDataManager dataManager) { HistoryProcessor.Configuration hpconf, A3CConfiguration conf, IDataManager dataManager) {
this(mdp, new ActorCriticFactoryCompGraphStdConv(netConf), hpconf, conf, dataManager); this(mdp, new ActorCriticFactoryCompGraphStdConv(netConf.toNetworkConfiguration()), hpconf, conf, dataManager);
} }
@Deprecated
public A3CDiscreteConv(MDP<O, Integer, DiscreteSpace> mdp, ActorCriticFactoryCompGraphStdConv.Configuration netConf, public A3CDiscreteConv(MDP<O, Integer, DiscreteSpace> mdp, ActorCriticFactoryCompGraphStdConv.Configuration netConf,
HistoryProcessor.Configuration hpconf, A3CConfiguration conf) { HistoryProcessor.Configuration hpconf, A3CConfiguration conf) {
this(mdp, new ActorCriticFactoryCompGraphStdConv(netConf.toNetworkConfiguration()), hpconf, conf);
}
public A3CDiscreteConv(MDP<O, Integer, DiscreteSpace> mdp, ActorCriticNetworkConfiguration netConf,
HistoryProcessor.Configuration hpconf, A3CLearningConfiguration conf) {
this(mdp, new ActorCriticFactoryCompGraphStdConv(netConf), hpconf, conf); this(mdp, new ActorCriticFactoryCompGraphStdConv(netConf), hpconf, conf);
} }

View File

@ -1,5 +1,6 @@
/******************************************************************************* /*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc. * Copyright (c) 2015-2019 Skymind, Inc.
* Copyright (c) 2020 Konduit K.K.
* *
* This program and the accompanying materials are made available under the * This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at * terms of the Apache License, Version 2.0 which is available at
@ -16,8 +17,10 @@
package org.deeplearning4j.rl4j.learning.async.a3c.discrete; package org.deeplearning4j.rl4j.learning.async.a3c.discrete;
import org.deeplearning4j.rl4j.learning.configuration.A3CLearningConfiguration;
import org.deeplearning4j.rl4j.mdp.MDP; import org.deeplearning4j.rl4j.mdp.MDP;
import org.deeplearning4j.rl4j.network.ac.*; import org.deeplearning4j.rl4j.network.ac.*;
import org.deeplearning4j.rl4j.network.configuration.ActorCriticDenseNetworkConfiguration;
import org.deeplearning4j.rl4j.space.DiscreteSpace; import org.deeplearning4j.rl4j.space.DiscreteSpace;
import org.deeplearning4j.rl4j.space.Encodable; import org.deeplearning4j.rl4j.space.Encodable;
import org.deeplearning4j.rl4j.util.DataManagerTrainingListener; import org.deeplearning4j.rl4j.util.DataManagerTrainingListener;
@ -25,12 +28,11 @@ import org.deeplearning4j.rl4j.util.IDataManager;
/** /**
* @author rubenfiszel (ruben.fiszel@epfl.ch) on 8/8/16. * @author rubenfiszel (ruben.fiszel@epfl.ch) on 8/8/16.
* * <p>
* Training for A3C in the Discrete Domain * Training for A3C in the Discrete Domain
* * <p>
* We use specifically the Separate version because * We use specifically the Separate version because
* the model is too small to have enough benefit by sharing layers * the model is too small to have enough benefit by sharing layers
*
*/ */
public class A3CDiscreteDense<O extends Encodable> extends A3CDiscrete<O> { public class A3CDiscreteDense<O extends Encodable> extends A3CDiscrete<O> {
@ -40,7 +42,13 @@ public class A3CDiscreteDense<O extends Encodable> extends A3CDiscrete<O> {
this(mdp, IActorCritic, conf); this(mdp, IActorCritic, conf);
addListener(new DataManagerTrainingListener(dataManager)); addListener(new DataManagerTrainingListener(dataManager));
} }
@Deprecated
public A3CDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp, IActorCritic actorCritic, A3CConfiguration conf) { public A3CDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp, IActorCritic actorCritic, A3CConfiguration conf) {
super(mdp, actorCritic, conf.toLearningConfiguration());
}
public A3CDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp, IActorCritic actorCritic, A3CLearningConfiguration conf) {
super(mdp, actorCritic, conf); super(mdp, actorCritic, conf);
} }
@ -50,19 +58,33 @@ public class A3CDiscreteDense<O extends Encodable> extends A3CDiscrete<O> {
this(mdp, factory.buildActorCritic(mdp.getObservationSpace().getShape(), mdp.getActionSpace().getSize()), conf, this(mdp, factory.buildActorCritic(mdp.getObservationSpace().getShape(), mdp.getActionSpace().getSize()), conf,
dataManager); dataManager);
} }
@Deprecated
public A3CDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp, ActorCriticFactorySeparate factory, public A3CDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp, ActorCriticFactorySeparate factory,
A3CConfiguration conf) { A3CConfiguration conf) {
this(mdp, factory.buildActorCritic(mdp.getObservationSpace().getShape(), mdp.getActionSpace().getSize()), conf); this(mdp, factory.buildActorCritic(mdp.getObservationSpace().getShape(), mdp.getActionSpace().getSize()), conf);
} }
public A3CDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp, ActorCriticFactorySeparate factory,
A3CLearningConfiguration conf) {
this(mdp, factory.buildActorCritic(mdp.getObservationSpace().getShape(), mdp.getActionSpace().getSize()), conf);
}
@Deprecated @Deprecated
public A3CDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp, public A3CDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp,
ActorCriticFactorySeparateStdDense.Configuration netConf, A3CConfiguration conf, ActorCriticFactorySeparateStdDense.Configuration netConf, A3CConfiguration conf,
IDataManager dataManager) { IDataManager dataManager) {
this(mdp, new ActorCriticFactorySeparateStdDense(netConf), conf, dataManager); this(mdp, new ActorCriticFactorySeparateStdDense(netConf.toNetworkConfiguration()), conf, dataManager);
} }
@Deprecated
public A3CDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp, public A3CDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp,
ActorCriticFactorySeparateStdDense.Configuration netConf, A3CConfiguration conf) { ActorCriticFactorySeparateStdDense.Configuration netConf, A3CConfiguration conf) {
this(mdp, new ActorCriticFactorySeparateStdDense(netConf.toNetworkConfiguration()), conf);
}
public A3CDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp,
ActorCriticDenseNetworkConfiguration netConf, A3CLearningConfiguration conf) {
this(mdp, new ActorCriticFactorySeparateStdDense(netConf), conf); this(mdp, new ActorCriticFactorySeparateStdDense(netConf), conf);
} }
@ -72,20 +94,15 @@ public class A3CDiscreteDense<O extends Encodable> extends A3CDiscrete<O> {
this(mdp, factory.buildActorCritic(mdp.getObservationSpace().getShape(), mdp.getActionSpace().getSize()), conf, this(mdp, factory.buildActorCritic(mdp.getObservationSpace().getShape(), mdp.getActionSpace().getSize()), conf,
dataManager); dataManager);
} }
@Deprecated
public A3CDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp, ActorCriticFactoryCompGraph factory, public A3CDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp, ActorCriticFactoryCompGraph factory,
A3CConfiguration conf) { A3CConfiguration conf) {
this(mdp, factory.buildActorCritic(mdp.getObservationSpace().getShape(), mdp.getActionSpace().getSize()), conf); this(mdp, factory.buildActorCritic(mdp.getObservationSpace().getShape(), mdp.getActionSpace().getSize()), conf);
} }
@Deprecated public A3CDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp, ActorCriticFactoryCompGraph factory,
public A3CDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp, A3CLearningConfiguration conf) {
ActorCriticFactoryCompGraphStdDense.Configuration netConf, A3CConfiguration conf, this(mdp, factory.buildActorCritic(mdp.getObservationSpace().getShape(), mdp.getActionSpace().getSize()), conf);
IDataManager dataManager) {
this(mdp, new ActorCriticFactoryCompGraphStdDense(netConf), conf, dataManager);
} }
public A3CDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp,
ActorCriticFactoryCompGraphStdDense.Configuration netConf, A3CConfiguration conf) {
this(mdp, new ActorCriticFactoryCompGraphStdDense(netConf), conf);
}
} }

View File

@ -1,5 +1,6 @@
/******************************************************************************* /*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc. * Copyright (c) 2015-2019 Skymind, Inc.
* Copyright (c) 2020 Konduit K.K.
* *
* This program and the accompanying materials are made available under the * This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at * terms of the Apache License, Version 2.0 which is available at
@ -19,10 +20,10 @@ package org.deeplearning4j.rl4j.learning.async.a3c.discrete;
import lombok.Getter; import lombok.Getter;
import org.deeplearning4j.nn.gradient.Gradient; import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.rl4j.learning.Learning; import org.deeplearning4j.rl4j.learning.Learning;
import org.deeplearning4j.rl4j.learning.async.AsyncGlobal;
import org.deeplearning4j.rl4j.learning.async.AsyncThreadDiscrete; import org.deeplearning4j.rl4j.learning.async.AsyncThreadDiscrete;
import org.deeplearning4j.rl4j.learning.async.IAsyncGlobal; import org.deeplearning4j.rl4j.learning.async.IAsyncGlobal;
import org.deeplearning4j.rl4j.learning.async.MiniTrans; import org.deeplearning4j.rl4j.learning.async.MiniTrans;
import org.deeplearning4j.rl4j.learning.configuration.A3CLearningConfiguration;
import org.deeplearning4j.rl4j.learning.listener.TrainingListenerList; import org.deeplearning4j.rl4j.learning.listener.TrainingListenerList;
import org.deeplearning4j.rl4j.mdp.MDP; import org.deeplearning4j.rl4j.mdp.MDP;
import org.deeplearning4j.rl4j.network.ac.IActorCritic; import org.deeplearning4j.rl4j.network.ac.IActorCritic;
@ -31,9 +32,9 @@ import org.deeplearning4j.rl4j.policy.Policy;
import org.deeplearning4j.rl4j.space.DiscreteSpace; import org.deeplearning4j.rl4j.space.DiscreteSpace;
import org.deeplearning4j.rl4j.space.Encodable; import org.deeplearning4j.rl4j.space.Encodable;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.rng.Random;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.NDArrayIndex; import org.nd4j.linalg.indexing.NDArrayIndex;
import org.nd4j.linalg.api.rng.Random;
import java.util.Stack; import java.util.Stack;
@ -45,7 +46,7 @@ import java.util.Stack;
public class A3CThreadDiscrete<O extends Encodable> extends AsyncThreadDiscrete<O, IActorCritic> { public class A3CThreadDiscrete<O extends Encodable> extends AsyncThreadDiscrete<O, IActorCritic> {
@Getter @Getter
final protected A3CDiscrete.A3CConfiguration conf; final protected A3CLearningConfiguration conf;
@Getter @Getter
final protected IAsyncGlobal<IActorCritic> asyncGlobal; final protected IAsyncGlobal<IActorCritic> asyncGlobal;
@Getter @Getter
@ -54,14 +55,14 @@ public class A3CThreadDiscrete<O extends Encodable> extends AsyncThreadDiscrete<
final private Random rnd; final private Random rnd;
public A3CThreadDiscrete(MDP<O, Integer, DiscreteSpace> mdp, IAsyncGlobal<IActorCritic> asyncGlobal, public A3CThreadDiscrete(MDP<O, Integer, DiscreteSpace> mdp, IAsyncGlobal<IActorCritic> asyncGlobal,
A3CDiscrete.A3CConfiguration a3cc, int deviceNum, TrainingListenerList listeners, A3CLearningConfiguration a3cc, int deviceNum, TrainingListenerList listeners,
int threadNumber) { int threadNumber) {
super(asyncGlobal, mdp, listeners, threadNumber, deviceNum); super(asyncGlobal, mdp, listeners, threadNumber, deviceNum);
this.conf = a3cc; this.conf = a3cc;
this.asyncGlobal = asyncGlobal; this.asyncGlobal = asyncGlobal;
this.threadNumber = threadNumber; this.threadNumber = threadNumber;
Integer seed = conf.getSeed(); Long seed = conf.getSeed();
rnd = Nd4j.getRandom(); rnd = Nd4j.getRandom();
if(seed != null) { if(seed != null) {
rnd.setSeed(seed + threadNumber); rnd.setSeed(seed + threadNumber);

View File

@ -1,5 +1,6 @@
/******************************************************************************* /*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc. * Copyright (c) 2015-2019 Skymind, Inc.
* Copyright (c) 2020 Konduit K.K.
* *
* This program and the accompanying materials are made available under the * This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at * terms of the Apache License, Version 2.0 which is available at
@ -16,18 +17,21 @@
package org.deeplearning4j.rl4j.learning.async.nstep.discrete; package org.deeplearning4j.rl4j.learning.async.nstep.discrete;
import lombok.*; import lombok.AllArgsConstructor;
import org.deeplearning4j.rl4j.learning.async.AsyncConfiguration; import lombok.Builder;
import lombok.Data;
import lombok.EqualsAndHashCode;
import lombok.Getter;
import org.deeplearning4j.rl4j.learning.async.AsyncGlobal; import org.deeplearning4j.rl4j.learning.async.AsyncGlobal;
import org.deeplearning4j.rl4j.learning.async.AsyncLearning; import org.deeplearning4j.rl4j.learning.async.AsyncLearning;
import org.deeplearning4j.rl4j.learning.async.AsyncThread; import org.deeplearning4j.rl4j.learning.async.AsyncThread;
import org.deeplearning4j.rl4j.learning.configuration.AsyncQLearningConfiguration;
import org.deeplearning4j.rl4j.mdp.MDP; import org.deeplearning4j.rl4j.mdp.MDP;
import org.deeplearning4j.rl4j.network.dqn.IDQN; import org.deeplearning4j.rl4j.network.dqn.IDQN;
import org.deeplearning4j.rl4j.policy.DQNPolicy; import org.deeplearning4j.rl4j.policy.DQNPolicy;
import org.deeplearning4j.rl4j.policy.IPolicy; import org.deeplearning4j.rl4j.policy.IPolicy;
import org.deeplearning4j.rl4j.space.DiscreteSpace; import org.deeplearning4j.rl4j.space.DiscreteSpace;
import org.deeplearning4j.rl4j.space.Encodable; import org.deeplearning4j.rl4j.space.Encodable;
import org.nd4j.linalg.factory.Nd4j;
/** /**
* @author rubenfiszel (ruben.fiszel@epfl.ch) on 8/5/16. * @author rubenfiszel (ruben.fiszel@epfl.ch) on 8/5/16.
@ -36,14 +40,14 @@ public abstract class AsyncNStepQLearningDiscrete<O extends Encodable>
extends AsyncLearning<O, Integer, DiscreteSpace, IDQN> { extends AsyncLearning<O, Integer, DiscreteSpace, IDQN> {
@Getter @Getter
final public AsyncNStepQLConfiguration configuration; final public AsyncQLearningConfiguration configuration;
@Getter @Getter
final private MDP<O, Integer, DiscreteSpace> mdp; final private MDP<O, Integer, DiscreteSpace> mdp;
@Getter @Getter
final private AsyncGlobal<IDQN> asyncGlobal; final private AsyncGlobal<IDQN> asyncGlobal;
public AsyncNStepQLearningDiscrete(MDP<O, Integer, DiscreteSpace> mdp, IDQN dqn, AsyncNStepQLConfiguration conf) { public AsyncNStepQLearningDiscrete(MDP<O, Integer, DiscreteSpace> mdp, IDQN dqn, AsyncQLearningConfiguration conf) {
this.mdp = mdp; this.mdp = mdp;
this.configuration = conf; this.configuration = conf;
this.asyncGlobal = new AsyncGlobal<>(dqn, conf, this); this.asyncGlobal = new AsyncGlobal<>(dqn, conf, this);
@ -62,12 +66,11 @@ public abstract class AsyncNStepQLearningDiscrete<O extends Encodable>
return new DQNPolicy<O>(getNeuralNet()); return new DQNPolicy<O>(getNeuralNet());
} }
@Data @Data
@AllArgsConstructor @AllArgsConstructor
@Builder @Builder
@EqualsAndHashCode(callSuper = false) @EqualsAndHashCode(callSuper = false)
public static class AsyncNStepQLConfiguration implements AsyncConfiguration { public static class AsyncNStepQLConfiguration {
Integer seed; Integer seed;
int maxEpochStep; int maxEpochStep;
@ -82,5 +85,22 @@ public abstract class AsyncNStepQLearningDiscrete<O extends Encodable>
float minEpsilon; float minEpsilon;
int epsilonNbStep; int epsilonNbStep;
public AsyncQLearningConfiguration toLearningConfiguration() {
return AsyncQLearningConfiguration.builder()
.seed(new Long(seed))
.maxEpochStep(maxEpochStep)
.maxStep(maxStep)
.numThreads(numThread)
.nStep(nstep)
.targetDqnUpdateFreq(targetDqnUpdateFreq)
.updateStart(updateStart)
.rewardFactor(rewardFactor)
.gamma(gamma)
.errorClamp(errorClamp)
.minEpsilon(minEpsilon)
.build();
} }
}
} }

View File

@ -1,5 +1,6 @@
/******************************************************************************* /*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc. * Copyright (c) 2015-2019 Skymind, Inc.
* Copyright (c) 2020 Konduit K.K.
* *
* This program and the accompanying materials are made available under the * This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at * terms of the Apache License, Version 2.0 which is available at
@ -18,7 +19,9 @@ package org.deeplearning4j.rl4j.learning.async.nstep.discrete;
import org.deeplearning4j.rl4j.learning.HistoryProcessor; import org.deeplearning4j.rl4j.learning.HistoryProcessor;
import org.deeplearning4j.rl4j.learning.async.AsyncThread; import org.deeplearning4j.rl4j.learning.async.AsyncThread;
import org.deeplearning4j.rl4j.learning.configuration.AsyncQLearningConfiguration;
import org.deeplearning4j.rl4j.mdp.MDP; import org.deeplearning4j.rl4j.mdp.MDP;
import org.deeplearning4j.rl4j.network.configuration.NetworkConfiguration;
import org.deeplearning4j.rl4j.network.dqn.DQNFactory; import org.deeplearning4j.rl4j.network.dqn.DQNFactory;
import org.deeplearning4j.rl4j.network.dqn.DQNFactoryStdConv; import org.deeplearning4j.rl4j.network.dqn.DQNFactoryStdConv;
import org.deeplearning4j.rl4j.network.dqn.IDQN; import org.deeplearning4j.rl4j.network.dqn.IDQN;
@ -38,12 +41,12 @@ public class AsyncNStepQLearningDiscreteConv<O extends Encodable> extends AsyncN
@Deprecated @Deprecated
public AsyncNStepQLearningDiscreteConv(MDP<O, Integer, DiscreteSpace> mdp, IDQN dqn, public AsyncNStepQLearningDiscreteConv(MDP<O, Integer, DiscreteSpace> mdp, IDQN dqn,
HistoryProcessor.Configuration hpconf, AsyncNStepQLConfiguration conf, IDataManager dataManager) { HistoryProcessor.Configuration hpconf, AsyncQLearningConfiguration conf, IDataManager dataManager) {
this(mdp, dqn, hpconf, conf); this(mdp, dqn, hpconf, conf);
addListener(new DataManagerTrainingListener(dataManager)); addListener(new DataManagerTrainingListener(dataManager));
} }
public AsyncNStepQLearningDiscreteConv(MDP<O, Integer, DiscreteSpace> mdp, IDQN dqn, public AsyncNStepQLearningDiscreteConv(MDP<O, Integer, DiscreteSpace> mdp, IDQN dqn,
HistoryProcessor.Configuration hpconf, AsyncNStepQLConfiguration conf) { HistoryProcessor.Configuration hpconf, AsyncQLearningConfiguration conf) {
super(mdp, dqn, conf); super(mdp, dqn, conf);
this.hpconf = hpconf; this.hpconf = hpconf;
setHistoryProcessor(hpconf); setHistoryProcessor(hpconf);
@ -51,21 +54,21 @@ public class AsyncNStepQLearningDiscreteConv<O extends Encodable> extends AsyncN
@Deprecated @Deprecated
public AsyncNStepQLearningDiscreteConv(MDP<O, Integer, DiscreteSpace> mdp, DQNFactory factory, public AsyncNStepQLearningDiscreteConv(MDP<O, Integer, DiscreteSpace> mdp, DQNFactory factory,
HistoryProcessor.Configuration hpconf, AsyncNStepQLConfiguration conf, IDataManager dataManager) { HistoryProcessor.Configuration hpconf, AsyncQLearningConfiguration conf, IDataManager dataManager) {
this(mdp, factory.buildDQN(hpconf.getShape(), mdp.getActionSpace().getSize()), hpconf, conf, dataManager); this(mdp, factory.buildDQN(hpconf.getShape(), mdp.getActionSpace().getSize()), hpconf, conf, dataManager);
} }
public AsyncNStepQLearningDiscreteConv(MDP<O, Integer, DiscreteSpace> mdp, DQNFactory factory, public AsyncNStepQLearningDiscreteConv(MDP<O, Integer, DiscreteSpace> mdp, DQNFactory factory,
HistoryProcessor.Configuration hpconf, AsyncNStepQLConfiguration conf) { HistoryProcessor.Configuration hpconf, AsyncQLearningConfiguration conf) {
this(mdp, factory.buildDQN(hpconf.getShape(), mdp.getActionSpace().getSize()), hpconf, conf); this(mdp, factory.buildDQN(hpconf.getShape(), mdp.getActionSpace().getSize()), hpconf, conf);
} }
@Deprecated @Deprecated
public AsyncNStepQLearningDiscreteConv(MDP<O, Integer, DiscreteSpace> mdp, DQNFactoryStdConv.Configuration netConf, public AsyncNStepQLearningDiscreteConv(MDP<O, Integer, DiscreteSpace> mdp, NetworkConfiguration netConf,
HistoryProcessor.Configuration hpconf, AsyncNStepQLConfiguration conf, IDataManager dataManager) { HistoryProcessor.Configuration hpconf, AsyncQLearningConfiguration conf, IDataManager dataManager) {
this(mdp, new DQNFactoryStdConv(netConf), hpconf, conf, dataManager); this(mdp, new DQNFactoryStdConv(netConf), hpconf, conf, dataManager);
} }
public AsyncNStepQLearningDiscreteConv(MDP<O, Integer, DiscreteSpace> mdp, DQNFactoryStdConv.Configuration netConf, public AsyncNStepQLearningDiscreteConv(MDP<O, Integer, DiscreteSpace> mdp, NetworkConfiguration netConf,
HistoryProcessor.Configuration hpconf, AsyncNStepQLConfiguration conf) { HistoryProcessor.Configuration hpconf, AsyncQLearningConfiguration conf) {
this(mdp, new DQNFactoryStdConv(netConf), hpconf, conf); this(mdp, new DQNFactoryStdConv(netConf), hpconf, conf);
} }

View File

@ -1,5 +1,7 @@
/******************************************************************************* /*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc. * Copyright (c) 2015-2019 Skymind, Inc.
* Copyright (c) 2020 Konduit K.K.
*
* *
* This program and the accompanying materials are made available under the * This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at * terms of the Apache License, Version 2.0 which is available at
@ -16,7 +18,9 @@
package org.deeplearning4j.rl4j.learning.async.nstep.discrete; package org.deeplearning4j.rl4j.learning.async.nstep.discrete;
import org.deeplearning4j.rl4j.learning.configuration.AsyncQLearningConfiguration;
import org.deeplearning4j.rl4j.mdp.MDP; import org.deeplearning4j.rl4j.mdp.MDP;
import org.deeplearning4j.rl4j.network.configuration.DQNDenseNetworkConfiguration;
import org.deeplearning4j.rl4j.network.dqn.DQNFactory; import org.deeplearning4j.rl4j.network.dqn.DQNFactory;
import org.deeplearning4j.rl4j.network.dqn.DQNFactoryStdDense; import org.deeplearning4j.rl4j.network.dqn.DQNFactoryStdDense;
import org.deeplearning4j.rl4j.network.dqn.IDQN; import org.deeplearning4j.rl4j.network.dqn.IDQN;
@ -33,12 +37,18 @@ public class AsyncNStepQLearningDiscreteDense<O extends Encodable> extends Async
@Deprecated @Deprecated
public AsyncNStepQLearningDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp, IDQN dqn, public AsyncNStepQLearningDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp, IDQN dqn,
AsyncNStepQLConfiguration conf, IDataManager dataManager) { AsyncNStepQLConfiguration conf, IDataManager dataManager) {
super(mdp, dqn, conf); super(mdp, dqn, conf.toLearningConfiguration());
addListener(new DataManagerTrainingListener(dataManager)); addListener(new DataManagerTrainingListener(dataManager));
} }
@Deprecated
public AsyncNStepQLearningDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp, IDQN dqn, public AsyncNStepQLearningDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp, IDQN dqn,
AsyncNStepQLConfiguration conf) { AsyncNStepQLConfiguration conf) {
super(mdp, dqn, conf.toLearningConfiguration());
}
public AsyncNStepQLearningDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp, IDQN dqn,
AsyncQLearningConfiguration conf) {
super(mdp, dqn, conf); super(mdp, dqn, conf);
} }
@ -48,19 +58,34 @@ public class AsyncNStepQLearningDiscreteDense<O extends Encodable> extends Async
this(mdp, factory.buildDQN(mdp.getObservationSpace().getShape(), mdp.getActionSpace().getSize()), conf, this(mdp, factory.buildDQN(mdp.getObservationSpace().getShape(), mdp.getActionSpace().getSize()), conf,
dataManager); dataManager);
} }
@Deprecated
public AsyncNStepQLearningDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp, DQNFactory factory, public AsyncNStepQLearningDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp, DQNFactory factory,
AsyncNStepQLConfiguration conf) { AsyncNStepQLConfiguration conf) {
this(mdp, factory.buildDQN(mdp.getObservationSpace().getShape(), mdp.getActionSpace().getSize()), conf); this(mdp, factory.buildDQN(mdp.getObservationSpace().getShape(), mdp.getActionSpace().getSize()), conf);
} }
public AsyncNStepQLearningDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp, DQNFactory factory,
AsyncQLearningConfiguration conf) {
this(mdp, factory.buildDQN(mdp.getObservationSpace().getShape(), mdp.getActionSpace().getSize()), conf);
}
@Deprecated @Deprecated
public AsyncNStepQLearningDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp, public AsyncNStepQLearningDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp,
DQNFactoryStdDense.Configuration netConf, AsyncNStepQLConfiguration conf, IDataManager dataManager) { DQNFactoryStdDense.Configuration netConf, AsyncNStepQLConfiguration conf, IDataManager dataManager) {
this(mdp, new DQNFactoryStdDense(netConf), conf, dataManager); this(mdp, new DQNFactoryStdDense(netConf.toNetworkConfiguration()), conf, dataManager);
} }
@Deprecated
public AsyncNStepQLearningDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp, public AsyncNStepQLearningDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp,
DQNFactoryStdDense.Configuration netConf, AsyncNStepQLConfiguration conf) { DQNFactoryStdDense.Configuration netConf, AsyncNStepQLConfiguration conf) {
this(mdp, new DQNFactoryStdDense(netConf.toNetworkConfiguration()), conf);
}
public AsyncNStepQLearningDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp,
DQNDenseNetworkConfiguration netConf, AsyncQLearningConfiguration conf) {
this(mdp, new DQNFactoryStdDense(netConf), conf); this(mdp, new DQNFactoryStdDense(netConf), conf);
} }
} }

View File

@ -1,5 +1,6 @@
/******************************************************************************* /*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc. * Copyright (c) 2015-2019 Skymind, Inc.
* Copyright (c) 2020 Konduit K.K.
* *
* This program and the accompanying materials are made available under the * This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at * terms of the Apache License, Version 2.0 which is available at
@ -22,6 +23,7 @@ import org.deeplearning4j.rl4j.learning.Learning;
import org.deeplearning4j.rl4j.learning.async.AsyncThreadDiscrete; import org.deeplearning4j.rl4j.learning.async.AsyncThreadDiscrete;
import org.deeplearning4j.rl4j.learning.async.IAsyncGlobal; import org.deeplearning4j.rl4j.learning.async.IAsyncGlobal;
import org.deeplearning4j.rl4j.learning.async.MiniTrans; import org.deeplearning4j.rl4j.learning.async.MiniTrans;
import org.deeplearning4j.rl4j.learning.configuration.AsyncQLearningConfiguration;
import org.deeplearning4j.rl4j.learning.listener.TrainingListenerList; import org.deeplearning4j.rl4j.learning.listener.TrainingListenerList;
import org.deeplearning4j.rl4j.mdp.MDP; import org.deeplearning4j.rl4j.mdp.MDP;
import org.deeplearning4j.rl4j.network.dqn.IDQN; import org.deeplearning4j.rl4j.network.dqn.IDQN;
@ -42,7 +44,7 @@ import java.util.Stack;
public class AsyncNStepQLearningThreadDiscrete<O extends Encodable> extends AsyncThreadDiscrete<O, IDQN> { public class AsyncNStepQLearningThreadDiscrete<O extends Encodable> extends AsyncThreadDiscrete<O, IDQN> {
@Getter @Getter
final protected AsyncNStepQLearningDiscrete.AsyncNStepQLConfiguration conf; final protected AsyncQLearningConfiguration conf;
@Getter @Getter
final protected IAsyncGlobal<IDQN> asyncGlobal; final protected IAsyncGlobal<IDQN> asyncGlobal;
@Getter @Getter
@ -51,7 +53,7 @@ public class AsyncNStepQLearningThreadDiscrete<O extends Encodable> extends Asyn
final private Random rnd; final private Random rnd;
public AsyncNStepQLearningThreadDiscrete(MDP<O, Integer, DiscreteSpace> mdp, IAsyncGlobal<IDQN> asyncGlobal, public AsyncNStepQLearningThreadDiscrete(MDP<O, Integer, DiscreteSpace> mdp, IAsyncGlobal<IDQN> asyncGlobal,
AsyncNStepQLearningDiscrete.AsyncNStepQLConfiguration conf, AsyncQLearningConfiguration conf,
TrainingListenerList listeners, int threadNumber, int deviceNum) { TrainingListenerList listeners, int threadNumber, int deviceNum) {
super(asyncGlobal, mdp, listeners, threadNumber, deviceNum); super(asyncGlobal, mdp, listeners, threadNumber, deviceNum);
this.conf = conf; this.conf = conf;
@ -59,7 +61,7 @@ public class AsyncNStepQLearningThreadDiscrete<O extends Encodable> extends Asyn
this.threadNumber = threadNumber; this.threadNumber = threadNumber;
rnd = Nd4j.getRandom(); rnd = Nd4j.getRandom();
Integer seed = conf.getSeed(); Long seed = conf.getSeed();
if(seed != null) { if(seed != null) {
rnd.setSeed(seed + threadNumber); rnd.setSeed(seed + threadNumber);
} }

View File

@ -0,0 +1,46 @@
/*******************************************************************************
* Copyright (c) 2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.rl4j.learning.configuration;
import lombok.Builder;
import lombok.Data;
import lombok.EqualsAndHashCode;
import lombok.experimental.SuperBuilder;
@Data
@SuperBuilder
@EqualsAndHashCode(callSuper = true)
public class A3CLearningConfiguration extends LearningConfiguration implements IAsyncLearningConfiguration {
/**
* The number of asynchronous threads to use to generate gradients
*/
private final int numThreads;
/**
* The number of steps to calculate gradients over
*/
private final int nStep;
/**
* The frequency of async training iterations to update the target network.
*
* If this is set to -1 then the target network is updated after every training iteration
*/
@Builder.Default
private int learnerUpdateFrequency = -1;
}

View File

@ -0,0 +1,42 @@
/*******************************************************************************
* Copyright (c) 2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.rl4j.learning.configuration;
import lombok.Data;
import lombok.EqualsAndHashCode;
import lombok.experimental.SuperBuilder;
@Data
@SuperBuilder
@EqualsAndHashCode(callSuper = true)
public class AsyncQLearningConfiguration extends QLearningConfiguration implements IAsyncLearningConfiguration {
/**
* The number of asynchronous threads to use to generate experience data
*/
private final int numThreads;
/**
* The number of steps in each training interations
*/
private final int nStep;
public int getLearnerUpdateFrequency() {
return getTargetDqnUpdateFreq();
}
}

View File

@ -0,0 +1,28 @@
/*******************************************************************************
* Copyright (c) 2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.rl4j.learning.configuration;
public interface IAsyncLearningConfiguration extends ILearningConfiguration {
int getNumThreads();
int getNStep();
int getLearnerUpdateFrequency();
int getMaxStep();
}

View File

@ -1,5 +1,5 @@
/******************************************************************************* /*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc. * Copyright (c) 2020 Konduit K.K.
* *
* This program and the accompanying materials are made available under the * This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at * terms of the Apache License, Version 2.0 which is available at
@ -14,36 +14,16 @@
* SPDX-License-Identifier: Apache-2.0 * SPDX-License-Identifier: Apache-2.0
******************************************************************************/ ******************************************************************************/
package org.deeplearning4j.rl4j.learning.async; package org.deeplearning4j.rl4j.learning.configuration;
import org.deeplearning4j.rl4j.learning.ILearning; public interface ILearningConfiguration {
Long getSeed();
/**
* @author rubenfiszel (ruben.fiszel@epfl.ch) on 8/23/16.
*
* Interface configuration for all training method that inherit
* from AsyncLearning
*/
public interface AsyncConfiguration extends ILearning.LConfiguration {
Integer getSeed();
int getMaxEpochStep(); int getMaxEpochStep();
int getMaxStep(); int getMaxStep();
int getNumThread();
int getNstep();
int getTargetDqnUpdateFreq();
int getUpdateStart();
double getRewardFactor();
double getGamma(); double getGamma();
double getErrorClamp(); double getRewardFactor();
} }

View File

@ -0,0 +1,59 @@
/*******************************************************************************
* Copyright (c) 2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.rl4j.learning.configuration;
import lombok.Builder;
import lombok.Data;
import lombok.NoArgsConstructor;
import lombok.experimental.SuperBuilder;
@Data
@SuperBuilder
@NoArgsConstructor
public class LearningConfiguration implements ILearningConfiguration {
/**
* Seed value used for training
*/
@Builder.Default
private Long seed = System.currentTimeMillis();
/**
* The maximum number of steps in each episode
*/
@Builder.Default
private int maxEpochStep = 200;
/**
* The maximum number of steps to train for
*/
@Builder.Default
private int maxStep = 150000;
/**
* Gamma parameter used for discounted rewards
*/
@Builder.Default
private double gamma = 0.99;
/**
* Scaling parameter for rewards
*/
@Builder.Default
private double rewardFactor = 1.0;
}

View File

@ -0,0 +1,79 @@
/*******************************************************************************
* Copyright (c) 2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.rl4j.learning.configuration;
import lombok.Builder;
import lombok.Data;
import lombok.EqualsAndHashCode;
import lombok.NoArgsConstructor;
import lombok.experimental.SuperBuilder;
@Data
@SuperBuilder
@NoArgsConstructor
@EqualsAndHashCode(callSuper = false)
public class QLearningConfiguration extends LearningConfiguration {
/**
* The maximum size of the experience replay buffer
*/
@Builder.Default
private int expRepMaxSize = 150000;
/**
* The batch size of experience for each training iteration
*/
@Builder.Default
private int batchSize = 32;
/**
* How many steps between target network updates
*/
@Builder.Default
private int targetDqnUpdateFreq = 100;
/**
* The number of steps to initially wait for until samplling batches from experience replay buffer
*/
@Builder.Default
private int updateStart = 10;
/**
* Prevent the new Q-Value from being farther than <i>errorClamp</i> away from the previous value. Double.NaN will result in no clamping
*/
@Builder.Default
private double errorClamp = 1.0;
/**
* The minimum probability for random exploration action during episilon-greedy annealing
*/
@Builder.Default
private double minEpsilon = 0.1f;
/**
* The number of steps to anneal epsilon to its minimum value.
*/
@Builder.Default
private int epsilonNbStep = 10000;
/**
* Whether to use the double DQN algorithm
*/
@Builder.Default
private boolean doubleDQN = false;
}

View File

@ -1,5 +1,6 @@
/******************************************************************************* /*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc. * Copyright (c) 2015-2019 Skymind, Inc.
* Copyright (c) 2020 Konduit K.K.
* *
* This program and the accompanying materials are made available under the * This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at * terms of the Apache License, Version 2.0 which is available at
@ -63,7 +64,7 @@ public abstract class SyncLearning<O, A, AS extends ActionSpace<A>, NN extends N
/** /**
* This method will train the model<p> * This method will train the model<p>
* The training stop when:<br> * The training stop when:<br>
* - the number of steps reaches the maximum defined in the configuration (see {@link LConfiguration#getMaxStep() LConfiguration.getMaxStep()})<br> * - the number of steps reaches the maximum defined in the configuration (see {@link ILearningConfiguration#getMaxStep() LConfiguration.getMaxStep()})<br>
* OR<br> * OR<br>
* - a listener explicitly stops it<br> * - a listener explicitly stops it<br>
* <p> * <p>

View File

@ -1,5 +1,6 @@
/******************************************************************************* /*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc. * Copyright (c) 2015-2019 Skymind, Inc.
* Copyright (c) 2020 Konduit K.K.
* *
* This program and the accompanying materials are made available under the * This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at * terms of the Apache License, Version 2.0 which is available at
@ -18,10 +19,19 @@ package org.deeplearning4j.rl4j.learning.sync.qlearning;
import com.fasterxml.jackson.databind.annotation.JsonDeserialize; import com.fasterxml.jackson.databind.annotation.JsonDeserialize;
import com.fasterxml.jackson.databind.annotation.JsonPOJOBuilder; import com.fasterxml.jackson.databind.annotation.JsonPOJOBuilder;
import lombok.*; import lombok.AccessLevel;
import lombok.AllArgsConstructor;
import lombok.Builder;
import lombok.Data;
import lombok.EqualsAndHashCode;
import lombok.Getter;
import lombok.Setter;
import lombok.Value;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.deeplearning4j.gym.StepReply; import org.deeplearning4j.gym.StepReply;
import org.deeplearning4j.rl4j.learning.EpochStepCounter; import org.deeplearning4j.rl4j.learning.EpochStepCounter;
import org.deeplearning4j.rl4j.learning.configuration.ILearningConfiguration;
import org.deeplearning4j.rl4j.learning.configuration.QLearningConfiguration;
import org.deeplearning4j.rl4j.learning.sync.ExpReplay; import org.deeplearning4j.rl4j.learning.sync.ExpReplay;
import org.deeplearning4j.rl4j.learning.sync.IExpReplay; import org.deeplearning4j.rl4j.learning.sync.IExpReplay;
import org.deeplearning4j.rl4j.learning.sync.SyncLearning; import org.deeplearning4j.rl4j.learning.sync.SyncLearning;
@ -59,15 +69,15 @@ public abstract class QLearning<O extends Encodable, A, AS extends ActionSpace<A
protected abstract LegacyMDPWrapper<O, A, AS> getLegacyMDPWrapper(); protected abstract LegacyMDPWrapper<O, A, AS> getLegacyMDPWrapper();
public QLearning(QLConfiguration conf) { public QLearning(QLearningConfiguration conf) {
this(conf, getSeededRandom(conf.getSeed())); this(conf, getSeededRandom(conf.getSeed()));
} }
public QLearning(QLConfiguration conf, Random random) { public QLearning(QLearningConfiguration conf, Random random) {
expReplay = new ExpReplay<>(conf.getExpRepMaxSize(), conf.getBatchSize(), random); expReplay = new ExpReplay<>(conf.getExpRepMaxSize(), conf.getBatchSize(), random);
} }
private static Random getSeededRandom(Integer seed) { private static Random getSeededRandom(Long seed) {
Random rnd = Nd4j.getRandom(); Random rnd = Nd4j.getRandom();
if(seed != null) { if(seed != null) {
rnd.setSeed(seed); rnd.setSeed(seed);
@ -95,7 +105,7 @@ public abstract class QLearning<O extends Encodable, A, AS extends ActionSpace<A
return getQNetwork(); return getQNetwork();
} }
public abstract QLConfiguration getConfiguration(); public abstract QLearningConfiguration getConfiguration();
protected abstract void preEpoch(); protected abstract void preEpoch();
@ -198,7 +208,7 @@ public abstract class QLearning<O extends Encodable, A, AS extends ActionSpace<A
double reward; double reward;
int episodeLength; int episodeLength;
List<Double> scores; List<Double> scores;
float epsilon; double epsilon;
double startQ; double startQ;
double meanQ; double meanQ;
} }
@ -213,12 +223,14 @@ public abstract class QLearning<O extends Encodable, A, AS extends ActionSpace<A
} }
@Data @Data
@AllArgsConstructor @AllArgsConstructor
@Builder @Builder
@Deprecated
@EqualsAndHashCode(callSuper = false) @EqualsAndHashCode(callSuper = false)
@JsonDeserialize(builder = QLConfiguration.QLConfigurationBuilder.class) @JsonDeserialize(builder = QLConfiguration.QLConfigurationBuilder.class)
public static class QLConfiguration implements LConfiguration { public static class QLConfiguration {
Integer seed; Integer seed;
int maxEpochStep; int maxEpochStep;
@ -237,7 +249,25 @@ public abstract class QLearning<O extends Encodable, A, AS extends ActionSpace<A
@JsonPOJOBuilder(withPrefix = "") @JsonPOJOBuilder(withPrefix = "")
public static final class QLConfigurationBuilder { public static final class QLConfigurationBuilder {
} }
public QLearningConfiguration toLearningConfiguration() {
return QLearningConfiguration.builder()
.seed(seed.longValue())
.maxEpochStep(maxEpochStep)
.maxStep(maxStep)
.expRepMaxSize(expRepMaxSize)
.batchSize(batchSize)
.targetDqnUpdateFreq(targetDqnUpdateFreq)
.updateStart(updateStart)
.rewardFactor(rewardFactor)
.gamma(gamma)
.errorClamp(errorClamp)
.minEpsilon(minEpsilon)
.epsilonNbStep(epsilonNbStep)
.doubleDQN(doubleDQN)
.build();
}
} }
} }

View File

@ -1,5 +1,6 @@
/******************************************************************************* /*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc. * Copyright (c) 2015-2019 Skymind, Inc.
* Copyright (c) 2020 Konduit K.K.
* *
* This program and the accompanying materials are made available under the * This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at * terms of the Apache License, Version 2.0 which is available at
@ -22,6 +23,7 @@ import lombok.Setter;
import org.deeplearning4j.gym.StepReply; import org.deeplearning4j.gym.StepReply;
import org.deeplearning4j.rl4j.learning.IHistoryProcessor; import org.deeplearning4j.rl4j.learning.IHistoryProcessor;
import org.deeplearning4j.rl4j.learning.Learning; import org.deeplearning4j.rl4j.learning.Learning;
import org.deeplearning4j.rl4j.learning.configuration.QLearningConfiguration;
import org.deeplearning4j.rl4j.learning.sync.Transition; import org.deeplearning4j.rl4j.learning.sync.Transition;
import org.deeplearning4j.rl4j.learning.sync.qlearning.QLearning; import org.deeplearning4j.rl4j.learning.sync.qlearning.QLearning;
import org.deeplearning4j.rl4j.learning.sync.qlearning.discrete.TDTargetAlgorithm.DoubleDQN; import org.deeplearning4j.rl4j.learning.sync.qlearning.discrete.TDTargetAlgorithm.DoubleDQN;
@ -45,16 +47,15 @@ import java.util.ArrayList;
/** /**
* @author rubenfiszel (ruben.fiszel@epfl.ch) 7/18/16. * @author rubenfiszel (ruben.fiszel@epfl.ch) 7/18/16.
* * <p>
* DQN or Deep Q-Learning in the Discrete domain * DQN or Deep Q-Learning in the Discrete domain
* * <p>
* http://arxiv.org/abs/1312.5602 * http://arxiv.org/abs/1312.5602
*
*/ */
public abstract class QLearningDiscrete<O extends Encodable> extends QLearning<O, Integer, DiscreteSpace> { public abstract class QLearningDiscrete<O extends Encodable> extends QLearning<O, Integer, DiscreteSpace> {
@Getter @Getter
final private QLConfiguration configuration; final private QLearningConfiguration configuration;
private final LegacyMDPWrapper<O, Integer, DiscreteSpace> mdp; private final LegacyMDPWrapper<O, Integer, DiscreteSpace> mdp;
@Getter @Getter
private DQNPolicy<O> policy; private DQNPolicy<O> policy;
@ -78,16 +79,15 @@ public abstract class QLearningDiscrete<O extends Encodable> extends QLearning<O
return mdp; return mdp;
} }
public QLearningDiscrete(MDP<O, Integer, DiscreteSpace> mdp, IDQN dqn, QLConfiguration conf, public QLearningDiscrete(MDP<O, Integer, DiscreteSpace> mdp, IDQN dqn, QLearningConfiguration conf, int epsilonNbStep) {
int epsilonNbStep) {
this(mdp, dqn, conf, epsilonNbStep, Nd4j.getRandomFactory().getNewRandomInstance(conf.getSeed())); this(mdp, dqn, conf, epsilonNbStep, Nd4j.getRandomFactory().getNewRandomInstance(conf.getSeed()));
} }
public QLearningDiscrete(MDP<O, Integer, DiscreteSpace> mdp, IDQN dqn, QLConfiguration conf, public QLearningDiscrete(MDP<O, Integer, DiscreteSpace> mdp, IDQN dqn, QLearningConfiguration conf,
int epsilonNbStep, Random random) { int epsilonNbStep, Random random) {
super(conf); super(conf);
this.configuration = conf; this.configuration = conf;
this.mdp = new LegacyMDPWrapper<O, Integer, DiscreteSpace>(mdp, null, this); this.mdp = new LegacyMDPWrapper<>(mdp, null, this);
qNetwork = dqn; qNetwork = dqn;
targetQNetwork = dqn.clone(); targetQNetwork = dqn.clone();
policy = new DQNPolicy(getQNetwork()); policy = new DQNPolicy(getQNetwork());
@ -125,6 +125,7 @@ public abstract class QLearningDiscrete<O extends Encodable> extends QLearning<O
/** /**
* Single step of training * Single step of training
*
* @param obs last obs * @param obs last obs
* @return relevant info for next step * @return relevant info for next step
*/ */
@ -135,8 +136,8 @@ public abstract class QLearningDiscrete<O extends Encodable> extends QLearning<O
boolean isHistoryProcessor = getHistoryProcessor() != null; boolean isHistoryProcessor = getHistoryProcessor() != null;
int skipFrame = isHistoryProcessor ? getHistoryProcessor().getConf().getSkipFrame() : 1; int skipFrame = isHistoryProcessor ? getHistoryProcessor().getConf().getSkipFrame() : 1;
int historyLength = isHistoryProcessor ? getHistoryProcessor().getConf().getHistoryLength() : 1; int historyLength = isHistoryProcessor ? getHistoryProcessor().getConf().getHistoryLength() : 1;
int updateStart = getConfiguration().getUpdateStart() int updateStart = this.getConfiguration().getUpdateStart()
+ ((getConfiguration().getBatchSize() + historyLength) * skipFrame); + ((this.getConfiguration().getBatchSize() + historyLength) * skipFrame);
Double maxQ = Double.NaN; //ignore if Nan for stats Double maxQ = Double.NaN; //ignore if Nan for stats
@ -161,7 +162,7 @@ public abstract class QLearningDiscrete<O extends Encodable> extends QLearning<O
if (!obs.isSkipped()) { if (!obs.isSkipped()) {
// Add experience // Add experience
if(pendingTransition != null) { if (pendingTransition != null) {
pendingTransition.setNextObservation(obs); pendingTransition.setNextObservation(obs);
getExpReplay().store(pendingTransition); getExpReplay().store(pendingTransition);
} }
@ -188,7 +189,7 @@ public abstract class QLearningDiscrete<O extends Encodable> extends QLearning<O
@Override @Override
protected void finishEpoch(Observation observation) { protected void finishEpoch(Observation observation) {
if(pendingTransition != null) { if (pendingTransition != null) {
pendingTransition.setNextObservation(observation); pendingTransition.setNextObservation(observation);
getExpReplay().store(pendingTransition); getExpReplay().store(pendingTransition);
} }

View File

@ -1,5 +1,6 @@
/******************************************************************************* /*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc. * Copyright (c) 2015-2019 Skymind, Inc.
* Copyright (c) 2020 Konduit K.K.
* *
* This program and the accompanying materials are made available under the * This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at * terms of the Apache License, Version 2.0 which is available at
@ -17,7 +18,9 @@
package org.deeplearning4j.rl4j.learning.sync.qlearning.discrete; package org.deeplearning4j.rl4j.learning.sync.qlearning.discrete;
import org.deeplearning4j.rl4j.learning.HistoryProcessor; import org.deeplearning4j.rl4j.learning.HistoryProcessor;
import org.deeplearning4j.rl4j.learning.configuration.QLearningConfiguration;
import org.deeplearning4j.rl4j.mdp.MDP; import org.deeplearning4j.rl4j.mdp.MDP;
import org.deeplearning4j.rl4j.network.configuration.NetworkConfiguration;
import org.deeplearning4j.rl4j.network.dqn.DQNFactory; import org.deeplearning4j.rl4j.network.dqn.DQNFactory;
import org.deeplearning4j.rl4j.network.dqn.DQNFactoryStdConv; import org.deeplearning4j.rl4j.network.dqn.DQNFactoryStdConv;
import org.deeplearning4j.rl4j.network.dqn.IDQN; import org.deeplearning4j.rl4j.network.dqn.IDQN;
@ -40,8 +43,16 @@ public class QLearningDiscreteConv<O extends Encodable> extends QLearningDiscret
this(mdp, dqn, hpconf, conf); this(mdp, dqn, hpconf, conf);
addListener(new DataManagerTrainingListener(dataManager)); addListener(new DataManagerTrainingListener(dataManager));
} }
@Deprecated
public QLearningDiscreteConv(MDP<O, Integer, DiscreteSpace> mdp, IDQN dqn, HistoryProcessor.Configuration hpconf, public QLearningDiscreteConv(MDP<O, Integer, DiscreteSpace> mdp, IDQN dqn, HistoryProcessor.Configuration hpconf,
QLConfiguration conf) { QLConfiguration conf) {
super(mdp, dqn, conf.toLearningConfiguration(), conf.getEpsilonNbStep() * hpconf.getSkipFrame());
setHistoryProcessor(hpconf);
}
public QLearningDiscreteConv(MDP<O, Integer, DiscreteSpace> mdp, IDQN dqn, HistoryProcessor.Configuration hpconf,
QLearningConfiguration conf) {
super(mdp, dqn, conf, conf.getEpsilonNbStep() * hpconf.getSkipFrame()); super(mdp, dqn, conf, conf.getEpsilonNbStep() * hpconf.getSkipFrame());
setHistoryProcessor(hpconf); setHistoryProcessor(hpconf);
} }
@ -51,18 +62,32 @@ public class QLearningDiscreteConv<O extends Encodable> extends QLearningDiscret
HistoryProcessor.Configuration hpconf, QLConfiguration conf, IDataManager dataManager) { HistoryProcessor.Configuration hpconf, QLConfiguration conf, IDataManager dataManager) {
this(mdp, factory.buildDQN(hpconf.getShape(), mdp.getActionSpace().getSize()), hpconf, conf, dataManager); this(mdp, factory.buildDQN(hpconf.getShape(), mdp.getActionSpace().getSize()), hpconf, conf, dataManager);
} }
@Deprecated
public QLearningDiscreteConv(MDP<O, Integer, DiscreteSpace> mdp, DQNFactory factory, public QLearningDiscreteConv(MDP<O, Integer, DiscreteSpace> mdp, DQNFactory factory,
HistoryProcessor.Configuration hpconf, QLConfiguration conf) { HistoryProcessor.Configuration hpconf, QLConfiguration conf) {
this(mdp, factory.buildDQN(hpconf.getShape(), mdp.getActionSpace().getSize()), hpconf, conf); this(mdp, factory.buildDQN(hpconf.getShape(), mdp.getActionSpace().getSize()), hpconf, conf);
} }
public QLearningDiscreteConv(MDP<O, Integer, DiscreteSpace> mdp, DQNFactory factory,
HistoryProcessor.Configuration hpconf, QLearningConfiguration conf) {
this(mdp, factory.buildDQN(hpconf.getShape(), mdp.getActionSpace().getSize()), hpconf, conf);
}
@Deprecated @Deprecated
public QLearningDiscreteConv(MDP<O, Integer, DiscreteSpace> mdp, DQNFactoryStdConv.Configuration netConf, public QLearningDiscreteConv(MDP<O, Integer, DiscreteSpace> mdp, DQNFactoryStdConv.Configuration netConf,
HistoryProcessor.Configuration hpconf, QLConfiguration conf, IDataManager dataManager) { HistoryProcessor.Configuration hpconf, QLConfiguration conf, IDataManager dataManager) {
this(mdp, new DQNFactoryStdConv(netConf), hpconf, conf, dataManager); this(mdp, new DQNFactoryStdConv(netConf.toNetworkConfiguration()), hpconf, conf, dataManager);
} }
@Deprecated
public QLearningDiscreteConv(MDP<O, Integer, DiscreteSpace> mdp, DQNFactoryStdConv.Configuration netConf, public QLearningDiscreteConv(MDP<O, Integer, DiscreteSpace> mdp, DQNFactoryStdConv.Configuration netConf,
HistoryProcessor.Configuration hpconf, QLConfiguration conf) { HistoryProcessor.Configuration hpconf, QLConfiguration conf) {
this(mdp, new DQNFactoryStdConv(netConf.toNetworkConfiguration()), hpconf, conf);
}
public QLearningDiscreteConv(MDP<O, Integer, DiscreteSpace> mdp, NetworkConfiguration netConf,
HistoryProcessor.Configuration hpconf, QLearningConfiguration conf) {
this(mdp, new DQNFactoryStdConv(netConf), hpconf, conf); this(mdp, new DQNFactoryStdConv(netConf), hpconf, conf);
} }

View File

@ -1,5 +1,6 @@
/******************************************************************************* /*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc. * Copyright (c) 2015-2019 Skymind, Inc.
* Copyright (c) 2020 Konduit K.K.
* *
* This program and the accompanying materials are made available under the * This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at * terms of the Apache License, Version 2.0 which is available at
@ -16,8 +17,10 @@
package org.deeplearning4j.rl4j.learning.sync.qlearning.discrete; package org.deeplearning4j.rl4j.learning.sync.qlearning.discrete;
import org.deeplearning4j.rl4j.learning.configuration.QLearningConfiguration;
import org.deeplearning4j.rl4j.learning.sync.qlearning.QLearning; import org.deeplearning4j.rl4j.learning.sync.qlearning.QLearning;
import org.deeplearning4j.rl4j.mdp.MDP; import org.deeplearning4j.rl4j.mdp.MDP;
import org.deeplearning4j.rl4j.network.configuration.DQNDenseNetworkConfiguration;
import org.deeplearning4j.rl4j.network.dqn.DQNFactory; import org.deeplearning4j.rl4j.network.dqn.DQNFactory;
import org.deeplearning4j.rl4j.network.dqn.DQNFactoryStdDense; import org.deeplearning4j.rl4j.network.dqn.DQNFactoryStdDense;
import org.deeplearning4j.rl4j.network.dqn.IDQN; import org.deeplearning4j.rl4j.network.dqn.IDQN;
@ -38,7 +41,13 @@ public class QLearningDiscreteDense<O extends Encodable> extends QLearningDiscre
this(mdp, dqn, conf); this(mdp, dqn, conf);
addListener(new DataManagerTrainingListener(dataManager)); addListener(new DataManagerTrainingListener(dataManager));
} }
@Deprecated
public QLearningDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp, IDQN dqn, QLearning.QLConfiguration conf) { public QLearningDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp, IDQN dqn, QLearning.QLConfiguration conf) {
super(mdp, dqn, conf.toLearningConfiguration(), conf.getEpsilonNbStep());
}
public QLearningDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp, IDQN dqn, QLearningConfiguration conf) {
super(mdp, dqn, conf, conf.getEpsilonNbStep()); super(mdp, dqn, conf, conf.getEpsilonNbStep());
} }
@ -48,18 +57,33 @@ public class QLearningDiscreteDense<O extends Encodable> extends QLearningDiscre
this(mdp, factory.buildDQN(mdp.getObservationSpace().getShape(), mdp.getActionSpace().getSize()), conf, this(mdp, factory.buildDQN(mdp.getObservationSpace().getShape(), mdp.getActionSpace().getSize()), conf,
dataManager); dataManager);
} }
@Deprecated
public QLearningDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp, DQNFactory factory, public QLearningDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp, DQNFactory factory,
QLearning.QLConfiguration conf) { QLearning.QLConfiguration conf) {
this(mdp, factory.buildDQN(mdp.getObservationSpace().getShape(), mdp.getActionSpace().getSize()), conf); this(mdp, factory.buildDQN(mdp.getObservationSpace().getShape(), mdp.getActionSpace().getSize()), conf);
} }
public QLearningDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp, DQNFactory factory,
QLearningConfiguration conf) {
this(mdp, factory.buildDQN(mdp.getObservationSpace().getShape(), mdp.getActionSpace().getSize()), conf);
}
@Deprecated @Deprecated
public QLearningDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp, DQNFactoryStdDense.Configuration netConf, public QLearningDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp, DQNFactoryStdDense.Configuration netConf,
QLearning.QLConfiguration conf, IDataManager dataManager) { QLearning.QLConfiguration conf, IDataManager dataManager) {
this(mdp, new DQNFactoryStdDense(netConf), conf, dataManager);
this(mdp, new DQNFactoryStdDense(netConf.toNetworkConfiguration()), conf, dataManager);
} }
@Deprecated
public QLearningDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp, DQNFactoryStdDense.Configuration netConf, public QLearningDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp, DQNFactoryStdDense.Configuration netConf,
QLearning.QLConfiguration conf) { QLearning.QLConfiguration conf) {
this(mdp, new DQNFactoryStdDense(netConf.toNetworkConfiguration()), conf);
}
public QLearningDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp, DQNDenseNetworkConfiguration netConf,
QLearningConfiguration conf) {
this(mdp, new DQNFactoryStdDense(netConf), conf); this(mdp, new DQNFactoryStdDense(netConf), conf);
} }

View File

@ -1,5 +1,6 @@
/******************************************************************************* /*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc. * Copyright (c) 2015-2019 Skymind, Inc.
* Copyright (c) 2020 Konduit K.K.
* *
* This program and the accompanying materials are made available under the * This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at * terms of the Apache License, Version 2.0 which is available at
@ -36,7 +37,7 @@ import java.util.Collection;
* *
* Standard implementation of ActorCriticCompGraph * Standard implementation of ActorCriticCompGraph
*/ */
public class ActorCriticCompGraph<NN extends ActorCriticCompGraph> implements IActorCritic<NN> { public class ActorCriticCompGraph implements IActorCritic<ActorCriticCompGraph> {
final protected ComputationGraph cg; final protected ComputationGraph cg;
@Getter @Getter
@ -73,13 +74,13 @@ public class ActorCriticCompGraph<NN extends ActorCriticCompGraph> implements IA
} }
} }
public NN clone() { public ActorCriticCompGraph clone() {
NN nn = (NN)new ActorCriticCompGraph(cg.clone()); ActorCriticCompGraph nn = new ActorCriticCompGraph(cg.clone());
nn.cg.setListeners(cg.getListeners()); nn.cg.setListeners(cg.getListeners());
return nn; return nn;
} }
public void copy(NN from) { public void copy(ActorCriticCompGraph from) {
cg.setParams(from.cg.params()); cg.setParams(from.cg.params());
} }

View File

@ -1,5 +1,6 @@
/******************************************************************************* /*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc. * Copyright (c) 2015-2019 Skymind, Inc.
* Copyright (c) 2020 Konduit K.K.
* *
* This program and the accompanying materials are made available under the * This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at * terms of the Apache License, Version 2.0 which is available at
@ -31,12 +32,16 @@ import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.weights.WeightInit; import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.optimize.api.TrainingListener; import org.deeplearning4j.optimize.api.TrainingListener;
import org.deeplearning4j.optimize.listeners.ScoreIterationListener; import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
import org.deeplearning4j.rl4j.network.configuration.ActorCriticNetworkConfiguration;
import org.deeplearning4j.rl4j.network.configuration.ActorCriticNetworkConfiguration.ActorCriticNetworkConfigurationBuilder;
import org.deeplearning4j.rl4j.util.Constants; import org.deeplearning4j.rl4j.util.Constants;
import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.learning.config.Adam; import org.nd4j.linalg.learning.config.Adam;
import org.nd4j.linalg.learning.config.IUpdater; import org.nd4j.linalg.learning.config.IUpdater;
import org.nd4j.linalg.lossfunctions.LossFunctions; import org.nd4j.linalg.lossfunctions.LossFunctions;
import java.util.Arrays;
/** /**
* @author rubenfiszel (ruben.fiszel@epfl.ch) on 8/9/16. * @author rubenfiszel (ruben.fiszel@epfl.ch) on 8/9/16.
* *
@ -45,8 +50,7 @@ import org.nd4j.linalg.lossfunctions.LossFunctions;
@Value @Value
public class ActorCriticFactoryCompGraphStdConv implements ActorCriticFactoryCompGraph { public class ActorCriticFactoryCompGraphStdConv implements ActorCriticFactoryCompGraph {
ActorCriticNetworkConfiguration conf;
Configuration conf;
public ActorCriticCompGraph buildActorCritic(int shapeInputs[], int numOutputs) { public ActorCriticCompGraph buildActorCritic(int shapeInputs[], int numOutputs) {
@ -109,16 +113,33 @@ public class ActorCriticFactoryCompGraphStdConv implements ActorCriticFactoryCom
return new ActorCriticCompGraph(model); return new ActorCriticCompGraph(model);
} }
@AllArgsConstructor @AllArgsConstructor
@Builder @Builder
@Value @Value
@Deprecated
public static class Configuration { public static class Configuration {
double l2; double l2;
IUpdater updater; IUpdater updater;
TrainingListener[] listeners; TrainingListener[] listeners;
boolean useLSTM; boolean useLSTM;
/**
* Converts the deprecated Configuration to the new NetworkConfiguration format
*/
public ActorCriticNetworkConfiguration toNetworkConfiguration() {
ActorCriticNetworkConfigurationBuilder builder = ActorCriticNetworkConfiguration.builder()
.l2(l2)
.updater(updater)
.useLSTM(useLSTM);
if (listeners != null) {
builder.listeners(Arrays.asList(listeners));
}
return builder.build();
}
} }
} }

View File

@ -1,5 +1,6 @@
/******************************************************************************* /*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc. * Copyright (c) 2015-2019 Skymind, Inc.
* Copyright (c) 2020 Konduit K.K.
* *
* This program and the accompanying materials are made available under the * This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at * terms of the Apache License, Version 2.0 which is available at
@ -16,8 +17,6 @@
package org.deeplearning4j.rl4j.network.ac; package org.deeplearning4j.rl4j.network.ac;
import lombok.AllArgsConstructor;
import lombok.Builder;
import lombok.Value; import lombok.Value;
import org.deeplearning4j.nn.api.OptimizationAlgorithm; import org.deeplearning4j.nn.api.OptimizationAlgorithm;
import org.deeplearning4j.nn.conf.ComputationGraphConfiguration; import org.deeplearning4j.nn.conf.ComputationGraphConfiguration;
@ -29,12 +28,11 @@ import org.deeplearning4j.nn.conf.layers.OutputLayer;
import org.deeplearning4j.nn.conf.layers.RnnOutputLayer; import org.deeplearning4j.nn.conf.layers.RnnOutputLayer;
import org.deeplearning4j.nn.graph.ComputationGraph; import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.weights.WeightInit; import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.optimize.api.TrainingListener;
import org.deeplearning4j.optimize.listeners.ScoreIterationListener; import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
import org.deeplearning4j.rl4j.network.configuration.ActorCriticDenseNetworkConfiguration;
import org.deeplearning4j.rl4j.util.Constants; import org.deeplearning4j.rl4j.util.Constants;
import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.learning.config.Adam; import org.nd4j.linalg.learning.config.Adam;
import org.nd4j.linalg.learning.config.IUpdater;
import org.nd4j.linalg.lossfunctions.LossFunctions; import org.nd4j.linalg.lossfunctions.LossFunctions;
/** /**
@ -45,7 +43,7 @@ import org.nd4j.linalg.lossfunctions.LossFunctions;
@Value @Value
public class ActorCriticFactoryCompGraphStdDense implements ActorCriticFactoryCompGraph { public class ActorCriticFactoryCompGraphStdDense implements ActorCriticFactoryCompGraph {
Configuration conf; ActorCriticDenseNetworkConfiguration conf;
public ActorCriticCompGraph buildActorCritic(int[] numInputs, int numOutputs) { public ActorCriticCompGraph buildActorCritic(int[] numInputs, int numOutputs) {
int nIn = 1; int nIn = 1;
@ -65,27 +63,27 @@ public class ActorCriticFactoryCompGraphStdDense implements ActorCriticFactoryCo
"input"); "input");
for (int i = 1; i < conf.getNumLayer(); i++) { for (int i = 1; i < conf.getNumLayers(); i++) {
confB.addLayer(i + "", new DenseLayer.Builder().nIn(conf.getNumHiddenNodes()).nOut(conf.getNumHiddenNodes()) confB.addLayer(i + "", new DenseLayer.Builder().nIn(conf.getNumHiddenNodes()).nOut(conf.getNumHiddenNodes())
.activation(Activation.RELU).build(), (i - 1) + ""); .activation(Activation.RELU).build(), (i - 1) + "");
} }
if (conf.isUseLSTM()) { if (conf.isUseLSTM()) {
confB.addLayer(getConf().getNumLayer() + "", new LSTM.Builder().activation(Activation.TANH) confB.addLayer(getConf().getNumLayers() + "", new LSTM.Builder().activation(Activation.TANH)
.nOut(conf.getNumHiddenNodes()).build(), (getConf().getNumLayer() - 1) + ""); .nOut(conf.getNumHiddenNodes()).build(), (getConf().getNumLayers() - 1) + "");
confB.addLayer("value", new RnnOutputLayer.Builder(LossFunctions.LossFunction.MSE).activation(Activation.IDENTITY) confB.addLayer("value", new RnnOutputLayer.Builder(LossFunctions.LossFunction.MSE).activation(Activation.IDENTITY)
.nOut(1).build(), getConf().getNumLayer() + ""); .nOut(1).build(), getConf().getNumLayers() + "");
confB.addLayer("softmax", new RnnOutputLayer.Builder(new ActorCriticLoss()).activation(Activation.SOFTMAX) confB.addLayer("softmax", new RnnOutputLayer.Builder(new ActorCriticLoss()).activation(Activation.SOFTMAX)
.nOut(numOutputs).build(), getConf().getNumLayer() + ""); .nOut(numOutputs).build(), getConf().getNumLayers() + "");
} else { } else {
confB.addLayer("value", new OutputLayer.Builder(LossFunctions.LossFunction.MSE).activation(Activation.IDENTITY) confB.addLayer("value", new OutputLayer.Builder(LossFunctions.LossFunction.MSE).activation(Activation.IDENTITY)
.nOut(1).build(), (getConf().getNumLayer() - 1) + ""); .nOut(1).build(), (getConf().getNumLayers() - 1) + "");
confB.addLayer("softmax", new OutputLayer.Builder(new ActorCriticLoss()).activation(Activation.SOFTMAX) confB.addLayer("softmax", new OutputLayer.Builder(new ActorCriticLoss()).activation(Activation.SOFTMAX)
.nOut(numOutputs).build(), (getConf().getNumLayer() - 1) + ""); .nOut(numOutputs).build(), (getConf().getNumLayers() - 1) + "");
} }
confB.setOutputs("value", "softmax"); confB.setOutputs("value", "softmax");
@ -103,18 +101,4 @@ public class ActorCriticFactoryCompGraphStdDense implements ActorCriticFactoryCo
return new ActorCriticCompGraph(model); return new ActorCriticCompGraph(model);
} }
@AllArgsConstructor
@Builder
@Value
public static class Configuration {
int numLayer;
int numHiddenNodes;
double l2;
IUpdater updater;
TrainingListener[] listeners;
boolean useLSTM;
}
} }

View File

@ -1,5 +1,6 @@
/******************************************************************************* /*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc. * Copyright (c) 2015-2019 Skymind, Inc.
* Copyright (c) 2020 Konduit K.K.
* *
* This program and the accompanying materials are made available under the * This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at * terms of the Apache License, Version 2.0 which is available at
@ -31,21 +32,24 @@ import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.weights.WeightInit; import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.optimize.api.TrainingListener; import org.deeplearning4j.optimize.api.TrainingListener;
import org.deeplearning4j.optimize.listeners.ScoreIterationListener; import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
import org.deeplearning4j.rl4j.network.configuration.ActorCriticDenseNetworkConfiguration;
import org.deeplearning4j.rl4j.network.configuration.ActorCriticDenseNetworkConfiguration.ActorCriticDenseNetworkConfigurationBuilder;
import org.deeplearning4j.rl4j.util.Constants; import org.deeplearning4j.rl4j.util.Constants;
import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.learning.config.Adam; import org.nd4j.linalg.learning.config.Adam;
import org.nd4j.linalg.learning.config.IUpdater; import org.nd4j.linalg.learning.config.IUpdater;
import org.nd4j.linalg.lossfunctions.LossFunctions; import org.nd4j.linalg.lossfunctions.LossFunctions;
import java.util.Arrays;
/** /**
* @author rubenfiszel (ruben.fiszel@epfl.ch) on 8/9/16. * @author rubenfiszel (ruben.fiszel@epfl.ch) on 8/9/16.
*
*
*/ */
@Value @Value
public class ActorCriticFactorySeparateStdDense implements ActorCriticFactorySeparate { public class ActorCriticFactorySeparateStdDense implements ActorCriticFactorySeparate {
Configuration conf; ActorCriticDenseNetworkConfiguration conf;
public ActorCriticSeparate buildActorCritic(int[] numInputs, int numOutputs) { public ActorCriticSeparate buildActorCritic(int[] numInputs, int numOutputs) {
int nIn = 1; int nIn = 1;
@ -61,18 +65,18 @@ public class ActorCriticFactorySeparateStdDense implements ActorCriticFactorySep
.activation(Activation.RELU).build()); .activation(Activation.RELU).build());
for (int i = 1; i < conf.getNumLayer(); i++) { for (int i = 1; i < conf.getNumLayers(); i++) {
confB.layer(i, new DenseLayer.Builder().nIn(conf.getNumHiddenNodes()).nOut(conf.getNumHiddenNodes()) confB.layer(i, new DenseLayer.Builder().nIn(conf.getNumHiddenNodes()).nOut(conf.getNumHiddenNodes())
.activation(Activation.RELU).build()); .activation(Activation.RELU).build());
} }
if (conf.isUseLSTM()) { if (conf.isUseLSTM()) {
confB.layer(conf.getNumLayer(), new LSTM.Builder().nOut(conf.getNumHiddenNodes()).activation(Activation.TANH).build()); confB.layer(conf.getNumLayers(), new LSTM.Builder().nOut(conf.getNumHiddenNodes()).activation(Activation.TANH).build());
confB.layer(conf.getNumLayer() + 1, new RnnOutputLayer.Builder(LossFunctions.LossFunction.MSE).activation(Activation.IDENTITY) confB.layer(conf.getNumLayers() + 1, new RnnOutputLayer.Builder(LossFunctions.LossFunction.MSE).activation(Activation.IDENTITY)
.nIn(conf.getNumHiddenNodes()).nOut(1).build()); .nIn(conf.getNumHiddenNodes()).nOut(1).build());
} else { } else {
confB.layer(conf.getNumLayer(), new OutputLayer.Builder(LossFunctions.LossFunction.MSE).activation(Activation.IDENTITY) confB.layer(conf.getNumLayers(), new OutputLayer.Builder(LossFunctions.LossFunction.MSE).activation(Activation.IDENTITY)
.nIn(conf.getNumHiddenNodes()).nOut(1).build()); .nIn(conf.getNumHiddenNodes()).nOut(1).build());
} }
@ -96,18 +100,18 @@ public class ActorCriticFactorySeparateStdDense implements ActorCriticFactorySep
.activation(Activation.RELU).build()); .activation(Activation.RELU).build());
for (int i = 1; i < conf.getNumLayer(); i++) { for (int i = 1; i < conf.getNumLayers(); i++) {
confB2.layer(i, new DenseLayer.Builder().nIn(conf.getNumHiddenNodes()).nOut(conf.getNumHiddenNodes()) confB2.layer(i, new DenseLayer.Builder().nIn(conf.getNumHiddenNodes()).nOut(conf.getNumHiddenNodes())
.activation(Activation.RELU).build()); .activation(Activation.RELU).build());
} }
if (conf.isUseLSTM()) { if (conf.isUseLSTM()) {
confB2.layer(conf.getNumLayer(), new LSTM.Builder().nOut(conf.getNumHiddenNodes()).activation(Activation.TANH).build()); confB2.layer(conf.getNumLayers(), new LSTM.Builder().nOut(conf.getNumHiddenNodes()).activation(Activation.TANH).build());
confB2.layer(conf.getNumLayer() + 1, new RnnOutputLayer.Builder(new ActorCriticLoss()) confB2.layer(conf.getNumLayers() + 1, new RnnOutputLayer.Builder(new ActorCriticLoss())
.activation(Activation.SOFTMAX).nIn(conf.getNumHiddenNodes()).nOut(numOutputs).build()); .activation(Activation.SOFTMAX).nIn(conf.getNumHiddenNodes()).nOut(numOutputs).build());
} else { } else {
confB2.layer(conf.getNumLayer(), new OutputLayer.Builder(new ActorCriticLoss()) confB2.layer(conf.getNumLayers(), new OutputLayer.Builder(new ActorCriticLoss())
.activation(Activation.SOFTMAX).nIn(conf.getNumHiddenNodes()).nOut(numOutputs).build()); .activation(Activation.SOFTMAX).nIn(conf.getNumHiddenNodes()).nOut(numOutputs).build());
} }
@ -128,6 +132,7 @@ public class ActorCriticFactorySeparateStdDense implements ActorCriticFactorySep
@AllArgsConstructor @AllArgsConstructor
@Value @Value
@Builder @Builder
@Deprecated
public static class Configuration { public static class Configuration {
int numLayer; int numLayer;
@ -136,6 +141,22 @@ public class ActorCriticFactorySeparateStdDense implements ActorCriticFactorySep
IUpdater updater; IUpdater updater;
TrainingListener[] listeners; TrainingListener[] listeners;
boolean useLSTM; boolean useLSTM;
public ActorCriticDenseNetworkConfiguration toNetworkConfiguration() {
ActorCriticDenseNetworkConfigurationBuilder builder = ActorCriticDenseNetworkConfiguration.builder()
.numHiddenNodes(numHiddenNodes)
.numLayers(numLayer)
.l2(l2)
.updater(updater)
.useLSTM(useLSTM);
if (listeners != null) {
builder.listeners(Arrays.asList(listeners));
}
return builder.build();
}
} }

View File

@ -0,0 +1,42 @@
/*******************************************************************************
* Copyright (c) 2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.rl4j.network.configuration;
import lombok.Builder;
import lombok.Data;
import lombok.EqualsAndHashCode;
import lombok.experimental.SuperBuilder;
@Data
@SuperBuilder
@EqualsAndHashCode(callSuper = true)
public class ActorCriticDenseNetworkConfiguration extends ActorCriticNetworkConfiguration {
/**
* The number of layers in the dense network
*/
@Builder.Default
private int numLayers = 3;
/**
* The number of hidden neurons in each layer
*/
@Builder.Default
private int numHiddenNodes = 100;
}

View File

@ -0,0 +1,37 @@
/*******************************************************************************
* Copyright (c) 2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.rl4j.network.configuration;
import lombok.Builder;
import lombok.Data;
import lombok.EqualsAndHashCode;
import lombok.NoArgsConstructor;
import lombok.experimental.SuperBuilder;
@Data
@SuperBuilder
@NoArgsConstructor
@EqualsAndHashCode(callSuper = true)
public class ActorCriticNetworkConfiguration extends NetworkConfiguration {
/**
* Whether or not to add an LSTM layer to the network.
*/
@Builder.Default
private boolean useLSTM = false;
}

View File

@ -0,0 +1,40 @@
/*******************************************************************************
* Copyright (c) 2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.rl4j.network.configuration;
import lombok.Builder;
import lombok.Data;
import lombok.EqualsAndHashCode;
import lombok.experimental.SuperBuilder;
@Data
@SuperBuilder
@EqualsAndHashCode(callSuper = true)
public class DQNDenseNetworkConfiguration extends NetworkConfiguration {
/**
* The number of layers in the dense network
*/
@Builder.Default
private int numLayers = 3;
/**
* The number of hidden neurons in each layer
*/
@Builder.Default
private int numHiddenNodes = 100;
}

View File

@ -0,0 +1,58 @@
/*******************************************************************************
* Copyright (c) 2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.rl4j.network.configuration;
import lombok.Builder;
import lombok.Data;
import lombok.NoArgsConstructor;
import lombok.Singular;
import lombok.experimental.SuperBuilder;
import org.deeplearning4j.optimize.api.TrainingListener;
import org.nd4j.linalg.learning.config.IUpdater;
import java.util.List;
@Data
@SuperBuilder
@NoArgsConstructor
public class NetworkConfiguration {
/**
* The learning rate of the network
*/
@Builder.Default
private double learningRate = 0.01;
/**
* L2 regularization on the network
*/
@Builder.Default
private double l2 = 0.0;
/**
* The network's gradient update algorithm
*/
private IUpdater updater;
/**
* Training listeners attached to the network
*/
@Singular
private List<TrainingListener> listeners;
}

View File

@ -1,5 +1,6 @@
/******************************************************************************* /*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc. * Copyright (c) 2015-2019 Skymind, Inc.
* Copyright (c) 2020 Konduit K.K.
* *
* This program and the accompanying materials are made available under the * This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at * terms of the Apache License, Version 2.0 which is available at
@ -30,12 +31,15 @@ import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.weights.WeightInit; import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.optimize.api.TrainingListener; import org.deeplearning4j.optimize.api.TrainingListener;
import org.deeplearning4j.optimize.listeners.ScoreIterationListener; import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
import org.deeplearning4j.rl4j.network.configuration.NetworkConfiguration;
import org.deeplearning4j.rl4j.util.Constants; import org.deeplearning4j.rl4j.util.Constants;
import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.learning.config.Adam; import org.nd4j.linalg.learning.config.Adam;
import org.nd4j.linalg.learning.config.IUpdater; import org.nd4j.linalg.learning.config.IUpdater;
import org.nd4j.linalg.lossfunctions.LossFunctions; import org.nd4j.linalg.lossfunctions.LossFunctions;
import java.util.Arrays;
/** /**
* @author rubenfiszel (ruben.fiszel@epfl.ch) 7/13/16. * @author rubenfiszel (ruben.fiszel@epfl.ch) 7/13/16.
*/ */
@ -43,7 +47,7 @@ import org.nd4j.linalg.lossfunctions.LossFunctions;
public class DQNFactoryStdConv implements DQNFactory { public class DQNFactoryStdConv implements DQNFactory {
Configuration conf; NetworkConfiguration conf;
public DQN buildDQN(int shapeInputs[], int numOutputs) { public DQN buildDQN(int shapeInputs[], int numOutputs) {
@ -80,7 +84,6 @@ public class DQNFactoryStdConv implements DQNFactory {
return new DQN(model); return new DQN(model);
} }
@AllArgsConstructor @AllArgsConstructor
@Builder @Builder
@Value @Value
@ -90,6 +93,23 @@ public class DQNFactoryStdConv implements DQNFactory {
double l2; double l2;
IUpdater updater; IUpdater updater;
TrainingListener[] listeners; TrainingListener[] listeners;
/**
* Converts the deprecated Configuration to the new NetworkConfiguration format
*/
public NetworkConfiguration toNetworkConfiguration() {
NetworkConfiguration.NetworkConfigurationBuilder builder = NetworkConfiguration.builder()
.learningRate(learningRate)
.l2(l2)
.updater(updater);
if (listeners != null) {
builder.listeners(Arrays.asList(listeners));
}
return builder.build();
}
} }
} }

View File

@ -1,5 +1,6 @@
/******************************************************************************* /*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc. * Copyright (c) 2015-2019 Skymind, Inc.
* Copyright (c) 2020 Konduit K.K.
* *
* This program and the accompanying materials are made available under the * This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at * terms of the Apache License, Version 2.0 which is available at
@ -28,12 +29,16 @@ import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.weights.WeightInit; import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.optimize.api.TrainingListener; import org.deeplearning4j.optimize.api.TrainingListener;
import org.deeplearning4j.optimize.listeners.ScoreIterationListener; import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
import org.deeplearning4j.rl4j.network.configuration.DQNDenseNetworkConfiguration;
import org.deeplearning4j.rl4j.network.configuration.DQNDenseNetworkConfiguration.DQNDenseNetworkConfigurationBuilder;
import org.deeplearning4j.rl4j.util.Constants; import org.deeplearning4j.rl4j.util.Constants;
import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.learning.config.Adam; import org.nd4j.linalg.learning.config.Adam;
import org.nd4j.linalg.learning.config.IUpdater; import org.nd4j.linalg.learning.config.IUpdater;
import org.nd4j.linalg.lossfunctions.LossFunctions; import org.nd4j.linalg.lossfunctions.LossFunctions;
import java.util.Arrays;
/** /**
* @author rubenfiszel (ruben.fiszel@epfl.ch) 7/13/16. * @author rubenfiszel (ruben.fiszel@epfl.ch) 7/13/16.
*/ */
@ -41,32 +46,41 @@ import org.nd4j.linalg.lossfunctions.LossFunctions;
@Value @Value
public class DQNFactoryStdDense implements DQNFactory { public class DQNFactoryStdDense implements DQNFactory {
DQNDenseNetworkConfiguration conf;
Configuration conf;
public DQN buildDQN(int[] numInputs, int numOutputs) { public DQN buildDQN(int[] numInputs, int numOutputs) {
int nIn = 1; int nIn = 1;
for (int i : numInputs) { for (int i : numInputs) {
nIn *= i; nIn *= i;
} }
NeuralNetConfiguration.ListBuilder confB = new NeuralNetConfiguration.Builder().seed(Constants.NEURAL_NET_SEED) NeuralNetConfiguration.ListBuilder confB = new NeuralNetConfiguration.Builder().seed(Constants.NEURAL_NET_SEED)
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
//.updater(Updater.NESTEROVS).momentum(0.9)
//.updater(Updater.RMSPROP).rho(conf.getRmsDecay())//.rmsDecay(conf.getRmsDecay())
.updater(conf.getUpdater() != null ? conf.getUpdater() : new Adam()) .updater(conf.getUpdater() != null ? conf.getUpdater() : new Adam())
.weightInit(WeightInit.XAVIER) .weightInit(WeightInit.XAVIER)
.l2(conf.getL2()) .l2(conf.getL2())
.list().layer(0, new DenseLayer.Builder().nIn(nIn).nOut(conf.getNumHiddenNodes()) .list()
.activation(Activation.RELU).build()); .layer(0,
new DenseLayer.Builder()
.nIn(nIn)
.nOut(conf.getNumHiddenNodes())
.activation(Activation.RELU).build()
);
for (int i = 1; i < conf.getNumLayer(); i++) { for (int i = 1; i < conf.getNumLayers(); i++) {
confB.layer(i, new DenseLayer.Builder().nIn(conf.getNumHiddenNodes()).nOut(conf.getNumHiddenNodes()) confB.layer(i, new DenseLayer.Builder().nIn(conf.getNumHiddenNodes()).nOut(conf.getNumHiddenNodes())
.activation(Activation.RELU).build()); .activation(Activation.RELU).build());
} }
confB.layer(conf.getNumLayer(), new OutputLayer.Builder(LossFunctions.LossFunction.MSE).activation(Activation.IDENTITY) confB.layer(conf.getNumLayers(),
.nIn(conf.getNumHiddenNodes()).nOut(numOutputs).build()); new OutputLayer.Builder(LossFunctions.LossFunction.MSE)
.activation(Activation.IDENTITY)
.nIn(conf.getNumHiddenNodes())
.nOut(numOutputs)
.build()
);
MultiLayerConfiguration mlnconf = confB.build(); MultiLayerConfiguration mlnconf = confB.build();
@ -83,6 +97,7 @@ public class DQNFactoryStdDense implements DQNFactory {
@AllArgsConstructor @AllArgsConstructor
@Value @Value
@Builder @Builder
@Deprecated
public static class Configuration { public static class Configuration {
int numLayer; int numLayer;
@ -90,7 +105,23 @@ public class DQNFactoryStdDense implements DQNFactory {
double l2; double l2;
IUpdater updater; IUpdater updater;
TrainingListener[] listeners; TrainingListener[] listeners;
/**
* Converts the deprecated Configuration to the new NetworkConfiguration format
*/
public DQNDenseNetworkConfiguration toNetworkConfiguration() {
DQNDenseNetworkConfigurationBuilder builder = DQNDenseNetworkConfiguration.builder()
.numHiddenNodes(numHiddenNodes)
.numLayers(numLayer)
.l2(l2)
.updater(updater);
if (listeners != null) {
builder.listeners(Arrays.asList(listeners));
} }
return builder.build();
}
}
} }

View File

@ -1,5 +1,6 @@
/******************************************************************************* /*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc. * Copyright (c) 2015-2019 Skymind, Inc.
* Copyright (c) 2020 Konduit K.K.
* *
* This program and the accompanying materials are made available under the * This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at * terms of the Apache License, Version 2.0 which is available at
@ -46,7 +47,7 @@ public class EpsGreedy<O, A, AS extends ActionSpace<A>> extends Policy<O, A> {
final private int updateStart; final private int updateStart;
final private int epsilonNbStep; final private int epsilonNbStep;
final private Random rnd; final private Random rnd;
final private float minEpsilon; final private double minEpsilon;
final private IEpochTrainer learning; final private IEpochTrainer learning;
public NeuralNet getNeuralNet() { public NeuralNet getNeuralNet() {
@ -55,10 +56,10 @@ public class EpsGreedy<O, A, AS extends ActionSpace<A>> extends Policy<O, A> {
public A nextAction(INDArray input) { public A nextAction(INDArray input) {
float ep = getEpsilon(); double ep = getEpsilon();
if (learning.getStepCounter() % 500 == 1) if (learning.getStepCounter() % 500 == 1)
log.info("EP: " + ep + " " + learning.getStepCounter()); log.info("EP: " + ep + " " + learning.getStepCounter());
if (rnd.nextFloat() > ep) if (rnd.nextDouble() > ep)
return policy.nextAction(input); return policy.nextAction(input);
else else
return mdp.getActionSpace().randomAction(); return mdp.getActionSpace().randomAction();
@ -68,7 +69,7 @@ public class EpsGreedy<O, A, AS extends ActionSpace<A>> extends Policy<O, A> {
return this.nextAction(observation.getData()); return this.nextAction(observation.getData());
} }
public float getEpsilon() { public double getEpsilon() {
return Math.min(1f, Math.max(minEpsilon, 1f - (learning.getStepCounter() - updateStart) * 1f / epsilonNbStep)); return Math.min(1.0, Math.max(minEpsilon, 1.0 - (learning.getStepCounter() - updateStart) * 1.0 / epsilonNbStep));
} }
} }

View File

@ -1,5 +1,6 @@
/******************************************************************************* /*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc. * Copyright (c) 2015-2019 Skymind, Inc.
* Copyright (c) 2020 Konduit K.K.
* *
* This program and the accompanying materials are made available under the * This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at * terms of the Apache License, Version 2.0 which is available at
@ -22,17 +23,30 @@ import lombok.Builder;
import lombok.Getter; import lombok.Getter;
import lombok.Value; import lombok.Value;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.deeplearning4j.rl4j.learning.NeuralNetFetchable;
import org.nd4j.linalg.primitives.Pair;
import org.deeplearning4j.rl4j.learning.ILearning; import org.deeplearning4j.rl4j.learning.ILearning;
import org.deeplearning4j.rl4j.learning.Learning; import org.deeplearning4j.rl4j.learning.NeuralNetFetchable;
import org.deeplearning4j.rl4j.learning.configuration.ILearningConfiguration;
import org.deeplearning4j.rl4j.network.ac.IActorCritic; import org.deeplearning4j.rl4j.network.ac.IActorCritic;
import org.deeplearning4j.rl4j.network.dqn.DQN; import org.deeplearning4j.rl4j.network.dqn.DQN;
import org.deeplearning4j.rl4j.network.dqn.IDQN; import org.deeplearning4j.rl4j.network.dqn.IDQN;
import org.deeplearning4j.util.ModelSerializer; import org.deeplearning4j.util.ModelSerializer;
import org.nd4j.linalg.primitives.Pair;
import java.io.*; import java.io.BufferedOutputStream;
import java.nio.file.*; import java.io.BufferedReader;
import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.File;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.InputStreamReader;
import java.io.OutputStream;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.nio.file.StandardCopyOption;
import java.nio.file.StandardOpenOption;
import java.util.zip.ZipEntry; import java.util.zip.ZipEntry;
import java.util.zip.ZipFile; import java.util.zip.ZipFile;
import java.util.zip.ZipOutputStream; import java.util.zip.ZipOutputStream;
@ -304,7 +318,7 @@ public class DataManager implements IDataManager {
public static class Info { public static class Info {
String trainingName; String trainingName;
String mdpName; String mdpName;
ILearning.LConfiguration conf; ILearningConfiguration conf;
int stepCounter; int stepCounter;
long millisTime; long millisTime;
} }

View File

@ -1,5 +1,6 @@
/******************************************************************************* /*******************************************************************************
* Copyright (c) 2015-2019 Skymind, Inc. * Copyright (c) 2015-2019 Skymind, Inc.
* Copyright (c) 2020 Konduit K.K.
* *
* This program and the accompanying materials are made available under the * This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at * terms of the Apache License, Version 2.0 which is available at
@ -16,14 +17,11 @@
package org.deeplearning4j.rl4j.learning; package org.deeplearning4j.rl4j.learning;
import java.util.Arrays;
import org.junit.Test; import org.junit.Test;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;
/** /**
* *
@ -32,7 +30,7 @@ import static org.junit.Assert.assertTrue;
public class HistoryProcessorTest { public class HistoryProcessorTest {
@Test @Test
public void testHistoryProcessor() throws Exception { public void testHistoryProcessor() {
HistoryProcessor.Configuration conf = HistoryProcessor.Configuration.builder() HistoryProcessor.Configuration conf = HistoryProcessor.Configuration.builder()
.croppingHeight(2).croppingWidth(2).rescaledHeight(2).rescaledWidth(2).build(); .croppingHeight(2).croppingWidth(2).rescaledHeight(2).rescaledWidth(2).build();
IHistoryProcessor hp = new HistoryProcessor(conf); IHistoryProcessor hp = new HistoryProcessor(conf);
@ -43,8 +41,6 @@ public class HistoryProcessorTest {
hp.add(a); hp.add(a);
INDArray[] h = hp.getHistory(); INDArray[] h = hp.getHistory();
assertEquals(4, h.length); assertEquals(4, h.length);
// System.out.println(Arrays.toString(a.shape()));
// System.out.println(Arrays.toString(h[0].shape()));
assertEquals( 1, h[0].shape()[0]); assertEquals( 1, h[0].shape()[0]);
assertEquals(a.shape()[0], h[0].shape()[1]); assertEquals(a.shape()[0], h[0].shape()[1]);
assertEquals(a.shape()[1], h[0].shape()[2]); assertEquals(a.shape()[1], h[0].shape()[2]);

View File

@ -1,9 +1,32 @@
/*******************************************************************************
* Copyright (c) 2015-2019 Skymind, Inc.
* Copyright (c) 2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.rl4j.learning.async; package org.deeplearning4j.rl4j.learning.async;
import org.deeplearning4j.rl4j.learning.configuration.IAsyncLearningConfiguration;
import org.deeplearning4j.rl4j.mdp.MDP; import org.deeplearning4j.rl4j.mdp.MDP;
import org.deeplearning4j.rl4j.policy.IPolicy; import org.deeplearning4j.rl4j.policy.IPolicy;
import org.deeplearning4j.rl4j.space.DiscreteSpace; import org.deeplearning4j.rl4j.space.DiscreteSpace;
import org.deeplearning4j.rl4j.support.*; import org.deeplearning4j.rl4j.support.MockAsyncConfiguration;
import org.deeplearning4j.rl4j.support.MockAsyncGlobal;
import org.deeplearning4j.rl4j.support.MockEncodable;
import org.deeplearning4j.rl4j.support.MockNeuralNet;
import org.deeplearning4j.rl4j.support.MockPolicy;
import org.deeplearning4j.rl4j.support.MockTrainingListener;
import org.junit.Test; import org.junit.Test;
import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertEquals;
@ -68,7 +91,7 @@ public class AsyncLearningTest {
public static class TestContext { public static class TestContext {
MockAsyncConfiguration config = new MockAsyncConfiguration(1, 11, 0, 0, 0, 0,0, 0, 0, 0); MockAsyncConfiguration config = new MockAsyncConfiguration(1L, 11, 0, 0, 0, 0,0, 0, 0, 0);
public final MockAsyncGlobal asyncGlobal = new MockAsyncGlobal(); public final MockAsyncGlobal asyncGlobal = new MockAsyncGlobal();
public final MockPolicy policy = new MockPolicy(); public final MockPolicy policy = new MockPolicy();
public final TestAsyncLearning sut = new TestAsyncLearning(config, asyncGlobal, policy); public final TestAsyncLearning sut = new TestAsyncLearning(config, asyncGlobal, policy);
@ -82,11 +105,11 @@ public class AsyncLearningTest {
} }
public static class TestAsyncLearning extends AsyncLearning<MockEncodable, Integer, DiscreteSpace, MockNeuralNet> { public static class TestAsyncLearning extends AsyncLearning<MockEncodable, Integer, DiscreteSpace, MockNeuralNet> {
private final AsyncConfiguration conf; private final IAsyncLearningConfiguration conf;
private final IAsyncGlobal asyncGlobal; private final IAsyncGlobal asyncGlobal;
private final IPolicy<MockEncodable, Integer> policy; private final IPolicy<MockEncodable, Integer> policy;
public TestAsyncLearning(AsyncConfiguration conf, IAsyncGlobal asyncGlobal, IPolicy<MockEncodable, Integer> policy) { public TestAsyncLearning(IAsyncLearningConfiguration conf, IAsyncGlobal asyncGlobal, IPolicy<MockEncodable, Integer> policy) {
this.conf = conf; this.conf = conf;
this.asyncGlobal = asyncGlobal; this.asyncGlobal = asyncGlobal;
this.policy = policy; this.policy = policy;
@ -98,7 +121,7 @@ public class AsyncLearningTest {
} }
@Override @Override
public AsyncConfiguration getConfiguration() { public IAsyncLearningConfiguration getConfiguration() {
return conf; return conf;
} }

View File

@ -1,7 +1,25 @@
/*******************************************************************************
* Copyright (c) 2015-2019 Skymind, Inc.
* Copyright (c) 2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.rl4j.learning.async; package org.deeplearning4j.rl4j.learning.async;
import org.deeplearning4j.nn.gradient.Gradient; import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.rl4j.learning.IHistoryProcessor; import org.deeplearning4j.rl4j.learning.IHistoryProcessor;
import org.deeplearning4j.rl4j.learning.configuration.IAsyncLearningConfiguration;
import org.deeplearning4j.rl4j.learning.listener.TrainingListenerList; import org.deeplearning4j.rl4j.learning.listener.TrainingListenerList;
import org.deeplearning4j.rl4j.mdp.MDP; import org.deeplearning4j.rl4j.mdp.MDP;
import org.deeplearning4j.rl4j.observation.Observation; import org.deeplearning4j.rl4j.observation.Observation;
@ -32,7 +50,7 @@ public class AsyncThreadDiscreteTest {
MockMDP mdpMock = new MockMDP(observationSpace); MockMDP mdpMock = new MockMDP(observationSpace);
TrainingListenerList listeners = new TrainingListenerList(); TrainingListenerList listeners = new TrainingListenerList();
MockPolicy policyMock = new MockPolicy(); MockPolicy policyMock = new MockPolicy();
MockAsyncConfiguration config = new MockAsyncConfiguration(5, 100, 0, 0, 2, 5,0, 0, 0, 0); MockAsyncConfiguration config = new MockAsyncConfiguration(5L, 100, 0,0, 0, 0, 0, 0, 2, 5);
TestAsyncThreadDiscrete sut = new TestAsyncThreadDiscrete(asyncGlobalMock, mdpMock, listeners, 0, 0, policyMock, config, hpMock); TestAsyncThreadDiscrete sut = new TestAsyncThreadDiscrete(asyncGlobalMock, mdpMock, listeners, 0, 0, policyMock, config, hpMock);
sut.getLegacyMDPWrapper().setTransformProcess(MockMDP.buildTransformProcess(observationSpace.getShape(), hpConf.getSkipFrame(), hpConf.getHistoryLength())); sut.getLegacyMDPWrapper().setTransformProcess(MockMDP.buildTransformProcess(observationSpace.getShape(), hpConf.getSkipFrame(), hpConf.getHistoryLength()));
@ -173,7 +191,7 @@ public class AsyncThreadDiscreteTest {
} }
@Override @Override
protected AsyncConfiguration getConf() { protected IAsyncLearningConfiguration getConf() {
return config; return config;
} }

View File

@ -3,12 +3,20 @@ package org.deeplearning4j.rl4j.learning.async;
import lombok.AllArgsConstructor; import lombok.AllArgsConstructor;
import lombok.Getter; import lombok.Getter;
import org.deeplearning4j.rl4j.learning.IHistoryProcessor; import org.deeplearning4j.rl4j.learning.IHistoryProcessor;
import org.deeplearning4j.rl4j.learning.configuration.IAsyncLearningConfiguration;
import org.deeplearning4j.rl4j.learning.listener.TrainingListenerList; import org.deeplearning4j.rl4j.learning.listener.TrainingListenerList;
import org.deeplearning4j.rl4j.mdp.MDP; import org.deeplearning4j.rl4j.mdp.MDP;
import org.deeplearning4j.rl4j.observation.Observation; import org.deeplearning4j.rl4j.observation.Observation;
import org.deeplearning4j.rl4j.policy.Policy; import org.deeplearning4j.rl4j.policy.Policy;
import org.deeplearning4j.rl4j.space.DiscreteSpace; import org.deeplearning4j.rl4j.space.DiscreteSpace;
import org.deeplearning4j.rl4j.support.*; import org.deeplearning4j.rl4j.support.MockAsyncConfiguration;
import org.deeplearning4j.rl4j.support.MockAsyncGlobal;
import org.deeplearning4j.rl4j.support.MockEncodable;
import org.deeplearning4j.rl4j.support.MockHistoryProcessor;
import org.deeplearning4j.rl4j.support.MockMDP;
import org.deeplearning4j.rl4j.support.MockNeuralNet;
import org.deeplearning4j.rl4j.support.MockObservationSpace;
import org.deeplearning4j.rl4j.support.MockTrainingListener;
import org.deeplearning4j.rl4j.util.IDataManager; import org.deeplearning4j.rl4j.util.IDataManager;
import org.junit.Test; import org.junit.Test;
@ -16,7 +24,6 @@ import java.util.ArrayList;
import java.util.List; import java.util.List;
import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNull;
public class AsyncThreadTest { public class AsyncThreadTest {
@ -126,7 +133,7 @@ public class AsyncThreadTest {
public final MockNeuralNet neuralNet = new MockNeuralNet(); public final MockNeuralNet neuralNet = new MockNeuralNet();
public final MockObservationSpace observationSpace = new MockObservationSpace(); public final MockObservationSpace observationSpace = new MockObservationSpace();
public final MockMDP mdp = new MockMDP(observationSpace); public final MockMDP mdp = new MockMDP(observationSpace);
public final MockAsyncConfiguration config = new MockAsyncConfiguration(5, 10, 0, 0, 10, 0, 0, 0, 0, 0); public final MockAsyncConfiguration config = new MockAsyncConfiguration(5L, 10, 0, 0, 0, 0, 0, 0, 10, 0);
public final TrainingListenerList listeners = new TrainingListenerList(); public final TrainingListenerList listeners = new TrainingListenerList();
public final MockTrainingListener listener = new MockTrainingListener(); public final MockTrainingListener listener = new MockTrainingListener();
public final IHistoryProcessor.Configuration hpConf = new IHistoryProcessor.Configuration(5, 4, 4, 4, 4, 0, 0, 2); public final IHistoryProcessor.Configuration hpConf = new IHistoryProcessor.Configuration(5, 4, 4, 4, 4, 0, 0, 2);
@ -149,11 +156,11 @@ public class AsyncThreadTest {
private final MockAsyncGlobal asyncGlobal; private final MockAsyncGlobal asyncGlobal;
private final MockNeuralNet neuralNet; private final MockNeuralNet neuralNet;
private final AsyncConfiguration conf; private final IAsyncLearningConfiguration conf;
private final List<TrainSubEpochParams> trainSubEpochParams = new ArrayList<TrainSubEpochParams>(); private final List<TrainSubEpochParams> trainSubEpochParams = new ArrayList<TrainSubEpochParams>();
public MockAsyncThread(MockAsyncGlobal asyncGlobal, int threadNumber, MockNeuralNet neuralNet, MDP mdp, AsyncConfiguration conf, TrainingListenerList listeners) { public MockAsyncThread(MockAsyncGlobal asyncGlobal, int threadNumber, MockNeuralNet neuralNet, MDP mdp, IAsyncLearningConfiguration conf, TrainingListenerList listeners) {
super(asyncGlobal, mdp, listeners, threadNumber, 0); super(asyncGlobal, mdp, listeners, threadNumber, 0);
this.asyncGlobal = asyncGlobal; this.asyncGlobal = asyncGlobal;
@ -184,7 +191,7 @@ public class AsyncThreadTest {
} }
@Override @Override
protected AsyncConfiguration getConf() { protected IAsyncLearningConfiguration getConf() {
return conf; return conf;
} }

View File

@ -1,11 +1,27 @@
/*******************************************************************************
* Copyright (c) 2015-2019 Skymind, Inc.
* Copyright (c) 2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.rl4j.learning.async.a3c.discrete; package org.deeplearning4j.rl4j.learning.async.a3c.discrete;
import org.deeplearning4j.nn.api.NeuralNetwork; import org.deeplearning4j.nn.api.NeuralNetwork;
import org.deeplearning4j.nn.gradient.Gradient; import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.rl4j.learning.IHistoryProcessor; import org.deeplearning4j.rl4j.learning.IHistoryProcessor;
import org.deeplearning4j.rl4j.learning.async.MiniTrans; import org.deeplearning4j.rl4j.learning.async.MiniTrans;
import org.deeplearning4j.rl4j.learning.async.nstep.discrete.AsyncNStepQLearningDiscrete; import org.deeplearning4j.rl4j.learning.configuration.A3CLearningConfiguration;
import org.deeplearning4j.rl4j.learning.async.nstep.discrete.AsyncNStepQLearningThreadDiscrete;
import org.deeplearning4j.rl4j.network.NeuralNet; import org.deeplearning4j.rl4j.network.NeuralNet;
import org.deeplearning4j.rl4j.network.ac.IActorCritic; import org.deeplearning4j.rl4j.network.ac.IActorCritic;
import org.deeplearning4j.rl4j.support.*; import org.deeplearning4j.rl4j.support.*;
@ -31,7 +47,7 @@ public class A3CThreadDiscreteTest {
double gamma = 0.9; double gamma = 0.9;
MockObservationSpace observationSpace = new MockObservationSpace(); MockObservationSpace observationSpace = new MockObservationSpace();
MockMDP mdpMock = new MockMDP(observationSpace); MockMDP mdpMock = new MockMDP(observationSpace);
A3CDiscrete.A3CConfiguration config = new A3CDiscrete.A3CConfiguration(0, 0, 0, 0, 0, 0, 0, gamma, 0); A3CLearningConfiguration config = A3CLearningConfiguration.builder().gamma(0.9).build();
MockActorCritic actorCriticMock = new MockActorCritic(); MockActorCritic actorCriticMock = new MockActorCritic();
IHistoryProcessor.Configuration hpConf = new IHistoryProcessor.Configuration(5, 1, 1, 1, 1, 0, 0, 2); IHistoryProcessor.Configuration hpConf = new IHistoryProcessor.Configuration(5, 1, 1, 1, 1, 0, 0, 2);
MockAsyncGlobal<IActorCritic> asyncGlobalMock = new MockAsyncGlobal<IActorCritic>(actorCriticMock); MockAsyncGlobal<IActorCritic> asyncGlobalMock = new MockAsyncGlobal<IActorCritic>(actorCriticMock);
@ -54,9 +70,9 @@ public class A3CThreadDiscreteTest {
Nd4j.zeros(5) Nd4j.zeros(5)
}; };
output[0].putScalar(i, outputs[i]); output[0].putScalar(i, outputs[i]);
minitransList.push(new MiniTrans<Integer>(obs, i, output, rewards[i])); minitransList.push(new MiniTrans<>(obs, i, output, rewards[i]));
} }
minitransList.push(new MiniTrans<Integer>(null, 0, null, 4.0)); // The special batch-ending MiniTrans minitransList.push(new MiniTrans<>(null, 0, null, 4.0)); // The special batch-ending MiniTrans
// Act // Act
sut.calcGradient(actorCriticMock, minitransList); sut.calcGradient(actorCriticMock, minitransList);

View File

@ -1,7 +1,24 @@
/*******************************************************************************
* Copyright (c) 2015-2020 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.rl4j.learning.async.nstep.discrete; package org.deeplearning4j.rl4j.learning.async.nstep.discrete;
import org.deeplearning4j.rl4j.learning.IHistoryProcessor; import org.deeplearning4j.rl4j.learning.IHistoryProcessor;
import org.deeplearning4j.rl4j.learning.async.MiniTrans; import org.deeplearning4j.rl4j.learning.async.MiniTrans;
import org.deeplearning4j.rl4j.learning.configuration.AsyncQLearningConfiguration;
import org.deeplearning4j.rl4j.support.*; import org.deeplearning4j.rl4j.support.*;
import org.junit.Test; import org.junit.Test;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
@ -19,7 +36,7 @@ public class AsyncNStepQLearningThreadDiscreteTest {
double gamma = 0.9; double gamma = 0.9;
MockObservationSpace observationSpace = new MockObservationSpace(); MockObservationSpace observationSpace = new MockObservationSpace();
MockMDP mdpMock = new MockMDP(observationSpace); MockMDP mdpMock = new MockMDP(observationSpace);
AsyncNStepQLearningDiscrete.AsyncNStepQLConfiguration config = new AsyncNStepQLearningDiscrete.AsyncNStepQLConfiguration(0, 0, 0, 0, 0, 0, 0, 0, gamma, 0, 0, 0); AsyncQLearningConfiguration config = AsyncQLearningConfiguration.builder().gamma(gamma).build();
MockDQN dqnMock = new MockDQN(); MockDQN dqnMock = new MockDQN();
IHistoryProcessor.Configuration hpConf = new IHistoryProcessor.Configuration(5, 1, 1, 1, 1, 0, 0, 2); IHistoryProcessor.Configuration hpConf = new IHistoryProcessor.Configuration(5, 1, 1, 1, 1, 0, 0, 2);
MockAsyncGlobal asyncGlobalMock = new MockAsyncGlobal(dqnMock); MockAsyncGlobal asyncGlobalMock = new MockAsyncGlobal(dqnMock);
@ -42,9 +59,9 @@ public class AsyncNStepQLearningThreadDiscreteTest {
Nd4j.zeros(5) Nd4j.zeros(5)
}; };
output[0].putScalar(i, outputs[i]); output[0].putScalar(i, outputs[i]);
minitransList.push(new MiniTrans<Integer>(obs, i, output, rewards[i])); minitransList.push(new MiniTrans<>(obs, i, output, rewards[i]));
} }
minitransList.push(new MiniTrans<Integer>(null, 0, null, 4.0)); // The special batch-ending MiniTrans minitransList.push(new MiniTrans<>(null, 0, null, 4.0)); // The special batch-ending MiniTrans
// Act // Act
sut.calcGradient(dqnMock, minitransList); sut.calcGradient(dqnMock, minitransList);

View File

@ -1,6 +1,26 @@
/*******************************************************************************
* Copyright (c) 2015-2019 Skymind, Inc.
* Copyright (c) 2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.rl4j.learning.sync; package org.deeplearning4j.rl4j.learning.sync;
import lombok.Getter; import lombok.Getter;
import org.deeplearning4j.rl4j.learning.configuration.ILearningConfiguration;
import org.deeplearning4j.rl4j.learning.configuration.LearningConfiguration;
import org.deeplearning4j.rl4j.learning.configuration.QLearningConfiguration;
import org.deeplearning4j.rl4j.learning.sync.qlearning.QLearning; import org.deeplearning4j.rl4j.learning.sync.qlearning.QLearning;
import org.deeplearning4j.rl4j.learning.sync.support.MockStatEntry; import org.deeplearning4j.rl4j.learning.sync.support.MockStatEntry;
import org.deeplearning4j.rl4j.mdp.MDP; import org.deeplearning4j.rl4j.mdp.MDP;
@ -17,7 +37,7 @@ public class SyncLearningTest {
@Test @Test
public void when_training_expect_listenersToBeCalled() { public void when_training_expect_listenersToBeCalled() {
// Arrange // Arrange
QLearning.QLConfiguration lconfig = QLearning.QLConfiguration.builder().maxStep(10).build(); QLearningConfiguration lconfig = QLearningConfiguration.builder().maxStep(10).build();
MockTrainingListener listener = new MockTrainingListener(); MockTrainingListener listener = new MockTrainingListener();
MockSyncLearning sut = new MockSyncLearning(lconfig); MockSyncLearning sut = new MockSyncLearning(lconfig);
sut.addListener(listener); sut.addListener(listener);
@ -34,7 +54,7 @@ public class SyncLearningTest {
@Test @Test
public void when_trainingStartCanContinueFalse_expect_trainingStopped() { public void when_trainingStartCanContinueFalse_expect_trainingStopped() {
// Arrange // Arrange
QLearning.QLConfiguration lconfig = QLearning.QLConfiguration.builder().maxStep(10).build(); QLearningConfiguration lconfig = QLearningConfiguration.builder().maxStep(10).build();
MockTrainingListener listener = new MockTrainingListener(); MockTrainingListener listener = new MockTrainingListener();
MockSyncLearning sut = new MockSyncLearning(lconfig); MockSyncLearning sut = new MockSyncLearning(lconfig);
sut.addListener(listener); sut.addListener(listener);
@ -52,7 +72,7 @@ public class SyncLearningTest {
@Test @Test
public void when_newEpochCanContinueFalse_expect_trainingStopped() { public void when_newEpochCanContinueFalse_expect_trainingStopped() {
// Arrange // Arrange
QLearning.QLConfiguration lconfig = QLearning.QLConfiguration.builder().maxStep(10).build(); QLearningConfiguration lconfig = QLearningConfiguration.builder().maxStep(10).build();
MockTrainingListener listener = new MockTrainingListener(); MockTrainingListener listener = new MockTrainingListener();
MockSyncLearning sut = new MockSyncLearning(lconfig); MockSyncLearning sut = new MockSyncLearning(lconfig);
sut.addListener(listener); sut.addListener(listener);
@ -70,7 +90,7 @@ public class SyncLearningTest {
@Test @Test
public void when_epochTrainingResultCanContinueFalse_expect_trainingStopped() { public void when_epochTrainingResultCanContinueFalse_expect_trainingStopped() {
// Arrange // Arrange
QLearning.QLConfiguration lconfig = QLearning.QLConfiguration.builder().maxStep(10).build(); LearningConfiguration lconfig = QLearningConfiguration.builder().maxStep(10).build();
MockTrainingListener listener = new MockTrainingListener(); MockTrainingListener listener = new MockTrainingListener();
MockSyncLearning sut = new MockSyncLearning(lconfig); MockSyncLearning sut = new MockSyncLearning(lconfig);
sut.addListener(listener); sut.addListener(listener);
@ -87,12 +107,12 @@ public class SyncLearningTest {
public static class MockSyncLearning extends SyncLearning { public static class MockSyncLearning extends SyncLearning {
private final LConfiguration conf; private final ILearningConfiguration conf;
@Getter @Getter
private int currentEpochStep = 0; private int currentEpochStep = 0;
public MockSyncLearning(LConfiguration conf) { public MockSyncLearning(ILearningConfiguration conf) {
this.conf = conf; this.conf = conf;
} }
@ -119,7 +139,7 @@ public class SyncLearningTest {
} }
@Override @Override
public LConfiguration getConfiguration() { public ILearningConfiguration getConfiguration() {
return conf; return conf;
} }

View File

@ -1,5 +1,6 @@
/******************************************************************************* /*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc. * Copyright (c) 2015-2019 Skymind, Inc.
* Copyright (c) 2020 Konduit K.K.
* *
* This program and the accompanying materials are made available under the * This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at * terms of the Apache License, Version 2.0 which is available at
@ -17,36 +18,24 @@
package org.deeplearning4j.rl4j.learning.sync.qlearning; package org.deeplearning4j.rl4j.learning.sync.qlearning;
import com.fasterxml.jackson.databind.ObjectMapper; import com.fasterxml.jackson.databind.ObjectMapper;
import org.deeplearning4j.rl4j.learning.configuration.QLearningConfiguration;
import org.junit.Rule; import org.junit.Rule;
import org.junit.Test; import org.junit.Test;
import org.junit.rules.ExpectedException; import org.junit.rules.ExpectedException;
public class QLConfigurationTest { public class QLearningConfigurationTest {
@Rule @Rule
public ExpectedException thrown = ExpectedException.none(); public ExpectedException thrown = ExpectedException.none();
@Test @Test
public void serialize() throws Exception { public void serialize() throws Exception {
ObjectMapper mapper = new ObjectMapper(); ObjectMapper mapper = new ObjectMapper();
QLearning.QLConfiguration qlConfiguration =
new QLearning.QLConfiguration( QLearningConfiguration qLearningConfiguration = QLearningConfiguration.builder()
123, //Random seed .build();
200, //Max step By epoch
8000, //Max step
150000, //Max size of experience replay
32, //size of batches
500, //target update (hard)
10, //num step noop warmup
0.01, //reward scaling
0.99, //gamma
1.0, //td error clipping
0.1f, //min epsilon
10000, //num step for eps greedy anneal
true //double DQN
);
// Should not throw.. // Should not throw..
String json = mapper.writeValueAsString(qlConfiguration); String json = mapper.writeValueAsString(qLearningConfiguration);
QLearning.QLConfiguration cnf = mapper.readValue(json, QLearning.QLConfiguration.class); QLearningConfiguration cnf = mapper.readValue(json, QLearningConfiguration.class);
} }
} }

View File

@ -1,6 +1,24 @@
/*******************************************************************************
* Copyright (c) 2015-2019 Skymind, Inc.
* Copyright (c) 2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.rl4j.learning.sync.qlearning.discrete; package org.deeplearning4j.rl4j.learning.sync.qlearning.discrete;
import org.deeplearning4j.rl4j.learning.IHistoryProcessor; import org.deeplearning4j.rl4j.learning.IHistoryProcessor;
import org.deeplearning4j.rl4j.learning.configuration.QLearningConfiguration;
import org.deeplearning4j.rl4j.learning.sync.IExpReplay; import org.deeplearning4j.rl4j.learning.sync.IExpReplay;
import org.deeplearning4j.rl4j.learning.sync.Transition; import org.deeplearning4j.rl4j.learning.sync.Transition;
import org.deeplearning4j.rl4j.learning.sync.qlearning.QLearning; import org.deeplearning4j.rl4j.learning.sync.qlearning.QLearning;
@ -27,7 +45,7 @@ public class QLearningDiscreteTest {
// Arrange // Arrange
MockObservationSpace observationSpace = new MockObservationSpace(); MockObservationSpace observationSpace = new MockObservationSpace();
MockDQN dqn = new MockDQN(); MockDQN dqn = new MockDQN();
MockRandom random = new MockRandom(new double[] { MockRandom random = new MockRandom(new double[]{
0.7309677600860596, 0.7309677600860596,
0.8314409852027893, 0.8314409852027893,
0.2405363917350769, 0.2405363917350769,
@ -37,13 +55,25 @@ public class QLearningDiscreteTest {
0.5504369735717773, 0.5504369735717773,
0.11700659990310669 0.11700659990310669
}, },
new int[] { 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4 }); new int[]{1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4});
MockMDP mdp = new MockMDP(observationSpace, random); MockMDP mdp = new MockMDP(observationSpace, random);
int initStepCount = 8; int initStepCount = 8;
QLearning.QLConfiguration conf = new QLearning.QLConfiguration(0, 24, 0, 5, 1, 1000, QLearningConfiguration conf = QLearningConfiguration.builder()
initStepCount, 1.0, 0, 0, 0, 0, true); .seed(0L)
.maxEpochStep(24)
.maxStep(0)
.expRepMaxSize(5).batchSize(1).targetDqnUpdateFreq(1000)
.updateStart(initStepCount)
.rewardFactor(1.0)
.gamma(0)
.errorClamp(0)
.minEpsilon(0)
.epsilonNbStep(0)
.doubleDQN(true)
.build();
MockDataManager dataManager = new MockDataManager(false); MockDataManager dataManager = new MockDataManager(false);
MockExpReplay expReplay = new MockExpReplay(); MockExpReplay expReplay = new MockExpReplay();
TestQLearningDiscrete sut = new TestQLearningDiscrete(mdp, dqn, conf, dataManager, expReplay, 10, random); TestQLearningDiscrete sut = new TestQLearningDiscrete(mdp, dqn, conf, dataManager, expReplay, 10, random);
@ -58,9 +88,9 @@ public class QLearningDiscreteTest {
// Assert // Assert
// HistoryProcessor calls // HistoryProcessor calls
double[] expectedRecords = new double[] { 0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0, 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0 }; double[] expectedRecords = new double[]{0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0, 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0};
assertEquals(expectedRecords.length, hp.recordCalls.size()); assertEquals(expectedRecords.length, hp.recordCalls.size());
for(int i = 0; i < expectedRecords.length; ++i) { for (int i = 0; i < expectedRecords.length; ++i) {
assertEquals(expectedRecords[i], hp.recordCalls.get(i).getDouble(0), 0.0001); assertEquals(expectedRecords[i], hp.recordCalls.get(i).getDouble(0), 0.0001);
} }
@ -72,59 +102,59 @@ public class QLearningDiscreteTest {
assertEquals(123.0, dqn.fitParams.get(0).getFirst().getDouble(0), 0.001); assertEquals(123.0, dqn.fitParams.get(0).getFirst().getDouble(0), 0.001);
assertEquals(234.0, dqn.fitParams.get(0).getSecond().getDouble(0), 0.001); assertEquals(234.0, dqn.fitParams.get(0).getSecond().getDouble(0), 0.001);
assertEquals(14, dqn.outputParams.size()); assertEquals(14, dqn.outputParams.size());
double[][] expectedDQNOutput = new double[][] { double[][] expectedDQNOutput = new double[][]{
new double[] { 0.0, 2.0, 4.0, 6.0, 8.0 }, new double[]{0.0, 2.0, 4.0, 6.0, 8.0},
new double[] { 2.0, 4.0, 6.0, 8.0, 10.0 }, new double[]{2.0, 4.0, 6.0, 8.0, 10.0},
new double[] { 2.0, 4.0, 6.0, 8.0, 10.0 }, new double[]{2.0, 4.0, 6.0, 8.0, 10.0},
new double[] { 4.0, 6.0, 8.0, 10.0, 12.0 }, new double[]{4.0, 6.0, 8.0, 10.0, 12.0},
new double[] { 6.0, 8.0, 10.0, 12.0, 14.0 }, new double[]{6.0, 8.0, 10.0, 12.0, 14.0},
new double[] { 6.0, 8.0, 10.0, 12.0, 14.0 }, new double[]{6.0, 8.0, 10.0, 12.0, 14.0},
new double[] { 8.0, 10.0, 12.0, 14.0, 16.0 }, new double[]{8.0, 10.0, 12.0, 14.0, 16.0},
new double[] { 8.0, 10.0, 12.0, 14.0, 16.0 }, new double[]{8.0, 10.0, 12.0, 14.0, 16.0},
new double[] { 10.0, 12.0, 14.0, 16.0, 18.0 }, new double[]{10.0, 12.0, 14.0, 16.0, 18.0},
new double[] { 10.0, 12.0, 14.0, 16.0, 18.0 }, new double[]{10.0, 12.0, 14.0, 16.0, 18.0},
new double[] { 12.0, 14.0, 16.0, 18.0, 20.0 }, new double[]{12.0, 14.0, 16.0, 18.0, 20.0},
new double[] { 12.0, 14.0, 16.0, 18.0, 20.0 }, new double[]{12.0, 14.0, 16.0, 18.0, 20.0},
new double[] { 14.0, 16.0, 18.0, 20.0, 22.0 }, new double[]{14.0, 16.0, 18.0, 20.0, 22.0},
new double[] { 14.0, 16.0, 18.0, 20.0, 22.0 }, new double[]{14.0, 16.0, 18.0, 20.0, 22.0},
}; };
for(int i = 0; i < expectedDQNOutput.length; ++i) { for (int i = 0; i < expectedDQNOutput.length; ++i) {
INDArray outputParam = dqn.outputParams.get(i); INDArray outputParam = dqn.outputParams.get(i);
assertEquals(5, outputParam.shape()[1]); assertEquals(5, outputParam.shape()[1]);
assertEquals(1, outputParam.shape()[2]); assertEquals(1, outputParam.shape()[2]);
double[] expectedRow = expectedDQNOutput[i]; double[] expectedRow = expectedDQNOutput[i];
for(int j = 0; j < expectedRow.length; ++j) { for (int j = 0; j < expectedRow.length; ++j) {
assertEquals("row: "+ i + " col: " + j, expectedRow[j], 255.0 * outputParam.getDouble(j), 0.00001); assertEquals("row: " + i + " col: " + j, expectedRow[j], 255.0 * outputParam.getDouble(j), 0.00001);
} }
} }
// MDP calls // MDP calls
assertArrayEquals(new Integer[] {0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 4, 4, 2, 2, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4 }, mdp.actions.toArray()); assertArrayEquals(new Integer[]{0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 4, 4, 2, 2, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4}, mdp.actions.toArray());
// ExpReplay calls // ExpReplay calls
double[] expectedTrRewards = new double[] { 9.0, 21.0, 25.0, 29.0, 33.0, 37.0, 41.0 }; double[] expectedTrRewards = new double[]{9.0, 21.0, 25.0, 29.0, 33.0, 37.0, 41.0};
int[] expectedTrActions = new int[] { 1, 4, 2, 4, 4, 4, 4, 4 }; int[] expectedTrActions = new int[]{1, 4, 2, 4, 4, 4, 4, 4};
double[] expectedTrNextObservation = new double[] { 2.0, 4.0, 6.0, 8.0, 10.0, 12.0, 14.0, 16.0 }; double[] expectedTrNextObservation = new double[]{2.0, 4.0, 6.0, 8.0, 10.0, 12.0, 14.0, 16.0};
double[][] expectedTrObservations = new double[][] { double[][] expectedTrObservations = new double[][]{
new double[] { 0.0, 2.0, 4.0, 6.0, 8.0 }, new double[]{0.0, 2.0, 4.0, 6.0, 8.0},
new double[] { 2.0, 4.0, 6.0, 8.0, 10.0 }, new double[]{2.0, 4.0, 6.0, 8.0, 10.0},
new double[] { 4.0, 6.0, 8.0, 10.0, 12.0 }, new double[]{4.0, 6.0, 8.0, 10.0, 12.0},
new double[] { 6.0, 8.0, 10.0, 12.0, 14.0 }, new double[]{6.0, 8.0, 10.0, 12.0, 14.0},
new double[] { 8.0, 10.0, 12.0, 14.0, 16.0 }, new double[]{8.0, 10.0, 12.0, 14.0, 16.0},
new double[] { 10.0, 12.0, 14.0, 16.0, 18.0 }, new double[]{10.0, 12.0, 14.0, 16.0, 18.0},
new double[] { 12.0, 14.0, 16.0, 18.0, 20.0 }, new double[]{12.0, 14.0, 16.0, 18.0, 20.0},
new double[] { 14.0, 16.0, 18.0, 20.0, 22.0 }, new double[]{14.0, 16.0, 18.0, 20.0, 22.0},
}; };
assertEquals(expectedTrObservations.length, expReplay.transitions.size()); assertEquals(expectedTrObservations.length, expReplay.transitions.size());
for(int i = 0; i < expectedTrRewards.length; ++i) { for (int i = 0; i < expectedTrRewards.length; ++i) {
Transition tr = expReplay.transitions.get(i); Transition tr = expReplay.transitions.get(i);
assertEquals(expectedTrRewards[i], tr.getReward(), 0.0001); assertEquals(expectedTrRewards[i], tr.getReward(), 0.0001);
assertEquals(expectedTrActions[i], tr.getAction()); assertEquals(expectedTrActions[i], tr.getAction());
assertEquals(expectedTrNextObservation[i], 255.0 * tr.getNextObservation().getDouble(0), 0.0001); assertEquals(expectedTrNextObservation[i], 255.0 * tr.getNextObservation().getDouble(0), 0.0001);
for(int j = 0; j < expectedTrObservations[i].length; ++j) { for (int j = 0; j < expectedTrObservations[i].length; ++j) {
assertEquals("row: "+ i + " col: " + j, expectedTrObservations[i][j], 255.0 * tr.getObservation().getData().getDouble(0, j, 0), 0.0001); assertEquals("row: " + i + " col: " + j, expectedTrObservations[i][j], 255.0 * tr.getObservation().getData().getDouble(0, j, 0), 0.0001);
} }
} }
@ -132,12 +162,12 @@ public class QLearningDiscreteTest {
assertEquals(initStepCount + 16, result.getStepCounter()); assertEquals(initStepCount + 16, result.getStepCounter());
assertEquals(300.0, result.getReward(), 0.00001); assertEquals(300.0, result.getReward(), 0.00001);
assertTrue(dqn.hasBeenReset); assertTrue(dqn.hasBeenReset);
assertTrue(((MockDQN)sut.getTargetQNetwork()).hasBeenReset); assertTrue(((MockDQN) sut.getTargetQNetwork()).hasBeenReset);
} }
public static class TestQLearningDiscrete extends QLearningDiscrete<MockEncodable> { public static class TestQLearningDiscrete extends QLearningDiscrete<MockEncodable> {
public TestQLearningDiscrete(MDP<MockEncodable, Integer, DiscreteSpace> mdp, IDQN dqn, public TestQLearningDiscrete(MDP<MockEncodable, Integer, DiscreteSpace> mdp, IDQN dqn,
QLConfiguration conf, IDataManager dataManager, MockExpReplay expReplay, QLearningConfiguration conf, IDataManager dataManager, MockExpReplay expReplay,
int epsilonNbStep, Random rnd) { int epsilonNbStep, Random rnd) {
super(mdp, dqn, conf, epsilonNbStep, rnd); super(mdp, dqn, conf, epsilonNbStep, rnd);
addListener(new DataManagerTrainingListener(dataManager)); addListener(new DataManagerTrainingListener(dataManager));
@ -146,10 +176,10 @@ public class QLearningDiscreteTest {
@Override @Override
protected DataSet setTarget(ArrayList<Transition<Integer>> transitions) { protected DataSet setTarget(ArrayList<Transition<Integer>> transitions) {
return new org.nd4j.linalg.dataset.DataSet(Nd4j.create(new double[] { 123.0 }), Nd4j.create(new double[] { 234.0 })); return new org.nd4j.linalg.dataset.DataSet(Nd4j.create(new double[]{123.0}), Nd4j.create(new double[]{234.0}));
} }
public void setExpReplay(IExpReplay<Integer> exp){ public void setExpReplay(IExpReplay<Integer> exp) {
this.expReplay = exp; this.expReplay = exp;
} }

View File

@ -1,5 +1,6 @@
/******************************************************************************* /*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc. * Copyright (c) 2015-2019 Skymind, Inc.
* Copyright (c) 2020 Konduit K.K.
* *
* This program and the accompanying materials are made available under the * This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at * terms of the Apache License, Version 2.0 which is available at
@ -16,6 +17,7 @@
package org.deeplearning4j.rl4j.network.ac; package org.deeplearning4j.rl4j.network.ac;
import org.deeplearning4j.rl4j.network.configuration.ActorCriticDenseNetworkConfiguration;
import org.junit.Test; import org.junit.Test;
import org.nd4j.linalg.activations.impl.ActivationSoftmax; import org.nd4j.linalg.activations.impl.ActivationSoftmax;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
@ -29,30 +31,31 @@ import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue; import static org.junit.Assert.assertTrue;
/** /**
*
* @author saudet * @author saudet
*/ */
public class ActorCriticTest { public class ActorCriticTest {
public static ActorCriticFactorySeparateStdDense.Configuration NET_CONF = public static ActorCriticDenseNetworkConfiguration NET_CONF =
new ActorCriticFactorySeparateStdDense.Configuration( ActorCriticDenseNetworkConfiguration.builder()
4, //number of layers .numLayers(4)
32, //number of hidden nodes .numHiddenNodes(32)
0.001, //l2 regularization .l2(0.001)
new RmsProp(0.0005), null, false .updater(new RmsProp(0.0005))
); .useLSTM(false)
.build();
public static ActorCriticFactoryCompGraphStdDense.Configuration NET_CONF_CG = public static ActorCriticDenseNetworkConfiguration NET_CONF_CG =
new ActorCriticFactoryCompGraphStdDense.Configuration( ActorCriticDenseNetworkConfiguration.builder()
2, //number of layers .numLayers(2)
128, //number of hidden nodes .numHiddenNodes(128)
0.00001, //l2 regularization .l2(0.00001)
new RmsProp(0.005), null, true .updater(new RmsProp(0.005))
); .useLSTM(true)
.build();
@Test @Test
public void testModelLoadSave() throws IOException { public void testModelLoadSave() throws IOException {
ActorCriticSeparate acs = new ActorCriticFactorySeparateStdDense(NET_CONF).buildActorCritic(new int[] {7}, 5); ActorCriticSeparate acs = new ActorCriticFactorySeparateStdDense(NET_CONF).buildActorCritic(new int[]{7}, 5);
File fileValue = File.createTempFile("rl4j-value-", ".model"); File fileValue = File.createTempFile("rl4j-value-", ".model");
File filePolicy = File.createTempFile("rl4j-policy-", ".model"); File filePolicy = File.createTempFile("rl4j-policy-", ".model");
@ -63,7 +66,7 @@ public class ActorCriticTest {
assertEquals(acs.valueNet, acs2.valueNet); assertEquals(acs.valueNet, acs2.valueNet);
assertEquals(acs.policyNet, acs2.policyNet); assertEquals(acs.policyNet, acs2.policyNet);
ActorCriticCompGraph accg = new ActorCriticFactoryCompGraphStdDense(NET_CONF_CG).buildActorCritic(new int[] {37}, 43); ActorCriticCompGraph accg = new ActorCriticFactoryCompGraphStdDense(NET_CONF_CG).buildActorCritic(new int[]{37}, 43);
File file = File.createTempFile("rl4j-cg-", ".model"); File file = File.createTempFile("rl4j-cg-", ".model");
accg.save(file.getAbsolutePath()); accg.save(file.getAbsolutePath());
@ -83,15 +86,15 @@ public class ActorCriticTest {
for (double i = eps; i < n; i++) { for (double i = eps; i < n; i++) {
for (double j = eps; j < n; j++) { for (double j = eps; j < n; j++) {
INDArray labels = Nd4j.create(new double[] {i / n, 1 - i / n}, new long[]{1,2}); INDArray labels = Nd4j.create(new double[]{i / n, 1 - i / n}, new long[]{1, 2});
INDArray output = Nd4j.create(new double[] {j / n, 1 - j / n}, new long[]{1,2}); INDArray output = Nd4j.create(new double[]{j / n, 1 - j / n}, new long[]{1, 2});
INDArray gradient = loss.computeGradient(labels, output, activation, null); INDArray gradient = loss.computeGradient(labels, output, activation, null);
output = Nd4j.create(new double[] {j / n, 1 - j / n}, new long[]{1,2}); output = Nd4j.create(new double[]{j / n, 1 - j / n}, new long[]{1, 2});
double score = loss.computeScore(labels, output, activation, null, false); double score = loss.computeScore(labels, output, activation, null, false);
INDArray output1 = Nd4j.create(new double[] {j / n + eps, 1 - j / n}, new long[]{1,2}); INDArray output1 = Nd4j.create(new double[]{j / n + eps, 1 - j / n}, new long[]{1, 2});
double score1 = loss.computeScore(labels, output1, activation, null, false); double score1 = loss.computeScore(labels, output1, activation, null, false);
INDArray output2 = Nd4j.create(new double[] {j / n, 1 - j / n + eps}, new long[]{1,2}); INDArray output2 = Nd4j.create(new double[]{j / n, 1 - j / n + eps}, new long[]{1, 2});
double score2 = loss.computeScore(labels, output2, activation, null, false); double score2 = loss.computeScore(labels, output2, activation, null, false);
double gradient1 = (score1 - score) / eps; double gradient1 = (score1 - score) / eps;

View File

@ -1,5 +1,6 @@
/******************************************************************************* /*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc. * Copyright (c) 2015-2019 Skymind, Inc.
* Copyright (c) 2020 Konduit K.K.
* *
* This program and the accompanying materials are made available under the * This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at * terms of the Apache License, Version 2.0 which is available at
@ -16,6 +17,7 @@
package org.deeplearning4j.rl4j.network.dqn; package org.deeplearning4j.rl4j.network.dqn;
import org.deeplearning4j.rl4j.network.configuration.DQNDenseNetworkConfiguration;
import org.junit.Test; import org.junit.Test;
import org.nd4j.linalg.learning.config.RmsProp; import org.nd4j.linalg.learning.config.RmsProp;
@ -25,22 +27,20 @@ import java.io.IOException;
import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertEquals;
/** /**
*
* @author saudet * @author saudet
*/ */
public class DQNTest { public class DQNTest {
public static DQNFactoryStdDense.Configuration NET_CONF = private static DQNDenseNetworkConfiguration NET_CONF =
new DQNFactoryStdDense.Configuration( DQNDenseNetworkConfiguration.builder().numLayers(3)
3, //number of layers .numHiddenNodes(16)
16, //number of hidden nodes .l2(0.001)
0.001, //l2 regularization .updater(new RmsProp(0.0005))
new RmsProp(0.0005), null .build();
);
@Test @Test
public void testModelLoadSave() throws IOException { public void testModelLoadSave() throws IOException {
DQN dqn = new DQNFactoryStdDense(NET_CONF).buildDQN(new int[] {42}, 13); DQN dqn = new DQNFactoryStdDense(NET_CONF).buildDQN(new int[]{42}, 13);
File file = File.createTempFile("rl4j-dqn-", ".model"); File file = File.createTempFile("rl4j-dqn-", ".model");
dqn.save(file.getAbsolutePath()); dqn.save(file.getAbsolutePath());

View File

@ -128,7 +128,7 @@ public class TransformProcessTest {
// Assert // Assert
assertFalse(result.isSkipped()); assertFalse(result.isSkipped());
assertEquals(1, result.getData().shape().length); assertEquals(2, result.getData().shape().length);
assertEquals(1, result.getData().shape()[0]); assertEquals(1, result.getData().shape()[0]);
assertEquals(-10.0, result.getData().getDouble(0), 0.00001); assertEquals(-10.0, result.getData().getDouble(0), 0.00001);
} }

View File

@ -1,5 +1,6 @@
/******************************************************************************* /*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc. * Copyright (c) 2015-2019 Skymind, Inc.
* Copyright (c) 2020 Konduit K.K.
* *
* This program and the accompanying materials are made available under the * This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at * terms of the Apache License, Version 2.0 which is available at
@ -24,16 +25,18 @@ import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.rl4j.learning.IHistoryProcessor; import org.deeplearning4j.rl4j.learning.IHistoryProcessor;
import org.deeplearning4j.rl4j.learning.Learning; import org.deeplearning4j.rl4j.learning.Learning;
import org.deeplearning4j.rl4j.learning.sync.qlearning.QLearning; import org.deeplearning4j.rl4j.learning.configuration.QLearningConfiguration;
import org.deeplearning4j.rl4j.learning.sync.qlearning.discrete.QLearningDiscreteTest;
import org.deeplearning4j.rl4j.mdp.MDP;
import org.deeplearning4j.rl4j.network.NeuralNet; import org.deeplearning4j.rl4j.network.NeuralNet;
import org.deeplearning4j.rl4j.network.ac.IActorCritic; import org.deeplearning4j.rl4j.network.ac.IActorCritic;
import org.deeplearning4j.rl4j.observation.Observation; import org.deeplearning4j.rl4j.observation.Observation;
import org.deeplearning4j.rl4j.space.ActionSpace; import org.deeplearning4j.rl4j.space.ActionSpace;
import org.deeplearning4j.rl4j.space.DiscreteSpace; import org.deeplearning4j.rl4j.support.MockDQN;
import org.deeplearning4j.rl4j.space.Encodable; import org.deeplearning4j.rl4j.support.MockEncodable;
import org.deeplearning4j.rl4j.support.*; import org.deeplearning4j.rl4j.support.MockHistoryProcessor;
import org.deeplearning4j.rl4j.support.MockMDP;
import org.deeplearning4j.rl4j.support.MockNeuralNet;
import org.deeplearning4j.rl4j.support.MockObservationSpace;
import org.deeplearning4j.rl4j.support.MockRandom;
import org.deeplearning4j.rl4j.util.LegacyMDPWrapper; import org.deeplearning4j.rl4j.util.LegacyMDPWrapper;
import org.junit.Test; import org.junit.Test;
import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.activations.Activation;
@ -43,8 +46,6 @@ import org.nd4j.linalg.lossfunctions.LossFunctions;
import java.io.IOException; import java.io.IOException;
import java.io.OutputStream; import java.io.OutputStream;
import java.util.ArrayList;
import java.util.List;
import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotNull; import static org.junit.Assert.assertNotNull;
@ -186,8 +187,22 @@ public class PolicyTest {
new int[] { 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4 }); new int[] { 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4 });
MockMDP mdp = new MockMDP(observationSpace, 30, random); MockMDP mdp = new MockMDP(observationSpace, 30, random);
QLearning.QLConfiguration conf = new QLearning.QLConfiguration(0, 0, 0, 5, 1, 0, QLearningConfiguration conf = QLearningConfiguration.builder()
0, 1.0, 0, 0, 0, 0, true); .seed(0L)
.maxEpochStep(0)
.maxStep(0)
.expRepMaxSize(5)
.batchSize(1)
.targetDqnUpdateFreq(0)
.updateStart(0)
.rewardFactor(1.0)
.gamma(0)
.errorClamp(0)
.minEpsilon(0)
.epsilonNbStep(0)
.doubleDQN(true)
.build();
MockNeuralNet nnMock = new MockNeuralNet(); MockNeuralNet nnMock = new MockNeuralNet();
IHistoryProcessor.Configuration hpConf = new IHistoryProcessor.Configuration(5, 4, 4, 4, 4, 0, 0, 2); IHistoryProcessor.Configuration hpConf = new IHistoryProcessor.Configuration(5, 4, 4, 4, 4, 0, 0, 2);
MockRefacPolicy sut = new MockRefacPolicy(nnMock, observationSpace.getShape(), hpConf.getSkipFrame(), hpConf.getHistoryLength()); MockRefacPolicy sut = new MockRefacPolicy(nnMock, observationSpace.getShape(), hpConf.getSkipFrame(), hpConf.getHistoryLength());

View File

@ -1,22 +1,37 @@
/*******************************************************************************
* Copyright (c) 2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.rl4j.support; package org.deeplearning4j.rl4j.support;
import lombok.AllArgsConstructor; import lombok.AllArgsConstructor;
import lombok.Getter;
import lombok.Value; import lombok.Value;
import org.deeplearning4j.rl4j.learning.async.AsyncConfiguration; import org.deeplearning4j.rl4j.learning.configuration.IAsyncLearningConfiguration;
@AllArgsConstructor
@Value @Value
public class MockAsyncConfiguration implements AsyncConfiguration { @AllArgsConstructor
public class MockAsyncConfiguration implements IAsyncLearningConfiguration {
private Integer seed; private Long seed;
private int maxEpochStep; private int maxEpochStep;
private int maxStep; private int maxStep;
private int numThread;
private int nstep;
private int targetDqnUpdateFreq;
private int updateStart; private int updateStart;
private double rewardFactor; private double rewardFactor;
private double gamma; private double gamma;
private double errorClamp; private double errorClamp;
private int numThreads;
private int nStep;
private int learnerUpdateFrequency;
} }

View File

@ -1,3 +1,20 @@
/*******************************************************************************
* Copyright (c) 2015-2019 Skymind, Inc.
* Copyright (c) 2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.rl4j.util; package org.deeplearning4j.rl4j.util;
import lombok.Getter; import lombok.Getter;
@ -5,6 +22,7 @@ import lombok.Setter;
import org.deeplearning4j.rl4j.learning.IEpochTrainer; import org.deeplearning4j.rl4j.learning.IEpochTrainer;
import org.deeplearning4j.rl4j.learning.IHistoryProcessor; import org.deeplearning4j.rl4j.learning.IHistoryProcessor;
import org.deeplearning4j.rl4j.learning.ILearning; import org.deeplearning4j.rl4j.learning.ILearning;
import org.deeplearning4j.rl4j.learning.configuration.ILearningConfiguration;
import org.deeplearning4j.rl4j.learning.listener.TrainingListener; import org.deeplearning4j.rl4j.learning.listener.TrainingListener;
import org.deeplearning4j.rl4j.learning.sync.support.MockStatEntry; import org.deeplearning4j.rl4j.learning.sync.support.MockStatEntry;
import org.deeplearning4j.rl4j.mdp.MDP; import org.deeplearning4j.rl4j.mdp.MDP;
@ -162,7 +180,7 @@ public class DataManagerTrainingListenerTest {
} }
@Override @Override
public LConfiguration getConfiguration() { public ILearningConfiguration getConfiguration() {
return null; return null;
} }

View File

@ -1,5 +1,6 @@
/******************************************************************************* /*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc. * Copyright (c) 2015-2019 Skymind, Inc.
* Copyright (c) 2020 Konduit K.K.
* *
* This program and the accompanying materials are made available under the * This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at * terms of the Apache License, Version 2.0 which is available at
@ -16,19 +17,18 @@
package org.deeplearning4j.malmo; package org.deeplearning4j.malmo;
import java.util.HashMap; import com.microsoft.msr.malmo.TimestampedStringVector;
import com.microsoft.msr.malmo.WorldState;
import org.json.JSONArray; import org.json.JSONArray;
import org.json.JSONObject; import org.json.JSONObject;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
import com.microsoft.msr.malmo.TimestampedStringVector; import java.util.HashMap;
import com.microsoft.msr.malmo.WorldState;
/** /**
* Observation space that contains a grid of Minecraft blocks * Observation space that contains a grid of Minecraft blocks
*
* @author howard-abrams (howard.abrams@ca.com) on 1/12/17. * @author howard-abrams (howard.abrams@ca.com) on 1/12/17.
*/ */
public class MalmoObservationSpaceGrid extends MalmoObservationSpace { public class MalmoObservationSpaceGrid extends MalmoObservationSpace {
@ -78,7 +78,7 @@ public class MalmoObservationSpaceGrid extends MalmoObservationSpace {
@Override @Override
public int[] getShape() { public int[] getShape() {
return new int[] {totalSize}; return new int[]{totalSize};
} }
@Override @Override