RL4J: Add NeuralNetUpdater (#500)

Signed-off-by: Alexandre Boulanger <aboulang2002@yahoo.com>
master
Alexandre Boulanger 2020-06-28 23:20:53 -04:00 committed by GitHub
parent 69ebc96068
commit 4190c9ee0f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
21 changed files with 329 additions and 102 deletions

View File

@ -16,6 +16,8 @@
package org.deeplearning4j.rl4j.agent.update;
import lombok.Getter;
import org.deeplearning4j.rl4j.agent.update.neuralnetupdater.INeuralNetUpdater;
import org.deeplearning4j.rl4j.agent.update.neuralnetupdater.NeuralNetUpdater;
import org.deeplearning4j.rl4j.learning.sync.Transition;
import org.deeplearning4j.rl4j.learning.sync.qlearning.discrete.TDTargetAlgorithm.DoubleDQN;
import org.deeplearning4j.rl4j.learning.sync.qlearning.discrete.TDTargetAlgorithm.ITDTargetAlgorithm;
@ -29,30 +31,26 @@ import java.util.List;
// and network update to sub components.
public class DQNNeuralNetUpdateRule implements IUpdateRule<Transition<Integer>> {
private final IDQN qNetwork;
private final IDQN targetQNetwork;
private final int targetUpdateFrequency;
private final INeuralNetUpdater updater;
private final ITDTargetAlgorithm<Integer> tdTargetAlgorithm;
@Getter
private int updateCount = 0;
public DQNNeuralNetUpdateRule(IDQN qNetwork, int targetUpdateFrequency, boolean isDoubleDQN, double gamma, double errorClamp) {
this.qNetwork = qNetwork;
public DQNNeuralNetUpdateRule(IDQN<IDQN> qNetwork, int targetUpdateFrequency, boolean isDoubleDQN, double gamma, double errorClamp) {
this.targetQNetwork = qNetwork.clone();
this.targetUpdateFrequency = targetUpdateFrequency;
tdTargetAlgorithm = isDoubleDQN
? new DoubleDQN(qNetwork, targetQNetwork, gamma, errorClamp)
: new StandardDQN(qNetwork, targetQNetwork, gamma, errorClamp);
updater = new NeuralNetUpdater(qNetwork, targetQNetwork, targetUpdateFrequency);
}
@Override
public void update(List<Transition<Integer>> trainingBatch) {
DataSet targets = tdTargetAlgorithm.computeTDTargets(trainingBatch);
qNetwork.fit(targets.getFeatures(), targets.getLabels());
if(++updateCount % targetUpdateFrequency == 0) {
targetQNetwork.copy(qNetwork);
}
DataSet targets = tdTargetAlgorithm.compute(trainingBatch);
updater.update(targets);
}
}

View File

@ -0,0 +1,30 @@
/*******************************************************************************
* Copyright (c) 2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.rl4j.agent.update.neuralnetupdater;
import org.nd4j.linalg.dataset.api.DataSet;
/**
* The role of INeuralNetUpdater implementations is to update a {@link org.deeplearning4j.rl4j.network.NeuralNet NeuralNet}
* from a {@link DataSet}.<p />
*/
public interface INeuralNetUpdater {
/**
* Update a {@link org.deeplearning4j.rl4j.network.NeuralNet NeuralNet}.
* @param featuresLabels A Dataset that will be used to update the network.
*/
void update(DataSet featuresLabels);
}

View File

@ -0,0 +1,62 @@
/*******************************************************************************
* Copyright (c) 2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.rl4j.agent.update.neuralnetupdater;
import org.deeplearning4j.rl4j.network.ITrainableNeuralNet;
import org.nd4j.linalg.dataset.api.DataSet;
/**
* A {@link INeuralNetUpdater} that updates a neural network and sync a target network at defined intervals
*/
public class NeuralNetUpdater implements INeuralNetUpdater {
private final ITrainableNeuralNet current;
private final ITrainableNeuralNet target;
private int updateCount = 0;
private final int targetUpdateFrequency;
/**
* @param current The current {@link ITrainableNeuralNet network}
* @param target The target {@link ITrainableNeuralNet network}
* @param targetUpdateFrequency Will synchronize the target network at every <i>targetUpdateFrequency</i> updates
*/
public NeuralNetUpdater(ITrainableNeuralNet current,
ITrainableNeuralNet target,
int targetUpdateFrequency) {
this.current = current;
this.target = target;
this.targetUpdateFrequency = targetUpdateFrequency;
}
/**
* Update the current network
* @param featuresLabels A Dataset that will be used to update the network.
*/
@Override
public void update(DataSet featuresLabels) {
current.fit(featuresLabels);
syncTargetNetwork();
}
private void syncTargetNetwork() {
if(++updateCount % targetUpdateFrequency == 0) {
target.copy(current);
}
}
}

View File

@ -23,16 +23,19 @@ import lombok.Setter;
import lombok.Value;
import lombok.extern.slf4j.Slf4j;
import org.deeplearning4j.gym.StepReply;
import org.deeplearning4j.rl4j.learning.*;
import org.deeplearning4j.rl4j.learning.HistoryProcessor;
import org.deeplearning4j.rl4j.learning.IEpochTrainer;
import org.deeplearning4j.rl4j.learning.IHistoryProcessor;
import org.deeplearning4j.rl4j.learning.Learning;
import org.deeplearning4j.rl4j.learning.configuration.IAsyncLearningConfiguration;
import org.deeplearning4j.rl4j.learning.listener.TrainingListener;
import org.deeplearning4j.rl4j.learning.listener.TrainingListenerList;
import org.deeplearning4j.rl4j.mdp.MDP;
import org.deeplearning4j.rl4j.network.NeuralNet;
import org.deeplearning4j.rl4j.space.Encodable;
import org.deeplearning4j.rl4j.observation.Observation;
import org.deeplearning4j.rl4j.policy.IPolicy;
import org.deeplearning4j.rl4j.space.ActionSpace;
import org.deeplearning4j.rl4j.space.Encodable;
import org.deeplearning4j.rl4j.util.IDataManager;
import org.deeplearning4j.rl4j.util.LegacyMDPWrapper;
import org.nd4j.linalg.factory.Nd4j;

View File

@ -18,11 +18,9 @@
package org.deeplearning4j.rl4j.learning.async;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.rl4j.network.NeuralNet;
import org.deeplearning4j.rl4j.network.ITrainableNeuralNet;
import java.util.concurrent.atomic.AtomicInteger;
public interface IAsyncGlobal<NN extends NeuralNet> {
public interface IAsyncGlobal<NN extends ITrainableNeuralNet> {
boolean isTrainingComplete();

View File

@ -80,7 +80,7 @@ public abstract class BaseTDTargetAlgorithm implements ITDTargetAlgorithm<Intege
protected abstract double computeTarget(int batchIdx, double reward, boolean isTerminal);
@Override
public DataSet computeTDTargets(List<Transition<Integer>> transitions) {
public DataSet compute(List<Transition<Integer>> transitions) {
int size = transitions.size();

View File

@ -34,5 +34,5 @@ public interface ITDTargetAlgorithm<A> {
* @param transitions The transitions from the experience replay
* @return A DataSet where every element is the observation and the estimated Q-Values for all actions
*/
DataSet computeTDTargets(List<Transition<A>> transitions);
DataSet compute(List<Transition<A>> transitions);
}

View File

@ -0,0 +1,40 @@
/*******************************************************************************
* Copyright (c) 2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.rl4j.network;
import org.nd4j.linalg.dataset.api.DataSet;
/**
* An interface defining the <i>trainable</i> aspect of a {@link NeuralNet}.
*/
public interface ITrainableNeuralNet<NET_TYPE extends ITrainableNeuralNet> {
/**
* Train the neural net using the supplied <i>feature-labels</i>
* @param featuresLabels The feature-labels
*/
void fit(DataSet featuresLabels);
/**
* Changes this instance to be a copy of the <i>from</i> network.
* @param from The network that will be the source of the copy.
*/
void copy(NET_TYPE from);
/**
* Creates a clone of the network instance.
*/
NET_TYPE clone();
}

View File

@ -29,7 +29,7 @@ import java.io.OutputStream;
* Factorisation between ActorCritic and DQN neural net.
* Useful for AsyncLearning and Thread code.
*/
public interface NeuralNet<NN extends NeuralNet> {
public interface NeuralNet<NN extends NeuralNet> extends IOutputNeuralNet, ITrainableNeuralNet<NN> {
/**
* Returns the underlying MultiLayerNetwork or ComputationGraph objects.
@ -52,18 +52,6 @@ public interface NeuralNet<NN extends NeuralNet> {
*/
INDArray[] outputAll(INDArray batch);
/**
* clone the Neural Net with the same paramaeters
* @return the cloned neural net
*/
NN clone();
/**
* copy the parameters from a neural net
* @param from where to copy parameters
*/
void copy(NN from);
/**
* Calculate the gradients from input and label (target) of all outputs
* @param input input batch

View File

@ -18,6 +18,7 @@
package org.deeplearning4j.rl4j.network.ac;
import lombok.Getter;
import org.apache.commons.lang3.NotImplementedException;
import org.deeplearning4j.nn.api.NeuralNetwork;
import org.deeplearning4j.nn.conf.ComputationGraphConfiguration;
import org.deeplearning4j.nn.gradient.Gradient;
@ -25,8 +26,10 @@ import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.layers.recurrent.RnnOutputLayer;
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
import org.deeplearning4j.optimize.api.TrainingListener;
import org.deeplearning4j.rl4j.observation.Observation;
import org.deeplearning4j.util.ModelSerializer;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.api.DataSet;
import java.io.IOException;
import java.io.OutputStream;
@ -80,6 +83,11 @@ public class ActorCriticCompGraph implements IActorCritic<ActorCriticCompGraph>
return nn;
}
@Override
public void fit(DataSet featuresLabels) {
fit(featuresLabels.getFeatures(), new INDArray[] { featuresLabels.getLabels() });
}
public void copy(ActorCriticCompGraph from) {
cg.setParams(from.cg.params());
}
@ -137,5 +145,19 @@ public class ActorCriticCompGraph implements IActorCritic<ActorCriticCompGraph>
public void save(String pathValue, String pathPolicy) throws IOException {
throw new UnsupportedOperationException("Call save(path)");
}
@Override
public INDArray output(Observation observation) {
// TODO: signature of output() will change to return a class that has named outputs to support network like
// this one (output from the value-network and another output for the policy-network
throw new NotImplementedException("Not implemented: will be done with AgentLearner async support");
}
@Override
public INDArray output(INDArray batch) {
// TODO: signature of output() will change to return a class that has named outputs to support network like
// this one (output from the value-network and another output for the policy-network
throw new NotImplementedException("Not implemented: will be done with AgentLearner async support");
}
}

View File

@ -17,6 +17,7 @@
package org.deeplearning4j.rl4j.network.ac;
import lombok.Getter;
import org.apache.commons.lang3.NotImplementedException;
import org.deeplearning4j.nn.api.NeuralNetwork;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.gradient.Gradient;
@ -24,8 +25,10 @@ import org.deeplearning4j.nn.layers.recurrent.RnnOutputLayer;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
import org.deeplearning4j.optimize.api.TrainingListener;
import org.deeplearning4j.rl4j.observation.Observation;
import org.deeplearning4j.util.ModelSerializer;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.api.DataSet;
import java.io.IOException;
import java.io.OutputStream;
@ -86,6 +89,13 @@ public class ActorCriticSeparate<NN extends ActorCriticSeparate> implements IAct
return nn;
}
@Override
public void fit(DataSet featuresLabels) {
// TODO: signature of fit() will change from DataSet to a class that has named labels to support network like
// this one (labels for the value-network and another labels for the policy-network
throw new NotImplementedException("Not implemented: will be done with AgentLearner async support");
}
public void copy(NN from) {
valueNet.setParams(from.valueNet.params());
policyNet.setParams(from.policyNet.params());
@ -164,6 +174,20 @@ public class ActorCriticSeparate<NN extends ActorCriticSeparate> implements IAct
ModelSerializer.writeModel(valueNet, pathValue, true);
ModelSerializer.writeModel(policyNet, pathPolicy, true);
}
@Override
public INDArray output(Observation observation) {
// TODO: signature of output() will change to return a class that has named outputs to support network like
// this one (output from the value-network and another output for the policy-network
throw new NotImplementedException("Not implemented: will be done with AgentLearner async support");
}
@Override
public INDArray output(INDArray batch) {
// TODO: signature of output() will change to return a class that has named outputs to support network like
// this one (output from the value-network and another output for the policy-network
throw new NotImplementedException("Not implemented: will be done with AgentLearner async support");
}
}

View File

@ -16,7 +16,6 @@
package org.deeplearning4j.rl4j.network.ac;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.rl4j.network.NeuralNet;
import org.nd4j.linalg.api.ndarray.INDArray;
@ -33,27 +32,11 @@ import java.io.OutputStream;
*/
public interface IActorCritic<NN extends IActorCritic> extends NeuralNet<NN> {
boolean isRecurrent();
void reset();
void fit(INDArray input, INDArray[] labels);
//FIRST SHOULD BE VALUE AND SECOND IS SOFTMAX POLICY. DONT MESS THIS UP OR ELSE ASYNC THREAD IS BROKEN (maxQ) !
INDArray[] outputAll(INDArray batch);
NN clone();
void copy(NN from);
Gradient[] gradient(INDArray input, INDArray[] labels);
void applyGradient(Gradient[] gradient, int batchSize);
void save(OutputStream streamValue, OutputStream streamPolicy) throws IOException;
void save(String pathValue, String pathPolicy) throws IOException;
double getLatestScore();
}

View File

@ -25,6 +25,7 @@ import org.deeplearning4j.optimize.api.TrainingListener;
import org.deeplearning4j.rl4j.observation.Observation;
import org.deeplearning4j.util.ModelSerializer;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.api.DataSet;
import java.io.IOException;
import java.io.OutputStream;
@ -33,7 +34,7 @@ import java.util.Collection;
/**
* @author rubenfiszel (ruben.fiszel@epfl.ch) 7/25/16.
*/
public class DQN<NN extends DQN> implements IDQN<NN> {
public class DQN implements IDQN<DQN> {
final protected MultiLayerNetwork mln;
@ -79,16 +80,23 @@ public class DQN<NN extends DQN> implements IDQN<NN> {
return new INDArray[] {output(batch)};
}
public NN clone() {
NN nn = (NN)new DQN(mln.clone());
nn.mln.setListeners(mln.getListeners());
return nn;
@Override
public void fit(DataSet featuresLabels) {
fit(featuresLabels.getFeatures(), featuresLabels.getLabels());
}
public void copy(NN from) {
@Override
public void copy(DQN from) {
mln.setParams(from.mln.params());
}
@Override
public DQN clone() {
DQN nn = new DQN(mln.clone());
nn.mln.setListeners(mln.getListeners());
return nn;
}
public Gradient[] gradient(INDArray input, INDArray labels) {
mln.setInput(input);
mln.setLabels(labels);

View File

@ -17,9 +17,7 @@
package org.deeplearning4j.rl4j.network.dqn;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.rl4j.network.IOutputNeuralNet;
import org.deeplearning4j.rl4j.network.NeuralNet;
import org.deeplearning4j.rl4j.observation.Observation;
import org.nd4j.linalg.api.ndarray.INDArray;
/**
@ -28,27 +26,11 @@ import org.nd4j.linalg.api.ndarray.INDArray;
* This neural net quantify the value of each action given a state
*
*/
public interface IDQN<NN extends IDQN> extends NeuralNet<NN>, IOutputNeuralNet {
boolean isRecurrent();
void reset();
public interface IDQN<NN extends IDQN> extends NeuralNet<NN> {
void fit(INDArray input, INDArray labels);
void fit(INDArray input, INDArray[] labels);
INDArray[] outputAll(INDArray batch);
NN clone();
void copy(NN from);
Gradient[] gradient(INDArray input, INDArray label);
Gradient[] gradient(INDArray input, INDArray[] label);
void applyGradient(Gradient[] gradient, int batchSize);
double getLatestScore();
}

View File

@ -0,0 +1,51 @@
package org.deeplearning4j.rl4j.agent.update.neuralnetupdater;
import org.deeplearning4j.rl4j.network.ITrainableNeuralNet;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.mockito.Mock;
import org.mockito.junit.MockitoJUnitRunner;
import org.nd4j.linalg.dataset.api.DataSet;
import static org.mockito.Mockito.*;
@RunWith(MockitoJUnitRunner.class)
public class NeuralNetUpdaterTest {
@Mock
ITrainableNeuralNet currentMock;
@Mock
ITrainableNeuralNet targetMock;
@Test
public void when_callingUpdate_expect_currentUpdatedAndtargetNotChanged() {
// Arrange
NeuralNetUpdater sut = new NeuralNetUpdater(currentMock, targetMock, Integer.MAX_VALUE);
DataSet featureLabels = new org.nd4j.linalg.dataset.DataSet();
// Act
sut.update(featureLabels);
// Assert
verify(currentMock, times(1)).fit(featureLabels);
verify(targetMock, never()).fit(any());
}
@Test
public void when_callingUpdate_expect_targetUpdatedFromCurrentAtFrequency() {
// Arrange
NeuralNetUpdater sut = new NeuralNetUpdater(currentMock, targetMock, 3);
DataSet featureLabels = new org.nd4j.linalg.dataset.DataSet();
// Act
sut.update(featureLabels);
sut.update(featureLabels);
sut.update(featureLabels);
// Assert
verify(currentMock, never()).copy(any());
verify(targetMock, times(1)).copy(currentMock);
}
}

View File

@ -50,7 +50,7 @@ public class DoubleDQNTest {
DoubleDQN sut = new DoubleDQN(qNetworkMock, targetQNetworkMock, 0.5);
// Act
DataSet result = sut.computeTDTargets(transitions);
DataSet result = sut.compute(transitions);
// Assert
INDArray evaluatedQValues = result.getLabels();
@ -74,7 +74,7 @@ public class DoubleDQNTest {
DoubleDQN sut = new DoubleDQN(qNetworkMock, targetQNetworkMock, 0.5);
// Act
DataSet result = sut.computeTDTargets(transitions);
DataSet result = sut.compute(transitions);
// Assert
INDArray evaluatedQValues = result.getLabels();
@ -102,7 +102,7 @@ public class DoubleDQNTest {
DoubleDQN sut = new DoubleDQN(qNetworkMock, targetQNetworkMock, 0.5);
// Act
DataSet result = sut.computeTDTargets(transitions);
DataSet result = sut.compute(transitions);
// Assert
INDArray evaluatedQValues = result.getLabels();

View File

@ -50,7 +50,7 @@ public class StandardDQNTest {
StandardDQN sut = new StandardDQN(qNetworkMock, targetQNetworkMock, 0.5);
// Act
DataSet result = sut.computeTDTargets(transitions);
DataSet result = sut.compute(transitions);
// Assert
INDArray evaluatedQValues = result.getLabels();
@ -72,7 +72,7 @@ public class StandardDQNTest {
StandardDQN sut = new StandardDQN(qNetworkMock, targetQNetworkMock, 0.5);
// Act
DataSet result = sut.computeTDTargets(transitions);
DataSet result = sut.compute(transitions);
// Assert
INDArray evaluatedQValues = result.getLabels();
@ -98,7 +98,7 @@ public class StandardDQNTest {
StandardDQN sut = new StandardDQN(qNetworkMock, targetQNetworkMock, 0.5);
// Act
DataSet result = sut.computeTDTargets(transitions);
DataSet result = sut.compute(transitions);
// Assert
INDArray evaluatedQValues = result.getLabels();

View File

@ -2,10 +2,12 @@ package org.deeplearning4j.rl4j.learning.sync.support;
import org.deeplearning4j.nn.api.NeuralNetwork;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.rl4j.network.ITrainableNeuralNet;
import org.deeplearning4j.rl4j.network.NeuralNet;
import org.deeplearning4j.rl4j.network.dqn.IDQN;
import org.deeplearning4j.rl4j.observation.Observation;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.api.DataSet;
import java.io.IOException;
import java.io.OutputStream;
@ -66,21 +68,21 @@ public class MockDQN implements IDQN {
return new INDArray[0];
}
@Override
public void fit(DataSet featuresLabels) {
throw new UnsupportedOperationException();
}
@Override
public void copy(ITrainableNeuralNet from) {
throw new UnsupportedOperationException();
}
@Override
public IDQN clone() {
return null;
}
@Override
public void copy(NeuralNet from) {
}
@Override
public void copy(IDQN from) {
}
@Override
public Gradient[] gradient(INDArray input, INDArray label) {
return new Gradient[0];

View File

@ -36,6 +36,7 @@ import org.deeplearning4j.rl4j.util.LegacyMDPWrapper;
import org.junit.Test;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.api.DataSet;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.lossfunctions.LossFunctions;
@ -88,6 +89,11 @@ public class PolicyTest {
throw new UnsupportedOperationException();
}
@Override
public void fit(DataSet featuresLabels) {
throw new UnsupportedOperationException();
}
@Override
public void copy(NN from) {
throw new UnsupportedOperationException();
@ -127,6 +133,16 @@ public class PolicyTest {
public void save(String filename) throws IOException {
throw new UnsupportedOperationException();
}
@Override
public INDArray output(Observation observation) {
throw new UnsupportedOperationException();
}
@Override
public INDArray output(INDArray batch) {
throw new UnsupportedOperationException();
}
}
@Test

View File

@ -2,11 +2,13 @@ package org.deeplearning4j.rl4j.support;
import org.deeplearning4j.nn.api.NeuralNetwork;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.rl4j.network.ITrainableNeuralNet;
import org.deeplearning4j.rl4j.network.NeuralNet;
import org.deeplearning4j.rl4j.network.dqn.IDQN;
import org.deeplearning4j.rl4j.observation.Observation;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.common.primitives.Pair;
import org.nd4j.linalg.dataset.api.DataSet;
import java.io.IOException;
import java.io.OutputStream;
@ -63,6 +65,16 @@ public class MockDQN implements IDQN {
return new INDArray[] { batch.mul(-1.0) };
}
@Override
public void fit(DataSet featuresLabels) {
throw new UnsupportedOperationException();
}
@Override
public void copy(ITrainableNeuralNet from) {
throw new UnsupportedOperationException();
}
@Override
public IDQN clone() {
MockDQN clone = new MockDQN();
@ -71,16 +83,6 @@ public class MockDQN implements IDQN {
return clone;
}
@Override
public void copy(NeuralNet from) {
}
@Override
public void copy(IDQN from) {
}
@Override
public Gradient[] gradient(INDArray input, INDArray label) {
gradientParams.add(new Pair<INDArray, INDArray>(input, label));

View File

@ -2,8 +2,11 @@ package org.deeplearning4j.rl4j.support;
import org.deeplearning4j.nn.api.NeuralNetwork;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.rl4j.network.ITrainableNeuralNet;
import org.deeplearning4j.rl4j.network.NeuralNet;
import org.deeplearning4j.rl4j.observation.Observation;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.api.DataSet;
import org.nd4j.linalg.factory.Nd4j;
import java.io.IOException;
@ -39,15 +42,20 @@ public class MockNeuralNet implements NeuralNet {
}
@Override
public NeuralNet clone() {
return this;
public void fit(DataSet featuresLabels) {
throw new UnsupportedOperationException();
}
@Override
public void copy(NeuralNet from) {
public void copy(ITrainableNeuralNet from) {
++copyCallCount;
}
@Override
public NeuralNet clone() {
return this;
}
@Override
public Gradient[] gradient(INDArray input, INDArray[] labels) {
return new Gradient[0];
@ -77,4 +85,14 @@ public class MockNeuralNet implements NeuralNet {
public void save(String filename) throws IOException {
}
@Override
public INDArray output(Observation observation) {
throw new UnsupportedOperationException();
}
@Override
public INDArray output(INDArray batch) {
throw new UnsupportedOperationException();
}
}