RL4J: Use directly NeuralNet instances in DoubleDQN and StandardDQN (#499)
Signed-off-by: Alexandre Boulanger <aboulang2002@yahoo.com>master
parent
654afc810d
commit
fb578fdecd
|
@ -17,7 +17,6 @@ package org.deeplearning4j.rl4j.agent.update;
|
|||
|
||||
import lombok.Getter;
|
||||
import org.deeplearning4j.rl4j.learning.sync.Transition;
|
||||
import org.deeplearning4j.rl4j.learning.sync.qlearning.TargetQNetworkSource;
|
||||
import org.deeplearning4j.rl4j.learning.sync.qlearning.discrete.TDTargetAlgorithm.DoubleDQN;
|
||||
import org.deeplearning4j.rl4j.learning.sync.qlearning.discrete.TDTargetAlgorithm.ITDTargetAlgorithm;
|
||||
import org.deeplearning4j.rl4j.learning.sync.qlearning.discrete.TDTargetAlgorithm.StandardDQN;
|
||||
|
@ -28,13 +27,10 @@ import java.util.List;
|
|||
|
||||
// Temporary class that will be replaced with a more generic class that delegates gradient computation
|
||||
// and network update to sub components.
|
||||
public class DQNNeuralNetUpdateRule implements IUpdateRule<Transition<Integer>>, TargetQNetworkSource {
|
||||
public class DQNNeuralNetUpdateRule implements IUpdateRule<Transition<Integer>> {
|
||||
|
||||
@Getter
|
||||
private final IDQN qNetwork;
|
||||
|
||||
@Getter
|
||||
private IDQN targetQNetwork;
|
||||
private final IDQN targetQNetwork;
|
||||
private final int targetUpdateFrequency;
|
||||
|
||||
private final ITDTargetAlgorithm<Integer> tdTargetAlgorithm;
|
||||
|
@ -47,8 +43,8 @@ public class DQNNeuralNetUpdateRule implements IUpdateRule<Transition<Integer>>,
|
|||
this.targetQNetwork = qNetwork.clone();
|
||||
this.targetUpdateFrequency = targetUpdateFrequency;
|
||||
tdTargetAlgorithm = isDoubleDQN
|
||||
? new DoubleDQN(this, gamma, errorClamp)
|
||||
: new StandardDQN(this, gamma, errorClamp);
|
||||
? new DoubleDQN(qNetwork, targetQNetwork, gamma, errorClamp)
|
||||
: new StandardDQN(qNetwork, targetQNetwork, gamma, errorClamp);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -56,7 +52,7 @@ public class DQNNeuralNetUpdateRule implements IUpdateRule<Transition<Integer>>,
|
|||
DataSet targets = tdTargetAlgorithm.computeTDTargets(trainingBatch);
|
||||
qNetwork.fit(targets.getFeatures(), targets.getLabels());
|
||||
if(++updateCount % targetUpdateFrequency == 0) {
|
||||
targetQNetwork = qNetwork.clone();
|
||||
targetQNetwork.copy(qNetwork);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,28 +0,0 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2019 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
package org.deeplearning4j.rl4j.learning.sync.qlearning;
|
||||
|
||||
import org.deeplearning4j.rl4j.network.dqn.IDQN;
|
||||
|
||||
/**
|
||||
* An interface that is an extension of {@link QNetworkSource} for all implementations capable of supplying a target Q-Network
|
||||
*
|
||||
* @author Alexandre Boulanger
|
||||
*/
|
||||
public interface TargetQNetworkSource extends QNetworkSource {
|
||||
IDQN getTargetQNetwork();
|
||||
}
|
|
@ -16,8 +16,7 @@
|
|||
|
||||
package org.deeplearning4j.rl4j.learning.sync.qlearning.discrete.TDTargetAlgorithm;
|
||||
|
||||
import org.deeplearning4j.rl4j.learning.sync.qlearning.TargetQNetworkSource;
|
||||
import org.deeplearning4j.rl4j.network.dqn.IDQN;
|
||||
import org.deeplearning4j.rl4j.network.IOutputNeuralNet;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
|
||||
/**
|
||||
|
@ -28,7 +27,7 @@ import org.nd4j.linalg.api.ndarray.INDArray;
|
|||
*/
|
||||
public abstract class BaseDQNAlgorithm extends BaseTDTargetAlgorithm {
|
||||
|
||||
private final TargetQNetworkSource qTargetNetworkSource;
|
||||
private final IOutputNeuralNet targetQNetwork;
|
||||
|
||||
/**
|
||||
* In litterature, this corresponds to Q{net}(s(t+1), a)
|
||||
|
@ -40,23 +39,21 @@ public abstract class BaseDQNAlgorithm extends BaseTDTargetAlgorithm {
|
|||
*/
|
||||
protected INDArray targetQNetworkNextObservation;
|
||||
|
||||
protected BaseDQNAlgorithm(TargetQNetworkSource qTargetNetworkSource, double gamma) {
|
||||
super(qTargetNetworkSource, gamma);
|
||||
this.qTargetNetworkSource = qTargetNetworkSource;
|
||||
protected BaseDQNAlgorithm(IOutputNeuralNet qNetwork, IOutputNeuralNet targetQNetwork, double gamma) {
|
||||
super(qNetwork, gamma);
|
||||
this.targetQNetwork = targetQNetwork;
|
||||
}
|
||||
|
||||
protected BaseDQNAlgorithm(TargetQNetworkSource qTargetNetworkSource, double gamma, double errorClamp) {
|
||||
super(qTargetNetworkSource, gamma, errorClamp);
|
||||
this.qTargetNetworkSource = qTargetNetworkSource;
|
||||
protected BaseDQNAlgorithm(IOutputNeuralNet qNetwork, IOutputNeuralNet targetQNetwork, double gamma, double errorClamp) {
|
||||
super(qNetwork, gamma, errorClamp);
|
||||
this.targetQNetwork = targetQNetwork;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected void initComputation(INDArray observations, INDArray nextObservations) {
|
||||
super.initComputation(observations, nextObservations);
|
||||
|
||||
qNetworkNextObservation = qNetworkSource.getQNetwork().output(nextObservations);
|
||||
|
||||
IDQN targetQNetwork = qTargetNetworkSource.getTargetQNetwork();
|
||||
qNetworkNextObservation = qNetwork.output(nextObservations);
|
||||
targetQNetworkNextObservation = targetQNetwork.output(nextObservations);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -17,7 +17,7 @@
|
|||
package org.deeplearning4j.rl4j.learning.sync.qlearning.discrete.TDTargetAlgorithm;
|
||||
|
||||
import org.deeplearning4j.rl4j.learning.sync.Transition;
|
||||
import org.deeplearning4j.rl4j.learning.sync.qlearning.QNetworkSource;
|
||||
import org.deeplearning4j.rl4j.network.IOutputNeuralNet;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.dataset.api.DataSet;
|
||||
|
||||
|
@ -30,7 +30,7 @@ import java.util.List;
|
|||
*/
|
||||
public abstract class BaseTDTargetAlgorithm implements ITDTargetAlgorithm<Integer> {
|
||||
|
||||
protected final QNetworkSource qNetworkSource;
|
||||
protected final IOutputNeuralNet qNetwork;
|
||||
protected final double gamma;
|
||||
|
||||
private final double errorClamp;
|
||||
|
@ -38,12 +38,12 @@ public abstract class BaseTDTargetAlgorithm implements ITDTargetAlgorithm<Intege
|
|||
|
||||
/**
|
||||
*
|
||||
* @param qNetworkSource The source of the Q-Network
|
||||
* @param qNetwork The Q-Network
|
||||
* @param gamma The discount factor
|
||||
* @param errorClamp Will prevent the new Q-Value from being farther than <i>errorClamp</i> away from the previous value. Double.NaN will disable the clamping.
|
||||
*/
|
||||
protected BaseTDTargetAlgorithm(QNetworkSource qNetworkSource, double gamma, double errorClamp) {
|
||||
this.qNetworkSource = qNetworkSource;
|
||||
protected BaseTDTargetAlgorithm(IOutputNeuralNet qNetwork, double gamma, double errorClamp) {
|
||||
this.qNetwork = qNetwork;
|
||||
this.gamma = gamma;
|
||||
|
||||
this.errorClamp = errorClamp;
|
||||
|
@ -52,12 +52,12 @@ public abstract class BaseTDTargetAlgorithm implements ITDTargetAlgorithm<Intege
|
|||
|
||||
/**
|
||||
*
|
||||
* @param qNetworkSource The source of the Q-Network
|
||||
* @param qNetwork The Q-Network
|
||||
* @param gamma The discount factor
|
||||
* Note: Error clamping is disabled with this ctor
|
||||
*/
|
||||
protected BaseTDTargetAlgorithm(QNetworkSource qNetworkSource, double gamma) {
|
||||
this(qNetworkSource, gamma, Double.NaN);
|
||||
protected BaseTDTargetAlgorithm(IOutputNeuralNet qNetwork, double gamma) {
|
||||
this(qNetwork, gamma, Double.NaN);
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -89,8 +89,7 @@ public abstract class BaseTDTargetAlgorithm implements ITDTargetAlgorithm<Intege
|
|||
|
||||
initComputation(observations, nextObservations);
|
||||
|
||||
INDArray updatedQValues = qNetworkSource.getQNetwork().output(observations);
|
||||
|
||||
INDArray updatedQValues = qNetwork.output(observations);
|
||||
for (int i = 0; i < size; ++i) {
|
||||
Transition<Integer> transition = transitions.get(i);
|
||||
double yTarget = computeTarget(i, transition.getReward(), transition.isTerminal());
|
||||
|
|
|
@ -16,7 +16,7 @@
|
|||
|
||||
package org.deeplearning4j.rl4j.learning.sync.qlearning.discrete.TDTargetAlgorithm;
|
||||
|
||||
import org.deeplearning4j.rl4j.learning.sync.qlearning.TargetQNetworkSource;
|
||||
import org.deeplearning4j.rl4j.network.IOutputNeuralNet;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
|
||||
|
@ -32,12 +32,12 @@ public class DoubleDQN extends BaseDQNAlgorithm {
|
|||
// In litterature, this corresponds to: max_{a}Q(s_{t+1}, a)
|
||||
private INDArray maxActionsFromQNetworkNextObservation;
|
||||
|
||||
public DoubleDQN(TargetQNetworkSource qTargetNetworkSource, double gamma) {
|
||||
super(qTargetNetworkSource, gamma);
|
||||
public DoubleDQN(IOutputNeuralNet qNetwork, IOutputNeuralNet targetQNetwork, double gamma) {
|
||||
super(qNetwork, targetQNetwork, gamma);
|
||||
}
|
||||
|
||||
public DoubleDQN(TargetQNetworkSource qTargetNetworkSource, double gamma, double errorClamp) {
|
||||
super(qTargetNetworkSource, gamma, errorClamp);
|
||||
public DoubleDQN(IOutputNeuralNet qNetwork, IOutputNeuralNet targetQNetwork, double gamma, double errorClamp) {
|
||||
super(qNetwork, targetQNetwork, gamma, errorClamp);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
|
|
@ -16,7 +16,7 @@
|
|||
|
||||
package org.deeplearning4j.rl4j.learning.sync.qlearning.discrete.TDTargetAlgorithm;
|
||||
|
||||
import org.deeplearning4j.rl4j.learning.sync.qlearning.TargetQNetworkSource;
|
||||
import org.deeplearning4j.rl4j.network.IOutputNeuralNet;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
|
||||
|
@ -32,12 +32,12 @@ public class StandardDQN extends BaseDQNAlgorithm {
|
|||
// In litterature, this corresponds to: max_{a}Q_{tar}(s_{t+1}, a)
|
||||
private INDArray maxActionsFromQTargetNextObservation;
|
||||
|
||||
public StandardDQN(TargetQNetworkSource qTargetNetworkSource, double gamma) {
|
||||
super(qTargetNetworkSource, gamma);
|
||||
public StandardDQN(IOutputNeuralNet qNetwork, IOutputNeuralNet targetQNetwork, double gamma) {
|
||||
super(qNetwork, targetQNetwork, gamma);
|
||||
}
|
||||
|
||||
public StandardDQN(TargetQNetworkSource qTargetNetworkSource, double gamma, double errorClamp) {
|
||||
super(qTargetNetworkSource, gamma, errorClamp);
|
||||
public StandardDQN(IOutputNeuralNet qNetwork, IOutputNeuralNet targetQNetwork, double gamma, double errorClamp) {
|
||||
super(qNetwork, targetQNetwork, gamma, errorClamp);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
|
|
@ -1,28 +1,38 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2019 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
package org.deeplearning4j.rl4j.learning.sync.qlearning;
|
||||
|
||||
import org.deeplearning4j.rl4j.network.dqn.IDQN;
|
||||
|
||||
/**
|
||||
* An interface for all implementations capable of supplying a Q-Network
|
||||
*
|
||||
* @author Alexandre Boulanger
|
||||
*/
|
||||
public interface QNetworkSource {
|
||||
IDQN getQNetwork();
|
||||
}
|
||||
/*******************************************************************************
|
||||
* Copyright (c) 2020 Konduit K.K.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
package org.deeplearning4j.rl4j.network;
|
||||
|
||||
import org.deeplearning4j.rl4j.observation.Observation;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
|
||||
/**
|
||||
* An interface defining the output aspect of a {@link NeuralNet}.
|
||||
*/
|
||||
public interface IOutputNeuralNet {
|
||||
/**
|
||||
* Compute the output for the supplied observation.
|
||||
* @param observation An {@link Observation}
|
||||
* @return The ouptut of the network
|
||||
*/
|
||||
INDArray output(Observation observation);
|
||||
|
||||
/**
|
||||
* Compute the output for the supplied batch.
|
||||
* @param batch
|
||||
* @return The ouptut of the network
|
||||
*/
|
||||
INDArray output(INDArray batch);
|
||||
}
|
|
@ -17,6 +17,7 @@
|
|||
package org.deeplearning4j.rl4j.network.dqn;
|
||||
|
||||
import org.deeplearning4j.nn.gradient.Gradient;
|
||||
import org.deeplearning4j.rl4j.network.IOutputNeuralNet;
|
||||
import org.deeplearning4j.rl4j.network.NeuralNet;
|
||||
import org.deeplearning4j.rl4j.observation.Observation;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
|
@ -27,7 +28,7 @@ import org.nd4j.linalg.api.ndarray.INDArray;
|
|||
* This neural net quantify the value of each action given a state
|
||||
*
|
||||
*/
|
||||
public interface IDQN<NN extends IDQN> extends NeuralNet<NN> {
|
||||
public interface IDQN<NN extends IDQN> extends NeuralNet<NN>, IOutputNeuralNet {
|
||||
|
||||
boolean isRecurrent();
|
||||
|
||||
|
@ -37,9 +38,6 @@ public interface IDQN<NN extends IDQN> extends NeuralNet<NN> {
|
|||
|
||||
void fit(INDArray input, INDArray[] labels);
|
||||
|
||||
INDArray output(INDArray batch);
|
||||
INDArray output(Observation observation);
|
||||
|
||||
INDArray[] outputAll(INDArray batch);
|
||||
|
||||
NN clone();
|
||||
|
|
|
@ -1,10 +1,13 @@
|
|||
package org.deeplearning4j.rl4j.learning.sync.qlearning.discrete.TDTargetAlgorithm;
|
||||
|
||||
import org.deeplearning4j.rl4j.learning.sync.Transition;
|
||||
import org.deeplearning4j.rl4j.learning.sync.support.MockDQN;
|
||||
import org.deeplearning4j.rl4j.learning.sync.support.MockTargetQNetworkSource;
|
||||
import org.deeplearning4j.rl4j.network.IOutputNeuralNet;
|
||||
import org.deeplearning4j.rl4j.observation.Observation;
|
||||
import org.junit.Before;
|
||||
import org.junit.Test;
|
||||
import org.junit.runner.RunWith;
|
||||
import org.mockito.Mock;
|
||||
import org.mockito.junit.MockitoJUnitRunner;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.dataset.api.DataSet;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
|
@ -13,16 +16,29 @@ import java.util.ArrayList;
|
|||
import java.util.List;
|
||||
|
||||
import static org.junit.Assert.assertEquals;
|
||||
import static org.mockito.ArgumentMatchers.any;
|
||||
import static org.mockito.Mockito.when;
|
||||
|
||||
@RunWith(MockitoJUnitRunner.class)
|
||||
public class DoubleDQNTest {
|
||||
|
||||
@Mock
|
||||
IOutputNeuralNet qNetworkMock;
|
||||
|
||||
@Mock
|
||||
IOutputNeuralNet targetQNetworkMock;
|
||||
|
||||
|
||||
@Before
|
||||
public void setup() {
|
||||
when(qNetworkMock.output(any(INDArray.class))).thenAnswer(i -> i.getArguments()[0]);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void when_isTerminal_expect_rewardValueAtIdx0() {
|
||||
|
||||
// Assemble
|
||||
MockDQN qNetwork = new MockDQN();
|
||||
MockDQN targetQNetwork = new MockDQN();
|
||||
MockTargetQNetworkSource targetQNetworkSource = new MockTargetQNetworkSource(qNetwork, targetQNetwork);
|
||||
when(targetQNetworkMock.output(any(INDArray.class))).thenAnswer(i -> i.getArguments()[0]);
|
||||
|
||||
List<Transition<Integer>> transitions = new ArrayList<Transition<Integer>>() {
|
||||
{
|
||||
|
@ -31,7 +47,7 @@ public class DoubleDQNTest {
|
|||
}
|
||||
};
|
||||
|
||||
DoubleDQN sut = new DoubleDQN(targetQNetworkSource, 0.5);
|
||||
DoubleDQN sut = new DoubleDQN(qNetworkMock, targetQNetworkMock, 0.5);
|
||||
|
||||
// Act
|
||||
DataSet result = sut.computeTDTargets(transitions);
|
||||
|
@ -46,9 +62,7 @@ public class DoubleDQNTest {
|
|||
public void when_isNotTerminal_expect_rewardPlusEstimatedQValue() {
|
||||
|
||||
// Assemble
|
||||
MockDQN qNetwork = new MockDQN();
|
||||
MockDQN targetQNetwork = new MockDQN(-1.0);
|
||||
MockTargetQNetworkSource targetQNetworkSource = new MockTargetQNetworkSource(qNetwork, targetQNetwork);
|
||||
when(targetQNetworkMock.output(any(INDArray.class))).thenAnswer(i -> ((INDArray)i.getArguments()[0]).mul(-1.0));
|
||||
|
||||
List<Transition<Integer>> transitions = new ArrayList<Transition<Integer>>() {
|
||||
{
|
||||
|
@ -57,7 +71,7 @@ public class DoubleDQNTest {
|
|||
}
|
||||
};
|
||||
|
||||
DoubleDQN sut = new DoubleDQN(targetQNetworkSource, 0.5);
|
||||
DoubleDQN sut = new DoubleDQN(qNetworkMock, targetQNetworkMock, 0.5);
|
||||
|
||||
// Act
|
||||
DataSet result = sut.computeTDTargets(transitions);
|
||||
|
@ -72,9 +86,7 @@ public class DoubleDQNTest {
|
|||
public void when_batchHasMoreThanOne_expect_everySampleEvaluated() {
|
||||
|
||||
// Assemble
|
||||
MockDQN qNetwork = new MockDQN();
|
||||
MockDQN targetQNetwork = new MockDQN(-1.0);
|
||||
MockTargetQNetworkSource targetQNetworkSource = new MockTargetQNetworkSource(qNetwork, targetQNetwork);
|
||||
when(targetQNetworkMock.output(any(INDArray.class))).thenAnswer(i -> ((INDArray)i.getArguments()[0]).mul(-1.0));
|
||||
|
||||
List<Transition<Integer>> transitions = new ArrayList<Transition<Integer>>() {
|
||||
{
|
||||
|
@ -87,7 +99,7 @@ public class DoubleDQNTest {
|
|||
}
|
||||
};
|
||||
|
||||
DoubleDQN sut = new DoubleDQN(targetQNetworkSource, 0.5);
|
||||
DoubleDQN sut = new DoubleDQN(qNetworkMock, targetQNetworkMock, 0.5);
|
||||
|
||||
// Act
|
||||
DataSet result = sut.computeTDTargets(transitions);
|
||||
|
|
|
@ -1,10 +1,13 @@
|
|||
package org.deeplearning4j.rl4j.learning.sync.qlearning.discrete.TDTargetAlgorithm;
|
||||
|
||||
import org.deeplearning4j.rl4j.learning.sync.Transition;
|
||||
import org.deeplearning4j.rl4j.learning.sync.support.MockDQN;
|
||||
import org.deeplearning4j.rl4j.learning.sync.support.MockTargetQNetworkSource;
|
||||
import org.deeplearning4j.rl4j.network.IOutputNeuralNet;
|
||||
import org.deeplearning4j.rl4j.observation.Observation;
|
||||
import org.junit.Before;
|
||||
import org.junit.Test;
|
||||
import org.junit.runner.RunWith;
|
||||
import org.mockito.Mock;
|
||||
import org.mockito.junit.MockitoJUnitRunner;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.dataset.api.DataSet;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
|
@ -12,17 +15,31 @@ import org.nd4j.linalg.factory.Nd4j;
|
|||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
|
||||
import static org.junit.Assert.*;
|
||||
import static org.junit.Assert.assertEquals;
|
||||
import static org.mockito.ArgumentMatchers.any;
|
||||
import static org.mockito.Mockito.when;
|
||||
|
||||
@RunWith(MockitoJUnitRunner.class)
|
||||
public class StandardDQNTest {
|
||||
|
||||
@Mock
|
||||
IOutputNeuralNet qNetworkMock;
|
||||
|
||||
@Mock
|
||||
IOutputNeuralNet targetQNetworkMock;
|
||||
|
||||
|
||||
@Before
|
||||
public void setup() {
|
||||
when(qNetworkMock.output(any(INDArray.class))).thenAnswer(i -> i.getArguments()[0]);
|
||||
when(targetQNetworkMock.output(any(INDArray.class))).thenAnswer(i -> i.getArguments()[0]);
|
||||
}
|
||||
|
||||
|
||||
@Test
|
||||
public void when_isTerminal_expect_rewardValueAtIdx0() {
|
||||
|
||||
// Assemble
|
||||
MockDQN qNetwork = new MockDQN();
|
||||
MockDQN targetQNetwork = new MockDQN();
|
||||
MockTargetQNetworkSource targetQNetworkSource = new MockTargetQNetworkSource(qNetwork, targetQNetwork);
|
||||
|
||||
List<Transition<Integer>> transitions = new ArrayList<Transition<Integer>>() {
|
||||
{
|
||||
add(buildTransition(buildObservation(new double[]{1.1, 2.2}),
|
||||
|
@ -30,7 +47,7 @@ public class StandardDQNTest {
|
|||
}
|
||||
};
|
||||
|
||||
StandardDQN sut = new StandardDQN(targetQNetworkSource, 0.5);
|
||||
StandardDQN sut = new StandardDQN(qNetworkMock, targetQNetworkMock, 0.5);
|
||||
|
||||
// Act
|
||||
DataSet result = sut.computeTDTargets(transitions);
|
||||
|
@ -45,10 +62,6 @@ public class StandardDQNTest {
|
|||
public void when_isNotTerminal_expect_rewardPlusEstimatedQValue() {
|
||||
|
||||
// Assemble
|
||||
MockDQN qNetwork = new MockDQN();
|
||||
MockDQN targetQNetwork = new MockDQN();
|
||||
MockTargetQNetworkSource targetQNetworkSource = new MockTargetQNetworkSource(qNetwork, targetQNetwork);
|
||||
|
||||
List<Transition<Integer>> transitions = new ArrayList<Transition<Integer>>() {
|
||||
{
|
||||
add(buildTransition(buildObservation(new double[]{1.1, 2.2}),
|
||||
|
@ -56,7 +69,7 @@ public class StandardDQNTest {
|
|||
}
|
||||
};
|
||||
|
||||
StandardDQN sut = new StandardDQN(targetQNetworkSource, 0.5);
|
||||
StandardDQN sut = new StandardDQN(qNetworkMock, targetQNetworkMock, 0.5);
|
||||
|
||||
// Act
|
||||
DataSet result = sut.computeTDTargets(transitions);
|
||||
|
@ -71,10 +84,6 @@ public class StandardDQNTest {
|
|||
public void when_batchHasMoreThanOne_expect_everySampleEvaluated() {
|
||||
|
||||
// Assemble
|
||||
MockDQN qNetwork = new MockDQN();
|
||||
MockDQN targetQNetwork = new MockDQN();
|
||||
MockTargetQNetworkSource targetQNetworkSource = new MockTargetQNetworkSource(qNetwork, targetQNetwork);
|
||||
|
||||
List<Transition<Integer>> transitions = new ArrayList<Transition<Integer>>() {
|
||||
{
|
||||
add(buildTransition(buildObservation(new double[]{1.1, 2.2}),
|
||||
|
@ -86,7 +95,7 @@ public class StandardDQNTest {
|
|||
}
|
||||
};
|
||||
|
||||
StandardDQN sut = new StandardDQN(targetQNetworkSource, 0.5);
|
||||
StandardDQN sut = new StandardDQN(qNetworkMock, targetQNetworkMock, 0.5);
|
||||
|
||||
// Act
|
||||
DataSet result = sut.computeTDTargets(transitions);
|
||||
|
|
|
@ -1,26 +0,0 @@
|
|||
package org.deeplearning4j.rl4j.learning.sync.support;
|
||||
|
||||
import org.deeplearning4j.rl4j.learning.sync.qlearning.TargetQNetworkSource;
|
||||
import org.deeplearning4j.rl4j.network.dqn.IDQN;
|
||||
|
||||
public class MockTargetQNetworkSource implements TargetQNetworkSource {
|
||||
|
||||
|
||||
private final IDQN qNetwork;
|
||||
private final IDQN targetQNetwork;
|
||||
|
||||
public MockTargetQNetworkSource(IDQN qNetwork, IDQN targetQNetwork) {
|
||||
this.qNetwork = qNetwork;
|
||||
this.targetQNetwork = targetQNetwork;
|
||||
}
|
||||
|
||||
@Override
|
||||
public IDQN getTargetQNetwork() {
|
||||
return targetQNetwork;
|
||||
}
|
||||
|
||||
@Override
|
||||
public IDQN getQNetwork() {
|
||||
return qNetwork;
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue