RL4J: Add NeuralNetUpdater (#500)
Signed-off-by: Alexandre Boulanger <aboulang2002@yahoo.com>master
parent
69ebc96068
commit
4190c9ee0f
|
@ -16,6 +16,8 @@
|
|||
package org.deeplearning4j.rl4j.agent.update;
|
||||
|
||||
import lombok.Getter;
|
||||
import org.deeplearning4j.rl4j.agent.update.neuralnetupdater.INeuralNetUpdater;
|
||||
import org.deeplearning4j.rl4j.agent.update.neuralnetupdater.NeuralNetUpdater;
|
||||
import org.deeplearning4j.rl4j.learning.sync.Transition;
|
||||
import org.deeplearning4j.rl4j.learning.sync.qlearning.discrete.TDTargetAlgorithm.DoubleDQN;
|
||||
import org.deeplearning4j.rl4j.learning.sync.qlearning.discrete.TDTargetAlgorithm.ITDTargetAlgorithm;
|
||||
|
@ -29,30 +31,26 @@ import java.util.List;
|
|||
// and network update to sub components.
|
||||
public class DQNNeuralNetUpdateRule implements IUpdateRule<Transition<Integer>> {
|
||||
|
||||
private final IDQN qNetwork;
|
||||
private final IDQN targetQNetwork;
|
||||
private final int targetUpdateFrequency;
|
||||
private final INeuralNetUpdater updater;
|
||||
|
||||
private final ITDTargetAlgorithm<Integer> tdTargetAlgorithm;
|
||||
|
||||
@Getter
|
||||
private int updateCount = 0;
|
||||
|
||||
public DQNNeuralNetUpdateRule(IDQN qNetwork, int targetUpdateFrequency, boolean isDoubleDQN, double gamma, double errorClamp) {
|
||||
this.qNetwork = qNetwork;
|
||||
public DQNNeuralNetUpdateRule(IDQN<IDQN> qNetwork, int targetUpdateFrequency, boolean isDoubleDQN, double gamma, double errorClamp) {
|
||||
this.targetQNetwork = qNetwork.clone();
|
||||
this.targetUpdateFrequency = targetUpdateFrequency;
|
||||
tdTargetAlgorithm = isDoubleDQN
|
||||
? new DoubleDQN(qNetwork, targetQNetwork, gamma, errorClamp)
|
||||
: new StandardDQN(qNetwork, targetQNetwork, gamma, errorClamp);
|
||||
updater = new NeuralNetUpdater(qNetwork, targetQNetwork, targetUpdateFrequency);
|
||||
|
||||
}
|
||||
|
||||
@Override
|
||||
public void update(List<Transition<Integer>> trainingBatch) {
|
||||
DataSet targets = tdTargetAlgorithm.computeTDTargets(trainingBatch);
|
||||
qNetwork.fit(targets.getFeatures(), targets.getLabels());
|
||||
if(++updateCount % targetUpdateFrequency == 0) {
|
||||
targetQNetwork.copy(qNetwork);
|
||||
}
|
||||
DataSet targets = tdTargetAlgorithm.compute(trainingBatch);
|
||||
updater.update(targets);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -0,0 +1,30 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2020 Konduit K.K.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
package org.deeplearning4j.rl4j.agent.update.neuralnetupdater;
|
||||
|
||||
import org.nd4j.linalg.dataset.api.DataSet;
|
||||
|
||||
/**
|
||||
* The role of INeuralNetUpdater implementations is to update a {@link org.deeplearning4j.rl4j.network.NeuralNet NeuralNet}
|
||||
* from a {@link DataSet}.<p />
|
||||
*/
|
||||
public interface INeuralNetUpdater {
|
||||
/**
|
||||
* Update a {@link org.deeplearning4j.rl4j.network.NeuralNet NeuralNet}.
|
||||
* @param featuresLabels A Dataset that will be used to update the network.
|
||||
*/
|
||||
void update(DataSet featuresLabels);
|
||||
}
|
|
@ -0,0 +1,62 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2020 Konduit K.K.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
package org.deeplearning4j.rl4j.agent.update.neuralnetupdater;
|
||||
|
||||
import org.deeplearning4j.rl4j.network.ITrainableNeuralNet;
|
||||
import org.nd4j.linalg.dataset.api.DataSet;
|
||||
|
||||
/**
|
||||
* A {@link INeuralNetUpdater} that updates a neural network and sync a target network at defined intervals
|
||||
*/
|
||||
public class NeuralNetUpdater implements INeuralNetUpdater {
|
||||
|
||||
private final ITrainableNeuralNet current;
|
||||
private final ITrainableNeuralNet target;
|
||||
|
||||
private int updateCount = 0;
|
||||
private final int targetUpdateFrequency;
|
||||
|
||||
/**
|
||||
* @param current The current {@link ITrainableNeuralNet network}
|
||||
* @param target The target {@link ITrainableNeuralNet network}
|
||||
* @param targetUpdateFrequency Will synchronize the target network at every <i>targetUpdateFrequency</i> updates
|
||||
*/
|
||||
public NeuralNetUpdater(ITrainableNeuralNet current,
|
||||
ITrainableNeuralNet target,
|
||||
int targetUpdateFrequency) {
|
||||
this.current = current;
|
||||
this.target = target;
|
||||
|
||||
this.targetUpdateFrequency = targetUpdateFrequency;
|
||||
}
|
||||
|
||||
/**
|
||||
* Update the current network
|
||||
* @param featuresLabels A Dataset that will be used to update the network.
|
||||
*/
|
||||
@Override
|
||||
public void update(DataSet featuresLabels) {
|
||||
current.fit(featuresLabels);
|
||||
syncTargetNetwork();
|
||||
}
|
||||
|
||||
private void syncTargetNetwork() {
|
||||
if(++updateCount % targetUpdateFrequency == 0) {
|
||||
target.copy(current);
|
||||
}
|
||||
}
|
||||
|
||||
}
|
|
@ -23,16 +23,19 @@ import lombok.Setter;
|
|||
import lombok.Value;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.deeplearning4j.gym.StepReply;
|
||||
import org.deeplearning4j.rl4j.learning.*;
|
||||
import org.deeplearning4j.rl4j.learning.HistoryProcessor;
|
||||
import org.deeplearning4j.rl4j.learning.IEpochTrainer;
|
||||
import org.deeplearning4j.rl4j.learning.IHistoryProcessor;
|
||||
import org.deeplearning4j.rl4j.learning.Learning;
|
||||
import org.deeplearning4j.rl4j.learning.configuration.IAsyncLearningConfiguration;
|
||||
import org.deeplearning4j.rl4j.learning.listener.TrainingListener;
|
||||
import org.deeplearning4j.rl4j.learning.listener.TrainingListenerList;
|
||||
import org.deeplearning4j.rl4j.mdp.MDP;
|
||||
import org.deeplearning4j.rl4j.network.NeuralNet;
|
||||
import org.deeplearning4j.rl4j.space.Encodable;
|
||||
import org.deeplearning4j.rl4j.observation.Observation;
|
||||
import org.deeplearning4j.rl4j.policy.IPolicy;
|
||||
import org.deeplearning4j.rl4j.space.ActionSpace;
|
||||
import org.deeplearning4j.rl4j.space.Encodable;
|
||||
import org.deeplearning4j.rl4j.util.IDataManager;
|
||||
import org.deeplearning4j.rl4j.util.LegacyMDPWrapper;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
|
|
|
@ -18,11 +18,9 @@
|
|||
package org.deeplearning4j.rl4j.learning.async;
|
||||
|
||||
import org.deeplearning4j.nn.gradient.Gradient;
|
||||
import org.deeplearning4j.rl4j.network.NeuralNet;
|
||||
import org.deeplearning4j.rl4j.network.ITrainableNeuralNet;
|
||||
|
||||
import java.util.concurrent.atomic.AtomicInteger;
|
||||
|
||||
public interface IAsyncGlobal<NN extends NeuralNet> {
|
||||
public interface IAsyncGlobal<NN extends ITrainableNeuralNet> {
|
||||
|
||||
boolean isTrainingComplete();
|
||||
|
||||
|
|
|
@ -80,7 +80,7 @@ public abstract class BaseTDTargetAlgorithm implements ITDTargetAlgorithm<Intege
|
|||
protected abstract double computeTarget(int batchIdx, double reward, boolean isTerminal);
|
||||
|
||||
@Override
|
||||
public DataSet computeTDTargets(List<Transition<Integer>> transitions) {
|
||||
public DataSet compute(List<Transition<Integer>> transitions) {
|
||||
|
||||
int size = transitions.size();
|
||||
|
||||
|
|
|
@ -34,5 +34,5 @@ public interface ITDTargetAlgorithm<A> {
|
|||
* @param transitions The transitions from the experience replay
|
||||
* @return A DataSet where every element is the observation and the estimated Q-Values for all actions
|
||||
*/
|
||||
DataSet computeTDTargets(List<Transition<A>> transitions);
|
||||
DataSet compute(List<Transition<A>> transitions);
|
||||
}
|
||||
|
|
|
@ -0,0 +1,40 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2020 Konduit K.K.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
package org.deeplearning4j.rl4j.network;
|
||||
|
||||
import org.nd4j.linalg.dataset.api.DataSet;
|
||||
|
||||
/**
|
||||
* An interface defining the <i>trainable</i> aspect of a {@link NeuralNet}.
|
||||
*/
|
||||
public interface ITrainableNeuralNet<NET_TYPE extends ITrainableNeuralNet> {
|
||||
/**
|
||||
* Train the neural net using the supplied <i>feature-labels</i>
|
||||
* @param featuresLabels The feature-labels
|
||||
*/
|
||||
void fit(DataSet featuresLabels);
|
||||
|
||||
/**
|
||||
* Changes this instance to be a copy of the <i>from</i> network.
|
||||
* @param from The network that will be the source of the copy.
|
||||
*/
|
||||
void copy(NET_TYPE from);
|
||||
|
||||
/**
|
||||
* Creates a clone of the network instance.
|
||||
*/
|
||||
NET_TYPE clone();
|
||||
}
|
|
@ -29,7 +29,7 @@ import java.io.OutputStream;
|
|||
* Factorisation between ActorCritic and DQN neural net.
|
||||
* Useful for AsyncLearning and Thread code.
|
||||
*/
|
||||
public interface NeuralNet<NN extends NeuralNet> {
|
||||
public interface NeuralNet<NN extends NeuralNet> extends IOutputNeuralNet, ITrainableNeuralNet<NN> {
|
||||
|
||||
/**
|
||||
* Returns the underlying MultiLayerNetwork or ComputationGraph objects.
|
||||
|
@ -52,18 +52,6 @@ public interface NeuralNet<NN extends NeuralNet> {
|
|||
*/
|
||||
INDArray[] outputAll(INDArray batch);
|
||||
|
||||
/**
|
||||
* clone the Neural Net with the same paramaeters
|
||||
* @return the cloned neural net
|
||||
*/
|
||||
NN clone();
|
||||
|
||||
/**
|
||||
* copy the parameters from a neural net
|
||||
* @param from where to copy parameters
|
||||
*/
|
||||
void copy(NN from);
|
||||
|
||||
/**
|
||||
* Calculate the gradients from input and label (target) of all outputs
|
||||
* @param input input batch
|
||||
|
|
|
@ -18,6 +18,7 @@
|
|||
package org.deeplearning4j.rl4j.network.ac;
|
||||
|
||||
import lombok.Getter;
|
||||
import org.apache.commons.lang3.NotImplementedException;
|
||||
import org.deeplearning4j.nn.api.NeuralNetwork;
|
||||
import org.deeplearning4j.nn.conf.ComputationGraphConfiguration;
|
||||
import org.deeplearning4j.nn.gradient.Gradient;
|
||||
|
@ -25,8 +26,10 @@ import org.deeplearning4j.nn.graph.ComputationGraph;
|
|||
import org.deeplearning4j.nn.layers.recurrent.RnnOutputLayer;
|
||||
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
|
||||
import org.deeplearning4j.optimize.api.TrainingListener;
|
||||
import org.deeplearning4j.rl4j.observation.Observation;
|
||||
import org.deeplearning4j.util.ModelSerializer;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.dataset.api.DataSet;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.io.OutputStream;
|
||||
|
@ -80,6 +83,11 @@ public class ActorCriticCompGraph implements IActorCritic<ActorCriticCompGraph>
|
|||
return nn;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void fit(DataSet featuresLabels) {
|
||||
fit(featuresLabels.getFeatures(), new INDArray[] { featuresLabels.getLabels() });
|
||||
}
|
||||
|
||||
public void copy(ActorCriticCompGraph from) {
|
||||
cg.setParams(from.cg.params());
|
||||
}
|
||||
|
@ -137,5 +145,19 @@ public class ActorCriticCompGraph implements IActorCritic<ActorCriticCompGraph>
|
|||
public void save(String pathValue, String pathPolicy) throws IOException {
|
||||
throw new UnsupportedOperationException("Call save(path)");
|
||||
}
|
||||
|
||||
@Override
|
||||
public INDArray output(Observation observation) {
|
||||
// TODO: signature of output() will change to return a class that has named outputs to support network like
|
||||
// this one (output from the value-network and another output for the policy-network
|
||||
throw new NotImplementedException("Not implemented: will be done with AgentLearner async support");
|
||||
}
|
||||
|
||||
@Override
|
||||
public INDArray output(INDArray batch) {
|
||||
// TODO: signature of output() will change to return a class that has named outputs to support network like
|
||||
// this one (output from the value-network and another output for the policy-network
|
||||
throw new NotImplementedException("Not implemented: will be done with AgentLearner async support");
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -17,6 +17,7 @@
|
|||
package org.deeplearning4j.rl4j.network.ac;
|
||||
|
||||
import lombok.Getter;
|
||||
import org.apache.commons.lang3.NotImplementedException;
|
||||
import org.deeplearning4j.nn.api.NeuralNetwork;
|
||||
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
|
||||
import org.deeplearning4j.nn.gradient.Gradient;
|
||||
|
@ -24,8 +25,10 @@ import org.deeplearning4j.nn.layers.recurrent.RnnOutputLayer;
|
|||
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
|
||||
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
|
||||
import org.deeplearning4j.optimize.api.TrainingListener;
|
||||
import org.deeplearning4j.rl4j.observation.Observation;
|
||||
import org.deeplearning4j.util.ModelSerializer;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.dataset.api.DataSet;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.io.OutputStream;
|
||||
|
@ -86,6 +89,13 @@ public class ActorCriticSeparate<NN extends ActorCriticSeparate> implements IAct
|
|||
return nn;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void fit(DataSet featuresLabels) {
|
||||
// TODO: signature of fit() will change from DataSet to a class that has named labels to support network like
|
||||
// this one (labels for the value-network and another labels for the policy-network
|
||||
throw new NotImplementedException("Not implemented: will be done with AgentLearner async support");
|
||||
}
|
||||
|
||||
public void copy(NN from) {
|
||||
valueNet.setParams(from.valueNet.params());
|
||||
policyNet.setParams(from.policyNet.params());
|
||||
|
@ -164,6 +174,20 @@ public class ActorCriticSeparate<NN extends ActorCriticSeparate> implements IAct
|
|||
ModelSerializer.writeModel(valueNet, pathValue, true);
|
||||
ModelSerializer.writeModel(policyNet, pathPolicy, true);
|
||||
}
|
||||
|
||||
@Override
|
||||
public INDArray output(Observation observation) {
|
||||
// TODO: signature of output() will change to return a class that has named outputs to support network like
|
||||
// this one (output from the value-network and another output for the policy-network
|
||||
throw new NotImplementedException("Not implemented: will be done with AgentLearner async support");
|
||||
}
|
||||
|
||||
@Override
|
||||
public INDArray output(INDArray batch) {
|
||||
// TODO: signature of output() will change to return a class that has named outputs to support network like
|
||||
// this one (output from the value-network and another output for the policy-network
|
||||
throw new NotImplementedException("Not implemented: will be done with AgentLearner async support");
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
|
|
@ -16,7 +16,6 @@
|
|||
|
||||
package org.deeplearning4j.rl4j.network.ac;
|
||||
|
||||
import org.deeplearning4j.nn.gradient.Gradient;
|
||||
import org.deeplearning4j.rl4j.network.NeuralNet;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
|
||||
|
@ -33,27 +32,11 @@ import java.io.OutputStream;
|
|||
*/
|
||||
public interface IActorCritic<NN extends IActorCritic> extends NeuralNet<NN> {
|
||||
|
||||
boolean isRecurrent();
|
||||
|
||||
void reset();
|
||||
|
||||
void fit(INDArray input, INDArray[] labels);
|
||||
|
||||
//FIRST SHOULD BE VALUE AND SECOND IS SOFTMAX POLICY. DONT MESS THIS UP OR ELSE ASYNC THREAD IS BROKEN (maxQ) !
|
||||
INDArray[] outputAll(INDArray batch);
|
||||
|
||||
NN clone();
|
||||
|
||||
void copy(NN from);
|
||||
|
||||
Gradient[] gradient(INDArray input, INDArray[] labels);
|
||||
|
||||
void applyGradient(Gradient[] gradient, int batchSize);
|
||||
|
||||
void save(OutputStream streamValue, OutputStream streamPolicy) throws IOException;
|
||||
|
||||
void save(String pathValue, String pathPolicy) throws IOException;
|
||||
|
||||
double getLatestScore();
|
||||
|
||||
}
|
||||
|
|
|
@ -25,6 +25,7 @@ import org.deeplearning4j.optimize.api.TrainingListener;
|
|||
import org.deeplearning4j.rl4j.observation.Observation;
|
||||
import org.deeplearning4j.util.ModelSerializer;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.dataset.api.DataSet;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.io.OutputStream;
|
||||
|
@ -33,7 +34,7 @@ import java.util.Collection;
|
|||
/**
|
||||
* @author rubenfiszel (ruben.fiszel@epfl.ch) 7/25/16.
|
||||
*/
|
||||
public class DQN<NN extends DQN> implements IDQN<NN> {
|
||||
public class DQN implements IDQN<DQN> {
|
||||
|
||||
final protected MultiLayerNetwork mln;
|
||||
|
||||
|
@ -79,16 +80,23 @@ public class DQN<NN extends DQN> implements IDQN<NN> {
|
|||
return new INDArray[] {output(batch)};
|
||||
}
|
||||
|
||||
public NN clone() {
|
||||
NN nn = (NN)new DQN(mln.clone());
|
||||
nn.mln.setListeners(mln.getListeners());
|
||||
return nn;
|
||||
@Override
|
||||
public void fit(DataSet featuresLabels) {
|
||||
fit(featuresLabels.getFeatures(), featuresLabels.getLabels());
|
||||
}
|
||||
|
||||
public void copy(NN from) {
|
||||
@Override
|
||||
public void copy(DQN from) {
|
||||
mln.setParams(from.mln.params());
|
||||
}
|
||||
|
||||
@Override
|
||||
public DQN clone() {
|
||||
DQN nn = new DQN(mln.clone());
|
||||
nn.mln.setListeners(mln.getListeners());
|
||||
return nn;
|
||||
}
|
||||
|
||||
public Gradient[] gradient(INDArray input, INDArray labels) {
|
||||
mln.setInput(input);
|
||||
mln.setLabels(labels);
|
||||
|
|
|
@ -17,9 +17,7 @@
|
|||
package org.deeplearning4j.rl4j.network.dqn;
|
||||
|
||||
import org.deeplearning4j.nn.gradient.Gradient;
|
||||
import org.deeplearning4j.rl4j.network.IOutputNeuralNet;
|
||||
import org.deeplearning4j.rl4j.network.NeuralNet;
|
||||
import org.deeplearning4j.rl4j.observation.Observation;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
|
||||
/**
|
||||
|
@ -28,27 +26,11 @@ import org.nd4j.linalg.api.ndarray.INDArray;
|
|||
* This neural net quantify the value of each action given a state
|
||||
*
|
||||
*/
|
||||
public interface IDQN<NN extends IDQN> extends NeuralNet<NN>, IOutputNeuralNet {
|
||||
|
||||
boolean isRecurrent();
|
||||
|
||||
void reset();
|
||||
public interface IDQN<NN extends IDQN> extends NeuralNet<NN> {
|
||||
|
||||
void fit(INDArray input, INDArray labels);
|
||||
|
||||
void fit(INDArray input, INDArray[] labels);
|
||||
|
||||
INDArray[] outputAll(INDArray batch);
|
||||
|
||||
NN clone();
|
||||
|
||||
void copy(NN from);
|
||||
|
||||
Gradient[] gradient(INDArray input, INDArray label);
|
||||
|
||||
Gradient[] gradient(INDArray input, INDArray[] label);
|
||||
|
||||
void applyGradient(Gradient[] gradient, int batchSize);
|
||||
|
||||
double getLatestScore();
|
||||
}
|
||||
|
|
|
@ -0,0 +1,51 @@
|
|||
package org.deeplearning4j.rl4j.agent.update.neuralnetupdater;
|
||||
|
||||
import org.deeplearning4j.rl4j.network.ITrainableNeuralNet;
|
||||
import org.junit.Test;
|
||||
import org.junit.runner.RunWith;
|
||||
import org.mockito.Mock;
|
||||
import org.mockito.junit.MockitoJUnitRunner;
|
||||
import org.nd4j.linalg.dataset.api.DataSet;
|
||||
|
||||
import static org.mockito.Mockito.*;
|
||||
|
||||
@RunWith(MockitoJUnitRunner.class)
|
||||
public class NeuralNetUpdaterTest {
|
||||
|
||||
@Mock
|
||||
ITrainableNeuralNet currentMock;
|
||||
|
||||
@Mock
|
||||
ITrainableNeuralNet targetMock;
|
||||
|
||||
@Test
|
||||
public void when_callingUpdate_expect_currentUpdatedAndtargetNotChanged() {
|
||||
// Arrange
|
||||
NeuralNetUpdater sut = new NeuralNetUpdater(currentMock, targetMock, Integer.MAX_VALUE);
|
||||
DataSet featureLabels = new org.nd4j.linalg.dataset.DataSet();
|
||||
|
||||
// Act
|
||||
sut.update(featureLabels);
|
||||
|
||||
// Assert
|
||||
verify(currentMock, times(1)).fit(featureLabels);
|
||||
verify(targetMock, never()).fit(any());
|
||||
}
|
||||
|
||||
@Test
|
||||
public void when_callingUpdate_expect_targetUpdatedFromCurrentAtFrequency() {
|
||||
// Arrange
|
||||
NeuralNetUpdater sut = new NeuralNetUpdater(currentMock, targetMock, 3);
|
||||
DataSet featureLabels = new org.nd4j.linalg.dataset.DataSet();
|
||||
|
||||
// Act
|
||||
sut.update(featureLabels);
|
||||
sut.update(featureLabels);
|
||||
sut.update(featureLabels);
|
||||
|
||||
// Assert
|
||||
verify(currentMock, never()).copy(any());
|
||||
verify(targetMock, times(1)).copy(currentMock);
|
||||
}
|
||||
|
||||
}
|
|
@ -50,7 +50,7 @@ public class DoubleDQNTest {
|
|||
DoubleDQN sut = new DoubleDQN(qNetworkMock, targetQNetworkMock, 0.5);
|
||||
|
||||
// Act
|
||||
DataSet result = sut.computeTDTargets(transitions);
|
||||
DataSet result = sut.compute(transitions);
|
||||
|
||||
// Assert
|
||||
INDArray evaluatedQValues = result.getLabels();
|
||||
|
@ -74,7 +74,7 @@ public class DoubleDQNTest {
|
|||
DoubleDQN sut = new DoubleDQN(qNetworkMock, targetQNetworkMock, 0.5);
|
||||
|
||||
// Act
|
||||
DataSet result = sut.computeTDTargets(transitions);
|
||||
DataSet result = sut.compute(transitions);
|
||||
|
||||
// Assert
|
||||
INDArray evaluatedQValues = result.getLabels();
|
||||
|
@ -102,7 +102,7 @@ public class DoubleDQNTest {
|
|||
DoubleDQN sut = new DoubleDQN(qNetworkMock, targetQNetworkMock, 0.5);
|
||||
|
||||
// Act
|
||||
DataSet result = sut.computeTDTargets(transitions);
|
||||
DataSet result = sut.compute(transitions);
|
||||
|
||||
// Assert
|
||||
INDArray evaluatedQValues = result.getLabels();
|
||||
|
|
|
@ -50,7 +50,7 @@ public class StandardDQNTest {
|
|||
StandardDQN sut = new StandardDQN(qNetworkMock, targetQNetworkMock, 0.5);
|
||||
|
||||
// Act
|
||||
DataSet result = sut.computeTDTargets(transitions);
|
||||
DataSet result = sut.compute(transitions);
|
||||
|
||||
// Assert
|
||||
INDArray evaluatedQValues = result.getLabels();
|
||||
|
@ -72,7 +72,7 @@ public class StandardDQNTest {
|
|||
StandardDQN sut = new StandardDQN(qNetworkMock, targetQNetworkMock, 0.5);
|
||||
|
||||
// Act
|
||||
DataSet result = sut.computeTDTargets(transitions);
|
||||
DataSet result = sut.compute(transitions);
|
||||
|
||||
// Assert
|
||||
INDArray evaluatedQValues = result.getLabels();
|
||||
|
@ -98,7 +98,7 @@ public class StandardDQNTest {
|
|||
StandardDQN sut = new StandardDQN(qNetworkMock, targetQNetworkMock, 0.5);
|
||||
|
||||
// Act
|
||||
DataSet result = sut.computeTDTargets(transitions);
|
||||
DataSet result = sut.compute(transitions);
|
||||
|
||||
// Assert
|
||||
INDArray evaluatedQValues = result.getLabels();
|
||||
|
|
|
@ -2,10 +2,12 @@ package org.deeplearning4j.rl4j.learning.sync.support;
|
|||
|
||||
import org.deeplearning4j.nn.api.NeuralNetwork;
|
||||
import org.deeplearning4j.nn.gradient.Gradient;
|
||||
import org.deeplearning4j.rl4j.network.ITrainableNeuralNet;
|
||||
import org.deeplearning4j.rl4j.network.NeuralNet;
|
||||
import org.deeplearning4j.rl4j.network.dqn.IDQN;
|
||||
import org.deeplearning4j.rl4j.observation.Observation;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.dataset.api.DataSet;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.io.OutputStream;
|
||||
|
@ -66,21 +68,21 @@ public class MockDQN implements IDQN {
|
|||
return new INDArray[0];
|
||||
}
|
||||
|
||||
@Override
|
||||
public void fit(DataSet featuresLabels) {
|
||||
throw new UnsupportedOperationException();
|
||||
}
|
||||
|
||||
@Override
|
||||
public void copy(ITrainableNeuralNet from) {
|
||||
throw new UnsupportedOperationException();
|
||||
}
|
||||
|
||||
@Override
|
||||
public IDQN clone() {
|
||||
return null;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void copy(NeuralNet from) {
|
||||
|
||||
}
|
||||
|
||||
@Override
|
||||
public void copy(IDQN from) {
|
||||
|
||||
}
|
||||
|
||||
@Override
|
||||
public Gradient[] gradient(INDArray input, INDArray label) {
|
||||
return new Gradient[0];
|
||||
|
|
|
@ -36,6 +36,7 @@ import org.deeplearning4j.rl4j.util.LegacyMDPWrapper;
|
|||
import org.junit.Test;
|
||||
import org.nd4j.linalg.activations.Activation;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.dataset.api.DataSet;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
import org.nd4j.linalg.lossfunctions.LossFunctions;
|
||||
|
||||
|
@ -88,6 +89,11 @@ public class PolicyTest {
|
|||
throw new UnsupportedOperationException();
|
||||
}
|
||||
|
||||
@Override
|
||||
public void fit(DataSet featuresLabels) {
|
||||
throw new UnsupportedOperationException();
|
||||
}
|
||||
|
||||
@Override
|
||||
public void copy(NN from) {
|
||||
throw new UnsupportedOperationException();
|
||||
|
@ -127,6 +133,16 @@ public class PolicyTest {
|
|||
public void save(String filename) throws IOException {
|
||||
throw new UnsupportedOperationException();
|
||||
}
|
||||
|
||||
@Override
|
||||
public INDArray output(Observation observation) {
|
||||
throw new UnsupportedOperationException();
|
||||
}
|
||||
|
||||
@Override
|
||||
public INDArray output(INDArray batch) {
|
||||
throw new UnsupportedOperationException();
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
|
|
|
@ -2,11 +2,13 @@ package org.deeplearning4j.rl4j.support;
|
|||
|
||||
import org.deeplearning4j.nn.api.NeuralNetwork;
|
||||
import org.deeplearning4j.nn.gradient.Gradient;
|
||||
import org.deeplearning4j.rl4j.network.ITrainableNeuralNet;
|
||||
import org.deeplearning4j.rl4j.network.NeuralNet;
|
||||
import org.deeplearning4j.rl4j.network.dqn.IDQN;
|
||||
import org.deeplearning4j.rl4j.observation.Observation;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.common.primitives.Pair;
|
||||
import org.nd4j.linalg.dataset.api.DataSet;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.io.OutputStream;
|
||||
|
@ -63,6 +65,16 @@ public class MockDQN implements IDQN {
|
|||
return new INDArray[] { batch.mul(-1.0) };
|
||||
}
|
||||
|
||||
@Override
|
||||
public void fit(DataSet featuresLabels) {
|
||||
throw new UnsupportedOperationException();
|
||||
}
|
||||
|
||||
@Override
|
||||
public void copy(ITrainableNeuralNet from) {
|
||||
throw new UnsupportedOperationException();
|
||||
}
|
||||
|
||||
@Override
|
||||
public IDQN clone() {
|
||||
MockDQN clone = new MockDQN();
|
||||
|
@ -71,16 +83,6 @@ public class MockDQN implements IDQN {
|
|||
return clone;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void copy(NeuralNet from) {
|
||||
|
||||
}
|
||||
|
||||
@Override
|
||||
public void copy(IDQN from) {
|
||||
|
||||
}
|
||||
|
||||
@Override
|
||||
public Gradient[] gradient(INDArray input, INDArray label) {
|
||||
gradientParams.add(new Pair<INDArray, INDArray>(input, label));
|
||||
|
|
|
@ -2,8 +2,11 @@ package org.deeplearning4j.rl4j.support;
|
|||
|
||||
import org.deeplearning4j.nn.api.NeuralNetwork;
|
||||
import org.deeplearning4j.nn.gradient.Gradient;
|
||||
import org.deeplearning4j.rl4j.network.ITrainableNeuralNet;
|
||||
import org.deeplearning4j.rl4j.network.NeuralNet;
|
||||
import org.deeplearning4j.rl4j.observation.Observation;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.dataset.api.DataSet;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
|
||||
import java.io.IOException;
|
||||
|
@ -39,15 +42,20 @@ public class MockNeuralNet implements NeuralNet {
|
|||
}
|
||||
|
||||
@Override
|
||||
public NeuralNet clone() {
|
||||
return this;
|
||||
public void fit(DataSet featuresLabels) {
|
||||
throw new UnsupportedOperationException();
|
||||
}
|
||||
|
||||
@Override
|
||||
public void copy(NeuralNet from) {
|
||||
public void copy(ITrainableNeuralNet from) {
|
||||
++copyCallCount;
|
||||
}
|
||||
|
||||
@Override
|
||||
public NeuralNet clone() {
|
||||
return this;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Gradient[] gradient(INDArray input, INDArray[] labels) {
|
||||
return new Gradient[0];
|
||||
|
@ -77,4 +85,14 @@ public class MockNeuralNet implements NeuralNet {
|
|||
public void save(String filename) throws IOException {
|
||||
|
||||
}
|
||||
|
||||
@Override
|
||||
public INDArray output(Observation observation) {
|
||||
throw new UnsupportedOperationException();
|
||||
}
|
||||
|
||||
@Override
|
||||
public INDArray output(INDArray batch) {
|
||||
throw new UnsupportedOperationException();
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue