RL4J: Add Backwardly Compatible Builder patterns (#326)
* Starting to switch configs of RL algorithms to use more fluent builder patterns. Many parameter choices in different algorithms default to SOTA and only be changed in specific cases Signed-off-by: Bam4d <chris.bam4d@gmail.com> * remove personal gpu-build file Signed-off-by: Bam4d <chrisbam4d@gmail.com> * refactored out configurations so they are heirarchical and re-usable, this is a step towards having a plug-and-play framework for different algorithms * backwardly compatible configurations * adding documentation to new configuration classes Signed-off-by: Bam4d <chrisbam4d@gmail.com> * private access modifiers are better suited here Signed-off-by: Bam4d <chrisbam4d@gmail.com> * RL4j does not compile without java 8 due to previous updates fixing null pointers when listener arrays are empty Signed-off-by: Bam4d <chrisbam4d@gmail.com> * fixing copyright headers Signed-off-by: Bam4d <chrisbam4d@gmail.com> * uncomment logging line Signed-off-by: Bam4d <chrisbam4d@gmail.com> * fixing default value for learningUpdateFrequency fixing test failure due to #352 Signed-off-by: Bam4d <chrisbam4d@gmail.com> Co-authored-by: Bam4d <chris.bam4d@gmail.com>master
parent
fb1c41c512
commit
1a35ebec2e
|
@ -16,6 +16,7 @@
|
||||||
|
|
||||||
package org.datavec.api.transform.split;
|
package org.datavec.api.transform.split;
|
||||||
|
|
||||||
|
|
||||||
import lombok.AllArgsConstructor;
|
import lombok.AllArgsConstructor;
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
|
|
||||||
|
|
|
@ -20,6 +20,18 @@
|
||||||
xmlns="http://maven.apache.org/POM/4.0.0"
|
xmlns="http://maven.apache.org/POM/4.0.0"
|
||||||
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
|
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
|
||||||
<modelVersion>4.0.0</modelVersion>
|
<modelVersion>4.0.0</modelVersion>
|
||||||
|
<build>
|
||||||
|
<plugins>
|
||||||
|
<plugin>
|
||||||
|
<groupId>org.apache.maven.plugins</groupId>
|
||||||
|
<artifactId>maven-compiler-plugin</artifactId>
|
||||||
|
<configuration>
|
||||||
|
<source>8</source>
|
||||||
|
<target>8</target>
|
||||||
|
</configuration>
|
||||||
|
</plugin>
|
||||||
|
</plugins>
|
||||||
|
</build>
|
||||||
|
|
||||||
<parent>
|
<parent>
|
||||||
<artifactId>rl4j</artifactId>
|
<artifactId>rl4j</artifactId>
|
||||||
|
|
|
@ -1,3 +1,19 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2020 Konduit K.K.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
package org.deeplearning4j.rl4j.learning;
|
package org.deeplearning4j.rl4j.learning;
|
||||||
|
|
||||||
public interface EpochStepCounter {
|
public interface EpochStepCounter {
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
/*******************************************************************************
|
/*******************************************************************************
|
||||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
* Copyright (c) 2015-2019 Skymind, Inc.
|
||||||
|
* Copyright (c) 2020 Konduit K.K.
|
||||||
*
|
*
|
||||||
* This program and the accompanying materials are made available under the
|
* This program and the accompanying materials are made available under the
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
@ -16,10 +17,10 @@
|
||||||
|
|
||||||
package org.deeplearning4j.rl4j.learning;
|
package org.deeplearning4j.rl4j.learning;
|
||||||
|
|
||||||
|
import org.deeplearning4j.rl4j.learning.configuration.ILearningConfiguration;
|
||||||
import org.deeplearning4j.rl4j.mdp.MDP;
|
import org.deeplearning4j.rl4j.mdp.MDP;
|
||||||
import org.deeplearning4j.rl4j.policy.IPolicy;
|
import org.deeplearning4j.rl4j.policy.IPolicy;
|
||||||
import org.deeplearning4j.rl4j.space.ActionSpace;
|
import org.deeplearning4j.rl4j.space.ActionSpace;
|
||||||
import org.deeplearning4j.rl4j.space.Encodable;
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @author rubenfiszel (ruben.fiszel@epfl.ch) 7/19/16.
|
* @author rubenfiszel (ruben.fiszel@epfl.ch) 7/19/16.
|
||||||
|
@ -34,21 +35,12 @@ public interface ILearning<O, A, AS extends ActionSpace<A>> {
|
||||||
|
|
||||||
int getStepCounter();
|
int getStepCounter();
|
||||||
|
|
||||||
LConfiguration getConfiguration();
|
ILearningConfiguration getConfiguration();
|
||||||
|
|
||||||
MDP<O, A, AS> getMdp();
|
MDP<O, A, AS> getMdp();
|
||||||
|
|
||||||
IHistoryProcessor getHistoryProcessor();
|
IHistoryProcessor getHistoryProcessor();
|
||||||
|
|
||||||
interface LConfiguration {
|
|
||||||
|
|
||||||
Integer getSeed();
|
|
||||||
|
|
||||||
int getMaxEpochStep();
|
|
||||||
|
|
||||||
int getMaxStep();
|
|
||||||
|
|
||||||
double getGamma();
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
/*******************************************************************************
|
/*******************************************************************************
|
||||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
* Copyright (c) 2015-2019 Skymind, Inc.
|
||||||
|
* Copyright (c) 2020 Konduit K.K.
|
||||||
*
|
*
|
||||||
* This program and the accompanying materials are made available under the
|
* This program and the accompanying materials are made available under the
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
@ -19,6 +20,8 @@ package org.deeplearning4j.rl4j.learning.async;
|
||||||
import lombok.Getter;
|
import lombok.Getter;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import org.deeplearning4j.nn.gradient.Gradient;
|
import org.deeplearning4j.nn.gradient.Gradient;
|
||||||
|
import org.deeplearning4j.rl4j.learning.configuration.AsyncQLearningConfiguration;
|
||||||
|
import org.deeplearning4j.rl4j.learning.configuration.IAsyncLearningConfiguration;
|
||||||
import org.deeplearning4j.rl4j.network.NeuralNet;
|
import org.deeplearning4j.rl4j.network.NeuralNet;
|
||||||
import org.nd4j.linalg.primitives.Pair;
|
import org.nd4j.linalg.primitives.Pair;
|
||||||
|
|
||||||
|
@ -27,28 +30,26 @@ import java.util.concurrent.atomic.AtomicInteger;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @author rubenfiszel (ruben.fiszel@epfl.ch) on 8/5/16.
|
* @author rubenfiszel (ruben.fiszel@epfl.ch) on 8/5/16.
|
||||||
*
|
* <p>
|
||||||
* In the original paper, the authors uses Asynchronous
|
* In the original paper, the authors uses Asynchronous
|
||||||
* Gradient Descent: Hogwild! It is a way to apply gradients
|
* Gradient Descent: Hogwild! It is a way to apply gradients
|
||||||
* and modify a model in a lock-free manner.
|
* and modify a model in a lock-free manner.
|
||||||
*
|
* <p>
|
||||||
* As a way to implement this with dl4j, it is unfortunately
|
* As a way to implement this with dl4j, it is unfortunately
|
||||||
* necessary at the time of writing to apply the gradient
|
* necessary at the time of writing to apply the gradient
|
||||||
* (update the parameters) on a single separate global thread.
|
* (update the parameters) on a single separate global thread.
|
||||||
*
|
* <p>
|
||||||
* This Central thread for Asynchronous Method of reinforcement learning
|
* This Central thread for Asynchronous Method of reinforcement learning
|
||||||
* enqueue the gradients coming from the different threads and update its
|
* enqueue the gradients coming from the different threads and update its
|
||||||
* model and target. Those neurals nets are then synced by the other threads.
|
* model and target. Those neurals nets are then synced by the other threads.
|
||||||
*
|
* <p>
|
||||||
* The benefits of this thread is that the updater is "shared" between all thread
|
* The benefits of this thread is that the updater is "shared" between all thread
|
||||||
* we have a single updater which is the single updater of the model contained here
|
* we have a single updater which is the single updater of the model contained here
|
||||||
*
|
* <p>
|
||||||
* This is similar to RMSProp with shared g and momentum
|
* This is similar to RMSProp with shared g and momentum
|
||||||
*
|
* <p>
|
||||||
* When Hogwild! is implemented, this could be replaced by a simple data
|
* When Hogwild! is implemented, this could be replaced by a simple data
|
||||||
* structure
|
* structure
|
||||||
*
|
|
||||||
*
|
|
||||||
*/
|
*/
|
||||||
@Slf4j
|
@Slf4j
|
||||||
public class AsyncGlobal<NN extends NeuralNet> extends Thread implements IAsyncGlobal<NN> {
|
public class AsyncGlobal<NN extends NeuralNet> extends Thread implements IAsyncGlobal<NN> {
|
||||||
|
@ -56,7 +57,7 @@ public class AsyncGlobal<NN extends NeuralNet> extends Thread implements IAsyncG
|
||||||
@Getter
|
@Getter
|
||||||
final private NN current;
|
final private NN current;
|
||||||
final private ConcurrentLinkedQueue<Pair<Gradient[], Integer>> queue;
|
final private ConcurrentLinkedQueue<Pair<Gradient[], Integer>> queue;
|
||||||
final private AsyncConfiguration a3cc;
|
final private IAsyncLearningConfiguration configuration;
|
||||||
private final IAsyncLearning learning;
|
private final IAsyncLearning learning;
|
||||||
@Getter
|
@Getter
|
||||||
private AtomicInteger T = new AtomicInteger(0);
|
private AtomicInteger T = new AtomicInteger(0);
|
||||||
|
@ -65,20 +66,20 @@ public class AsyncGlobal<NN extends NeuralNet> extends Thread implements IAsyncG
|
||||||
@Getter
|
@Getter
|
||||||
private boolean running = true;
|
private boolean running = true;
|
||||||
|
|
||||||
public AsyncGlobal(NN initial, AsyncConfiguration a3cc, IAsyncLearning learning) {
|
public AsyncGlobal(NN initial, IAsyncLearningConfiguration configuration, IAsyncLearning learning) {
|
||||||
this.current = initial;
|
this.current = initial;
|
||||||
target = (NN) initial.clone();
|
target = (NN) initial.clone();
|
||||||
this.a3cc = a3cc;
|
this.configuration = configuration;
|
||||||
this.learning = learning;
|
this.learning = learning;
|
||||||
queue = new ConcurrentLinkedQueue<>();
|
queue = new ConcurrentLinkedQueue<>();
|
||||||
}
|
}
|
||||||
|
|
||||||
public boolean isTrainingComplete() {
|
public boolean isTrainingComplete() {
|
||||||
return T.get() >= a3cc.getMaxStep();
|
return T.get() >= configuration.getMaxStep();
|
||||||
}
|
}
|
||||||
|
|
||||||
public void enqueue(Gradient[] gradient, Integer nstep) {
|
public void enqueue(Gradient[] gradient, Integer nstep) {
|
||||||
if(running && !isTrainingComplete()) {
|
if (running && !isTrainingComplete()) {
|
||||||
queue.add(new Pair<>(gradient, nstep));
|
queue.add(new Pair<>(gradient, nstep));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -94,9 +95,8 @@ public class AsyncGlobal<NN extends NeuralNet> extends Thread implements IAsyncG
|
||||||
synchronized (this) {
|
synchronized (this) {
|
||||||
current.applyGradient(gradient, pair.getSecond());
|
current.applyGradient(gradient, pair.getSecond());
|
||||||
}
|
}
|
||||||
if (a3cc.getTargetDqnUpdateFreq() != -1
|
if (configuration.getLearnerUpdateFrequency() != -1 && T.get() / configuration.getLearnerUpdateFrequency() > (T.get() - pair.getSecond())
|
||||||
&& T.get() / a3cc.getTargetDqnUpdateFreq() > (T.get() - pair.getSecond())
|
/ configuration.getLearnerUpdateFrequency()) {
|
||||||
/ a3cc.getTargetDqnUpdateFreq()) {
|
|
||||||
log.info("TARGET UPDATE at T = " + T.get());
|
log.info("TARGET UPDATE at T = " + T.get());
|
||||||
synchronized (this) {
|
synchronized (this) {
|
||||||
target.copy(current);
|
target.copy(current);
|
||||||
|
@ -111,7 +111,7 @@ public class AsyncGlobal<NN extends NeuralNet> extends Thread implements IAsyncG
|
||||||
* Force the immediate termination of the AsyncGlobal instance. Queued work items will be discarded and the AsyncLearning instance will be forced to terminate too.
|
* Force the immediate termination of the AsyncGlobal instance. Queued work items will be discarded and the AsyncLearning instance will be forced to terminate too.
|
||||||
*/
|
*/
|
||||||
public void terminate() {
|
public void terminate() {
|
||||||
if(running) {
|
if (running) {
|
||||||
running = false;
|
running = false;
|
||||||
queue.clear();
|
queue.clear();
|
||||||
learning.terminate();
|
learning.terminate();
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
/*******************************************************************************
|
/*******************************************************************************
|
||||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
* Copyright (c) 2015-2019 Skymind, Inc.
|
||||||
|
* Copyright (c) 2020 Konduit K.K.
|
||||||
*
|
*
|
||||||
* This program and the accompanying materials are made available under the
|
* This program and the accompanying materials are made available under the
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
@ -21,14 +22,17 @@ import lombok.Getter;
|
||||||
import lombok.Setter;
|
import lombok.Setter;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import org.deeplearning4j.rl4j.learning.Learning;
|
import org.deeplearning4j.rl4j.learning.Learning;
|
||||||
import org.deeplearning4j.rl4j.learning.listener.*;
|
import org.deeplearning4j.rl4j.learning.configuration.AsyncQLearningConfiguration;
|
||||||
|
import org.deeplearning4j.rl4j.learning.configuration.IAsyncLearningConfiguration;
|
||||||
|
import org.deeplearning4j.rl4j.learning.listener.TrainingListener;
|
||||||
|
import org.deeplearning4j.rl4j.learning.listener.TrainingListenerList;
|
||||||
import org.deeplearning4j.rl4j.network.NeuralNet;
|
import org.deeplearning4j.rl4j.network.NeuralNet;
|
||||||
import org.deeplearning4j.rl4j.space.ActionSpace;
|
import org.deeplearning4j.rl4j.space.ActionSpace;
|
||||||
import org.deeplearning4j.rl4j.space.Encodable;
|
import org.deeplearning4j.rl4j.space.Encodable;
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* The entry point for async training. This class will start a number ({@link AsyncConfiguration#getNumThread()
|
* The entry point for async training. This class will start a number ({@link AsyncQLearningConfiguration#getNumThreads()
|
||||||
* configuration.getNumThread()}) of worker threads. Then, it will monitor their progress at regular intervals
|
* configuration.getNumThread()}) of worker threads. Then, it will monitor their progress at regular intervals
|
||||||
* (see setProgressEventInterval(int))
|
* (see setProgressEventInterval(int))
|
||||||
*
|
*
|
||||||
|
@ -37,8 +41,8 @@ import org.nd4j.linalg.factory.Nd4j;
|
||||||
*/
|
*/
|
||||||
@Slf4j
|
@Slf4j
|
||||||
public abstract class AsyncLearning<O extends Encodable, A, AS extends ActionSpace<A>, NN extends NeuralNet>
|
public abstract class AsyncLearning<O extends Encodable, A, AS extends ActionSpace<A>, NN extends NeuralNet>
|
||||||
extends Learning<O, A, AS, NN>
|
extends Learning<O, A, AS, NN>
|
||||||
implements IAsyncLearning {
|
implements IAsyncLearning {
|
||||||
|
|
||||||
private Thread monitorThread = null;
|
private Thread monitorThread = null;
|
||||||
|
|
||||||
|
@ -56,9 +60,10 @@ public abstract class AsyncLearning<O extends Encodable, A, AS extends ActionSpa
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Returns the configuration
|
* Returns the configuration
|
||||||
* @return the configuration (see {@link AsyncConfiguration})
|
*
|
||||||
|
* @return the configuration (see {@link AsyncQLearningConfiguration})
|
||||||
*/
|
*/
|
||||||
public abstract AsyncConfiguration getConfiguration();
|
public abstract IAsyncLearningConfiguration getConfiguration();
|
||||||
|
|
||||||
protected abstract AsyncThread newThread(int i, int deviceAffinity);
|
protected abstract AsyncThread newThread(int i, int deviceAffinity);
|
||||||
|
|
||||||
|
@ -77,12 +82,13 @@ public abstract class AsyncLearning<O extends Encodable, A, AS extends ActionSpa
|
||||||
/**
|
/**
|
||||||
* Number of milliseconds between calls to onTrainingProgress
|
* Number of milliseconds between calls to onTrainingProgress
|
||||||
*/
|
*/
|
||||||
@Getter @Setter
|
@Getter
|
||||||
|
@Setter
|
||||||
private int progressMonitorFrequency = 20000;
|
private int progressMonitorFrequency = 20000;
|
||||||
|
|
||||||
private void launchThreads() {
|
private void launchThreads() {
|
||||||
startGlobalThread();
|
startGlobalThread();
|
||||||
for (int i = 0; i < getConfiguration().getNumThread(); i++) {
|
for (int i = 0; i < getConfiguration().getNumThreads(); i++) {
|
||||||
Thread t = newThread(i, i % Nd4j.getAffinityManager().getNumberOfDevices());
|
Thread t = newThread(i, i % Nd4j.getAffinityManager().getNumberOfDevices());
|
||||||
t.start();
|
t.start();
|
||||||
}
|
}
|
||||||
|
@ -132,7 +138,7 @@ public abstract class AsyncLearning<O extends Encodable, A, AS extends ActionSpa
|
||||||
monitorThread = Thread.currentThread();
|
monitorThread = Thread.currentThread();
|
||||||
while (canContinue && !isTrainingComplete() && getAsyncGlobal().isRunning()) {
|
while (canContinue && !isTrainingComplete() && getAsyncGlobal().isRunning()) {
|
||||||
canContinue = listeners.notifyTrainingProgress(this);
|
canContinue = listeners.notifyTrainingProgress(this);
|
||||||
if(!canContinue) {
|
if (!canContinue) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -155,11 +161,11 @@ public abstract class AsyncLearning<O extends Encodable, A, AS extends ActionSpa
|
||||||
* Force the immediate termination of the learning. All learning threads, the AsyncGlobal thread and the monitor thread will be terminated.
|
* Force the immediate termination of the learning. All learning threads, the AsyncGlobal thread and the monitor thread will be terminated.
|
||||||
*/
|
*/
|
||||||
public void terminate() {
|
public void terminate() {
|
||||||
if(canContinue) {
|
if (canContinue) {
|
||||||
canContinue = false;
|
canContinue = false;
|
||||||
|
|
||||||
Thread safeMonitorThread = monitorThread;
|
Thread safeMonitorThread = monitorThread;
|
||||||
if(safeMonitorThread != null) {
|
if (safeMonitorThread != null) {
|
||||||
safeMonitorThread.interrupt();
|
safeMonitorThread.interrupt();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
/*******************************************************************************
|
/*******************************************************************************
|
||||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
* Copyright (c) 2015-2019 Skymind, Inc.
|
||||||
|
* Copyright (c) 2020 Konduit K.K.
|
||||||
*
|
*
|
||||||
* This program and the accompanying materials are made available under the
|
* This program and the accompanying materials are made available under the
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
@ -23,6 +24,7 @@ import lombok.Value;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import org.deeplearning4j.gym.StepReply;
|
import org.deeplearning4j.gym.StepReply;
|
||||||
import org.deeplearning4j.rl4j.learning.*;
|
import org.deeplearning4j.rl4j.learning.*;
|
||||||
|
import org.deeplearning4j.rl4j.learning.configuration.IAsyncLearningConfiguration;
|
||||||
import org.deeplearning4j.rl4j.learning.listener.TrainingListener;
|
import org.deeplearning4j.rl4j.learning.listener.TrainingListener;
|
||||||
import org.deeplearning4j.rl4j.learning.listener.TrainingListenerList;
|
import org.deeplearning4j.rl4j.learning.listener.TrainingListenerList;
|
||||||
import org.deeplearning4j.rl4j.mdp.MDP;
|
import org.deeplearning4j.rl4j.mdp.MDP;
|
||||||
|
@ -155,7 +157,7 @@ public abstract class AsyncThread<O, A, AS extends ActionSpace<A>, NN extends Ne
|
||||||
}
|
}
|
||||||
|
|
||||||
private void handleTraining(RunContext context) {
|
private void handleTraining(RunContext context) {
|
||||||
int maxSteps = Math.min(getConf().getNstep(), getConf().getMaxEpochStep() - currentEpochStep);
|
int maxSteps = Math.min(getConf().getNStep(), getConf().getMaxEpochStep() - currentEpochStep);
|
||||||
SubEpochReturn subEpochReturn = trainSubEpoch(context.obs, maxSteps);
|
SubEpochReturn subEpochReturn = trainSubEpoch(context.obs, maxSteps);
|
||||||
|
|
||||||
context.obs = subEpochReturn.getLastObs();
|
context.obs = subEpochReturn.getLastObs();
|
||||||
|
@ -197,7 +199,7 @@ public abstract class AsyncThread<O, A, AS extends ActionSpace<A>, NN extends Ne
|
||||||
|
|
||||||
protected abstract IAsyncGlobal<NN> getAsyncGlobal();
|
protected abstract IAsyncGlobal<NN> getAsyncGlobal();
|
||||||
|
|
||||||
protected abstract AsyncConfiguration getConf();
|
protected abstract IAsyncLearningConfiguration getConf();
|
||||||
|
|
||||||
protected abstract IPolicy<O, A> getPolicy(NN net);
|
protected abstract IPolicy<O, A> getPolicy(NN net);
|
||||||
|
|
||||||
|
|
|
@ -1,5 +1,7 @@
|
||||||
/*******************************************************************************
|
/*******************************************************************************
|
||||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
* Copyright (c) 2015-2019 Skymind, Inc.
|
||||||
|
* Copyright (c) 2020 Konduit K.K.
|
||||||
|
*
|
||||||
*
|
*
|
||||||
* This program and the accompanying materials are made available under the
|
* This program and the accompanying materials are made available under the
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
@ -112,7 +114,7 @@ public abstract class AsyncThreadDiscrete<O, NN extends NeuralNet>
|
||||||
rewards.add(new MiniTrans(obs.getData(), null, null, 0));
|
rewards.add(new MiniTrans(obs.getData(), null, null, 0));
|
||||||
else {
|
else {
|
||||||
INDArray[] output = null;
|
INDArray[] output = null;
|
||||||
if (getConf().getTargetDqnUpdateFreq() == -1)
|
if (getConf().getLearnerUpdateFrequency() == -1)
|
||||||
output = current.outputAll(obs.getData());
|
output = current.outputAll(obs.getData());
|
||||||
else synchronized (getAsyncGlobal()) {
|
else synchronized (getAsyncGlobal()) {
|
||||||
output = getAsyncGlobal().getTarget().outputAll(obs.getData());
|
output = getAsyncGlobal().getTarget().outputAll(obs.getData());
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
/*******************************************************************************
|
/*******************************************************************************
|
||||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
* Copyright (c) 2015-2019 Skymind, Inc.
|
||||||
|
* Copyright (c) 2020 Konduit K.K.
|
||||||
*
|
*
|
||||||
* This program and the accompanying materials are made available under the
|
* This program and the accompanying materials are made available under the
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
@ -16,11 +17,15 @@
|
||||||
|
|
||||||
package org.deeplearning4j.rl4j.learning.async.a3c.discrete;
|
package org.deeplearning4j.rl4j.learning.async.a3c.discrete;
|
||||||
|
|
||||||
import lombok.*;
|
import lombok.AllArgsConstructor;
|
||||||
import org.deeplearning4j.rl4j.learning.async.AsyncConfiguration;
|
import lombok.Builder;
|
||||||
|
import lombok.Data;
|
||||||
|
import lombok.EqualsAndHashCode;
|
||||||
|
import lombok.Getter;
|
||||||
import org.deeplearning4j.rl4j.learning.async.AsyncGlobal;
|
import org.deeplearning4j.rl4j.learning.async.AsyncGlobal;
|
||||||
import org.deeplearning4j.rl4j.learning.async.AsyncLearning;
|
import org.deeplearning4j.rl4j.learning.async.AsyncLearning;
|
||||||
import org.deeplearning4j.rl4j.learning.async.AsyncThread;
|
import org.deeplearning4j.rl4j.learning.async.AsyncThread;
|
||||||
|
import org.deeplearning4j.rl4j.learning.configuration.A3CLearningConfiguration;
|
||||||
import org.deeplearning4j.rl4j.mdp.MDP;
|
import org.deeplearning4j.rl4j.mdp.MDP;
|
||||||
import org.deeplearning4j.rl4j.network.ac.IActorCritic;
|
import org.deeplearning4j.rl4j.network.ac.IActorCritic;
|
||||||
import org.deeplearning4j.rl4j.policy.ACPolicy;
|
import org.deeplearning4j.rl4j.policy.ACPolicy;
|
||||||
|
@ -32,15 +37,14 @@ import org.nd4j.linalg.factory.Nd4j;
|
||||||
/**
|
/**
|
||||||
* @author rubenfiszel (ruben.fiszel@epfl.ch) 7/23/16.
|
* @author rubenfiszel (ruben.fiszel@epfl.ch) 7/23/16.
|
||||||
* Training for A3C in the Discrete Domain
|
* Training for A3C in the Discrete Domain
|
||||||
*
|
* <p>
|
||||||
* All methods are fully implemented as described in the
|
* All methods are fully implemented as described in the
|
||||||
* https://arxiv.org/abs/1602.01783 paper.
|
* https://arxiv.org/abs/1602.01783 paper.
|
||||||
*
|
|
||||||
*/
|
*/
|
||||||
public abstract class A3CDiscrete<O extends Encodable> extends AsyncLearning<O, Integer, DiscreteSpace, IActorCritic> {
|
public abstract class A3CDiscrete<O extends Encodable> extends AsyncLearning<O, Integer, DiscreteSpace, IActorCritic> {
|
||||||
|
|
||||||
@Getter
|
@Getter
|
||||||
final public A3CConfiguration configuration;
|
final public A3CLearningConfiguration configuration;
|
||||||
@Getter
|
@Getter
|
||||||
final protected MDP<O, Integer, DiscreteSpace> mdp;
|
final protected MDP<O, Integer, DiscreteSpace> mdp;
|
||||||
final private IActorCritic iActorCritic;
|
final private IActorCritic iActorCritic;
|
||||||
|
@ -49,15 +53,15 @@ public abstract class A3CDiscrete<O extends Encodable> extends AsyncLearning<O,
|
||||||
@Getter
|
@Getter
|
||||||
final private ACPolicy<O> policy;
|
final private ACPolicy<O> policy;
|
||||||
|
|
||||||
public A3CDiscrete(MDP<O, Integer, DiscreteSpace> mdp, IActorCritic iActorCritic, A3CConfiguration conf) {
|
public A3CDiscrete(MDP<O, Integer, DiscreteSpace> mdp, IActorCritic iActorCritic, A3CLearningConfiguration conf) {
|
||||||
this.iActorCritic = iActorCritic;
|
this.iActorCritic = iActorCritic;
|
||||||
this.mdp = mdp;
|
this.mdp = mdp;
|
||||||
this.configuration = conf;
|
this.configuration = conf;
|
||||||
asyncGlobal = new AsyncGlobal<>(iActorCritic, conf, this);
|
asyncGlobal = new AsyncGlobal<>(iActorCritic, conf, this);
|
||||||
|
|
||||||
Integer seed = conf.getSeed();
|
Long seed = conf.getSeed();
|
||||||
Random rnd = Nd4j.getRandom();
|
Random rnd = Nd4j.getRandom();
|
||||||
if(seed != null) {
|
if (seed != null) {
|
||||||
rnd.setSeed(seed);
|
rnd.setSeed(seed);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -65,7 +69,7 @@ public abstract class A3CDiscrete<O extends Encodable> extends AsyncLearning<O,
|
||||||
}
|
}
|
||||||
|
|
||||||
protected AsyncThread newThread(int i, int deviceNum) {
|
protected AsyncThread newThread(int i, int deviceNum) {
|
||||||
return new A3CThreadDiscrete(mdp.newInstance(), asyncGlobal, getConfiguration(), deviceNum, getListeners(), i);
|
return new A3CThreadDiscrete(mdp.newInstance(), asyncGlobal, this.getConfiguration(), deviceNum, getListeners(), i);
|
||||||
}
|
}
|
||||||
|
|
||||||
public IActorCritic getNeuralNet() {
|
public IActorCritic getNeuralNet() {
|
||||||
|
@ -76,9 +80,9 @@ public abstract class A3CDiscrete<O extends Encodable> extends AsyncLearning<O,
|
||||||
@AllArgsConstructor
|
@AllArgsConstructor
|
||||||
@Builder
|
@Builder
|
||||||
@EqualsAndHashCode(callSuper = false)
|
@EqualsAndHashCode(callSuper = false)
|
||||||
public static class A3CConfiguration implements AsyncConfiguration {
|
public static class A3CConfiguration {
|
||||||
|
|
||||||
Integer seed;
|
int seed;
|
||||||
int maxEpochStep;
|
int maxEpochStep;
|
||||||
int maxStep;
|
int maxStep;
|
||||||
int numThread;
|
int numThread;
|
||||||
|
@ -88,8 +92,20 @@ public abstract class A3CDiscrete<O extends Encodable> extends AsyncLearning<O,
|
||||||
double gamma;
|
double gamma;
|
||||||
double errorClamp;
|
double errorClamp;
|
||||||
|
|
||||||
public int getTargetDqnUpdateFreq() {
|
/**
|
||||||
return -1;
|
* Converts the deprecated A3CConfiguration to the new LearningConfiguration format
|
||||||
|
*/
|
||||||
|
public A3CLearningConfiguration toLearningConfiguration() {
|
||||||
|
return A3CLearningConfiguration.builder()
|
||||||
|
.seed(new Long(seed))
|
||||||
|
.maxEpochStep(maxEpochStep)
|
||||||
|
.maxStep(maxStep)
|
||||||
|
.numThreads(numThread)
|
||||||
|
.nStep(nstep)
|
||||||
|
.rewardFactor(rewardFactor)
|
||||||
|
.gamma(gamma)
|
||||||
|
.build();
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
/*******************************************************************************
|
/*******************************************************************************
|
||||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
* Copyright (c) 2015-2019 Skymind, Inc.
|
||||||
|
* Copyright (c) 2020 Konduit K.K.
|
||||||
*
|
*
|
||||||
* This program and the accompanying materials are made available under the
|
* This program and the accompanying materials are made available under the
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
@ -18,10 +19,12 @@ package org.deeplearning4j.rl4j.learning.async.a3c.discrete;
|
||||||
|
|
||||||
import org.deeplearning4j.rl4j.learning.HistoryProcessor;
|
import org.deeplearning4j.rl4j.learning.HistoryProcessor;
|
||||||
import org.deeplearning4j.rl4j.learning.async.AsyncThread;
|
import org.deeplearning4j.rl4j.learning.async.AsyncThread;
|
||||||
|
import org.deeplearning4j.rl4j.learning.configuration.A3CLearningConfiguration;
|
||||||
import org.deeplearning4j.rl4j.mdp.MDP;
|
import org.deeplearning4j.rl4j.mdp.MDP;
|
||||||
import org.deeplearning4j.rl4j.network.ac.ActorCriticFactoryCompGraph;
|
import org.deeplearning4j.rl4j.network.ac.ActorCriticFactoryCompGraph;
|
||||||
import org.deeplearning4j.rl4j.network.ac.ActorCriticFactoryCompGraphStdConv;
|
import org.deeplearning4j.rl4j.network.ac.ActorCriticFactoryCompGraphStdConv;
|
||||||
import org.deeplearning4j.rl4j.network.ac.IActorCritic;
|
import org.deeplearning4j.rl4j.network.ac.IActorCritic;
|
||||||
|
import org.deeplearning4j.rl4j.network.configuration.ActorCriticNetworkConfiguration;
|
||||||
import org.deeplearning4j.rl4j.space.DiscreteSpace;
|
import org.deeplearning4j.rl4j.space.DiscreteSpace;
|
||||||
import org.deeplearning4j.rl4j.space.Encodable;
|
import org.deeplearning4j.rl4j.space.Encodable;
|
||||||
import org.deeplearning4j.rl4j.util.DataManagerTrainingListener;
|
import org.deeplearning4j.rl4j.util.DataManagerTrainingListener;
|
||||||
|
@ -29,16 +32,15 @@ import org.deeplearning4j.rl4j.util.IDataManager;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @author rubenfiszel (ruben.fiszel@epfl.ch) on 8/8/16.
|
* @author rubenfiszel (ruben.fiszel@epfl.ch) on 8/8/16.
|
||||||
*
|
* <p>
|
||||||
* Training for A3C in the Discrete Domain
|
* Training for A3C in the Discrete Domain
|
||||||
*
|
* <p>
|
||||||
* Specialized constructors for the Conv (pixels input) case
|
* Specialized constructors for the Conv (pixels input) case
|
||||||
* Specialized conf + provide additional type safety
|
* Specialized conf + provide additional type safety
|
||||||
*
|
* <p>
|
||||||
* It uses CompGraph because there is benefit to combine the
|
* It uses CompGraph because there is benefit to combine the
|
||||||
* first layers since they're essentially doing the same dimension
|
* first layers since they're essentially doing the same dimension
|
||||||
* reduction task
|
* reduction task
|
||||||
*
|
|
||||||
**/
|
**/
|
||||||
public class A3CDiscreteConv<O extends Encodable> extends A3CDiscrete<O> {
|
public class A3CDiscreteConv<O extends Encodable> extends A3CDiscrete<O> {
|
||||||
|
|
||||||
|
@ -46,12 +48,22 @@ public class A3CDiscreteConv<O extends Encodable> extends A3CDiscrete<O> {
|
||||||
|
|
||||||
@Deprecated
|
@Deprecated
|
||||||
public A3CDiscreteConv(MDP<O, Integer, DiscreteSpace> mdp, IActorCritic actorCritic,
|
public A3CDiscreteConv(MDP<O, Integer, DiscreteSpace> mdp, IActorCritic actorCritic,
|
||||||
HistoryProcessor.Configuration hpconf, A3CConfiguration conf, IDataManager dataManager) {
|
HistoryProcessor.Configuration hpconf, A3CConfiguration conf, IDataManager dataManager) {
|
||||||
this(mdp, actorCritic, hpconf, conf);
|
this(mdp, actorCritic, hpconf, conf);
|
||||||
addListener(new DataManagerTrainingListener(dataManager));
|
addListener(new DataManagerTrainingListener(dataManager));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Deprecated
|
||||||
public A3CDiscreteConv(MDP<O, Integer, DiscreteSpace> mdp, IActorCritic IActorCritic,
|
public A3CDiscreteConv(MDP<O, Integer, DiscreteSpace> mdp, IActorCritic IActorCritic,
|
||||||
HistoryProcessor.Configuration hpconf, A3CConfiguration conf) {
|
HistoryProcessor.Configuration hpconf, A3CConfiguration conf) {
|
||||||
|
|
||||||
|
super(mdp, IActorCritic, conf.toLearningConfiguration());
|
||||||
|
this.hpconf = hpconf;
|
||||||
|
setHistoryProcessor(hpconf);
|
||||||
|
}
|
||||||
|
|
||||||
|
public A3CDiscreteConv(MDP<O, Integer, DiscreteSpace> mdp, IActorCritic IActorCritic,
|
||||||
|
HistoryProcessor.Configuration hpconf, A3CLearningConfiguration conf) {
|
||||||
super(mdp, IActorCritic, conf);
|
super(mdp, IActorCritic, conf);
|
||||||
this.hpconf = hpconf;
|
this.hpconf = hpconf;
|
||||||
setHistoryProcessor(hpconf);
|
setHistoryProcessor(hpconf);
|
||||||
|
@ -59,21 +71,35 @@ public class A3CDiscreteConv<O extends Encodable> extends A3CDiscrete<O> {
|
||||||
|
|
||||||
@Deprecated
|
@Deprecated
|
||||||
public A3CDiscreteConv(MDP<O, Integer, DiscreteSpace> mdp, ActorCriticFactoryCompGraph factory,
|
public A3CDiscreteConv(MDP<O, Integer, DiscreteSpace> mdp, ActorCriticFactoryCompGraph factory,
|
||||||
HistoryProcessor.Configuration hpconf, A3CConfiguration conf, IDataManager dataManager) {
|
HistoryProcessor.Configuration hpconf, A3CConfiguration conf, IDataManager dataManager) {
|
||||||
this(mdp, factory.buildActorCritic(hpconf.getShape(), mdp.getActionSpace().getSize()), hpconf, conf, dataManager);
|
this(mdp, factory.buildActorCritic(hpconf.getShape(), mdp.getActionSpace().getSize()), hpconf, conf, dataManager);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Deprecated
|
||||||
public A3CDiscreteConv(MDP<O, Integer, DiscreteSpace> mdp, ActorCriticFactoryCompGraph factory,
|
public A3CDiscreteConv(MDP<O, Integer, DiscreteSpace> mdp, ActorCriticFactoryCompGraph factory,
|
||||||
HistoryProcessor.Configuration hpconf, A3CConfiguration conf) {
|
HistoryProcessor.Configuration hpconf, A3CConfiguration conf) {
|
||||||
this(mdp, factory.buildActorCritic(hpconf.getShape(), mdp.getActionSpace().getSize()), hpconf, conf);
|
this(mdp, factory.buildActorCritic(hpconf.getShape(), mdp.getActionSpace().getSize()), hpconf, conf);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public A3CDiscreteConv(MDP<O, Integer, DiscreteSpace> mdp, ActorCriticFactoryCompGraph factory,
|
||||||
|
HistoryProcessor.Configuration hpconf, A3CLearningConfiguration conf) {
|
||||||
|
this(mdp, factory.buildActorCritic(hpconf.getShape(), mdp.getActionSpace().getSize()), hpconf, conf);
|
||||||
|
}
|
||||||
|
|
||||||
@Deprecated
|
@Deprecated
|
||||||
public A3CDiscreteConv(MDP<O, Integer, DiscreteSpace> mdp, ActorCriticFactoryCompGraphStdConv.Configuration netConf,
|
public A3CDiscreteConv(MDP<O, Integer, DiscreteSpace> mdp, ActorCriticFactoryCompGraphStdConv.Configuration netConf,
|
||||||
HistoryProcessor.Configuration hpconf, A3CConfiguration conf, IDataManager dataManager) {
|
HistoryProcessor.Configuration hpconf, A3CConfiguration conf, IDataManager dataManager) {
|
||||||
this(mdp, new ActorCriticFactoryCompGraphStdConv(netConf), hpconf, conf, dataManager);
|
this(mdp, new ActorCriticFactoryCompGraphStdConv(netConf.toNetworkConfiguration()), hpconf, conf, dataManager);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Deprecated
|
||||||
public A3CDiscreteConv(MDP<O, Integer, DiscreteSpace> mdp, ActorCriticFactoryCompGraphStdConv.Configuration netConf,
|
public A3CDiscreteConv(MDP<O, Integer, DiscreteSpace> mdp, ActorCriticFactoryCompGraphStdConv.Configuration netConf,
|
||||||
HistoryProcessor.Configuration hpconf, A3CConfiguration conf) {
|
HistoryProcessor.Configuration hpconf, A3CConfiguration conf) {
|
||||||
|
this(mdp, new ActorCriticFactoryCompGraphStdConv(netConf.toNetworkConfiguration()), hpconf, conf);
|
||||||
|
}
|
||||||
|
|
||||||
|
public A3CDiscreteConv(MDP<O, Integer, DiscreteSpace> mdp, ActorCriticNetworkConfiguration netConf,
|
||||||
|
HistoryProcessor.Configuration hpconf, A3CLearningConfiguration conf) {
|
||||||
this(mdp, new ActorCriticFactoryCompGraphStdConv(netConf), hpconf, conf);
|
this(mdp, new ActorCriticFactoryCompGraphStdConv(netConf), hpconf, conf);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
/*******************************************************************************
|
/*******************************************************************************
|
||||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
* Copyright (c) 2015-2019 Skymind, Inc.
|
||||||
|
* Copyright (c) 2020 Konduit K.K.
|
||||||
*
|
*
|
||||||
* This program and the accompanying materials are made available under the
|
* This program and the accompanying materials are made available under the
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
@ -16,8 +17,10 @@
|
||||||
|
|
||||||
package org.deeplearning4j.rl4j.learning.async.a3c.discrete;
|
package org.deeplearning4j.rl4j.learning.async.a3c.discrete;
|
||||||
|
|
||||||
|
import org.deeplearning4j.rl4j.learning.configuration.A3CLearningConfiguration;
|
||||||
import org.deeplearning4j.rl4j.mdp.MDP;
|
import org.deeplearning4j.rl4j.mdp.MDP;
|
||||||
import org.deeplearning4j.rl4j.network.ac.*;
|
import org.deeplearning4j.rl4j.network.ac.*;
|
||||||
|
import org.deeplearning4j.rl4j.network.configuration.ActorCriticDenseNetworkConfiguration;
|
||||||
import org.deeplearning4j.rl4j.space.DiscreteSpace;
|
import org.deeplearning4j.rl4j.space.DiscreteSpace;
|
||||||
import org.deeplearning4j.rl4j.space.Encodable;
|
import org.deeplearning4j.rl4j.space.Encodable;
|
||||||
import org.deeplearning4j.rl4j.util.DataManagerTrainingListener;
|
import org.deeplearning4j.rl4j.util.DataManagerTrainingListener;
|
||||||
|
@ -25,67 +28,81 @@ import org.deeplearning4j.rl4j.util.IDataManager;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @author rubenfiszel (ruben.fiszel@epfl.ch) on 8/8/16.
|
* @author rubenfiszel (ruben.fiszel@epfl.ch) on 8/8/16.
|
||||||
*
|
* <p>
|
||||||
* Training for A3C in the Discrete Domain
|
* Training for A3C in the Discrete Domain
|
||||||
*
|
* <p>
|
||||||
* We use specifically the Separate version because
|
* We use specifically the Separate version because
|
||||||
* the model is too small to have enough benefit by sharing layers
|
* the model is too small to have enough benefit by sharing layers
|
||||||
*
|
|
||||||
*/
|
*/
|
||||||
public class A3CDiscreteDense<O extends Encodable> extends A3CDiscrete<O> {
|
public class A3CDiscreteDense<O extends Encodable> extends A3CDiscrete<O> {
|
||||||
|
|
||||||
@Deprecated
|
@Deprecated
|
||||||
public A3CDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp, IActorCritic IActorCritic, A3CConfiguration conf,
|
public A3CDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp, IActorCritic IActorCritic, A3CConfiguration conf,
|
||||||
IDataManager dataManager) {
|
IDataManager dataManager) {
|
||||||
this(mdp, IActorCritic, conf);
|
this(mdp, IActorCritic, conf);
|
||||||
addListener(new DataManagerTrainingListener(dataManager));
|
addListener(new DataManagerTrainingListener(dataManager));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Deprecated
|
||||||
public A3CDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp, IActorCritic actorCritic, A3CConfiguration conf) {
|
public A3CDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp, IActorCritic actorCritic, A3CConfiguration conf) {
|
||||||
|
super(mdp, actorCritic, conf.toLearningConfiguration());
|
||||||
|
}
|
||||||
|
|
||||||
|
public A3CDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp, IActorCritic actorCritic, A3CLearningConfiguration conf) {
|
||||||
super(mdp, actorCritic, conf);
|
super(mdp, actorCritic, conf);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Deprecated
|
@Deprecated
|
||||||
public A3CDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp, ActorCriticFactorySeparate factory,
|
public A3CDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp, ActorCriticFactorySeparate factory,
|
||||||
A3CConfiguration conf, IDataManager dataManager) {
|
A3CConfiguration conf, IDataManager dataManager) {
|
||||||
this(mdp, factory.buildActorCritic(mdp.getObservationSpace().getShape(), mdp.getActionSpace().getSize()), conf,
|
this(mdp, factory.buildActorCritic(mdp.getObservationSpace().getShape(), mdp.getActionSpace().getSize()), conf,
|
||||||
dataManager);
|
dataManager);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Deprecated
|
||||||
public A3CDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp, ActorCriticFactorySeparate factory,
|
public A3CDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp, ActorCriticFactorySeparate factory,
|
||||||
A3CConfiguration conf) {
|
A3CConfiguration conf) {
|
||||||
this(mdp, factory.buildActorCritic(mdp.getObservationSpace().getShape(), mdp.getActionSpace().getSize()), conf);
|
this(mdp, factory.buildActorCritic(mdp.getObservationSpace().getShape(), mdp.getActionSpace().getSize()), conf);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Deprecated
|
public A3CDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp, ActorCriticFactorySeparate factory,
|
||||||
public A3CDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp,
|
A3CLearningConfiguration conf) {
|
||||||
ActorCriticFactorySeparateStdDense.Configuration netConf, A3CConfiguration conf,
|
|
||||||
IDataManager dataManager) {
|
|
||||||
this(mdp, new ActorCriticFactorySeparateStdDense(netConf), conf, dataManager);
|
|
||||||
}
|
|
||||||
public A3CDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp,
|
|
||||||
ActorCriticFactorySeparateStdDense.Configuration netConf, A3CConfiguration conf) {
|
|
||||||
this(mdp, new ActorCriticFactorySeparateStdDense(netConf), conf);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Deprecated
|
|
||||||
public A3CDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp, ActorCriticFactoryCompGraph factory,
|
|
||||||
A3CConfiguration conf, IDataManager dataManager) {
|
|
||||||
this(mdp, factory.buildActorCritic(mdp.getObservationSpace().getShape(), mdp.getActionSpace().getSize()), conf,
|
|
||||||
dataManager);
|
|
||||||
}
|
|
||||||
public A3CDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp, ActorCriticFactoryCompGraph factory,
|
|
||||||
A3CConfiguration conf) {
|
|
||||||
this(mdp, factory.buildActorCritic(mdp.getObservationSpace().getShape(), mdp.getActionSpace().getSize()), conf);
|
this(mdp, factory.buildActorCritic(mdp.getObservationSpace().getShape(), mdp.getActionSpace().getSize()), conf);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Deprecated
|
@Deprecated
|
||||||
public A3CDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp,
|
public A3CDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp,
|
||||||
ActorCriticFactoryCompGraphStdDense.Configuration netConf, A3CConfiguration conf,
|
ActorCriticFactorySeparateStdDense.Configuration netConf, A3CConfiguration conf,
|
||||||
IDataManager dataManager) {
|
IDataManager dataManager) {
|
||||||
this(mdp, new ActorCriticFactoryCompGraphStdDense(netConf), conf, dataManager);
|
this(mdp, new ActorCriticFactorySeparateStdDense(netConf.toNetworkConfiguration()), conf, dataManager);
|
||||||
}
|
|
||||||
public A3CDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp,
|
|
||||||
ActorCriticFactoryCompGraphStdDense.Configuration netConf, A3CConfiguration conf) {
|
|
||||||
this(mdp, new ActorCriticFactoryCompGraphStdDense(netConf), conf);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Deprecated
|
||||||
|
public A3CDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp,
|
||||||
|
ActorCriticFactorySeparateStdDense.Configuration netConf, A3CConfiguration conf) {
|
||||||
|
this(mdp, new ActorCriticFactorySeparateStdDense(netConf.toNetworkConfiguration()), conf);
|
||||||
|
}
|
||||||
|
|
||||||
|
public A3CDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp,
|
||||||
|
ActorCriticDenseNetworkConfiguration netConf, A3CLearningConfiguration conf) {
|
||||||
|
this(mdp, new ActorCriticFactorySeparateStdDense(netConf), conf);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Deprecated
|
||||||
|
public A3CDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp, ActorCriticFactoryCompGraph factory,
|
||||||
|
A3CConfiguration conf, IDataManager dataManager) {
|
||||||
|
this(mdp, factory.buildActorCritic(mdp.getObservationSpace().getShape(), mdp.getActionSpace().getSize()), conf,
|
||||||
|
dataManager);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Deprecated
|
||||||
|
public A3CDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp, ActorCriticFactoryCompGraph factory,
|
||||||
|
A3CConfiguration conf) {
|
||||||
|
this(mdp, factory.buildActorCritic(mdp.getObservationSpace().getShape(), mdp.getActionSpace().getSize()), conf);
|
||||||
|
}
|
||||||
|
|
||||||
|
public A3CDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp, ActorCriticFactoryCompGraph factory,
|
||||||
|
A3CLearningConfiguration conf) {
|
||||||
|
this(mdp, factory.buildActorCritic(mdp.getObservationSpace().getShape(), mdp.getActionSpace().getSize()), conf);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
/*******************************************************************************
|
/*******************************************************************************
|
||||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
* Copyright (c) 2015-2019 Skymind, Inc.
|
||||||
|
* Copyright (c) 2020 Konduit K.K.
|
||||||
*
|
*
|
||||||
* This program and the accompanying materials are made available under the
|
* This program and the accompanying materials are made available under the
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
@ -19,10 +20,10 @@ package org.deeplearning4j.rl4j.learning.async.a3c.discrete;
|
||||||
import lombok.Getter;
|
import lombok.Getter;
|
||||||
import org.deeplearning4j.nn.gradient.Gradient;
|
import org.deeplearning4j.nn.gradient.Gradient;
|
||||||
import org.deeplearning4j.rl4j.learning.Learning;
|
import org.deeplearning4j.rl4j.learning.Learning;
|
||||||
import org.deeplearning4j.rl4j.learning.async.AsyncGlobal;
|
|
||||||
import org.deeplearning4j.rl4j.learning.async.AsyncThreadDiscrete;
|
import org.deeplearning4j.rl4j.learning.async.AsyncThreadDiscrete;
|
||||||
import org.deeplearning4j.rl4j.learning.async.IAsyncGlobal;
|
import org.deeplearning4j.rl4j.learning.async.IAsyncGlobal;
|
||||||
import org.deeplearning4j.rl4j.learning.async.MiniTrans;
|
import org.deeplearning4j.rl4j.learning.async.MiniTrans;
|
||||||
|
import org.deeplearning4j.rl4j.learning.configuration.A3CLearningConfiguration;
|
||||||
import org.deeplearning4j.rl4j.learning.listener.TrainingListenerList;
|
import org.deeplearning4j.rl4j.learning.listener.TrainingListenerList;
|
||||||
import org.deeplearning4j.rl4j.mdp.MDP;
|
import org.deeplearning4j.rl4j.mdp.MDP;
|
||||||
import org.deeplearning4j.rl4j.network.ac.IActorCritic;
|
import org.deeplearning4j.rl4j.network.ac.IActorCritic;
|
||||||
|
@ -31,9 +32,9 @@ import org.deeplearning4j.rl4j.policy.Policy;
|
||||||
import org.deeplearning4j.rl4j.space.DiscreteSpace;
|
import org.deeplearning4j.rl4j.space.DiscreteSpace;
|
||||||
import org.deeplearning4j.rl4j.space.Encodable;
|
import org.deeplearning4j.rl4j.space.Encodable;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
import org.nd4j.linalg.api.rng.Random;
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
import org.nd4j.linalg.indexing.NDArrayIndex;
|
import org.nd4j.linalg.indexing.NDArrayIndex;
|
||||||
import org.nd4j.linalg.api.rng.Random;
|
|
||||||
|
|
||||||
import java.util.Stack;
|
import java.util.Stack;
|
||||||
|
|
||||||
|
@ -45,7 +46,7 @@ import java.util.Stack;
|
||||||
public class A3CThreadDiscrete<O extends Encodable> extends AsyncThreadDiscrete<O, IActorCritic> {
|
public class A3CThreadDiscrete<O extends Encodable> extends AsyncThreadDiscrete<O, IActorCritic> {
|
||||||
|
|
||||||
@Getter
|
@Getter
|
||||||
final protected A3CDiscrete.A3CConfiguration conf;
|
final protected A3CLearningConfiguration conf;
|
||||||
@Getter
|
@Getter
|
||||||
final protected IAsyncGlobal<IActorCritic> asyncGlobal;
|
final protected IAsyncGlobal<IActorCritic> asyncGlobal;
|
||||||
@Getter
|
@Getter
|
||||||
|
@ -54,14 +55,14 @@ public class A3CThreadDiscrete<O extends Encodable> extends AsyncThreadDiscrete<
|
||||||
final private Random rnd;
|
final private Random rnd;
|
||||||
|
|
||||||
public A3CThreadDiscrete(MDP<O, Integer, DiscreteSpace> mdp, IAsyncGlobal<IActorCritic> asyncGlobal,
|
public A3CThreadDiscrete(MDP<O, Integer, DiscreteSpace> mdp, IAsyncGlobal<IActorCritic> asyncGlobal,
|
||||||
A3CDiscrete.A3CConfiguration a3cc, int deviceNum, TrainingListenerList listeners,
|
A3CLearningConfiguration a3cc, int deviceNum, TrainingListenerList listeners,
|
||||||
int threadNumber) {
|
int threadNumber) {
|
||||||
super(asyncGlobal, mdp, listeners, threadNumber, deviceNum);
|
super(asyncGlobal, mdp, listeners, threadNumber, deviceNum);
|
||||||
this.conf = a3cc;
|
this.conf = a3cc;
|
||||||
this.asyncGlobal = asyncGlobal;
|
this.asyncGlobal = asyncGlobal;
|
||||||
this.threadNumber = threadNumber;
|
this.threadNumber = threadNumber;
|
||||||
|
|
||||||
Integer seed = conf.getSeed();
|
Long seed = conf.getSeed();
|
||||||
rnd = Nd4j.getRandom();
|
rnd = Nd4j.getRandom();
|
||||||
if(seed != null) {
|
if(seed != null) {
|
||||||
rnd.setSeed(seed + threadNumber);
|
rnd.setSeed(seed + threadNumber);
|
||||||
|
|
|
@ -1,49 +1,53 @@
|
||||||
/*******************************************************************************
|
/*******************************************************************************
|
||||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
* Copyright (c) 2015-2019 Skymind, Inc.
|
||||||
|
* Copyright (c) 2020 Konduit K.K.
|
||||||
*
|
*
|
||||||
* This program and the accompanying materials are made available under the
|
* This program and the accompanying materials are made available under the
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
*
|
*
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
* License for the specific language governing permissions and limitations
|
* License for the specific language governing permissions and limitations
|
||||||
* under the License.
|
* under the License.
|
||||||
*
|
*
|
||||||
* SPDX-License-Identifier: Apache-2.0
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
******************************************************************************/
|
******************************************************************************/
|
||||||
|
|
||||||
package org.deeplearning4j.rl4j.learning.async.nstep.discrete;
|
package org.deeplearning4j.rl4j.learning.async.nstep.discrete;
|
||||||
|
|
||||||
import lombok.*;
|
import lombok.AllArgsConstructor;
|
||||||
import org.deeplearning4j.rl4j.learning.async.AsyncConfiguration;
|
import lombok.Builder;
|
||||||
|
import lombok.Data;
|
||||||
|
import lombok.EqualsAndHashCode;
|
||||||
|
import lombok.Getter;
|
||||||
import org.deeplearning4j.rl4j.learning.async.AsyncGlobal;
|
import org.deeplearning4j.rl4j.learning.async.AsyncGlobal;
|
||||||
import org.deeplearning4j.rl4j.learning.async.AsyncLearning;
|
import org.deeplearning4j.rl4j.learning.async.AsyncLearning;
|
||||||
import org.deeplearning4j.rl4j.learning.async.AsyncThread;
|
import org.deeplearning4j.rl4j.learning.async.AsyncThread;
|
||||||
|
import org.deeplearning4j.rl4j.learning.configuration.AsyncQLearningConfiguration;
|
||||||
import org.deeplearning4j.rl4j.mdp.MDP;
|
import org.deeplearning4j.rl4j.mdp.MDP;
|
||||||
import org.deeplearning4j.rl4j.network.dqn.IDQN;
|
import org.deeplearning4j.rl4j.network.dqn.IDQN;
|
||||||
import org.deeplearning4j.rl4j.policy.DQNPolicy;
|
import org.deeplearning4j.rl4j.policy.DQNPolicy;
|
||||||
import org.deeplearning4j.rl4j.policy.IPolicy;
|
import org.deeplearning4j.rl4j.policy.IPolicy;
|
||||||
import org.deeplearning4j.rl4j.space.DiscreteSpace;
|
import org.deeplearning4j.rl4j.space.DiscreteSpace;
|
||||||
import org.deeplearning4j.rl4j.space.Encodable;
|
import org.deeplearning4j.rl4j.space.Encodable;
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @author rubenfiszel (ruben.fiszel@epfl.ch) on 8/5/16.
|
* @author rubenfiszel (ruben.fiszel@epfl.ch) on 8/5/16.
|
||||||
*/
|
*/
|
||||||
public abstract class AsyncNStepQLearningDiscrete<O extends Encodable>
|
public abstract class AsyncNStepQLearningDiscrete<O extends Encodable>
|
||||||
extends AsyncLearning<O, Integer, DiscreteSpace, IDQN> {
|
extends AsyncLearning<O, Integer, DiscreteSpace, IDQN> {
|
||||||
|
|
||||||
@Getter
|
@Getter
|
||||||
final public AsyncNStepQLConfiguration configuration;
|
final public AsyncQLearningConfiguration configuration;
|
||||||
@Getter
|
@Getter
|
||||||
final private MDP<O, Integer, DiscreteSpace> mdp;
|
final private MDP<O, Integer, DiscreteSpace> mdp;
|
||||||
@Getter
|
@Getter
|
||||||
final private AsyncGlobal<IDQN> asyncGlobal;
|
final private AsyncGlobal<IDQN> asyncGlobal;
|
||||||
|
|
||||||
|
|
||||||
public AsyncNStepQLearningDiscrete(MDP<O, Integer, DiscreteSpace> mdp, IDQN dqn, AsyncNStepQLConfiguration conf) {
|
public AsyncNStepQLearningDiscrete(MDP<O, Integer, DiscreteSpace> mdp, IDQN dqn, AsyncQLearningConfiguration conf) {
|
||||||
this.mdp = mdp;
|
this.mdp = mdp;
|
||||||
this.configuration = conf;
|
this.configuration = conf;
|
||||||
this.asyncGlobal = new AsyncGlobal<>(dqn, conf, this);
|
this.asyncGlobal = new AsyncGlobal<>(dqn, conf, this);
|
||||||
|
@ -62,12 +66,11 @@ public abstract class AsyncNStepQLearningDiscrete<O extends Encodable>
|
||||||
return new DQNPolicy<O>(getNeuralNet());
|
return new DQNPolicy<O>(getNeuralNet());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@Data
|
@Data
|
||||||
@AllArgsConstructor
|
@AllArgsConstructor
|
||||||
@Builder
|
@Builder
|
||||||
@EqualsAndHashCode(callSuper = false)
|
@EqualsAndHashCode(callSuper = false)
|
||||||
public static class AsyncNStepQLConfiguration implements AsyncConfiguration {
|
public static class AsyncNStepQLConfiguration {
|
||||||
|
|
||||||
Integer seed;
|
Integer seed;
|
||||||
int maxEpochStep;
|
int maxEpochStep;
|
||||||
|
@ -82,5 +85,22 @@ public abstract class AsyncNStepQLearningDiscrete<O extends Encodable>
|
||||||
float minEpsilon;
|
float minEpsilon;
|
||||||
int epsilonNbStep;
|
int epsilonNbStep;
|
||||||
|
|
||||||
|
public AsyncQLearningConfiguration toLearningConfiguration() {
|
||||||
|
return AsyncQLearningConfiguration.builder()
|
||||||
|
.seed(new Long(seed))
|
||||||
|
.maxEpochStep(maxEpochStep)
|
||||||
|
.maxStep(maxStep)
|
||||||
|
.numThreads(numThread)
|
||||||
|
.nStep(nstep)
|
||||||
|
.targetDqnUpdateFreq(targetDqnUpdateFreq)
|
||||||
|
.updateStart(updateStart)
|
||||||
|
.rewardFactor(rewardFactor)
|
||||||
|
.gamma(gamma)
|
||||||
|
.errorClamp(errorClamp)
|
||||||
|
.minEpsilon(minEpsilon)
|
||||||
|
.build();
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,24 +1,27 @@
|
||||||
/*******************************************************************************
|
/*******************************************************************************
|
||||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
* Copyright (c) 2015-2019 Skymind, Inc.
|
||||||
|
* Copyright (c) 2020 Konduit K.K.
|
||||||
*
|
*
|
||||||
* This program and the accompanying materials are made available under the
|
* This program and the accompanying materials are made available under the
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
*
|
*
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
* License for the specific language governing permissions and limitations
|
* License for the specific language governing permissions and limitations
|
||||||
* under the License.
|
* under the License.
|
||||||
*
|
*
|
||||||
* SPDX-License-Identifier: Apache-2.0
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
******************************************************************************/
|
******************************************************************************/
|
||||||
|
|
||||||
package org.deeplearning4j.rl4j.learning.async.nstep.discrete;
|
package org.deeplearning4j.rl4j.learning.async.nstep.discrete;
|
||||||
|
|
||||||
import org.deeplearning4j.rl4j.learning.HistoryProcessor;
|
import org.deeplearning4j.rl4j.learning.HistoryProcessor;
|
||||||
import org.deeplearning4j.rl4j.learning.async.AsyncThread;
|
import org.deeplearning4j.rl4j.learning.async.AsyncThread;
|
||||||
|
import org.deeplearning4j.rl4j.learning.configuration.AsyncQLearningConfiguration;
|
||||||
import org.deeplearning4j.rl4j.mdp.MDP;
|
import org.deeplearning4j.rl4j.mdp.MDP;
|
||||||
|
import org.deeplearning4j.rl4j.network.configuration.NetworkConfiguration;
|
||||||
import org.deeplearning4j.rl4j.network.dqn.DQNFactory;
|
import org.deeplearning4j.rl4j.network.dqn.DQNFactory;
|
||||||
import org.deeplearning4j.rl4j.network.dqn.DQNFactoryStdConv;
|
import org.deeplearning4j.rl4j.network.dqn.DQNFactoryStdConv;
|
||||||
import org.deeplearning4j.rl4j.network.dqn.IDQN;
|
import org.deeplearning4j.rl4j.network.dqn.IDQN;
|
||||||
|
@ -38,12 +41,12 @@ public class AsyncNStepQLearningDiscreteConv<O extends Encodable> extends AsyncN
|
||||||
|
|
||||||
@Deprecated
|
@Deprecated
|
||||||
public AsyncNStepQLearningDiscreteConv(MDP<O, Integer, DiscreteSpace> mdp, IDQN dqn,
|
public AsyncNStepQLearningDiscreteConv(MDP<O, Integer, DiscreteSpace> mdp, IDQN dqn,
|
||||||
HistoryProcessor.Configuration hpconf, AsyncNStepQLConfiguration conf, IDataManager dataManager) {
|
HistoryProcessor.Configuration hpconf, AsyncQLearningConfiguration conf, IDataManager dataManager) {
|
||||||
this(mdp, dqn, hpconf, conf);
|
this(mdp, dqn, hpconf, conf);
|
||||||
addListener(new DataManagerTrainingListener(dataManager));
|
addListener(new DataManagerTrainingListener(dataManager));
|
||||||
}
|
}
|
||||||
public AsyncNStepQLearningDiscreteConv(MDP<O, Integer, DiscreteSpace> mdp, IDQN dqn,
|
public AsyncNStepQLearningDiscreteConv(MDP<O, Integer, DiscreteSpace> mdp, IDQN dqn,
|
||||||
HistoryProcessor.Configuration hpconf, AsyncNStepQLConfiguration conf) {
|
HistoryProcessor.Configuration hpconf, AsyncQLearningConfiguration conf) {
|
||||||
super(mdp, dqn, conf);
|
super(mdp, dqn, conf);
|
||||||
this.hpconf = hpconf;
|
this.hpconf = hpconf;
|
||||||
setHistoryProcessor(hpconf);
|
setHistoryProcessor(hpconf);
|
||||||
|
@ -51,21 +54,21 @@ public class AsyncNStepQLearningDiscreteConv<O extends Encodable> extends AsyncN
|
||||||
|
|
||||||
@Deprecated
|
@Deprecated
|
||||||
public AsyncNStepQLearningDiscreteConv(MDP<O, Integer, DiscreteSpace> mdp, DQNFactory factory,
|
public AsyncNStepQLearningDiscreteConv(MDP<O, Integer, DiscreteSpace> mdp, DQNFactory factory,
|
||||||
HistoryProcessor.Configuration hpconf, AsyncNStepQLConfiguration conf, IDataManager dataManager) {
|
HistoryProcessor.Configuration hpconf, AsyncQLearningConfiguration conf, IDataManager dataManager) {
|
||||||
this(mdp, factory.buildDQN(hpconf.getShape(), mdp.getActionSpace().getSize()), hpconf, conf, dataManager);
|
this(mdp, factory.buildDQN(hpconf.getShape(), mdp.getActionSpace().getSize()), hpconf, conf, dataManager);
|
||||||
}
|
}
|
||||||
public AsyncNStepQLearningDiscreteConv(MDP<O, Integer, DiscreteSpace> mdp, DQNFactory factory,
|
public AsyncNStepQLearningDiscreteConv(MDP<O, Integer, DiscreteSpace> mdp, DQNFactory factory,
|
||||||
HistoryProcessor.Configuration hpconf, AsyncNStepQLConfiguration conf) {
|
HistoryProcessor.Configuration hpconf, AsyncQLearningConfiguration conf) {
|
||||||
this(mdp, factory.buildDQN(hpconf.getShape(), mdp.getActionSpace().getSize()), hpconf, conf);
|
this(mdp, factory.buildDQN(hpconf.getShape(), mdp.getActionSpace().getSize()), hpconf, conf);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Deprecated
|
@Deprecated
|
||||||
public AsyncNStepQLearningDiscreteConv(MDP<O, Integer, DiscreteSpace> mdp, DQNFactoryStdConv.Configuration netConf,
|
public AsyncNStepQLearningDiscreteConv(MDP<O, Integer, DiscreteSpace> mdp, NetworkConfiguration netConf,
|
||||||
HistoryProcessor.Configuration hpconf, AsyncNStepQLConfiguration conf, IDataManager dataManager) {
|
HistoryProcessor.Configuration hpconf, AsyncQLearningConfiguration conf, IDataManager dataManager) {
|
||||||
this(mdp, new DQNFactoryStdConv(netConf), hpconf, conf, dataManager);
|
this(mdp, new DQNFactoryStdConv(netConf), hpconf, conf, dataManager);
|
||||||
}
|
}
|
||||||
public AsyncNStepQLearningDiscreteConv(MDP<O, Integer, DiscreteSpace> mdp, DQNFactoryStdConv.Configuration netConf,
|
public AsyncNStepQLearningDiscreteConv(MDP<O, Integer, DiscreteSpace> mdp, NetworkConfiguration netConf,
|
||||||
HistoryProcessor.Configuration hpconf, AsyncNStepQLConfiguration conf) {
|
HistoryProcessor.Configuration hpconf, AsyncQLearningConfiguration conf) {
|
||||||
this(mdp, new DQNFactoryStdConv(netConf), hpconf, conf);
|
this(mdp, new DQNFactoryStdConv(netConf), hpconf, conf);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -1,22 +1,26 @@
|
||||||
/*******************************************************************************
|
/*******************************************************************************
|
||||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
* Copyright (c) 2015-2019 Skymind, Inc.
|
||||||
|
* Copyright (c) 2020 Konduit K.K.
|
||||||
*
|
*
|
||||||
* This program and the accompanying materials are made available under the
|
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
*
|
*
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
* This program and the accompanying materials are made available under the
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
* License for the specific language governing permissions and limitations
|
|
||||||
* under the License.
|
|
||||||
*
|
*
|
||||||
* SPDX-License-Identifier: Apache-2.0
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
******************************************************************************/
|
******************************************************************************/
|
||||||
|
|
||||||
package org.deeplearning4j.rl4j.learning.async.nstep.discrete;
|
package org.deeplearning4j.rl4j.learning.async.nstep.discrete;
|
||||||
|
|
||||||
|
import org.deeplearning4j.rl4j.learning.configuration.AsyncQLearningConfiguration;
|
||||||
import org.deeplearning4j.rl4j.mdp.MDP;
|
import org.deeplearning4j.rl4j.mdp.MDP;
|
||||||
|
import org.deeplearning4j.rl4j.network.configuration.DQNDenseNetworkConfiguration;
|
||||||
import org.deeplearning4j.rl4j.network.dqn.DQNFactory;
|
import org.deeplearning4j.rl4j.network.dqn.DQNFactory;
|
||||||
import org.deeplearning4j.rl4j.network.dqn.DQNFactoryStdDense;
|
import org.deeplearning4j.rl4j.network.dqn.DQNFactoryStdDense;
|
||||||
import org.deeplearning4j.rl4j.network.dqn.IDQN;
|
import org.deeplearning4j.rl4j.network.dqn.IDQN;
|
||||||
|
@ -32,35 +36,56 @@ public class AsyncNStepQLearningDiscreteDense<O extends Encodable> extends Async
|
||||||
|
|
||||||
@Deprecated
|
@Deprecated
|
||||||
public AsyncNStepQLearningDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp, IDQN dqn,
|
public AsyncNStepQLearningDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp, IDQN dqn,
|
||||||
AsyncNStepQLConfiguration conf, IDataManager dataManager) {
|
AsyncNStepQLConfiguration conf, IDataManager dataManager) {
|
||||||
super(mdp, dqn, conf);
|
super(mdp, dqn, conf.toLearningConfiguration());
|
||||||
addListener(new DataManagerTrainingListener(dataManager));
|
addListener(new DataManagerTrainingListener(dataManager));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Deprecated
|
||||||
public AsyncNStepQLearningDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp, IDQN dqn,
|
public AsyncNStepQLearningDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp, IDQN dqn,
|
||||||
AsyncNStepQLConfiguration conf) {
|
AsyncNStepQLConfiguration conf) {
|
||||||
|
super(mdp, dqn, conf.toLearningConfiguration());
|
||||||
|
}
|
||||||
|
|
||||||
|
public AsyncNStepQLearningDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp, IDQN dqn,
|
||||||
|
AsyncQLearningConfiguration conf) {
|
||||||
super(mdp, dqn, conf);
|
super(mdp, dqn, conf);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Deprecated
|
@Deprecated
|
||||||
public AsyncNStepQLearningDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp, DQNFactory factory,
|
public AsyncNStepQLearningDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp, DQNFactory factory,
|
||||||
AsyncNStepQLConfiguration conf, IDataManager dataManager) {
|
AsyncNStepQLConfiguration conf, IDataManager dataManager) {
|
||||||
this(mdp, factory.buildDQN(mdp.getObservationSpace().getShape(), mdp.getActionSpace().getSize()), conf,
|
this(mdp, factory.buildDQN(mdp.getObservationSpace().getShape(), mdp.getActionSpace().getSize()), conf,
|
||||||
dataManager);
|
dataManager);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Deprecated
|
||||||
public AsyncNStepQLearningDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp, DQNFactory factory,
|
public AsyncNStepQLearningDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp, DQNFactory factory,
|
||||||
AsyncNStepQLConfiguration conf) {
|
AsyncNStepQLConfiguration conf) {
|
||||||
this(mdp, factory.buildDQN(mdp.getObservationSpace().getShape(), mdp.getActionSpace().getSize()), conf);
|
this(mdp, factory.buildDQN(mdp.getObservationSpace().getShape(), mdp.getActionSpace().getSize()), conf);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public AsyncNStepQLearningDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp, DQNFactory factory,
|
||||||
|
AsyncQLearningConfiguration conf) {
|
||||||
|
this(mdp, factory.buildDQN(mdp.getObservationSpace().getShape(), mdp.getActionSpace().getSize()), conf);
|
||||||
|
}
|
||||||
|
|
||||||
@Deprecated
|
@Deprecated
|
||||||
public AsyncNStepQLearningDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp,
|
public AsyncNStepQLearningDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp,
|
||||||
DQNFactoryStdDense.Configuration netConf, AsyncNStepQLConfiguration conf, IDataManager dataManager) {
|
DQNFactoryStdDense.Configuration netConf, AsyncNStepQLConfiguration conf, IDataManager dataManager) {
|
||||||
this(mdp, new DQNFactoryStdDense(netConf), conf, dataManager);
|
this(mdp, new DQNFactoryStdDense(netConf.toNetworkConfiguration()), conf, dataManager);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Deprecated
|
||||||
public AsyncNStepQLearningDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp,
|
public AsyncNStepQLearningDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp,
|
||||||
DQNFactoryStdDense.Configuration netConf, AsyncNStepQLConfiguration conf) {
|
DQNFactoryStdDense.Configuration netConf, AsyncNStepQLConfiguration conf) {
|
||||||
|
this(mdp, new DQNFactoryStdDense(netConf.toNetworkConfiguration()), conf);
|
||||||
|
}
|
||||||
|
|
||||||
|
public AsyncNStepQLearningDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp,
|
||||||
|
DQNDenseNetworkConfiguration netConf, AsyncQLearningConfiguration conf) {
|
||||||
this(mdp, new DQNFactoryStdDense(netConf), conf);
|
this(mdp, new DQNFactoryStdDense(netConf), conf);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,17 +1,18 @@
|
||||||
/*******************************************************************************
|
/*******************************************************************************
|
||||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
* Copyright (c) 2015-2019 Skymind, Inc.
|
||||||
|
* Copyright (c) 2020 Konduit K.K.
|
||||||
*
|
*
|
||||||
* This program and the accompanying materials are made available under the
|
* This program and the accompanying materials are made available under the
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
*
|
*
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
* License for the specific language governing permissions and limitations
|
* License for the specific language governing permissions and limitations
|
||||||
* under the License.
|
* under the License.
|
||||||
*
|
*
|
||||||
* SPDX-License-Identifier: Apache-2.0
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
******************************************************************************/
|
******************************************************************************/
|
||||||
|
|
||||||
package org.deeplearning4j.rl4j.learning.async.nstep.discrete;
|
package org.deeplearning4j.rl4j.learning.async.nstep.discrete;
|
||||||
|
@ -22,6 +23,7 @@ import org.deeplearning4j.rl4j.learning.Learning;
|
||||||
import org.deeplearning4j.rl4j.learning.async.AsyncThreadDiscrete;
|
import org.deeplearning4j.rl4j.learning.async.AsyncThreadDiscrete;
|
||||||
import org.deeplearning4j.rl4j.learning.async.IAsyncGlobal;
|
import org.deeplearning4j.rl4j.learning.async.IAsyncGlobal;
|
||||||
import org.deeplearning4j.rl4j.learning.async.MiniTrans;
|
import org.deeplearning4j.rl4j.learning.async.MiniTrans;
|
||||||
|
import org.deeplearning4j.rl4j.learning.configuration.AsyncQLearningConfiguration;
|
||||||
import org.deeplearning4j.rl4j.learning.listener.TrainingListenerList;
|
import org.deeplearning4j.rl4j.learning.listener.TrainingListenerList;
|
||||||
import org.deeplearning4j.rl4j.mdp.MDP;
|
import org.deeplearning4j.rl4j.mdp.MDP;
|
||||||
import org.deeplearning4j.rl4j.network.dqn.IDQN;
|
import org.deeplearning4j.rl4j.network.dqn.IDQN;
|
||||||
|
@ -42,7 +44,7 @@ import java.util.Stack;
|
||||||
public class AsyncNStepQLearningThreadDiscrete<O extends Encodable> extends AsyncThreadDiscrete<O, IDQN> {
|
public class AsyncNStepQLearningThreadDiscrete<O extends Encodable> extends AsyncThreadDiscrete<O, IDQN> {
|
||||||
|
|
||||||
@Getter
|
@Getter
|
||||||
final protected AsyncNStepQLearningDiscrete.AsyncNStepQLConfiguration conf;
|
final protected AsyncQLearningConfiguration conf;
|
||||||
@Getter
|
@Getter
|
||||||
final protected IAsyncGlobal<IDQN> asyncGlobal;
|
final protected IAsyncGlobal<IDQN> asyncGlobal;
|
||||||
@Getter
|
@Getter
|
||||||
|
@ -51,7 +53,7 @@ public class AsyncNStepQLearningThreadDiscrete<O extends Encodable> extends Asyn
|
||||||
final private Random rnd;
|
final private Random rnd;
|
||||||
|
|
||||||
public AsyncNStepQLearningThreadDiscrete(MDP<O, Integer, DiscreteSpace> mdp, IAsyncGlobal<IDQN> asyncGlobal,
|
public AsyncNStepQLearningThreadDiscrete(MDP<O, Integer, DiscreteSpace> mdp, IAsyncGlobal<IDQN> asyncGlobal,
|
||||||
AsyncNStepQLearningDiscrete.AsyncNStepQLConfiguration conf,
|
AsyncQLearningConfiguration conf,
|
||||||
TrainingListenerList listeners, int threadNumber, int deviceNum) {
|
TrainingListenerList listeners, int threadNumber, int deviceNum) {
|
||||||
super(asyncGlobal, mdp, listeners, threadNumber, deviceNum);
|
super(asyncGlobal, mdp, listeners, threadNumber, deviceNum);
|
||||||
this.conf = conf;
|
this.conf = conf;
|
||||||
|
@ -59,7 +61,7 @@ public class AsyncNStepQLearningThreadDiscrete<O extends Encodable> extends Asyn
|
||||||
this.threadNumber = threadNumber;
|
this.threadNumber = threadNumber;
|
||||||
rnd = Nd4j.getRandom();
|
rnd = Nd4j.getRandom();
|
||||||
|
|
||||||
Integer seed = conf.getSeed();
|
Long seed = conf.getSeed();
|
||||||
if(seed != null) {
|
if(seed != null) {
|
||||||
rnd.setSeed(seed + threadNumber);
|
rnd.setSeed(seed + threadNumber);
|
||||||
}
|
}
|
||||||
|
|
|
@ -0,0 +1,46 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2020 Konduit K.K.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
package org.deeplearning4j.rl4j.learning.configuration;
|
||||||
|
|
||||||
|
import lombok.Builder;
|
||||||
|
import lombok.Data;
|
||||||
|
import lombok.EqualsAndHashCode;
|
||||||
|
import lombok.experimental.SuperBuilder;
|
||||||
|
|
||||||
|
@Data
|
||||||
|
@SuperBuilder
|
||||||
|
@EqualsAndHashCode(callSuper = true)
|
||||||
|
public class A3CLearningConfiguration extends LearningConfiguration implements IAsyncLearningConfiguration {
|
||||||
|
|
||||||
|
/**
|
||||||
|
* The number of asynchronous threads to use to generate gradients
|
||||||
|
*/
|
||||||
|
private final int numThreads;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* The number of steps to calculate gradients over
|
||||||
|
*/
|
||||||
|
private final int nStep;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* The frequency of async training iterations to update the target network.
|
||||||
|
*
|
||||||
|
* If this is set to -1 then the target network is updated after every training iteration
|
||||||
|
*/
|
||||||
|
@Builder.Default
|
||||||
|
private int learnerUpdateFrequency = -1;
|
||||||
|
}
|
|
@ -0,0 +1,42 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2020 Konduit K.K.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
package org.deeplearning4j.rl4j.learning.configuration;
|
||||||
|
|
||||||
|
import lombok.Data;
|
||||||
|
import lombok.EqualsAndHashCode;
|
||||||
|
import lombok.experimental.SuperBuilder;
|
||||||
|
|
||||||
|
|
||||||
|
@Data
|
||||||
|
@SuperBuilder
|
||||||
|
@EqualsAndHashCode(callSuper = true)
|
||||||
|
public class AsyncQLearningConfiguration extends QLearningConfiguration implements IAsyncLearningConfiguration {
|
||||||
|
|
||||||
|
/**
|
||||||
|
* The number of asynchronous threads to use to generate experience data
|
||||||
|
*/
|
||||||
|
private final int numThreads;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* The number of steps in each training interations
|
||||||
|
*/
|
||||||
|
private final int nStep;
|
||||||
|
|
||||||
|
public int getLearnerUpdateFrequency() {
|
||||||
|
return getTargetDqnUpdateFreq();
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,28 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2020 Konduit K.K.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
package org.deeplearning4j.rl4j.learning.configuration;
|
||||||
|
|
||||||
|
public interface IAsyncLearningConfiguration extends ILearningConfiguration {
|
||||||
|
|
||||||
|
int getNumThreads();
|
||||||
|
|
||||||
|
int getNStep();
|
||||||
|
|
||||||
|
int getLearnerUpdateFrequency();
|
||||||
|
|
||||||
|
int getMaxStep();
|
||||||
|
}
|
|
@ -1,5 +1,5 @@
|
||||||
/*******************************************************************************
|
/*******************************************************************************
|
||||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
* Copyright (c) 2020 Konduit K.K.
|
||||||
*
|
*
|
||||||
* This program and the accompanying materials are made available under the
|
* This program and the accompanying materials are made available under the
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
@ -14,36 +14,16 @@
|
||||||
* SPDX-License-Identifier: Apache-2.0
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
******************************************************************************/
|
******************************************************************************/
|
||||||
|
|
||||||
package org.deeplearning4j.rl4j.learning.async;
|
package org.deeplearning4j.rl4j.learning.configuration;
|
||||||
|
|
||||||
import org.deeplearning4j.rl4j.learning.ILearning;
|
public interface ILearningConfiguration {
|
||||||
|
Long getSeed();
|
||||||
/**
|
|
||||||
* @author rubenfiszel (ruben.fiszel@epfl.ch) on 8/23/16.
|
|
||||||
*
|
|
||||||
* Interface configuration for all training method that inherit
|
|
||||||
* from AsyncLearning
|
|
||||||
*/
|
|
||||||
public interface AsyncConfiguration extends ILearning.LConfiguration {
|
|
||||||
|
|
||||||
Integer getSeed();
|
|
||||||
|
|
||||||
int getMaxEpochStep();
|
int getMaxEpochStep();
|
||||||
|
|
||||||
int getMaxStep();
|
int getMaxStep();
|
||||||
|
|
||||||
int getNumThread();
|
|
||||||
|
|
||||||
int getNstep();
|
|
||||||
|
|
||||||
int getTargetDqnUpdateFreq();
|
|
||||||
|
|
||||||
int getUpdateStart();
|
|
||||||
|
|
||||||
double getRewardFactor();
|
|
||||||
|
|
||||||
double getGamma();
|
double getGamma();
|
||||||
|
|
||||||
double getErrorClamp();
|
double getRewardFactor();
|
||||||
|
|
||||||
}
|
}
|
|
@ -0,0 +1,59 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2020 Konduit K.K.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
package org.deeplearning4j.rl4j.learning.configuration;
|
||||||
|
|
||||||
|
import lombok.Builder;
|
||||||
|
import lombok.Data;
|
||||||
|
import lombok.NoArgsConstructor;
|
||||||
|
import lombok.experimental.SuperBuilder;
|
||||||
|
|
||||||
|
@Data
|
||||||
|
@SuperBuilder
|
||||||
|
@NoArgsConstructor
|
||||||
|
public class LearningConfiguration implements ILearningConfiguration {
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Seed value used for training
|
||||||
|
*/
|
||||||
|
@Builder.Default
|
||||||
|
private Long seed = System.currentTimeMillis();
|
||||||
|
|
||||||
|
/**
|
||||||
|
* The maximum number of steps in each episode
|
||||||
|
*/
|
||||||
|
@Builder.Default
|
||||||
|
private int maxEpochStep = 200;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* The maximum number of steps to train for
|
||||||
|
*/
|
||||||
|
@Builder.Default
|
||||||
|
private int maxStep = 150000;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Gamma parameter used for discounted rewards
|
||||||
|
*/
|
||||||
|
@Builder.Default
|
||||||
|
private double gamma = 0.99;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Scaling parameter for rewards
|
||||||
|
*/
|
||||||
|
@Builder.Default
|
||||||
|
private double rewardFactor = 1.0;
|
||||||
|
|
||||||
|
}
|
|
@ -0,0 +1,79 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2020 Konduit K.K.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
package org.deeplearning4j.rl4j.learning.configuration;
|
||||||
|
|
||||||
|
import lombok.Builder;
|
||||||
|
import lombok.Data;
|
||||||
|
import lombok.EqualsAndHashCode;
|
||||||
|
import lombok.NoArgsConstructor;
|
||||||
|
import lombok.experimental.SuperBuilder;
|
||||||
|
|
||||||
|
@Data
|
||||||
|
@SuperBuilder
|
||||||
|
@NoArgsConstructor
|
||||||
|
@EqualsAndHashCode(callSuper = false)
|
||||||
|
public class QLearningConfiguration extends LearningConfiguration {
|
||||||
|
|
||||||
|
/**
|
||||||
|
* The maximum size of the experience replay buffer
|
||||||
|
*/
|
||||||
|
@Builder.Default
|
||||||
|
private int expRepMaxSize = 150000;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* The batch size of experience for each training iteration
|
||||||
|
*/
|
||||||
|
@Builder.Default
|
||||||
|
private int batchSize = 32;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* How many steps between target network updates
|
||||||
|
*/
|
||||||
|
@Builder.Default
|
||||||
|
private int targetDqnUpdateFreq = 100;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* The number of steps to initially wait for until samplling batches from experience replay buffer
|
||||||
|
*/
|
||||||
|
@Builder.Default
|
||||||
|
private int updateStart = 10;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Prevent the new Q-Value from being farther than <i>errorClamp</i> away from the previous value. Double.NaN will result in no clamping
|
||||||
|
*/
|
||||||
|
@Builder.Default
|
||||||
|
private double errorClamp = 1.0;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* The minimum probability for random exploration action during episilon-greedy annealing
|
||||||
|
*/
|
||||||
|
@Builder.Default
|
||||||
|
private double minEpsilon = 0.1f;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* The number of steps to anneal epsilon to its minimum value.
|
||||||
|
*/
|
||||||
|
@Builder.Default
|
||||||
|
private int epsilonNbStep = 10000;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Whether to use the double DQN algorithm
|
||||||
|
*/
|
||||||
|
@Builder.Default
|
||||||
|
private boolean doubleDQN = false;
|
||||||
|
|
||||||
|
}
|
|
@ -1,5 +1,6 @@
|
||||||
/*******************************************************************************
|
/*******************************************************************************
|
||||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
* Copyright (c) 2015-2019 Skymind, Inc.
|
||||||
|
* Copyright (c) 2020 Konduit K.K.
|
||||||
*
|
*
|
||||||
* This program and the accompanying materials are made available under the
|
* This program and the accompanying materials are made available under the
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
@ -63,7 +64,7 @@ public abstract class SyncLearning<O, A, AS extends ActionSpace<A>, NN extends N
|
||||||
/**
|
/**
|
||||||
* This method will train the model<p>
|
* This method will train the model<p>
|
||||||
* The training stop when:<br>
|
* The training stop when:<br>
|
||||||
* - the number of steps reaches the maximum defined in the configuration (see {@link LConfiguration#getMaxStep() LConfiguration.getMaxStep()})<br>
|
* - the number of steps reaches the maximum defined in the configuration (see {@link ILearningConfiguration#getMaxStep() LConfiguration.getMaxStep()})<br>
|
||||||
* OR<br>
|
* OR<br>
|
||||||
* - a listener explicitly stops it<br>
|
* - a listener explicitly stops it<br>
|
||||||
* <p>
|
* <p>
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
/*******************************************************************************
|
/*******************************************************************************
|
||||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
* Copyright (c) 2015-2019 Skymind, Inc.
|
||||||
|
* Copyright (c) 2020 Konduit K.K.
|
||||||
*
|
*
|
||||||
* This program and the accompanying materials are made available under the
|
* This program and the accompanying materials are made available under the
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
@ -18,10 +19,19 @@ package org.deeplearning4j.rl4j.learning.sync.qlearning;
|
||||||
|
|
||||||
import com.fasterxml.jackson.databind.annotation.JsonDeserialize;
|
import com.fasterxml.jackson.databind.annotation.JsonDeserialize;
|
||||||
import com.fasterxml.jackson.databind.annotation.JsonPOJOBuilder;
|
import com.fasterxml.jackson.databind.annotation.JsonPOJOBuilder;
|
||||||
import lombok.*;
|
import lombok.AccessLevel;
|
||||||
|
import lombok.AllArgsConstructor;
|
||||||
|
import lombok.Builder;
|
||||||
|
import lombok.Data;
|
||||||
|
import lombok.EqualsAndHashCode;
|
||||||
|
import lombok.Getter;
|
||||||
|
import lombok.Setter;
|
||||||
|
import lombok.Value;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import org.deeplearning4j.gym.StepReply;
|
import org.deeplearning4j.gym.StepReply;
|
||||||
import org.deeplearning4j.rl4j.learning.EpochStepCounter;
|
import org.deeplearning4j.rl4j.learning.EpochStepCounter;
|
||||||
|
import org.deeplearning4j.rl4j.learning.configuration.ILearningConfiguration;
|
||||||
|
import org.deeplearning4j.rl4j.learning.configuration.QLearningConfiguration;
|
||||||
import org.deeplearning4j.rl4j.learning.sync.ExpReplay;
|
import org.deeplearning4j.rl4j.learning.sync.ExpReplay;
|
||||||
import org.deeplearning4j.rl4j.learning.sync.IExpReplay;
|
import org.deeplearning4j.rl4j.learning.sync.IExpReplay;
|
||||||
import org.deeplearning4j.rl4j.learning.sync.SyncLearning;
|
import org.deeplearning4j.rl4j.learning.sync.SyncLearning;
|
||||||
|
@ -59,15 +69,15 @@ public abstract class QLearning<O extends Encodable, A, AS extends ActionSpace<A
|
||||||
|
|
||||||
protected abstract LegacyMDPWrapper<O, A, AS> getLegacyMDPWrapper();
|
protected abstract LegacyMDPWrapper<O, A, AS> getLegacyMDPWrapper();
|
||||||
|
|
||||||
public QLearning(QLConfiguration conf) {
|
public QLearning(QLearningConfiguration conf) {
|
||||||
this(conf, getSeededRandom(conf.getSeed()));
|
this(conf, getSeededRandom(conf.getSeed()));
|
||||||
}
|
}
|
||||||
|
|
||||||
public QLearning(QLConfiguration conf, Random random) {
|
public QLearning(QLearningConfiguration conf, Random random) {
|
||||||
expReplay = new ExpReplay<>(conf.getExpRepMaxSize(), conf.getBatchSize(), random);
|
expReplay = new ExpReplay<>(conf.getExpRepMaxSize(), conf.getBatchSize(), random);
|
||||||
}
|
}
|
||||||
|
|
||||||
private static Random getSeededRandom(Integer seed) {
|
private static Random getSeededRandom(Long seed) {
|
||||||
Random rnd = Nd4j.getRandom();
|
Random rnd = Nd4j.getRandom();
|
||||||
if(seed != null) {
|
if(seed != null) {
|
||||||
rnd.setSeed(seed);
|
rnd.setSeed(seed);
|
||||||
|
@ -95,7 +105,7 @@ public abstract class QLearning<O extends Encodable, A, AS extends ActionSpace<A
|
||||||
return getQNetwork();
|
return getQNetwork();
|
||||||
}
|
}
|
||||||
|
|
||||||
public abstract QLConfiguration getConfiguration();
|
public abstract QLearningConfiguration getConfiguration();
|
||||||
|
|
||||||
protected abstract void preEpoch();
|
protected abstract void preEpoch();
|
||||||
|
|
||||||
|
@ -198,7 +208,7 @@ public abstract class QLearning<O extends Encodable, A, AS extends ActionSpace<A
|
||||||
double reward;
|
double reward;
|
||||||
int episodeLength;
|
int episodeLength;
|
||||||
List<Double> scores;
|
List<Double> scores;
|
||||||
float epsilon;
|
double epsilon;
|
||||||
double startQ;
|
double startQ;
|
||||||
double meanQ;
|
double meanQ;
|
||||||
}
|
}
|
||||||
|
@ -213,12 +223,14 @@ public abstract class QLearning<O extends Encodable, A, AS extends ActionSpace<A
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@Data
|
@Data
|
||||||
@AllArgsConstructor
|
@AllArgsConstructor
|
||||||
@Builder
|
@Builder
|
||||||
|
@Deprecated
|
||||||
@EqualsAndHashCode(callSuper = false)
|
@EqualsAndHashCode(callSuper = false)
|
||||||
@JsonDeserialize(builder = QLConfiguration.QLConfigurationBuilder.class)
|
@JsonDeserialize(builder = QLConfiguration.QLConfigurationBuilder.class)
|
||||||
public static class QLConfiguration implements LConfiguration {
|
public static class QLConfiguration {
|
||||||
|
|
||||||
Integer seed;
|
Integer seed;
|
||||||
int maxEpochStep;
|
int maxEpochStep;
|
||||||
|
@ -237,7 +249,25 @@ public abstract class QLearning<O extends Encodable, A, AS extends ActionSpace<A
|
||||||
@JsonPOJOBuilder(withPrefix = "")
|
@JsonPOJOBuilder(withPrefix = "")
|
||||||
public static final class QLConfigurationBuilder {
|
public static final class QLConfigurationBuilder {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public QLearningConfiguration toLearningConfiguration() {
|
||||||
|
|
||||||
|
return QLearningConfiguration.builder()
|
||||||
|
.seed(seed.longValue())
|
||||||
|
.maxEpochStep(maxEpochStep)
|
||||||
|
.maxStep(maxStep)
|
||||||
|
.expRepMaxSize(expRepMaxSize)
|
||||||
|
.batchSize(batchSize)
|
||||||
|
.targetDqnUpdateFreq(targetDqnUpdateFreq)
|
||||||
|
.updateStart(updateStart)
|
||||||
|
.rewardFactor(rewardFactor)
|
||||||
|
.gamma(gamma)
|
||||||
|
.errorClamp(errorClamp)
|
||||||
|
.minEpsilon(minEpsilon)
|
||||||
|
.epsilonNbStep(epsilonNbStep)
|
||||||
|
.doubleDQN(doubleDQN)
|
||||||
|
.build();
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
/*******************************************************************************
|
/*******************************************************************************
|
||||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
* Copyright (c) 2015-2019 Skymind, Inc.
|
||||||
|
* Copyright (c) 2020 Konduit K.K.
|
||||||
*
|
*
|
||||||
* This program and the accompanying materials are made available under the
|
* This program and the accompanying materials are made available under the
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
@ -22,6 +23,7 @@ import lombok.Setter;
|
||||||
import org.deeplearning4j.gym.StepReply;
|
import org.deeplearning4j.gym.StepReply;
|
||||||
import org.deeplearning4j.rl4j.learning.IHistoryProcessor;
|
import org.deeplearning4j.rl4j.learning.IHistoryProcessor;
|
||||||
import org.deeplearning4j.rl4j.learning.Learning;
|
import org.deeplearning4j.rl4j.learning.Learning;
|
||||||
|
import org.deeplearning4j.rl4j.learning.configuration.QLearningConfiguration;
|
||||||
import org.deeplearning4j.rl4j.learning.sync.Transition;
|
import org.deeplearning4j.rl4j.learning.sync.Transition;
|
||||||
import org.deeplearning4j.rl4j.learning.sync.qlearning.QLearning;
|
import org.deeplearning4j.rl4j.learning.sync.qlearning.QLearning;
|
||||||
import org.deeplearning4j.rl4j.learning.sync.qlearning.discrete.TDTargetAlgorithm.DoubleDQN;
|
import org.deeplearning4j.rl4j.learning.sync.qlearning.discrete.TDTargetAlgorithm.DoubleDQN;
|
||||||
|
@ -45,16 +47,15 @@ import java.util.ArrayList;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @author rubenfiszel (ruben.fiszel@epfl.ch) 7/18/16.
|
* @author rubenfiszel (ruben.fiszel@epfl.ch) 7/18/16.
|
||||||
*
|
* <p>
|
||||||
* DQN or Deep Q-Learning in the Discrete domain
|
* DQN or Deep Q-Learning in the Discrete domain
|
||||||
*
|
* <p>
|
||||||
* http://arxiv.org/abs/1312.5602
|
* http://arxiv.org/abs/1312.5602
|
||||||
*
|
|
||||||
*/
|
*/
|
||||||
public abstract class QLearningDiscrete<O extends Encodable> extends QLearning<O, Integer, DiscreteSpace> {
|
public abstract class QLearningDiscrete<O extends Encodable> extends QLearning<O, Integer, DiscreteSpace> {
|
||||||
|
|
||||||
@Getter
|
@Getter
|
||||||
final private QLConfiguration configuration;
|
final private QLearningConfiguration configuration;
|
||||||
private final LegacyMDPWrapper<O, Integer, DiscreteSpace> mdp;
|
private final LegacyMDPWrapper<O, Integer, DiscreteSpace> mdp;
|
||||||
@Getter
|
@Getter
|
||||||
private DQNPolicy<O> policy;
|
private DQNPolicy<O> policy;
|
||||||
|
@ -78,16 +79,15 @@ public abstract class QLearningDiscrete<O extends Encodable> extends QLearning<O
|
||||||
return mdp;
|
return mdp;
|
||||||
}
|
}
|
||||||
|
|
||||||
public QLearningDiscrete(MDP<O, Integer, DiscreteSpace> mdp, IDQN dqn, QLConfiguration conf,
|
public QLearningDiscrete(MDP<O, Integer, DiscreteSpace> mdp, IDQN dqn, QLearningConfiguration conf, int epsilonNbStep) {
|
||||||
int epsilonNbStep) {
|
|
||||||
this(mdp, dqn, conf, epsilonNbStep, Nd4j.getRandomFactory().getNewRandomInstance(conf.getSeed()));
|
this(mdp, dqn, conf, epsilonNbStep, Nd4j.getRandomFactory().getNewRandomInstance(conf.getSeed()));
|
||||||
}
|
}
|
||||||
|
|
||||||
public QLearningDiscrete(MDP<O, Integer, DiscreteSpace> mdp, IDQN dqn, QLConfiguration conf,
|
public QLearningDiscrete(MDP<O, Integer, DiscreteSpace> mdp, IDQN dqn, QLearningConfiguration conf,
|
||||||
int epsilonNbStep, Random random) {
|
int epsilonNbStep, Random random) {
|
||||||
super(conf);
|
super(conf);
|
||||||
this.configuration = conf;
|
this.configuration = conf;
|
||||||
this.mdp = new LegacyMDPWrapper<O, Integer, DiscreteSpace>(mdp, null, this);
|
this.mdp = new LegacyMDPWrapper<>(mdp, null, this);
|
||||||
qNetwork = dqn;
|
qNetwork = dqn;
|
||||||
targetQNetwork = dqn.clone();
|
targetQNetwork = dqn.clone();
|
||||||
policy = new DQNPolicy(getQNetwork());
|
policy = new DQNPolicy(getQNetwork());
|
||||||
|
@ -125,6 +125,7 @@ public abstract class QLearningDiscrete<O extends Encodable> extends QLearning<O
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Single step of training
|
* Single step of training
|
||||||
|
*
|
||||||
* @param obs last obs
|
* @param obs last obs
|
||||||
* @return relevant info for next step
|
* @return relevant info for next step
|
||||||
*/
|
*/
|
||||||
|
@ -135,8 +136,8 @@ public abstract class QLearningDiscrete<O extends Encodable> extends QLearning<O
|
||||||
boolean isHistoryProcessor = getHistoryProcessor() != null;
|
boolean isHistoryProcessor = getHistoryProcessor() != null;
|
||||||
int skipFrame = isHistoryProcessor ? getHistoryProcessor().getConf().getSkipFrame() : 1;
|
int skipFrame = isHistoryProcessor ? getHistoryProcessor().getConf().getSkipFrame() : 1;
|
||||||
int historyLength = isHistoryProcessor ? getHistoryProcessor().getConf().getHistoryLength() : 1;
|
int historyLength = isHistoryProcessor ? getHistoryProcessor().getConf().getHistoryLength() : 1;
|
||||||
int updateStart = getConfiguration().getUpdateStart()
|
int updateStart = this.getConfiguration().getUpdateStart()
|
||||||
+ ((getConfiguration().getBatchSize() + historyLength) * skipFrame);
|
+ ((this.getConfiguration().getBatchSize() + historyLength) * skipFrame);
|
||||||
|
|
||||||
Double maxQ = Double.NaN; //ignore if Nan for stats
|
Double maxQ = Double.NaN; //ignore if Nan for stats
|
||||||
|
|
||||||
|
@ -161,7 +162,7 @@ public abstract class QLearningDiscrete<O extends Encodable> extends QLearning<O
|
||||||
if (!obs.isSkipped()) {
|
if (!obs.isSkipped()) {
|
||||||
|
|
||||||
// Add experience
|
// Add experience
|
||||||
if(pendingTransition != null) {
|
if (pendingTransition != null) {
|
||||||
pendingTransition.setNextObservation(obs);
|
pendingTransition.setNextObservation(obs);
|
||||||
getExpReplay().store(pendingTransition);
|
getExpReplay().store(pendingTransition);
|
||||||
}
|
}
|
||||||
|
@ -188,7 +189,7 @@ public abstract class QLearningDiscrete<O extends Encodable> extends QLearning<O
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
protected void finishEpoch(Observation observation) {
|
protected void finishEpoch(Observation observation) {
|
||||||
if(pendingTransition != null) {
|
if (pendingTransition != null) {
|
||||||
pendingTransition.setNextObservation(observation);
|
pendingTransition.setNextObservation(observation);
|
||||||
getExpReplay().store(pendingTransition);
|
getExpReplay().store(pendingTransition);
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
/*******************************************************************************
|
/*******************************************************************************
|
||||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
* Copyright (c) 2015-2019 Skymind, Inc.
|
||||||
|
* Copyright (c) 2020 Konduit K.K.
|
||||||
*
|
*
|
||||||
* This program and the accompanying materials are made available under the
|
* This program and the accompanying materials are made available under the
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
@ -17,7 +18,9 @@
|
||||||
package org.deeplearning4j.rl4j.learning.sync.qlearning.discrete;
|
package org.deeplearning4j.rl4j.learning.sync.qlearning.discrete;
|
||||||
|
|
||||||
import org.deeplearning4j.rl4j.learning.HistoryProcessor;
|
import org.deeplearning4j.rl4j.learning.HistoryProcessor;
|
||||||
|
import org.deeplearning4j.rl4j.learning.configuration.QLearningConfiguration;
|
||||||
import org.deeplearning4j.rl4j.mdp.MDP;
|
import org.deeplearning4j.rl4j.mdp.MDP;
|
||||||
|
import org.deeplearning4j.rl4j.network.configuration.NetworkConfiguration;
|
||||||
import org.deeplearning4j.rl4j.network.dqn.DQNFactory;
|
import org.deeplearning4j.rl4j.network.dqn.DQNFactory;
|
||||||
import org.deeplearning4j.rl4j.network.dqn.DQNFactoryStdConv;
|
import org.deeplearning4j.rl4j.network.dqn.DQNFactoryStdConv;
|
||||||
import org.deeplearning4j.rl4j.network.dqn.IDQN;
|
import org.deeplearning4j.rl4j.network.dqn.IDQN;
|
||||||
|
@ -36,33 +39,55 @@ public class QLearningDiscreteConv<O extends Encodable> extends QLearningDiscret
|
||||||
|
|
||||||
@Deprecated
|
@Deprecated
|
||||||
public QLearningDiscreteConv(MDP<O, Integer, DiscreteSpace> mdp, IDQN dqn, HistoryProcessor.Configuration hpconf,
|
public QLearningDiscreteConv(MDP<O, Integer, DiscreteSpace> mdp, IDQN dqn, HistoryProcessor.Configuration hpconf,
|
||||||
QLConfiguration conf, IDataManager dataManager) {
|
QLConfiguration conf, IDataManager dataManager) {
|
||||||
this(mdp, dqn, hpconf, conf);
|
this(mdp, dqn, hpconf, conf);
|
||||||
addListener(new DataManagerTrainingListener(dataManager));
|
addListener(new DataManagerTrainingListener(dataManager));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Deprecated
|
||||||
public QLearningDiscreteConv(MDP<O, Integer, DiscreteSpace> mdp, IDQN dqn, HistoryProcessor.Configuration hpconf,
|
public QLearningDiscreteConv(MDP<O, Integer, DiscreteSpace> mdp, IDQN dqn, HistoryProcessor.Configuration hpconf,
|
||||||
QLConfiguration conf) {
|
QLConfiguration conf) {
|
||||||
|
super(mdp, dqn, conf.toLearningConfiguration(), conf.getEpsilonNbStep() * hpconf.getSkipFrame());
|
||||||
|
setHistoryProcessor(hpconf);
|
||||||
|
}
|
||||||
|
|
||||||
|
public QLearningDiscreteConv(MDP<O, Integer, DiscreteSpace> mdp, IDQN dqn, HistoryProcessor.Configuration hpconf,
|
||||||
|
QLearningConfiguration conf) {
|
||||||
super(mdp, dqn, conf, conf.getEpsilonNbStep() * hpconf.getSkipFrame());
|
super(mdp, dqn, conf, conf.getEpsilonNbStep() * hpconf.getSkipFrame());
|
||||||
setHistoryProcessor(hpconf);
|
setHistoryProcessor(hpconf);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Deprecated
|
@Deprecated
|
||||||
public QLearningDiscreteConv(MDP<O, Integer, DiscreteSpace> mdp, DQNFactory factory,
|
public QLearningDiscreteConv(MDP<O, Integer, DiscreteSpace> mdp, DQNFactory factory,
|
||||||
HistoryProcessor.Configuration hpconf, QLConfiguration conf, IDataManager dataManager) {
|
HistoryProcessor.Configuration hpconf, QLConfiguration conf, IDataManager dataManager) {
|
||||||
this(mdp, factory.buildDQN(hpconf.getShape(), mdp.getActionSpace().getSize()), hpconf, conf, dataManager);
|
this(mdp, factory.buildDQN(hpconf.getShape(), mdp.getActionSpace().getSize()), hpconf, conf, dataManager);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Deprecated
|
||||||
public QLearningDiscreteConv(MDP<O, Integer, DiscreteSpace> mdp, DQNFactory factory,
|
public QLearningDiscreteConv(MDP<O, Integer, DiscreteSpace> mdp, DQNFactory factory,
|
||||||
HistoryProcessor.Configuration hpconf, QLConfiguration conf) {
|
HistoryProcessor.Configuration hpconf, QLConfiguration conf) {
|
||||||
this(mdp, factory.buildDQN(hpconf.getShape(), mdp.getActionSpace().getSize()), hpconf, conf);
|
this(mdp, factory.buildDQN(hpconf.getShape(), mdp.getActionSpace().getSize()), hpconf, conf);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public QLearningDiscreteConv(MDP<O, Integer, DiscreteSpace> mdp, DQNFactory factory,
|
||||||
|
HistoryProcessor.Configuration hpconf, QLearningConfiguration conf) {
|
||||||
|
this(mdp, factory.buildDQN(hpconf.getShape(), mdp.getActionSpace().getSize()), hpconf, conf);
|
||||||
|
}
|
||||||
|
|
||||||
@Deprecated
|
@Deprecated
|
||||||
public QLearningDiscreteConv(MDP<O, Integer, DiscreteSpace> mdp, DQNFactoryStdConv.Configuration netConf,
|
public QLearningDiscreteConv(MDP<O, Integer, DiscreteSpace> mdp, DQNFactoryStdConv.Configuration netConf,
|
||||||
HistoryProcessor.Configuration hpconf, QLConfiguration conf, IDataManager dataManager) {
|
HistoryProcessor.Configuration hpconf, QLConfiguration conf, IDataManager dataManager) {
|
||||||
this(mdp, new DQNFactoryStdConv(netConf), hpconf, conf, dataManager);
|
this(mdp, new DQNFactoryStdConv(netConf.toNetworkConfiguration()), hpconf, conf, dataManager);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Deprecated
|
||||||
public QLearningDiscreteConv(MDP<O, Integer, DiscreteSpace> mdp, DQNFactoryStdConv.Configuration netConf,
|
public QLearningDiscreteConv(MDP<O, Integer, DiscreteSpace> mdp, DQNFactoryStdConv.Configuration netConf,
|
||||||
HistoryProcessor.Configuration hpconf, QLConfiguration conf) {
|
HistoryProcessor.Configuration hpconf, QLConfiguration conf) {
|
||||||
|
this(mdp, new DQNFactoryStdConv(netConf.toNetworkConfiguration()), hpconf, conf);
|
||||||
|
}
|
||||||
|
|
||||||
|
public QLearningDiscreteConv(MDP<O, Integer, DiscreteSpace> mdp, NetworkConfiguration netConf,
|
||||||
|
HistoryProcessor.Configuration hpconf, QLearningConfiguration conf) {
|
||||||
this(mdp, new DQNFactoryStdConv(netConf), hpconf, conf);
|
this(mdp, new DQNFactoryStdConv(netConf), hpconf, conf);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
/*******************************************************************************
|
/*******************************************************************************
|
||||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
* Copyright (c) 2015-2019 Skymind, Inc.
|
||||||
|
* Copyright (c) 2020 Konduit K.K.
|
||||||
*
|
*
|
||||||
* This program and the accompanying materials are made available under the
|
* This program and the accompanying materials are made available under the
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
@ -16,8 +17,10 @@
|
||||||
|
|
||||||
package org.deeplearning4j.rl4j.learning.sync.qlearning.discrete;
|
package org.deeplearning4j.rl4j.learning.sync.qlearning.discrete;
|
||||||
|
|
||||||
|
import org.deeplearning4j.rl4j.learning.configuration.QLearningConfiguration;
|
||||||
import org.deeplearning4j.rl4j.learning.sync.qlearning.QLearning;
|
import org.deeplearning4j.rl4j.learning.sync.qlearning.QLearning;
|
||||||
import org.deeplearning4j.rl4j.mdp.MDP;
|
import org.deeplearning4j.rl4j.mdp.MDP;
|
||||||
|
import org.deeplearning4j.rl4j.network.configuration.DQNDenseNetworkConfiguration;
|
||||||
import org.deeplearning4j.rl4j.network.dqn.DQNFactory;
|
import org.deeplearning4j.rl4j.network.dqn.DQNFactory;
|
||||||
import org.deeplearning4j.rl4j.network.dqn.DQNFactoryStdDense;
|
import org.deeplearning4j.rl4j.network.dqn.DQNFactoryStdDense;
|
||||||
import org.deeplearning4j.rl4j.network.dqn.IDQN;
|
import org.deeplearning4j.rl4j.network.dqn.IDQN;
|
||||||
|
@ -38,7 +41,13 @@ public class QLearningDiscreteDense<O extends Encodable> extends QLearningDiscre
|
||||||
this(mdp, dqn, conf);
|
this(mdp, dqn, conf);
|
||||||
addListener(new DataManagerTrainingListener(dataManager));
|
addListener(new DataManagerTrainingListener(dataManager));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Deprecated
|
||||||
public QLearningDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp, IDQN dqn, QLearning.QLConfiguration conf) {
|
public QLearningDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp, IDQN dqn, QLearning.QLConfiguration conf) {
|
||||||
|
super(mdp, dqn, conf.toLearningConfiguration(), conf.getEpsilonNbStep());
|
||||||
|
}
|
||||||
|
|
||||||
|
public QLearningDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp, IDQN dqn, QLearningConfiguration conf) {
|
||||||
super(mdp, dqn, conf, conf.getEpsilonNbStep());
|
super(mdp, dqn, conf, conf.getEpsilonNbStep());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -48,18 +57,33 @@ public class QLearningDiscreteDense<O extends Encodable> extends QLearningDiscre
|
||||||
this(mdp, factory.buildDQN(mdp.getObservationSpace().getShape(), mdp.getActionSpace().getSize()), conf,
|
this(mdp, factory.buildDQN(mdp.getObservationSpace().getShape(), mdp.getActionSpace().getSize()), conf,
|
||||||
dataManager);
|
dataManager);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Deprecated
|
||||||
public QLearningDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp, DQNFactory factory,
|
public QLearningDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp, DQNFactory factory,
|
||||||
QLearning.QLConfiguration conf) {
|
QLearning.QLConfiguration conf) {
|
||||||
this(mdp, factory.buildDQN(mdp.getObservationSpace().getShape(), mdp.getActionSpace().getSize()), conf);
|
this(mdp, factory.buildDQN(mdp.getObservationSpace().getShape(), mdp.getActionSpace().getSize()), conf);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public QLearningDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp, DQNFactory factory,
|
||||||
|
QLearningConfiguration conf) {
|
||||||
|
this(mdp, factory.buildDQN(mdp.getObservationSpace().getShape(), mdp.getActionSpace().getSize()), conf);
|
||||||
|
}
|
||||||
|
|
||||||
@Deprecated
|
@Deprecated
|
||||||
public QLearningDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp, DQNFactoryStdDense.Configuration netConf,
|
public QLearningDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp, DQNFactoryStdDense.Configuration netConf,
|
||||||
QLearning.QLConfiguration conf, IDataManager dataManager) {
|
QLearning.QLConfiguration conf, IDataManager dataManager) {
|
||||||
this(mdp, new DQNFactoryStdDense(netConf), conf, dataManager);
|
|
||||||
|
this(mdp, new DQNFactoryStdDense(netConf.toNetworkConfiguration()), conf, dataManager);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Deprecated
|
||||||
public QLearningDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp, DQNFactoryStdDense.Configuration netConf,
|
public QLearningDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp, DQNFactoryStdDense.Configuration netConf,
|
||||||
QLearning.QLConfiguration conf) {
|
QLearning.QLConfiguration conf) {
|
||||||
|
this(mdp, new DQNFactoryStdDense(netConf.toNetworkConfiguration()), conf);
|
||||||
|
}
|
||||||
|
|
||||||
|
public QLearningDiscreteDense(MDP<O, Integer, DiscreteSpace> mdp, DQNDenseNetworkConfiguration netConf,
|
||||||
|
QLearningConfiguration conf) {
|
||||||
this(mdp, new DQNFactoryStdDense(netConf), conf);
|
this(mdp, new DQNFactoryStdDense(netConf), conf);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
/*******************************************************************************
|
/*******************************************************************************
|
||||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
* Copyright (c) 2015-2019 Skymind, Inc.
|
||||||
|
* Copyright (c) 2020 Konduit K.K.
|
||||||
*
|
*
|
||||||
* This program and the accompanying materials are made available under the
|
* This program and the accompanying materials are made available under the
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
@ -36,7 +37,7 @@ import java.util.Collection;
|
||||||
*
|
*
|
||||||
* Standard implementation of ActorCriticCompGraph
|
* Standard implementation of ActorCriticCompGraph
|
||||||
*/
|
*/
|
||||||
public class ActorCriticCompGraph<NN extends ActorCriticCompGraph> implements IActorCritic<NN> {
|
public class ActorCriticCompGraph implements IActorCritic<ActorCriticCompGraph> {
|
||||||
|
|
||||||
final protected ComputationGraph cg;
|
final protected ComputationGraph cg;
|
||||||
@Getter
|
@Getter
|
||||||
|
@ -73,13 +74,13 @@ public class ActorCriticCompGraph<NN extends ActorCriticCompGraph> implements IA
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
public NN clone() {
|
public ActorCriticCompGraph clone() {
|
||||||
NN nn = (NN)new ActorCriticCompGraph(cg.clone());
|
ActorCriticCompGraph nn = new ActorCriticCompGraph(cg.clone());
|
||||||
nn.cg.setListeners(cg.getListeners());
|
nn.cg.setListeners(cg.getListeners());
|
||||||
return nn;
|
return nn;
|
||||||
}
|
}
|
||||||
|
|
||||||
public void copy(NN from) {
|
public void copy(ActorCriticCompGraph from) {
|
||||||
cg.setParams(from.cg.params());
|
cg.setParams(from.cg.params());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
/*******************************************************************************
|
/*******************************************************************************
|
||||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
* Copyright (c) 2015-2019 Skymind, Inc.
|
||||||
|
* Copyright (c) 2020 Konduit K.K.
|
||||||
*
|
*
|
||||||
* This program and the accompanying materials are made available under the
|
* This program and the accompanying materials are made available under the
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
@ -31,12 +32,16 @@ import org.deeplearning4j.nn.graph.ComputationGraph;
|
||||||
import org.deeplearning4j.nn.weights.WeightInit;
|
import org.deeplearning4j.nn.weights.WeightInit;
|
||||||
import org.deeplearning4j.optimize.api.TrainingListener;
|
import org.deeplearning4j.optimize.api.TrainingListener;
|
||||||
import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
|
import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
|
||||||
|
import org.deeplearning4j.rl4j.network.configuration.ActorCriticNetworkConfiguration;
|
||||||
|
import org.deeplearning4j.rl4j.network.configuration.ActorCriticNetworkConfiguration.ActorCriticNetworkConfigurationBuilder;
|
||||||
import org.deeplearning4j.rl4j.util.Constants;
|
import org.deeplearning4j.rl4j.util.Constants;
|
||||||
import org.nd4j.linalg.activations.Activation;
|
import org.nd4j.linalg.activations.Activation;
|
||||||
import org.nd4j.linalg.learning.config.Adam;
|
import org.nd4j.linalg.learning.config.Adam;
|
||||||
import org.nd4j.linalg.learning.config.IUpdater;
|
import org.nd4j.linalg.learning.config.IUpdater;
|
||||||
import org.nd4j.linalg.lossfunctions.LossFunctions;
|
import org.nd4j.linalg.lossfunctions.LossFunctions;
|
||||||
|
|
||||||
|
import java.util.Arrays;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @author rubenfiszel (ruben.fiszel@epfl.ch) on 8/9/16.
|
* @author rubenfiszel (ruben.fiszel@epfl.ch) on 8/9/16.
|
||||||
*
|
*
|
||||||
|
@ -45,8 +50,7 @@ import org.nd4j.linalg.lossfunctions.LossFunctions;
|
||||||
@Value
|
@Value
|
||||||
public class ActorCriticFactoryCompGraphStdConv implements ActorCriticFactoryCompGraph {
|
public class ActorCriticFactoryCompGraphStdConv implements ActorCriticFactoryCompGraph {
|
||||||
|
|
||||||
|
ActorCriticNetworkConfiguration conf;
|
||||||
Configuration conf;
|
|
||||||
|
|
||||||
public ActorCriticCompGraph buildActorCritic(int shapeInputs[], int numOutputs) {
|
public ActorCriticCompGraph buildActorCritic(int shapeInputs[], int numOutputs) {
|
||||||
|
|
||||||
|
@ -109,16 +113,33 @@ public class ActorCriticFactoryCompGraphStdConv implements ActorCriticFactoryCom
|
||||||
return new ActorCriticCompGraph(model);
|
return new ActorCriticCompGraph(model);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@AllArgsConstructor
|
@AllArgsConstructor
|
||||||
@Builder
|
@Builder
|
||||||
@Value
|
@Value
|
||||||
|
@Deprecated
|
||||||
public static class Configuration {
|
public static class Configuration {
|
||||||
|
|
||||||
double l2;
|
double l2;
|
||||||
IUpdater updater;
|
IUpdater updater;
|
||||||
TrainingListener[] listeners;
|
TrainingListener[] listeners;
|
||||||
boolean useLSTM;
|
boolean useLSTM;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Converts the deprecated Configuration to the new NetworkConfiguration format
|
||||||
|
*/
|
||||||
|
public ActorCriticNetworkConfiguration toNetworkConfiguration() {
|
||||||
|
ActorCriticNetworkConfigurationBuilder builder = ActorCriticNetworkConfiguration.builder()
|
||||||
|
.l2(l2)
|
||||||
|
.updater(updater)
|
||||||
|
.useLSTM(useLSTM);
|
||||||
|
|
||||||
|
if (listeners != null) {
|
||||||
|
builder.listeners(Arrays.asList(listeners));
|
||||||
|
}
|
||||||
|
|
||||||
|
return builder.build();
|
||||||
|
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
/*******************************************************************************
|
/*******************************************************************************
|
||||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
* Copyright (c) 2015-2019 Skymind, Inc.
|
||||||
|
* Copyright (c) 2020 Konduit K.K.
|
||||||
*
|
*
|
||||||
* This program and the accompanying materials are made available under the
|
* This program and the accompanying materials are made available under the
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
@ -16,8 +17,6 @@
|
||||||
|
|
||||||
package org.deeplearning4j.rl4j.network.ac;
|
package org.deeplearning4j.rl4j.network.ac;
|
||||||
|
|
||||||
import lombok.AllArgsConstructor;
|
|
||||||
import lombok.Builder;
|
|
||||||
import lombok.Value;
|
import lombok.Value;
|
||||||
import org.deeplearning4j.nn.api.OptimizationAlgorithm;
|
import org.deeplearning4j.nn.api.OptimizationAlgorithm;
|
||||||
import org.deeplearning4j.nn.conf.ComputationGraphConfiguration;
|
import org.deeplearning4j.nn.conf.ComputationGraphConfiguration;
|
||||||
|
@ -29,12 +28,11 @@ import org.deeplearning4j.nn.conf.layers.OutputLayer;
|
||||||
import org.deeplearning4j.nn.conf.layers.RnnOutputLayer;
|
import org.deeplearning4j.nn.conf.layers.RnnOutputLayer;
|
||||||
import org.deeplearning4j.nn.graph.ComputationGraph;
|
import org.deeplearning4j.nn.graph.ComputationGraph;
|
||||||
import org.deeplearning4j.nn.weights.WeightInit;
|
import org.deeplearning4j.nn.weights.WeightInit;
|
||||||
import org.deeplearning4j.optimize.api.TrainingListener;
|
|
||||||
import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
|
import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
|
||||||
|
import org.deeplearning4j.rl4j.network.configuration.ActorCriticDenseNetworkConfiguration;
|
||||||
import org.deeplearning4j.rl4j.util.Constants;
|
import org.deeplearning4j.rl4j.util.Constants;
|
||||||
import org.nd4j.linalg.activations.Activation;
|
import org.nd4j.linalg.activations.Activation;
|
||||||
import org.nd4j.linalg.learning.config.Adam;
|
import org.nd4j.linalg.learning.config.Adam;
|
||||||
import org.nd4j.linalg.learning.config.IUpdater;
|
|
||||||
import org.nd4j.linalg.lossfunctions.LossFunctions;
|
import org.nd4j.linalg.lossfunctions.LossFunctions;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -45,7 +43,7 @@ import org.nd4j.linalg.lossfunctions.LossFunctions;
|
||||||
@Value
|
@Value
|
||||||
public class ActorCriticFactoryCompGraphStdDense implements ActorCriticFactoryCompGraph {
|
public class ActorCriticFactoryCompGraphStdDense implements ActorCriticFactoryCompGraph {
|
||||||
|
|
||||||
Configuration conf;
|
ActorCriticDenseNetworkConfiguration conf;
|
||||||
|
|
||||||
public ActorCriticCompGraph buildActorCritic(int[] numInputs, int numOutputs) {
|
public ActorCriticCompGraph buildActorCritic(int[] numInputs, int numOutputs) {
|
||||||
int nIn = 1;
|
int nIn = 1;
|
||||||
|
@ -65,27 +63,27 @@ public class ActorCriticFactoryCompGraphStdDense implements ActorCriticFactoryCo
|
||||||
"input");
|
"input");
|
||||||
|
|
||||||
|
|
||||||
for (int i = 1; i < conf.getNumLayer(); i++) {
|
for (int i = 1; i < conf.getNumLayers(); i++) {
|
||||||
confB.addLayer(i + "", new DenseLayer.Builder().nIn(conf.getNumHiddenNodes()).nOut(conf.getNumHiddenNodes())
|
confB.addLayer(i + "", new DenseLayer.Builder().nIn(conf.getNumHiddenNodes()).nOut(conf.getNumHiddenNodes())
|
||||||
.activation(Activation.RELU).build(), (i - 1) + "");
|
.activation(Activation.RELU).build(), (i - 1) + "");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
if (conf.isUseLSTM()) {
|
if (conf.isUseLSTM()) {
|
||||||
confB.addLayer(getConf().getNumLayer() + "", new LSTM.Builder().activation(Activation.TANH)
|
confB.addLayer(getConf().getNumLayers() + "", new LSTM.Builder().activation(Activation.TANH)
|
||||||
.nOut(conf.getNumHiddenNodes()).build(), (getConf().getNumLayer() - 1) + "");
|
.nOut(conf.getNumHiddenNodes()).build(), (getConf().getNumLayers() - 1) + "");
|
||||||
|
|
||||||
confB.addLayer("value", new RnnOutputLayer.Builder(LossFunctions.LossFunction.MSE).activation(Activation.IDENTITY)
|
confB.addLayer("value", new RnnOutputLayer.Builder(LossFunctions.LossFunction.MSE).activation(Activation.IDENTITY)
|
||||||
.nOut(1).build(), getConf().getNumLayer() + "");
|
.nOut(1).build(), getConf().getNumLayers() + "");
|
||||||
|
|
||||||
confB.addLayer("softmax", new RnnOutputLayer.Builder(new ActorCriticLoss()).activation(Activation.SOFTMAX)
|
confB.addLayer("softmax", new RnnOutputLayer.Builder(new ActorCriticLoss()).activation(Activation.SOFTMAX)
|
||||||
.nOut(numOutputs).build(), getConf().getNumLayer() + "");
|
.nOut(numOutputs).build(), getConf().getNumLayers() + "");
|
||||||
} else {
|
} else {
|
||||||
confB.addLayer("value", new OutputLayer.Builder(LossFunctions.LossFunction.MSE).activation(Activation.IDENTITY)
|
confB.addLayer("value", new OutputLayer.Builder(LossFunctions.LossFunction.MSE).activation(Activation.IDENTITY)
|
||||||
.nOut(1).build(), (getConf().getNumLayer() - 1) + "");
|
.nOut(1).build(), (getConf().getNumLayers() - 1) + "");
|
||||||
|
|
||||||
confB.addLayer("softmax", new OutputLayer.Builder(new ActorCriticLoss()).activation(Activation.SOFTMAX)
|
confB.addLayer("softmax", new OutputLayer.Builder(new ActorCriticLoss()).activation(Activation.SOFTMAX)
|
||||||
.nOut(numOutputs).build(), (getConf().getNumLayer() - 1) + "");
|
.nOut(numOutputs).build(), (getConf().getNumLayers() - 1) + "");
|
||||||
}
|
}
|
||||||
|
|
||||||
confB.setOutputs("value", "softmax");
|
confB.setOutputs("value", "softmax");
|
||||||
|
@ -103,18 +101,4 @@ public class ActorCriticFactoryCompGraphStdDense implements ActorCriticFactoryCo
|
||||||
return new ActorCriticCompGraph(model);
|
return new ActorCriticCompGraph(model);
|
||||||
}
|
}
|
||||||
|
|
||||||
@AllArgsConstructor
|
|
||||||
@Builder
|
|
||||||
@Value
|
|
||||||
public static class Configuration {
|
|
||||||
|
|
||||||
int numLayer;
|
|
||||||
int numHiddenNodes;
|
|
||||||
double l2;
|
|
||||||
IUpdater updater;
|
|
||||||
TrainingListener[] listeners;
|
|
||||||
boolean useLSTM;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
/*******************************************************************************
|
/*******************************************************************************
|
||||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
* Copyright (c) 2015-2019 Skymind, Inc.
|
||||||
|
* Copyright (c) 2020 Konduit K.K.
|
||||||
*
|
*
|
||||||
* This program and the accompanying materials are made available under the
|
* This program and the accompanying materials are made available under the
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
@ -31,21 +32,24 @@ import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
|
||||||
import org.deeplearning4j.nn.weights.WeightInit;
|
import org.deeplearning4j.nn.weights.WeightInit;
|
||||||
import org.deeplearning4j.optimize.api.TrainingListener;
|
import org.deeplearning4j.optimize.api.TrainingListener;
|
||||||
import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
|
import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
|
||||||
|
|
||||||
|
import org.deeplearning4j.rl4j.network.configuration.ActorCriticDenseNetworkConfiguration;
|
||||||
|
import org.deeplearning4j.rl4j.network.configuration.ActorCriticDenseNetworkConfiguration.ActorCriticDenseNetworkConfigurationBuilder;
|
||||||
import org.deeplearning4j.rl4j.util.Constants;
|
import org.deeplearning4j.rl4j.util.Constants;
|
||||||
import org.nd4j.linalg.activations.Activation;
|
import org.nd4j.linalg.activations.Activation;
|
||||||
import org.nd4j.linalg.learning.config.Adam;
|
import org.nd4j.linalg.learning.config.Adam;
|
||||||
import org.nd4j.linalg.learning.config.IUpdater;
|
import org.nd4j.linalg.learning.config.IUpdater;
|
||||||
import org.nd4j.linalg.lossfunctions.LossFunctions;
|
import org.nd4j.linalg.lossfunctions.LossFunctions;
|
||||||
|
|
||||||
|
import java.util.Arrays;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @author rubenfiszel (ruben.fiszel@epfl.ch) on 8/9/16.
|
* @author rubenfiszel (ruben.fiszel@epfl.ch) on 8/9/16.
|
||||||
*
|
|
||||||
*
|
|
||||||
*/
|
*/
|
||||||
@Value
|
@Value
|
||||||
public class ActorCriticFactorySeparateStdDense implements ActorCriticFactorySeparate {
|
public class ActorCriticFactorySeparateStdDense implements ActorCriticFactorySeparate {
|
||||||
|
|
||||||
Configuration conf;
|
ActorCriticDenseNetworkConfiguration conf;
|
||||||
|
|
||||||
public ActorCriticSeparate buildActorCritic(int[] numInputs, int numOutputs) {
|
public ActorCriticSeparate buildActorCritic(int[] numInputs, int numOutputs) {
|
||||||
int nIn = 1;
|
int nIn = 1;
|
||||||
|
@ -53,27 +57,27 @@ public class ActorCriticFactorySeparateStdDense implements ActorCriticFactorySep
|
||||||
nIn *= i;
|
nIn *= i;
|
||||||
}
|
}
|
||||||
NeuralNetConfiguration.ListBuilder confB = new NeuralNetConfiguration.Builder().seed(Constants.NEURAL_NET_SEED)
|
NeuralNetConfiguration.ListBuilder confB = new NeuralNetConfiguration.Builder().seed(Constants.NEURAL_NET_SEED)
|
||||||
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
|
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
|
||||||
.updater(conf.getUpdater() != null ? conf.getUpdater() : new Adam())
|
.updater(conf.getUpdater() != null ? conf.getUpdater() : new Adam())
|
||||||
.weightInit(WeightInit.XAVIER)
|
.weightInit(WeightInit.XAVIER)
|
||||||
.l2(conf.getL2())
|
.l2(conf.getL2())
|
||||||
.list().layer(0, new DenseLayer.Builder().nIn(nIn).nOut(conf.getNumHiddenNodes())
|
.list().layer(0, new DenseLayer.Builder().nIn(nIn).nOut(conf.getNumHiddenNodes())
|
||||||
.activation(Activation.RELU).build());
|
.activation(Activation.RELU).build());
|
||||||
|
|
||||||
|
|
||||||
for (int i = 1; i < conf.getNumLayer(); i++) {
|
for (int i = 1; i < conf.getNumLayers(); i++) {
|
||||||
confB.layer(i, new DenseLayer.Builder().nIn(conf.getNumHiddenNodes()).nOut(conf.getNumHiddenNodes())
|
confB.layer(i, new DenseLayer.Builder().nIn(conf.getNumHiddenNodes()).nOut(conf.getNumHiddenNodes())
|
||||||
.activation(Activation.RELU).build());
|
.activation(Activation.RELU).build());
|
||||||
}
|
}
|
||||||
|
|
||||||
if (conf.isUseLSTM()) {
|
if (conf.isUseLSTM()) {
|
||||||
confB.layer(conf.getNumLayer(), new LSTM.Builder().nOut(conf.getNumHiddenNodes()).activation(Activation.TANH).build());
|
confB.layer(conf.getNumLayers(), new LSTM.Builder().nOut(conf.getNumHiddenNodes()).activation(Activation.TANH).build());
|
||||||
|
|
||||||
confB.layer(conf.getNumLayer() + 1, new RnnOutputLayer.Builder(LossFunctions.LossFunction.MSE).activation(Activation.IDENTITY)
|
confB.layer(conf.getNumLayers() + 1, new RnnOutputLayer.Builder(LossFunctions.LossFunction.MSE).activation(Activation.IDENTITY)
|
||||||
.nIn(conf.getNumHiddenNodes()).nOut(1).build());
|
.nIn(conf.getNumHiddenNodes()).nOut(1).build());
|
||||||
} else {
|
} else {
|
||||||
confB.layer(conf.getNumLayer(), new OutputLayer.Builder(LossFunctions.LossFunction.MSE).activation(Activation.IDENTITY)
|
confB.layer(conf.getNumLayers(), new OutputLayer.Builder(LossFunctions.LossFunction.MSE).activation(Activation.IDENTITY)
|
||||||
.nIn(conf.getNumHiddenNodes()).nOut(1).build());
|
.nIn(conf.getNumHiddenNodes()).nOut(1).build());
|
||||||
}
|
}
|
||||||
|
|
||||||
confB.setInputType(conf.isUseLSTM() ? InputType.recurrent(nIn) : InputType.feedForward(nIn));
|
confB.setInputType(conf.isUseLSTM() ? InputType.recurrent(nIn) : InputType.feedForward(nIn));
|
||||||
|
@ -87,28 +91,28 @@ public class ActorCriticFactorySeparateStdDense implements ActorCriticFactorySep
|
||||||
}
|
}
|
||||||
|
|
||||||
NeuralNetConfiguration.ListBuilder confB2 = new NeuralNetConfiguration.Builder().seed(Constants.NEURAL_NET_SEED)
|
NeuralNetConfiguration.ListBuilder confB2 = new NeuralNetConfiguration.Builder().seed(Constants.NEURAL_NET_SEED)
|
||||||
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
|
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
|
||||||
.updater(conf.getUpdater() != null ? conf.getUpdater() : new Adam())
|
.updater(conf.getUpdater() != null ? conf.getUpdater() : new Adam())
|
||||||
.weightInit(WeightInit.XAVIER)
|
.weightInit(WeightInit.XAVIER)
|
||||||
//.regularization(true)
|
//.regularization(true)
|
||||||
//.l2(conf.getL2())
|
//.l2(conf.getL2())
|
||||||
.list().layer(0, new DenseLayer.Builder().nIn(nIn).nOut(conf.getNumHiddenNodes())
|
.list().layer(0, new DenseLayer.Builder().nIn(nIn).nOut(conf.getNumHiddenNodes())
|
||||||
.activation(Activation.RELU).build());
|
.activation(Activation.RELU).build());
|
||||||
|
|
||||||
|
|
||||||
for (int i = 1; i < conf.getNumLayer(); i++) {
|
for (int i = 1; i < conf.getNumLayers(); i++) {
|
||||||
confB2.layer(i, new DenseLayer.Builder().nIn(conf.getNumHiddenNodes()).nOut(conf.getNumHiddenNodes())
|
confB2.layer(i, new DenseLayer.Builder().nIn(conf.getNumHiddenNodes()).nOut(conf.getNumHiddenNodes())
|
||||||
.activation(Activation.RELU).build());
|
.activation(Activation.RELU).build());
|
||||||
}
|
}
|
||||||
|
|
||||||
if (conf.isUseLSTM()) {
|
if (conf.isUseLSTM()) {
|
||||||
confB2.layer(conf.getNumLayer(), new LSTM.Builder().nOut(conf.getNumHiddenNodes()).activation(Activation.TANH).build());
|
confB2.layer(conf.getNumLayers(), new LSTM.Builder().nOut(conf.getNumHiddenNodes()).activation(Activation.TANH).build());
|
||||||
|
|
||||||
confB2.layer(conf.getNumLayer() + 1, new RnnOutputLayer.Builder(new ActorCriticLoss())
|
confB2.layer(conf.getNumLayers() + 1, new RnnOutputLayer.Builder(new ActorCriticLoss())
|
||||||
.activation(Activation.SOFTMAX).nIn(conf.getNumHiddenNodes()).nOut(numOutputs).build());
|
.activation(Activation.SOFTMAX).nIn(conf.getNumHiddenNodes()).nOut(numOutputs).build());
|
||||||
} else {
|
} else {
|
||||||
confB2.layer(conf.getNumLayer(), new OutputLayer.Builder(new ActorCriticLoss())
|
confB2.layer(conf.getNumLayers(), new OutputLayer.Builder(new ActorCriticLoss())
|
||||||
.activation(Activation.SOFTMAX).nIn(conf.getNumHiddenNodes()).nOut(numOutputs).build());
|
.activation(Activation.SOFTMAX).nIn(conf.getNumHiddenNodes()).nOut(numOutputs).build());
|
||||||
}
|
}
|
||||||
|
|
||||||
confB2.setInputType(conf.isUseLSTM() ? InputType.recurrent(nIn) : InputType.feedForward(nIn));
|
confB2.setInputType(conf.isUseLSTM() ? InputType.recurrent(nIn) : InputType.feedForward(nIn));
|
||||||
|
@ -128,6 +132,7 @@ public class ActorCriticFactorySeparateStdDense implements ActorCriticFactorySep
|
||||||
@AllArgsConstructor
|
@AllArgsConstructor
|
||||||
@Value
|
@Value
|
||||||
@Builder
|
@Builder
|
||||||
|
@Deprecated
|
||||||
public static class Configuration {
|
public static class Configuration {
|
||||||
|
|
||||||
int numLayer;
|
int numLayer;
|
||||||
|
@ -136,6 +141,22 @@ public class ActorCriticFactorySeparateStdDense implements ActorCriticFactorySep
|
||||||
IUpdater updater;
|
IUpdater updater;
|
||||||
TrainingListener[] listeners;
|
TrainingListener[] listeners;
|
||||||
boolean useLSTM;
|
boolean useLSTM;
|
||||||
|
|
||||||
|
public ActorCriticDenseNetworkConfiguration toNetworkConfiguration() {
|
||||||
|
ActorCriticDenseNetworkConfigurationBuilder builder = ActorCriticDenseNetworkConfiguration.builder()
|
||||||
|
.numHiddenNodes(numHiddenNodes)
|
||||||
|
.numLayers(numLayer)
|
||||||
|
.l2(l2)
|
||||||
|
.updater(updater)
|
||||||
|
.useLSTM(useLSTM);
|
||||||
|
|
||||||
|
if (listeners != null) {
|
||||||
|
builder.listeners(Arrays.asList(listeners));
|
||||||
|
}
|
||||||
|
|
||||||
|
return builder.build();
|
||||||
|
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,42 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2020 Konduit K.K.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
package org.deeplearning4j.rl4j.network.configuration;
|
||||||
|
|
||||||
|
import lombok.Builder;
|
||||||
|
import lombok.Data;
|
||||||
|
import lombok.EqualsAndHashCode;
|
||||||
|
import lombok.experimental.SuperBuilder;
|
||||||
|
|
||||||
|
@Data
|
||||||
|
@SuperBuilder
|
||||||
|
@EqualsAndHashCode(callSuper = true)
|
||||||
|
public class ActorCriticDenseNetworkConfiguration extends ActorCriticNetworkConfiguration {
|
||||||
|
|
||||||
|
/**
|
||||||
|
* The number of layers in the dense network
|
||||||
|
*/
|
||||||
|
@Builder.Default
|
||||||
|
private int numLayers = 3;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* The number of hidden neurons in each layer
|
||||||
|
*/
|
||||||
|
@Builder.Default
|
||||||
|
private int numHiddenNodes = 100;
|
||||||
|
|
||||||
|
|
||||||
|
}
|
|
@ -0,0 +1,37 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2020 Konduit K.K.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
package org.deeplearning4j.rl4j.network.configuration;
|
||||||
|
|
||||||
|
import lombok.Builder;
|
||||||
|
import lombok.Data;
|
||||||
|
import lombok.EqualsAndHashCode;
|
||||||
|
import lombok.NoArgsConstructor;
|
||||||
|
import lombok.experimental.SuperBuilder;
|
||||||
|
|
||||||
|
@Data
|
||||||
|
@SuperBuilder
|
||||||
|
@NoArgsConstructor
|
||||||
|
@EqualsAndHashCode(callSuper = true)
|
||||||
|
public class ActorCriticNetworkConfiguration extends NetworkConfiguration {
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Whether or not to add an LSTM layer to the network.
|
||||||
|
*/
|
||||||
|
@Builder.Default
|
||||||
|
private boolean useLSTM = false;
|
||||||
|
|
||||||
|
}
|
|
@ -0,0 +1,40 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2020 Konduit K.K.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
package org.deeplearning4j.rl4j.network.configuration;
|
||||||
|
|
||||||
|
import lombok.Builder;
|
||||||
|
import lombok.Data;
|
||||||
|
import lombok.EqualsAndHashCode;
|
||||||
|
import lombok.experimental.SuperBuilder;
|
||||||
|
|
||||||
|
@Data
|
||||||
|
@SuperBuilder
|
||||||
|
@EqualsAndHashCode(callSuper = true)
|
||||||
|
public class DQNDenseNetworkConfiguration extends NetworkConfiguration {
|
||||||
|
|
||||||
|
/**
|
||||||
|
* The number of layers in the dense network
|
||||||
|
*/
|
||||||
|
@Builder.Default
|
||||||
|
private int numLayers = 3;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* The number of hidden neurons in each layer
|
||||||
|
*/
|
||||||
|
@Builder.Default
|
||||||
|
private int numHiddenNodes = 100;
|
||||||
|
}
|
|
@ -0,0 +1,58 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2020 Konduit K.K.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
package org.deeplearning4j.rl4j.network.configuration;
|
||||||
|
|
||||||
|
import lombok.Builder;
|
||||||
|
import lombok.Data;
|
||||||
|
import lombok.NoArgsConstructor;
|
||||||
|
import lombok.Singular;
|
||||||
|
import lombok.experimental.SuperBuilder;
|
||||||
|
import org.deeplearning4j.optimize.api.TrainingListener;
|
||||||
|
import org.nd4j.linalg.learning.config.IUpdater;
|
||||||
|
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
|
|
||||||
|
@Data
|
||||||
|
@SuperBuilder
|
||||||
|
@NoArgsConstructor
|
||||||
|
public class NetworkConfiguration {
|
||||||
|
|
||||||
|
/**
|
||||||
|
* The learning rate of the network
|
||||||
|
*/
|
||||||
|
@Builder.Default
|
||||||
|
private double learningRate = 0.01;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* L2 regularization on the network
|
||||||
|
*/
|
||||||
|
@Builder.Default
|
||||||
|
private double l2 = 0.0;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* The network's gradient update algorithm
|
||||||
|
*/
|
||||||
|
private IUpdater updater;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Training listeners attached to the network
|
||||||
|
*/
|
||||||
|
@Singular
|
||||||
|
private List<TrainingListener> listeners;
|
||||||
|
|
||||||
|
}
|
|
@ -1,5 +1,6 @@
|
||||||
/*******************************************************************************
|
/*******************************************************************************
|
||||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
* Copyright (c) 2015-2019 Skymind, Inc.
|
||||||
|
* Copyright (c) 2020 Konduit K.K.
|
||||||
*
|
*
|
||||||
* This program and the accompanying materials are made available under the
|
* This program and the accompanying materials are made available under the
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
@ -30,12 +31,15 @@ import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
|
||||||
import org.deeplearning4j.nn.weights.WeightInit;
|
import org.deeplearning4j.nn.weights.WeightInit;
|
||||||
import org.deeplearning4j.optimize.api.TrainingListener;
|
import org.deeplearning4j.optimize.api.TrainingListener;
|
||||||
import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
|
import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
|
||||||
|
import org.deeplearning4j.rl4j.network.configuration.NetworkConfiguration;
|
||||||
import org.deeplearning4j.rl4j.util.Constants;
|
import org.deeplearning4j.rl4j.util.Constants;
|
||||||
import org.nd4j.linalg.activations.Activation;
|
import org.nd4j.linalg.activations.Activation;
|
||||||
import org.nd4j.linalg.learning.config.Adam;
|
import org.nd4j.linalg.learning.config.Adam;
|
||||||
import org.nd4j.linalg.learning.config.IUpdater;
|
import org.nd4j.linalg.learning.config.IUpdater;
|
||||||
import org.nd4j.linalg.lossfunctions.LossFunctions;
|
import org.nd4j.linalg.lossfunctions.LossFunctions;
|
||||||
|
|
||||||
|
import java.util.Arrays;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @author rubenfiszel (ruben.fiszel@epfl.ch) 7/13/16.
|
* @author rubenfiszel (ruben.fiszel@epfl.ch) 7/13/16.
|
||||||
*/
|
*/
|
||||||
|
@ -43,7 +47,7 @@ import org.nd4j.linalg.lossfunctions.LossFunctions;
|
||||||
public class DQNFactoryStdConv implements DQNFactory {
|
public class DQNFactoryStdConv implements DQNFactory {
|
||||||
|
|
||||||
|
|
||||||
Configuration conf;
|
NetworkConfiguration conf;
|
||||||
|
|
||||||
public DQN buildDQN(int shapeInputs[], int numOutputs) {
|
public DQN buildDQN(int shapeInputs[], int numOutputs) {
|
||||||
|
|
||||||
|
@ -80,7 +84,6 @@ public class DQNFactoryStdConv implements DQNFactory {
|
||||||
return new DQN(model);
|
return new DQN(model);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@AllArgsConstructor
|
@AllArgsConstructor
|
||||||
@Builder
|
@Builder
|
||||||
@Value
|
@Value
|
||||||
|
@ -90,6 +93,23 @@ public class DQNFactoryStdConv implements DQNFactory {
|
||||||
double l2;
|
double l2;
|
||||||
IUpdater updater;
|
IUpdater updater;
|
||||||
TrainingListener[] listeners;
|
TrainingListener[] listeners;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Converts the deprecated Configuration to the new NetworkConfiguration format
|
||||||
|
*/
|
||||||
|
public NetworkConfiguration toNetworkConfiguration() {
|
||||||
|
NetworkConfiguration.NetworkConfigurationBuilder builder = NetworkConfiguration.builder()
|
||||||
|
.learningRate(learningRate)
|
||||||
|
.l2(l2)
|
||||||
|
.updater(updater);
|
||||||
|
|
||||||
|
if (listeners != null) {
|
||||||
|
builder.listeners(Arrays.asList(listeners));
|
||||||
|
}
|
||||||
|
|
||||||
|
return builder.build();
|
||||||
|
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
/*******************************************************************************
|
/*******************************************************************************
|
||||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
* Copyright (c) 2015-2019 Skymind, Inc.
|
||||||
|
* Copyright (c) 2020 Konduit K.K.
|
||||||
*
|
*
|
||||||
* This program and the accompanying materials are made available under the
|
* This program and the accompanying materials are made available under the
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
@ -28,12 +29,16 @@ import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
|
||||||
import org.deeplearning4j.nn.weights.WeightInit;
|
import org.deeplearning4j.nn.weights.WeightInit;
|
||||||
import org.deeplearning4j.optimize.api.TrainingListener;
|
import org.deeplearning4j.optimize.api.TrainingListener;
|
||||||
import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
|
import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
|
||||||
|
import org.deeplearning4j.rl4j.network.configuration.DQNDenseNetworkConfiguration;
|
||||||
|
import org.deeplearning4j.rl4j.network.configuration.DQNDenseNetworkConfiguration.DQNDenseNetworkConfigurationBuilder;
|
||||||
import org.deeplearning4j.rl4j.util.Constants;
|
import org.deeplearning4j.rl4j.util.Constants;
|
||||||
import org.nd4j.linalg.activations.Activation;
|
import org.nd4j.linalg.activations.Activation;
|
||||||
import org.nd4j.linalg.learning.config.Adam;
|
import org.nd4j.linalg.learning.config.Adam;
|
||||||
import org.nd4j.linalg.learning.config.IUpdater;
|
import org.nd4j.linalg.learning.config.IUpdater;
|
||||||
import org.nd4j.linalg.lossfunctions.LossFunctions;
|
import org.nd4j.linalg.lossfunctions.LossFunctions;
|
||||||
|
|
||||||
|
import java.util.Arrays;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @author rubenfiszel (ruben.fiszel@epfl.ch) 7/13/16.
|
* @author rubenfiszel (ruben.fiszel@epfl.ch) 7/13/16.
|
||||||
*/
|
*/
|
||||||
|
@ -41,32 +46,41 @@ import org.nd4j.linalg.lossfunctions.LossFunctions;
|
||||||
@Value
|
@Value
|
||||||
public class DQNFactoryStdDense implements DQNFactory {
|
public class DQNFactoryStdDense implements DQNFactory {
|
||||||
|
|
||||||
|
DQNDenseNetworkConfiguration conf;
|
||||||
Configuration conf;
|
|
||||||
|
|
||||||
public DQN buildDQN(int[] numInputs, int numOutputs) {
|
public DQN buildDQN(int[] numInputs, int numOutputs) {
|
||||||
int nIn = 1;
|
int nIn = 1;
|
||||||
|
|
||||||
for (int i : numInputs) {
|
for (int i : numInputs) {
|
||||||
nIn *= i;
|
nIn *= i;
|
||||||
}
|
}
|
||||||
|
|
||||||
NeuralNetConfiguration.ListBuilder confB = new NeuralNetConfiguration.Builder().seed(Constants.NEURAL_NET_SEED)
|
NeuralNetConfiguration.ListBuilder confB = new NeuralNetConfiguration.Builder().seed(Constants.NEURAL_NET_SEED)
|
||||||
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
|
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
|
||||||
//.updater(Updater.NESTEROVS).momentum(0.9)
|
.updater(conf.getUpdater() != null ? conf.getUpdater() : new Adam())
|
||||||
//.updater(Updater.RMSPROP).rho(conf.getRmsDecay())//.rmsDecay(conf.getRmsDecay())
|
.weightInit(WeightInit.XAVIER)
|
||||||
.updater(conf.getUpdater() != null ? conf.getUpdater() : new Adam())
|
.l2(conf.getL2())
|
||||||
.weightInit(WeightInit.XAVIER)
|
.list()
|
||||||
.l2(conf.getL2())
|
.layer(0,
|
||||||
.list().layer(0, new DenseLayer.Builder().nIn(nIn).nOut(conf.getNumHiddenNodes())
|
new DenseLayer.Builder()
|
||||||
.activation(Activation.RELU).build());
|
.nIn(nIn)
|
||||||
|
.nOut(conf.getNumHiddenNodes())
|
||||||
|
.activation(Activation.RELU).build()
|
||||||
|
);
|
||||||
|
|
||||||
|
|
||||||
for (int i = 1; i < conf.getNumLayer(); i++) {
|
for (int i = 1; i < conf.getNumLayers(); i++) {
|
||||||
confB.layer(i, new DenseLayer.Builder().nIn(conf.getNumHiddenNodes()).nOut(conf.getNumHiddenNodes())
|
confB.layer(i, new DenseLayer.Builder().nIn(conf.getNumHiddenNodes()).nOut(conf.getNumHiddenNodes())
|
||||||
.activation(Activation.RELU).build());
|
.activation(Activation.RELU).build());
|
||||||
}
|
}
|
||||||
|
|
||||||
confB.layer(conf.getNumLayer(), new OutputLayer.Builder(LossFunctions.LossFunction.MSE).activation(Activation.IDENTITY)
|
confB.layer(conf.getNumLayers(),
|
||||||
.nIn(conf.getNumHiddenNodes()).nOut(numOutputs).build());
|
new OutputLayer.Builder(LossFunctions.LossFunction.MSE)
|
||||||
|
.activation(Activation.IDENTITY)
|
||||||
|
.nIn(conf.getNumHiddenNodes())
|
||||||
|
.nOut(numOutputs)
|
||||||
|
.build()
|
||||||
|
);
|
||||||
|
|
||||||
|
|
||||||
MultiLayerConfiguration mlnconf = confB.build();
|
MultiLayerConfiguration mlnconf = confB.build();
|
||||||
|
@ -83,6 +97,7 @@ public class DQNFactoryStdDense implements DQNFactory {
|
||||||
@AllArgsConstructor
|
@AllArgsConstructor
|
||||||
@Value
|
@Value
|
||||||
@Builder
|
@Builder
|
||||||
|
@Deprecated
|
||||||
public static class Configuration {
|
public static class Configuration {
|
||||||
|
|
||||||
int numLayer;
|
int numLayer;
|
||||||
|
@ -90,7 +105,23 @@ public class DQNFactoryStdDense implements DQNFactory {
|
||||||
double l2;
|
double l2;
|
||||||
IUpdater updater;
|
IUpdater updater;
|
||||||
TrainingListener[] listeners;
|
TrainingListener[] listeners;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Converts the deprecated Configuration to the new NetworkConfiguration format
|
||||||
|
*/
|
||||||
|
public DQNDenseNetworkConfiguration toNetworkConfiguration() {
|
||||||
|
DQNDenseNetworkConfigurationBuilder builder = DQNDenseNetworkConfiguration.builder()
|
||||||
|
.numHiddenNodes(numHiddenNodes)
|
||||||
|
.numLayers(numLayer)
|
||||||
|
.l2(l2)
|
||||||
|
.updater(updater);
|
||||||
|
|
||||||
|
if (listeners != null) {
|
||||||
|
builder.listeners(Arrays.asList(listeners));
|
||||||
|
}
|
||||||
|
|
||||||
|
return builder.build();
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
/*******************************************************************************
|
/*******************************************************************************
|
||||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
* Copyright (c) 2015-2019 Skymind, Inc.
|
||||||
|
* Copyright (c) 2020 Konduit K.K.
|
||||||
*
|
*
|
||||||
* This program and the accompanying materials are made available under the
|
* This program and the accompanying materials are made available under the
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
@ -46,7 +47,7 @@ public class EpsGreedy<O, A, AS extends ActionSpace<A>> extends Policy<O, A> {
|
||||||
final private int updateStart;
|
final private int updateStart;
|
||||||
final private int epsilonNbStep;
|
final private int epsilonNbStep;
|
||||||
final private Random rnd;
|
final private Random rnd;
|
||||||
final private float minEpsilon;
|
final private double minEpsilon;
|
||||||
final private IEpochTrainer learning;
|
final private IEpochTrainer learning;
|
||||||
|
|
||||||
public NeuralNet getNeuralNet() {
|
public NeuralNet getNeuralNet() {
|
||||||
|
@ -55,10 +56,10 @@ public class EpsGreedy<O, A, AS extends ActionSpace<A>> extends Policy<O, A> {
|
||||||
|
|
||||||
public A nextAction(INDArray input) {
|
public A nextAction(INDArray input) {
|
||||||
|
|
||||||
float ep = getEpsilon();
|
double ep = getEpsilon();
|
||||||
if (learning.getStepCounter() % 500 == 1)
|
if (learning.getStepCounter() % 500 == 1)
|
||||||
log.info("EP: " + ep + " " + learning.getStepCounter());
|
log.info("EP: " + ep + " " + learning.getStepCounter());
|
||||||
if (rnd.nextFloat() > ep)
|
if (rnd.nextDouble() > ep)
|
||||||
return policy.nextAction(input);
|
return policy.nextAction(input);
|
||||||
else
|
else
|
||||||
return mdp.getActionSpace().randomAction();
|
return mdp.getActionSpace().randomAction();
|
||||||
|
@ -68,7 +69,7 @@ public class EpsGreedy<O, A, AS extends ActionSpace<A>> extends Policy<O, A> {
|
||||||
return this.nextAction(observation.getData());
|
return this.nextAction(observation.getData());
|
||||||
}
|
}
|
||||||
|
|
||||||
public float getEpsilon() {
|
public double getEpsilon() {
|
||||||
return Math.min(1f, Math.max(minEpsilon, 1f - (learning.getStepCounter() - updateStart) * 1f / epsilonNbStep));
|
return Math.min(1.0, Math.max(minEpsilon, 1.0 - (learning.getStepCounter() - updateStart) * 1.0 / epsilonNbStep));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
/*******************************************************************************
|
/*******************************************************************************
|
||||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
* Copyright (c) 2015-2019 Skymind, Inc.
|
||||||
|
* Copyright (c) 2020 Konduit K.K.
|
||||||
*
|
*
|
||||||
* This program and the accompanying materials are made available under the
|
* This program and the accompanying materials are made available under the
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
@ -22,17 +23,30 @@ import lombok.Builder;
|
||||||
import lombok.Getter;
|
import lombok.Getter;
|
||||||
import lombok.Value;
|
import lombok.Value;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import org.deeplearning4j.rl4j.learning.NeuralNetFetchable;
|
|
||||||
import org.nd4j.linalg.primitives.Pair;
|
|
||||||
import org.deeplearning4j.rl4j.learning.ILearning;
|
import org.deeplearning4j.rl4j.learning.ILearning;
|
||||||
import org.deeplearning4j.rl4j.learning.Learning;
|
import org.deeplearning4j.rl4j.learning.NeuralNetFetchable;
|
||||||
|
import org.deeplearning4j.rl4j.learning.configuration.ILearningConfiguration;
|
||||||
import org.deeplearning4j.rl4j.network.ac.IActorCritic;
|
import org.deeplearning4j.rl4j.network.ac.IActorCritic;
|
||||||
import org.deeplearning4j.rl4j.network.dqn.DQN;
|
import org.deeplearning4j.rl4j.network.dqn.DQN;
|
||||||
import org.deeplearning4j.rl4j.network.dqn.IDQN;
|
import org.deeplearning4j.rl4j.network.dqn.IDQN;
|
||||||
import org.deeplearning4j.util.ModelSerializer;
|
import org.deeplearning4j.util.ModelSerializer;
|
||||||
|
import org.nd4j.linalg.primitives.Pair;
|
||||||
|
|
||||||
import java.io.*;
|
import java.io.BufferedOutputStream;
|
||||||
import java.nio.file.*;
|
import java.io.BufferedReader;
|
||||||
|
import java.io.ByteArrayInputStream;
|
||||||
|
import java.io.ByteArrayOutputStream;
|
||||||
|
import java.io.File;
|
||||||
|
import java.io.FileOutputStream;
|
||||||
|
import java.io.IOException;
|
||||||
|
import java.io.InputStream;
|
||||||
|
import java.io.InputStreamReader;
|
||||||
|
import java.io.OutputStream;
|
||||||
|
import java.nio.file.Files;
|
||||||
|
import java.nio.file.Path;
|
||||||
|
import java.nio.file.Paths;
|
||||||
|
import java.nio.file.StandardCopyOption;
|
||||||
|
import java.nio.file.StandardOpenOption;
|
||||||
import java.util.zip.ZipEntry;
|
import java.util.zip.ZipEntry;
|
||||||
import java.util.zip.ZipFile;
|
import java.util.zip.ZipFile;
|
||||||
import java.util.zip.ZipOutputStream;
|
import java.util.zip.ZipOutputStream;
|
||||||
|
@ -304,7 +318,7 @@ public class DataManager implements IDataManager {
|
||||||
public static class Info {
|
public static class Info {
|
||||||
String trainingName;
|
String trainingName;
|
||||||
String mdpName;
|
String mdpName;
|
||||||
ILearning.LConfiguration conf;
|
ILearningConfiguration conf;
|
||||||
int stepCounter;
|
int stepCounter;
|
||||||
long millisTime;
|
long millisTime;
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
/*******************************************************************************
|
/*******************************************************************************
|
||||||
* Copyright (c) 2015-2019 Skymind, Inc.
|
* Copyright (c) 2015-2019 Skymind, Inc.
|
||||||
|
* Copyright (c) 2020 Konduit K.K.
|
||||||
*
|
*
|
||||||
* This program and the accompanying materials are made available under the
|
* This program and the accompanying materials are made available under the
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
@ -16,14 +17,11 @@
|
||||||
|
|
||||||
package org.deeplearning4j.rl4j.learning;
|
package org.deeplearning4j.rl4j.learning;
|
||||||
|
|
||||||
import java.util.Arrays;
|
|
||||||
import org.junit.Test;
|
import org.junit.Test;
|
||||||
import org.nd4j.linalg.activations.Activation;
|
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
|
|
||||||
import static org.junit.Assert.assertEquals;
|
import static org.junit.Assert.assertEquals;
|
||||||
import static org.junit.Assert.assertTrue;
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
*
|
*
|
||||||
|
@ -32,7 +30,7 @@ import static org.junit.Assert.assertTrue;
|
||||||
public class HistoryProcessorTest {
|
public class HistoryProcessorTest {
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testHistoryProcessor() throws Exception {
|
public void testHistoryProcessor() {
|
||||||
HistoryProcessor.Configuration conf = HistoryProcessor.Configuration.builder()
|
HistoryProcessor.Configuration conf = HistoryProcessor.Configuration.builder()
|
||||||
.croppingHeight(2).croppingWidth(2).rescaledHeight(2).rescaledWidth(2).build();
|
.croppingHeight(2).croppingWidth(2).rescaledHeight(2).rescaledWidth(2).build();
|
||||||
IHistoryProcessor hp = new HistoryProcessor(conf);
|
IHistoryProcessor hp = new HistoryProcessor(conf);
|
||||||
|
@ -43,8 +41,6 @@ public class HistoryProcessorTest {
|
||||||
hp.add(a);
|
hp.add(a);
|
||||||
INDArray[] h = hp.getHistory();
|
INDArray[] h = hp.getHistory();
|
||||||
assertEquals(4, h.length);
|
assertEquals(4, h.length);
|
||||||
// System.out.println(Arrays.toString(a.shape()));
|
|
||||||
// System.out.println(Arrays.toString(h[0].shape()));
|
|
||||||
assertEquals( 1, h[0].shape()[0]);
|
assertEquals( 1, h[0].shape()[0]);
|
||||||
assertEquals(a.shape()[0], h[0].shape()[1]);
|
assertEquals(a.shape()[0], h[0].shape()[1]);
|
||||||
assertEquals(a.shape()[1], h[0].shape()[2]);
|
assertEquals(a.shape()[1], h[0].shape()[2]);
|
||||||
|
|
|
@ -1,9 +1,32 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2015-2019 Skymind, Inc.
|
||||||
|
* Copyright (c) 2020 Konduit K.K.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
package org.deeplearning4j.rl4j.learning.async;
|
package org.deeplearning4j.rl4j.learning.async;
|
||||||
|
|
||||||
|
import org.deeplearning4j.rl4j.learning.configuration.IAsyncLearningConfiguration;
|
||||||
import org.deeplearning4j.rl4j.mdp.MDP;
|
import org.deeplearning4j.rl4j.mdp.MDP;
|
||||||
import org.deeplearning4j.rl4j.policy.IPolicy;
|
import org.deeplearning4j.rl4j.policy.IPolicy;
|
||||||
import org.deeplearning4j.rl4j.space.DiscreteSpace;
|
import org.deeplearning4j.rl4j.space.DiscreteSpace;
|
||||||
import org.deeplearning4j.rl4j.support.*;
|
import org.deeplearning4j.rl4j.support.MockAsyncConfiguration;
|
||||||
|
import org.deeplearning4j.rl4j.support.MockAsyncGlobal;
|
||||||
|
import org.deeplearning4j.rl4j.support.MockEncodable;
|
||||||
|
import org.deeplearning4j.rl4j.support.MockNeuralNet;
|
||||||
|
import org.deeplearning4j.rl4j.support.MockPolicy;
|
||||||
|
import org.deeplearning4j.rl4j.support.MockTrainingListener;
|
||||||
import org.junit.Test;
|
import org.junit.Test;
|
||||||
|
|
||||||
import static org.junit.Assert.assertEquals;
|
import static org.junit.Assert.assertEquals;
|
||||||
|
@ -68,7 +91,7 @@ public class AsyncLearningTest {
|
||||||
|
|
||||||
|
|
||||||
public static class TestContext {
|
public static class TestContext {
|
||||||
MockAsyncConfiguration config = new MockAsyncConfiguration(1, 11, 0, 0, 0, 0,0, 0, 0, 0);
|
MockAsyncConfiguration config = new MockAsyncConfiguration(1L, 11, 0, 0, 0, 0,0, 0, 0, 0);
|
||||||
public final MockAsyncGlobal asyncGlobal = new MockAsyncGlobal();
|
public final MockAsyncGlobal asyncGlobal = new MockAsyncGlobal();
|
||||||
public final MockPolicy policy = new MockPolicy();
|
public final MockPolicy policy = new MockPolicy();
|
||||||
public final TestAsyncLearning sut = new TestAsyncLearning(config, asyncGlobal, policy);
|
public final TestAsyncLearning sut = new TestAsyncLearning(config, asyncGlobal, policy);
|
||||||
|
@ -82,11 +105,11 @@ public class AsyncLearningTest {
|
||||||
}
|
}
|
||||||
|
|
||||||
public static class TestAsyncLearning extends AsyncLearning<MockEncodable, Integer, DiscreteSpace, MockNeuralNet> {
|
public static class TestAsyncLearning extends AsyncLearning<MockEncodable, Integer, DiscreteSpace, MockNeuralNet> {
|
||||||
private final AsyncConfiguration conf;
|
private final IAsyncLearningConfiguration conf;
|
||||||
private final IAsyncGlobal asyncGlobal;
|
private final IAsyncGlobal asyncGlobal;
|
||||||
private final IPolicy<MockEncodable, Integer> policy;
|
private final IPolicy<MockEncodable, Integer> policy;
|
||||||
|
|
||||||
public TestAsyncLearning(AsyncConfiguration conf, IAsyncGlobal asyncGlobal, IPolicy<MockEncodable, Integer> policy) {
|
public TestAsyncLearning(IAsyncLearningConfiguration conf, IAsyncGlobal asyncGlobal, IPolicy<MockEncodable, Integer> policy) {
|
||||||
this.conf = conf;
|
this.conf = conf;
|
||||||
this.asyncGlobal = asyncGlobal;
|
this.asyncGlobal = asyncGlobal;
|
||||||
this.policy = policy;
|
this.policy = policy;
|
||||||
|
@ -98,7 +121,7 @@ public class AsyncLearningTest {
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public AsyncConfiguration getConfiguration() {
|
public IAsyncLearningConfiguration getConfiguration() {
|
||||||
return conf;
|
return conf;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -1,7 +1,25 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2015-2019 Skymind, Inc.
|
||||||
|
* Copyright (c) 2020 Konduit K.K.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
package org.deeplearning4j.rl4j.learning.async;
|
package org.deeplearning4j.rl4j.learning.async;
|
||||||
|
|
||||||
import org.deeplearning4j.nn.gradient.Gradient;
|
import org.deeplearning4j.nn.gradient.Gradient;
|
||||||
import org.deeplearning4j.rl4j.learning.IHistoryProcessor;
|
import org.deeplearning4j.rl4j.learning.IHistoryProcessor;
|
||||||
|
import org.deeplearning4j.rl4j.learning.configuration.IAsyncLearningConfiguration;
|
||||||
import org.deeplearning4j.rl4j.learning.listener.TrainingListenerList;
|
import org.deeplearning4j.rl4j.learning.listener.TrainingListenerList;
|
||||||
import org.deeplearning4j.rl4j.mdp.MDP;
|
import org.deeplearning4j.rl4j.mdp.MDP;
|
||||||
import org.deeplearning4j.rl4j.observation.Observation;
|
import org.deeplearning4j.rl4j.observation.Observation;
|
||||||
|
@ -32,7 +50,7 @@ public class AsyncThreadDiscreteTest {
|
||||||
MockMDP mdpMock = new MockMDP(observationSpace);
|
MockMDP mdpMock = new MockMDP(observationSpace);
|
||||||
TrainingListenerList listeners = new TrainingListenerList();
|
TrainingListenerList listeners = new TrainingListenerList();
|
||||||
MockPolicy policyMock = new MockPolicy();
|
MockPolicy policyMock = new MockPolicy();
|
||||||
MockAsyncConfiguration config = new MockAsyncConfiguration(5, 100, 0, 0, 2, 5,0, 0, 0, 0);
|
MockAsyncConfiguration config = new MockAsyncConfiguration(5L, 100, 0,0, 0, 0, 0, 0, 2, 5);
|
||||||
TestAsyncThreadDiscrete sut = new TestAsyncThreadDiscrete(asyncGlobalMock, mdpMock, listeners, 0, 0, policyMock, config, hpMock);
|
TestAsyncThreadDiscrete sut = new TestAsyncThreadDiscrete(asyncGlobalMock, mdpMock, listeners, 0, 0, policyMock, config, hpMock);
|
||||||
sut.getLegacyMDPWrapper().setTransformProcess(MockMDP.buildTransformProcess(observationSpace.getShape(), hpConf.getSkipFrame(), hpConf.getHistoryLength()));
|
sut.getLegacyMDPWrapper().setTransformProcess(MockMDP.buildTransformProcess(observationSpace.getShape(), hpConf.getSkipFrame(), hpConf.getHistoryLength()));
|
||||||
|
|
||||||
|
@ -173,7 +191,7 @@ public class AsyncThreadDiscreteTest {
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
protected AsyncConfiguration getConf() {
|
protected IAsyncLearningConfiguration getConf() {
|
||||||
return config;
|
return config;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -3,12 +3,20 @@ package org.deeplearning4j.rl4j.learning.async;
|
||||||
import lombok.AllArgsConstructor;
|
import lombok.AllArgsConstructor;
|
||||||
import lombok.Getter;
|
import lombok.Getter;
|
||||||
import org.deeplearning4j.rl4j.learning.IHistoryProcessor;
|
import org.deeplearning4j.rl4j.learning.IHistoryProcessor;
|
||||||
|
import org.deeplearning4j.rl4j.learning.configuration.IAsyncLearningConfiguration;
|
||||||
import org.deeplearning4j.rl4j.learning.listener.TrainingListenerList;
|
import org.deeplearning4j.rl4j.learning.listener.TrainingListenerList;
|
||||||
import org.deeplearning4j.rl4j.mdp.MDP;
|
import org.deeplearning4j.rl4j.mdp.MDP;
|
||||||
import org.deeplearning4j.rl4j.observation.Observation;
|
import org.deeplearning4j.rl4j.observation.Observation;
|
||||||
import org.deeplearning4j.rl4j.policy.Policy;
|
import org.deeplearning4j.rl4j.policy.Policy;
|
||||||
import org.deeplearning4j.rl4j.space.DiscreteSpace;
|
import org.deeplearning4j.rl4j.space.DiscreteSpace;
|
||||||
import org.deeplearning4j.rl4j.support.*;
|
import org.deeplearning4j.rl4j.support.MockAsyncConfiguration;
|
||||||
|
import org.deeplearning4j.rl4j.support.MockAsyncGlobal;
|
||||||
|
import org.deeplearning4j.rl4j.support.MockEncodable;
|
||||||
|
import org.deeplearning4j.rl4j.support.MockHistoryProcessor;
|
||||||
|
import org.deeplearning4j.rl4j.support.MockMDP;
|
||||||
|
import org.deeplearning4j.rl4j.support.MockNeuralNet;
|
||||||
|
import org.deeplearning4j.rl4j.support.MockObservationSpace;
|
||||||
|
import org.deeplearning4j.rl4j.support.MockTrainingListener;
|
||||||
import org.deeplearning4j.rl4j.util.IDataManager;
|
import org.deeplearning4j.rl4j.util.IDataManager;
|
||||||
import org.junit.Test;
|
import org.junit.Test;
|
||||||
|
|
||||||
|
@ -16,7 +24,6 @@ import java.util.ArrayList;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
|
||||||
import static org.junit.Assert.assertEquals;
|
import static org.junit.Assert.assertEquals;
|
||||||
import static org.junit.Assert.assertNull;
|
|
||||||
|
|
||||||
public class AsyncThreadTest {
|
public class AsyncThreadTest {
|
||||||
|
|
||||||
|
@ -126,7 +133,7 @@ public class AsyncThreadTest {
|
||||||
public final MockNeuralNet neuralNet = new MockNeuralNet();
|
public final MockNeuralNet neuralNet = new MockNeuralNet();
|
||||||
public final MockObservationSpace observationSpace = new MockObservationSpace();
|
public final MockObservationSpace observationSpace = new MockObservationSpace();
|
||||||
public final MockMDP mdp = new MockMDP(observationSpace);
|
public final MockMDP mdp = new MockMDP(observationSpace);
|
||||||
public final MockAsyncConfiguration config = new MockAsyncConfiguration(5, 10, 0, 0, 10, 0, 0, 0, 0, 0);
|
public final MockAsyncConfiguration config = new MockAsyncConfiguration(5L, 10, 0, 0, 0, 0, 0, 0, 10, 0);
|
||||||
public final TrainingListenerList listeners = new TrainingListenerList();
|
public final TrainingListenerList listeners = new TrainingListenerList();
|
||||||
public final MockTrainingListener listener = new MockTrainingListener();
|
public final MockTrainingListener listener = new MockTrainingListener();
|
||||||
public final IHistoryProcessor.Configuration hpConf = new IHistoryProcessor.Configuration(5, 4, 4, 4, 4, 0, 0, 2);
|
public final IHistoryProcessor.Configuration hpConf = new IHistoryProcessor.Configuration(5, 4, 4, 4, 4, 0, 0, 2);
|
||||||
|
@ -149,11 +156,11 @@ public class AsyncThreadTest {
|
||||||
|
|
||||||
private final MockAsyncGlobal asyncGlobal;
|
private final MockAsyncGlobal asyncGlobal;
|
||||||
private final MockNeuralNet neuralNet;
|
private final MockNeuralNet neuralNet;
|
||||||
private final AsyncConfiguration conf;
|
private final IAsyncLearningConfiguration conf;
|
||||||
|
|
||||||
private final List<TrainSubEpochParams> trainSubEpochParams = new ArrayList<TrainSubEpochParams>();
|
private final List<TrainSubEpochParams> trainSubEpochParams = new ArrayList<TrainSubEpochParams>();
|
||||||
|
|
||||||
public MockAsyncThread(MockAsyncGlobal asyncGlobal, int threadNumber, MockNeuralNet neuralNet, MDP mdp, AsyncConfiguration conf, TrainingListenerList listeners) {
|
public MockAsyncThread(MockAsyncGlobal asyncGlobal, int threadNumber, MockNeuralNet neuralNet, MDP mdp, IAsyncLearningConfiguration conf, TrainingListenerList listeners) {
|
||||||
super(asyncGlobal, mdp, listeners, threadNumber, 0);
|
super(asyncGlobal, mdp, listeners, threadNumber, 0);
|
||||||
|
|
||||||
this.asyncGlobal = asyncGlobal;
|
this.asyncGlobal = asyncGlobal;
|
||||||
|
@ -184,7 +191,7 @@ public class AsyncThreadTest {
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
protected AsyncConfiguration getConf() {
|
protected IAsyncLearningConfiguration getConf() {
|
||||||
return conf;
|
return conf;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -1,11 +1,27 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2015-2019 Skymind, Inc.
|
||||||
|
* Copyright (c) 2020 Konduit K.K.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
package org.deeplearning4j.rl4j.learning.async.a3c.discrete;
|
package org.deeplearning4j.rl4j.learning.async.a3c.discrete;
|
||||||
|
|
||||||
import org.deeplearning4j.nn.api.NeuralNetwork;
|
import org.deeplearning4j.nn.api.NeuralNetwork;
|
||||||
import org.deeplearning4j.nn.gradient.Gradient;
|
import org.deeplearning4j.nn.gradient.Gradient;
|
||||||
import org.deeplearning4j.rl4j.learning.IHistoryProcessor;
|
import org.deeplearning4j.rl4j.learning.IHistoryProcessor;
|
||||||
import org.deeplearning4j.rl4j.learning.async.MiniTrans;
|
import org.deeplearning4j.rl4j.learning.async.MiniTrans;
|
||||||
import org.deeplearning4j.rl4j.learning.async.nstep.discrete.AsyncNStepQLearningDiscrete;
|
import org.deeplearning4j.rl4j.learning.configuration.A3CLearningConfiguration;
|
||||||
import org.deeplearning4j.rl4j.learning.async.nstep.discrete.AsyncNStepQLearningThreadDiscrete;
|
|
||||||
import org.deeplearning4j.rl4j.network.NeuralNet;
|
import org.deeplearning4j.rl4j.network.NeuralNet;
|
||||||
import org.deeplearning4j.rl4j.network.ac.IActorCritic;
|
import org.deeplearning4j.rl4j.network.ac.IActorCritic;
|
||||||
import org.deeplearning4j.rl4j.support.*;
|
import org.deeplearning4j.rl4j.support.*;
|
||||||
|
@ -31,7 +47,7 @@ public class A3CThreadDiscreteTest {
|
||||||
double gamma = 0.9;
|
double gamma = 0.9;
|
||||||
MockObservationSpace observationSpace = new MockObservationSpace();
|
MockObservationSpace observationSpace = new MockObservationSpace();
|
||||||
MockMDP mdpMock = new MockMDP(observationSpace);
|
MockMDP mdpMock = new MockMDP(observationSpace);
|
||||||
A3CDiscrete.A3CConfiguration config = new A3CDiscrete.A3CConfiguration(0, 0, 0, 0, 0, 0, 0, gamma, 0);
|
A3CLearningConfiguration config = A3CLearningConfiguration.builder().gamma(0.9).build();
|
||||||
MockActorCritic actorCriticMock = new MockActorCritic();
|
MockActorCritic actorCriticMock = new MockActorCritic();
|
||||||
IHistoryProcessor.Configuration hpConf = new IHistoryProcessor.Configuration(5, 1, 1, 1, 1, 0, 0, 2);
|
IHistoryProcessor.Configuration hpConf = new IHistoryProcessor.Configuration(5, 1, 1, 1, 1, 0, 0, 2);
|
||||||
MockAsyncGlobal<IActorCritic> asyncGlobalMock = new MockAsyncGlobal<IActorCritic>(actorCriticMock);
|
MockAsyncGlobal<IActorCritic> asyncGlobalMock = new MockAsyncGlobal<IActorCritic>(actorCriticMock);
|
||||||
|
@ -54,9 +70,9 @@ public class A3CThreadDiscreteTest {
|
||||||
Nd4j.zeros(5)
|
Nd4j.zeros(5)
|
||||||
};
|
};
|
||||||
output[0].putScalar(i, outputs[i]);
|
output[0].putScalar(i, outputs[i]);
|
||||||
minitransList.push(new MiniTrans<Integer>(obs, i, output, rewards[i]));
|
minitransList.push(new MiniTrans<>(obs, i, output, rewards[i]));
|
||||||
}
|
}
|
||||||
minitransList.push(new MiniTrans<Integer>(null, 0, null, 4.0)); // The special batch-ending MiniTrans
|
minitransList.push(new MiniTrans<>(null, 0, null, 4.0)); // The special batch-ending MiniTrans
|
||||||
|
|
||||||
// Act
|
// Act
|
||||||
sut.calcGradient(actorCriticMock, minitransList);
|
sut.calcGradient(actorCriticMock, minitransList);
|
||||||
|
|
|
@ -1,7 +1,24 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2015-2020 Skymind, Inc.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
package org.deeplearning4j.rl4j.learning.async.nstep.discrete;
|
package org.deeplearning4j.rl4j.learning.async.nstep.discrete;
|
||||||
|
|
||||||
import org.deeplearning4j.rl4j.learning.IHistoryProcessor;
|
import org.deeplearning4j.rl4j.learning.IHistoryProcessor;
|
||||||
import org.deeplearning4j.rl4j.learning.async.MiniTrans;
|
import org.deeplearning4j.rl4j.learning.async.MiniTrans;
|
||||||
|
import org.deeplearning4j.rl4j.learning.configuration.AsyncQLearningConfiguration;
|
||||||
import org.deeplearning4j.rl4j.support.*;
|
import org.deeplearning4j.rl4j.support.*;
|
||||||
import org.junit.Test;
|
import org.junit.Test;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
@ -19,7 +36,7 @@ public class AsyncNStepQLearningThreadDiscreteTest {
|
||||||
double gamma = 0.9;
|
double gamma = 0.9;
|
||||||
MockObservationSpace observationSpace = new MockObservationSpace();
|
MockObservationSpace observationSpace = new MockObservationSpace();
|
||||||
MockMDP mdpMock = new MockMDP(observationSpace);
|
MockMDP mdpMock = new MockMDP(observationSpace);
|
||||||
AsyncNStepQLearningDiscrete.AsyncNStepQLConfiguration config = new AsyncNStepQLearningDiscrete.AsyncNStepQLConfiguration(0, 0, 0, 0, 0, 0, 0, 0, gamma, 0, 0, 0);
|
AsyncQLearningConfiguration config = AsyncQLearningConfiguration.builder().gamma(gamma).build();
|
||||||
MockDQN dqnMock = new MockDQN();
|
MockDQN dqnMock = new MockDQN();
|
||||||
IHistoryProcessor.Configuration hpConf = new IHistoryProcessor.Configuration(5, 1, 1, 1, 1, 0, 0, 2);
|
IHistoryProcessor.Configuration hpConf = new IHistoryProcessor.Configuration(5, 1, 1, 1, 1, 0, 0, 2);
|
||||||
MockAsyncGlobal asyncGlobalMock = new MockAsyncGlobal(dqnMock);
|
MockAsyncGlobal asyncGlobalMock = new MockAsyncGlobal(dqnMock);
|
||||||
|
@ -42,9 +59,9 @@ public class AsyncNStepQLearningThreadDiscreteTest {
|
||||||
Nd4j.zeros(5)
|
Nd4j.zeros(5)
|
||||||
};
|
};
|
||||||
output[0].putScalar(i, outputs[i]);
|
output[0].putScalar(i, outputs[i]);
|
||||||
minitransList.push(new MiniTrans<Integer>(obs, i, output, rewards[i]));
|
minitransList.push(new MiniTrans<>(obs, i, output, rewards[i]));
|
||||||
}
|
}
|
||||||
minitransList.push(new MiniTrans<Integer>(null, 0, null, 4.0)); // The special batch-ending MiniTrans
|
minitransList.push(new MiniTrans<>(null, 0, null, 4.0)); // The special batch-ending MiniTrans
|
||||||
|
|
||||||
// Act
|
// Act
|
||||||
sut.calcGradient(dqnMock, minitransList);
|
sut.calcGradient(dqnMock, minitransList);
|
||||||
|
|
|
@ -1,6 +1,26 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2015-2019 Skymind, Inc.
|
||||||
|
* Copyright (c) 2020 Konduit K.K.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
package org.deeplearning4j.rl4j.learning.sync;
|
package org.deeplearning4j.rl4j.learning.sync;
|
||||||
|
|
||||||
import lombok.Getter;
|
import lombok.Getter;
|
||||||
|
import org.deeplearning4j.rl4j.learning.configuration.ILearningConfiguration;
|
||||||
|
import org.deeplearning4j.rl4j.learning.configuration.LearningConfiguration;
|
||||||
|
import org.deeplearning4j.rl4j.learning.configuration.QLearningConfiguration;
|
||||||
import org.deeplearning4j.rl4j.learning.sync.qlearning.QLearning;
|
import org.deeplearning4j.rl4j.learning.sync.qlearning.QLearning;
|
||||||
import org.deeplearning4j.rl4j.learning.sync.support.MockStatEntry;
|
import org.deeplearning4j.rl4j.learning.sync.support.MockStatEntry;
|
||||||
import org.deeplearning4j.rl4j.mdp.MDP;
|
import org.deeplearning4j.rl4j.mdp.MDP;
|
||||||
|
@ -17,7 +37,7 @@ public class SyncLearningTest {
|
||||||
@Test
|
@Test
|
||||||
public void when_training_expect_listenersToBeCalled() {
|
public void when_training_expect_listenersToBeCalled() {
|
||||||
// Arrange
|
// Arrange
|
||||||
QLearning.QLConfiguration lconfig = QLearning.QLConfiguration.builder().maxStep(10).build();
|
QLearningConfiguration lconfig = QLearningConfiguration.builder().maxStep(10).build();
|
||||||
MockTrainingListener listener = new MockTrainingListener();
|
MockTrainingListener listener = new MockTrainingListener();
|
||||||
MockSyncLearning sut = new MockSyncLearning(lconfig);
|
MockSyncLearning sut = new MockSyncLearning(lconfig);
|
||||||
sut.addListener(listener);
|
sut.addListener(listener);
|
||||||
|
@ -34,7 +54,7 @@ public class SyncLearningTest {
|
||||||
@Test
|
@Test
|
||||||
public void when_trainingStartCanContinueFalse_expect_trainingStopped() {
|
public void when_trainingStartCanContinueFalse_expect_trainingStopped() {
|
||||||
// Arrange
|
// Arrange
|
||||||
QLearning.QLConfiguration lconfig = QLearning.QLConfiguration.builder().maxStep(10).build();
|
QLearningConfiguration lconfig = QLearningConfiguration.builder().maxStep(10).build();
|
||||||
MockTrainingListener listener = new MockTrainingListener();
|
MockTrainingListener listener = new MockTrainingListener();
|
||||||
MockSyncLearning sut = new MockSyncLearning(lconfig);
|
MockSyncLearning sut = new MockSyncLearning(lconfig);
|
||||||
sut.addListener(listener);
|
sut.addListener(listener);
|
||||||
|
@ -52,7 +72,7 @@ public class SyncLearningTest {
|
||||||
@Test
|
@Test
|
||||||
public void when_newEpochCanContinueFalse_expect_trainingStopped() {
|
public void when_newEpochCanContinueFalse_expect_trainingStopped() {
|
||||||
// Arrange
|
// Arrange
|
||||||
QLearning.QLConfiguration lconfig = QLearning.QLConfiguration.builder().maxStep(10).build();
|
QLearningConfiguration lconfig = QLearningConfiguration.builder().maxStep(10).build();
|
||||||
MockTrainingListener listener = new MockTrainingListener();
|
MockTrainingListener listener = new MockTrainingListener();
|
||||||
MockSyncLearning sut = new MockSyncLearning(lconfig);
|
MockSyncLearning sut = new MockSyncLearning(lconfig);
|
||||||
sut.addListener(listener);
|
sut.addListener(listener);
|
||||||
|
@ -70,7 +90,7 @@ public class SyncLearningTest {
|
||||||
@Test
|
@Test
|
||||||
public void when_epochTrainingResultCanContinueFalse_expect_trainingStopped() {
|
public void when_epochTrainingResultCanContinueFalse_expect_trainingStopped() {
|
||||||
// Arrange
|
// Arrange
|
||||||
QLearning.QLConfiguration lconfig = QLearning.QLConfiguration.builder().maxStep(10).build();
|
LearningConfiguration lconfig = QLearningConfiguration.builder().maxStep(10).build();
|
||||||
MockTrainingListener listener = new MockTrainingListener();
|
MockTrainingListener listener = new MockTrainingListener();
|
||||||
MockSyncLearning sut = new MockSyncLearning(lconfig);
|
MockSyncLearning sut = new MockSyncLearning(lconfig);
|
||||||
sut.addListener(listener);
|
sut.addListener(listener);
|
||||||
|
@ -87,12 +107,12 @@ public class SyncLearningTest {
|
||||||
|
|
||||||
public static class MockSyncLearning extends SyncLearning {
|
public static class MockSyncLearning extends SyncLearning {
|
||||||
|
|
||||||
private final LConfiguration conf;
|
private final ILearningConfiguration conf;
|
||||||
|
|
||||||
@Getter
|
@Getter
|
||||||
private int currentEpochStep = 0;
|
private int currentEpochStep = 0;
|
||||||
|
|
||||||
public MockSyncLearning(LConfiguration conf) {
|
public MockSyncLearning(ILearningConfiguration conf) {
|
||||||
this.conf = conf;
|
this.conf = conf;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -119,7 +139,7 @@ public class SyncLearningTest {
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public LConfiguration getConfiguration() {
|
public ILearningConfiguration getConfiguration() {
|
||||||
return conf;
|
return conf;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
/*******************************************************************************
|
/*******************************************************************************
|
||||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
* Copyright (c) 2015-2019 Skymind, Inc.
|
||||||
|
* Copyright (c) 2020 Konduit K.K.
|
||||||
*
|
*
|
||||||
* This program and the accompanying materials are made available under the
|
* This program and the accompanying materials are made available under the
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
@ -17,36 +18,24 @@
|
||||||
package org.deeplearning4j.rl4j.learning.sync.qlearning;
|
package org.deeplearning4j.rl4j.learning.sync.qlearning;
|
||||||
|
|
||||||
import com.fasterxml.jackson.databind.ObjectMapper;
|
import com.fasterxml.jackson.databind.ObjectMapper;
|
||||||
|
import org.deeplearning4j.rl4j.learning.configuration.QLearningConfiguration;
|
||||||
import org.junit.Rule;
|
import org.junit.Rule;
|
||||||
import org.junit.Test;
|
import org.junit.Test;
|
||||||
import org.junit.rules.ExpectedException;
|
import org.junit.rules.ExpectedException;
|
||||||
|
|
||||||
public class QLConfigurationTest {
|
public class QLearningConfigurationTest {
|
||||||
@Rule
|
@Rule
|
||||||
public ExpectedException thrown = ExpectedException.none();
|
public ExpectedException thrown = ExpectedException.none();
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void serialize() throws Exception {
|
public void serialize() throws Exception {
|
||||||
ObjectMapper mapper = new ObjectMapper();
|
ObjectMapper mapper = new ObjectMapper();
|
||||||
QLearning.QLConfiguration qlConfiguration =
|
|
||||||
new QLearning.QLConfiguration(
|
QLearningConfiguration qLearningConfiguration = QLearningConfiguration.builder()
|
||||||
123, //Random seed
|
.build();
|
||||||
200, //Max step By epoch
|
|
||||||
8000, //Max step
|
|
||||||
150000, //Max size of experience replay
|
|
||||||
32, //size of batches
|
|
||||||
500, //target update (hard)
|
|
||||||
10, //num step noop warmup
|
|
||||||
0.01, //reward scaling
|
|
||||||
0.99, //gamma
|
|
||||||
1.0, //td error clipping
|
|
||||||
0.1f, //min epsilon
|
|
||||||
10000, //num step for eps greedy anneal
|
|
||||||
true //double DQN
|
|
||||||
);
|
|
||||||
|
|
||||||
// Should not throw..
|
// Should not throw..
|
||||||
String json = mapper.writeValueAsString(qlConfiguration);
|
String json = mapper.writeValueAsString(qLearningConfiguration);
|
||||||
QLearning.QLConfiguration cnf = mapper.readValue(json, QLearning.QLConfiguration.class);
|
QLearningConfiguration cnf = mapper.readValue(json, QLearningConfiguration.class);
|
||||||
}
|
}
|
||||||
}
|
}
|
|
@ -1,6 +1,24 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2015-2019 Skymind, Inc.
|
||||||
|
* Copyright (c) 2020 Konduit K.K.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
package org.deeplearning4j.rl4j.learning.sync.qlearning.discrete;
|
package org.deeplearning4j.rl4j.learning.sync.qlearning.discrete;
|
||||||
|
|
||||||
import org.deeplearning4j.rl4j.learning.IHistoryProcessor;
|
import org.deeplearning4j.rl4j.learning.IHistoryProcessor;
|
||||||
|
import org.deeplearning4j.rl4j.learning.configuration.QLearningConfiguration;
|
||||||
import org.deeplearning4j.rl4j.learning.sync.IExpReplay;
|
import org.deeplearning4j.rl4j.learning.sync.IExpReplay;
|
||||||
import org.deeplearning4j.rl4j.learning.sync.Transition;
|
import org.deeplearning4j.rl4j.learning.sync.Transition;
|
||||||
import org.deeplearning4j.rl4j.learning.sync.qlearning.QLearning;
|
import org.deeplearning4j.rl4j.learning.sync.qlearning.QLearning;
|
||||||
|
@ -27,7 +45,7 @@ public class QLearningDiscreteTest {
|
||||||
// Arrange
|
// Arrange
|
||||||
MockObservationSpace observationSpace = new MockObservationSpace();
|
MockObservationSpace observationSpace = new MockObservationSpace();
|
||||||
MockDQN dqn = new MockDQN();
|
MockDQN dqn = new MockDQN();
|
||||||
MockRandom random = new MockRandom(new double[] {
|
MockRandom random = new MockRandom(new double[]{
|
||||||
0.7309677600860596,
|
0.7309677600860596,
|
||||||
0.8314409852027893,
|
0.8314409852027893,
|
||||||
0.2405363917350769,
|
0.2405363917350769,
|
||||||
|
@ -36,14 +54,26 @@ public class QLearningDiscreteTest {
|
||||||
0.3090505599975586,
|
0.3090505599975586,
|
||||||
0.5504369735717773,
|
0.5504369735717773,
|
||||||
0.11700659990310669
|
0.11700659990310669
|
||||||
},
|
},
|
||||||
new int[] { 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4 });
|
new int[]{1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4});
|
||||||
MockMDP mdp = new MockMDP(observationSpace, random);
|
MockMDP mdp = new MockMDP(observationSpace, random);
|
||||||
|
|
||||||
int initStepCount = 8;
|
int initStepCount = 8;
|
||||||
|
|
||||||
QLearning.QLConfiguration conf = new QLearning.QLConfiguration(0, 24, 0, 5, 1, 1000,
|
QLearningConfiguration conf = QLearningConfiguration.builder()
|
||||||
initStepCount, 1.0, 0, 0, 0, 0, true);
|
.seed(0L)
|
||||||
|
.maxEpochStep(24)
|
||||||
|
.maxStep(0)
|
||||||
|
.expRepMaxSize(5).batchSize(1).targetDqnUpdateFreq(1000)
|
||||||
|
.updateStart(initStepCount)
|
||||||
|
.rewardFactor(1.0)
|
||||||
|
.gamma(0)
|
||||||
|
.errorClamp(0)
|
||||||
|
.minEpsilon(0)
|
||||||
|
.epsilonNbStep(0)
|
||||||
|
.doubleDQN(true)
|
||||||
|
.build();
|
||||||
|
|
||||||
MockDataManager dataManager = new MockDataManager(false);
|
MockDataManager dataManager = new MockDataManager(false);
|
||||||
MockExpReplay expReplay = new MockExpReplay();
|
MockExpReplay expReplay = new MockExpReplay();
|
||||||
TestQLearningDiscrete sut = new TestQLearningDiscrete(mdp, dqn, conf, dataManager, expReplay, 10, random);
|
TestQLearningDiscrete sut = new TestQLearningDiscrete(mdp, dqn, conf, dataManager, expReplay, 10, random);
|
||||||
|
@ -58,9 +88,9 @@ public class QLearningDiscreteTest {
|
||||||
|
|
||||||
// Assert
|
// Assert
|
||||||
// HistoryProcessor calls
|
// HistoryProcessor calls
|
||||||
double[] expectedRecords = new double[] { 0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0, 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0 };
|
double[] expectedRecords = new double[]{0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0, 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0};
|
||||||
assertEquals(expectedRecords.length, hp.recordCalls.size());
|
assertEquals(expectedRecords.length, hp.recordCalls.size());
|
||||||
for(int i = 0; i < expectedRecords.length; ++i) {
|
for (int i = 0; i < expectedRecords.length; ++i) {
|
||||||
assertEquals(expectedRecords[i], hp.recordCalls.get(i).getDouble(0), 0.0001);
|
assertEquals(expectedRecords[i], hp.recordCalls.get(i).getDouble(0), 0.0001);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -72,59 +102,59 @@ public class QLearningDiscreteTest {
|
||||||
assertEquals(123.0, dqn.fitParams.get(0).getFirst().getDouble(0), 0.001);
|
assertEquals(123.0, dqn.fitParams.get(0).getFirst().getDouble(0), 0.001);
|
||||||
assertEquals(234.0, dqn.fitParams.get(0).getSecond().getDouble(0), 0.001);
|
assertEquals(234.0, dqn.fitParams.get(0).getSecond().getDouble(0), 0.001);
|
||||||
assertEquals(14, dqn.outputParams.size());
|
assertEquals(14, dqn.outputParams.size());
|
||||||
double[][] expectedDQNOutput = new double[][] {
|
double[][] expectedDQNOutput = new double[][]{
|
||||||
new double[] { 0.0, 2.0, 4.0, 6.0, 8.0 },
|
new double[]{0.0, 2.0, 4.0, 6.0, 8.0},
|
||||||
new double[] { 2.0, 4.0, 6.0, 8.0, 10.0 },
|
new double[]{2.0, 4.0, 6.0, 8.0, 10.0},
|
||||||
new double[] { 2.0, 4.0, 6.0, 8.0, 10.0 },
|
new double[]{2.0, 4.0, 6.0, 8.0, 10.0},
|
||||||
new double[] { 4.0, 6.0, 8.0, 10.0, 12.0 },
|
new double[]{4.0, 6.0, 8.0, 10.0, 12.0},
|
||||||
new double[] { 6.0, 8.0, 10.0, 12.0, 14.0 },
|
new double[]{6.0, 8.0, 10.0, 12.0, 14.0},
|
||||||
new double[] { 6.0, 8.0, 10.0, 12.0, 14.0 },
|
new double[]{6.0, 8.0, 10.0, 12.0, 14.0},
|
||||||
new double[] { 8.0, 10.0, 12.0, 14.0, 16.0 },
|
new double[]{8.0, 10.0, 12.0, 14.0, 16.0},
|
||||||
new double[] { 8.0, 10.0, 12.0, 14.0, 16.0 },
|
new double[]{8.0, 10.0, 12.0, 14.0, 16.0},
|
||||||
new double[] { 10.0, 12.0, 14.0, 16.0, 18.0 },
|
new double[]{10.0, 12.0, 14.0, 16.0, 18.0},
|
||||||
new double[] { 10.0, 12.0, 14.0, 16.0, 18.0 },
|
new double[]{10.0, 12.0, 14.0, 16.0, 18.0},
|
||||||
new double[] { 12.0, 14.0, 16.0, 18.0, 20.0 },
|
new double[]{12.0, 14.0, 16.0, 18.0, 20.0},
|
||||||
new double[] { 12.0, 14.0, 16.0, 18.0, 20.0 },
|
new double[]{12.0, 14.0, 16.0, 18.0, 20.0},
|
||||||
new double[] { 14.0, 16.0, 18.0, 20.0, 22.0 },
|
new double[]{14.0, 16.0, 18.0, 20.0, 22.0},
|
||||||
new double[] { 14.0, 16.0, 18.0, 20.0, 22.0 },
|
new double[]{14.0, 16.0, 18.0, 20.0, 22.0},
|
||||||
};
|
};
|
||||||
for(int i = 0; i < expectedDQNOutput.length; ++i) {
|
for (int i = 0; i < expectedDQNOutput.length; ++i) {
|
||||||
INDArray outputParam = dqn.outputParams.get(i);
|
INDArray outputParam = dqn.outputParams.get(i);
|
||||||
|
|
||||||
assertEquals(5, outputParam.shape()[1]);
|
assertEquals(5, outputParam.shape()[1]);
|
||||||
assertEquals(1, outputParam.shape()[2]);
|
assertEquals(1, outputParam.shape()[2]);
|
||||||
|
|
||||||
double[] expectedRow = expectedDQNOutput[i];
|
double[] expectedRow = expectedDQNOutput[i];
|
||||||
for(int j = 0; j < expectedRow.length; ++j) {
|
for (int j = 0; j < expectedRow.length; ++j) {
|
||||||
assertEquals("row: "+ i + " col: " + j, expectedRow[j], 255.0 * outputParam.getDouble(j), 0.00001);
|
assertEquals("row: " + i + " col: " + j, expectedRow[j], 255.0 * outputParam.getDouble(j), 0.00001);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// MDP calls
|
// MDP calls
|
||||||
assertArrayEquals(new Integer[] {0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 4, 4, 2, 2, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4 }, mdp.actions.toArray());
|
assertArrayEquals(new Integer[]{0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 4, 4, 2, 2, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4}, mdp.actions.toArray());
|
||||||
|
|
||||||
// ExpReplay calls
|
// ExpReplay calls
|
||||||
double[] expectedTrRewards = new double[] { 9.0, 21.0, 25.0, 29.0, 33.0, 37.0, 41.0 };
|
double[] expectedTrRewards = new double[]{9.0, 21.0, 25.0, 29.0, 33.0, 37.0, 41.0};
|
||||||
int[] expectedTrActions = new int[] { 1, 4, 2, 4, 4, 4, 4, 4 };
|
int[] expectedTrActions = new int[]{1, 4, 2, 4, 4, 4, 4, 4};
|
||||||
double[] expectedTrNextObservation = new double[] { 2.0, 4.0, 6.0, 8.0, 10.0, 12.0, 14.0, 16.0 };
|
double[] expectedTrNextObservation = new double[]{2.0, 4.0, 6.0, 8.0, 10.0, 12.0, 14.0, 16.0};
|
||||||
double[][] expectedTrObservations = new double[][] {
|
double[][] expectedTrObservations = new double[][]{
|
||||||
new double[] { 0.0, 2.0, 4.0, 6.0, 8.0 },
|
new double[]{0.0, 2.0, 4.0, 6.0, 8.0},
|
||||||
new double[] { 2.0, 4.0, 6.0, 8.0, 10.0 },
|
new double[]{2.0, 4.0, 6.0, 8.0, 10.0},
|
||||||
new double[] { 4.0, 6.0, 8.0, 10.0, 12.0 },
|
new double[]{4.0, 6.0, 8.0, 10.0, 12.0},
|
||||||
new double[] { 6.0, 8.0, 10.0, 12.0, 14.0 },
|
new double[]{6.0, 8.0, 10.0, 12.0, 14.0},
|
||||||
new double[] { 8.0, 10.0, 12.0, 14.0, 16.0 },
|
new double[]{8.0, 10.0, 12.0, 14.0, 16.0},
|
||||||
new double[] { 10.0, 12.0, 14.0, 16.0, 18.0 },
|
new double[]{10.0, 12.0, 14.0, 16.0, 18.0},
|
||||||
new double[] { 12.0, 14.0, 16.0, 18.0, 20.0 },
|
new double[]{12.0, 14.0, 16.0, 18.0, 20.0},
|
||||||
new double[] { 14.0, 16.0, 18.0, 20.0, 22.0 },
|
new double[]{14.0, 16.0, 18.0, 20.0, 22.0},
|
||||||
};
|
};
|
||||||
assertEquals(expectedTrObservations.length, expReplay.transitions.size());
|
assertEquals(expectedTrObservations.length, expReplay.transitions.size());
|
||||||
for(int i = 0; i < expectedTrRewards.length; ++i) {
|
for (int i = 0; i < expectedTrRewards.length; ++i) {
|
||||||
Transition tr = expReplay.transitions.get(i);
|
Transition tr = expReplay.transitions.get(i);
|
||||||
assertEquals(expectedTrRewards[i], tr.getReward(), 0.0001);
|
assertEquals(expectedTrRewards[i], tr.getReward(), 0.0001);
|
||||||
assertEquals(expectedTrActions[i], tr.getAction());
|
assertEquals(expectedTrActions[i], tr.getAction());
|
||||||
assertEquals(expectedTrNextObservation[i], 255.0 * tr.getNextObservation().getDouble(0), 0.0001);
|
assertEquals(expectedTrNextObservation[i], 255.0 * tr.getNextObservation().getDouble(0), 0.0001);
|
||||||
for(int j = 0; j < expectedTrObservations[i].length; ++j) {
|
for (int j = 0; j < expectedTrObservations[i].length; ++j) {
|
||||||
assertEquals("row: "+ i + " col: " + j, expectedTrObservations[i][j], 255.0 * tr.getObservation().getData().getDouble(0, j, 0), 0.0001);
|
assertEquals("row: " + i + " col: " + j, expectedTrObservations[i][j], 255.0 * tr.getObservation().getData().getDouble(0, j, 0), 0.0001);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -132,12 +162,12 @@ public class QLearningDiscreteTest {
|
||||||
assertEquals(initStepCount + 16, result.getStepCounter());
|
assertEquals(initStepCount + 16, result.getStepCounter());
|
||||||
assertEquals(300.0, result.getReward(), 0.00001);
|
assertEquals(300.0, result.getReward(), 0.00001);
|
||||||
assertTrue(dqn.hasBeenReset);
|
assertTrue(dqn.hasBeenReset);
|
||||||
assertTrue(((MockDQN)sut.getTargetQNetwork()).hasBeenReset);
|
assertTrue(((MockDQN) sut.getTargetQNetwork()).hasBeenReset);
|
||||||
}
|
}
|
||||||
|
|
||||||
public static class TestQLearningDiscrete extends QLearningDiscrete<MockEncodable> {
|
public static class TestQLearningDiscrete extends QLearningDiscrete<MockEncodable> {
|
||||||
public TestQLearningDiscrete(MDP<MockEncodable, Integer, DiscreteSpace> mdp, IDQN dqn,
|
public TestQLearningDiscrete(MDP<MockEncodable, Integer, DiscreteSpace> mdp, IDQN dqn,
|
||||||
QLConfiguration conf, IDataManager dataManager, MockExpReplay expReplay,
|
QLearningConfiguration conf, IDataManager dataManager, MockExpReplay expReplay,
|
||||||
int epsilonNbStep, Random rnd) {
|
int epsilonNbStep, Random rnd) {
|
||||||
super(mdp, dqn, conf, epsilonNbStep, rnd);
|
super(mdp, dqn, conf, epsilonNbStep, rnd);
|
||||||
addListener(new DataManagerTrainingListener(dataManager));
|
addListener(new DataManagerTrainingListener(dataManager));
|
||||||
|
@ -146,10 +176,10 @@ public class QLearningDiscreteTest {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
protected DataSet setTarget(ArrayList<Transition<Integer>> transitions) {
|
protected DataSet setTarget(ArrayList<Transition<Integer>> transitions) {
|
||||||
return new org.nd4j.linalg.dataset.DataSet(Nd4j.create(new double[] { 123.0 }), Nd4j.create(new double[] { 234.0 }));
|
return new org.nd4j.linalg.dataset.DataSet(Nd4j.create(new double[]{123.0}), Nd4j.create(new double[]{234.0}));
|
||||||
}
|
}
|
||||||
|
|
||||||
public void setExpReplay(IExpReplay<Integer> exp){
|
public void setExpReplay(IExpReplay<Integer> exp) {
|
||||||
this.expReplay = exp;
|
this.expReplay = exp;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
/*******************************************************************************
|
/*******************************************************************************
|
||||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
* Copyright (c) 2015-2019 Skymind, Inc.
|
||||||
|
* Copyright (c) 2020 Konduit K.K.
|
||||||
*
|
*
|
||||||
* This program and the accompanying materials are made available under the
|
* This program and the accompanying materials are made available under the
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
@ -16,6 +17,7 @@
|
||||||
|
|
||||||
package org.deeplearning4j.rl4j.network.ac;
|
package org.deeplearning4j.rl4j.network.ac;
|
||||||
|
|
||||||
|
import org.deeplearning4j.rl4j.network.configuration.ActorCriticDenseNetworkConfiguration;
|
||||||
import org.junit.Test;
|
import org.junit.Test;
|
||||||
import org.nd4j.linalg.activations.impl.ActivationSoftmax;
|
import org.nd4j.linalg.activations.impl.ActivationSoftmax;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
@ -29,30 +31,31 @@ import static org.junit.Assert.assertEquals;
|
||||||
import static org.junit.Assert.assertTrue;
|
import static org.junit.Assert.assertTrue;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
*
|
|
||||||
* @author saudet
|
* @author saudet
|
||||||
*/
|
*/
|
||||||
public class ActorCriticTest {
|
public class ActorCriticTest {
|
||||||
|
|
||||||
public static ActorCriticFactorySeparateStdDense.Configuration NET_CONF =
|
public static ActorCriticDenseNetworkConfiguration NET_CONF =
|
||||||
new ActorCriticFactorySeparateStdDense.Configuration(
|
ActorCriticDenseNetworkConfiguration.builder()
|
||||||
4, //number of layers
|
.numLayers(4)
|
||||||
32, //number of hidden nodes
|
.numHiddenNodes(32)
|
||||||
0.001, //l2 regularization
|
.l2(0.001)
|
||||||
new RmsProp(0.0005), null, false
|
.updater(new RmsProp(0.0005))
|
||||||
);
|
.useLSTM(false)
|
||||||
|
.build();
|
||||||
|
|
||||||
public static ActorCriticFactoryCompGraphStdDense.Configuration NET_CONF_CG =
|
public static ActorCriticDenseNetworkConfiguration NET_CONF_CG =
|
||||||
new ActorCriticFactoryCompGraphStdDense.Configuration(
|
ActorCriticDenseNetworkConfiguration.builder()
|
||||||
2, //number of layers
|
.numLayers(2)
|
||||||
128, //number of hidden nodes
|
.numHiddenNodes(128)
|
||||||
0.00001, //l2 regularization
|
.l2(0.00001)
|
||||||
new RmsProp(0.005), null, true
|
.updater(new RmsProp(0.005))
|
||||||
);
|
.useLSTM(true)
|
||||||
|
.build();
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testModelLoadSave() throws IOException {
|
public void testModelLoadSave() throws IOException {
|
||||||
ActorCriticSeparate acs = new ActorCriticFactorySeparateStdDense(NET_CONF).buildActorCritic(new int[] {7}, 5);
|
ActorCriticSeparate acs = new ActorCriticFactorySeparateStdDense(NET_CONF).buildActorCritic(new int[]{7}, 5);
|
||||||
|
|
||||||
File fileValue = File.createTempFile("rl4j-value-", ".model");
|
File fileValue = File.createTempFile("rl4j-value-", ".model");
|
||||||
File filePolicy = File.createTempFile("rl4j-policy-", ".model");
|
File filePolicy = File.createTempFile("rl4j-policy-", ".model");
|
||||||
|
@ -63,7 +66,7 @@ public class ActorCriticTest {
|
||||||
assertEquals(acs.valueNet, acs2.valueNet);
|
assertEquals(acs.valueNet, acs2.valueNet);
|
||||||
assertEquals(acs.policyNet, acs2.policyNet);
|
assertEquals(acs.policyNet, acs2.policyNet);
|
||||||
|
|
||||||
ActorCriticCompGraph accg = new ActorCriticFactoryCompGraphStdDense(NET_CONF_CG).buildActorCritic(new int[] {37}, 43);
|
ActorCriticCompGraph accg = new ActorCriticFactoryCompGraphStdDense(NET_CONF_CG).buildActorCritic(new int[]{37}, 43);
|
||||||
|
|
||||||
File file = File.createTempFile("rl4j-cg-", ".model");
|
File file = File.createTempFile("rl4j-cg-", ".model");
|
||||||
accg.save(file.getAbsolutePath());
|
accg.save(file.getAbsolutePath());
|
||||||
|
@ -83,15 +86,15 @@ public class ActorCriticTest {
|
||||||
|
|
||||||
for (double i = eps; i < n; i++) {
|
for (double i = eps; i < n; i++) {
|
||||||
for (double j = eps; j < n; j++) {
|
for (double j = eps; j < n; j++) {
|
||||||
INDArray labels = Nd4j.create(new double[] {i / n, 1 - i / n}, new long[]{1,2});
|
INDArray labels = Nd4j.create(new double[]{i / n, 1 - i / n}, new long[]{1, 2});
|
||||||
INDArray output = Nd4j.create(new double[] {j / n, 1 - j / n}, new long[]{1,2});
|
INDArray output = Nd4j.create(new double[]{j / n, 1 - j / n}, new long[]{1, 2});
|
||||||
INDArray gradient = loss.computeGradient(labels, output, activation, null);
|
INDArray gradient = loss.computeGradient(labels, output, activation, null);
|
||||||
|
|
||||||
output = Nd4j.create(new double[] {j / n, 1 - j / n}, new long[]{1,2});
|
output = Nd4j.create(new double[]{j / n, 1 - j / n}, new long[]{1, 2});
|
||||||
double score = loss.computeScore(labels, output, activation, null, false);
|
double score = loss.computeScore(labels, output, activation, null, false);
|
||||||
INDArray output1 = Nd4j.create(new double[] {j / n + eps, 1 - j / n}, new long[]{1,2});
|
INDArray output1 = Nd4j.create(new double[]{j / n + eps, 1 - j / n}, new long[]{1, 2});
|
||||||
double score1 = loss.computeScore(labels, output1, activation, null, false);
|
double score1 = loss.computeScore(labels, output1, activation, null, false);
|
||||||
INDArray output2 = Nd4j.create(new double[] {j / n, 1 - j / n + eps}, new long[]{1,2});
|
INDArray output2 = Nd4j.create(new double[]{j / n, 1 - j / n + eps}, new long[]{1, 2});
|
||||||
double score2 = loss.computeScore(labels, output2, activation, null, false);
|
double score2 = loss.computeScore(labels, output2, activation, null, false);
|
||||||
|
|
||||||
double gradient1 = (score1 - score) / eps;
|
double gradient1 = (score1 - score) / eps;
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
/*******************************************************************************
|
/*******************************************************************************
|
||||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
* Copyright (c) 2015-2019 Skymind, Inc.
|
||||||
|
* Copyright (c) 2020 Konduit K.K.
|
||||||
*
|
*
|
||||||
* This program and the accompanying materials are made available under the
|
* This program and the accompanying materials are made available under the
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
@ -16,6 +17,7 @@
|
||||||
|
|
||||||
package org.deeplearning4j.rl4j.network.dqn;
|
package org.deeplearning4j.rl4j.network.dqn;
|
||||||
|
|
||||||
|
import org.deeplearning4j.rl4j.network.configuration.DQNDenseNetworkConfiguration;
|
||||||
import org.junit.Test;
|
import org.junit.Test;
|
||||||
import org.nd4j.linalg.learning.config.RmsProp;
|
import org.nd4j.linalg.learning.config.RmsProp;
|
||||||
|
|
||||||
|
@ -25,22 +27,20 @@ import java.io.IOException;
|
||||||
import static org.junit.Assert.assertEquals;
|
import static org.junit.Assert.assertEquals;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
*
|
|
||||||
* @author saudet
|
* @author saudet
|
||||||
*/
|
*/
|
||||||
public class DQNTest {
|
public class DQNTest {
|
||||||
|
|
||||||
public static DQNFactoryStdDense.Configuration NET_CONF =
|
private static DQNDenseNetworkConfiguration NET_CONF =
|
||||||
new DQNFactoryStdDense.Configuration(
|
DQNDenseNetworkConfiguration.builder().numLayers(3)
|
||||||
3, //number of layers
|
.numHiddenNodes(16)
|
||||||
16, //number of hidden nodes
|
.l2(0.001)
|
||||||
0.001, //l2 regularization
|
.updater(new RmsProp(0.0005))
|
||||||
new RmsProp(0.0005), null
|
.build();
|
||||||
);
|
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testModelLoadSave() throws IOException {
|
public void testModelLoadSave() throws IOException {
|
||||||
DQN dqn = new DQNFactoryStdDense(NET_CONF).buildDQN(new int[] {42}, 13);
|
DQN dqn = new DQNFactoryStdDense(NET_CONF).buildDQN(new int[]{42}, 13);
|
||||||
|
|
||||||
File file = File.createTempFile("rl4j-dqn-", ".model");
|
File file = File.createTempFile("rl4j-dqn-", ".model");
|
||||||
dqn.save(file.getAbsolutePath());
|
dqn.save(file.getAbsolutePath());
|
||||||
|
|
|
@ -128,7 +128,7 @@ public class TransformProcessTest {
|
||||||
|
|
||||||
// Assert
|
// Assert
|
||||||
assertFalse(result.isSkipped());
|
assertFalse(result.isSkipped());
|
||||||
assertEquals(1, result.getData().shape().length);
|
assertEquals(2, result.getData().shape().length);
|
||||||
assertEquals(1, result.getData().shape()[0]);
|
assertEquals(1, result.getData().shape()[0]);
|
||||||
assertEquals(-10.0, result.getData().getDouble(0), 0.00001);
|
assertEquals(-10.0, result.getData().getDouble(0), 0.00001);
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
/*******************************************************************************
|
/*******************************************************************************
|
||||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
* Copyright (c) 2015-2019 Skymind, Inc.
|
||||||
|
* Copyright (c) 2020 Konduit K.K.
|
||||||
*
|
*
|
||||||
* This program and the accompanying materials are made available under the
|
* This program and the accompanying materials are made available under the
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
@ -24,16 +25,18 @@ import org.deeplearning4j.nn.graph.ComputationGraph;
|
||||||
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
|
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
|
||||||
import org.deeplearning4j.rl4j.learning.IHistoryProcessor;
|
import org.deeplearning4j.rl4j.learning.IHistoryProcessor;
|
||||||
import org.deeplearning4j.rl4j.learning.Learning;
|
import org.deeplearning4j.rl4j.learning.Learning;
|
||||||
import org.deeplearning4j.rl4j.learning.sync.qlearning.QLearning;
|
import org.deeplearning4j.rl4j.learning.configuration.QLearningConfiguration;
|
||||||
import org.deeplearning4j.rl4j.learning.sync.qlearning.discrete.QLearningDiscreteTest;
|
|
||||||
import org.deeplearning4j.rl4j.mdp.MDP;
|
|
||||||
import org.deeplearning4j.rl4j.network.NeuralNet;
|
import org.deeplearning4j.rl4j.network.NeuralNet;
|
||||||
import org.deeplearning4j.rl4j.network.ac.IActorCritic;
|
import org.deeplearning4j.rl4j.network.ac.IActorCritic;
|
||||||
import org.deeplearning4j.rl4j.observation.Observation;
|
import org.deeplearning4j.rl4j.observation.Observation;
|
||||||
import org.deeplearning4j.rl4j.space.ActionSpace;
|
import org.deeplearning4j.rl4j.space.ActionSpace;
|
||||||
import org.deeplearning4j.rl4j.space.DiscreteSpace;
|
import org.deeplearning4j.rl4j.support.MockDQN;
|
||||||
import org.deeplearning4j.rl4j.space.Encodable;
|
import org.deeplearning4j.rl4j.support.MockEncodable;
|
||||||
import org.deeplearning4j.rl4j.support.*;
|
import org.deeplearning4j.rl4j.support.MockHistoryProcessor;
|
||||||
|
import org.deeplearning4j.rl4j.support.MockMDP;
|
||||||
|
import org.deeplearning4j.rl4j.support.MockNeuralNet;
|
||||||
|
import org.deeplearning4j.rl4j.support.MockObservationSpace;
|
||||||
|
import org.deeplearning4j.rl4j.support.MockRandom;
|
||||||
import org.deeplearning4j.rl4j.util.LegacyMDPWrapper;
|
import org.deeplearning4j.rl4j.util.LegacyMDPWrapper;
|
||||||
import org.junit.Test;
|
import org.junit.Test;
|
||||||
import org.nd4j.linalg.activations.Activation;
|
import org.nd4j.linalg.activations.Activation;
|
||||||
|
@ -43,8 +46,6 @@ import org.nd4j.linalg.lossfunctions.LossFunctions;
|
||||||
|
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
import java.io.OutputStream;
|
import java.io.OutputStream;
|
||||||
import java.util.ArrayList;
|
|
||||||
import java.util.List;
|
|
||||||
|
|
||||||
import static org.junit.Assert.assertEquals;
|
import static org.junit.Assert.assertEquals;
|
||||||
import static org.junit.Assert.assertNotNull;
|
import static org.junit.Assert.assertNotNull;
|
||||||
|
@ -186,8 +187,22 @@ public class PolicyTest {
|
||||||
new int[] { 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4 });
|
new int[] { 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4 });
|
||||||
MockMDP mdp = new MockMDP(observationSpace, 30, random);
|
MockMDP mdp = new MockMDP(observationSpace, 30, random);
|
||||||
|
|
||||||
QLearning.QLConfiguration conf = new QLearning.QLConfiguration(0, 0, 0, 5, 1, 0,
|
QLearningConfiguration conf = QLearningConfiguration.builder()
|
||||||
0, 1.0, 0, 0, 0, 0, true);
|
.seed(0L)
|
||||||
|
.maxEpochStep(0)
|
||||||
|
.maxStep(0)
|
||||||
|
.expRepMaxSize(5)
|
||||||
|
.batchSize(1)
|
||||||
|
.targetDqnUpdateFreq(0)
|
||||||
|
.updateStart(0)
|
||||||
|
.rewardFactor(1.0)
|
||||||
|
.gamma(0)
|
||||||
|
.errorClamp(0)
|
||||||
|
.minEpsilon(0)
|
||||||
|
.epsilonNbStep(0)
|
||||||
|
.doubleDQN(true)
|
||||||
|
.build();
|
||||||
|
|
||||||
MockNeuralNet nnMock = new MockNeuralNet();
|
MockNeuralNet nnMock = new MockNeuralNet();
|
||||||
IHistoryProcessor.Configuration hpConf = new IHistoryProcessor.Configuration(5, 4, 4, 4, 4, 0, 0, 2);
|
IHistoryProcessor.Configuration hpConf = new IHistoryProcessor.Configuration(5, 4, 4, 4, 4, 0, 0, 2);
|
||||||
MockRefacPolicy sut = new MockRefacPolicy(nnMock, observationSpace.getShape(), hpConf.getSkipFrame(), hpConf.getHistoryLength());
|
MockRefacPolicy sut = new MockRefacPolicy(nnMock, observationSpace.getShape(), hpConf.getSkipFrame(), hpConf.getHistoryLength());
|
||||||
|
|
|
@ -1,22 +1,37 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2020 Konduit K.K.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
package org.deeplearning4j.rl4j.support;
|
package org.deeplearning4j.rl4j.support;
|
||||||
|
|
||||||
import lombok.AllArgsConstructor;
|
import lombok.AllArgsConstructor;
|
||||||
import lombok.Getter;
|
|
||||||
import lombok.Value;
|
import lombok.Value;
|
||||||
import org.deeplearning4j.rl4j.learning.async.AsyncConfiguration;
|
import org.deeplearning4j.rl4j.learning.configuration.IAsyncLearningConfiguration;
|
||||||
|
|
||||||
@AllArgsConstructor
|
|
||||||
@Value
|
@Value
|
||||||
public class MockAsyncConfiguration implements AsyncConfiguration {
|
@AllArgsConstructor
|
||||||
|
public class MockAsyncConfiguration implements IAsyncLearningConfiguration {
|
||||||
|
|
||||||
private Integer seed;
|
private Long seed;
|
||||||
private int maxEpochStep;
|
private int maxEpochStep;
|
||||||
private int maxStep;
|
private int maxStep;
|
||||||
private int numThread;
|
|
||||||
private int nstep;
|
|
||||||
private int targetDqnUpdateFreq;
|
|
||||||
private int updateStart;
|
private int updateStart;
|
||||||
private double rewardFactor;
|
private double rewardFactor;
|
||||||
private double gamma;
|
private double gamma;
|
||||||
private double errorClamp;
|
private double errorClamp;
|
||||||
|
private int numThreads;
|
||||||
|
private int nStep;
|
||||||
|
private int learnerUpdateFrequency;
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,3 +1,20 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2015-2019 Skymind, Inc.
|
||||||
|
* Copyright (c) 2020 Konduit K.K.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
package org.deeplearning4j.rl4j.util;
|
package org.deeplearning4j.rl4j.util;
|
||||||
|
|
||||||
import lombok.Getter;
|
import lombok.Getter;
|
||||||
|
@ -5,6 +22,7 @@ import lombok.Setter;
|
||||||
import org.deeplearning4j.rl4j.learning.IEpochTrainer;
|
import org.deeplearning4j.rl4j.learning.IEpochTrainer;
|
||||||
import org.deeplearning4j.rl4j.learning.IHistoryProcessor;
|
import org.deeplearning4j.rl4j.learning.IHistoryProcessor;
|
||||||
import org.deeplearning4j.rl4j.learning.ILearning;
|
import org.deeplearning4j.rl4j.learning.ILearning;
|
||||||
|
import org.deeplearning4j.rl4j.learning.configuration.ILearningConfiguration;
|
||||||
import org.deeplearning4j.rl4j.learning.listener.TrainingListener;
|
import org.deeplearning4j.rl4j.learning.listener.TrainingListener;
|
||||||
import org.deeplearning4j.rl4j.learning.sync.support.MockStatEntry;
|
import org.deeplearning4j.rl4j.learning.sync.support.MockStatEntry;
|
||||||
import org.deeplearning4j.rl4j.mdp.MDP;
|
import org.deeplearning4j.rl4j.mdp.MDP;
|
||||||
|
@ -162,7 +180,7 @@ public class DataManagerTrainingListenerTest {
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public LConfiguration getConfiguration() {
|
public ILearningConfiguration getConfiguration() {
|
||||||
return null;
|
return null;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
/*******************************************************************************
|
/*******************************************************************************
|
||||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
* Copyright (c) 2015-2019 Skymind, Inc.
|
||||||
|
* Copyright (c) 2020 Konduit K.K.
|
||||||
*
|
*
|
||||||
* This program and the accompanying materials are made available under the
|
* This program and the accompanying materials are made available under the
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
@ -16,19 +17,18 @@
|
||||||
|
|
||||||
package org.deeplearning4j.malmo;
|
package org.deeplearning4j.malmo;
|
||||||
|
|
||||||
import java.util.HashMap;
|
import com.microsoft.msr.malmo.TimestampedStringVector;
|
||||||
|
import com.microsoft.msr.malmo.WorldState;
|
||||||
import org.json.JSONArray;
|
import org.json.JSONArray;
|
||||||
import org.json.JSONObject;
|
import org.json.JSONObject;
|
||||||
|
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
|
|
||||||
import com.microsoft.msr.malmo.TimestampedStringVector;
|
import java.util.HashMap;
|
||||||
import com.microsoft.msr.malmo.WorldState;
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Observation space that contains a grid of Minecraft blocks
|
* Observation space that contains a grid of Minecraft blocks
|
||||||
|
*
|
||||||
* @author howard-abrams (howard.abrams@ca.com) on 1/12/17.
|
* @author howard-abrams (howard.abrams@ca.com) on 1/12/17.
|
||||||
*/
|
*/
|
||||||
public class MalmoObservationSpaceGrid extends MalmoObservationSpace {
|
public class MalmoObservationSpaceGrid extends MalmoObservationSpace {
|
||||||
|
@ -44,10 +44,10 @@ public class MalmoObservationSpaceGrid extends MalmoObservationSpace {
|
||||||
/**
|
/**
|
||||||
* Construct observation space from a array of blocks policy should distinguish between.
|
* Construct observation space from a array of blocks policy should distinguish between.
|
||||||
*
|
*
|
||||||
* @param name Name given to Grid element in mission specification
|
* @param name Name given to Grid element in mission specification
|
||||||
* @param xSize total x size of grid
|
* @param xSize total x size of grid
|
||||||
* @param ySize total y size of grid
|
* @param ySize total y size of grid
|
||||||
* @param zSize total z size of grid
|
* @param zSize total z size of grid
|
||||||
* @param blocks Array of block names to distinguish between. Supports combination of individual strings and/or arrays of strings to map multiple block types to a single observation value. If not specified, it will dynamically map block names to integers - however, because these will be mapped as they are seen, different missions may have different mappings!
|
* @param blocks Array of block names to distinguish between. Supports combination of individual strings and/or arrays of strings to map multiple block types to a single observation value. If not specified, it will dynamically map block names to integers - however, because these will be mapped as they are seen, different missions may have different mappings!
|
||||||
*/
|
*/
|
||||||
public MalmoObservationSpaceGrid(String name, int xSize, int ySize, int zSize, Object... blocks) {
|
public MalmoObservationSpaceGrid(String name, int xSize, int ySize, int zSize, Object... blocks) {
|
||||||
|
@ -78,7 +78,7 @@ public class MalmoObservationSpaceGrid extends MalmoObservationSpace {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public int[] getShape() {
|
public int[] getShape() {
|
||||||
return new int[] {totalSize};
|
return new int[]{totalSize};
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
|
Loading…
Reference in New Issue