RL4J: Use directly NeuralNet instances in DoubleDQN and StandardDQN (#499)
Signed-off-by: Alexandre Boulanger <aboulang2002@yahoo.com>master
parent
654afc810d
commit
fb578fdecd
|
@ -17,7 +17,6 @@ package org.deeplearning4j.rl4j.agent.update;
|
||||||
|
|
||||||
import lombok.Getter;
|
import lombok.Getter;
|
||||||
import org.deeplearning4j.rl4j.learning.sync.Transition;
|
import org.deeplearning4j.rl4j.learning.sync.Transition;
|
||||||
import org.deeplearning4j.rl4j.learning.sync.qlearning.TargetQNetworkSource;
|
|
||||||
import org.deeplearning4j.rl4j.learning.sync.qlearning.discrete.TDTargetAlgorithm.DoubleDQN;
|
import org.deeplearning4j.rl4j.learning.sync.qlearning.discrete.TDTargetAlgorithm.DoubleDQN;
|
||||||
import org.deeplearning4j.rl4j.learning.sync.qlearning.discrete.TDTargetAlgorithm.ITDTargetAlgorithm;
|
import org.deeplearning4j.rl4j.learning.sync.qlearning.discrete.TDTargetAlgorithm.ITDTargetAlgorithm;
|
||||||
import org.deeplearning4j.rl4j.learning.sync.qlearning.discrete.TDTargetAlgorithm.StandardDQN;
|
import org.deeplearning4j.rl4j.learning.sync.qlearning.discrete.TDTargetAlgorithm.StandardDQN;
|
||||||
|
@ -28,13 +27,10 @@ import java.util.List;
|
||||||
|
|
||||||
// Temporary class that will be replaced with a more generic class that delegates gradient computation
|
// Temporary class that will be replaced with a more generic class that delegates gradient computation
|
||||||
// and network update to sub components.
|
// and network update to sub components.
|
||||||
public class DQNNeuralNetUpdateRule implements IUpdateRule<Transition<Integer>>, TargetQNetworkSource {
|
public class DQNNeuralNetUpdateRule implements IUpdateRule<Transition<Integer>> {
|
||||||
|
|
||||||
@Getter
|
|
||||||
private final IDQN qNetwork;
|
private final IDQN qNetwork;
|
||||||
|
private final IDQN targetQNetwork;
|
||||||
@Getter
|
|
||||||
private IDQN targetQNetwork;
|
|
||||||
private final int targetUpdateFrequency;
|
private final int targetUpdateFrequency;
|
||||||
|
|
||||||
private final ITDTargetAlgorithm<Integer> tdTargetAlgorithm;
|
private final ITDTargetAlgorithm<Integer> tdTargetAlgorithm;
|
||||||
|
@ -47,8 +43,8 @@ public class DQNNeuralNetUpdateRule implements IUpdateRule<Transition<Integer>>,
|
||||||
this.targetQNetwork = qNetwork.clone();
|
this.targetQNetwork = qNetwork.clone();
|
||||||
this.targetUpdateFrequency = targetUpdateFrequency;
|
this.targetUpdateFrequency = targetUpdateFrequency;
|
||||||
tdTargetAlgorithm = isDoubleDQN
|
tdTargetAlgorithm = isDoubleDQN
|
||||||
? new DoubleDQN(this, gamma, errorClamp)
|
? new DoubleDQN(qNetwork, targetQNetwork, gamma, errorClamp)
|
||||||
: new StandardDQN(this, gamma, errorClamp);
|
: new StandardDQN(qNetwork, targetQNetwork, gamma, errorClamp);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
@ -56,7 +52,7 @@ public class DQNNeuralNetUpdateRule implements IUpdateRule<Transition<Integer>>,
|
||||||
DataSet targets = tdTargetAlgorithm.computeTDTargets(trainingBatch);
|
DataSet targets = tdTargetAlgorithm.computeTDTargets(trainingBatch);
|
||||||
qNetwork.fit(targets.getFeatures(), targets.getLabels());
|
qNetwork.fit(targets.getFeatures(), targets.getLabels());
|
||||||
if(++updateCount % targetUpdateFrequency == 0) {
|
if(++updateCount % targetUpdateFrequency == 0) {
|
||||||
targetQNetwork = qNetwork.clone();
|
targetQNetwork.copy(qNetwork);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,28 +0,0 @@
|
||||||
/*******************************************************************************
|
|
||||||
* Copyright (c) 2015-2019 Skymind, Inc.
|
|
||||||
*
|
|
||||||
* This program and the accompanying materials are made available under the
|
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* License for the specific language governing permissions and limitations
|
|
||||||
* under the License.
|
|
||||||
*
|
|
||||||
* SPDX-License-Identifier: Apache-2.0
|
|
||||||
******************************************************************************/
|
|
||||||
|
|
||||||
package org.deeplearning4j.rl4j.learning.sync.qlearning;
|
|
||||||
|
|
||||||
import org.deeplearning4j.rl4j.network.dqn.IDQN;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* An interface that is an extension of {@link QNetworkSource} for all implementations capable of supplying a target Q-Network
|
|
||||||
*
|
|
||||||
* @author Alexandre Boulanger
|
|
||||||
*/
|
|
||||||
public interface TargetQNetworkSource extends QNetworkSource {
|
|
||||||
IDQN getTargetQNetwork();
|
|
||||||
}
|
|
|
@ -16,8 +16,7 @@
|
||||||
|
|
||||||
package org.deeplearning4j.rl4j.learning.sync.qlearning.discrete.TDTargetAlgorithm;
|
package org.deeplearning4j.rl4j.learning.sync.qlearning.discrete.TDTargetAlgorithm;
|
||||||
|
|
||||||
import org.deeplearning4j.rl4j.learning.sync.qlearning.TargetQNetworkSource;
|
import org.deeplearning4j.rl4j.network.IOutputNeuralNet;
|
||||||
import org.deeplearning4j.rl4j.network.dqn.IDQN;
|
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -28,7 +27,7 @@ import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
*/
|
*/
|
||||||
public abstract class BaseDQNAlgorithm extends BaseTDTargetAlgorithm {
|
public abstract class BaseDQNAlgorithm extends BaseTDTargetAlgorithm {
|
||||||
|
|
||||||
private final TargetQNetworkSource qTargetNetworkSource;
|
private final IOutputNeuralNet targetQNetwork;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* In litterature, this corresponds to Q{net}(s(t+1), a)
|
* In litterature, this corresponds to Q{net}(s(t+1), a)
|
||||||
|
@ -40,23 +39,21 @@ public abstract class BaseDQNAlgorithm extends BaseTDTargetAlgorithm {
|
||||||
*/
|
*/
|
||||||
protected INDArray targetQNetworkNextObservation;
|
protected INDArray targetQNetworkNextObservation;
|
||||||
|
|
||||||
protected BaseDQNAlgorithm(TargetQNetworkSource qTargetNetworkSource, double gamma) {
|
protected BaseDQNAlgorithm(IOutputNeuralNet qNetwork, IOutputNeuralNet targetQNetwork, double gamma) {
|
||||||
super(qTargetNetworkSource, gamma);
|
super(qNetwork, gamma);
|
||||||
this.qTargetNetworkSource = qTargetNetworkSource;
|
this.targetQNetwork = targetQNetwork;
|
||||||
}
|
}
|
||||||
|
|
||||||
protected BaseDQNAlgorithm(TargetQNetworkSource qTargetNetworkSource, double gamma, double errorClamp) {
|
protected BaseDQNAlgorithm(IOutputNeuralNet qNetwork, IOutputNeuralNet targetQNetwork, double gamma, double errorClamp) {
|
||||||
super(qTargetNetworkSource, gamma, errorClamp);
|
super(qNetwork, gamma, errorClamp);
|
||||||
this.qTargetNetworkSource = qTargetNetworkSource;
|
this.targetQNetwork = targetQNetwork;
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
protected void initComputation(INDArray observations, INDArray nextObservations) {
|
protected void initComputation(INDArray observations, INDArray nextObservations) {
|
||||||
super.initComputation(observations, nextObservations);
|
super.initComputation(observations, nextObservations);
|
||||||
|
|
||||||
qNetworkNextObservation = qNetworkSource.getQNetwork().output(nextObservations);
|
qNetworkNextObservation = qNetwork.output(nextObservations);
|
||||||
|
|
||||||
IDQN targetQNetwork = qTargetNetworkSource.getTargetQNetwork();
|
|
||||||
targetQNetworkNextObservation = targetQNetwork.output(nextObservations);
|
targetQNetworkNextObservation = targetQNetwork.output(nextObservations);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -17,7 +17,7 @@
|
||||||
package org.deeplearning4j.rl4j.learning.sync.qlearning.discrete.TDTargetAlgorithm;
|
package org.deeplearning4j.rl4j.learning.sync.qlearning.discrete.TDTargetAlgorithm;
|
||||||
|
|
||||||
import org.deeplearning4j.rl4j.learning.sync.Transition;
|
import org.deeplearning4j.rl4j.learning.sync.Transition;
|
||||||
import org.deeplearning4j.rl4j.learning.sync.qlearning.QNetworkSource;
|
import org.deeplearning4j.rl4j.network.IOutputNeuralNet;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.dataset.api.DataSet;
|
import org.nd4j.linalg.dataset.api.DataSet;
|
||||||
|
|
||||||
|
@ -30,7 +30,7 @@ import java.util.List;
|
||||||
*/
|
*/
|
||||||
public abstract class BaseTDTargetAlgorithm implements ITDTargetAlgorithm<Integer> {
|
public abstract class BaseTDTargetAlgorithm implements ITDTargetAlgorithm<Integer> {
|
||||||
|
|
||||||
protected final QNetworkSource qNetworkSource;
|
protected final IOutputNeuralNet qNetwork;
|
||||||
protected final double gamma;
|
protected final double gamma;
|
||||||
|
|
||||||
private final double errorClamp;
|
private final double errorClamp;
|
||||||
|
@ -38,12 +38,12 @@ public abstract class BaseTDTargetAlgorithm implements ITDTargetAlgorithm<Intege
|
||||||
|
|
||||||
/**
|
/**
|
||||||
*
|
*
|
||||||
* @param qNetworkSource The source of the Q-Network
|
* @param qNetwork The Q-Network
|
||||||
* @param gamma The discount factor
|
* @param gamma The discount factor
|
||||||
* @param errorClamp Will prevent the new Q-Value from being farther than <i>errorClamp</i> away from the previous value. Double.NaN will disable the clamping.
|
* @param errorClamp Will prevent the new Q-Value from being farther than <i>errorClamp</i> away from the previous value. Double.NaN will disable the clamping.
|
||||||
*/
|
*/
|
||||||
protected BaseTDTargetAlgorithm(QNetworkSource qNetworkSource, double gamma, double errorClamp) {
|
protected BaseTDTargetAlgorithm(IOutputNeuralNet qNetwork, double gamma, double errorClamp) {
|
||||||
this.qNetworkSource = qNetworkSource;
|
this.qNetwork = qNetwork;
|
||||||
this.gamma = gamma;
|
this.gamma = gamma;
|
||||||
|
|
||||||
this.errorClamp = errorClamp;
|
this.errorClamp = errorClamp;
|
||||||
|
@ -52,12 +52,12 @@ public abstract class BaseTDTargetAlgorithm implements ITDTargetAlgorithm<Intege
|
||||||
|
|
||||||
/**
|
/**
|
||||||
*
|
*
|
||||||
* @param qNetworkSource The source of the Q-Network
|
* @param qNetwork The Q-Network
|
||||||
* @param gamma The discount factor
|
* @param gamma The discount factor
|
||||||
* Note: Error clamping is disabled with this ctor
|
* Note: Error clamping is disabled with this ctor
|
||||||
*/
|
*/
|
||||||
protected BaseTDTargetAlgorithm(QNetworkSource qNetworkSource, double gamma) {
|
protected BaseTDTargetAlgorithm(IOutputNeuralNet qNetwork, double gamma) {
|
||||||
this(qNetworkSource, gamma, Double.NaN);
|
this(qNetwork, gamma, Double.NaN);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -89,8 +89,7 @@ public abstract class BaseTDTargetAlgorithm implements ITDTargetAlgorithm<Intege
|
||||||
|
|
||||||
initComputation(observations, nextObservations);
|
initComputation(observations, nextObservations);
|
||||||
|
|
||||||
INDArray updatedQValues = qNetworkSource.getQNetwork().output(observations);
|
INDArray updatedQValues = qNetwork.output(observations);
|
||||||
|
|
||||||
for (int i = 0; i < size; ++i) {
|
for (int i = 0; i < size; ++i) {
|
||||||
Transition<Integer> transition = transitions.get(i);
|
Transition<Integer> transition = transitions.get(i);
|
||||||
double yTarget = computeTarget(i, transition.getReward(), transition.isTerminal());
|
double yTarget = computeTarget(i, transition.getReward(), transition.isTerminal());
|
||||||
|
|
|
@ -16,7 +16,7 @@
|
||||||
|
|
||||||
package org.deeplearning4j.rl4j.learning.sync.qlearning.discrete.TDTargetAlgorithm;
|
package org.deeplearning4j.rl4j.learning.sync.qlearning.discrete.TDTargetAlgorithm;
|
||||||
|
|
||||||
import org.deeplearning4j.rl4j.learning.sync.qlearning.TargetQNetworkSource;
|
import org.deeplearning4j.rl4j.network.IOutputNeuralNet;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
|
|
||||||
|
@ -32,12 +32,12 @@ public class DoubleDQN extends BaseDQNAlgorithm {
|
||||||
// In litterature, this corresponds to: max_{a}Q(s_{t+1}, a)
|
// In litterature, this corresponds to: max_{a}Q(s_{t+1}, a)
|
||||||
private INDArray maxActionsFromQNetworkNextObservation;
|
private INDArray maxActionsFromQNetworkNextObservation;
|
||||||
|
|
||||||
public DoubleDQN(TargetQNetworkSource qTargetNetworkSource, double gamma) {
|
public DoubleDQN(IOutputNeuralNet qNetwork, IOutputNeuralNet targetQNetwork, double gamma) {
|
||||||
super(qTargetNetworkSource, gamma);
|
super(qNetwork, targetQNetwork, gamma);
|
||||||
}
|
}
|
||||||
|
|
||||||
public DoubleDQN(TargetQNetworkSource qTargetNetworkSource, double gamma, double errorClamp) {
|
public DoubleDQN(IOutputNeuralNet qNetwork, IOutputNeuralNet targetQNetwork, double gamma, double errorClamp) {
|
||||||
super(qTargetNetworkSource, gamma, errorClamp);
|
super(qNetwork, targetQNetwork, gamma, errorClamp);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
|
|
@ -16,7 +16,7 @@
|
||||||
|
|
||||||
package org.deeplearning4j.rl4j.learning.sync.qlearning.discrete.TDTargetAlgorithm;
|
package org.deeplearning4j.rl4j.learning.sync.qlearning.discrete.TDTargetAlgorithm;
|
||||||
|
|
||||||
import org.deeplearning4j.rl4j.learning.sync.qlearning.TargetQNetworkSource;
|
import org.deeplearning4j.rl4j.network.IOutputNeuralNet;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
|
|
||||||
|
@ -32,12 +32,12 @@ public class StandardDQN extends BaseDQNAlgorithm {
|
||||||
// In litterature, this corresponds to: max_{a}Q_{tar}(s_{t+1}, a)
|
// In litterature, this corresponds to: max_{a}Q_{tar}(s_{t+1}, a)
|
||||||
private INDArray maxActionsFromQTargetNextObservation;
|
private INDArray maxActionsFromQTargetNextObservation;
|
||||||
|
|
||||||
public StandardDQN(TargetQNetworkSource qTargetNetworkSource, double gamma) {
|
public StandardDQN(IOutputNeuralNet qNetwork, IOutputNeuralNet targetQNetwork, double gamma) {
|
||||||
super(qTargetNetworkSource, gamma);
|
super(qNetwork, targetQNetwork, gamma);
|
||||||
}
|
}
|
||||||
|
|
||||||
public StandardDQN(TargetQNetworkSource qTargetNetworkSource, double gamma, double errorClamp) {
|
public StandardDQN(IOutputNeuralNet qNetwork, IOutputNeuralNet targetQNetwork, double gamma, double errorClamp) {
|
||||||
super(qTargetNetworkSource, gamma, errorClamp);
|
super(qNetwork, targetQNetwork, gamma, errorClamp);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
|
|
@ -1,28 +1,38 @@
|
||||||
/*******************************************************************************
|
/*******************************************************************************
|
||||||
* Copyright (c) 2015-2019 Skymind, Inc.
|
* Copyright (c) 2020 Konduit K.K.
|
||||||
*
|
*
|
||||||
* This program and the accompanying materials are made available under the
|
* This program and the accompanying materials are made available under the
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
*
|
*
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
* License for the specific language governing permissions and limitations
|
* License for the specific language governing permissions and limitations
|
||||||
* under the License.
|
* under the License.
|
||||||
*
|
*
|
||||||
* SPDX-License-Identifier: Apache-2.0
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
******************************************************************************/
|
******************************************************************************/
|
||||||
|
package org.deeplearning4j.rl4j.network;
|
||||||
package org.deeplearning4j.rl4j.learning.sync.qlearning;
|
|
||||||
|
import org.deeplearning4j.rl4j.observation.Observation;
|
||||||
import org.deeplearning4j.rl4j.network.dqn.IDQN;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* An interface for all implementations capable of supplying a Q-Network
|
* An interface defining the output aspect of a {@link NeuralNet}.
|
||||||
*
|
*/
|
||||||
* @author Alexandre Boulanger
|
public interface IOutputNeuralNet {
|
||||||
*/
|
/**
|
||||||
public interface QNetworkSource {
|
* Compute the output for the supplied observation.
|
||||||
IDQN getQNetwork();
|
* @param observation An {@link Observation}
|
||||||
}
|
* @return The ouptut of the network
|
||||||
|
*/
|
||||||
|
INDArray output(Observation observation);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Compute the output for the supplied batch.
|
||||||
|
* @param batch
|
||||||
|
* @return The ouptut of the network
|
||||||
|
*/
|
||||||
|
INDArray output(INDArray batch);
|
||||||
|
}
|
|
@ -17,6 +17,7 @@
|
||||||
package org.deeplearning4j.rl4j.network.dqn;
|
package org.deeplearning4j.rl4j.network.dqn;
|
||||||
|
|
||||||
import org.deeplearning4j.nn.gradient.Gradient;
|
import org.deeplearning4j.nn.gradient.Gradient;
|
||||||
|
import org.deeplearning4j.rl4j.network.IOutputNeuralNet;
|
||||||
import org.deeplearning4j.rl4j.network.NeuralNet;
|
import org.deeplearning4j.rl4j.network.NeuralNet;
|
||||||
import org.deeplearning4j.rl4j.observation.Observation;
|
import org.deeplearning4j.rl4j.observation.Observation;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
@ -27,7 +28,7 @@ import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
* This neural net quantify the value of each action given a state
|
* This neural net quantify the value of each action given a state
|
||||||
*
|
*
|
||||||
*/
|
*/
|
||||||
public interface IDQN<NN extends IDQN> extends NeuralNet<NN> {
|
public interface IDQN<NN extends IDQN> extends NeuralNet<NN>, IOutputNeuralNet {
|
||||||
|
|
||||||
boolean isRecurrent();
|
boolean isRecurrent();
|
||||||
|
|
||||||
|
@ -37,9 +38,6 @@ public interface IDQN<NN extends IDQN> extends NeuralNet<NN> {
|
||||||
|
|
||||||
void fit(INDArray input, INDArray[] labels);
|
void fit(INDArray input, INDArray[] labels);
|
||||||
|
|
||||||
INDArray output(INDArray batch);
|
|
||||||
INDArray output(Observation observation);
|
|
||||||
|
|
||||||
INDArray[] outputAll(INDArray batch);
|
INDArray[] outputAll(INDArray batch);
|
||||||
|
|
||||||
NN clone();
|
NN clone();
|
||||||
|
|
|
@ -1,10 +1,13 @@
|
||||||
package org.deeplearning4j.rl4j.learning.sync.qlearning.discrete.TDTargetAlgorithm;
|
package org.deeplearning4j.rl4j.learning.sync.qlearning.discrete.TDTargetAlgorithm;
|
||||||
|
|
||||||
import org.deeplearning4j.rl4j.learning.sync.Transition;
|
import org.deeplearning4j.rl4j.learning.sync.Transition;
|
||||||
import org.deeplearning4j.rl4j.learning.sync.support.MockDQN;
|
import org.deeplearning4j.rl4j.network.IOutputNeuralNet;
|
||||||
import org.deeplearning4j.rl4j.learning.sync.support.MockTargetQNetworkSource;
|
|
||||||
import org.deeplearning4j.rl4j.observation.Observation;
|
import org.deeplearning4j.rl4j.observation.Observation;
|
||||||
|
import org.junit.Before;
|
||||||
import org.junit.Test;
|
import org.junit.Test;
|
||||||
|
import org.junit.runner.RunWith;
|
||||||
|
import org.mockito.Mock;
|
||||||
|
import org.mockito.junit.MockitoJUnitRunner;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.dataset.api.DataSet;
|
import org.nd4j.linalg.dataset.api.DataSet;
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
|
@ -13,16 +16,29 @@ import java.util.ArrayList;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
|
||||||
import static org.junit.Assert.assertEquals;
|
import static org.junit.Assert.assertEquals;
|
||||||
|
import static org.mockito.ArgumentMatchers.any;
|
||||||
|
import static org.mockito.Mockito.when;
|
||||||
|
|
||||||
|
@RunWith(MockitoJUnitRunner.class)
|
||||||
public class DoubleDQNTest {
|
public class DoubleDQNTest {
|
||||||
|
|
||||||
|
@Mock
|
||||||
|
IOutputNeuralNet qNetworkMock;
|
||||||
|
|
||||||
|
@Mock
|
||||||
|
IOutputNeuralNet targetQNetworkMock;
|
||||||
|
|
||||||
|
|
||||||
|
@Before
|
||||||
|
public void setup() {
|
||||||
|
when(qNetworkMock.output(any(INDArray.class))).thenAnswer(i -> i.getArguments()[0]);
|
||||||
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void when_isTerminal_expect_rewardValueAtIdx0() {
|
public void when_isTerminal_expect_rewardValueAtIdx0() {
|
||||||
|
|
||||||
// Assemble
|
// Assemble
|
||||||
MockDQN qNetwork = new MockDQN();
|
when(targetQNetworkMock.output(any(INDArray.class))).thenAnswer(i -> i.getArguments()[0]);
|
||||||
MockDQN targetQNetwork = new MockDQN();
|
|
||||||
MockTargetQNetworkSource targetQNetworkSource = new MockTargetQNetworkSource(qNetwork, targetQNetwork);
|
|
||||||
|
|
||||||
List<Transition<Integer>> transitions = new ArrayList<Transition<Integer>>() {
|
List<Transition<Integer>> transitions = new ArrayList<Transition<Integer>>() {
|
||||||
{
|
{
|
||||||
|
@ -31,7 +47,7 @@ public class DoubleDQNTest {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
DoubleDQN sut = new DoubleDQN(targetQNetworkSource, 0.5);
|
DoubleDQN sut = new DoubleDQN(qNetworkMock, targetQNetworkMock, 0.5);
|
||||||
|
|
||||||
// Act
|
// Act
|
||||||
DataSet result = sut.computeTDTargets(transitions);
|
DataSet result = sut.computeTDTargets(transitions);
|
||||||
|
@ -46,9 +62,7 @@ public class DoubleDQNTest {
|
||||||
public void when_isNotTerminal_expect_rewardPlusEstimatedQValue() {
|
public void when_isNotTerminal_expect_rewardPlusEstimatedQValue() {
|
||||||
|
|
||||||
// Assemble
|
// Assemble
|
||||||
MockDQN qNetwork = new MockDQN();
|
when(targetQNetworkMock.output(any(INDArray.class))).thenAnswer(i -> ((INDArray)i.getArguments()[0]).mul(-1.0));
|
||||||
MockDQN targetQNetwork = new MockDQN(-1.0);
|
|
||||||
MockTargetQNetworkSource targetQNetworkSource = new MockTargetQNetworkSource(qNetwork, targetQNetwork);
|
|
||||||
|
|
||||||
List<Transition<Integer>> transitions = new ArrayList<Transition<Integer>>() {
|
List<Transition<Integer>> transitions = new ArrayList<Transition<Integer>>() {
|
||||||
{
|
{
|
||||||
|
@ -57,7 +71,7 @@ public class DoubleDQNTest {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
DoubleDQN sut = new DoubleDQN(targetQNetworkSource, 0.5);
|
DoubleDQN sut = new DoubleDQN(qNetworkMock, targetQNetworkMock, 0.5);
|
||||||
|
|
||||||
// Act
|
// Act
|
||||||
DataSet result = sut.computeTDTargets(transitions);
|
DataSet result = sut.computeTDTargets(transitions);
|
||||||
|
@ -72,9 +86,7 @@ public class DoubleDQNTest {
|
||||||
public void when_batchHasMoreThanOne_expect_everySampleEvaluated() {
|
public void when_batchHasMoreThanOne_expect_everySampleEvaluated() {
|
||||||
|
|
||||||
// Assemble
|
// Assemble
|
||||||
MockDQN qNetwork = new MockDQN();
|
when(targetQNetworkMock.output(any(INDArray.class))).thenAnswer(i -> ((INDArray)i.getArguments()[0]).mul(-1.0));
|
||||||
MockDQN targetQNetwork = new MockDQN(-1.0);
|
|
||||||
MockTargetQNetworkSource targetQNetworkSource = new MockTargetQNetworkSource(qNetwork, targetQNetwork);
|
|
||||||
|
|
||||||
List<Transition<Integer>> transitions = new ArrayList<Transition<Integer>>() {
|
List<Transition<Integer>> transitions = new ArrayList<Transition<Integer>>() {
|
||||||
{
|
{
|
||||||
|
@ -87,7 +99,7 @@ public class DoubleDQNTest {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
DoubleDQN sut = new DoubleDQN(targetQNetworkSource, 0.5);
|
DoubleDQN sut = new DoubleDQN(qNetworkMock, targetQNetworkMock, 0.5);
|
||||||
|
|
||||||
// Act
|
// Act
|
||||||
DataSet result = sut.computeTDTargets(transitions);
|
DataSet result = sut.computeTDTargets(transitions);
|
||||||
|
|
|
@ -1,10 +1,13 @@
|
||||||
package org.deeplearning4j.rl4j.learning.sync.qlearning.discrete.TDTargetAlgorithm;
|
package org.deeplearning4j.rl4j.learning.sync.qlearning.discrete.TDTargetAlgorithm;
|
||||||
|
|
||||||
import org.deeplearning4j.rl4j.learning.sync.Transition;
|
import org.deeplearning4j.rl4j.learning.sync.Transition;
|
||||||
import org.deeplearning4j.rl4j.learning.sync.support.MockDQN;
|
import org.deeplearning4j.rl4j.network.IOutputNeuralNet;
|
||||||
import org.deeplearning4j.rl4j.learning.sync.support.MockTargetQNetworkSource;
|
|
||||||
import org.deeplearning4j.rl4j.observation.Observation;
|
import org.deeplearning4j.rl4j.observation.Observation;
|
||||||
|
import org.junit.Before;
|
||||||
import org.junit.Test;
|
import org.junit.Test;
|
||||||
|
import org.junit.runner.RunWith;
|
||||||
|
import org.mockito.Mock;
|
||||||
|
import org.mockito.junit.MockitoJUnitRunner;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.dataset.api.DataSet;
|
import org.nd4j.linalg.dataset.api.DataSet;
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
|
@ -12,17 +15,31 @@ import org.nd4j.linalg.factory.Nd4j;
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
|
||||||
import static org.junit.Assert.*;
|
import static org.junit.Assert.assertEquals;
|
||||||
|
import static org.mockito.ArgumentMatchers.any;
|
||||||
|
import static org.mockito.Mockito.when;
|
||||||
|
|
||||||
|
@RunWith(MockitoJUnitRunner.class)
|
||||||
public class StandardDQNTest {
|
public class StandardDQNTest {
|
||||||
|
|
||||||
|
@Mock
|
||||||
|
IOutputNeuralNet qNetworkMock;
|
||||||
|
|
||||||
|
@Mock
|
||||||
|
IOutputNeuralNet targetQNetworkMock;
|
||||||
|
|
||||||
|
|
||||||
|
@Before
|
||||||
|
public void setup() {
|
||||||
|
when(qNetworkMock.output(any(INDArray.class))).thenAnswer(i -> i.getArguments()[0]);
|
||||||
|
when(targetQNetworkMock.output(any(INDArray.class))).thenAnswer(i -> i.getArguments()[0]);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void when_isTerminal_expect_rewardValueAtIdx0() {
|
public void when_isTerminal_expect_rewardValueAtIdx0() {
|
||||||
|
|
||||||
// Assemble
|
// Assemble
|
||||||
MockDQN qNetwork = new MockDQN();
|
|
||||||
MockDQN targetQNetwork = new MockDQN();
|
|
||||||
MockTargetQNetworkSource targetQNetworkSource = new MockTargetQNetworkSource(qNetwork, targetQNetwork);
|
|
||||||
|
|
||||||
List<Transition<Integer>> transitions = new ArrayList<Transition<Integer>>() {
|
List<Transition<Integer>> transitions = new ArrayList<Transition<Integer>>() {
|
||||||
{
|
{
|
||||||
add(buildTransition(buildObservation(new double[]{1.1, 2.2}),
|
add(buildTransition(buildObservation(new double[]{1.1, 2.2}),
|
||||||
|
@ -30,7 +47,7 @@ public class StandardDQNTest {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
StandardDQN sut = new StandardDQN(targetQNetworkSource, 0.5);
|
StandardDQN sut = new StandardDQN(qNetworkMock, targetQNetworkMock, 0.5);
|
||||||
|
|
||||||
// Act
|
// Act
|
||||||
DataSet result = sut.computeTDTargets(transitions);
|
DataSet result = sut.computeTDTargets(transitions);
|
||||||
|
@ -45,10 +62,6 @@ public class StandardDQNTest {
|
||||||
public void when_isNotTerminal_expect_rewardPlusEstimatedQValue() {
|
public void when_isNotTerminal_expect_rewardPlusEstimatedQValue() {
|
||||||
|
|
||||||
// Assemble
|
// Assemble
|
||||||
MockDQN qNetwork = new MockDQN();
|
|
||||||
MockDQN targetQNetwork = new MockDQN();
|
|
||||||
MockTargetQNetworkSource targetQNetworkSource = new MockTargetQNetworkSource(qNetwork, targetQNetwork);
|
|
||||||
|
|
||||||
List<Transition<Integer>> transitions = new ArrayList<Transition<Integer>>() {
|
List<Transition<Integer>> transitions = new ArrayList<Transition<Integer>>() {
|
||||||
{
|
{
|
||||||
add(buildTransition(buildObservation(new double[]{1.1, 2.2}),
|
add(buildTransition(buildObservation(new double[]{1.1, 2.2}),
|
||||||
|
@ -56,7 +69,7 @@ public class StandardDQNTest {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
StandardDQN sut = new StandardDQN(targetQNetworkSource, 0.5);
|
StandardDQN sut = new StandardDQN(qNetworkMock, targetQNetworkMock, 0.5);
|
||||||
|
|
||||||
// Act
|
// Act
|
||||||
DataSet result = sut.computeTDTargets(transitions);
|
DataSet result = sut.computeTDTargets(transitions);
|
||||||
|
@ -71,10 +84,6 @@ public class StandardDQNTest {
|
||||||
public void when_batchHasMoreThanOne_expect_everySampleEvaluated() {
|
public void when_batchHasMoreThanOne_expect_everySampleEvaluated() {
|
||||||
|
|
||||||
// Assemble
|
// Assemble
|
||||||
MockDQN qNetwork = new MockDQN();
|
|
||||||
MockDQN targetQNetwork = new MockDQN();
|
|
||||||
MockTargetQNetworkSource targetQNetworkSource = new MockTargetQNetworkSource(qNetwork, targetQNetwork);
|
|
||||||
|
|
||||||
List<Transition<Integer>> transitions = new ArrayList<Transition<Integer>>() {
|
List<Transition<Integer>> transitions = new ArrayList<Transition<Integer>>() {
|
||||||
{
|
{
|
||||||
add(buildTransition(buildObservation(new double[]{1.1, 2.2}),
|
add(buildTransition(buildObservation(new double[]{1.1, 2.2}),
|
||||||
|
@ -86,7 +95,7 @@ public class StandardDQNTest {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
StandardDQN sut = new StandardDQN(targetQNetworkSource, 0.5);
|
StandardDQN sut = new StandardDQN(qNetworkMock, targetQNetworkMock, 0.5);
|
||||||
|
|
||||||
// Act
|
// Act
|
||||||
DataSet result = sut.computeTDTargets(transitions);
|
DataSet result = sut.computeTDTargets(transitions);
|
||||||
|
|
|
@ -1,26 +0,0 @@
|
||||||
package org.deeplearning4j.rl4j.learning.sync.support;
|
|
||||||
|
|
||||||
import org.deeplearning4j.rl4j.learning.sync.qlearning.TargetQNetworkSource;
|
|
||||||
import org.deeplearning4j.rl4j.network.dqn.IDQN;
|
|
||||||
|
|
||||||
public class MockTargetQNetworkSource implements TargetQNetworkSource {
|
|
||||||
|
|
||||||
|
|
||||||
private final IDQN qNetwork;
|
|
||||||
private final IDQN targetQNetwork;
|
|
||||||
|
|
||||||
public MockTargetQNetworkSource(IDQN qNetwork, IDQN targetQNetwork) {
|
|
||||||
this.qNetwork = qNetwork;
|
|
||||||
this.targetQNetwork = targetQNetwork;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public IDQN getTargetQNetwork() {
|
|
||||||
return targetQNetwork;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public IDQN getQNetwork() {
|
|
||||||
return qNetwork;
|
|
||||||
}
|
|
||||||
}
|
|
Loading…
Reference in New Issue