RL4J: Use directly NeuralNet instances in DoubleDQN and StandardDQN (#499)

Signed-off-by: Alexandre Boulanger <aboulang2002@yahoo.com>
master
Alexandre Boulanger 2020-06-25 22:23:47 -04:00 committed by GitHub
parent 654afc810d
commit fb578fdecd
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 126 additions and 159 deletions

View File

@ -17,7 +17,6 @@ package org.deeplearning4j.rl4j.agent.update;
import lombok.Getter; import lombok.Getter;
import org.deeplearning4j.rl4j.learning.sync.Transition; import org.deeplearning4j.rl4j.learning.sync.Transition;
import org.deeplearning4j.rl4j.learning.sync.qlearning.TargetQNetworkSource;
import org.deeplearning4j.rl4j.learning.sync.qlearning.discrete.TDTargetAlgorithm.DoubleDQN; import org.deeplearning4j.rl4j.learning.sync.qlearning.discrete.TDTargetAlgorithm.DoubleDQN;
import org.deeplearning4j.rl4j.learning.sync.qlearning.discrete.TDTargetAlgorithm.ITDTargetAlgorithm; import org.deeplearning4j.rl4j.learning.sync.qlearning.discrete.TDTargetAlgorithm.ITDTargetAlgorithm;
import org.deeplearning4j.rl4j.learning.sync.qlearning.discrete.TDTargetAlgorithm.StandardDQN; import org.deeplearning4j.rl4j.learning.sync.qlearning.discrete.TDTargetAlgorithm.StandardDQN;
@ -28,13 +27,10 @@ import java.util.List;
// Temporary class that will be replaced with a more generic class that delegates gradient computation // Temporary class that will be replaced with a more generic class that delegates gradient computation
// and network update to sub components. // and network update to sub components.
public class DQNNeuralNetUpdateRule implements IUpdateRule<Transition<Integer>>, TargetQNetworkSource { public class DQNNeuralNetUpdateRule implements IUpdateRule<Transition<Integer>> {
@Getter
private final IDQN qNetwork; private final IDQN qNetwork;
private final IDQN targetQNetwork;
@Getter
private IDQN targetQNetwork;
private final int targetUpdateFrequency; private final int targetUpdateFrequency;
private final ITDTargetAlgorithm<Integer> tdTargetAlgorithm; private final ITDTargetAlgorithm<Integer> tdTargetAlgorithm;
@ -47,8 +43,8 @@ public class DQNNeuralNetUpdateRule implements IUpdateRule<Transition<Integer>>,
this.targetQNetwork = qNetwork.clone(); this.targetQNetwork = qNetwork.clone();
this.targetUpdateFrequency = targetUpdateFrequency; this.targetUpdateFrequency = targetUpdateFrequency;
tdTargetAlgorithm = isDoubleDQN tdTargetAlgorithm = isDoubleDQN
? new DoubleDQN(this, gamma, errorClamp) ? new DoubleDQN(qNetwork, targetQNetwork, gamma, errorClamp)
: new StandardDQN(this, gamma, errorClamp); : new StandardDQN(qNetwork, targetQNetwork, gamma, errorClamp);
} }
@Override @Override
@ -56,7 +52,7 @@ public class DQNNeuralNetUpdateRule implements IUpdateRule<Transition<Integer>>,
DataSet targets = tdTargetAlgorithm.computeTDTargets(trainingBatch); DataSet targets = tdTargetAlgorithm.computeTDTargets(trainingBatch);
qNetwork.fit(targets.getFeatures(), targets.getLabels()); qNetwork.fit(targets.getFeatures(), targets.getLabels());
if(++updateCount % targetUpdateFrequency == 0) { if(++updateCount % targetUpdateFrequency == 0) {
targetQNetwork = qNetwork.clone(); targetQNetwork.copy(qNetwork);
} }
} }
} }

View File

@ -1,28 +0,0 @@
/*******************************************************************************
* Copyright (c) 2015-2019 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.rl4j.learning.sync.qlearning;
import org.deeplearning4j.rl4j.network.dqn.IDQN;
/**
* An interface that is an extension of {@link QNetworkSource} for all implementations capable of supplying a target Q-Network
*
* @author Alexandre Boulanger
*/
public interface TargetQNetworkSource extends QNetworkSource {
IDQN getTargetQNetwork();
}

View File

@ -16,8 +16,7 @@
package org.deeplearning4j.rl4j.learning.sync.qlearning.discrete.TDTargetAlgorithm; package org.deeplearning4j.rl4j.learning.sync.qlearning.discrete.TDTargetAlgorithm;
import org.deeplearning4j.rl4j.learning.sync.qlearning.TargetQNetworkSource; import org.deeplearning4j.rl4j.network.IOutputNeuralNet;
import org.deeplearning4j.rl4j.network.dqn.IDQN;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
/** /**
@ -28,7 +27,7 @@ import org.nd4j.linalg.api.ndarray.INDArray;
*/ */
public abstract class BaseDQNAlgorithm extends BaseTDTargetAlgorithm { public abstract class BaseDQNAlgorithm extends BaseTDTargetAlgorithm {
private final TargetQNetworkSource qTargetNetworkSource; private final IOutputNeuralNet targetQNetwork;
/** /**
* In litterature, this corresponds to Q{net}(s(t+1), a) * In litterature, this corresponds to Q{net}(s(t+1), a)
@ -40,23 +39,21 @@ public abstract class BaseDQNAlgorithm extends BaseTDTargetAlgorithm {
*/ */
protected INDArray targetQNetworkNextObservation; protected INDArray targetQNetworkNextObservation;
protected BaseDQNAlgorithm(TargetQNetworkSource qTargetNetworkSource, double gamma) { protected BaseDQNAlgorithm(IOutputNeuralNet qNetwork, IOutputNeuralNet targetQNetwork, double gamma) {
super(qTargetNetworkSource, gamma); super(qNetwork, gamma);
this.qTargetNetworkSource = qTargetNetworkSource; this.targetQNetwork = targetQNetwork;
} }
protected BaseDQNAlgorithm(TargetQNetworkSource qTargetNetworkSource, double gamma, double errorClamp) { protected BaseDQNAlgorithm(IOutputNeuralNet qNetwork, IOutputNeuralNet targetQNetwork, double gamma, double errorClamp) {
super(qTargetNetworkSource, gamma, errorClamp); super(qNetwork, gamma, errorClamp);
this.qTargetNetworkSource = qTargetNetworkSource; this.targetQNetwork = targetQNetwork;
} }
@Override @Override
protected void initComputation(INDArray observations, INDArray nextObservations) { protected void initComputation(INDArray observations, INDArray nextObservations) {
super.initComputation(observations, nextObservations); super.initComputation(observations, nextObservations);
qNetworkNextObservation = qNetworkSource.getQNetwork().output(nextObservations); qNetworkNextObservation = qNetwork.output(nextObservations);
IDQN targetQNetwork = qTargetNetworkSource.getTargetQNetwork();
targetQNetworkNextObservation = targetQNetwork.output(nextObservations); targetQNetworkNextObservation = targetQNetwork.output(nextObservations);
} }
} }

View File

@ -17,7 +17,7 @@
package org.deeplearning4j.rl4j.learning.sync.qlearning.discrete.TDTargetAlgorithm; package org.deeplearning4j.rl4j.learning.sync.qlearning.discrete.TDTargetAlgorithm;
import org.deeplearning4j.rl4j.learning.sync.Transition; import org.deeplearning4j.rl4j.learning.sync.Transition;
import org.deeplearning4j.rl4j.learning.sync.qlearning.QNetworkSource; import org.deeplearning4j.rl4j.network.IOutputNeuralNet;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.api.DataSet; import org.nd4j.linalg.dataset.api.DataSet;
@ -30,7 +30,7 @@ import java.util.List;
*/ */
public abstract class BaseTDTargetAlgorithm implements ITDTargetAlgorithm<Integer> { public abstract class BaseTDTargetAlgorithm implements ITDTargetAlgorithm<Integer> {
protected final QNetworkSource qNetworkSource; protected final IOutputNeuralNet qNetwork;
protected final double gamma; protected final double gamma;
private final double errorClamp; private final double errorClamp;
@ -38,12 +38,12 @@ public abstract class BaseTDTargetAlgorithm implements ITDTargetAlgorithm<Intege
/** /**
* *
* @param qNetworkSource The source of the Q-Network * @param qNetwork The Q-Network
* @param gamma The discount factor * @param gamma The discount factor
* @param errorClamp Will prevent the new Q-Value from being farther than <i>errorClamp</i> away from the previous value. Double.NaN will disable the clamping. * @param errorClamp Will prevent the new Q-Value from being farther than <i>errorClamp</i> away from the previous value. Double.NaN will disable the clamping.
*/ */
protected BaseTDTargetAlgorithm(QNetworkSource qNetworkSource, double gamma, double errorClamp) { protected BaseTDTargetAlgorithm(IOutputNeuralNet qNetwork, double gamma, double errorClamp) {
this.qNetworkSource = qNetworkSource; this.qNetwork = qNetwork;
this.gamma = gamma; this.gamma = gamma;
this.errorClamp = errorClamp; this.errorClamp = errorClamp;
@ -52,12 +52,12 @@ public abstract class BaseTDTargetAlgorithm implements ITDTargetAlgorithm<Intege
/** /**
* *
* @param qNetworkSource The source of the Q-Network * @param qNetwork The Q-Network
* @param gamma The discount factor * @param gamma The discount factor
* Note: Error clamping is disabled with this ctor * Note: Error clamping is disabled with this ctor
*/ */
protected BaseTDTargetAlgorithm(QNetworkSource qNetworkSource, double gamma) { protected BaseTDTargetAlgorithm(IOutputNeuralNet qNetwork, double gamma) {
this(qNetworkSource, gamma, Double.NaN); this(qNetwork, gamma, Double.NaN);
} }
/** /**
@ -89,8 +89,7 @@ public abstract class BaseTDTargetAlgorithm implements ITDTargetAlgorithm<Intege
initComputation(observations, nextObservations); initComputation(observations, nextObservations);
INDArray updatedQValues = qNetworkSource.getQNetwork().output(observations); INDArray updatedQValues = qNetwork.output(observations);
for (int i = 0; i < size; ++i) { for (int i = 0; i < size; ++i) {
Transition<Integer> transition = transitions.get(i); Transition<Integer> transition = transitions.get(i);
double yTarget = computeTarget(i, transition.getReward(), transition.isTerminal()); double yTarget = computeTarget(i, transition.getReward(), transition.isTerminal());

View File

@ -16,7 +16,7 @@
package org.deeplearning4j.rl4j.learning.sync.qlearning.discrete.TDTargetAlgorithm; package org.deeplearning4j.rl4j.learning.sync.qlearning.discrete.TDTargetAlgorithm;
import org.deeplearning4j.rl4j.learning.sync.qlearning.TargetQNetworkSource; import org.deeplearning4j.rl4j.network.IOutputNeuralNet;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
@ -32,12 +32,12 @@ public class DoubleDQN extends BaseDQNAlgorithm {
// In litterature, this corresponds to: max_{a}Q(s_{t+1}, a) // In litterature, this corresponds to: max_{a}Q(s_{t+1}, a)
private INDArray maxActionsFromQNetworkNextObservation; private INDArray maxActionsFromQNetworkNextObservation;
public DoubleDQN(TargetQNetworkSource qTargetNetworkSource, double gamma) { public DoubleDQN(IOutputNeuralNet qNetwork, IOutputNeuralNet targetQNetwork, double gamma) {
super(qTargetNetworkSource, gamma); super(qNetwork, targetQNetwork, gamma);
} }
public DoubleDQN(TargetQNetworkSource qTargetNetworkSource, double gamma, double errorClamp) { public DoubleDQN(IOutputNeuralNet qNetwork, IOutputNeuralNet targetQNetwork, double gamma, double errorClamp) {
super(qTargetNetworkSource, gamma, errorClamp); super(qNetwork, targetQNetwork, gamma, errorClamp);
} }
@Override @Override

View File

@ -16,7 +16,7 @@
package org.deeplearning4j.rl4j.learning.sync.qlearning.discrete.TDTargetAlgorithm; package org.deeplearning4j.rl4j.learning.sync.qlearning.discrete.TDTargetAlgorithm;
import org.deeplearning4j.rl4j.learning.sync.qlearning.TargetQNetworkSource; import org.deeplearning4j.rl4j.network.IOutputNeuralNet;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
@ -32,12 +32,12 @@ public class StandardDQN extends BaseDQNAlgorithm {
// In litterature, this corresponds to: max_{a}Q_{tar}(s_{t+1}, a) // In litterature, this corresponds to: max_{a}Q_{tar}(s_{t+1}, a)
private INDArray maxActionsFromQTargetNextObservation; private INDArray maxActionsFromQTargetNextObservation;
public StandardDQN(TargetQNetworkSource qTargetNetworkSource, double gamma) { public StandardDQN(IOutputNeuralNet qNetwork, IOutputNeuralNet targetQNetwork, double gamma) {
super(qTargetNetworkSource, gamma); super(qNetwork, targetQNetwork, gamma);
} }
public StandardDQN(TargetQNetworkSource qTargetNetworkSource, double gamma, double errorClamp) { public StandardDQN(IOutputNeuralNet qNetwork, IOutputNeuralNet targetQNetwork, double gamma, double errorClamp) {
super(qTargetNetworkSource, gamma, errorClamp); super(qNetwork, targetQNetwork, gamma, errorClamp);
} }
@Override @Override

View File

@ -1,28 +1,38 @@
/******************************************************************************* /*******************************************************************************
* Copyright (c) 2015-2019 Skymind, Inc. * Copyright (c) 2020 Konduit K.K.
* *
* This program and the accompanying materials are made available under the * This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at * terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0. * https://www.apache.org/licenses/LICENSE-2.0.
* *
* Unless required by applicable law or agreed to in writing, software * Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations * License for the specific language governing permissions and limitations
* under the License. * under the License.
* *
* SPDX-License-Identifier: Apache-2.0 * SPDX-License-Identifier: Apache-2.0
******************************************************************************/ ******************************************************************************/
package org.deeplearning4j.rl4j.network;
package org.deeplearning4j.rl4j.learning.sync.qlearning;
import org.deeplearning4j.rl4j.observation.Observation;
import org.deeplearning4j.rl4j.network.dqn.IDQN; import org.nd4j.linalg.api.ndarray.INDArray;
/** /**
* An interface for all implementations capable of supplying a Q-Network * An interface defining the output aspect of a {@link NeuralNet}.
* */
* @author Alexandre Boulanger public interface IOutputNeuralNet {
*/ /**
public interface QNetworkSource { * Compute the output for the supplied observation.
IDQN getQNetwork(); * @param observation An {@link Observation}
} * @return The ouptut of the network
*/
INDArray output(Observation observation);
/**
* Compute the output for the supplied batch.
* @param batch
* @return The ouptut of the network
*/
INDArray output(INDArray batch);
}

View File

@ -17,6 +17,7 @@
package org.deeplearning4j.rl4j.network.dqn; package org.deeplearning4j.rl4j.network.dqn;
import org.deeplearning4j.nn.gradient.Gradient; import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.rl4j.network.IOutputNeuralNet;
import org.deeplearning4j.rl4j.network.NeuralNet; import org.deeplearning4j.rl4j.network.NeuralNet;
import org.deeplearning4j.rl4j.observation.Observation; import org.deeplearning4j.rl4j.observation.Observation;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
@ -27,7 +28,7 @@ import org.nd4j.linalg.api.ndarray.INDArray;
* This neural net quantify the value of each action given a state * This neural net quantify the value of each action given a state
* *
*/ */
public interface IDQN<NN extends IDQN> extends NeuralNet<NN> { public interface IDQN<NN extends IDQN> extends NeuralNet<NN>, IOutputNeuralNet {
boolean isRecurrent(); boolean isRecurrent();
@ -37,9 +38,6 @@ public interface IDQN<NN extends IDQN> extends NeuralNet<NN> {
void fit(INDArray input, INDArray[] labels); void fit(INDArray input, INDArray[] labels);
INDArray output(INDArray batch);
INDArray output(Observation observation);
INDArray[] outputAll(INDArray batch); INDArray[] outputAll(INDArray batch);
NN clone(); NN clone();

View File

@ -1,10 +1,13 @@
package org.deeplearning4j.rl4j.learning.sync.qlearning.discrete.TDTargetAlgorithm; package org.deeplearning4j.rl4j.learning.sync.qlearning.discrete.TDTargetAlgorithm;
import org.deeplearning4j.rl4j.learning.sync.Transition; import org.deeplearning4j.rl4j.learning.sync.Transition;
import org.deeplearning4j.rl4j.learning.sync.support.MockDQN; import org.deeplearning4j.rl4j.network.IOutputNeuralNet;
import org.deeplearning4j.rl4j.learning.sync.support.MockTargetQNetworkSource;
import org.deeplearning4j.rl4j.observation.Observation; import org.deeplearning4j.rl4j.observation.Observation;
import org.junit.Before;
import org.junit.Test; import org.junit.Test;
import org.junit.runner.RunWith;
import org.mockito.Mock;
import org.mockito.junit.MockitoJUnitRunner;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.api.DataSet; import org.nd4j.linalg.dataset.api.DataSet;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
@ -13,16 +16,29 @@ import java.util.ArrayList;
import java.util.List; import java.util.List;
import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertEquals;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.when;
@RunWith(MockitoJUnitRunner.class)
public class DoubleDQNTest { public class DoubleDQNTest {
@Mock
IOutputNeuralNet qNetworkMock;
@Mock
IOutputNeuralNet targetQNetworkMock;
@Before
public void setup() {
when(qNetworkMock.output(any(INDArray.class))).thenAnswer(i -> i.getArguments()[0]);
}
@Test @Test
public void when_isTerminal_expect_rewardValueAtIdx0() { public void when_isTerminal_expect_rewardValueAtIdx0() {
// Assemble // Assemble
MockDQN qNetwork = new MockDQN(); when(targetQNetworkMock.output(any(INDArray.class))).thenAnswer(i -> i.getArguments()[0]);
MockDQN targetQNetwork = new MockDQN();
MockTargetQNetworkSource targetQNetworkSource = new MockTargetQNetworkSource(qNetwork, targetQNetwork);
List<Transition<Integer>> transitions = new ArrayList<Transition<Integer>>() { List<Transition<Integer>> transitions = new ArrayList<Transition<Integer>>() {
{ {
@ -31,7 +47,7 @@ public class DoubleDQNTest {
} }
}; };
DoubleDQN sut = new DoubleDQN(targetQNetworkSource, 0.5); DoubleDQN sut = new DoubleDQN(qNetworkMock, targetQNetworkMock, 0.5);
// Act // Act
DataSet result = sut.computeTDTargets(transitions); DataSet result = sut.computeTDTargets(transitions);
@ -46,9 +62,7 @@ public class DoubleDQNTest {
public void when_isNotTerminal_expect_rewardPlusEstimatedQValue() { public void when_isNotTerminal_expect_rewardPlusEstimatedQValue() {
// Assemble // Assemble
MockDQN qNetwork = new MockDQN(); when(targetQNetworkMock.output(any(INDArray.class))).thenAnswer(i -> ((INDArray)i.getArguments()[0]).mul(-1.0));
MockDQN targetQNetwork = new MockDQN(-1.0);
MockTargetQNetworkSource targetQNetworkSource = new MockTargetQNetworkSource(qNetwork, targetQNetwork);
List<Transition<Integer>> transitions = new ArrayList<Transition<Integer>>() { List<Transition<Integer>> transitions = new ArrayList<Transition<Integer>>() {
{ {
@ -57,7 +71,7 @@ public class DoubleDQNTest {
} }
}; };
DoubleDQN sut = new DoubleDQN(targetQNetworkSource, 0.5); DoubleDQN sut = new DoubleDQN(qNetworkMock, targetQNetworkMock, 0.5);
// Act // Act
DataSet result = sut.computeTDTargets(transitions); DataSet result = sut.computeTDTargets(transitions);
@ -72,9 +86,7 @@ public class DoubleDQNTest {
public void when_batchHasMoreThanOne_expect_everySampleEvaluated() { public void when_batchHasMoreThanOne_expect_everySampleEvaluated() {
// Assemble // Assemble
MockDQN qNetwork = new MockDQN(); when(targetQNetworkMock.output(any(INDArray.class))).thenAnswer(i -> ((INDArray)i.getArguments()[0]).mul(-1.0));
MockDQN targetQNetwork = new MockDQN(-1.0);
MockTargetQNetworkSource targetQNetworkSource = new MockTargetQNetworkSource(qNetwork, targetQNetwork);
List<Transition<Integer>> transitions = new ArrayList<Transition<Integer>>() { List<Transition<Integer>> transitions = new ArrayList<Transition<Integer>>() {
{ {
@ -87,7 +99,7 @@ public class DoubleDQNTest {
} }
}; };
DoubleDQN sut = new DoubleDQN(targetQNetworkSource, 0.5); DoubleDQN sut = new DoubleDQN(qNetworkMock, targetQNetworkMock, 0.5);
// Act // Act
DataSet result = sut.computeTDTargets(transitions); DataSet result = sut.computeTDTargets(transitions);

View File

@ -1,10 +1,13 @@
package org.deeplearning4j.rl4j.learning.sync.qlearning.discrete.TDTargetAlgorithm; package org.deeplearning4j.rl4j.learning.sync.qlearning.discrete.TDTargetAlgorithm;
import org.deeplearning4j.rl4j.learning.sync.Transition; import org.deeplearning4j.rl4j.learning.sync.Transition;
import org.deeplearning4j.rl4j.learning.sync.support.MockDQN; import org.deeplearning4j.rl4j.network.IOutputNeuralNet;
import org.deeplearning4j.rl4j.learning.sync.support.MockTargetQNetworkSource;
import org.deeplearning4j.rl4j.observation.Observation; import org.deeplearning4j.rl4j.observation.Observation;
import org.junit.Before;
import org.junit.Test; import org.junit.Test;
import org.junit.runner.RunWith;
import org.mockito.Mock;
import org.mockito.junit.MockitoJUnitRunner;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.api.DataSet; import org.nd4j.linalg.dataset.api.DataSet;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
@ -12,17 +15,31 @@ import org.nd4j.linalg.factory.Nd4j;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.List; import java.util.List;
import static org.junit.Assert.*; import static org.junit.Assert.assertEquals;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.when;
@RunWith(MockitoJUnitRunner.class)
public class StandardDQNTest { public class StandardDQNTest {
@Mock
IOutputNeuralNet qNetworkMock;
@Mock
IOutputNeuralNet targetQNetworkMock;
@Before
public void setup() {
when(qNetworkMock.output(any(INDArray.class))).thenAnswer(i -> i.getArguments()[0]);
when(targetQNetworkMock.output(any(INDArray.class))).thenAnswer(i -> i.getArguments()[0]);
}
@Test @Test
public void when_isTerminal_expect_rewardValueAtIdx0() { public void when_isTerminal_expect_rewardValueAtIdx0() {
// Assemble // Assemble
MockDQN qNetwork = new MockDQN();
MockDQN targetQNetwork = new MockDQN();
MockTargetQNetworkSource targetQNetworkSource = new MockTargetQNetworkSource(qNetwork, targetQNetwork);
List<Transition<Integer>> transitions = new ArrayList<Transition<Integer>>() { List<Transition<Integer>> transitions = new ArrayList<Transition<Integer>>() {
{ {
add(buildTransition(buildObservation(new double[]{1.1, 2.2}), add(buildTransition(buildObservation(new double[]{1.1, 2.2}),
@ -30,7 +47,7 @@ public class StandardDQNTest {
} }
}; };
StandardDQN sut = new StandardDQN(targetQNetworkSource, 0.5); StandardDQN sut = new StandardDQN(qNetworkMock, targetQNetworkMock, 0.5);
// Act // Act
DataSet result = sut.computeTDTargets(transitions); DataSet result = sut.computeTDTargets(transitions);
@ -45,10 +62,6 @@ public class StandardDQNTest {
public void when_isNotTerminal_expect_rewardPlusEstimatedQValue() { public void when_isNotTerminal_expect_rewardPlusEstimatedQValue() {
// Assemble // Assemble
MockDQN qNetwork = new MockDQN();
MockDQN targetQNetwork = new MockDQN();
MockTargetQNetworkSource targetQNetworkSource = new MockTargetQNetworkSource(qNetwork, targetQNetwork);
List<Transition<Integer>> transitions = new ArrayList<Transition<Integer>>() { List<Transition<Integer>> transitions = new ArrayList<Transition<Integer>>() {
{ {
add(buildTransition(buildObservation(new double[]{1.1, 2.2}), add(buildTransition(buildObservation(new double[]{1.1, 2.2}),
@ -56,7 +69,7 @@ public class StandardDQNTest {
} }
}; };
StandardDQN sut = new StandardDQN(targetQNetworkSource, 0.5); StandardDQN sut = new StandardDQN(qNetworkMock, targetQNetworkMock, 0.5);
// Act // Act
DataSet result = sut.computeTDTargets(transitions); DataSet result = sut.computeTDTargets(transitions);
@ -71,10 +84,6 @@ public class StandardDQNTest {
public void when_batchHasMoreThanOne_expect_everySampleEvaluated() { public void when_batchHasMoreThanOne_expect_everySampleEvaluated() {
// Assemble // Assemble
MockDQN qNetwork = new MockDQN();
MockDQN targetQNetwork = new MockDQN();
MockTargetQNetworkSource targetQNetworkSource = new MockTargetQNetworkSource(qNetwork, targetQNetwork);
List<Transition<Integer>> transitions = new ArrayList<Transition<Integer>>() { List<Transition<Integer>> transitions = new ArrayList<Transition<Integer>>() {
{ {
add(buildTransition(buildObservation(new double[]{1.1, 2.2}), add(buildTransition(buildObservation(new double[]{1.1, 2.2}),
@ -86,7 +95,7 @@ public class StandardDQNTest {
} }
}; };
StandardDQN sut = new StandardDQN(targetQNetworkSource, 0.5); StandardDQN sut = new StandardDQN(qNetworkMock, targetQNetworkMock, 0.5);
// Act // Act
DataSet result = sut.computeTDTargets(transitions); DataSet result = sut.computeTDTargets(transitions);

View File

@ -1,26 +0,0 @@
package org.deeplearning4j.rl4j.learning.sync.support;
import org.deeplearning4j.rl4j.learning.sync.qlearning.TargetQNetworkSource;
import org.deeplearning4j.rl4j.network.dqn.IDQN;
public class MockTargetQNetworkSource implements TargetQNetworkSource {
private final IDQN qNetwork;
private final IDQN targetQNetwork;
public MockTargetQNetworkSource(IDQN qNetwork, IDQN targetQNetwork) {
this.qNetwork = qNetwork;
this.targetQNetwork = targetQNetwork;
}
@Override
public IDQN getTargetQNetwork() {
return targetQNetwork;
}
@Override
public IDQN getQNetwork() {
return qNetwork;
}
}