getMdp();
IHistoryProcessor getHistoryProcessor();
- interface LConfiguration {
- Integer getSeed();
-
- int getMaxEpochStep();
-
- int getMaxStep();
-
- double getGamma();
- }
}
diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/AsyncGlobal.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/AsyncGlobal.java
index 5501a29e1..01c519b57 100644
--- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/AsyncGlobal.java
+++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/AsyncGlobal.java
@@ -1,5 +1,6 @@
/*******************************************************************************
- * Copyright (c) 2015-2018 Skymind, Inc.
+ * Copyright (c) 2015-2019 Skymind, Inc.
+ * Copyright (c) 2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
@@ -19,6 +20,8 @@ package org.deeplearning4j.rl4j.learning.async;
import lombok.Getter;
import lombok.extern.slf4j.Slf4j;
import org.deeplearning4j.nn.gradient.Gradient;
+import org.deeplearning4j.rl4j.learning.configuration.AsyncQLearningConfiguration;
+import org.deeplearning4j.rl4j.learning.configuration.IAsyncLearningConfiguration;
import org.deeplearning4j.rl4j.network.NeuralNet;
import org.nd4j.linalg.primitives.Pair;
@@ -27,28 +30,26 @@ import java.util.concurrent.atomic.AtomicInteger;
/**
* @author rubenfiszel (ruben.fiszel@epfl.ch) on 8/5/16.
- *
+ *
* In the original paper, the authors uses Asynchronous
* Gradient Descent: Hogwild! It is a way to apply gradients
* and modify a model in a lock-free manner.
- *
+ *
* As a way to implement this with dl4j, it is unfortunately
* necessary at the time of writing to apply the gradient
* (update the parameters) on a single separate global thread.
- *
+ *
* This Central thread for Asynchronous Method of reinforcement learning
* enqueue the gradients coming from the different threads and update its
* model and target. Those neurals nets are then synced by the other threads.
- *
+ *
* The benefits of this thread is that the updater is "shared" between all thread
* we have a single updater which is the single updater of the model contained here
- *
+ *
* This is similar to RMSProp with shared g and momentum
- *
+ *
* When Hogwild! is implemented, this could be replaced by a simple data
* structure
- *
- *
*/
@Slf4j
public class AsyncGlobal extends Thread implements IAsyncGlobal {
@@ -56,7 +57,7 @@ public class AsyncGlobal extends Thread implements IAsyncG
@Getter
final private NN current;
final private ConcurrentLinkedQueue> queue;
- final private AsyncConfiguration a3cc;
+ final private IAsyncLearningConfiguration configuration;
private final IAsyncLearning learning;
@Getter
private AtomicInteger T = new AtomicInteger(0);
@@ -65,20 +66,20 @@ public class AsyncGlobal extends Thread implements IAsyncG
@Getter
private boolean running = true;
- public AsyncGlobal(NN initial, AsyncConfiguration a3cc, IAsyncLearning learning) {
+ public AsyncGlobal(NN initial, IAsyncLearningConfiguration configuration, IAsyncLearning learning) {
this.current = initial;
target = (NN) initial.clone();
- this.a3cc = a3cc;
+ this.configuration = configuration;
this.learning = learning;
queue = new ConcurrentLinkedQueue<>();
}
public boolean isTrainingComplete() {
- return T.get() >= a3cc.getMaxStep();
+ return T.get() >= configuration.getMaxStep();
}
public void enqueue(Gradient[] gradient, Integer nstep) {
- if(running && !isTrainingComplete()) {
+ if (running && !isTrainingComplete()) {
queue.add(new Pair<>(gradient, nstep));
}
}
@@ -94,9 +95,8 @@ public class AsyncGlobal extends Thread implements IAsyncG
synchronized (this) {
current.applyGradient(gradient, pair.getSecond());
}
- if (a3cc.getTargetDqnUpdateFreq() != -1
- && T.get() / a3cc.getTargetDqnUpdateFreq() > (T.get() - pair.getSecond())
- / a3cc.getTargetDqnUpdateFreq()) {
+ if (configuration.getLearnerUpdateFrequency() != -1 && T.get() / configuration.getLearnerUpdateFrequency() > (T.get() - pair.getSecond())
+ / configuration.getLearnerUpdateFrequency()) {
log.info("TARGET UPDATE at T = " + T.get());
synchronized (this) {
target.copy(current);
@@ -111,7 +111,7 @@ public class AsyncGlobal extends Thread implements IAsyncG
* Force the immediate termination of the AsyncGlobal instance. Queued work items will be discarded and the AsyncLearning instance will be forced to terminate too.
*/
public void terminate() {
- if(running) {
+ if (running) {
running = false;
queue.clear();
learning.terminate();
diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/AsyncLearning.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/AsyncLearning.java
index 994ec9cb0..1c3c83972 100644
--- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/AsyncLearning.java
+++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/AsyncLearning.java
@@ -1,5 +1,6 @@
/*******************************************************************************
- * Copyright (c) 2015-2018 Skymind, Inc.
+ * Copyright (c) 2015-2019 Skymind, Inc.
+ * Copyright (c) 2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
@@ -21,14 +22,17 @@ import lombok.Getter;
import lombok.Setter;
import lombok.extern.slf4j.Slf4j;
import org.deeplearning4j.rl4j.learning.Learning;
-import org.deeplearning4j.rl4j.learning.listener.*;
+import org.deeplearning4j.rl4j.learning.configuration.AsyncQLearningConfiguration;
+import org.deeplearning4j.rl4j.learning.configuration.IAsyncLearningConfiguration;
+import org.deeplearning4j.rl4j.learning.listener.TrainingListener;
+import org.deeplearning4j.rl4j.learning.listener.TrainingListenerList;
import org.deeplearning4j.rl4j.network.NeuralNet;
import org.deeplearning4j.rl4j.space.ActionSpace;
import org.deeplearning4j.rl4j.space.Encodable;
import org.nd4j.linalg.factory.Nd4j;
/**
- * The entry point for async training. This class will start a number ({@link AsyncConfiguration#getNumThread()
+ * The entry point for async training. This class will start a number ({@link AsyncQLearningConfiguration#getNumThreads()
* configuration.getNumThread()}) of worker threads. Then, it will monitor their progress at regular intervals
* (see setProgressEventInterval(int))
*
@@ -37,8 +41,8 @@ import org.nd4j.linalg.factory.Nd4j;
*/
@Slf4j
public abstract class AsyncLearning, NN extends NeuralNet>
- extends Learning
- implements IAsyncLearning {
+ extends Learning
+ implements IAsyncLearning {
private Thread monitorThread = null;
@@ -56,9 +60,10 @@ public abstract class AsyncLearning, NN extends Ne
}
private void handleTraining(RunContext context) {
- int maxSteps = Math.min(getConf().getNstep(), getConf().getMaxEpochStep() - currentEpochStep);
+ int maxSteps = Math.min(getConf().getNStep(), getConf().getMaxEpochStep() - currentEpochStep);
SubEpochReturn subEpochReturn = trainSubEpoch(context.obs, maxSteps);
context.obs = subEpochReturn.getLastObs();
@@ -197,7 +199,7 @@ public abstract class AsyncThread, NN extends Ne
protected abstract IAsyncGlobal getAsyncGlobal();
- protected abstract AsyncConfiguration getConf();
+ protected abstract IAsyncLearningConfiguration getConf();
protected abstract IPolicy getPolicy(NN net);
diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/AsyncThreadDiscrete.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/AsyncThreadDiscrete.java
index 27d49c366..a72abfa62 100644
--- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/AsyncThreadDiscrete.java
+++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/AsyncThreadDiscrete.java
@@ -1,5 +1,7 @@
/*******************************************************************************
- * Copyright (c) 2015-2018 Skymind, Inc.
+ * Copyright (c) 2015-2019 Skymind, Inc.
+ * Copyright (c) 2020 Konduit K.K.
+ *
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
@@ -112,7 +114,7 @@ public abstract class AsyncThreadDiscrete
rewards.add(new MiniTrans(obs.getData(), null, null, 0));
else {
INDArray[] output = null;
- if (getConf().getTargetDqnUpdateFreq() == -1)
+ if (getConf().getLearnerUpdateFrequency() == -1)
output = current.outputAll(obs.getData());
else synchronized (getAsyncGlobal()) {
output = getAsyncGlobal().getTarget().outputAll(obs.getData());
diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/a3c/discrete/A3CDiscrete.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/a3c/discrete/A3CDiscrete.java
index 81308ba5a..0608ec5cc 100644
--- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/a3c/discrete/A3CDiscrete.java
+++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/a3c/discrete/A3CDiscrete.java
@@ -1,5 +1,6 @@
/*******************************************************************************
- * Copyright (c) 2015-2018 Skymind, Inc.
+ * Copyright (c) 2015-2019 Skymind, Inc.
+ * Copyright (c) 2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
@@ -16,11 +17,15 @@
package org.deeplearning4j.rl4j.learning.async.a3c.discrete;
-import lombok.*;
-import org.deeplearning4j.rl4j.learning.async.AsyncConfiguration;
+import lombok.AllArgsConstructor;
+import lombok.Builder;
+import lombok.Data;
+import lombok.EqualsAndHashCode;
+import lombok.Getter;
import org.deeplearning4j.rl4j.learning.async.AsyncGlobal;
import org.deeplearning4j.rl4j.learning.async.AsyncLearning;
import org.deeplearning4j.rl4j.learning.async.AsyncThread;
+import org.deeplearning4j.rl4j.learning.configuration.A3CLearningConfiguration;
import org.deeplearning4j.rl4j.mdp.MDP;
import org.deeplearning4j.rl4j.network.ac.IActorCritic;
import org.deeplearning4j.rl4j.policy.ACPolicy;
@@ -32,15 +37,14 @@ import org.nd4j.linalg.factory.Nd4j;
/**
* @author rubenfiszel (ruben.fiszel@epfl.ch) 7/23/16.
* Training for A3C in the Discrete Domain
- *
+ *
* All methods are fully implemented as described in the
* https://arxiv.org/abs/1602.01783 paper.
- *
*/
public abstract class A3CDiscrete extends AsyncLearning {
@Getter
- final public A3CConfiguration configuration;
+ final public A3CLearningConfiguration configuration;
@Getter
final protected MDP mdp;
final private IActorCritic iActorCritic;
@@ -49,15 +53,15 @@ public abstract class A3CDiscrete extends AsyncLearning policy;
- public A3CDiscrete(MDP mdp, IActorCritic iActorCritic, A3CConfiguration conf) {
+ public A3CDiscrete(MDP mdp, IActorCritic iActorCritic, A3CLearningConfiguration conf) {
this.iActorCritic = iActorCritic;
this.mdp = mdp;
this.configuration = conf;
asyncGlobal = new AsyncGlobal<>(iActorCritic, conf, this);
- Integer seed = conf.getSeed();
+ Long seed = conf.getSeed();
Random rnd = Nd4j.getRandom();
- if(seed != null) {
+ if (seed != null) {
rnd.setSeed(seed);
}
@@ -65,7 +69,7 @@ public abstract class A3CDiscrete extends AsyncLearning extends AsyncLearning extends AsyncLearning
* Training for A3C in the Discrete Domain
- *
+ *
* Specialized constructors for the Conv (pixels input) case
* Specialized conf + provide additional type safety
- *
+ *
* It uses CompGraph because there is benefit to combine the
* first layers since they're essentially doing the same dimension
* reduction task
- *
**/
public class A3CDiscreteConv extends A3CDiscrete {
@@ -46,12 +48,22 @@ public class A3CDiscreteConv extends A3CDiscrete {
@Deprecated
public A3CDiscreteConv(MDP mdp, IActorCritic actorCritic,
- HistoryProcessor.Configuration hpconf, A3CConfiguration conf, IDataManager dataManager) {
+ HistoryProcessor.Configuration hpconf, A3CConfiguration conf, IDataManager dataManager) {
this(mdp, actorCritic, hpconf, conf);
addListener(new DataManagerTrainingListener(dataManager));
}
+
+ @Deprecated
public A3CDiscreteConv(MDP mdp, IActorCritic IActorCritic,
HistoryProcessor.Configuration hpconf, A3CConfiguration conf) {
+
+ super(mdp, IActorCritic, conf.toLearningConfiguration());
+ this.hpconf = hpconf;
+ setHistoryProcessor(hpconf);
+ }
+
+ public A3CDiscreteConv(MDP mdp, IActorCritic IActorCritic,
+ HistoryProcessor.Configuration hpconf, A3CLearningConfiguration conf) {
super(mdp, IActorCritic, conf);
this.hpconf = hpconf;
setHistoryProcessor(hpconf);
@@ -59,21 +71,35 @@ public class A3CDiscreteConv extends A3CDiscrete {
@Deprecated
public A3CDiscreteConv(MDP mdp, ActorCriticFactoryCompGraph factory,
- HistoryProcessor.Configuration hpconf, A3CConfiguration conf, IDataManager dataManager) {
+ HistoryProcessor.Configuration hpconf, A3CConfiguration conf, IDataManager dataManager) {
this(mdp, factory.buildActorCritic(hpconf.getShape(), mdp.getActionSpace().getSize()), hpconf, conf, dataManager);
}
+
+ @Deprecated
public A3CDiscreteConv(MDP mdp, ActorCriticFactoryCompGraph factory,
HistoryProcessor.Configuration hpconf, A3CConfiguration conf) {
this(mdp, factory.buildActorCritic(hpconf.getShape(), mdp.getActionSpace().getSize()), hpconf, conf);
}
+ public A3CDiscreteConv(MDP mdp, ActorCriticFactoryCompGraph factory,
+ HistoryProcessor.Configuration hpconf, A3CLearningConfiguration conf) {
+ this(mdp, factory.buildActorCritic(hpconf.getShape(), mdp.getActionSpace().getSize()), hpconf, conf);
+ }
+
@Deprecated
public A3CDiscreteConv(MDP mdp, ActorCriticFactoryCompGraphStdConv.Configuration netConf,
- HistoryProcessor.Configuration hpconf, A3CConfiguration conf, IDataManager dataManager) {
- this(mdp, new ActorCriticFactoryCompGraphStdConv(netConf), hpconf, conf, dataManager);
+ HistoryProcessor.Configuration hpconf, A3CConfiguration conf, IDataManager dataManager) {
+ this(mdp, new ActorCriticFactoryCompGraphStdConv(netConf.toNetworkConfiguration()), hpconf, conf, dataManager);
}
+
+ @Deprecated
public A3CDiscreteConv(MDP mdp, ActorCriticFactoryCompGraphStdConv.Configuration netConf,
HistoryProcessor.Configuration hpconf, A3CConfiguration conf) {
+ this(mdp, new ActorCriticFactoryCompGraphStdConv(netConf.toNetworkConfiguration()), hpconf, conf);
+ }
+
+ public A3CDiscreteConv(MDP mdp, ActorCriticNetworkConfiguration netConf,
+ HistoryProcessor.Configuration hpconf, A3CLearningConfiguration conf) {
this(mdp, new ActorCriticFactoryCompGraphStdConv(netConf), hpconf, conf);
}
diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/a3c/discrete/A3CDiscreteDense.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/a3c/discrete/A3CDiscreteDense.java
index 16b8151df..74332bf3a 100644
--- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/a3c/discrete/A3CDiscreteDense.java
+++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/a3c/discrete/A3CDiscreteDense.java
@@ -1,5 +1,6 @@
/*******************************************************************************
- * Copyright (c) 2015-2018 Skymind, Inc.
+ * Copyright (c) 2015-2019 Skymind, Inc.
+ * Copyright (c) 2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
@@ -16,8 +17,10 @@
package org.deeplearning4j.rl4j.learning.async.a3c.discrete;
+import org.deeplearning4j.rl4j.learning.configuration.A3CLearningConfiguration;
import org.deeplearning4j.rl4j.mdp.MDP;
import org.deeplearning4j.rl4j.network.ac.*;
+import org.deeplearning4j.rl4j.network.configuration.ActorCriticDenseNetworkConfiguration;
import org.deeplearning4j.rl4j.space.DiscreteSpace;
import org.deeplearning4j.rl4j.space.Encodable;
import org.deeplearning4j.rl4j.util.DataManagerTrainingListener;
@@ -25,67 +28,81 @@ import org.deeplearning4j.rl4j.util.IDataManager;
/**
* @author rubenfiszel (ruben.fiszel@epfl.ch) on 8/8/16.
- *
+ *
* Training for A3C in the Discrete Domain
- *
+ *
* We use specifically the Separate version because
* the model is too small to have enough benefit by sharing layers
- *
*/
public class A3CDiscreteDense extends A3CDiscrete {
@Deprecated
public A3CDiscreteDense(MDP mdp, IActorCritic IActorCritic, A3CConfiguration conf,
- IDataManager dataManager) {
+ IDataManager dataManager) {
this(mdp, IActorCritic, conf);
addListener(new DataManagerTrainingListener(dataManager));
}
+
+ @Deprecated
public A3CDiscreteDense(MDP mdp, IActorCritic actorCritic, A3CConfiguration conf) {
+ super(mdp, actorCritic, conf.toLearningConfiguration());
+ }
+
+ public A3CDiscreteDense(MDP mdp, IActorCritic actorCritic, A3CLearningConfiguration conf) {
super(mdp, actorCritic, conf);
}
@Deprecated
public A3CDiscreteDense(MDP mdp, ActorCriticFactorySeparate factory,
- A3CConfiguration conf, IDataManager dataManager) {
+ A3CConfiguration conf, IDataManager dataManager) {
this(mdp, factory.buildActorCritic(mdp.getObservationSpace().getShape(), mdp.getActionSpace().getSize()), conf,
- dataManager);
+ dataManager);
}
+
+ @Deprecated
public A3CDiscreteDense(MDP mdp, ActorCriticFactorySeparate factory,
A3CConfiguration conf) {
this(mdp, factory.buildActorCritic(mdp.getObservationSpace().getShape(), mdp.getActionSpace().getSize()), conf);
}
- @Deprecated
- public A3CDiscreteDense(MDP mdp,
- ActorCriticFactorySeparateStdDense.Configuration netConf, A3CConfiguration conf,
- IDataManager dataManager) {
- this(mdp, new ActorCriticFactorySeparateStdDense(netConf), conf, dataManager);
- }
- public A3CDiscreteDense(MDP mdp,
- ActorCriticFactorySeparateStdDense.Configuration netConf, A3CConfiguration conf) {
- this(mdp, new ActorCriticFactorySeparateStdDense(netConf), conf);
- }
-
- @Deprecated
- public A3CDiscreteDense(MDP mdp, ActorCriticFactoryCompGraph factory,
- A3CConfiguration conf, IDataManager dataManager) {
- this(mdp, factory.buildActorCritic(mdp.getObservationSpace().getShape(), mdp.getActionSpace().getSize()), conf,
- dataManager);
- }
- public A3CDiscreteDense(MDP mdp, ActorCriticFactoryCompGraph factory,
- A3CConfiguration conf) {
+ public A3CDiscreteDense(MDP mdp, ActorCriticFactorySeparate factory,
+ A3CLearningConfiguration conf) {
this(mdp, factory.buildActorCritic(mdp.getObservationSpace().getShape(), mdp.getActionSpace().getSize()), conf);
}
@Deprecated
public A3CDiscreteDense(MDP mdp,
- ActorCriticFactoryCompGraphStdDense.Configuration netConf, A3CConfiguration conf,
- IDataManager dataManager) {
- this(mdp, new ActorCriticFactoryCompGraphStdDense(netConf), conf, dataManager);
- }
- public A3CDiscreteDense(MDP mdp,
- ActorCriticFactoryCompGraphStdDense.Configuration netConf, A3CConfiguration conf) {
- this(mdp, new ActorCriticFactoryCompGraphStdDense(netConf), conf);
+ ActorCriticFactorySeparateStdDense.Configuration netConf, A3CConfiguration conf,
+ IDataManager dataManager) {
+ this(mdp, new ActorCriticFactorySeparateStdDense(netConf.toNetworkConfiguration()), conf, dataManager);
}
+ @Deprecated
+ public A3CDiscreteDense(MDP mdp,
+ ActorCriticFactorySeparateStdDense.Configuration netConf, A3CConfiguration conf) {
+ this(mdp, new ActorCriticFactorySeparateStdDense(netConf.toNetworkConfiguration()), conf);
+ }
+
+ public A3CDiscreteDense(MDP mdp,
+ ActorCriticDenseNetworkConfiguration netConf, A3CLearningConfiguration conf) {
+ this(mdp, new ActorCriticFactorySeparateStdDense(netConf), conf);
+ }
+
+ @Deprecated
+ public A3CDiscreteDense(MDP mdp, ActorCriticFactoryCompGraph factory,
+ A3CConfiguration conf, IDataManager dataManager) {
+ this(mdp, factory.buildActorCritic(mdp.getObservationSpace().getShape(), mdp.getActionSpace().getSize()), conf,
+ dataManager);
+ }
+
+ @Deprecated
+ public A3CDiscreteDense(MDP mdp, ActorCriticFactoryCompGraph factory,
+ A3CConfiguration conf) {
+ this(mdp, factory.buildActorCritic(mdp.getObservationSpace().getShape(), mdp.getActionSpace().getSize()), conf);
+ }
+
+ public A3CDiscreteDense(MDP mdp, ActorCriticFactoryCompGraph factory,
+ A3CLearningConfiguration conf) {
+ this(mdp, factory.buildActorCritic(mdp.getObservationSpace().getShape(), mdp.getActionSpace().getSize()), conf);
+ }
}
diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/a3c/discrete/A3CThreadDiscrete.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/a3c/discrete/A3CThreadDiscrete.java
index 22b3894b2..c2a16d6b4 100644
--- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/a3c/discrete/A3CThreadDiscrete.java
+++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/a3c/discrete/A3CThreadDiscrete.java
@@ -1,5 +1,6 @@
/*******************************************************************************
- * Copyright (c) 2015-2018 Skymind, Inc.
+ * Copyright (c) 2015-2019 Skymind, Inc.
+ * Copyright (c) 2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
@@ -19,10 +20,10 @@ package org.deeplearning4j.rl4j.learning.async.a3c.discrete;
import lombok.Getter;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.rl4j.learning.Learning;
-import org.deeplearning4j.rl4j.learning.async.AsyncGlobal;
import org.deeplearning4j.rl4j.learning.async.AsyncThreadDiscrete;
import org.deeplearning4j.rl4j.learning.async.IAsyncGlobal;
import org.deeplearning4j.rl4j.learning.async.MiniTrans;
+import org.deeplearning4j.rl4j.learning.configuration.A3CLearningConfiguration;
import org.deeplearning4j.rl4j.learning.listener.TrainingListenerList;
import org.deeplearning4j.rl4j.mdp.MDP;
import org.deeplearning4j.rl4j.network.ac.IActorCritic;
@@ -31,9 +32,9 @@ import org.deeplearning4j.rl4j.policy.Policy;
import org.deeplearning4j.rl4j.space.DiscreteSpace;
import org.deeplearning4j.rl4j.space.Encodable;
import org.nd4j.linalg.api.ndarray.INDArray;
+import org.nd4j.linalg.api.rng.Random;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.NDArrayIndex;
-import org.nd4j.linalg.api.rng.Random;
import java.util.Stack;
@@ -45,7 +46,7 @@ import java.util.Stack;
public class A3CThreadDiscrete extends AsyncThreadDiscrete {
@Getter
- final protected A3CDiscrete.A3CConfiguration conf;
+ final protected A3CLearningConfiguration conf;
@Getter
final protected IAsyncGlobal asyncGlobal;
@Getter
@@ -54,14 +55,14 @@ public class A3CThreadDiscrete extends AsyncThreadDiscrete<
final private Random rnd;
public A3CThreadDiscrete(MDP mdp, IAsyncGlobal asyncGlobal,
- A3CDiscrete.A3CConfiguration a3cc, int deviceNum, TrainingListenerList listeners,
+ A3CLearningConfiguration a3cc, int deviceNum, TrainingListenerList listeners,
int threadNumber) {
super(asyncGlobal, mdp, listeners, threadNumber, deviceNum);
this.conf = a3cc;
this.asyncGlobal = asyncGlobal;
this.threadNumber = threadNumber;
- Integer seed = conf.getSeed();
+ Long seed = conf.getSeed();
rnd = Nd4j.getRandom();
if(seed != null) {
rnd.setSeed(seed + threadNumber);
diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/nstep/discrete/AsyncNStepQLearningDiscrete.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/nstep/discrete/AsyncNStepQLearningDiscrete.java
index c18de9e10..9a8049f6f 100644
--- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/nstep/discrete/AsyncNStepQLearningDiscrete.java
+++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/nstep/discrete/AsyncNStepQLearningDiscrete.java
@@ -1,49 +1,53 @@
/*******************************************************************************
- * Copyright (c) 2015-2018 Skymind, Inc.
+ * Copyright (c) 2015-2019 Skymind, Inc.
+ * Copyright (c) 2020 Konduit K.K.
*
- * This program and the accompanying materials are made available under the
- * terms of the Apache License, Version 2.0 which is available at
- * https://www.apache.org/licenses/LICENSE-2.0.
+ * This program and the accompanying materials are made available under the
+ * terms of the Apache License, Version 2.0 which is available at
+ * https://www.apache.org/licenses/LICENSE-2.0.
*
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
- * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
- * License for the specific language governing permissions and limitations
- * under the License.
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+ * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+ * License for the specific language governing permissions and limitations
+ * under the License.
*
- * SPDX-License-Identifier: Apache-2.0
+ * SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.rl4j.learning.async.nstep.discrete;
-import lombok.*;
-import org.deeplearning4j.rl4j.learning.async.AsyncConfiguration;
+import lombok.AllArgsConstructor;
+import lombok.Builder;
+import lombok.Data;
+import lombok.EqualsAndHashCode;
+import lombok.Getter;
import org.deeplearning4j.rl4j.learning.async.AsyncGlobal;
import org.deeplearning4j.rl4j.learning.async.AsyncLearning;
import org.deeplearning4j.rl4j.learning.async.AsyncThread;
+import org.deeplearning4j.rl4j.learning.configuration.AsyncQLearningConfiguration;
import org.deeplearning4j.rl4j.mdp.MDP;
import org.deeplearning4j.rl4j.network.dqn.IDQN;
import org.deeplearning4j.rl4j.policy.DQNPolicy;
import org.deeplearning4j.rl4j.policy.IPolicy;
import org.deeplearning4j.rl4j.space.DiscreteSpace;
import org.deeplearning4j.rl4j.space.Encodable;
-import org.nd4j.linalg.factory.Nd4j;
/**
* @author rubenfiszel (ruben.fiszel@epfl.ch) on 8/5/16.
*/
public abstract class AsyncNStepQLearningDiscrete
- extends AsyncLearning {
+ extends AsyncLearning {
@Getter
- final public AsyncNStepQLConfiguration configuration;
+ final public AsyncQLearningConfiguration configuration;
@Getter
final private MDP mdp;
@Getter
final private AsyncGlobal asyncGlobal;
- public AsyncNStepQLearningDiscrete(MDP mdp, IDQN dqn, AsyncNStepQLConfiguration conf) {
+ public AsyncNStepQLearningDiscrete(MDP mdp, IDQN dqn, AsyncQLearningConfiguration conf) {
this.mdp = mdp;
this.configuration = conf;
this.asyncGlobal = new AsyncGlobal<>(dqn, conf, this);
@@ -62,12 +66,11 @@ public abstract class AsyncNStepQLearningDiscrete
return new DQNPolicy(getNeuralNet());
}
-
@Data
@AllArgsConstructor
@Builder
@EqualsAndHashCode(callSuper = false)
- public static class AsyncNStepQLConfiguration implements AsyncConfiguration {
+ public static class AsyncNStepQLConfiguration {
Integer seed;
int maxEpochStep;
@@ -82,5 +85,22 @@ public abstract class AsyncNStepQLearningDiscrete
float minEpsilon;
int epsilonNbStep;
+ public AsyncQLearningConfiguration toLearningConfiguration() {
+ return AsyncQLearningConfiguration.builder()
+ .seed(new Long(seed))
+ .maxEpochStep(maxEpochStep)
+ .maxStep(maxStep)
+ .numThreads(numThread)
+ .nStep(nstep)
+ .targetDqnUpdateFreq(targetDqnUpdateFreq)
+ .updateStart(updateStart)
+ .rewardFactor(rewardFactor)
+ .gamma(gamma)
+ .errorClamp(errorClamp)
+ .minEpsilon(minEpsilon)
+ .build();
+ }
+
}
+
}
diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/nstep/discrete/AsyncNStepQLearningDiscreteConv.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/nstep/discrete/AsyncNStepQLearningDiscreteConv.java
index 83274b7f6..f92b704b6 100644
--- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/nstep/discrete/AsyncNStepQLearningDiscreteConv.java
+++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/nstep/discrete/AsyncNStepQLearningDiscreteConv.java
@@ -1,24 +1,27 @@
/*******************************************************************************
- * Copyright (c) 2015-2018 Skymind, Inc.
+ * Copyright (c) 2015-2019 Skymind, Inc.
+ * Copyright (c) 2020 Konduit K.K.
*
- * This program and the accompanying materials are made available under the
- * terms of the Apache License, Version 2.0 which is available at
- * https://www.apache.org/licenses/LICENSE-2.0.
+ * This program and the accompanying materials are made available under the
+ * terms of the Apache License, Version 2.0 which is available at
+ * https://www.apache.org/licenses/LICENSE-2.0.
*
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
- * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
- * License for the specific language governing permissions and limitations
- * under the License.
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+ * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+ * License for the specific language governing permissions and limitations
+ * under the License.
*
- * SPDX-License-Identifier: Apache-2.0
+ * SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.rl4j.learning.async.nstep.discrete;
import org.deeplearning4j.rl4j.learning.HistoryProcessor;
import org.deeplearning4j.rl4j.learning.async.AsyncThread;
+import org.deeplearning4j.rl4j.learning.configuration.AsyncQLearningConfiguration;
import org.deeplearning4j.rl4j.mdp.MDP;
+import org.deeplearning4j.rl4j.network.configuration.NetworkConfiguration;
import org.deeplearning4j.rl4j.network.dqn.DQNFactory;
import org.deeplearning4j.rl4j.network.dqn.DQNFactoryStdConv;
import org.deeplearning4j.rl4j.network.dqn.IDQN;
@@ -38,12 +41,12 @@ public class AsyncNStepQLearningDiscreteConv extends AsyncN
@Deprecated
public AsyncNStepQLearningDiscreteConv(MDP mdp, IDQN dqn,
- HistoryProcessor.Configuration hpconf, AsyncNStepQLConfiguration conf, IDataManager dataManager) {
+ HistoryProcessor.Configuration hpconf, AsyncQLearningConfiguration conf, IDataManager dataManager) {
this(mdp, dqn, hpconf, conf);
addListener(new DataManagerTrainingListener(dataManager));
}
public AsyncNStepQLearningDiscreteConv(MDP mdp, IDQN dqn,
- HistoryProcessor.Configuration hpconf, AsyncNStepQLConfiguration conf) {
+ HistoryProcessor.Configuration hpconf, AsyncQLearningConfiguration conf) {
super(mdp, dqn, conf);
this.hpconf = hpconf;
setHistoryProcessor(hpconf);
@@ -51,21 +54,21 @@ public class AsyncNStepQLearningDiscreteConv extends AsyncN
@Deprecated
public AsyncNStepQLearningDiscreteConv(MDP mdp, DQNFactory factory,
- HistoryProcessor.Configuration hpconf, AsyncNStepQLConfiguration conf, IDataManager dataManager) {
+ HistoryProcessor.Configuration hpconf, AsyncQLearningConfiguration conf, IDataManager dataManager) {
this(mdp, factory.buildDQN(hpconf.getShape(), mdp.getActionSpace().getSize()), hpconf, conf, dataManager);
}
public AsyncNStepQLearningDiscreteConv(MDP mdp, DQNFactory factory,
- HistoryProcessor.Configuration hpconf, AsyncNStepQLConfiguration conf) {
+ HistoryProcessor.Configuration hpconf, AsyncQLearningConfiguration conf) {
this(mdp, factory.buildDQN(hpconf.getShape(), mdp.getActionSpace().getSize()), hpconf, conf);
}
@Deprecated
- public AsyncNStepQLearningDiscreteConv(MDP mdp, DQNFactoryStdConv.Configuration netConf,
- HistoryProcessor.Configuration hpconf, AsyncNStepQLConfiguration conf, IDataManager dataManager) {
+ public AsyncNStepQLearningDiscreteConv(MDP mdp, NetworkConfiguration netConf,
+ HistoryProcessor.Configuration hpconf, AsyncQLearningConfiguration conf, IDataManager dataManager) {
this(mdp, new DQNFactoryStdConv(netConf), hpconf, conf, dataManager);
}
- public AsyncNStepQLearningDiscreteConv(MDP mdp, DQNFactoryStdConv.Configuration netConf,
- HistoryProcessor.Configuration hpconf, AsyncNStepQLConfiguration conf) {
+ public AsyncNStepQLearningDiscreteConv(MDP mdp, NetworkConfiguration netConf,
+ HistoryProcessor.Configuration hpconf, AsyncQLearningConfiguration conf) {
this(mdp, new DQNFactoryStdConv(netConf), hpconf, conf);
}
diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/nstep/discrete/AsyncNStepQLearningDiscreteDense.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/nstep/discrete/AsyncNStepQLearningDiscreteDense.java
index b58e15902..b6216e849 100644
--- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/nstep/discrete/AsyncNStepQLearningDiscreteDense.java
+++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/nstep/discrete/AsyncNStepQLearningDiscreteDense.java
@@ -1,22 +1,26 @@
/*******************************************************************************
- * Copyright (c) 2015-2018 Skymind, Inc.
+ * Copyright (c) 2015-2019 Skymind, Inc.
+ * Copyright (c) 2020 Konduit K.K.
*
- * This program and the accompanying materials are made available under the
- * terms of the Apache License, Version 2.0 which is available at
- * https://www.apache.org/licenses/LICENSE-2.0.
*
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
- * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
- * License for the specific language governing permissions and limitations
- * under the License.
+ * This program and the accompanying materials are made available under the
+ * terms of the Apache License, Version 2.0 which is available at
+ * https://www.apache.org/licenses/LICENSE-2.0.
*
- * SPDX-License-Identifier: Apache-2.0
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+ * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+ * License for the specific language governing permissions and limitations
+ * under the License.
+ *
+ * SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.rl4j.learning.async.nstep.discrete;
+import org.deeplearning4j.rl4j.learning.configuration.AsyncQLearningConfiguration;
import org.deeplearning4j.rl4j.mdp.MDP;
+import org.deeplearning4j.rl4j.network.configuration.DQNDenseNetworkConfiguration;
import org.deeplearning4j.rl4j.network.dqn.DQNFactory;
import org.deeplearning4j.rl4j.network.dqn.DQNFactoryStdDense;
import org.deeplearning4j.rl4j.network.dqn.IDQN;
@@ -32,35 +36,56 @@ public class AsyncNStepQLearningDiscreteDense extends Async
@Deprecated
public AsyncNStepQLearningDiscreteDense(MDP mdp, IDQN dqn,
- AsyncNStepQLConfiguration conf, IDataManager dataManager) {
- super(mdp, dqn, conf);
+ AsyncNStepQLConfiguration conf, IDataManager dataManager) {
+ super(mdp, dqn, conf.toLearningConfiguration());
addListener(new DataManagerTrainingListener(dataManager));
}
+ @Deprecated
public AsyncNStepQLearningDiscreteDense(MDP mdp, IDQN dqn,
AsyncNStepQLConfiguration conf) {
+ super(mdp, dqn, conf.toLearningConfiguration());
+ }
+
+ public AsyncNStepQLearningDiscreteDense(MDP mdp, IDQN dqn,
+ AsyncQLearningConfiguration conf) {
super(mdp, dqn, conf);
}
@Deprecated
public AsyncNStepQLearningDiscreteDense(MDP mdp, DQNFactory factory,
- AsyncNStepQLConfiguration conf, IDataManager dataManager) {
+ AsyncNStepQLConfiguration conf, IDataManager dataManager) {
this(mdp, factory.buildDQN(mdp.getObservationSpace().getShape(), mdp.getActionSpace().getSize()), conf,
- dataManager);
+ dataManager);
}
+
+ @Deprecated
public AsyncNStepQLearningDiscreteDense(MDP mdp, DQNFactory factory,
AsyncNStepQLConfiguration conf) {
this(mdp, factory.buildDQN(mdp.getObservationSpace().getShape(), mdp.getActionSpace().getSize()), conf);
}
+ public AsyncNStepQLearningDiscreteDense(MDP mdp, DQNFactory factory,
+ AsyncQLearningConfiguration conf) {
+ this(mdp, factory.buildDQN(mdp.getObservationSpace().getShape(), mdp.getActionSpace().getSize()), conf);
+ }
+
@Deprecated
public AsyncNStepQLearningDiscreteDense(MDP mdp,
- DQNFactoryStdDense.Configuration netConf, AsyncNStepQLConfiguration conf, IDataManager dataManager) {
- this(mdp, new DQNFactoryStdDense(netConf), conf, dataManager);
+ DQNFactoryStdDense.Configuration netConf, AsyncNStepQLConfiguration conf, IDataManager dataManager) {
+ this(mdp, new DQNFactoryStdDense(netConf.toNetworkConfiguration()), conf, dataManager);
}
+
+ @Deprecated
public AsyncNStepQLearningDiscreteDense(MDP mdp,
DQNFactoryStdDense.Configuration netConf, AsyncNStepQLConfiguration conf) {
+ this(mdp, new DQNFactoryStdDense(netConf.toNetworkConfiguration()), conf);
+ }
+
+ public AsyncNStepQLearningDiscreteDense(MDP mdp,
+ DQNDenseNetworkConfiguration netConf, AsyncQLearningConfiguration conf) {
this(mdp, new DQNFactoryStdDense(netConf), conf);
}
+
}
diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/nstep/discrete/AsyncNStepQLearningThreadDiscrete.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/nstep/discrete/AsyncNStepQLearningThreadDiscrete.java
index f8c470269..71199efaf 100644
--- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/nstep/discrete/AsyncNStepQLearningThreadDiscrete.java
+++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/nstep/discrete/AsyncNStepQLearningThreadDiscrete.java
@@ -1,17 +1,18 @@
/*******************************************************************************
- * Copyright (c) 2015-2018 Skymind, Inc.
+ * Copyright (c) 2015-2019 Skymind, Inc.
+ * Copyright (c) 2020 Konduit K.K.
*
- * This program and the accompanying materials are made available under the
- * terms of the Apache License, Version 2.0 which is available at
- * https://www.apache.org/licenses/LICENSE-2.0.
+ * This program and the accompanying materials are made available under the
+ * terms of the Apache License, Version 2.0 which is available at
+ * https://www.apache.org/licenses/LICENSE-2.0.
*
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
- * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
- * License for the specific language governing permissions and limitations
- * under the License.
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+ * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+ * License for the specific language governing permissions and limitations
+ * under the License.
*
- * SPDX-License-Identifier: Apache-2.0
+ * SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.rl4j.learning.async.nstep.discrete;
@@ -22,6 +23,7 @@ import org.deeplearning4j.rl4j.learning.Learning;
import org.deeplearning4j.rl4j.learning.async.AsyncThreadDiscrete;
import org.deeplearning4j.rl4j.learning.async.IAsyncGlobal;
import org.deeplearning4j.rl4j.learning.async.MiniTrans;
+import org.deeplearning4j.rl4j.learning.configuration.AsyncQLearningConfiguration;
import org.deeplearning4j.rl4j.learning.listener.TrainingListenerList;
import org.deeplearning4j.rl4j.mdp.MDP;
import org.deeplearning4j.rl4j.network.dqn.IDQN;
@@ -42,7 +44,7 @@ import java.util.Stack;
public class AsyncNStepQLearningThreadDiscrete extends AsyncThreadDiscrete {
@Getter
- final protected AsyncNStepQLearningDiscrete.AsyncNStepQLConfiguration conf;
+ final protected AsyncQLearningConfiguration conf;
@Getter
final protected IAsyncGlobal asyncGlobal;
@Getter
@@ -51,7 +53,7 @@ public class AsyncNStepQLearningThreadDiscrete extends Asyn
final private Random rnd;
public AsyncNStepQLearningThreadDiscrete(MDP mdp, IAsyncGlobal asyncGlobal,
- AsyncNStepQLearningDiscrete.AsyncNStepQLConfiguration conf,
+ AsyncQLearningConfiguration conf,
TrainingListenerList listeners, int threadNumber, int deviceNum) {
super(asyncGlobal, mdp, listeners, threadNumber, deviceNum);
this.conf = conf;
@@ -59,7 +61,7 @@ public class AsyncNStepQLearningThreadDiscrete extends Asyn
this.threadNumber = threadNumber;
rnd = Nd4j.getRandom();
- Integer seed = conf.getSeed();
+ Long seed = conf.getSeed();
if(seed != null) {
rnd.setSeed(seed + threadNumber);
}
diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/configuration/A3CLearningConfiguration.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/configuration/A3CLearningConfiguration.java
new file mode 100644
index 000000000..226fe4419
--- /dev/null
+++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/configuration/A3CLearningConfiguration.java
@@ -0,0 +1,46 @@
+/*******************************************************************************
+ * Copyright (c) 2020 Konduit K.K.
+ *
+ * This program and the accompanying materials are made available under the
+ * terms of the Apache License, Version 2.0 which is available at
+ * https://www.apache.org/licenses/LICENSE-2.0.
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+ * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+ * License for the specific language governing permissions and limitations
+ * under the License.
+ *
+ * SPDX-License-Identifier: Apache-2.0
+ ******************************************************************************/
+
+package org.deeplearning4j.rl4j.learning.configuration;
+
+import lombok.Builder;
+import lombok.Data;
+import lombok.EqualsAndHashCode;
+import lombok.experimental.SuperBuilder;
+
+@Data
+@SuperBuilder
+@EqualsAndHashCode(callSuper = true)
+public class A3CLearningConfiguration extends LearningConfiguration implements IAsyncLearningConfiguration {
+
+ /**
+ * The number of asynchronous threads to use to generate gradients
+ */
+ private final int numThreads;
+
+ /**
+ * The number of steps to calculate gradients over
+ */
+ private final int nStep;
+
+ /**
+ * The frequency of async training iterations to update the target network.
+ *
+ * If this is set to -1 then the target network is updated after every training iteration
+ */
+ @Builder.Default
+ private int learnerUpdateFrequency = -1;
+}
diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/configuration/AsyncQLearningConfiguration.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/configuration/AsyncQLearningConfiguration.java
new file mode 100644
index 000000000..a60903e59
--- /dev/null
+++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/configuration/AsyncQLearningConfiguration.java
@@ -0,0 +1,42 @@
+/*******************************************************************************
+ * Copyright (c) 2020 Konduit K.K.
+ *
+ * This program and the accompanying materials are made available under the
+ * terms of the Apache License, Version 2.0 which is available at
+ * https://www.apache.org/licenses/LICENSE-2.0.
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+ * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+ * License for the specific language governing permissions and limitations
+ * under the License.
+ *
+ * SPDX-License-Identifier: Apache-2.0
+ ******************************************************************************/
+
+package org.deeplearning4j.rl4j.learning.configuration;
+
+import lombok.Data;
+import lombok.EqualsAndHashCode;
+import lombok.experimental.SuperBuilder;
+
+
+@Data
+@SuperBuilder
+@EqualsAndHashCode(callSuper = true)
+public class AsyncQLearningConfiguration extends QLearningConfiguration implements IAsyncLearningConfiguration {
+
+ /**
+ * The number of asynchronous threads to use to generate experience data
+ */
+ private final int numThreads;
+
+ /**
+ * The number of steps in each training interations
+ */
+ private final int nStep;
+
+ public int getLearnerUpdateFrequency() {
+ return getTargetDqnUpdateFreq();
+ }
+}
diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/configuration/IAsyncLearningConfiguration.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/configuration/IAsyncLearningConfiguration.java
new file mode 100644
index 000000000..1e7cf3f2e
--- /dev/null
+++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/configuration/IAsyncLearningConfiguration.java
@@ -0,0 +1,28 @@
+/*******************************************************************************
+ * Copyright (c) 2020 Konduit K.K.
+ *
+ * This program and the accompanying materials are made available under the
+ * terms of the Apache License, Version 2.0 which is available at
+ * https://www.apache.org/licenses/LICENSE-2.0.
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+ * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+ * License for the specific language governing permissions and limitations
+ * under the License.
+ *
+ * SPDX-License-Identifier: Apache-2.0
+ ******************************************************************************/
+
+package org.deeplearning4j.rl4j.learning.configuration;
+
+public interface IAsyncLearningConfiguration extends ILearningConfiguration {
+
+ int getNumThreads();
+
+ int getNStep();
+
+ int getLearnerUpdateFrequency();
+
+ int getMaxStep();
+}
diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/AsyncConfiguration.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/configuration/ILearningConfiguration.java
similarity index 61%
rename from rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/AsyncConfiguration.java
rename to rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/configuration/ILearningConfiguration.java
index 0727db475..7ae215087 100644
--- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/AsyncConfiguration.java
+++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/configuration/ILearningConfiguration.java
@@ -1,5 +1,5 @@
/*******************************************************************************
- * Copyright (c) 2015-2018 Skymind, Inc.
+ * Copyright (c) 2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
@@ -14,36 +14,16 @@
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
-package org.deeplearning4j.rl4j.learning.async;
+package org.deeplearning4j.rl4j.learning.configuration;
-import org.deeplearning4j.rl4j.learning.ILearning;
-
-/**
- * @author rubenfiszel (ruben.fiszel@epfl.ch) on 8/23/16.
- *
- * Interface configuration for all training method that inherit
- * from AsyncLearning
- */
-public interface AsyncConfiguration extends ILearning.LConfiguration {
-
- Integer getSeed();
+public interface ILearningConfiguration {
+ Long getSeed();
int getMaxEpochStep();
int getMaxStep();
- int getNumThread();
-
- int getNstep();
-
- int getTargetDqnUpdateFreq();
-
- int getUpdateStart();
-
- double getRewardFactor();
-
double getGamma();
- double getErrorClamp();
-
+ double getRewardFactor();
}
diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/configuration/LearningConfiguration.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/configuration/LearningConfiguration.java
new file mode 100644
index 000000000..d1567e619
--- /dev/null
+++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/configuration/LearningConfiguration.java
@@ -0,0 +1,59 @@
+/*******************************************************************************
+ * Copyright (c) 2020 Konduit K.K.
+ *
+ * This program and the accompanying materials are made available under the
+ * terms of the Apache License, Version 2.0 which is available at
+ * https://www.apache.org/licenses/LICENSE-2.0.
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+ * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+ * License for the specific language governing permissions and limitations
+ * under the License.
+ *
+ * SPDX-License-Identifier: Apache-2.0
+ ******************************************************************************/
+
+package org.deeplearning4j.rl4j.learning.configuration;
+
+import lombok.Builder;
+import lombok.Data;
+import lombok.NoArgsConstructor;
+import lombok.experimental.SuperBuilder;
+
+@Data
+@SuperBuilder
+@NoArgsConstructor
+public class LearningConfiguration implements ILearningConfiguration {
+
+ /**
+ * Seed value used for training
+ */
+ @Builder.Default
+ private Long seed = System.currentTimeMillis();
+
+ /**
+ * The maximum number of steps in each episode
+ */
+ @Builder.Default
+ private int maxEpochStep = 200;
+
+ /**
+ * The maximum number of steps to train for
+ */
+ @Builder.Default
+ private int maxStep = 150000;
+
+ /**
+ * Gamma parameter used for discounted rewards
+ */
+ @Builder.Default
+ private double gamma = 0.99;
+
+ /**
+ * Scaling parameter for rewards
+ */
+ @Builder.Default
+ private double rewardFactor = 1.0;
+
+}
diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/configuration/QLearningConfiguration.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/configuration/QLearningConfiguration.java
new file mode 100644
index 000000000..26ac57f0c
--- /dev/null
+++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/configuration/QLearningConfiguration.java
@@ -0,0 +1,79 @@
+/*******************************************************************************
+ * Copyright (c) 2020 Konduit K.K.
+ *
+ * This program and the accompanying materials are made available under the
+ * terms of the Apache License, Version 2.0 which is available at
+ * https://www.apache.org/licenses/LICENSE-2.0.
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+ * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+ * License for the specific language governing permissions and limitations
+ * under the License.
+ *
+ * SPDX-License-Identifier: Apache-2.0
+ ******************************************************************************/
+
+package org.deeplearning4j.rl4j.learning.configuration;
+
+import lombok.Builder;
+import lombok.Data;
+import lombok.EqualsAndHashCode;
+import lombok.NoArgsConstructor;
+import lombok.experimental.SuperBuilder;
+
+@Data
+@SuperBuilder
+@NoArgsConstructor
+@EqualsAndHashCode(callSuper = false)
+public class QLearningConfiguration extends LearningConfiguration {
+
+ /**
+ * The maximum size of the experience replay buffer
+ */
+ @Builder.Default
+ private int expRepMaxSize = 150000;
+
+ /**
+ * The batch size of experience for each training iteration
+ */
+ @Builder.Default
+ private int batchSize = 32;
+
+ /**
+ * How many steps between target network updates
+ */
+ @Builder.Default
+ private int targetDqnUpdateFreq = 100;
+
+ /**
+ * The number of steps to initially wait for until samplling batches from experience replay buffer
+ */
+ @Builder.Default
+ private int updateStart = 10;
+
+ /**
+ * Prevent the new Q-Value from being farther than errorClamp away from the previous value. Double.NaN will result in no clamping
+ */
+ @Builder.Default
+ private double errorClamp = 1.0;
+
+ /**
+ * The minimum probability for random exploration action during episilon-greedy annealing
+ */
+ @Builder.Default
+ private double minEpsilon = 0.1f;
+
+ /**
+ * The number of steps to anneal epsilon to its minimum value.
+ */
+ @Builder.Default
+ private int epsilonNbStep = 10000;
+
+ /**
+ * Whether to use the double DQN algorithm
+ */
+ @Builder.Default
+ private boolean doubleDQN = false;
+
+}
diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/SyncLearning.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/SyncLearning.java
index 22d936fcf..c42756145 100644
--- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/SyncLearning.java
+++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/SyncLearning.java
@@ -1,5 +1,6 @@
/*******************************************************************************
- * Copyright (c) 2015-2018 Skymind, Inc.
+ * Copyright (c) 2015-2019 Skymind, Inc.
+ * Copyright (c) 2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
@@ -63,7 +64,7 @@ public abstract class SyncLearning, NN extends N
/**
* This method will train the model
* The training stop when:
- * - the number of steps reaches the maximum defined in the configuration (see {@link LConfiguration#getMaxStep() LConfiguration.getMaxStep()})
+ * - the number of steps reaches the maximum defined in the configuration (see {@link ILearningConfiguration#getMaxStep() LConfiguration.getMaxStep()})
* OR
* - a listener explicitly stops it
*
diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/qlearning/QLearning.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/qlearning/QLearning.java
index 0757043f0..40704d4e9 100644
--- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/qlearning/QLearning.java
+++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/qlearning/QLearning.java
@@ -1,5 +1,6 @@
/*******************************************************************************
- * Copyright (c) 2015-2018 Skymind, Inc.
+ * Copyright (c) 2015-2019 Skymind, Inc.
+ * Copyright (c) 2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
@@ -18,10 +19,19 @@ package org.deeplearning4j.rl4j.learning.sync.qlearning;
import com.fasterxml.jackson.databind.annotation.JsonDeserialize;
import com.fasterxml.jackson.databind.annotation.JsonPOJOBuilder;
-import lombok.*;
+import lombok.AccessLevel;
+import lombok.AllArgsConstructor;
+import lombok.Builder;
+import lombok.Data;
+import lombok.EqualsAndHashCode;
+import lombok.Getter;
+import lombok.Setter;
+import lombok.Value;
import lombok.extern.slf4j.Slf4j;
import org.deeplearning4j.gym.StepReply;
import org.deeplearning4j.rl4j.learning.EpochStepCounter;
+import org.deeplearning4j.rl4j.learning.configuration.ILearningConfiguration;
+import org.deeplearning4j.rl4j.learning.configuration.QLearningConfiguration;
import org.deeplearning4j.rl4j.learning.sync.ExpReplay;
import org.deeplearning4j.rl4j.learning.sync.IExpReplay;
import org.deeplearning4j.rl4j.learning.sync.SyncLearning;
@@ -59,15 +69,15 @@ public abstract class QLearning getLegacyMDPWrapper();
- public QLearning(QLConfiguration conf) {
+ public QLearning(QLearningConfiguration conf) {
this(conf, getSeededRandom(conf.getSeed()));
}
- public QLearning(QLConfiguration conf, Random random) {
+ public QLearning(QLearningConfiguration conf, Random random) {
expReplay = new ExpReplay<>(conf.getExpRepMaxSize(), conf.getBatchSize(), random);
}
- private static Random getSeededRandom(Integer seed) {
+ private static Random getSeededRandom(Long seed) {
Random rnd = Nd4j.getRandom();
if(seed != null) {
rnd.setSeed(seed);
@@ -95,7 +105,7 @@ public abstract class QLearning scores;
- float epsilon;
+ double epsilon;
double startQ;
double meanQ;
}
@@ -213,12 +223,14 @@ public abstract class QLearning
* DQN or Deep Q-Learning in the Discrete domain
- *
+ *
* http://arxiv.org/abs/1312.5602
- *
*/
public abstract class QLearningDiscrete extends QLearning {
@Getter
- final private QLConfiguration configuration;
+ final private QLearningConfiguration configuration;
private final LegacyMDPWrapper mdp;
@Getter
private DQNPolicy policy;
@@ -78,16 +79,15 @@ public abstract class QLearningDiscrete extends QLearning mdp, IDQN dqn, QLConfiguration conf,
- int epsilonNbStep) {
+ public QLearningDiscrete(MDP mdp, IDQN dqn, QLearningConfiguration conf, int epsilonNbStep) {
this(mdp, dqn, conf, epsilonNbStep, Nd4j.getRandomFactory().getNewRandomInstance(conf.getSeed()));
}
- public QLearningDiscrete(MDP mdp, IDQN dqn, QLConfiguration conf,
+ public QLearningDiscrete(MDP mdp, IDQN dqn, QLearningConfiguration conf,
int epsilonNbStep, Random random) {
super(conf);
this.configuration = conf;
- this.mdp = new LegacyMDPWrapper(mdp, null, this);
+ this.mdp = new LegacyMDPWrapper<>(mdp, null, this);
qNetwork = dqn;
targetQNetwork = dqn.clone();
policy = new DQNPolicy(getQNetwork());
@@ -125,6 +125,7 @@ public abstract class QLearningDiscrete extends QLearning extends QLearning extends QLearning extends QLearning extends QLearningDiscret
@Deprecated
public QLearningDiscreteConv(MDP mdp, IDQN dqn, HistoryProcessor.Configuration hpconf,
- QLConfiguration conf, IDataManager dataManager) {
+ QLConfiguration conf, IDataManager dataManager) {
this(mdp, dqn, hpconf, conf);
addListener(new DataManagerTrainingListener(dataManager));
}
+
+ @Deprecated
public QLearningDiscreteConv(MDP mdp, IDQN dqn, HistoryProcessor.Configuration hpconf,
QLConfiguration conf) {
+ super(mdp, dqn, conf.toLearningConfiguration(), conf.getEpsilonNbStep() * hpconf.getSkipFrame());
+ setHistoryProcessor(hpconf);
+ }
+
+ public QLearningDiscreteConv(MDP mdp, IDQN dqn, HistoryProcessor.Configuration hpconf,
+ QLearningConfiguration conf) {
super(mdp, dqn, conf, conf.getEpsilonNbStep() * hpconf.getSkipFrame());
setHistoryProcessor(hpconf);
}
@Deprecated
public QLearningDiscreteConv(MDP mdp, DQNFactory factory,
- HistoryProcessor.Configuration hpconf, QLConfiguration conf, IDataManager dataManager) {
+ HistoryProcessor.Configuration hpconf, QLConfiguration conf, IDataManager dataManager) {
this(mdp, factory.buildDQN(hpconf.getShape(), mdp.getActionSpace().getSize()), hpconf, conf, dataManager);
}
+
+ @Deprecated
public QLearningDiscreteConv(MDP mdp, DQNFactory factory,
HistoryProcessor.Configuration hpconf, QLConfiguration conf) {
this(mdp, factory.buildDQN(hpconf.getShape(), mdp.getActionSpace().getSize()), hpconf, conf);
}
+ public QLearningDiscreteConv(MDP mdp, DQNFactory factory,
+ HistoryProcessor.Configuration hpconf, QLearningConfiguration conf) {
+ this(mdp, factory.buildDQN(hpconf.getShape(), mdp.getActionSpace().getSize()), hpconf, conf);
+ }
+
@Deprecated
public QLearningDiscreteConv(MDP mdp, DQNFactoryStdConv.Configuration netConf,
- HistoryProcessor.Configuration hpconf, QLConfiguration conf, IDataManager dataManager) {
- this(mdp, new DQNFactoryStdConv(netConf), hpconf, conf, dataManager);
+ HistoryProcessor.Configuration hpconf, QLConfiguration conf, IDataManager dataManager) {
+ this(mdp, new DQNFactoryStdConv(netConf.toNetworkConfiguration()), hpconf, conf, dataManager);
}
+
+ @Deprecated
public QLearningDiscreteConv(MDP mdp, DQNFactoryStdConv.Configuration netConf,
HistoryProcessor.Configuration hpconf, QLConfiguration conf) {
+ this(mdp, new DQNFactoryStdConv(netConf.toNetworkConfiguration()), hpconf, conf);
+ }
+
+ public QLearningDiscreteConv(MDP mdp, NetworkConfiguration netConf,
+ HistoryProcessor.Configuration hpconf, QLearningConfiguration conf) {
this(mdp, new DQNFactoryStdConv(netConf), hpconf, conf);
}
diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/QLearningDiscreteDense.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/QLearningDiscreteDense.java
index ef69ea6fb..5b95cc84e 100644
--- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/QLearningDiscreteDense.java
+++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/QLearningDiscreteDense.java
@@ -1,5 +1,6 @@
/*******************************************************************************
- * Copyright (c) 2015-2018 Skymind, Inc.
+ * Copyright (c) 2015-2019 Skymind, Inc.
+ * Copyright (c) 2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
@@ -16,8 +17,10 @@
package org.deeplearning4j.rl4j.learning.sync.qlearning.discrete;
+import org.deeplearning4j.rl4j.learning.configuration.QLearningConfiguration;
import org.deeplearning4j.rl4j.learning.sync.qlearning.QLearning;
import org.deeplearning4j.rl4j.mdp.MDP;
+import org.deeplearning4j.rl4j.network.configuration.DQNDenseNetworkConfiguration;
import org.deeplearning4j.rl4j.network.dqn.DQNFactory;
import org.deeplearning4j.rl4j.network.dqn.DQNFactoryStdDense;
import org.deeplearning4j.rl4j.network.dqn.IDQN;
@@ -38,7 +41,13 @@ public class QLearningDiscreteDense extends QLearningDiscre
this(mdp, dqn, conf);
addListener(new DataManagerTrainingListener(dataManager));
}
+
+ @Deprecated
public QLearningDiscreteDense(MDP mdp, IDQN dqn, QLearning.QLConfiguration conf) {
+ super(mdp, dqn, conf.toLearningConfiguration(), conf.getEpsilonNbStep());
+ }
+
+ public QLearningDiscreteDense(MDP mdp, IDQN dqn, QLearningConfiguration conf) {
super(mdp, dqn, conf, conf.getEpsilonNbStep());
}
@@ -48,18 +57,33 @@ public class QLearningDiscreteDense extends QLearningDiscre
this(mdp, factory.buildDQN(mdp.getObservationSpace().getShape(), mdp.getActionSpace().getSize()), conf,
dataManager);
}
+
+ @Deprecated
public QLearningDiscreteDense(MDP mdp, DQNFactory factory,
QLearning.QLConfiguration conf) {
this(mdp, factory.buildDQN(mdp.getObservationSpace().getShape(), mdp.getActionSpace().getSize()), conf);
}
+ public QLearningDiscreteDense(MDP mdp, DQNFactory factory,
+ QLearningConfiguration conf) {
+ this(mdp, factory.buildDQN(mdp.getObservationSpace().getShape(), mdp.getActionSpace().getSize()), conf);
+ }
+
@Deprecated
public QLearningDiscreteDense(MDP mdp, DQNFactoryStdDense.Configuration netConf,
QLearning.QLConfiguration conf, IDataManager dataManager) {
- this(mdp, new DQNFactoryStdDense(netConf), conf, dataManager);
+
+ this(mdp, new DQNFactoryStdDense(netConf.toNetworkConfiguration()), conf, dataManager);
}
+
+ @Deprecated
public QLearningDiscreteDense(MDP mdp, DQNFactoryStdDense.Configuration netConf,
QLearning.QLConfiguration conf) {
+ this(mdp, new DQNFactoryStdDense(netConf.toNetworkConfiguration()), conf);
+ }
+
+ public QLearningDiscreteDense(MDP mdp, DQNDenseNetworkConfiguration netConf,
+ QLearningConfiguration conf) {
this(mdp, new DQNFactoryStdDense(netConf), conf);
}
diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/ac/ActorCriticCompGraph.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/ac/ActorCriticCompGraph.java
index 274606ed9..63438bb74 100644
--- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/ac/ActorCriticCompGraph.java
+++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/ac/ActorCriticCompGraph.java
@@ -1,5 +1,6 @@
/*******************************************************************************
- * Copyright (c) 2015-2018 Skymind, Inc.
+ * Copyright (c) 2015-2019 Skymind, Inc.
+ * Copyright (c) 2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
@@ -36,7 +37,7 @@ import java.util.Collection;
*
* Standard implementation of ActorCriticCompGraph
*/
-public class ActorCriticCompGraph implements IActorCritic {
+public class ActorCriticCompGraph implements IActorCritic {
final protected ComputationGraph cg;
@Getter
@@ -73,13 +74,13 @@ public class ActorCriticCompGraph implements IA
}
}
- public NN clone() {
- NN nn = (NN)new ActorCriticCompGraph(cg.clone());
+ public ActorCriticCompGraph clone() {
+ ActorCriticCompGraph nn = new ActorCriticCompGraph(cg.clone());
nn.cg.setListeners(cg.getListeners());
return nn;
}
- public void copy(NN from) {
+ public void copy(ActorCriticCompGraph from) {
cg.setParams(from.cg.params());
}
diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/ac/ActorCriticFactoryCompGraphStdConv.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/ac/ActorCriticFactoryCompGraphStdConv.java
index bdadd2969..eaccf2a10 100644
--- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/ac/ActorCriticFactoryCompGraphStdConv.java
+++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/ac/ActorCriticFactoryCompGraphStdConv.java
@@ -1,5 +1,6 @@
/*******************************************************************************
- * Copyright (c) 2015-2018 Skymind, Inc.
+ * Copyright (c) 2015-2019 Skymind, Inc.
+ * Copyright (c) 2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
@@ -31,12 +32,16 @@ import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.optimize.api.TrainingListener;
import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
+import org.deeplearning4j.rl4j.network.configuration.ActorCriticNetworkConfiguration;
+import org.deeplearning4j.rl4j.network.configuration.ActorCriticNetworkConfiguration.ActorCriticNetworkConfigurationBuilder;
import org.deeplearning4j.rl4j.util.Constants;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.learning.config.Adam;
import org.nd4j.linalg.learning.config.IUpdater;
import org.nd4j.linalg.lossfunctions.LossFunctions;
+import java.util.Arrays;
+
/**
* @author rubenfiszel (ruben.fiszel@epfl.ch) on 8/9/16.
*
@@ -45,8 +50,7 @@ import org.nd4j.linalg.lossfunctions.LossFunctions;
@Value
public class ActorCriticFactoryCompGraphStdConv implements ActorCriticFactoryCompGraph {
-
- Configuration conf;
+ ActorCriticNetworkConfiguration conf;
public ActorCriticCompGraph buildActorCritic(int shapeInputs[], int numOutputs) {
@@ -109,16 +113,33 @@ public class ActorCriticFactoryCompGraphStdConv implements ActorCriticFactoryCom
return new ActorCriticCompGraph(model);
}
-
@AllArgsConstructor
@Builder
@Value
+ @Deprecated
public static class Configuration {
double l2;
IUpdater updater;
TrainingListener[] listeners;
boolean useLSTM;
+
+ /**
+ * Converts the deprecated Configuration to the new NetworkConfiguration format
+ */
+ public ActorCriticNetworkConfiguration toNetworkConfiguration() {
+ ActorCriticNetworkConfigurationBuilder builder = ActorCriticNetworkConfiguration.builder()
+ .l2(l2)
+ .updater(updater)
+ .useLSTM(useLSTM);
+
+ if (listeners != null) {
+ builder.listeners(Arrays.asList(listeners));
+ }
+
+ return builder.build();
+
+ }
}
}
diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/ac/ActorCriticFactoryCompGraphStdDense.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/ac/ActorCriticFactoryCompGraphStdDense.java
index 7c9e3e21b..0d9dae3c6 100644
--- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/ac/ActorCriticFactoryCompGraphStdDense.java
+++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/ac/ActorCriticFactoryCompGraphStdDense.java
@@ -1,5 +1,6 @@
/*******************************************************************************
- * Copyright (c) 2015-2018 Skymind, Inc.
+ * Copyright (c) 2015-2019 Skymind, Inc.
+ * Copyright (c) 2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
@@ -16,8 +17,6 @@
package org.deeplearning4j.rl4j.network.ac;
-import lombok.AllArgsConstructor;
-import lombok.Builder;
import lombok.Value;
import org.deeplearning4j.nn.api.OptimizationAlgorithm;
import org.deeplearning4j.nn.conf.ComputationGraphConfiguration;
@@ -29,12 +28,11 @@ import org.deeplearning4j.nn.conf.layers.OutputLayer;
import org.deeplearning4j.nn.conf.layers.RnnOutputLayer;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.weights.WeightInit;
-import org.deeplearning4j.optimize.api.TrainingListener;
import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
+import org.deeplearning4j.rl4j.network.configuration.ActorCriticDenseNetworkConfiguration;
import org.deeplearning4j.rl4j.util.Constants;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.learning.config.Adam;
-import org.nd4j.linalg.learning.config.IUpdater;
import org.nd4j.linalg.lossfunctions.LossFunctions;
/**
@@ -45,7 +43,7 @@ import org.nd4j.linalg.lossfunctions.LossFunctions;
@Value
public class ActorCriticFactoryCompGraphStdDense implements ActorCriticFactoryCompGraph {
- Configuration conf;
+ ActorCriticDenseNetworkConfiguration conf;
public ActorCriticCompGraph buildActorCritic(int[] numInputs, int numOutputs) {
int nIn = 1;
@@ -65,27 +63,27 @@ public class ActorCriticFactoryCompGraphStdDense implements ActorCriticFactoryCo
"input");
- for (int i = 1; i < conf.getNumLayer(); i++) {
+ for (int i = 1; i < conf.getNumLayers(); i++) {
confB.addLayer(i + "", new DenseLayer.Builder().nIn(conf.getNumHiddenNodes()).nOut(conf.getNumHiddenNodes())
.activation(Activation.RELU).build(), (i - 1) + "");
}
if (conf.isUseLSTM()) {
- confB.addLayer(getConf().getNumLayer() + "", new LSTM.Builder().activation(Activation.TANH)
- .nOut(conf.getNumHiddenNodes()).build(), (getConf().getNumLayer() - 1) + "");
+ confB.addLayer(getConf().getNumLayers() + "", new LSTM.Builder().activation(Activation.TANH)
+ .nOut(conf.getNumHiddenNodes()).build(), (getConf().getNumLayers() - 1) + "");
confB.addLayer("value", new RnnOutputLayer.Builder(LossFunctions.LossFunction.MSE).activation(Activation.IDENTITY)
- .nOut(1).build(), getConf().getNumLayer() + "");
+ .nOut(1).build(), getConf().getNumLayers() + "");
confB.addLayer("softmax", new RnnOutputLayer.Builder(new ActorCriticLoss()).activation(Activation.SOFTMAX)
- .nOut(numOutputs).build(), getConf().getNumLayer() + "");
+ .nOut(numOutputs).build(), getConf().getNumLayers() + "");
} else {
confB.addLayer("value", new OutputLayer.Builder(LossFunctions.LossFunction.MSE).activation(Activation.IDENTITY)
- .nOut(1).build(), (getConf().getNumLayer() - 1) + "");
+ .nOut(1).build(), (getConf().getNumLayers() - 1) + "");
confB.addLayer("softmax", new OutputLayer.Builder(new ActorCriticLoss()).activation(Activation.SOFTMAX)
- .nOut(numOutputs).build(), (getConf().getNumLayer() - 1) + "");
+ .nOut(numOutputs).build(), (getConf().getNumLayers() - 1) + "");
}
confB.setOutputs("value", "softmax");
@@ -103,18 +101,4 @@ public class ActorCriticFactoryCompGraphStdDense implements ActorCriticFactoryCo
return new ActorCriticCompGraph(model);
}
- @AllArgsConstructor
- @Builder
- @Value
- public static class Configuration {
-
- int numLayer;
- int numHiddenNodes;
- double l2;
- IUpdater updater;
- TrainingListener[] listeners;
- boolean useLSTM;
- }
-
-
}
diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/ac/ActorCriticFactorySeparateStdDense.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/ac/ActorCriticFactorySeparateStdDense.java
index a55e351c0..4ac557096 100644
--- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/ac/ActorCriticFactorySeparateStdDense.java
+++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/ac/ActorCriticFactorySeparateStdDense.java
@@ -1,5 +1,6 @@
/*******************************************************************************
- * Copyright (c) 2015-2018 Skymind, Inc.
+ * Copyright (c) 2015-2019 Skymind, Inc.
+ * Copyright (c) 2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
@@ -31,21 +32,24 @@ import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.optimize.api.TrainingListener;
import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
+
+import org.deeplearning4j.rl4j.network.configuration.ActorCriticDenseNetworkConfiguration;
+import org.deeplearning4j.rl4j.network.configuration.ActorCriticDenseNetworkConfiguration.ActorCriticDenseNetworkConfigurationBuilder;
import org.deeplearning4j.rl4j.util.Constants;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.learning.config.Adam;
import org.nd4j.linalg.learning.config.IUpdater;
import org.nd4j.linalg.lossfunctions.LossFunctions;
+import java.util.Arrays;
+
/**
* @author rubenfiszel (ruben.fiszel@epfl.ch) on 8/9/16.
- *
- *
*/
@Value
public class ActorCriticFactorySeparateStdDense implements ActorCriticFactorySeparate {
- Configuration conf;
+ ActorCriticDenseNetworkConfiguration conf;
public ActorCriticSeparate buildActorCritic(int[] numInputs, int numOutputs) {
int nIn = 1;
@@ -53,27 +57,27 @@ public class ActorCriticFactorySeparateStdDense implements ActorCriticFactorySep
nIn *= i;
}
NeuralNetConfiguration.ListBuilder confB = new NeuralNetConfiguration.Builder().seed(Constants.NEURAL_NET_SEED)
- .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
- .updater(conf.getUpdater() != null ? conf.getUpdater() : new Adam())
- .weightInit(WeightInit.XAVIER)
- .l2(conf.getL2())
- .list().layer(0, new DenseLayer.Builder().nIn(nIn).nOut(conf.getNumHiddenNodes())
- .activation(Activation.RELU).build());
+ .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
+ .updater(conf.getUpdater() != null ? conf.getUpdater() : new Adam())
+ .weightInit(WeightInit.XAVIER)
+ .l2(conf.getL2())
+ .list().layer(0, new DenseLayer.Builder().nIn(nIn).nOut(conf.getNumHiddenNodes())
+ .activation(Activation.RELU).build());
- for (int i = 1; i < conf.getNumLayer(); i++) {
+ for (int i = 1; i < conf.getNumLayers(); i++) {
confB.layer(i, new DenseLayer.Builder().nIn(conf.getNumHiddenNodes()).nOut(conf.getNumHiddenNodes())
- .activation(Activation.RELU).build());
+ .activation(Activation.RELU).build());
}
if (conf.isUseLSTM()) {
- confB.layer(conf.getNumLayer(), new LSTM.Builder().nOut(conf.getNumHiddenNodes()).activation(Activation.TANH).build());
+ confB.layer(conf.getNumLayers(), new LSTM.Builder().nOut(conf.getNumHiddenNodes()).activation(Activation.TANH).build());
- confB.layer(conf.getNumLayer() + 1, new RnnOutputLayer.Builder(LossFunctions.LossFunction.MSE).activation(Activation.IDENTITY)
- .nIn(conf.getNumHiddenNodes()).nOut(1).build());
+ confB.layer(conf.getNumLayers() + 1, new RnnOutputLayer.Builder(LossFunctions.LossFunction.MSE).activation(Activation.IDENTITY)
+ .nIn(conf.getNumHiddenNodes()).nOut(1).build());
} else {
- confB.layer(conf.getNumLayer(), new OutputLayer.Builder(LossFunctions.LossFunction.MSE).activation(Activation.IDENTITY)
- .nIn(conf.getNumHiddenNodes()).nOut(1).build());
+ confB.layer(conf.getNumLayers(), new OutputLayer.Builder(LossFunctions.LossFunction.MSE).activation(Activation.IDENTITY)
+ .nIn(conf.getNumHiddenNodes()).nOut(1).build());
}
confB.setInputType(conf.isUseLSTM() ? InputType.recurrent(nIn) : InputType.feedForward(nIn));
@@ -87,28 +91,28 @@ public class ActorCriticFactorySeparateStdDense implements ActorCriticFactorySep
}
NeuralNetConfiguration.ListBuilder confB2 = new NeuralNetConfiguration.Builder().seed(Constants.NEURAL_NET_SEED)
- .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
- .updater(conf.getUpdater() != null ? conf.getUpdater() : new Adam())
- .weightInit(WeightInit.XAVIER)
- //.regularization(true)
- //.l2(conf.getL2())
- .list().layer(0, new DenseLayer.Builder().nIn(nIn).nOut(conf.getNumHiddenNodes())
- .activation(Activation.RELU).build());
+ .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
+ .updater(conf.getUpdater() != null ? conf.getUpdater() : new Adam())
+ .weightInit(WeightInit.XAVIER)
+ //.regularization(true)
+ //.l2(conf.getL2())
+ .list().layer(0, new DenseLayer.Builder().nIn(nIn).nOut(conf.getNumHiddenNodes())
+ .activation(Activation.RELU).build());
- for (int i = 1; i < conf.getNumLayer(); i++) {
+ for (int i = 1; i < conf.getNumLayers(); i++) {
confB2.layer(i, new DenseLayer.Builder().nIn(conf.getNumHiddenNodes()).nOut(conf.getNumHiddenNodes())
- .activation(Activation.RELU).build());
+ .activation(Activation.RELU).build());
}
if (conf.isUseLSTM()) {
- confB2.layer(conf.getNumLayer(), new LSTM.Builder().nOut(conf.getNumHiddenNodes()).activation(Activation.TANH).build());
+ confB2.layer(conf.getNumLayers(), new LSTM.Builder().nOut(conf.getNumHiddenNodes()).activation(Activation.TANH).build());
- confB2.layer(conf.getNumLayer() + 1, new RnnOutputLayer.Builder(new ActorCriticLoss())
- .activation(Activation.SOFTMAX).nIn(conf.getNumHiddenNodes()).nOut(numOutputs).build());
+ confB2.layer(conf.getNumLayers() + 1, new RnnOutputLayer.Builder(new ActorCriticLoss())
+ .activation(Activation.SOFTMAX).nIn(conf.getNumHiddenNodes()).nOut(numOutputs).build());
} else {
- confB2.layer(conf.getNumLayer(), new OutputLayer.Builder(new ActorCriticLoss())
- .activation(Activation.SOFTMAX).nIn(conf.getNumHiddenNodes()).nOut(numOutputs).build());
+ confB2.layer(conf.getNumLayers(), new OutputLayer.Builder(new ActorCriticLoss())
+ .activation(Activation.SOFTMAX).nIn(conf.getNumHiddenNodes()).nOut(numOutputs).build());
}
confB2.setInputType(conf.isUseLSTM() ? InputType.recurrent(nIn) : InputType.feedForward(nIn));
@@ -128,6 +132,7 @@ public class ActorCriticFactorySeparateStdDense implements ActorCriticFactorySep
@AllArgsConstructor
@Value
@Builder
+ @Deprecated
public static class Configuration {
int numLayer;
@@ -136,6 +141,22 @@ public class ActorCriticFactorySeparateStdDense implements ActorCriticFactorySep
IUpdater updater;
TrainingListener[] listeners;
boolean useLSTM;
+
+ public ActorCriticDenseNetworkConfiguration toNetworkConfiguration() {
+ ActorCriticDenseNetworkConfigurationBuilder builder = ActorCriticDenseNetworkConfiguration.builder()
+ .numHiddenNodes(numHiddenNodes)
+ .numLayers(numLayer)
+ .l2(l2)
+ .updater(updater)
+ .useLSTM(useLSTM);
+
+ if (listeners != null) {
+ builder.listeners(Arrays.asList(listeners));
+ }
+
+ return builder.build();
+
+ }
}
diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/configuration/ActorCriticDenseNetworkConfiguration.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/configuration/ActorCriticDenseNetworkConfiguration.java
new file mode 100644
index 000000000..e85ec6356
--- /dev/null
+++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/configuration/ActorCriticDenseNetworkConfiguration.java
@@ -0,0 +1,42 @@
+/*******************************************************************************
+ * Copyright (c) 2020 Konduit K.K.
+ *
+ * This program and the accompanying materials are made available under the
+ * terms of the Apache License, Version 2.0 which is available at
+ * https://www.apache.org/licenses/LICENSE-2.0.
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+ * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+ * License for the specific language governing permissions and limitations
+ * under the License.
+ *
+ * SPDX-License-Identifier: Apache-2.0
+ ******************************************************************************/
+
+package org.deeplearning4j.rl4j.network.configuration;
+
+import lombok.Builder;
+import lombok.Data;
+import lombok.EqualsAndHashCode;
+import lombok.experimental.SuperBuilder;
+
+@Data
+@SuperBuilder
+@EqualsAndHashCode(callSuper = true)
+public class ActorCriticDenseNetworkConfiguration extends ActorCriticNetworkConfiguration {
+
+ /**
+ * The number of layers in the dense network
+ */
+ @Builder.Default
+ private int numLayers = 3;
+
+ /**
+ * The number of hidden neurons in each layer
+ */
+ @Builder.Default
+ private int numHiddenNodes = 100;
+
+
+}
diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/configuration/ActorCriticNetworkConfiguration.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/configuration/ActorCriticNetworkConfiguration.java
new file mode 100644
index 000000000..c043f458e
--- /dev/null
+++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/configuration/ActorCriticNetworkConfiguration.java
@@ -0,0 +1,37 @@
+/*******************************************************************************
+ * Copyright (c) 2020 Konduit K.K.
+ *
+ * This program and the accompanying materials are made available under the
+ * terms of the Apache License, Version 2.0 which is available at
+ * https://www.apache.org/licenses/LICENSE-2.0.
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+ * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+ * License for the specific language governing permissions and limitations
+ * under the License.
+ *
+ * SPDX-License-Identifier: Apache-2.0
+ ******************************************************************************/
+
+package org.deeplearning4j.rl4j.network.configuration;
+
+import lombok.Builder;
+import lombok.Data;
+import lombok.EqualsAndHashCode;
+import lombok.NoArgsConstructor;
+import lombok.experimental.SuperBuilder;
+
+@Data
+@SuperBuilder
+@NoArgsConstructor
+@EqualsAndHashCode(callSuper = true)
+public class ActorCriticNetworkConfiguration extends NetworkConfiguration {
+
+ /**
+ * Whether or not to add an LSTM layer to the network.
+ */
+ @Builder.Default
+ private boolean useLSTM = false;
+
+}
diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/configuration/DQNDenseNetworkConfiguration.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/configuration/DQNDenseNetworkConfiguration.java
new file mode 100644
index 000000000..452cb83c2
--- /dev/null
+++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/configuration/DQNDenseNetworkConfiguration.java
@@ -0,0 +1,40 @@
+/*******************************************************************************
+ * Copyright (c) 2020 Konduit K.K.
+ *
+ * This program and the accompanying materials are made available under the
+ * terms of the Apache License, Version 2.0 which is available at
+ * https://www.apache.org/licenses/LICENSE-2.0.
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+ * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+ * License for the specific language governing permissions and limitations
+ * under the License.
+ *
+ * SPDX-License-Identifier: Apache-2.0
+ ******************************************************************************/
+
+package org.deeplearning4j.rl4j.network.configuration;
+
+import lombok.Builder;
+import lombok.Data;
+import lombok.EqualsAndHashCode;
+import lombok.experimental.SuperBuilder;
+
+@Data
+@SuperBuilder
+@EqualsAndHashCode(callSuper = true)
+public class DQNDenseNetworkConfiguration extends NetworkConfiguration {
+
+ /**
+ * The number of layers in the dense network
+ */
+ @Builder.Default
+ private int numLayers = 3;
+
+ /**
+ * The number of hidden neurons in each layer
+ */
+ @Builder.Default
+ private int numHiddenNodes = 100;
+}
diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/configuration/NetworkConfiguration.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/configuration/NetworkConfiguration.java
new file mode 100644
index 000000000..c77c379a2
--- /dev/null
+++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/configuration/NetworkConfiguration.java
@@ -0,0 +1,58 @@
+/*******************************************************************************
+ * Copyright (c) 2020 Konduit K.K.
+ *
+ * This program and the accompanying materials are made available under the
+ * terms of the Apache License, Version 2.0 which is available at
+ * https://www.apache.org/licenses/LICENSE-2.0.
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+ * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+ * License for the specific language governing permissions and limitations
+ * under the License.
+ *
+ * SPDX-License-Identifier: Apache-2.0
+ ******************************************************************************/
+
+package org.deeplearning4j.rl4j.network.configuration;
+
+import lombok.Builder;
+import lombok.Data;
+import lombok.NoArgsConstructor;
+import lombok.Singular;
+import lombok.experimental.SuperBuilder;
+import org.deeplearning4j.optimize.api.TrainingListener;
+import org.nd4j.linalg.learning.config.IUpdater;
+
+import java.util.List;
+
+
+@Data
+@SuperBuilder
+@NoArgsConstructor
+public class NetworkConfiguration {
+
+ /**
+ * The learning rate of the network
+ */
+ @Builder.Default
+ private double learningRate = 0.01;
+
+ /**
+ * L2 regularization on the network
+ */
+ @Builder.Default
+ private double l2 = 0.0;
+
+ /**
+ * The network's gradient update algorithm
+ */
+ private IUpdater updater;
+
+ /**
+ * Training listeners attached to the network
+ */
+ @Singular
+ private List listeners;
+
+}
diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/dqn/DQNFactoryStdConv.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/dqn/DQNFactoryStdConv.java
index ec09d1c1c..077bbf1ce 100644
--- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/dqn/DQNFactoryStdConv.java
+++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/dqn/DQNFactoryStdConv.java
@@ -1,5 +1,6 @@
/*******************************************************************************
- * Copyright (c) 2015-2018 Skymind, Inc.
+ * Copyright (c) 2015-2019 Skymind, Inc.
+ * Copyright (c) 2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
@@ -30,12 +31,15 @@ import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.optimize.api.TrainingListener;
import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
+import org.deeplearning4j.rl4j.network.configuration.NetworkConfiguration;
import org.deeplearning4j.rl4j.util.Constants;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.learning.config.Adam;
import org.nd4j.linalg.learning.config.IUpdater;
import org.nd4j.linalg.lossfunctions.LossFunctions;
+import java.util.Arrays;
+
/**
* @author rubenfiszel (ruben.fiszel@epfl.ch) 7/13/16.
*/
@@ -43,7 +47,7 @@ import org.nd4j.linalg.lossfunctions.LossFunctions;
public class DQNFactoryStdConv implements DQNFactory {
- Configuration conf;
+ NetworkConfiguration conf;
public DQN buildDQN(int shapeInputs[], int numOutputs) {
@@ -80,7 +84,6 @@ public class DQNFactoryStdConv implements DQNFactory {
return new DQN(model);
}
-
@AllArgsConstructor
@Builder
@Value
@@ -90,6 +93,23 @@ public class DQNFactoryStdConv implements DQNFactory {
double l2;
IUpdater updater;
TrainingListener[] listeners;
+
+ /**
+ * Converts the deprecated Configuration to the new NetworkConfiguration format
+ */
+ public NetworkConfiguration toNetworkConfiguration() {
+ NetworkConfiguration.NetworkConfigurationBuilder builder = NetworkConfiguration.builder()
+ .learningRate(learningRate)
+ .l2(l2)
+ .updater(updater);
+
+ if (listeners != null) {
+ builder.listeners(Arrays.asList(listeners));
+ }
+
+ return builder.build();
+
+ }
}
}
diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/dqn/DQNFactoryStdDense.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/dqn/DQNFactoryStdDense.java
index 323ca7ecb..ebe730b4d 100644
--- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/dqn/DQNFactoryStdDense.java
+++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/dqn/DQNFactoryStdDense.java
@@ -1,5 +1,6 @@
/*******************************************************************************
- * Copyright (c) 2015-2018 Skymind, Inc.
+ * Copyright (c) 2015-2019 Skymind, Inc.
+ * Copyright (c) 2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
@@ -28,12 +29,16 @@ import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.optimize.api.TrainingListener;
import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
+import org.deeplearning4j.rl4j.network.configuration.DQNDenseNetworkConfiguration;
+import org.deeplearning4j.rl4j.network.configuration.DQNDenseNetworkConfiguration.DQNDenseNetworkConfigurationBuilder;
import org.deeplearning4j.rl4j.util.Constants;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.learning.config.Adam;
import org.nd4j.linalg.learning.config.IUpdater;
import org.nd4j.linalg.lossfunctions.LossFunctions;
+import java.util.Arrays;
+
/**
* @author rubenfiszel (ruben.fiszel@epfl.ch) 7/13/16.
*/
@@ -41,32 +46,41 @@ import org.nd4j.linalg.lossfunctions.LossFunctions;
@Value
public class DQNFactoryStdDense implements DQNFactory {
-
- Configuration conf;
+ DQNDenseNetworkConfiguration conf;
public DQN buildDQN(int[] numInputs, int numOutputs) {
int nIn = 1;
+
for (int i : numInputs) {
nIn *= i;
}
+
NeuralNetConfiguration.ListBuilder confB = new NeuralNetConfiguration.Builder().seed(Constants.NEURAL_NET_SEED)
- .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
- //.updater(Updater.NESTEROVS).momentum(0.9)
- //.updater(Updater.RMSPROP).rho(conf.getRmsDecay())//.rmsDecay(conf.getRmsDecay())
- .updater(conf.getUpdater() != null ? conf.getUpdater() : new Adam())
- .weightInit(WeightInit.XAVIER)
- .l2(conf.getL2())
- .list().layer(0, new DenseLayer.Builder().nIn(nIn).nOut(conf.getNumHiddenNodes())
- .activation(Activation.RELU).build());
+ .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
+ .updater(conf.getUpdater() != null ? conf.getUpdater() : new Adam())
+ .weightInit(WeightInit.XAVIER)
+ .l2(conf.getL2())
+ .list()
+ .layer(0,
+ new DenseLayer.Builder()
+ .nIn(nIn)
+ .nOut(conf.getNumHiddenNodes())
+ .activation(Activation.RELU).build()
+ );
- for (int i = 1; i < conf.getNumLayer(); i++) {
+ for (int i = 1; i < conf.getNumLayers(); i++) {
confB.layer(i, new DenseLayer.Builder().nIn(conf.getNumHiddenNodes()).nOut(conf.getNumHiddenNodes())
- .activation(Activation.RELU).build());
+ .activation(Activation.RELU).build());
}
- confB.layer(conf.getNumLayer(), new OutputLayer.Builder(LossFunctions.LossFunction.MSE).activation(Activation.IDENTITY)
- .nIn(conf.getNumHiddenNodes()).nOut(numOutputs).build());
+ confB.layer(conf.getNumLayers(),
+ new OutputLayer.Builder(LossFunctions.LossFunction.MSE)
+ .activation(Activation.IDENTITY)
+ .nIn(conf.getNumHiddenNodes())
+ .nOut(numOutputs)
+ .build()
+ );
MultiLayerConfiguration mlnconf = confB.build();
@@ -83,6 +97,7 @@ public class DQNFactoryStdDense implements DQNFactory {
@AllArgsConstructor
@Value
@Builder
+ @Deprecated
public static class Configuration {
int numLayer;
@@ -90,7 +105,23 @@ public class DQNFactoryStdDense implements DQNFactory {
double l2;
IUpdater updater;
TrainingListener[] listeners;
+
+ /**
+ * Converts the deprecated Configuration to the new NetworkConfiguration format
+ */
+ public DQNDenseNetworkConfiguration toNetworkConfiguration() {
+ DQNDenseNetworkConfigurationBuilder builder = DQNDenseNetworkConfiguration.builder()
+ .numHiddenNodes(numHiddenNodes)
+ .numLayers(numLayer)
+ .l2(l2)
+ .updater(updater);
+
+ if (listeners != null) {
+ builder.listeners(Arrays.asList(listeners));
+ }
+
+ return builder.build();
+ }
}
-
}
diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/policy/EpsGreedy.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/policy/EpsGreedy.java
index 3ed375084..3454a37e6 100644
--- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/policy/EpsGreedy.java
+++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/policy/EpsGreedy.java
@@ -1,5 +1,6 @@
/*******************************************************************************
- * Copyright (c) 2015-2018 Skymind, Inc.
+ * Copyright (c) 2015-2019 Skymind, Inc.
+ * Copyright (c) 2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
@@ -46,7 +47,7 @@ public class EpsGreedy> extends Policy {
final private int updateStart;
final private int epsilonNbStep;
final private Random rnd;
- final private float minEpsilon;
+ final private double minEpsilon;
final private IEpochTrainer learning;
public NeuralNet getNeuralNet() {
@@ -55,10 +56,10 @@ public class EpsGreedy> extends Policy {
public A nextAction(INDArray input) {
- float ep = getEpsilon();
+ double ep = getEpsilon();
if (learning.getStepCounter() % 500 == 1)
log.info("EP: " + ep + " " + learning.getStepCounter());
- if (rnd.nextFloat() > ep)
+ if (rnd.nextDouble() > ep)
return policy.nextAction(input);
else
return mdp.getActionSpace().randomAction();
@@ -68,7 +69,7 @@ public class EpsGreedy> extends Policy {
return this.nextAction(observation.getData());
}
- public float getEpsilon() {
- return Math.min(1f, Math.max(minEpsilon, 1f - (learning.getStepCounter() - updateStart) * 1f / epsilonNbStep));
+ public double getEpsilon() {
+ return Math.min(1.0, Math.max(minEpsilon, 1.0 - (learning.getStepCounter() - updateStart) * 1.0 / epsilonNbStep));
}
}
diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/util/DataManager.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/util/DataManager.java
index b639efdaa..bffafdb76 100644
--- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/util/DataManager.java
+++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/util/DataManager.java
@@ -1,5 +1,6 @@
/*******************************************************************************
- * Copyright (c) 2015-2018 Skymind, Inc.
+ * Copyright (c) 2015-2019 Skymind, Inc.
+ * Copyright (c) 2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
@@ -22,17 +23,30 @@ import lombok.Builder;
import lombok.Getter;
import lombok.Value;
import lombok.extern.slf4j.Slf4j;
-import org.deeplearning4j.rl4j.learning.NeuralNetFetchable;
-import org.nd4j.linalg.primitives.Pair;
import org.deeplearning4j.rl4j.learning.ILearning;
-import org.deeplearning4j.rl4j.learning.Learning;
+import org.deeplearning4j.rl4j.learning.NeuralNetFetchable;
+import org.deeplearning4j.rl4j.learning.configuration.ILearningConfiguration;
import org.deeplearning4j.rl4j.network.ac.IActorCritic;
import org.deeplearning4j.rl4j.network.dqn.DQN;
import org.deeplearning4j.rl4j.network.dqn.IDQN;
import org.deeplearning4j.util.ModelSerializer;
+import org.nd4j.linalg.primitives.Pair;
-import java.io.*;
-import java.nio.file.*;
+import java.io.BufferedOutputStream;
+import java.io.BufferedReader;
+import java.io.ByteArrayInputStream;
+import java.io.ByteArrayOutputStream;
+import java.io.File;
+import java.io.FileOutputStream;
+import java.io.IOException;
+import java.io.InputStream;
+import java.io.InputStreamReader;
+import java.io.OutputStream;
+import java.nio.file.Files;
+import java.nio.file.Path;
+import java.nio.file.Paths;
+import java.nio.file.StandardCopyOption;
+import java.nio.file.StandardOpenOption;
import java.util.zip.ZipEntry;
import java.util.zip.ZipFile;
import java.util.zip.ZipOutputStream;
@@ -304,7 +318,7 @@ public class DataManager implements IDataManager {
public static class Info {
String trainingName;
String mdpName;
- ILearning.LConfiguration conf;
+ ILearningConfiguration conf;
int stepCounter;
long millisTime;
}
diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/HistoryProcessorTest.java b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/HistoryProcessorTest.java
index 26ec0708f..8718d252d 100644
--- a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/HistoryProcessorTest.java
+++ b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/HistoryProcessorTest.java
@@ -1,5 +1,6 @@
/*******************************************************************************
* Copyright (c) 2015-2019 Skymind, Inc.
+ * Copyright (c) 2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
@@ -16,14 +17,11 @@
package org.deeplearning4j.rl4j.learning;
-import java.util.Arrays;
import org.junit.Test;
-import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import static org.junit.Assert.assertEquals;
-import static org.junit.Assert.assertTrue;
/**
*
@@ -32,7 +30,7 @@ import static org.junit.Assert.assertTrue;
public class HistoryProcessorTest {
@Test
- public void testHistoryProcessor() throws Exception {
+ public void testHistoryProcessor() {
HistoryProcessor.Configuration conf = HistoryProcessor.Configuration.builder()
.croppingHeight(2).croppingWidth(2).rescaledHeight(2).rescaledWidth(2).build();
IHistoryProcessor hp = new HistoryProcessor(conf);
@@ -43,8 +41,6 @@ public class HistoryProcessorTest {
hp.add(a);
INDArray[] h = hp.getHistory();
assertEquals(4, h.length);
-// System.out.println(Arrays.toString(a.shape()));
-// System.out.println(Arrays.toString(h[0].shape()));
assertEquals( 1, h[0].shape()[0]);
assertEquals(a.shape()[0], h[0].shape()[1]);
assertEquals(a.shape()[1], h[0].shape()[2]);
diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/async/AsyncLearningTest.java b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/async/AsyncLearningTest.java
index 2302117d2..f2941feef 100644
--- a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/async/AsyncLearningTest.java
+++ b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/async/AsyncLearningTest.java
@@ -1,9 +1,32 @@
+/*******************************************************************************
+ * Copyright (c) 2015-2019 Skymind, Inc.
+ * Copyright (c) 2020 Konduit K.K.
+ *
+ * This program and the accompanying materials are made available under the
+ * terms of the Apache License, Version 2.0 which is available at
+ * https://www.apache.org/licenses/LICENSE-2.0.
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+ * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+ * License for the specific language governing permissions and limitations
+ * under the License.
+ *
+ * SPDX-License-Identifier: Apache-2.0
+ ******************************************************************************/
+
package org.deeplearning4j.rl4j.learning.async;
+import org.deeplearning4j.rl4j.learning.configuration.IAsyncLearningConfiguration;
import org.deeplearning4j.rl4j.mdp.MDP;
import org.deeplearning4j.rl4j.policy.IPolicy;
import org.deeplearning4j.rl4j.space.DiscreteSpace;
-import org.deeplearning4j.rl4j.support.*;
+import org.deeplearning4j.rl4j.support.MockAsyncConfiguration;
+import org.deeplearning4j.rl4j.support.MockAsyncGlobal;
+import org.deeplearning4j.rl4j.support.MockEncodable;
+import org.deeplearning4j.rl4j.support.MockNeuralNet;
+import org.deeplearning4j.rl4j.support.MockPolicy;
+import org.deeplearning4j.rl4j.support.MockTrainingListener;
import org.junit.Test;
import static org.junit.Assert.assertEquals;
@@ -68,7 +91,7 @@ public class AsyncLearningTest {
public static class TestContext {
- MockAsyncConfiguration config = new MockAsyncConfiguration(1, 11, 0, 0, 0, 0,0, 0, 0, 0);
+ MockAsyncConfiguration config = new MockAsyncConfiguration(1L, 11, 0, 0, 0, 0,0, 0, 0, 0);
public final MockAsyncGlobal asyncGlobal = new MockAsyncGlobal();
public final MockPolicy policy = new MockPolicy();
public final TestAsyncLearning sut = new TestAsyncLearning(config, asyncGlobal, policy);
@@ -82,11 +105,11 @@ public class AsyncLearningTest {
}
public static class TestAsyncLearning extends AsyncLearning {
- private final AsyncConfiguration conf;
+ private final IAsyncLearningConfiguration conf;
private final IAsyncGlobal asyncGlobal;
private final IPolicy policy;
- public TestAsyncLearning(AsyncConfiguration conf, IAsyncGlobal asyncGlobal, IPolicy policy) {
+ public TestAsyncLearning(IAsyncLearningConfiguration conf, IAsyncGlobal asyncGlobal, IPolicy policy) {
this.conf = conf;
this.asyncGlobal = asyncGlobal;
this.policy = policy;
@@ -98,7 +121,7 @@ public class AsyncLearningTest {
}
@Override
- public AsyncConfiguration getConfiguration() {
+ public IAsyncLearningConfiguration getConfiguration() {
return conf;
}
diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/async/AsyncThreadDiscreteTest.java b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/async/AsyncThreadDiscreteTest.java
index bc396502f..72f374db5 100644
--- a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/async/AsyncThreadDiscreteTest.java
+++ b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/async/AsyncThreadDiscreteTest.java
@@ -1,7 +1,25 @@
+/*******************************************************************************
+ * Copyright (c) 2015-2019 Skymind, Inc.
+ * Copyright (c) 2020 Konduit K.K.
+ *
+ * This program and the accompanying materials are made available under the
+ * terms of the Apache License, Version 2.0 which is available at
+ * https://www.apache.org/licenses/LICENSE-2.0.
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+ * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+ * License for the specific language governing permissions and limitations
+ * under the License.
+ *
+ * SPDX-License-Identifier: Apache-2.0
+ ******************************************************************************/
+
package org.deeplearning4j.rl4j.learning.async;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.rl4j.learning.IHistoryProcessor;
+import org.deeplearning4j.rl4j.learning.configuration.IAsyncLearningConfiguration;
import org.deeplearning4j.rl4j.learning.listener.TrainingListenerList;
import org.deeplearning4j.rl4j.mdp.MDP;
import org.deeplearning4j.rl4j.observation.Observation;
@@ -32,7 +50,7 @@ public class AsyncThreadDiscreteTest {
MockMDP mdpMock = new MockMDP(observationSpace);
TrainingListenerList listeners = new TrainingListenerList();
MockPolicy policyMock = new MockPolicy();
- MockAsyncConfiguration config = new MockAsyncConfiguration(5, 100, 0, 0, 2, 5,0, 0, 0, 0);
+ MockAsyncConfiguration config = new MockAsyncConfiguration(5L, 100, 0,0, 0, 0, 0, 0, 2, 5);
TestAsyncThreadDiscrete sut = new TestAsyncThreadDiscrete(asyncGlobalMock, mdpMock, listeners, 0, 0, policyMock, config, hpMock);
sut.getLegacyMDPWrapper().setTransformProcess(MockMDP.buildTransformProcess(observationSpace.getShape(), hpConf.getSkipFrame(), hpConf.getHistoryLength()));
@@ -173,7 +191,7 @@ public class AsyncThreadDiscreteTest {
}
@Override
- protected AsyncConfiguration getConf() {
+ protected IAsyncLearningConfiguration getConf() {
return config;
}
diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/async/AsyncThreadTest.java b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/async/AsyncThreadTest.java
index 3dea25936..ff29960f1 100644
--- a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/async/AsyncThreadTest.java
+++ b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/async/AsyncThreadTest.java
@@ -3,12 +3,20 @@ package org.deeplearning4j.rl4j.learning.async;
import lombok.AllArgsConstructor;
import lombok.Getter;
import org.deeplearning4j.rl4j.learning.IHistoryProcessor;
+import org.deeplearning4j.rl4j.learning.configuration.IAsyncLearningConfiguration;
import org.deeplearning4j.rl4j.learning.listener.TrainingListenerList;
import org.deeplearning4j.rl4j.mdp.MDP;
import org.deeplearning4j.rl4j.observation.Observation;
import org.deeplearning4j.rl4j.policy.Policy;
import org.deeplearning4j.rl4j.space.DiscreteSpace;
-import org.deeplearning4j.rl4j.support.*;
+import org.deeplearning4j.rl4j.support.MockAsyncConfiguration;
+import org.deeplearning4j.rl4j.support.MockAsyncGlobal;
+import org.deeplearning4j.rl4j.support.MockEncodable;
+import org.deeplearning4j.rl4j.support.MockHistoryProcessor;
+import org.deeplearning4j.rl4j.support.MockMDP;
+import org.deeplearning4j.rl4j.support.MockNeuralNet;
+import org.deeplearning4j.rl4j.support.MockObservationSpace;
+import org.deeplearning4j.rl4j.support.MockTrainingListener;
import org.deeplearning4j.rl4j.util.IDataManager;
import org.junit.Test;
@@ -16,7 +24,6 @@ import java.util.ArrayList;
import java.util.List;
import static org.junit.Assert.assertEquals;
-import static org.junit.Assert.assertNull;
public class AsyncThreadTest {
@@ -126,7 +133,7 @@ public class AsyncThreadTest {
public final MockNeuralNet neuralNet = new MockNeuralNet();
public final MockObservationSpace observationSpace = new MockObservationSpace();
public final MockMDP mdp = new MockMDP(observationSpace);
- public final MockAsyncConfiguration config = new MockAsyncConfiguration(5, 10, 0, 0, 10, 0, 0, 0, 0, 0);
+ public final MockAsyncConfiguration config = new MockAsyncConfiguration(5L, 10, 0, 0, 0, 0, 0, 0, 10, 0);
public final TrainingListenerList listeners = new TrainingListenerList();
public final MockTrainingListener listener = new MockTrainingListener();
public final IHistoryProcessor.Configuration hpConf = new IHistoryProcessor.Configuration(5, 4, 4, 4, 4, 0, 0, 2);
@@ -149,11 +156,11 @@ public class AsyncThreadTest {
private final MockAsyncGlobal asyncGlobal;
private final MockNeuralNet neuralNet;
- private final AsyncConfiguration conf;
+ private final IAsyncLearningConfiguration conf;
private final List trainSubEpochParams = new ArrayList();
- public MockAsyncThread(MockAsyncGlobal asyncGlobal, int threadNumber, MockNeuralNet neuralNet, MDP mdp, AsyncConfiguration conf, TrainingListenerList listeners) {
+ public MockAsyncThread(MockAsyncGlobal asyncGlobal, int threadNumber, MockNeuralNet neuralNet, MDP mdp, IAsyncLearningConfiguration conf, TrainingListenerList listeners) {
super(asyncGlobal, mdp, listeners, threadNumber, 0);
this.asyncGlobal = asyncGlobal;
@@ -184,7 +191,7 @@ public class AsyncThreadTest {
}
@Override
- protected AsyncConfiguration getConf() {
+ protected IAsyncLearningConfiguration getConf() {
return conf;
}
diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/async/a3c/discrete/A3CThreadDiscreteTest.java b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/async/a3c/discrete/A3CThreadDiscreteTest.java
index ef7fec7d0..b812a5582 100644
--- a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/async/a3c/discrete/A3CThreadDiscreteTest.java
+++ b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/async/a3c/discrete/A3CThreadDiscreteTest.java
@@ -1,11 +1,27 @@
+/*******************************************************************************
+ * Copyright (c) 2015-2019 Skymind, Inc.
+ * Copyright (c) 2020 Konduit K.K.
+ *
+ * This program and the accompanying materials are made available under the
+ * terms of the Apache License, Version 2.0 which is available at
+ * https://www.apache.org/licenses/LICENSE-2.0.
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+ * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+ * License for the specific language governing permissions and limitations
+ * under the License.
+ *
+ * SPDX-License-Identifier: Apache-2.0
+ ******************************************************************************/
+
package org.deeplearning4j.rl4j.learning.async.a3c.discrete;
import org.deeplearning4j.nn.api.NeuralNetwork;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.rl4j.learning.IHistoryProcessor;
import org.deeplearning4j.rl4j.learning.async.MiniTrans;
-import org.deeplearning4j.rl4j.learning.async.nstep.discrete.AsyncNStepQLearningDiscrete;
-import org.deeplearning4j.rl4j.learning.async.nstep.discrete.AsyncNStepQLearningThreadDiscrete;
+import org.deeplearning4j.rl4j.learning.configuration.A3CLearningConfiguration;
import org.deeplearning4j.rl4j.network.NeuralNet;
import org.deeplearning4j.rl4j.network.ac.IActorCritic;
import org.deeplearning4j.rl4j.support.*;
@@ -31,7 +47,7 @@ public class A3CThreadDiscreteTest {
double gamma = 0.9;
MockObservationSpace observationSpace = new MockObservationSpace();
MockMDP mdpMock = new MockMDP(observationSpace);
- A3CDiscrete.A3CConfiguration config = new A3CDiscrete.A3CConfiguration(0, 0, 0, 0, 0, 0, 0, gamma, 0);
+ A3CLearningConfiguration config = A3CLearningConfiguration.builder().gamma(0.9).build();
MockActorCritic actorCriticMock = new MockActorCritic();
IHistoryProcessor.Configuration hpConf = new IHistoryProcessor.Configuration(5, 1, 1, 1, 1, 0, 0, 2);
MockAsyncGlobal asyncGlobalMock = new MockAsyncGlobal(actorCriticMock);
@@ -54,9 +70,9 @@ public class A3CThreadDiscreteTest {
Nd4j.zeros(5)
};
output[0].putScalar(i, outputs[i]);
- minitransList.push(new MiniTrans(obs, i, output, rewards[i]));
+ minitransList.push(new MiniTrans<>(obs, i, output, rewards[i]));
}
- minitransList.push(new MiniTrans(null, 0, null, 4.0)); // The special batch-ending MiniTrans
+ minitransList.push(new MiniTrans<>(null, 0, null, 4.0)); // The special batch-ending MiniTrans
// Act
sut.calcGradient(actorCriticMock, minitransList);
diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/async/nstep/discrete/AsyncNStepQLearningThreadDiscreteTest.java b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/async/nstep/discrete/AsyncNStepQLearningThreadDiscreteTest.java
index d105419df..2a8c5b832 100644
--- a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/async/nstep/discrete/AsyncNStepQLearningThreadDiscreteTest.java
+++ b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/async/nstep/discrete/AsyncNStepQLearningThreadDiscreteTest.java
@@ -1,7 +1,24 @@
+/*******************************************************************************
+ * Copyright (c) 2015-2020 Skymind, Inc.
+ *
+ * This program and the accompanying materials are made available under the
+ * terms of the Apache License, Version 2.0 which is available at
+ * https://www.apache.org/licenses/LICENSE-2.0.
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+ * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+ * License for the specific language governing permissions and limitations
+ * under the License.
+ *
+ * SPDX-License-Identifier: Apache-2.0
+ ******************************************************************************/
+
package org.deeplearning4j.rl4j.learning.async.nstep.discrete;
import org.deeplearning4j.rl4j.learning.IHistoryProcessor;
import org.deeplearning4j.rl4j.learning.async.MiniTrans;
+import org.deeplearning4j.rl4j.learning.configuration.AsyncQLearningConfiguration;
import org.deeplearning4j.rl4j.support.*;
import org.junit.Test;
import org.nd4j.linalg.api.ndarray.INDArray;
@@ -19,7 +36,7 @@ public class AsyncNStepQLearningThreadDiscreteTest {
double gamma = 0.9;
MockObservationSpace observationSpace = new MockObservationSpace();
MockMDP mdpMock = new MockMDP(observationSpace);
- AsyncNStepQLearningDiscrete.AsyncNStepQLConfiguration config = new AsyncNStepQLearningDiscrete.AsyncNStepQLConfiguration(0, 0, 0, 0, 0, 0, 0, 0, gamma, 0, 0, 0);
+ AsyncQLearningConfiguration config = AsyncQLearningConfiguration.builder().gamma(gamma).build();
MockDQN dqnMock = new MockDQN();
IHistoryProcessor.Configuration hpConf = new IHistoryProcessor.Configuration(5, 1, 1, 1, 1, 0, 0, 2);
MockAsyncGlobal asyncGlobalMock = new MockAsyncGlobal(dqnMock);
@@ -42,9 +59,9 @@ public class AsyncNStepQLearningThreadDiscreteTest {
Nd4j.zeros(5)
};
output[0].putScalar(i, outputs[i]);
- minitransList.push(new MiniTrans(obs, i, output, rewards[i]));
+ minitransList.push(new MiniTrans<>(obs, i, output, rewards[i]));
}
- minitransList.push(new MiniTrans(null, 0, null, 4.0)); // The special batch-ending MiniTrans
+ minitransList.push(new MiniTrans<>(null, 0, null, 4.0)); // The special batch-ending MiniTrans
// Act
sut.calcGradient(dqnMock, minitransList);
diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/SyncLearningTest.java b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/SyncLearningTest.java
index 79be025b5..22e4be3f6 100644
--- a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/SyncLearningTest.java
+++ b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/SyncLearningTest.java
@@ -1,6 +1,26 @@
+/*******************************************************************************
+ * Copyright (c) 2015-2019 Skymind, Inc.
+ * Copyright (c) 2020 Konduit K.K.
+ *
+ * This program and the accompanying materials are made available under the
+ * terms of the Apache License, Version 2.0 which is available at
+ * https://www.apache.org/licenses/LICENSE-2.0.
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+ * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+ * License for the specific language governing permissions and limitations
+ * under the License.
+ *
+ * SPDX-License-Identifier: Apache-2.0
+ ******************************************************************************/
+
package org.deeplearning4j.rl4j.learning.sync;
import lombok.Getter;
+import org.deeplearning4j.rl4j.learning.configuration.ILearningConfiguration;
+import org.deeplearning4j.rl4j.learning.configuration.LearningConfiguration;
+import org.deeplearning4j.rl4j.learning.configuration.QLearningConfiguration;
import org.deeplearning4j.rl4j.learning.sync.qlearning.QLearning;
import org.deeplearning4j.rl4j.learning.sync.support.MockStatEntry;
import org.deeplearning4j.rl4j.mdp.MDP;
@@ -17,7 +37,7 @@ public class SyncLearningTest {
@Test
public void when_training_expect_listenersToBeCalled() {
// Arrange
- QLearning.QLConfiguration lconfig = QLearning.QLConfiguration.builder().maxStep(10).build();
+ QLearningConfiguration lconfig = QLearningConfiguration.builder().maxStep(10).build();
MockTrainingListener listener = new MockTrainingListener();
MockSyncLearning sut = new MockSyncLearning(lconfig);
sut.addListener(listener);
@@ -34,7 +54,7 @@ public class SyncLearningTest {
@Test
public void when_trainingStartCanContinueFalse_expect_trainingStopped() {
// Arrange
- QLearning.QLConfiguration lconfig = QLearning.QLConfiguration.builder().maxStep(10).build();
+ QLearningConfiguration lconfig = QLearningConfiguration.builder().maxStep(10).build();
MockTrainingListener listener = new MockTrainingListener();
MockSyncLearning sut = new MockSyncLearning(lconfig);
sut.addListener(listener);
@@ -52,7 +72,7 @@ public class SyncLearningTest {
@Test
public void when_newEpochCanContinueFalse_expect_trainingStopped() {
// Arrange
- QLearning.QLConfiguration lconfig = QLearning.QLConfiguration.builder().maxStep(10).build();
+ QLearningConfiguration lconfig = QLearningConfiguration.builder().maxStep(10).build();
MockTrainingListener listener = new MockTrainingListener();
MockSyncLearning sut = new MockSyncLearning(lconfig);
sut.addListener(listener);
@@ -70,7 +90,7 @@ public class SyncLearningTest {
@Test
public void when_epochTrainingResultCanContinueFalse_expect_trainingStopped() {
// Arrange
- QLearning.QLConfiguration lconfig = QLearning.QLConfiguration.builder().maxStep(10).build();
+ LearningConfiguration lconfig = QLearningConfiguration.builder().maxStep(10).build();
MockTrainingListener listener = new MockTrainingListener();
MockSyncLearning sut = new MockSyncLearning(lconfig);
sut.addListener(listener);
@@ -87,12 +107,12 @@ public class SyncLearningTest {
public static class MockSyncLearning extends SyncLearning {
- private final LConfiguration conf;
+ private final ILearningConfiguration conf;
@Getter
private int currentEpochStep = 0;
- public MockSyncLearning(LConfiguration conf) {
+ public MockSyncLearning(ILearningConfiguration conf) {
this.conf = conf;
}
@@ -119,7 +139,7 @@ public class SyncLearningTest {
}
@Override
- public LConfiguration getConfiguration() {
+ public ILearningConfiguration getConfiguration() {
return conf;
}
diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/qlearning/QLConfigurationTest.java b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/qlearning/QLearningConfigurationTest.java
similarity index 52%
rename from rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/qlearning/QLConfigurationTest.java
rename to rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/qlearning/QLearningConfigurationTest.java
index b12866ed2..d7d9bf072 100644
--- a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/qlearning/QLConfigurationTest.java
+++ b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/qlearning/QLearningConfigurationTest.java
@@ -1,5 +1,6 @@
/*******************************************************************************
- * Copyright (c) 2015-2018 Skymind, Inc.
+ * Copyright (c) 2015-2019 Skymind, Inc.
+ * Copyright (c) 2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
@@ -17,36 +18,24 @@
package org.deeplearning4j.rl4j.learning.sync.qlearning;
import com.fasterxml.jackson.databind.ObjectMapper;
+import org.deeplearning4j.rl4j.learning.configuration.QLearningConfiguration;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.ExpectedException;
-public class QLConfigurationTest {
+public class QLearningConfigurationTest {
@Rule
public ExpectedException thrown = ExpectedException.none();
@Test
public void serialize() throws Exception {
ObjectMapper mapper = new ObjectMapper();
- QLearning.QLConfiguration qlConfiguration =
- new QLearning.QLConfiguration(
- 123, //Random seed
- 200, //Max step By epoch
- 8000, //Max step
- 150000, //Max size of experience replay
- 32, //size of batches
- 500, //target update (hard)
- 10, //num step noop warmup
- 0.01, //reward scaling
- 0.99, //gamma
- 1.0, //td error clipping
- 0.1f, //min epsilon
- 10000, //num step for eps greedy anneal
- true //double DQN
- );
+
+ QLearningConfiguration qLearningConfiguration = QLearningConfiguration.builder()
+ .build();
// Should not throw..
- String json = mapper.writeValueAsString(qlConfiguration);
- QLearning.QLConfiguration cnf = mapper.readValue(json, QLearning.QLConfiguration.class);
+ String json = mapper.writeValueAsString(qLearningConfiguration);
+ QLearningConfiguration cnf = mapper.readValue(json, QLearningConfiguration.class);
}
-}
\ No newline at end of file
+}
diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/QLearningDiscreteTest.java b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/QLearningDiscreteTest.java
index 58aaab297..fe8dd6acc 100644
--- a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/QLearningDiscreteTest.java
+++ b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/QLearningDiscreteTest.java
@@ -1,6 +1,24 @@
+/*******************************************************************************
+ * Copyright (c) 2015-2019 Skymind, Inc.
+ * Copyright (c) 2020 Konduit K.K.
+ *
+ * This program and the accompanying materials are made available under the
+ * terms of the Apache License, Version 2.0 which is available at
+ * https://www.apache.org/licenses/LICENSE-2.0.
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+ * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+ * License for the specific language governing permissions and limitations
+ * under the License.
+ *
+ * SPDX-License-Identifier: Apache-2.0
+ ******************************************************************************/
+
package org.deeplearning4j.rl4j.learning.sync.qlearning.discrete;
import org.deeplearning4j.rl4j.learning.IHistoryProcessor;
+import org.deeplearning4j.rl4j.learning.configuration.QLearningConfiguration;
import org.deeplearning4j.rl4j.learning.sync.IExpReplay;
import org.deeplearning4j.rl4j.learning.sync.Transition;
import org.deeplearning4j.rl4j.learning.sync.qlearning.QLearning;
@@ -27,7 +45,7 @@ public class QLearningDiscreteTest {
// Arrange
MockObservationSpace observationSpace = new MockObservationSpace();
MockDQN dqn = new MockDQN();
- MockRandom random = new MockRandom(new double[] {
+ MockRandom random = new MockRandom(new double[]{
0.7309677600860596,
0.8314409852027893,
0.2405363917350769,
@@ -36,14 +54,26 @@ public class QLearningDiscreteTest {
0.3090505599975586,
0.5504369735717773,
0.11700659990310669
- },
- new int[] { 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4 });
+ },
+ new int[]{1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4});
MockMDP mdp = new MockMDP(observationSpace, random);
int initStepCount = 8;
- QLearning.QLConfiguration conf = new QLearning.QLConfiguration(0, 24, 0, 5, 1, 1000,
- initStepCount, 1.0, 0, 0, 0, 0, true);
+ QLearningConfiguration conf = QLearningConfiguration.builder()
+ .seed(0L)
+ .maxEpochStep(24)
+ .maxStep(0)
+ .expRepMaxSize(5).batchSize(1).targetDqnUpdateFreq(1000)
+ .updateStart(initStepCount)
+ .rewardFactor(1.0)
+ .gamma(0)
+ .errorClamp(0)
+ .minEpsilon(0)
+ .epsilonNbStep(0)
+ .doubleDQN(true)
+ .build();
+
MockDataManager dataManager = new MockDataManager(false);
MockExpReplay expReplay = new MockExpReplay();
TestQLearningDiscrete sut = new TestQLearningDiscrete(mdp, dqn, conf, dataManager, expReplay, 10, random);
@@ -58,9 +88,9 @@ public class QLearningDiscreteTest {
// Assert
// HistoryProcessor calls
- double[] expectedRecords = new double[] { 0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0, 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0 };
+ double[] expectedRecords = new double[]{0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0, 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0};
assertEquals(expectedRecords.length, hp.recordCalls.size());
- for(int i = 0; i < expectedRecords.length; ++i) {
+ for (int i = 0; i < expectedRecords.length; ++i) {
assertEquals(expectedRecords[i], hp.recordCalls.get(i).getDouble(0), 0.0001);
}
@@ -72,59 +102,59 @@ public class QLearningDiscreteTest {
assertEquals(123.0, dqn.fitParams.get(0).getFirst().getDouble(0), 0.001);
assertEquals(234.0, dqn.fitParams.get(0).getSecond().getDouble(0), 0.001);
assertEquals(14, dqn.outputParams.size());
- double[][] expectedDQNOutput = new double[][] {
- new double[] { 0.0, 2.0, 4.0, 6.0, 8.0 },
- new double[] { 2.0, 4.0, 6.0, 8.0, 10.0 },
- new double[] { 2.0, 4.0, 6.0, 8.0, 10.0 },
- new double[] { 4.0, 6.0, 8.0, 10.0, 12.0 },
- new double[] { 6.0, 8.0, 10.0, 12.0, 14.0 },
- new double[] { 6.0, 8.0, 10.0, 12.0, 14.0 },
- new double[] { 8.0, 10.0, 12.0, 14.0, 16.0 },
- new double[] { 8.0, 10.0, 12.0, 14.0, 16.0 },
- new double[] { 10.0, 12.0, 14.0, 16.0, 18.0 },
- new double[] { 10.0, 12.0, 14.0, 16.0, 18.0 },
- new double[] { 12.0, 14.0, 16.0, 18.0, 20.0 },
- new double[] { 12.0, 14.0, 16.0, 18.0, 20.0 },
- new double[] { 14.0, 16.0, 18.0, 20.0, 22.0 },
- new double[] { 14.0, 16.0, 18.0, 20.0, 22.0 },
+ double[][] expectedDQNOutput = new double[][]{
+ new double[]{0.0, 2.0, 4.0, 6.0, 8.0},
+ new double[]{2.0, 4.0, 6.0, 8.0, 10.0},
+ new double[]{2.0, 4.0, 6.0, 8.0, 10.0},
+ new double[]{4.0, 6.0, 8.0, 10.0, 12.0},
+ new double[]{6.0, 8.0, 10.0, 12.0, 14.0},
+ new double[]{6.0, 8.0, 10.0, 12.0, 14.0},
+ new double[]{8.0, 10.0, 12.0, 14.0, 16.0},
+ new double[]{8.0, 10.0, 12.0, 14.0, 16.0},
+ new double[]{10.0, 12.0, 14.0, 16.0, 18.0},
+ new double[]{10.0, 12.0, 14.0, 16.0, 18.0},
+ new double[]{12.0, 14.0, 16.0, 18.0, 20.0},
+ new double[]{12.0, 14.0, 16.0, 18.0, 20.0},
+ new double[]{14.0, 16.0, 18.0, 20.0, 22.0},
+ new double[]{14.0, 16.0, 18.0, 20.0, 22.0},
};
- for(int i = 0; i < expectedDQNOutput.length; ++i) {
+ for (int i = 0; i < expectedDQNOutput.length; ++i) {
INDArray outputParam = dqn.outputParams.get(i);
assertEquals(5, outputParam.shape()[1]);
assertEquals(1, outputParam.shape()[2]);
double[] expectedRow = expectedDQNOutput[i];
- for(int j = 0; j < expectedRow.length; ++j) {
- assertEquals("row: "+ i + " col: " + j, expectedRow[j], 255.0 * outputParam.getDouble(j), 0.00001);
+ for (int j = 0; j < expectedRow.length; ++j) {
+ assertEquals("row: " + i + " col: " + j, expectedRow[j], 255.0 * outputParam.getDouble(j), 0.00001);
}
}
// MDP calls
- assertArrayEquals(new Integer[] {0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 4, 4, 2, 2, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4 }, mdp.actions.toArray());
+ assertArrayEquals(new Integer[]{0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 4, 4, 2, 2, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4}, mdp.actions.toArray());
// ExpReplay calls
- double[] expectedTrRewards = new double[] { 9.0, 21.0, 25.0, 29.0, 33.0, 37.0, 41.0 };
- int[] expectedTrActions = new int[] { 1, 4, 2, 4, 4, 4, 4, 4 };
- double[] expectedTrNextObservation = new double[] { 2.0, 4.0, 6.0, 8.0, 10.0, 12.0, 14.0, 16.0 };
- double[][] expectedTrObservations = new double[][] {
- new double[] { 0.0, 2.0, 4.0, 6.0, 8.0 },
- new double[] { 2.0, 4.0, 6.0, 8.0, 10.0 },
- new double[] { 4.0, 6.0, 8.0, 10.0, 12.0 },
- new double[] { 6.0, 8.0, 10.0, 12.0, 14.0 },
- new double[] { 8.0, 10.0, 12.0, 14.0, 16.0 },
- new double[] { 10.0, 12.0, 14.0, 16.0, 18.0 },
- new double[] { 12.0, 14.0, 16.0, 18.0, 20.0 },
- new double[] { 14.0, 16.0, 18.0, 20.0, 22.0 },
+ double[] expectedTrRewards = new double[]{9.0, 21.0, 25.0, 29.0, 33.0, 37.0, 41.0};
+ int[] expectedTrActions = new int[]{1, 4, 2, 4, 4, 4, 4, 4};
+ double[] expectedTrNextObservation = new double[]{2.0, 4.0, 6.0, 8.0, 10.0, 12.0, 14.0, 16.0};
+ double[][] expectedTrObservations = new double[][]{
+ new double[]{0.0, 2.0, 4.0, 6.0, 8.0},
+ new double[]{2.0, 4.0, 6.0, 8.0, 10.0},
+ new double[]{4.0, 6.0, 8.0, 10.0, 12.0},
+ new double[]{6.0, 8.0, 10.0, 12.0, 14.0},
+ new double[]{8.0, 10.0, 12.0, 14.0, 16.0},
+ new double[]{10.0, 12.0, 14.0, 16.0, 18.0},
+ new double[]{12.0, 14.0, 16.0, 18.0, 20.0},
+ new double[]{14.0, 16.0, 18.0, 20.0, 22.0},
};
assertEquals(expectedTrObservations.length, expReplay.transitions.size());
- for(int i = 0; i < expectedTrRewards.length; ++i) {
+ for (int i = 0; i < expectedTrRewards.length; ++i) {
Transition tr = expReplay.transitions.get(i);
assertEquals(expectedTrRewards[i], tr.getReward(), 0.0001);
assertEquals(expectedTrActions[i], tr.getAction());
assertEquals(expectedTrNextObservation[i], 255.0 * tr.getNextObservation().getDouble(0), 0.0001);
- for(int j = 0; j < expectedTrObservations[i].length; ++j) {
- assertEquals("row: "+ i + " col: " + j, expectedTrObservations[i][j], 255.0 * tr.getObservation().getData().getDouble(0, j, 0), 0.0001);
+ for (int j = 0; j < expectedTrObservations[i].length; ++j) {
+ assertEquals("row: " + i + " col: " + j, expectedTrObservations[i][j], 255.0 * tr.getObservation().getData().getDouble(0, j, 0), 0.0001);
}
}
@@ -132,12 +162,12 @@ public class QLearningDiscreteTest {
assertEquals(initStepCount + 16, result.getStepCounter());
assertEquals(300.0, result.getReward(), 0.00001);
assertTrue(dqn.hasBeenReset);
- assertTrue(((MockDQN)sut.getTargetQNetwork()).hasBeenReset);
+ assertTrue(((MockDQN) sut.getTargetQNetwork()).hasBeenReset);
}
public static class TestQLearningDiscrete extends QLearningDiscrete {
public TestQLearningDiscrete(MDP mdp, IDQN dqn,
- QLConfiguration conf, IDataManager dataManager, MockExpReplay expReplay,
+ QLearningConfiguration conf, IDataManager dataManager, MockExpReplay expReplay,
int epsilonNbStep, Random rnd) {
super(mdp, dqn, conf, epsilonNbStep, rnd);
addListener(new DataManagerTrainingListener(dataManager));
@@ -146,10 +176,10 @@ public class QLearningDiscreteTest {
@Override
protected DataSet setTarget(ArrayList> transitions) {
- return new org.nd4j.linalg.dataset.DataSet(Nd4j.create(new double[] { 123.0 }), Nd4j.create(new double[] { 234.0 }));
+ return new org.nd4j.linalg.dataset.DataSet(Nd4j.create(new double[]{123.0}), Nd4j.create(new double[]{234.0}));
}
- public void setExpReplay(IExpReplay exp){
+ public void setExpReplay(IExpReplay exp) {
this.expReplay = exp;
}
diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/network/ac/ActorCriticTest.java b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/network/ac/ActorCriticTest.java
index c43c26d50..821863054 100644
--- a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/network/ac/ActorCriticTest.java
+++ b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/network/ac/ActorCriticTest.java
@@ -1,5 +1,6 @@
/*******************************************************************************
- * Copyright (c) 2015-2018 Skymind, Inc.
+ * Copyright (c) 2015-2019 Skymind, Inc.
+ * Copyright (c) 2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
@@ -16,6 +17,7 @@
package org.deeplearning4j.rl4j.network.ac;
+import org.deeplearning4j.rl4j.network.configuration.ActorCriticDenseNetworkConfiguration;
import org.junit.Test;
import org.nd4j.linalg.activations.impl.ActivationSoftmax;
import org.nd4j.linalg.api.ndarray.INDArray;
@@ -29,30 +31,31 @@ import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;
/**
- *
* @author saudet
*/
public class ActorCriticTest {
- public static ActorCriticFactorySeparateStdDense.Configuration NET_CONF =
- new ActorCriticFactorySeparateStdDense.Configuration(
- 4, //number of layers
- 32, //number of hidden nodes
- 0.001, //l2 regularization
- new RmsProp(0.0005), null, false
- );
+ public static ActorCriticDenseNetworkConfiguration NET_CONF =
+ ActorCriticDenseNetworkConfiguration.builder()
+ .numLayers(4)
+ .numHiddenNodes(32)
+ .l2(0.001)
+ .updater(new RmsProp(0.0005))
+ .useLSTM(false)
+ .build();
- public static ActorCriticFactoryCompGraphStdDense.Configuration NET_CONF_CG =
- new ActorCriticFactoryCompGraphStdDense.Configuration(
- 2, //number of layers
- 128, //number of hidden nodes
- 0.00001, //l2 regularization
- new RmsProp(0.005), null, true
- );
+ public static ActorCriticDenseNetworkConfiguration NET_CONF_CG =
+ ActorCriticDenseNetworkConfiguration.builder()
+ .numLayers(2)
+ .numHiddenNodes(128)
+ .l2(0.00001)
+ .updater(new RmsProp(0.005))
+ .useLSTM(true)
+ .build();
@Test
public void testModelLoadSave() throws IOException {
- ActorCriticSeparate acs = new ActorCriticFactorySeparateStdDense(NET_CONF).buildActorCritic(new int[] {7}, 5);
+ ActorCriticSeparate acs = new ActorCriticFactorySeparateStdDense(NET_CONF).buildActorCritic(new int[]{7}, 5);
File fileValue = File.createTempFile("rl4j-value-", ".model");
File filePolicy = File.createTempFile("rl4j-policy-", ".model");
@@ -63,7 +66,7 @@ public class ActorCriticTest {
assertEquals(acs.valueNet, acs2.valueNet);
assertEquals(acs.policyNet, acs2.policyNet);
- ActorCriticCompGraph accg = new ActorCriticFactoryCompGraphStdDense(NET_CONF_CG).buildActorCritic(new int[] {37}, 43);
+ ActorCriticCompGraph accg = new ActorCriticFactoryCompGraphStdDense(NET_CONF_CG).buildActorCritic(new int[]{37}, 43);
File file = File.createTempFile("rl4j-cg-", ".model");
accg.save(file.getAbsolutePath());
@@ -83,15 +86,15 @@ public class ActorCriticTest {
for (double i = eps; i < n; i++) {
for (double j = eps; j < n; j++) {
- INDArray labels = Nd4j.create(new double[] {i / n, 1 - i / n}, new long[]{1,2});
- INDArray output = Nd4j.create(new double[] {j / n, 1 - j / n}, new long[]{1,2});
+ INDArray labels = Nd4j.create(new double[]{i / n, 1 - i / n}, new long[]{1, 2});
+ INDArray output = Nd4j.create(new double[]{j / n, 1 - j / n}, new long[]{1, 2});
INDArray gradient = loss.computeGradient(labels, output, activation, null);
- output = Nd4j.create(new double[] {j / n, 1 - j / n}, new long[]{1,2});
+ output = Nd4j.create(new double[]{j / n, 1 - j / n}, new long[]{1, 2});
double score = loss.computeScore(labels, output, activation, null, false);
- INDArray output1 = Nd4j.create(new double[] {j / n + eps, 1 - j / n}, new long[]{1,2});
+ INDArray output1 = Nd4j.create(new double[]{j / n + eps, 1 - j / n}, new long[]{1, 2});
double score1 = loss.computeScore(labels, output1, activation, null, false);
- INDArray output2 = Nd4j.create(new double[] {j / n, 1 - j / n + eps}, new long[]{1,2});
+ INDArray output2 = Nd4j.create(new double[]{j / n, 1 - j / n + eps}, new long[]{1, 2});
double score2 = loss.computeScore(labels, output2, activation, null, false);
double gradient1 = (score1 - score) / eps;
diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/network/dqn/DQNTest.java b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/network/dqn/DQNTest.java
index 3f68b8f3c..a9997ec0c 100644
--- a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/network/dqn/DQNTest.java
+++ b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/network/dqn/DQNTest.java
@@ -1,5 +1,6 @@
/*******************************************************************************
- * Copyright (c) 2015-2018 Skymind, Inc.
+ * Copyright (c) 2015-2019 Skymind, Inc.
+ * Copyright (c) 2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
@@ -16,6 +17,7 @@
package org.deeplearning4j.rl4j.network.dqn;
+import org.deeplearning4j.rl4j.network.configuration.DQNDenseNetworkConfiguration;
import org.junit.Test;
import org.nd4j.linalg.learning.config.RmsProp;
@@ -25,22 +27,20 @@ import java.io.IOException;
import static org.junit.Assert.assertEquals;
/**
- *
* @author saudet
*/
public class DQNTest {
- public static DQNFactoryStdDense.Configuration NET_CONF =
- new DQNFactoryStdDense.Configuration(
- 3, //number of layers
- 16, //number of hidden nodes
- 0.001, //l2 regularization
- new RmsProp(0.0005), null
- );
+ private static DQNDenseNetworkConfiguration NET_CONF =
+ DQNDenseNetworkConfiguration.builder().numLayers(3)
+ .numHiddenNodes(16)
+ .l2(0.001)
+ .updater(new RmsProp(0.0005))
+ .build();
@Test
public void testModelLoadSave() throws IOException {
- DQN dqn = new DQNFactoryStdDense(NET_CONF).buildDQN(new int[] {42}, 13);
+ DQN dqn = new DQNFactoryStdDense(NET_CONF).buildDQN(new int[]{42}, 13);
File file = File.createTempFile("rl4j-dqn-", ".model");
dqn.save(file.getAbsolutePath());
diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/observation/transform/TransformProcessTest.java b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/observation/transform/TransformProcessTest.java
index fe79bdfc7..3f5e761a6 100644
--- a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/observation/transform/TransformProcessTest.java
+++ b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/observation/transform/TransformProcessTest.java
@@ -128,7 +128,7 @@ public class TransformProcessTest {
// Assert
assertFalse(result.isSkipped());
- assertEquals(1, result.getData().shape().length);
+ assertEquals(2, result.getData().shape().length);
assertEquals(1, result.getData().shape()[0]);
assertEquals(-10.0, result.getData().getDouble(0), 0.00001);
}
diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/policy/PolicyTest.java b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/policy/PolicyTest.java
index 0707e16ab..0dc16df09 100644
--- a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/policy/PolicyTest.java
+++ b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/policy/PolicyTest.java
@@ -1,5 +1,6 @@
/*******************************************************************************
- * Copyright (c) 2015-2018 Skymind, Inc.
+ * Copyright (c) 2015-2019 Skymind, Inc.
+ * Copyright (c) 2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
@@ -24,16 +25,18 @@ import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.rl4j.learning.IHistoryProcessor;
import org.deeplearning4j.rl4j.learning.Learning;
-import org.deeplearning4j.rl4j.learning.sync.qlearning.QLearning;
-import org.deeplearning4j.rl4j.learning.sync.qlearning.discrete.QLearningDiscreteTest;
-import org.deeplearning4j.rl4j.mdp.MDP;
+import org.deeplearning4j.rl4j.learning.configuration.QLearningConfiguration;
import org.deeplearning4j.rl4j.network.NeuralNet;
import org.deeplearning4j.rl4j.network.ac.IActorCritic;
import org.deeplearning4j.rl4j.observation.Observation;
import org.deeplearning4j.rl4j.space.ActionSpace;
-import org.deeplearning4j.rl4j.space.DiscreteSpace;
-import org.deeplearning4j.rl4j.space.Encodable;
-import org.deeplearning4j.rl4j.support.*;
+import org.deeplearning4j.rl4j.support.MockDQN;
+import org.deeplearning4j.rl4j.support.MockEncodable;
+import org.deeplearning4j.rl4j.support.MockHistoryProcessor;
+import org.deeplearning4j.rl4j.support.MockMDP;
+import org.deeplearning4j.rl4j.support.MockNeuralNet;
+import org.deeplearning4j.rl4j.support.MockObservationSpace;
+import org.deeplearning4j.rl4j.support.MockRandom;
import org.deeplearning4j.rl4j.util.LegacyMDPWrapper;
import org.junit.Test;
import org.nd4j.linalg.activations.Activation;
@@ -43,8 +46,6 @@ import org.nd4j.linalg.lossfunctions.LossFunctions;
import java.io.IOException;
import java.io.OutputStream;
-import java.util.ArrayList;
-import java.util.List;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotNull;
@@ -186,8 +187,22 @@ public class PolicyTest {
new int[] { 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4 });
MockMDP mdp = new MockMDP(observationSpace, 30, random);
- QLearning.QLConfiguration conf = new QLearning.QLConfiguration(0, 0, 0, 5, 1, 0,
- 0, 1.0, 0, 0, 0, 0, true);
+ QLearningConfiguration conf = QLearningConfiguration.builder()
+ .seed(0L)
+ .maxEpochStep(0)
+ .maxStep(0)
+ .expRepMaxSize(5)
+ .batchSize(1)
+ .targetDqnUpdateFreq(0)
+ .updateStart(0)
+ .rewardFactor(1.0)
+ .gamma(0)
+ .errorClamp(0)
+ .minEpsilon(0)
+ .epsilonNbStep(0)
+ .doubleDQN(true)
+ .build();
+
MockNeuralNet nnMock = new MockNeuralNet();
IHistoryProcessor.Configuration hpConf = new IHistoryProcessor.Configuration(5, 4, 4, 4, 4, 0, 0, 2);
MockRefacPolicy sut = new MockRefacPolicy(nnMock, observationSpace.getShape(), hpConf.getSkipFrame(), hpConf.getHistoryLength());
diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/support/MockAsyncConfiguration.java b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/support/MockAsyncConfiguration.java
index 56581cc0d..08689b032 100644
--- a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/support/MockAsyncConfiguration.java
+++ b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/support/MockAsyncConfiguration.java
@@ -1,22 +1,37 @@
+/*******************************************************************************
+ * Copyright (c) 2020 Konduit K.K.
+ *
+ * This program and the accompanying materials are made available under the
+ * terms of the Apache License, Version 2.0 which is available at
+ * https://www.apache.org/licenses/LICENSE-2.0.
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+ * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+ * License for the specific language governing permissions and limitations
+ * under the License.
+ *
+ * SPDX-License-Identifier: Apache-2.0
+ ******************************************************************************/
+
package org.deeplearning4j.rl4j.support;
import lombok.AllArgsConstructor;
-import lombok.Getter;
import lombok.Value;
-import org.deeplearning4j.rl4j.learning.async.AsyncConfiguration;
+import org.deeplearning4j.rl4j.learning.configuration.IAsyncLearningConfiguration;
-@AllArgsConstructor
@Value
-public class MockAsyncConfiguration implements AsyncConfiguration {
+@AllArgsConstructor
+public class MockAsyncConfiguration implements IAsyncLearningConfiguration {
- private Integer seed;
+ private Long seed;
private int maxEpochStep;
private int maxStep;
- private int numThread;
- private int nstep;
- private int targetDqnUpdateFreq;
private int updateStart;
private double rewardFactor;
private double gamma;
private double errorClamp;
+ private int numThreads;
+ private int nStep;
+ private int learnerUpdateFrequency;
}
diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/util/DataManagerTrainingListenerTest.java b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/util/DataManagerTrainingListenerTest.java
index a3a5598d4..3a2d5230a 100644
--- a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/util/DataManagerTrainingListenerTest.java
+++ b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/util/DataManagerTrainingListenerTest.java
@@ -1,3 +1,20 @@
+/*******************************************************************************
+ * Copyright (c) 2015-2019 Skymind, Inc.
+ * Copyright (c) 2020 Konduit K.K.
+ *
+ * This program and the accompanying materials are made available under the
+ * terms of the Apache License, Version 2.0 which is available at
+ * https://www.apache.org/licenses/LICENSE-2.0.
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+ * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+ * License for the specific language governing permissions and limitations
+ * under the License.
+ *
+ * SPDX-License-Identifier: Apache-2.0
+ ******************************************************************************/
+
package org.deeplearning4j.rl4j.util;
import lombok.Getter;
@@ -5,6 +22,7 @@ import lombok.Setter;
import org.deeplearning4j.rl4j.learning.IEpochTrainer;
import org.deeplearning4j.rl4j.learning.IHistoryProcessor;
import org.deeplearning4j.rl4j.learning.ILearning;
+import org.deeplearning4j.rl4j.learning.configuration.ILearningConfiguration;
import org.deeplearning4j.rl4j.learning.listener.TrainingListener;
import org.deeplearning4j.rl4j.learning.sync.support.MockStatEntry;
import org.deeplearning4j.rl4j.mdp.MDP;
@@ -162,7 +180,7 @@ public class DataManagerTrainingListenerTest {
}
@Override
- public LConfiguration getConfiguration() {
+ public ILearningConfiguration getConfiguration() {
return null;
}
diff --git a/rl4j/rl4j-malmo/src/main/java/org/deeplearning4j/malmo/MalmoObservationSpaceGrid.java b/rl4j/rl4j-malmo/src/main/java/org/deeplearning4j/malmo/MalmoObservationSpaceGrid.java
index 7400657ef..00b7c4f7a 100644
--- a/rl4j/rl4j-malmo/src/main/java/org/deeplearning4j/malmo/MalmoObservationSpaceGrid.java
+++ b/rl4j/rl4j-malmo/src/main/java/org/deeplearning4j/malmo/MalmoObservationSpaceGrid.java
@@ -1,5 +1,6 @@
/*******************************************************************************
- * Copyright (c) 2015-2018 Skymind, Inc.
+ * Copyright (c) 2015-2019 Skymind, Inc.
+ * Copyright (c) 2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
@@ -16,19 +17,18 @@
package org.deeplearning4j.malmo;
-import java.util.HashMap;
-
+import com.microsoft.msr.malmo.TimestampedStringVector;
+import com.microsoft.msr.malmo.WorldState;
import org.json.JSONArray;
import org.json.JSONObject;
-
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
-import com.microsoft.msr.malmo.TimestampedStringVector;
-import com.microsoft.msr.malmo.WorldState;
+import java.util.HashMap;
/**
* Observation space that contains a grid of Minecraft blocks
+ *
* @author howard-abrams (howard.abrams@ca.com) on 1/12/17.
*/
public class MalmoObservationSpaceGrid extends MalmoObservationSpace {
@@ -43,11 +43,11 @@ public class MalmoObservationSpaceGrid extends MalmoObservationSpace {
/**
* Construct observation space from a array of blocks policy should distinguish between.
- *
- * @param name Name given to Grid element in mission specification
- * @param xSize total x size of grid
- * @param ySize total y size of grid
- * @param zSize total z size of grid
+ *
+ * @param name Name given to Grid element in mission specification
+ * @param xSize total x size of grid
+ * @param ySize total y size of grid
+ * @param zSize total z size of grid
* @param blocks Array of block names to distinguish between. Supports combination of individual strings and/or arrays of strings to map multiple block types to a single observation value. If not specified, it will dynamically map block names to integers - however, because these will be mapped as they are seen, different missions may have different mappings!
*/
public MalmoObservationSpaceGrid(String name, int xSize, int ySize, int zSize, Object... blocks) {
@@ -78,7 +78,7 @@ public class MalmoObservationSpaceGrid extends MalmoObservationSpace {
@Override
public int[] getShape() {
- return new int[] {totalSize};
+ return new int[]{totalSize};
}
@Override