Merge pull request #9027 from KonduitAI/master

Development updates [WIP]
master
Alex Black 2020-06-29 19:55:29 +10:00 committed by GitHub
commit b3e3456b89
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
23 changed files with 349 additions and 107 deletions

View File

@ -258,7 +258,7 @@ public class PythonTypes {
return ret; return ret;
}else if (javaObject instanceof byte[]){ }else if (javaObject instanceof byte[]){
byte[] arr = (byte[]) javaObject; byte[] arr = (byte[]) javaObject;
for (int x : arr) ret.add(x); for (int x : arr) ret.add(x & 0xff);
return ret; return ret;
} else if (javaObject instanceof long[]) { } else if (javaObject instanceof long[]) {
long[] arr = (long[]) javaObject; long[] arr = (long[]) javaObject;

View File

@ -81,16 +81,31 @@ public class PythonPrimitiveTypesTest {
} }
@Test @Test
public void testBytes() { public void testBytes() {
byte[] bytes = new byte[256];
for (int i = 0; i < 256; i++) {
bytes[i] = (byte) i;
}
List<PythonVariable> inputs = new ArrayList<>();
inputs.add(new PythonVariable<>("b1", PythonTypes.BYTES, bytes));
List<PythonVariable> outputs = new ArrayList<>();
outputs.add(new PythonVariable<>("b2", PythonTypes.BYTES));
String code = "b2=b1";
PythonExecutioner.exec(code, inputs, outputs);
Assert.assertArrayEquals(bytes, (byte[]) outputs.get(0).getValue());
}
@Test
public void testBytes2() {
byte[] bytes = new byte[]{97, 98, 99}; byte[] bytes = new byte[]{97, 98, 99};
List<PythonVariable> inputs = new ArrayList<>(); List<PythonVariable> inputs = new ArrayList<>();
inputs.add(new PythonVariable<>("buff", PythonTypes.BYTES, bytes)); inputs.add(new PythonVariable<>("b1", PythonTypes.BYTES, bytes));
List<PythonVariable> outputs = new ArrayList<>(); List<PythonVariable> outputs = new ArrayList<>();
outputs.add(new PythonVariable<>("s1", PythonTypes.STR)); outputs.add(new PythonVariable<>("s1", PythonTypes.STR));
outputs.add(new PythonVariable<>("buff2", PythonTypes.BYTES)); outputs.add(new PythonVariable<>("b2", PythonTypes.BYTES));
String code = "s1 = ''.join(chr(c) for c in buff)\nbuff2=b'def'"; String code = "s1 = ''.join(chr(c) for c in b1)\nb2=b'def'";
PythonExecutioner.exec(code, inputs, outputs); PythonExecutioner.exec(code, inputs, outputs);
Assert.assertEquals("abc", outputs.get(0).getValue()); Assert.assertEquals("abc", outputs.get(0).getValue());
Assert.assertArrayEquals(new byte[]{100, 101, 102}, (byte[])outputs.get(1).getValue()); Assert.assertArrayEquals(new byte[]{100, 101, 102}, (byte[]) outputs.get(1).getValue());
} }
} }

View File

@ -16,6 +16,8 @@
package org.deeplearning4j.rl4j.agent.update; package org.deeplearning4j.rl4j.agent.update;
import lombok.Getter; import lombok.Getter;
import org.deeplearning4j.rl4j.agent.update.neuralnetupdater.INeuralNetUpdater;
import org.deeplearning4j.rl4j.agent.update.neuralnetupdater.NeuralNetUpdater;
import org.deeplearning4j.rl4j.learning.sync.Transition; import org.deeplearning4j.rl4j.learning.sync.Transition;
import org.deeplearning4j.rl4j.learning.sync.qlearning.discrete.TDTargetAlgorithm.DoubleDQN; import org.deeplearning4j.rl4j.learning.sync.qlearning.discrete.TDTargetAlgorithm.DoubleDQN;
import org.deeplearning4j.rl4j.learning.sync.qlearning.discrete.TDTargetAlgorithm.ITDTargetAlgorithm; import org.deeplearning4j.rl4j.learning.sync.qlearning.discrete.TDTargetAlgorithm.ITDTargetAlgorithm;
@ -29,30 +31,26 @@ import java.util.List;
// and network update to sub components. // and network update to sub components.
public class DQNNeuralNetUpdateRule implements IUpdateRule<Transition<Integer>> { public class DQNNeuralNetUpdateRule implements IUpdateRule<Transition<Integer>> {
private final IDQN qNetwork;
private final IDQN targetQNetwork; private final IDQN targetQNetwork;
private final int targetUpdateFrequency; private final INeuralNetUpdater updater;
private final ITDTargetAlgorithm<Integer> tdTargetAlgorithm; private final ITDTargetAlgorithm<Integer> tdTargetAlgorithm;
@Getter @Getter
private int updateCount = 0; private int updateCount = 0;
public DQNNeuralNetUpdateRule(IDQN qNetwork, int targetUpdateFrequency, boolean isDoubleDQN, double gamma, double errorClamp) { public DQNNeuralNetUpdateRule(IDQN<IDQN> qNetwork, int targetUpdateFrequency, boolean isDoubleDQN, double gamma, double errorClamp) {
this.qNetwork = qNetwork;
this.targetQNetwork = qNetwork.clone(); this.targetQNetwork = qNetwork.clone();
this.targetUpdateFrequency = targetUpdateFrequency;
tdTargetAlgorithm = isDoubleDQN tdTargetAlgorithm = isDoubleDQN
? new DoubleDQN(qNetwork, targetQNetwork, gamma, errorClamp) ? new DoubleDQN(qNetwork, targetQNetwork, gamma, errorClamp)
: new StandardDQN(qNetwork, targetQNetwork, gamma, errorClamp); : new StandardDQN(qNetwork, targetQNetwork, gamma, errorClamp);
updater = new NeuralNetUpdater(qNetwork, targetQNetwork, targetUpdateFrequency);
} }
@Override @Override
public void update(List<Transition<Integer>> trainingBatch) { public void update(List<Transition<Integer>> trainingBatch) {
DataSet targets = tdTargetAlgorithm.computeTDTargets(trainingBatch); DataSet targets = tdTargetAlgorithm.compute(trainingBatch);
qNetwork.fit(targets.getFeatures(), targets.getLabels()); updater.update(targets);
if(++updateCount % targetUpdateFrequency == 0) {
targetQNetwork.copy(qNetwork);
}
} }
} }

View File

@ -0,0 +1,30 @@
/*******************************************************************************
* Copyright (c) 2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.rl4j.agent.update.neuralnetupdater;
import org.nd4j.linalg.dataset.api.DataSet;
/**
* The role of INeuralNetUpdater implementations is to update a {@link org.deeplearning4j.rl4j.network.NeuralNet NeuralNet}
* from a {@link DataSet}.<p />
*/
public interface INeuralNetUpdater {
/**
* Update a {@link org.deeplearning4j.rl4j.network.NeuralNet NeuralNet}.
* @param featuresLabels A Dataset that will be used to update the network.
*/
void update(DataSet featuresLabels);
}

View File

@ -0,0 +1,62 @@
/*******************************************************************************
* Copyright (c) 2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.rl4j.agent.update.neuralnetupdater;
import org.deeplearning4j.rl4j.network.ITrainableNeuralNet;
import org.nd4j.linalg.dataset.api.DataSet;
/**
* A {@link INeuralNetUpdater} that updates a neural network and sync a target network at defined intervals
*/
public class NeuralNetUpdater implements INeuralNetUpdater {
private final ITrainableNeuralNet current;
private final ITrainableNeuralNet target;
private int updateCount = 0;
private final int targetUpdateFrequency;
/**
* @param current The current {@link ITrainableNeuralNet network}
* @param target The target {@link ITrainableNeuralNet network}
* @param targetUpdateFrequency Will synchronize the target network at every <i>targetUpdateFrequency</i> updates
*/
public NeuralNetUpdater(ITrainableNeuralNet current,
ITrainableNeuralNet target,
int targetUpdateFrequency) {
this.current = current;
this.target = target;
this.targetUpdateFrequency = targetUpdateFrequency;
}
/**
* Update the current network
* @param featuresLabels A Dataset that will be used to update the network.
*/
@Override
public void update(DataSet featuresLabels) {
current.fit(featuresLabels);
syncTargetNetwork();
}
private void syncTargetNetwork() {
if(++updateCount % targetUpdateFrequency == 0) {
target.copy(current);
}
}
}

View File

@ -23,16 +23,19 @@ import lombok.Setter;
import lombok.Value; import lombok.Value;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.deeplearning4j.gym.StepReply; import org.deeplearning4j.gym.StepReply;
import org.deeplearning4j.rl4j.learning.*; import org.deeplearning4j.rl4j.learning.HistoryProcessor;
import org.deeplearning4j.rl4j.learning.IEpochTrainer;
import org.deeplearning4j.rl4j.learning.IHistoryProcessor;
import org.deeplearning4j.rl4j.learning.Learning;
import org.deeplearning4j.rl4j.learning.configuration.IAsyncLearningConfiguration; import org.deeplearning4j.rl4j.learning.configuration.IAsyncLearningConfiguration;
import org.deeplearning4j.rl4j.learning.listener.TrainingListener; import org.deeplearning4j.rl4j.learning.listener.TrainingListener;
import org.deeplearning4j.rl4j.learning.listener.TrainingListenerList; import org.deeplearning4j.rl4j.learning.listener.TrainingListenerList;
import org.deeplearning4j.rl4j.mdp.MDP; import org.deeplearning4j.rl4j.mdp.MDP;
import org.deeplearning4j.rl4j.network.NeuralNet; import org.deeplearning4j.rl4j.network.NeuralNet;
import org.deeplearning4j.rl4j.space.Encodable;
import org.deeplearning4j.rl4j.observation.Observation; import org.deeplearning4j.rl4j.observation.Observation;
import org.deeplearning4j.rl4j.policy.IPolicy; import org.deeplearning4j.rl4j.policy.IPolicy;
import org.deeplearning4j.rl4j.space.ActionSpace; import org.deeplearning4j.rl4j.space.ActionSpace;
import org.deeplearning4j.rl4j.space.Encodable;
import org.deeplearning4j.rl4j.util.IDataManager; import org.deeplearning4j.rl4j.util.IDataManager;
import org.deeplearning4j.rl4j.util.LegacyMDPWrapper; import org.deeplearning4j.rl4j.util.LegacyMDPWrapper;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;

View File

@ -18,11 +18,9 @@
package org.deeplearning4j.rl4j.learning.async; package org.deeplearning4j.rl4j.learning.async;
import org.deeplearning4j.nn.gradient.Gradient; import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.rl4j.network.NeuralNet; import org.deeplearning4j.rl4j.network.ITrainableNeuralNet;
import java.util.concurrent.atomic.AtomicInteger; public interface IAsyncGlobal<NN extends ITrainableNeuralNet> {
public interface IAsyncGlobal<NN extends NeuralNet> {
boolean isTrainingComplete(); boolean isTrainingComplete();

View File

@ -80,7 +80,7 @@ public abstract class BaseTDTargetAlgorithm implements ITDTargetAlgorithm<Intege
protected abstract double computeTarget(int batchIdx, double reward, boolean isTerminal); protected abstract double computeTarget(int batchIdx, double reward, boolean isTerminal);
@Override @Override
public DataSet computeTDTargets(List<Transition<Integer>> transitions) { public DataSet compute(List<Transition<Integer>> transitions) {
int size = transitions.size(); int size = transitions.size();

View File

@ -34,5 +34,5 @@ public interface ITDTargetAlgorithm<A> {
* @param transitions The transitions from the experience replay * @param transitions The transitions from the experience replay
* @return A DataSet where every element is the observation and the estimated Q-Values for all actions * @return A DataSet where every element is the observation and the estimated Q-Values for all actions
*/ */
DataSet computeTDTargets(List<Transition<A>> transitions); DataSet compute(List<Transition<A>> transitions);
} }

View File

@ -0,0 +1,40 @@
/*******************************************************************************
* Copyright (c) 2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.rl4j.network;
import org.nd4j.linalg.dataset.api.DataSet;
/**
* An interface defining the <i>trainable</i> aspect of a {@link NeuralNet}.
*/
public interface ITrainableNeuralNet<NET_TYPE extends ITrainableNeuralNet> {
/**
* Train the neural net using the supplied <i>feature-labels</i>
* @param featuresLabels The feature-labels
*/
void fit(DataSet featuresLabels);
/**
* Changes this instance to be a copy of the <i>from</i> network.
* @param from The network that will be the source of the copy.
*/
void copy(NET_TYPE from);
/**
* Creates a clone of the network instance.
*/
NET_TYPE clone();
}

View File

@ -29,7 +29,7 @@ import java.io.OutputStream;
* Factorisation between ActorCritic and DQN neural net. * Factorisation between ActorCritic and DQN neural net.
* Useful for AsyncLearning and Thread code. * Useful for AsyncLearning and Thread code.
*/ */
public interface NeuralNet<NN extends NeuralNet> { public interface NeuralNet<NN extends NeuralNet> extends IOutputNeuralNet, ITrainableNeuralNet<NN> {
/** /**
* Returns the underlying MultiLayerNetwork or ComputationGraph objects. * Returns the underlying MultiLayerNetwork or ComputationGraph objects.
@ -52,18 +52,6 @@ public interface NeuralNet<NN extends NeuralNet> {
*/ */
INDArray[] outputAll(INDArray batch); INDArray[] outputAll(INDArray batch);
/**
* clone the Neural Net with the same paramaeters
* @return the cloned neural net
*/
NN clone();
/**
* copy the parameters from a neural net
* @param from where to copy parameters
*/
void copy(NN from);
/** /**
* Calculate the gradients from input and label (target) of all outputs * Calculate the gradients from input and label (target) of all outputs
* @param input input batch * @param input input batch

View File

@ -18,6 +18,7 @@
package org.deeplearning4j.rl4j.network.ac; package org.deeplearning4j.rl4j.network.ac;
import lombok.Getter; import lombok.Getter;
import org.apache.commons.lang3.NotImplementedException;
import org.deeplearning4j.nn.api.NeuralNetwork; import org.deeplearning4j.nn.api.NeuralNetwork;
import org.deeplearning4j.nn.conf.ComputationGraphConfiguration; import org.deeplearning4j.nn.conf.ComputationGraphConfiguration;
import org.deeplearning4j.nn.gradient.Gradient; import org.deeplearning4j.nn.gradient.Gradient;
@ -25,8 +26,10 @@ import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.layers.recurrent.RnnOutputLayer; import org.deeplearning4j.nn.layers.recurrent.RnnOutputLayer;
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr; import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
import org.deeplearning4j.optimize.api.TrainingListener; import org.deeplearning4j.optimize.api.TrainingListener;
import org.deeplearning4j.rl4j.observation.Observation;
import org.deeplearning4j.util.ModelSerializer; import org.deeplearning4j.util.ModelSerializer;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.api.DataSet;
import java.io.IOException; import java.io.IOException;
import java.io.OutputStream; import java.io.OutputStream;
@ -80,6 +83,11 @@ public class ActorCriticCompGraph implements IActorCritic<ActorCriticCompGraph>
return nn; return nn;
} }
@Override
public void fit(DataSet featuresLabels) {
fit(featuresLabels.getFeatures(), new INDArray[] { featuresLabels.getLabels() });
}
public void copy(ActorCriticCompGraph from) { public void copy(ActorCriticCompGraph from) {
cg.setParams(from.cg.params()); cg.setParams(from.cg.params());
} }
@ -137,5 +145,19 @@ public class ActorCriticCompGraph implements IActorCritic<ActorCriticCompGraph>
public void save(String pathValue, String pathPolicy) throws IOException { public void save(String pathValue, String pathPolicy) throws IOException {
throw new UnsupportedOperationException("Call save(path)"); throw new UnsupportedOperationException("Call save(path)");
} }
@Override
public INDArray output(Observation observation) {
// TODO: signature of output() will change to return a class that has named outputs to support network like
// this one (output from the value-network and another output for the policy-network
throw new NotImplementedException("Not implemented: will be done with AgentLearner async support");
}
@Override
public INDArray output(INDArray batch) {
// TODO: signature of output() will change to return a class that has named outputs to support network like
// this one (output from the value-network and another output for the policy-network
throw new NotImplementedException("Not implemented: will be done with AgentLearner async support");
}
} }

View File

@ -17,6 +17,7 @@
package org.deeplearning4j.rl4j.network.ac; package org.deeplearning4j.rl4j.network.ac;
import lombok.Getter; import lombok.Getter;
import org.apache.commons.lang3.NotImplementedException;
import org.deeplearning4j.nn.api.NeuralNetwork; import org.deeplearning4j.nn.api.NeuralNetwork;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration; import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.gradient.Gradient; import org.deeplearning4j.nn.gradient.Gradient;
@ -24,8 +25,10 @@ import org.deeplearning4j.nn.layers.recurrent.RnnOutputLayer;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr; import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
import org.deeplearning4j.optimize.api.TrainingListener; import org.deeplearning4j.optimize.api.TrainingListener;
import org.deeplearning4j.rl4j.observation.Observation;
import org.deeplearning4j.util.ModelSerializer; import org.deeplearning4j.util.ModelSerializer;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.api.DataSet;
import java.io.IOException; import java.io.IOException;
import java.io.OutputStream; import java.io.OutputStream;
@ -86,6 +89,13 @@ public class ActorCriticSeparate<NN extends ActorCriticSeparate> implements IAct
return nn; return nn;
} }
@Override
public void fit(DataSet featuresLabels) {
// TODO: signature of fit() will change from DataSet to a class that has named labels to support network like
// this one (labels for the value-network and another labels for the policy-network
throw new NotImplementedException("Not implemented: will be done with AgentLearner async support");
}
public void copy(NN from) { public void copy(NN from) {
valueNet.setParams(from.valueNet.params()); valueNet.setParams(from.valueNet.params());
policyNet.setParams(from.policyNet.params()); policyNet.setParams(from.policyNet.params());
@ -164,6 +174,20 @@ public class ActorCriticSeparate<NN extends ActorCriticSeparate> implements IAct
ModelSerializer.writeModel(valueNet, pathValue, true); ModelSerializer.writeModel(valueNet, pathValue, true);
ModelSerializer.writeModel(policyNet, pathPolicy, true); ModelSerializer.writeModel(policyNet, pathPolicy, true);
} }
@Override
public INDArray output(Observation observation) {
// TODO: signature of output() will change to return a class that has named outputs to support network like
// this one (output from the value-network and another output for the policy-network
throw new NotImplementedException("Not implemented: will be done with AgentLearner async support");
}
@Override
public INDArray output(INDArray batch) {
// TODO: signature of output() will change to return a class that has named outputs to support network like
// this one (output from the value-network and another output for the policy-network
throw new NotImplementedException("Not implemented: will be done with AgentLearner async support");
}
} }

View File

@ -16,7 +16,6 @@
package org.deeplearning4j.rl4j.network.ac; package org.deeplearning4j.rl4j.network.ac;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.rl4j.network.NeuralNet; import org.deeplearning4j.rl4j.network.NeuralNet;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
@ -33,27 +32,11 @@ import java.io.OutputStream;
*/ */
public interface IActorCritic<NN extends IActorCritic> extends NeuralNet<NN> { public interface IActorCritic<NN extends IActorCritic> extends NeuralNet<NN> {
boolean isRecurrent();
void reset();
void fit(INDArray input, INDArray[] labels);
//FIRST SHOULD BE VALUE AND SECOND IS SOFTMAX POLICY. DONT MESS THIS UP OR ELSE ASYNC THREAD IS BROKEN (maxQ) ! //FIRST SHOULD BE VALUE AND SECOND IS SOFTMAX POLICY. DONT MESS THIS UP OR ELSE ASYNC THREAD IS BROKEN (maxQ) !
INDArray[] outputAll(INDArray batch); INDArray[] outputAll(INDArray batch);
NN clone();
void copy(NN from);
Gradient[] gradient(INDArray input, INDArray[] labels);
void applyGradient(Gradient[] gradient, int batchSize);
void save(OutputStream streamValue, OutputStream streamPolicy) throws IOException; void save(OutputStream streamValue, OutputStream streamPolicy) throws IOException;
void save(String pathValue, String pathPolicy) throws IOException; void save(String pathValue, String pathPolicy) throws IOException;
double getLatestScore();
} }

View File

@ -25,6 +25,7 @@ import org.deeplearning4j.optimize.api.TrainingListener;
import org.deeplearning4j.rl4j.observation.Observation; import org.deeplearning4j.rl4j.observation.Observation;
import org.deeplearning4j.util.ModelSerializer; import org.deeplearning4j.util.ModelSerializer;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.api.DataSet;
import java.io.IOException; import java.io.IOException;
import java.io.OutputStream; import java.io.OutputStream;
@ -33,7 +34,7 @@ import java.util.Collection;
/** /**
* @author rubenfiszel (ruben.fiszel@epfl.ch) 7/25/16. * @author rubenfiszel (ruben.fiszel@epfl.ch) 7/25/16.
*/ */
public class DQN<NN extends DQN> implements IDQN<NN> { public class DQN implements IDQN<DQN> {
final protected MultiLayerNetwork mln; final protected MultiLayerNetwork mln;
@ -79,16 +80,23 @@ public class DQN<NN extends DQN> implements IDQN<NN> {
return new INDArray[] {output(batch)}; return new INDArray[] {output(batch)};
} }
public NN clone() { @Override
NN nn = (NN)new DQN(mln.clone()); public void fit(DataSet featuresLabels) {
nn.mln.setListeners(mln.getListeners()); fit(featuresLabels.getFeatures(), featuresLabels.getLabels());
return nn;
} }
public void copy(NN from) { @Override
public void copy(DQN from) {
mln.setParams(from.mln.params()); mln.setParams(from.mln.params());
} }
@Override
public DQN clone() {
DQN nn = new DQN(mln.clone());
nn.mln.setListeners(mln.getListeners());
return nn;
}
public Gradient[] gradient(INDArray input, INDArray labels) { public Gradient[] gradient(INDArray input, INDArray labels) {
mln.setInput(input); mln.setInput(input);
mln.setLabels(labels); mln.setLabels(labels);

View File

@ -17,9 +17,7 @@
package org.deeplearning4j.rl4j.network.dqn; package org.deeplearning4j.rl4j.network.dqn;
import org.deeplearning4j.nn.gradient.Gradient; import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.rl4j.network.IOutputNeuralNet;
import org.deeplearning4j.rl4j.network.NeuralNet; import org.deeplearning4j.rl4j.network.NeuralNet;
import org.deeplearning4j.rl4j.observation.Observation;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
/** /**
@ -28,27 +26,11 @@ import org.nd4j.linalg.api.ndarray.INDArray;
* This neural net quantify the value of each action given a state * This neural net quantify the value of each action given a state
* *
*/ */
public interface IDQN<NN extends IDQN> extends NeuralNet<NN>, IOutputNeuralNet { public interface IDQN<NN extends IDQN> extends NeuralNet<NN> {
boolean isRecurrent();
void reset();
void fit(INDArray input, INDArray labels); void fit(INDArray input, INDArray labels);
void fit(INDArray input, INDArray[] labels);
INDArray[] outputAll(INDArray batch);
NN clone();
void copy(NN from);
Gradient[] gradient(INDArray input, INDArray label); Gradient[] gradient(INDArray input, INDArray label);
Gradient[] gradient(INDArray input, INDArray[] label); Gradient[] gradient(INDArray input, INDArray[] label);
void applyGradient(Gradient[] gradient, int batchSize);
double getLatestScore();
} }

View File

@ -0,0 +1,51 @@
package org.deeplearning4j.rl4j.agent.update.neuralnetupdater;
import org.deeplearning4j.rl4j.network.ITrainableNeuralNet;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.mockito.Mock;
import org.mockito.junit.MockitoJUnitRunner;
import org.nd4j.linalg.dataset.api.DataSet;
import static org.mockito.Mockito.*;
@RunWith(MockitoJUnitRunner.class)
public class NeuralNetUpdaterTest {
@Mock
ITrainableNeuralNet currentMock;
@Mock
ITrainableNeuralNet targetMock;
@Test
public void when_callingUpdate_expect_currentUpdatedAndtargetNotChanged() {
// Arrange
NeuralNetUpdater sut = new NeuralNetUpdater(currentMock, targetMock, Integer.MAX_VALUE);
DataSet featureLabels = new org.nd4j.linalg.dataset.DataSet();
// Act
sut.update(featureLabels);
// Assert
verify(currentMock, times(1)).fit(featureLabels);
verify(targetMock, never()).fit(any());
}
@Test
public void when_callingUpdate_expect_targetUpdatedFromCurrentAtFrequency() {
// Arrange
NeuralNetUpdater sut = new NeuralNetUpdater(currentMock, targetMock, 3);
DataSet featureLabels = new org.nd4j.linalg.dataset.DataSet();
// Act
sut.update(featureLabels);
sut.update(featureLabels);
sut.update(featureLabels);
// Assert
verify(currentMock, never()).copy(any());
verify(targetMock, times(1)).copy(currentMock);
}
}

View File

@ -50,7 +50,7 @@ public class DoubleDQNTest {
DoubleDQN sut = new DoubleDQN(qNetworkMock, targetQNetworkMock, 0.5); DoubleDQN sut = new DoubleDQN(qNetworkMock, targetQNetworkMock, 0.5);
// Act // Act
DataSet result = sut.computeTDTargets(transitions); DataSet result = sut.compute(transitions);
// Assert // Assert
INDArray evaluatedQValues = result.getLabels(); INDArray evaluatedQValues = result.getLabels();
@ -74,7 +74,7 @@ public class DoubleDQNTest {
DoubleDQN sut = new DoubleDQN(qNetworkMock, targetQNetworkMock, 0.5); DoubleDQN sut = new DoubleDQN(qNetworkMock, targetQNetworkMock, 0.5);
// Act // Act
DataSet result = sut.computeTDTargets(transitions); DataSet result = sut.compute(transitions);
// Assert // Assert
INDArray evaluatedQValues = result.getLabels(); INDArray evaluatedQValues = result.getLabels();
@ -102,7 +102,7 @@ public class DoubleDQNTest {
DoubleDQN sut = new DoubleDQN(qNetworkMock, targetQNetworkMock, 0.5); DoubleDQN sut = new DoubleDQN(qNetworkMock, targetQNetworkMock, 0.5);
// Act // Act
DataSet result = sut.computeTDTargets(transitions); DataSet result = sut.compute(transitions);
// Assert // Assert
INDArray evaluatedQValues = result.getLabels(); INDArray evaluatedQValues = result.getLabels();

View File

@ -50,7 +50,7 @@ public class StandardDQNTest {
StandardDQN sut = new StandardDQN(qNetworkMock, targetQNetworkMock, 0.5); StandardDQN sut = new StandardDQN(qNetworkMock, targetQNetworkMock, 0.5);
// Act // Act
DataSet result = sut.computeTDTargets(transitions); DataSet result = sut.compute(transitions);
// Assert // Assert
INDArray evaluatedQValues = result.getLabels(); INDArray evaluatedQValues = result.getLabels();
@ -72,7 +72,7 @@ public class StandardDQNTest {
StandardDQN sut = new StandardDQN(qNetworkMock, targetQNetworkMock, 0.5); StandardDQN sut = new StandardDQN(qNetworkMock, targetQNetworkMock, 0.5);
// Act // Act
DataSet result = sut.computeTDTargets(transitions); DataSet result = sut.compute(transitions);
// Assert // Assert
INDArray evaluatedQValues = result.getLabels(); INDArray evaluatedQValues = result.getLabels();
@ -98,7 +98,7 @@ public class StandardDQNTest {
StandardDQN sut = new StandardDQN(qNetworkMock, targetQNetworkMock, 0.5); StandardDQN sut = new StandardDQN(qNetworkMock, targetQNetworkMock, 0.5);
// Act // Act
DataSet result = sut.computeTDTargets(transitions); DataSet result = sut.compute(transitions);
// Assert // Assert
INDArray evaluatedQValues = result.getLabels(); INDArray evaluatedQValues = result.getLabels();

View File

@ -2,10 +2,12 @@ package org.deeplearning4j.rl4j.learning.sync.support;
import org.deeplearning4j.nn.api.NeuralNetwork; import org.deeplearning4j.nn.api.NeuralNetwork;
import org.deeplearning4j.nn.gradient.Gradient; import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.rl4j.network.ITrainableNeuralNet;
import org.deeplearning4j.rl4j.network.NeuralNet; import org.deeplearning4j.rl4j.network.NeuralNet;
import org.deeplearning4j.rl4j.network.dqn.IDQN; import org.deeplearning4j.rl4j.network.dqn.IDQN;
import org.deeplearning4j.rl4j.observation.Observation; import org.deeplearning4j.rl4j.observation.Observation;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.api.DataSet;
import java.io.IOException; import java.io.IOException;
import java.io.OutputStream; import java.io.OutputStream;
@ -66,21 +68,21 @@ public class MockDQN implements IDQN {
return new INDArray[0]; return new INDArray[0];
} }
@Override
public void fit(DataSet featuresLabels) {
throw new UnsupportedOperationException();
}
@Override
public void copy(ITrainableNeuralNet from) {
throw new UnsupportedOperationException();
}
@Override @Override
public IDQN clone() { public IDQN clone() {
return null; return null;
} }
@Override
public void copy(NeuralNet from) {
}
@Override
public void copy(IDQN from) {
}
@Override @Override
public Gradient[] gradient(INDArray input, INDArray label) { public Gradient[] gradient(INDArray input, INDArray label) {
return new Gradient[0]; return new Gradient[0];

View File

@ -36,6 +36,7 @@ import org.deeplearning4j.rl4j.util.LegacyMDPWrapper;
import org.junit.Test; import org.junit.Test;
import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.api.DataSet;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.lossfunctions.LossFunctions; import org.nd4j.linalg.lossfunctions.LossFunctions;
@ -88,6 +89,11 @@ public class PolicyTest {
throw new UnsupportedOperationException(); throw new UnsupportedOperationException();
} }
@Override
public void fit(DataSet featuresLabels) {
throw new UnsupportedOperationException();
}
@Override @Override
public void copy(NN from) { public void copy(NN from) {
throw new UnsupportedOperationException(); throw new UnsupportedOperationException();
@ -127,6 +133,16 @@ public class PolicyTest {
public void save(String filename) throws IOException { public void save(String filename) throws IOException {
throw new UnsupportedOperationException(); throw new UnsupportedOperationException();
} }
@Override
public INDArray output(Observation observation) {
throw new UnsupportedOperationException();
}
@Override
public INDArray output(INDArray batch) {
throw new UnsupportedOperationException();
}
} }
@Test @Test

View File

@ -2,11 +2,13 @@ package org.deeplearning4j.rl4j.support;
import org.deeplearning4j.nn.api.NeuralNetwork; import org.deeplearning4j.nn.api.NeuralNetwork;
import org.deeplearning4j.nn.gradient.Gradient; import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.rl4j.network.ITrainableNeuralNet;
import org.deeplearning4j.rl4j.network.NeuralNet; import org.deeplearning4j.rl4j.network.NeuralNet;
import org.deeplearning4j.rl4j.network.dqn.IDQN; import org.deeplearning4j.rl4j.network.dqn.IDQN;
import org.deeplearning4j.rl4j.observation.Observation; import org.deeplearning4j.rl4j.observation.Observation;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.common.primitives.Pair; import org.nd4j.common.primitives.Pair;
import org.nd4j.linalg.dataset.api.DataSet;
import java.io.IOException; import java.io.IOException;
import java.io.OutputStream; import java.io.OutputStream;
@ -63,6 +65,16 @@ public class MockDQN implements IDQN {
return new INDArray[] { batch.mul(-1.0) }; return new INDArray[] { batch.mul(-1.0) };
} }
@Override
public void fit(DataSet featuresLabels) {
throw new UnsupportedOperationException();
}
@Override
public void copy(ITrainableNeuralNet from) {
throw new UnsupportedOperationException();
}
@Override @Override
public IDQN clone() { public IDQN clone() {
MockDQN clone = new MockDQN(); MockDQN clone = new MockDQN();
@ -71,16 +83,6 @@ public class MockDQN implements IDQN {
return clone; return clone;
} }
@Override
public void copy(NeuralNet from) {
}
@Override
public void copy(IDQN from) {
}
@Override @Override
public Gradient[] gradient(INDArray input, INDArray label) { public Gradient[] gradient(INDArray input, INDArray label) {
gradientParams.add(new Pair<INDArray, INDArray>(input, label)); gradientParams.add(new Pair<INDArray, INDArray>(input, label));

View File

@ -2,8 +2,11 @@ package org.deeplearning4j.rl4j.support;
import org.deeplearning4j.nn.api.NeuralNetwork; import org.deeplearning4j.nn.api.NeuralNetwork;
import org.deeplearning4j.nn.gradient.Gradient; import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.rl4j.network.ITrainableNeuralNet;
import org.deeplearning4j.rl4j.network.NeuralNet; import org.deeplearning4j.rl4j.network.NeuralNet;
import org.deeplearning4j.rl4j.observation.Observation;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.api.DataSet;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
import java.io.IOException; import java.io.IOException;
@ -39,15 +42,20 @@ public class MockNeuralNet implements NeuralNet {
} }
@Override @Override
public NeuralNet clone() { public void fit(DataSet featuresLabels) {
return this; throw new UnsupportedOperationException();
} }
@Override @Override
public void copy(NeuralNet from) { public void copy(ITrainableNeuralNet from) {
++copyCallCount; ++copyCallCount;
} }
@Override
public NeuralNet clone() {
return this;
}
@Override @Override
public Gradient[] gradient(INDArray input, INDArray[] labels) { public Gradient[] gradient(INDArray input, INDArray[] labels) {
return new Gradient[0]; return new Gradient[0];
@ -77,4 +85,14 @@ public class MockNeuralNet implements NeuralNet {
public void save(String filename) throws IOException { public void save(String filename) throws IOException {
} }
@Override
public INDArray output(Observation observation) {
throw new UnsupportedOperationException();
}
@Override
public INDArray output(INDArray batch) {
throw new UnsupportedOperationException();
}
} }