From 4190c9ee0f6c764ecb3b3a1a658466ed1c342da7 Mon Sep 17 00:00:00 2001 From: Alexandre Boulanger <44292157+aboulang2002@users.noreply.github.com> Date: Sun, 28 Jun 2020 23:20:53 -0400 Subject: [PATCH 1/2] RL4J: Add NeuralNetUpdater (#500) Signed-off-by: Alexandre Boulanger --- .../agent/update/DQNNeuralNetUpdateRule.java | 18 +++--- .../neuralnetupdater/INeuralNetUpdater.java | 30 +++++++++ .../neuralnetupdater/NeuralNetUpdater.java | 62 +++++++++++++++++++ .../rl4j/learning/async/AsyncThread.java | 7 ++- .../rl4j/learning/async/IAsyncGlobal.java | 6 +- .../BaseTDTargetAlgorithm.java | 2 +- .../TDTargetAlgorithm/ITDTargetAlgorithm.java | 2 +- .../rl4j/network/ITrainableNeuralNet.java | 40 ++++++++++++ .../rl4j/network/NeuralNet.java | 14 +---- .../rl4j/network/ac/ActorCriticCompGraph.java | 22 +++++++ .../rl4j/network/ac/ActorCriticSeparate.java | 24 +++++++ .../rl4j/network/ac/IActorCritic.java | 17 ----- .../deeplearning4j/rl4j/network/dqn/DQN.java | 20 ++++-- .../deeplearning4j/rl4j/network/dqn/IDQN.java | 20 +----- .../NeuralNetUpdaterTest.java | 51 +++++++++++++++ .../TDTargetAlgorithm/DoubleDQNTest.java | 6 +- .../TDTargetAlgorithm/StandardDQNTest.java | 6 +- .../rl4j/learning/sync/support/MockDQN.java | 22 ++++--- .../rl4j/policy/PolicyTest.java | 16 +++++ .../deeplearning4j/rl4j/support/MockDQN.java | 22 ++++--- .../rl4j/support/MockNeuralNet.java | 24 ++++++- 21 files changed, 329 insertions(+), 102 deletions(-) create mode 100644 rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/agent/update/neuralnetupdater/INeuralNetUpdater.java create mode 100644 rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/agent/update/neuralnetupdater/NeuralNetUpdater.java create mode 100644 rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/ITrainableNeuralNet.java create mode 100644 rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/agent/update/neuralnetupdater/NeuralNetUpdaterTest.java diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/agent/update/DQNNeuralNetUpdateRule.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/agent/update/DQNNeuralNetUpdateRule.java index 98873b827..c359e02ce 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/agent/update/DQNNeuralNetUpdateRule.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/agent/update/DQNNeuralNetUpdateRule.java @@ -16,6 +16,8 @@ package org.deeplearning4j.rl4j.agent.update; import lombok.Getter; +import org.deeplearning4j.rl4j.agent.update.neuralnetupdater.INeuralNetUpdater; +import org.deeplearning4j.rl4j.agent.update.neuralnetupdater.NeuralNetUpdater; import org.deeplearning4j.rl4j.learning.sync.Transition; import org.deeplearning4j.rl4j.learning.sync.qlearning.discrete.TDTargetAlgorithm.DoubleDQN; import org.deeplearning4j.rl4j.learning.sync.qlearning.discrete.TDTargetAlgorithm.ITDTargetAlgorithm; @@ -29,30 +31,26 @@ import java.util.List; // and network update to sub components. public class DQNNeuralNetUpdateRule implements IUpdateRule> { - private final IDQN qNetwork; private final IDQN targetQNetwork; - private final int targetUpdateFrequency; + private final INeuralNetUpdater updater; private final ITDTargetAlgorithm tdTargetAlgorithm; @Getter private int updateCount = 0; - public DQNNeuralNetUpdateRule(IDQN qNetwork, int targetUpdateFrequency, boolean isDoubleDQN, double gamma, double errorClamp) { - this.qNetwork = qNetwork; + public DQNNeuralNetUpdateRule(IDQN qNetwork, int targetUpdateFrequency, boolean isDoubleDQN, double gamma, double errorClamp) { this.targetQNetwork = qNetwork.clone(); - this.targetUpdateFrequency = targetUpdateFrequency; tdTargetAlgorithm = isDoubleDQN ? new DoubleDQN(qNetwork, targetQNetwork, gamma, errorClamp) : new StandardDQN(qNetwork, targetQNetwork, gamma, errorClamp); + updater = new NeuralNetUpdater(qNetwork, targetQNetwork, targetUpdateFrequency); + } @Override public void update(List> trainingBatch) { - DataSet targets = tdTargetAlgorithm.computeTDTargets(trainingBatch); - qNetwork.fit(targets.getFeatures(), targets.getLabels()); - if(++updateCount % targetUpdateFrequency == 0) { - targetQNetwork.copy(qNetwork); - } + DataSet targets = tdTargetAlgorithm.compute(trainingBatch); + updater.update(targets); } } diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/agent/update/neuralnetupdater/INeuralNetUpdater.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/agent/update/neuralnetupdater/INeuralNetUpdater.java new file mode 100644 index 000000000..f17c4f11e --- /dev/null +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/agent/update/neuralnetupdater/INeuralNetUpdater.java @@ -0,0 +1,30 @@ +/******************************************************************************* + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ +package org.deeplearning4j.rl4j.agent.update.neuralnetupdater; + +import org.nd4j.linalg.dataset.api.DataSet; + +/** + * The role of INeuralNetUpdater implementations is to update a {@link org.deeplearning4j.rl4j.network.NeuralNet NeuralNet} + * from a {@link DataSet}.

+ */ +public interface INeuralNetUpdater { + /** + * Update a {@link org.deeplearning4j.rl4j.network.NeuralNet NeuralNet}. + * @param featuresLabels A Dataset that will be used to update the network. + */ + void update(DataSet featuresLabels); +} diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/agent/update/neuralnetupdater/NeuralNetUpdater.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/agent/update/neuralnetupdater/NeuralNetUpdater.java new file mode 100644 index 000000000..3dc778251 --- /dev/null +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/agent/update/neuralnetupdater/NeuralNetUpdater.java @@ -0,0 +1,62 @@ +/******************************************************************************* + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ +package org.deeplearning4j.rl4j.agent.update.neuralnetupdater; + +import org.deeplearning4j.rl4j.network.ITrainableNeuralNet; +import org.nd4j.linalg.dataset.api.DataSet; + +/** + * A {@link INeuralNetUpdater} that updates a neural network and sync a target network at defined intervals + */ +public class NeuralNetUpdater implements INeuralNetUpdater { + + private final ITrainableNeuralNet current; + private final ITrainableNeuralNet target; + + private int updateCount = 0; + private final int targetUpdateFrequency; + + /** + * @param current The current {@link ITrainableNeuralNet network} + * @param target The target {@link ITrainableNeuralNet network} + * @param targetUpdateFrequency Will synchronize the target network at every targetUpdateFrequency updates + */ + public NeuralNetUpdater(ITrainableNeuralNet current, + ITrainableNeuralNet target, + int targetUpdateFrequency) { + this.current = current; + this.target = target; + + this.targetUpdateFrequency = targetUpdateFrequency; + } + + /** + * Update the current network + * @param featuresLabels A Dataset that will be used to update the network. + */ + @Override + public void update(DataSet featuresLabels) { + current.fit(featuresLabels); + syncTargetNetwork(); + } + + private void syncTargetNetwork() { + if(++updateCount % targetUpdateFrequency == 0) { + target.copy(current); + } + } + +} diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/AsyncThread.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/AsyncThread.java index 26d8d5e02..74123d0d2 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/AsyncThread.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/AsyncThread.java @@ -23,16 +23,19 @@ import lombok.Setter; import lombok.Value; import lombok.extern.slf4j.Slf4j; import org.deeplearning4j.gym.StepReply; -import org.deeplearning4j.rl4j.learning.*; +import org.deeplearning4j.rl4j.learning.HistoryProcessor; +import org.deeplearning4j.rl4j.learning.IEpochTrainer; +import org.deeplearning4j.rl4j.learning.IHistoryProcessor; +import org.deeplearning4j.rl4j.learning.Learning; import org.deeplearning4j.rl4j.learning.configuration.IAsyncLearningConfiguration; import org.deeplearning4j.rl4j.learning.listener.TrainingListener; import org.deeplearning4j.rl4j.learning.listener.TrainingListenerList; import org.deeplearning4j.rl4j.mdp.MDP; import org.deeplearning4j.rl4j.network.NeuralNet; -import org.deeplearning4j.rl4j.space.Encodable; import org.deeplearning4j.rl4j.observation.Observation; import org.deeplearning4j.rl4j.policy.IPolicy; import org.deeplearning4j.rl4j.space.ActionSpace; +import org.deeplearning4j.rl4j.space.Encodable; import org.deeplearning4j.rl4j.util.IDataManager; import org.deeplearning4j.rl4j.util.LegacyMDPWrapper; import org.nd4j.linalg.factory.Nd4j; diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/IAsyncGlobal.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/IAsyncGlobal.java index b9725499a..8f6bf46bf 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/IAsyncGlobal.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/IAsyncGlobal.java @@ -18,11 +18,9 @@ package org.deeplearning4j.rl4j.learning.async; import org.deeplearning4j.nn.gradient.Gradient; -import org.deeplearning4j.rl4j.network.NeuralNet; +import org.deeplearning4j.rl4j.network.ITrainableNeuralNet; -import java.util.concurrent.atomic.AtomicInteger; - -public interface IAsyncGlobal { +public interface IAsyncGlobal { boolean isTrainingComplete(); diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/TDTargetAlgorithm/BaseTDTargetAlgorithm.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/TDTargetAlgorithm/BaseTDTargetAlgorithm.java index e0ede18d7..460567a86 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/TDTargetAlgorithm/BaseTDTargetAlgorithm.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/TDTargetAlgorithm/BaseTDTargetAlgorithm.java @@ -80,7 +80,7 @@ public abstract class BaseTDTargetAlgorithm implements ITDTargetAlgorithm> transitions) { + public DataSet compute(List> transitions) { int size = transitions.size(); diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/TDTargetAlgorithm/ITDTargetAlgorithm.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/TDTargetAlgorithm/ITDTargetAlgorithm.java index 199c0e7e3..cd7588bfe 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/TDTargetAlgorithm/ITDTargetAlgorithm.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/TDTargetAlgorithm/ITDTargetAlgorithm.java @@ -34,5 +34,5 @@ public interface ITDTargetAlgorithm { * @param transitions The transitions from the experience replay * @return A DataSet where every element is the observation and the estimated Q-Values for all actions */ - DataSet computeTDTargets(List> transitions); + DataSet compute(List> transitions); } diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/ITrainableNeuralNet.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/ITrainableNeuralNet.java new file mode 100644 index 000000000..da91d7e6d --- /dev/null +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/ITrainableNeuralNet.java @@ -0,0 +1,40 @@ +/******************************************************************************* + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ +package org.deeplearning4j.rl4j.network; + +import org.nd4j.linalg.dataset.api.DataSet; + +/** + * An interface defining the trainable aspect of a {@link NeuralNet}. + */ +public interface ITrainableNeuralNet { + /** + * Train the neural net using the supplied feature-labels + * @param featuresLabels The feature-labels + */ + void fit(DataSet featuresLabels); + + /** + * Changes this instance to be a copy of the from network. + * @param from The network that will be the source of the copy. + */ + void copy(NET_TYPE from); + + /** + * Creates a clone of the network instance. + */ + NET_TYPE clone(); +} diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/NeuralNet.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/NeuralNet.java index 38c91562a..7823ea906 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/NeuralNet.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/NeuralNet.java @@ -29,7 +29,7 @@ import java.io.OutputStream; * Factorisation between ActorCritic and DQN neural net. * Useful for AsyncLearning and Thread code. */ -public interface NeuralNet { +public interface NeuralNet extends IOutputNeuralNet, ITrainableNeuralNet { /** * Returns the underlying MultiLayerNetwork or ComputationGraph objects. @@ -52,18 +52,6 @@ public interface NeuralNet { */ INDArray[] outputAll(INDArray batch); - /** - * clone the Neural Net with the same paramaeters - * @return the cloned neural net - */ - NN clone(); - - /** - * copy the parameters from a neural net - * @param from where to copy parameters - */ - void copy(NN from); - /** * Calculate the gradients from input and label (target) of all outputs * @param input input batch diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/ac/ActorCriticCompGraph.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/ac/ActorCriticCompGraph.java index 63438bb74..6e37fb0f3 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/ac/ActorCriticCompGraph.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/ac/ActorCriticCompGraph.java @@ -18,6 +18,7 @@ package org.deeplearning4j.rl4j.network.ac; import lombok.Getter; +import org.apache.commons.lang3.NotImplementedException; import org.deeplearning4j.nn.api.NeuralNetwork; import org.deeplearning4j.nn.conf.ComputationGraphConfiguration; import org.deeplearning4j.nn.gradient.Gradient; @@ -25,8 +26,10 @@ import org.deeplearning4j.nn.graph.ComputationGraph; import org.deeplearning4j.nn.layers.recurrent.RnnOutputLayer; import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr; import org.deeplearning4j.optimize.api.TrainingListener; +import org.deeplearning4j.rl4j.observation.Observation; import org.deeplearning4j.util.ModelSerializer; import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.dataset.api.DataSet; import java.io.IOException; import java.io.OutputStream; @@ -80,6 +83,11 @@ public class ActorCriticCompGraph implements IActorCritic return nn; } + @Override + public void fit(DataSet featuresLabels) { + fit(featuresLabels.getFeatures(), new INDArray[] { featuresLabels.getLabels() }); + } + public void copy(ActorCriticCompGraph from) { cg.setParams(from.cg.params()); } @@ -137,5 +145,19 @@ public class ActorCriticCompGraph implements IActorCritic public void save(String pathValue, String pathPolicy) throws IOException { throw new UnsupportedOperationException("Call save(path)"); } + + @Override + public INDArray output(Observation observation) { + // TODO: signature of output() will change to return a class that has named outputs to support network like + // this one (output from the value-network and another output for the policy-network + throw new NotImplementedException("Not implemented: will be done with AgentLearner async support"); + } + + @Override + public INDArray output(INDArray batch) { + // TODO: signature of output() will change to return a class that has named outputs to support network like + // this one (output from the value-network and another output for the policy-network + throw new NotImplementedException("Not implemented: will be done with AgentLearner async support"); + } } diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/ac/ActorCriticSeparate.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/ac/ActorCriticSeparate.java index 474a99428..a086b6065 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/ac/ActorCriticSeparate.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/ac/ActorCriticSeparate.java @@ -17,6 +17,7 @@ package org.deeplearning4j.rl4j.network.ac; import lombok.Getter; +import org.apache.commons.lang3.NotImplementedException; import org.deeplearning4j.nn.api.NeuralNetwork; import org.deeplearning4j.nn.conf.MultiLayerConfiguration; import org.deeplearning4j.nn.gradient.Gradient; @@ -24,8 +25,10 @@ import org.deeplearning4j.nn.layers.recurrent.RnnOutputLayer; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr; import org.deeplearning4j.optimize.api.TrainingListener; +import org.deeplearning4j.rl4j.observation.Observation; import org.deeplearning4j.util.ModelSerializer; import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.dataset.api.DataSet; import java.io.IOException; import java.io.OutputStream; @@ -86,6 +89,13 @@ public class ActorCriticSeparate implements IAct return nn; } + @Override + public void fit(DataSet featuresLabels) { + // TODO: signature of fit() will change from DataSet to a class that has named labels to support network like + // this one (labels for the value-network and another labels for the policy-network + throw new NotImplementedException("Not implemented: will be done with AgentLearner async support"); + } + public void copy(NN from) { valueNet.setParams(from.valueNet.params()); policyNet.setParams(from.policyNet.params()); @@ -164,6 +174,20 @@ public class ActorCriticSeparate implements IAct ModelSerializer.writeModel(valueNet, pathValue, true); ModelSerializer.writeModel(policyNet, pathPolicy, true); } + + @Override + public INDArray output(Observation observation) { + // TODO: signature of output() will change to return a class that has named outputs to support network like + // this one (output from the value-network and another output for the policy-network + throw new NotImplementedException("Not implemented: will be done with AgentLearner async support"); + } + + @Override + public INDArray output(INDArray batch) { + // TODO: signature of output() will change to return a class that has named outputs to support network like + // this one (output from the value-network and another output for the policy-network + throw new NotImplementedException("Not implemented: will be done with AgentLearner async support"); + } } diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/ac/IActorCritic.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/ac/IActorCritic.java index 47b309c2f..fd49b92b1 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/ac/IActorCritic.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/ac/IActorCritic.java @@ -16,7 +16,6 @@ package org.deeplearning4j.rl4j.network.ac; -import org.deeplearning4j.nn.gradient.Gradient; import org.deeplearning4j.rl4j.network.NeuralNet; import org.nd4j.linalg.api.ndarray.INDArray; @@ -33,27 +32,11 @@ import java.io.OutputStream; */ public interface IActorCritic extends NeuralNet { - boolean isRecurrent(); - - void reset(); - - void fit(INDArray input, INDArray[] labels); - //FIRST SHOULD BE VALUE AND SECOND IS SOFTMAX POLICY. DONT MESS THIS UP OR ELSE ASYNC THREAD IS BROKEN (maxQ) ! INDArray[] outputAll(INDArray batch); - NN clone(); - - void copy(NN from); - - Gradient[] gradient(INDArray input, INDArray[] labels); - - void applyGradient(Gradient[] gradient, int batchSize); - void save(OutputStream streamValue, OutputStream streamPolicy) throws IOException; void save(String pathValue, String pathPolicy) throws IOException; - double getLatestScore(); - } diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/dqn/DQN.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/dqn/DQN.java index b3293c1b6..260ea6aa2 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/dqn/DQN.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/dqn/DQN.java @@ -25,6 +25,7 @@ import org.deeplearning4j.optimize.api.TrainingListener; import org.deeplearning4j.rl4j.observation.Observation; import org.deeplearning4j.util.ModelSerializer; import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.dataset.api.DataSet; import java.io.IOException; import java.io.OutputStream; @@ -33,7 +34,7 @@ import java.util.Collection; /** * @author rubenfiszel (ruben.fiszel@epfl.ch) 7/25/16. */ -public class DQN implements IDQN { +public class DQN implements IDQN { final protected MultiLayerNetwork mln; @@ -79,16 +80,23 @@ public class DQN implements IDQN { return new INDArray[] {output(batch)}; } - public NN clone() { - NN nn = (NN)new DQN(mln.clone()); - nn.mln.setListeners(mln.getListeners()); - return nn; + @Override + public void fit(DataSet featuresLabels) { + fit(featuresLabels.getFeatures(), featuresLabels.getLabels()); } - public void copy(NN from) { + @Override + public void copy(DQN from) { mln.setParams(from.mln.params()); } + @Override + public DQN clone() { + DQN nn = new DQN(mln.clone()); + nn.mln.setListeners(mln.getListeners()); + return nn; + } + public Gradient[] gradient(INDArray input, INDArray labels) { mln.setInput(input); mln.setLabels(labels); diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/dqn/IDQN.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/dqn/IDQN.java index daed646c5..1cb9a18d7 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/dqn/IDQN.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/dqn/IDQN.java @@ -17,9 +17,7 @@ package org.deeplearning4j.rl4j.network.dqn; import org.deeplearning4j.nn.gradient.Gradient; -import org.deeplearning4j.rl4j.network.IOutputNeuralNet; import org.deeplearning4j.rl4j.network.NeuralNet; -import org.deeplearning4j.rl4j.observation.Observation; import org.nd4j.linalg.api.ndarray.INDArray; /** @@ -28,27 +26,11 @@ import org.nd4j.linalg.api.ndarray.INDArray; * This neural net quantify the value of each action given a state * */ -public interface IDQN extends NeuralNet, IOutputNeuralNet { - - boolean isRecurrent(); - - void reset(); +public interface IDQN extends NeuralNet { void fit(INDArray input, INDArray labels); - void fit(INDArray input, INDArray[] labels); - - INDArray[] outputAll(INDArray batch); - - NN clone(); - - void copy(NN from); - Gradient[] gradient(INDArray input, INDArray label); Gradient[] gradient(INDArray input, INDArray[] label); - - void applyGradient(Gradient[] gradient, int batchSize); - - double getLatestScore(); } diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/agent/update/neuralnetupdater/NeuralNetUpdaterTest.java b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/agent/update/neuralnetupdater/NeuralNetUpdaterTest.java new file mode 100644 index 000000000..578da2d50 --- /dev/null +++ b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/agent/update/neuralnetupdater/NeuralNetUpdaterTest.java @@ -0,0 +1,51 @@ +package org.deeplearning4j.rl4j.agent.update.neuralnetupdater; + +import org.deeplearning4j.rl4j.network.ITrainableNeuralNet; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.mockito.Mock; +import org.mockito.junit.MockitoJUnitRunner; +import org.nd4j.linalg.dataset.api.DataSet; + +import static org.mockito.Mockito.*; + +@RunWith(MockitoJUnitRunner.class) +public class NeuralNetUpdaterTest { + + @Mock + ITrainableNeuralNet currentMock; + + @Mock + ITrainableNeuralNet targetMock; + + @Test + public void when_callingUpdate_expect_currentUpdatedAndtargetNotChanged() { + // Arrange + NeuralNetUpdater sut = new NeuralNetUpdater(currentMock, targetMock, Integer.MAX_VALUE); + DataSet featureLabels = new org.nd4j.linalg.dataset.DataSet(); + + // Act + sut.update(featureLabels); + + // Assert + verify(currentMock, times(1)).fit(featureLabels); + verify(targetMock, never()).fit(any()); + } + + @Test + public void when_callingUpdate_expect_targetUpdatedFromCurrentAtFrequency() { + // Arrange + NeuralNetUpdater sut = new NeuralNetUpdater(currentMock, targetMock, 3); + DataSet featureLabels = new org.nd4j.linalg.dataset.DataSet(); + + // Act + sut.update(featureLabels); + sut.update(featureLabels); + sut.update(featureLabels); + + // Assert + verify(currentMock, never()).copy(any()); + verify(targetMock, times(1)).copy(currentMock); + } + +} diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/TDTargetAlgorithm/DoubleDQNTest.java b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/TDTargetAlgorithm/DoubleDQNTest.java index 0f03a5370..d5a946a65 100644 --- a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/TDTargetAlgorithm/DoubleDQNTest.java +++ b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/TDTargetAlgorithm/DoubleDQNTest.java @@ -50,7 +50,7 @@ public class DoubleDQNTest { DoubleDQN sut = new DoubleDQN(qNetworkMock, targetQNetworkMock, 0.5); // Act - DataSet result = sut.computeTDTargets(transitions); + DataSet result = sut.compute(transitions); // Assert INDArray evaluatedQValues = result.getLabels(); @@ -74,7 +74,7 @@ public class DoubleDQNTest { DoubleDQN sut = new DoubleDQN(qNetworkMock, targetQNetworkMock, 0.5); // Act - DataSet result = sut.computeTDTargets(transitions); + DataSet result = sut.compute(transitions); // Assert INDArray evaluatedQValues = result.getLabels(); @@ -102,7 +102,7 @@ public class DoubleDQNTest { DoubleDQN sut = new DoubleDQN(qNetworkMock, targetQNetworkMock, 0.5); // Act - DataSet result = sut.computeTDTargets(transitions); + DataSet result = sut.compute(transitions); // Assert INDArray evaluatedQValues = result.getLabels(); diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/TDTargetAlgorithm/StandardDQNTest.java b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/TDTargetAlgorithm/StandardDQNTest.java index 6aead9e76..bc7812b36 100644 --- a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/TDTargetAlgorithm/StandardDQNTest.java +++ b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/TDTargetAlgorithm/StandardDQNTest.java @@ -50,7 +50,7 @@ public class StandardDQNTest { StandardDQN sut = new StandardDQN(qNetworkMock, targetQNetworkMock, 0.5); // Act - DataSet result = sut.computeTDTargets(transitions); + DataSet result = sut.compute(transitions); // Assert INDArray evaluatedQValues = result.getLabels(); @@ -72,7 +72,7 @@ public class StandardDQNTest { StandardDQN sut = new StandardDQN(qNetworkMock, targetQNetworkMock, 0.5); // Act - DataSet result = sut.computeTDTargets(transitions); + DataSet result = sut.compute(transitions); // Assert INDArray evaluatedQValues = result.getLabels(); @@ -98,7 +98,7 @@ public class StandardDQNTest { StandardDQN sut = new StandardDQN(qNetworkMock, targetQNetworkMock, 0.5); // Act - DataSet result = sut.computeTDTargets(transitions); + DataSet result = sut.compute(transitions); // Assert INDArray evaluatedQValues = result.getLabels(); diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/support/MockDQN.java b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/support/MockDQN.java index e5a87cb93..7fb3f8300 100644 --- a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/support/MockDQN.java +++ b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/support/MockDQN.java @@ -2,10 +2,12 @@ package org.deeplearning4j.rl4j.learning.sync.support; import org.deeplearning4j.nn.api.NeuralNetwork; import org.deeplearning4j.nn.gradient.Gradient; +import org.deeplearning4j.rl4j.network.ITrainableNeuralNet; import org.deeplearning4j.rl4j.network.NeuralNet; import org.deeplearning4j.rl4j.network.dqn.IDQN; import org.deeplearning4j.rl4j.observation.Observation; import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.dataset.api.DataSet; import java.io.IOException; import java.io.OutputStream; @@ -66,21 +68,21 @@ public class MockDQN implements IDQN { return new INDArray[0]; } + @Override + public void fit(DataSet featuresLabels) { + throw new UnsupportedOperationException(); + } + + @Override + public void copy(ITrainableNeuralNet from) { + throw new UnsupportedOperationException(); + } + @Override public IDQN clone() { return null; } - @Override - public void copy(NeuralNet from) { - - } - - @Override - public void copy(IDQN from) { - - } - @Override public Gradient[] gradient(INDArray input, INDArray label) { return new Gradient[0]; diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/policy/PolicyTest.java b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/policy/PolicyTest.java index 249304afb..b05e55a19 100644 --- a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/policy/PolicyTest.java +++ b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/policy/PolicyTest.java @@ -36,6 +36,7 @@ import org.deeplearning4j.rl4j.util.LegacyMDPWrapper; import org.junit.Test; import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.dataset.api.DataSet; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.lossfunctions.LossFunctions; @@ -88,6 +89,11 @@ public class PolicyTest { throw new UnsupportedOperationException(); } + @Override + public void fit(DataSet featuresLabels) { + throw new UnsupportedOperationException(); + } + @Override public void copy(NN from) { throw new UnsupportedOperationException(); @@ -127,6 +133,16 @@ public class PolicyTest { public void save(String filename) throws IOException { throw new UnsupportedOperationException(); } + + @Override + public INDArray output(Observation observation) { + throw new UnsupportedOperationException(); + } + + @Override + public INDArray output(INDArray batch) { + throw new UnsupportedOperationException(); + } } @Test diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/support/MockDQN.java b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/support/MockDQN.java index 1462d2779..1b9c63b65 100644 --- a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/support/MockDQN.java +++ b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/support/MockDQN.java @@ -2,11 +2,13 @@ package org.deeplearning4j.rl4j.support; import org.deeplearning4j.nn.api.NeuralNetwork; import org.deeplearning4j.nn.gradient.Gradient; +import org.deeplearning4j.rl4j.network.ITrainableNeuralNet; import org.deeplearning4j.rl4j.network.NeuralNet; import org.deeplearning4j.rl4j.network.dqn.IDQN; import org.deeplearning4j.rl4j.observation.Observation; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.common.primitives.Pair; +import org.nd4j.linalg.dataset.api.DataSet; import java.io.IOException; import java.io.OutputStream; @@ -63,6 +65,16 @@ public class MockDQN implements IDQN { return new INDArray[] { batch.mul(-1.0) }; } + @Override + public void fit(DataSet featuresLabels) { + throw new UnsupportedOperationException(); + } + + @Override + public void copy(ITrainableNeuralNet from) { + throw new UnsupportedOperationException(); + } + @Override public IDQN clone() { MockDQN clone = new MockDQN(); @@ -71,16 +83,6 @@ public class MockDQN implements IDQN { return clone; } - @Override - public void copy(NeuralNet from) { - - } - - @Override - public void copy(IDQN from) { - - } - @Override public Gradient[] gradient(INDArray input, INDArray label) { gradientParams.add(new Pair(input, label)); diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/support/MockNeuralNet.java b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/support/MockNeuralNet.java index a5d7c5f3e..fe5b3a7f9 100644 --- a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/support/MockNeuralNet.java +++ b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/support/MockNeuralNet.java @@ -2,8 +2,11 @@ package org.deeplearning4j.rl4j.support; import org.deeplearning4j.nn.api.NeuralNetwork; import org.deeplearning4j.nn.gradient.Gradient; +import org.deeplearning4j.rl4j.network.ITrainableNeuralNet; import org.deeplearning4j.rl4j.network.NeuralNet; +import org.deeplearning4j.rl4j.observation.Observation; import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.dataset.api.DataSet; import org.nd4j.linalg.factory.Nd4j; import java.io.IOException; @@ -39,15 +42,20 @@ public class MockNeuralNet implements NeuralNet { } @Override - public NeuralNet clone() { - return this; + public void fit(DataSet featuresLabels) { + throw new UnsupportedOperationException(); } @Override - public void copy(NeuralNet from) { + public void copy(ITrainableNeuralNet from) { ++copyCallCount; } + @Override + public NeuralNet clone() { + return this; + } + @Override public Gradient[] gradient(INDArray input, INDArray[] labels) { return new Gradient[0]; @@ -77,4 +85,14 @@ public class MockNeuralNet implements NeuralNet { public void save(String filename) throws IOException { } + + @Override + public INDArray output(Observation observation) { + throw new UnsupportedOperationException(); + } + + @Override + public INDArray output(INDArray batch) { + throw new UnsupportedOperationException(); + } } \ No newline at end of file From 99b85c5006d9e70e177fe1e0945885dd8c8bdf5b Mon Sep 17 00:00:00 2001 From: Fariz Rahman Date: Mon, 29 Jun 2020 10:48:25 +0400 Subject: [PATCH 2/2] Python4j: bytes conversion fix + test (#497) * bytes fix+test * bytes fix+test --- .../java/org/nd4j/python4j/PythonTypes.java | 2 +- .../test/java/PythonPrimitiveTypesTest.java | 23 +++++++++++++++---- 2 files changed, 20 insertions(+), 5 deletions(-) diff --git a/python4j/python4j-core/src/main/java/org/nd4j/python4j/PythonTypes.java b/python4j/python4j-core/src/main/java/org/nd4j/python4j/PythonTypes.java index 089c8aefe..d23c70dde 100644 --- a/python4j/python4j-core/src/main/java/org/nd4j/python4j/PythonTypes.java +++ b/python4j/python4j-core/src/main/java/org/nd4j/python4j/PythonTypes.java @@ -258,7 +258,7 @@ public class PythonTypes { return ret; }else if (javaObject instanceof byte[]){ byte[] arr = (byte[]) javaObject; - for (int x : arr) ret.add(x); + for (int x : arr) ret.add(x & 0xff); return ret; } else if (javaObject instanceof long[]) { long[] arr = (long[]) javaObject; diff --git a/python4j/python4j-core/src/test/java/PythonPrimitiveTypesTest.java b/python4j/python4j-core/src/test/java/PythonPrimitiveTypesTest.java index 94423f7de..5080b8b35 100644 --- a/python4j/python4j-core/src/test/java/PythonPrimitiveTypesTest.java +++ b/python4j/python4j-core/src/test/java/PythonPrimitiveTypesTest.java @@ -81,16 +81,31 @@ public class PythonPrimitiveTypesTest { } @Test public void testBytes() { + byte[] bytes = new byte[256]; + for (int i = 0; i < 256; i++) { + bytes[i] = (byte) i; + } + List inputs = new ArrayList<>(); + inputs.add(new PythonVariable<>("b1", PythonTypes.BYTES, bytes)); + List outputs = new ArrayList<>(); + outputs.add(new PythonVariable<>("b2", PythonTypes.BYTES)); + String code = "b2=b1"; + PythonExecutioner.exec(code, inputs, outputs); + Assert.assertArrayEquals(bytes, (byte[]) outputs.get(0).getValue()); + } + + @Test + public void testBytes2() { byte[] bytes = new byte[]{97, 98, 99}; List inputs = new ArrayList<>(); - inputs.add(new PythonVariable<>("buff", PythonTypes.BYTES, bytes)); + inputs.add(new PythonVariable<>("b1", PythonTypes.BYTES, bytes)); List outputs = new ArrayList<>(); outputs.add(new PythonVariable<>("s1", PythonTypes.STR)); - outputs.add(new PythonVariable<>("buff2", PythonTypes.BYTES)); - String code = "s1 = ''.join(chr(c) for c in buff)\nbuff2=b'def'"; + outputs.add(new PythonVariable<>("b2", PythonTypes.BYTES)); + String code = "s1 = ''.join(chr(c) for c in b1)\nb2=b'def'"; PythonExecutioner.exec(code, inputs, outputs); Assert.assertEquals("abc", outputs.get(0).getValue()); - Assert.assertArrayEquals(new byte[]{100, 101, 102}, (byte[])outputs.get(1).getValue()); + Assert.assertArrayEquals(new byte[]{100, 101, 102}, (byte[]) outputs.get(1).getValue()); } }