RL4J: Use directly NeuralNet instances in DoubleDQN and StandardDQN (#499)

Signed-off-by: Alexandre Boulanger <aboulang2002@yahoo.com>
master
Alexandre Boulanger 2020-06-25 22:23:47 -04:00 committed by GitHub
parent 654afc810d
commit fb578fdecd
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 126 additions and 159 deletions

View File

@ -17,7 +17,6 @@ package org.deeplearning4j.rl4j.agent.update;
import lombok.Getter;
import org.deeplearning4j.rl4j.learning.sync.Transition;
import org.deeplearning4j.rl4j.learning.sync.qlearning.TargetQNetworkSource;
import org.deeplearning4j.rl4j.learning.sync.qlearning.discrete.TDTargetAlgorithm.DoubleDQN;
import org.deeplearning4j.rl4j.learning.sync.qlearning.discrete.TDTargetAlgorithm.ITDTargetAlgorithm;
import org.deeplearning4j.rl4j.learning.sync.qlearning.discrete.TDTargetAlgorithm.StandardDQN;
@ -28,13 +27,10 @@ import java.util.List;
// Temporary class that will be replaced with a more generic class that delegates gradient computation
// and network update to sub components.
public class DQNNeuralNetUpdateRule implements IUpdateRule<Transition<Integer>>, TargetQNetworkSource {
public class DQNNeuralNetUpdateRule implements IUpdateRule<Transition<Integer>> {
@Getter
private final IDQN qNetwork;
@Getter
private IDQN targetQNetwork;
private final IDQN targetQNetwork;
private final int targetUpdateFrequency;
private final ITDTargetAlgorithm<Integer> tdTargetAlgorithm;
@ -47,8 +43,8 @@ public class DQNNeuralNetUpdateRule implements IUpdateRule<Transition<Integer>>,
this.targetQNetwork = qNetwork.clone();
this.targetUpdateFrequency = targetUpdateFrequency;
tdTargetAlgorithm = isDoubleDQN
? new DoubleDQN(this, gamma, errorClamp)
: new StandardDQN(this, gamma, errorClamp);
? new DoubleDQN(qNetwork, targetQNetwork, gamma, errorClamp)
: new StandardDQN(qNetwork, targetQNetwork, gamma, errorClamp);
}
@Override
@ -56,7 +52,7 @@ public class DQNNeuralNetUpdateRule implements IUpdateRule<Transition<Integer>>,
DataSet targets = tdTargetAlgorithm.computeTDTargets(trainingBatch);
qNetwork.fit(targets.getFeatures(), targets.getLabels());
if(++updateCount % targetUpdateFrequency == 0) {
targetQNetwork = qNetwork.clone();
targetQNetwork.copy(qNetwork);
}
}
}

View File

@ -1,28 +0,0 @@
/*******************************************************************************
* Copyright (c) 2015-2019 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.rl4j.learning.sync.qlearning;
import org.deeplearning4j.rl4j.network.dqn.IDQN;
/**
* An interface that is an extension of {@link QNetworkSource} for all implementations capable of supplying a target Q-Network
*
* @author Alexandre Boulanger
*/
public interface TargetQNetworkSource extends QNetworkSource {
IDQN getTargetQNetwork();
}

View File

@ -16,8 +16,7 @@
package org.deeplearning4j.rl4j.learning.sync.qlearning.discrete.TDTargetAlgorithm;
import org.deeplearning4j.rl4j.learning.sync.qlearning.TargetQNetworkSource;
import org.deeplearning4j.rl4j.network.dqn.IDQN;
import org.deeplearning4j.rl4j.network.IOutputNeuralNet;
import org.nd4j.linalg.api.ndarray.INDArray;
/**
@ -28,7 +27,7 @@ import org.nd4j.linalg.api.ndarray.INDArray;
*/
public abstract class BaseDQNAlgorithm extends BaseTDTargetAlgorithm {
private final TargetQNetworkSource qTargetNetworkSource;
private final IOutputNeuralNet targetQNetwork;
/**
* In litterature, this corresponds to Q{net}(s(t+1), a)
@ -40,23 +39,21 @@ public abstract class BaseDQNAlgorithm extends BaseTDTargetAlgorithm {
*/
protected INDArray targetQNetworkNextObservation;
protected BaseDQNAlgorithm(TargetQNetworkSource qTargetNetworkSource, double gamma) {
super(qTargetNetworkSource, gamma);
this.qTargetNetworkSource = qTargetNetworkSource;
protected BaseDQNAlgorithm(IOutputNeuralNet qNetwork, IOutputNeuralNet targetQNetwork, double gamma) {
super(qNetwork, gamma);
this.targetQNetwork = targetQNetwork;
}
protected BaseDQNAlgorithm(TargetQNetworkSource qTargetNetworkSource, double gamma, double errorClamp) {
super(qTargetNetworkSource, gamma, errorClamp);
this.qTargetNetworkSource = qTargetNetworkSource;
protected BaseDQNAlgorithm(IOutputNeuralNet qNetwork, IOutputNeuralNet targetQNetwork, double gamma, double errorClamp) {
super(qNetwork, gamma, errorClamp);
this.targetQNetwork = targetQNetwork;
}
@Override
protected void initComputation(INDArray observations, INDArray nextObservations) {
super.initComputation(observations, nextObservations);
qNetworkNextObservation = qNetworkSource.getQNetwork().output(nextObservations);
IDQN targetQNetwork = qTargetNetworkSource.getTargetQNetwork();
qNetworkNextObservation = qNetwork.output(nextObservations);
targetQNetworkNextObservation = targetQNetwork.output(nextObservations);
}
}

View File

@ -17,7 +17,7 @@
package org.deeplearning4j.rl4j.learning.sync.qlearning.discrete.TDTargetAlgorithm;
import org.deeplearning4j.rl4j.learning.sync.Transition;
import org.deeplearning4j.rl4j.learning.sync.qlearning.QNetworkSource;
import org.deeplearning4j.rl4j.network.IOutputNeuralNet;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.api.DataSet;
@ -30,7 +30,7 @@ import java.util.List;
*/
public abstract class BaseTDTargetAlgorithm implements ITDTargetAlgorithm<Integer> {
protected final QNetworkSource qNetworkSource;
protected final IOutputNeuralNet qNetwork;
protected final double gamma;
private final double errorClamp;
@ -38,12 +38,12 @@ public abstract class BaseTDTargetAlgorithm implements ITDTargetAlgorithm<Intege
/**
*
* @param qNetworkSource The source of the Q-Network
* @param qNetwork The Q-Network
* @param gamma The discount factor
* @param errorClamp Will prevent the new Q-Value from being farther than <i>errorClamp</i> away from the previous value. Double.NaN will disable the clamping.
*/
protected BaseTDTargetAlgorithm(QNetworkSource qNetworkSource, double gamma, double errorClamp) {
this.qNetworkSource = qNetworkSource;
protected BaseTDTargetAlgorithm(IOutputNeuralNet qNetwork, double gamma, double errorClamp) {
this.qNetwork = qNetwork;
this.gamma = gamma;
this.errorClamp = errorClamp;
@ -52,12 +52,12 @@ public abstract class BaseTDTargetAlgorithm implements ITDTargetAlgorithm<Intege
/**
*
* @param qNetworkSource The source of the Q-Network
* @param qNetwork The Q-Network
* @param gamma The discount factor
* Note: Error clamping is disabled with this ctor
*/
protected BaseTDTargetAlgorithm(QNetworkSource qNetworkSource, double gamma) {
this(qNetworkSource, gamma, Double.NaN);
protected BaseTDTargetAlgorithm(IOutputNeuralNet qNetwork, double gamma) {
this(qNetwork, gamma, Double.NaN);
}
/**
@ -89,8 +89,7 @@ public abstract class BaseTDTargetAlgorithm implements ITDTargetAlgorithm<Intege
initComputation(observations, nextObservations);
INDArray updatedQValues = qNetworkSource.getQNetwork().output(observations);
INDArray updatedQValues = qNetwork.output(observations);
for (int i = 0; i < size; ++i) {
Transition<Integer> transition = transitions.get(i);
double yTarget = computeTarget(i, transition.getReward(), transition.isTerminal());

View File

@ -16,7 +16,7 @@
package org.deeplearning4j.rl4j.learning.sync.qlearning.discrete.TDTargetAlgorithm;
import org.deeplearning4j.rl4j.learning.sync.qlearning.TargetQNetworkSource;
import org.deeplearning4j.rl4j.network.IOutputNeuralNet;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
@ -32,12 +32,12 @@ public class DoubleDQN extends BaseDQNAlgorithm {
// In litterature, this corresponds to: max_{a}Q(s_{t+1}, a)
private INDArray maxActionsFromQNetworkNextObservation;
public DoubleDQN(TargetQNetworkSource qTargetNetworkSource, double gamma) {
super(qTargetNetworkSource, gamma);
public DoubleDQN(IOutputNeuralNet qNetwork, IOutputNeuralNet targetQNetwork, double gamma) {
super(qNetwork, targetQNetwork, gamma);
}
public DoubleDQN(TargetQNetworkSource qTargetNetworkSource, double gamma, double errorClamp) {
super(qTargetNetworkSource, gamma, errorClamp);
public DoubleDQN(IOutputNeuralNet qNetwork, IOutputNeuralNet targetQNetwork, double gamma, double errorClamp) {
super(qNetwork, targetQNetwork, gamma, errorClamp);
}
@Override

View File

@ -16,7 +16,7 @@
package org.deeplearning4j.rl4j.learning.sync.qlearning.discrete.TDTargetAlgorithm;
import org.deeplearning4j.rl4j.learning.sync.qlearning.TargetQNetworkSource;
import org.deeplearning4j.rl4j.network.IOutputNeuralNet;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
@ -32,12 +32,12 @@ public class StandardDQN extends BaseDQNAlgorithm {
// In litterature, this corresponds to: max_{a}Q_{tar}(s_{t+1}, a)
private INDArray maxActionsFromQTargetNextObservation;
public StandardDQN(TargetQNetworkSource qTargetNetworkSource, double gamma) {
super(qTargetNetworkSource, gamma);
public StandardDQN(IOutputNeuralNet qNetwork, IOutputNeuralNet targetQNetwork, double gamma) {
super(qNetwork, targetQNetwork, gamma);
}
public StandardDQN(TargetQNetworkSource qTargetNetworkSource, double gamma, double errorClamp) {
super(qTargetNetworkSource, gamma, errorClamp);
public StandardDQN(IOutputNeuralNet qNetwork, IOutputNeuralNet targetQNetwork, double gamma, double errorClamp) {
super(qNetwork, targetQNetwork, gamma, errorClamp);
}
@Override

View File

@ -1,28 +1,38 @@
/*******************************************************************************
* Copyright (c) 2015-2019 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.rl4j.learning.sync.qlearning;
import org.deeplearning4j.rl4j.network.dqn.IDQN;
/**
* An interface for all implementations capable of supplying a Q-Network
*
* @author Alexandre Boulanger
*/
public interface QNetworkSource {
IDQN getQNetwork();
}
/*******************************************************************************
* Copyright (c) 2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.rl4j.network;
import org.deeplearning4j.rl4j.observation.Observation;
import org.nd4j.linalg.api.ndarray.INDArray;
/**
* An interface defining the output aspect of a {@link NeuralNet}.
*/
public interface IOutputNeuralNet {
/**
* Compute the output for the supplied observation.
* @param observation An {@link Observation}
* @return The ouptut of the network
*/
INDArray output(Observation observation);
/**
* Compute the output for the supplied batch.
* @param batch
* @return The ouptut of the network
*/
INDArray output(INDArray batch);
}

View File

@ -17,6 +17,7 @@
package org.deeplearning4j.rl4j.network.dqn;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.rl4j.network.IOutputNeuralNet;
import org.deeplearning4j.rl4j.network.NeuralNet;
import org.deeplearning4j.rl4j.observation.Observation;
import org.nd4j.linalg.api.ndarray.INDArray;
@ -27,7 +28,7 @@ import org.nd4j.linalg.api.ndarray.INDArray;
* This neural net quantify the value of each action given a state
*
*/
public interface IDQN<NN extends IDQN> extends NeuralNet<NN> {
public interface IDQN<NN extends IDQN> extends NeuralNet<NN>, IOutputNeuralNet {
boolean isRecurrent();
@ -37,9 +38,6 @@ public interface IDQN<NN extends IDQN> extends NeuralNet<NN> {
void fit(INDArray input, INDArray[] labels);
INDArray output(INDArray batch);
INDArray output(Observation observation);
INDArray[] outputAll(INDArray batch);
NN clone();

View File

@ -1,10 +1,13 @@
package org.deeplearning4j.rl4j.learning.sync.qlearning.discrete.TDTargetAlgorithm;
import org.deeplearning4j.rl4j.learning.sync.Transition;
import org.deeplearning4j.rl4j.learning.sync.support.MockDQN;
import org.deeplearning4j.rl4j.learning.sync.support.MockTargetQNetworkSource;
import org.deeplearning4j.rl4j.network.IOutputNeuralNet;
import org.deeplearning4j.rl4j.observation.Observation;
import org.junit.Before;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.mockito.Mock;
import org.mockito.junit.MockitoJUnitRunner;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.api.DataSet;
import org.nd4j.linalg.factory.Nd4j;
@ -13,16 +16,29 @@ import java.util.ArrayList;
import java.util.List;
import static org.junit.Assert.assertEquals;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.when;
@RunWith(MockitoJUnitRunner.class)
public class DoubleDQNTest {
@Mock
IOutputNeuralNet qNetworkMock;
@Mock
IOutputNeuralNet targetQNetworkMock;
@Before
public void setup() {
when(qNetworkMock.output(any(INDArray.class))).thenAnswer(i -> i.getArguments()[0]);
}
@Test
public void when_isTerminal_expect_rewardValueAtIdx0() {
// Assemble
MockDQN qNetwork = new MockDQN();
MockDQN targetQNetwork = new MockDQN();
MockTargetQNetworkSource targetQNetworkSource = new MockTargetQNetworkSource(qNetwork, targetQNetwork);
when(targetQNetworkMock.output(any(INDArray.class))).thenAnswer(i -> i.getArguments()[0]);
List<Transition<Integer>> transitions = new ArrayList<Transition<Integer>>() {
{
@ -31,7 +47,7 @@ public class DoubleDQNTest {
}
};
DoubleDQN sut = new DoubleDQN(targetQNetworkSource, 0.5);
DoubleDQN sut = new DoubleDQN(qNetworkMock, targetQNetworkMock, 0.5);
// Act
DataSet result = sut.computeTDTargets(transitions);
@ -46,9 +62,7 @@ public class DoubleDQNTest {
public void when_isNotTerminal_expect_rewardPlusEstimatedQValue() {
// Assemble
MockDQN qNetwork = new MockDQN();
MockDQN targetQNetwork = new MockDQN(-1.0);
MockTargetQNetworkSource targetQNetworkSource = new MockTargetQNetworkSource(qNetwork, targetQNetwork);
when(targetQNetworkMock.output(any(INDArray.class))).thenAnswer(i -> ((INDArray)i.getArguments()[0]).mul(-1.0));
List<Transition<Integer>> transitions = new ArrayList<Transition<Integer>>() {
{
@ -57,7 +71,7 @@ public class DoubleDQNTest {
}
};
DoubleDQN sut = new DoubleDQN(targetQNetworkSource, 0.5);
DoubleDQN sut = new DoubleDQN(qNetworkMock, targetQNetworkMock, 0.5);
// Act
DataSet result = sut.computeTDTargets(transitions);
@ -72,9 +86,7 @@ public class DoubleDQNTest {
public void when_batchHasMoreThanOne_expect_everySampleEvaluated() {
// Assemble
MockDQN qNetwork = new MockDQN();
MockDQN targetQNetwork = new MockDQN(-1.0);
MockTargetQNetworkSource targetQNetworkSource = new MockTargetQNetworkSource(qNetwork, targetQNetwork);
when(targetQNetworkMock.output(any(INDArray.class))).thenAnswer(i -> ((INDArray)i.getArguments()[0]).mul(-1.0));
List<Transition<Integer>> transitions = new ArrayList<Transition<Integer>>() {
{
@ -87,7 +99,7 @@ public class DoubleDQNTest {
}
};
DoubleDQN sut = new DoubleDQN(targetQNetworkSource, 0.5);
DoubleDQN sut = new DoubleDQN(qNetworkMock, targetQNetworkMock, 0.5);
// Act
DataSet result = sut.computeTDTargets(transitions);

View File

@ -1,10 +1,13 @@
package org.deeplearning4j.rl4j.learning.sync.qlearning.discrete.TDTargetAlgorithm;
import org.deeplearning4j.rl4j.learning.sync.Transition;
import org.deeplearning4j.rl4j.learning.sync.support.MockDQN;
import org.deeplearning4j.rl4j.learning.sync.support.MockTargetQNetworkSource;
import org.deeplearning4j.rl4j.network.IOutputNeuralNet;
import org.deeplearning4j.rl4j.observation.Observation;
import org.junit.Before;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.mockito.Mock;
import org.mockito.junit.MockitoJUnitRunner;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.api.DataSet;
import org.nd4j.linalg.factory.Nd4j;
@ -12,17 +15,31 @@ import org.nd4j.linalg.factory.Nd4j;
import java.util.ArrayList;
import java.util.List;
import static org.junit.Assert.*;
import static org.junit.Assert.assertEquals;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.when;
@RunWith(MockitoJUnitRunner.class)
public class StandardDQNTest {
@Mock
IOutputNeuralNet qNetworkMock;
@Mock
IOutputNeuralNet targetQNetworkMock;
@Before
public void setup() {
when(qNetworkMock.output(any(INDArray.class))).thenAnswer(i -> i.getArguments()[0]);
when(targetQNetworkMock.output(any(INDArray.class))).thenAnswer(i -> i.getArguments()[0]);
}
@Test
public void when_isTerminal_expect_rewardValueAtIdx0() {
// Assemble
MockDQN qNetwork = new MockDQN();
MockDQN targetQNetwork = new MockDQN();
MockTargetQNetworkSource targetQNetworkSource = new MockTargetQNetworkSource(qNetwork, targetQNetwork);
List<Transition<Integer>> transitions = new ArrayList<Transition<Integer>>() {
{
add(buildTransition(buildObservation(new double[]{1.1, 2.2}),
@ -30,7 +47,7 @@ public class StandardDQNTest {
}
};
StandardDQN sut = new StandardDQN(targetQNetworkSource, 0.5);
StandardDQN sut = new StandardDQN(qNetworkMock, targetQNetworkMock, 0.5);
// Act
DataSet result = sut.computeTDTargets(transitions);
@ -45,10 +62,6 @@ public class StandardDQNTest {
public void when_isNotTerminal_expect_rewardPlusEstimatedQValue() {
// Assemble
MockDQN qNetwork = new MockDQN();
MockDQN targetQNetwork = new MockDQN();
MockTargetQNetworkSource targetQNetworkSource = new MockTargetQNetworkSource(qNetwork, targetQNetwork);
List<Transition<Integer>> transitions = new ArrayList<Transition<Integer>>() {
{
add(buildTransition(buildObservation(new double[]{1.1, 2.2}),
@ -56,7 +69,7 @@ public class StandardDQNTest {
}
};
StandardDQN sut = new StandardDQN(targetQNetworkSource, 0.5);
StandardDQN sut = new StandardDQN(qNetworkMock, targetQNetworkMock, 0.5);
// Act
DataSet result = sut.computeTDTargets(transitions);
@ -71,10 +84,6 @@ public class StandardDQNTest {
public void when_batchHasMoreThanOne_expect_everySampleEvaluated() {
// Assemble
MockDQN qNetwork = new MockDQN();
MockDQN targetQNetwork = new MockDQN();
MockTargetQNetworkSource targetQNetworkSource = new MockTargetQNetworkSource(qNetwork, targetQNetwork);
List<Transition<Integer>> transitions = new ArrayList<Transition<Integer>>() {
{
add(buildTransition(buildObservation(new double[]{1.1, 2.2}),
@ -86,7 +95,7 @@ public class StandardDQNTest {
}
};
StandardDQN sut = new StandardDQN(targetQNetworkSource, 0.5);
StandardDQN sut = new StandardDQN(qNetworkMock, targetQNetworkMock, 0.5);
// Act
DataSet result = sut.computeTDTargets(transitions);

View File

@ -1,26 +0,0 @@
package org.deeplearning4j.rl4j.learning.sync.support;
import org.deeplearning4j.rl4j.learning.sync.qlearning.TargetQNetworkSource;
import org.deeplearning4j.rl4j.network.dqn.IDQN;
public class MockTargetQNetworkSource implements TargetQNetworkSource {
private final IDQN qNetwork;
private final IDQN targetQNetwork;
public MockTargetQNetworkSource(IDQN qNetwork, IDQN targetQNetwork) {
this.qNetwork = qNetwork;
this.targetQNetwork = targetQNetwork;
}
@Override
public IDQN getTargetQNetwork() {
return targetQNetwork;
}
@Override
public IDQN getQNetwork() {
return qNetwork;
}
}