commit
b3e3456b89
|
@ -258,7 +258,7 @@ public class PythonTypes {
|
||||||
return ret;
|
return ret;
|
||||||
}else if (javaObject instanceof byte[]){
|
}else if (javaObject instanceof byte[]){
|
||||||
byte[] arr = (byte[]) javaObject;
|
byte[] arr = (byte[]) javaObject;
|
||||||
for (int x : arr) ret.add(x);
|
for (int x : arr) ret.add(x & 0xff);
|
||||||
return ret;
|
return ret;
|
||||||
} else if (javaObject instanceof long[]) {
|
} else if (javaObject instanceof long[]) {
|
||||||
long[] arr = (long[]) javaObject;
|
long[] arr = (long[]) javaObject;
|
||||||
|
|
|
@ -81,13 +81,28 @@ public class PythonPrimitiveTypesTest {
|
||||||
}
|
}
|
||||||
@Test
|
@Test
|
||||||
public void testBytes() {
|
public void testBytes() {
|
||||||
|
byte[] bytes = new byte[256];
|
||||||
|
for (int i = 0; i < 256; i++) {
|
||||||
|
bytes[i] = (byte) i;
|
||||||
|
}
|
||||||
|
List<PythonVariable> inputs = new ArrayList<>();
|
||||||
|
inputs.add(new PythonVariable<>("b1", PythonTypes.BYTES, bytes));
|
||||||
|
List<PythonVariable> outputs = new ArrayList<>();
|
||||||
|
outputs.add(new PythonVariable<>("b2", PythonTypes.BYTES));
|
||||||
|
String code = "b2=b1";
|
||||||
|
PythonExecutioner.exec(code, inputs, outputs);
|
||||||
|
Assert.assertArrayEquals(bytes, (byte[]) outputs.get(0).getValue());
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testBytes2() {
|
||||||
byte[] bytes = new byte[]{97, 98, 99};
|
byte[] bytes = new byte[]{97, 98, 99};
|
||||||
List<PythonVariable> inputs = new ArrayList<>();
|
List<PythonVariable> inputs = new ArrayList<>();
|
||||||
inputs.add(new PythonVariable<>("buff", PythonTypes.BYTES, bytes));
|
inputs.add(new PythonVariable<>("b1", PythonTypes.BYTES, bytes));
|
||||||
List<PythonVariable> outputs = new ArrayList<>();
|
List<PythonVariable> outputs = new ArrayList<>();
|
||||||
outputs.add(new PythonVariable<>("s1", PythonTypes.STR));
|
outputs.add(new PythonVariable<>("s1", PythonTypes.STR));
|
||||||
outputs.add(new PythonVariable<>("buff2", PythonTypes.BYTES));
|
outputs.add(new PythonVariable<>("b2", PythonTypes.BYTES));
|
||||||
String code = "s1 = ''.join(chr(c) for c in buff)\nbuff2=b'def'";
|
String code = "s1 = ''.join(chr(c) for c in b1)\nb2=b'def'";
|
||||||
PythonExecutioner.exec(code, inputs, outputs);
|
PythonExecutioner.exec(code, inputs, outputs);
|
||||||
Assert.assertEquals("abc", outputs.get(0).getValue());
|
Assert.assertEquals("abc", outputs.get(0).getValue());
|
||||||
Assert.assertArrayEquals(new byte[]{100, 101, 102}, (byte[]) outputs.get(1).getValue());
|
Assert.assertArrayEquals(new byte[]{100, 101, 102}, (byte[]) outputs.get(1).getValue());
|
||||||
|
|
|
@ -16,6 +16,8 @@
|
||||||
package org.deeplearning4j.rl4j.agent.update;
|
package org.deeplearning4j.rl4j.agent.update;
|
||||||
|
|
||||||
import lombok.Getter;
|
import lombok.Getter;
|
||||||
|
import org.deeplearning4j.rl4j.agent.update.neuralnetupdater.INeuralNetUpdater;
|
||||||
|
import org.deeplearning4j.rl4j.agent.update.neuralnetupdater.NeuralNetUpdater;
|
||||||
import org.deeplearning4j.rl4j.learning.sync.Transition;
|
import org.deeplearning4j.rl4j.learning.sync.Transition;
|
||||||
import org.deeplearning4j.rl4j.learning.sync.qlearning.discrete.TDTargetAlgorithm.DoubleDQN;
|
import org.deeplearning4j.rl4j.learning.sync.qlearning.discrete.TDTargetAlgorithm.DoubleDQN;
|
||||||
import org.deeplearning4j.rl4j.learning.sync.qlearning.discrete.TDTargetAlgorithm.ITDTargetAlgorithm;
|
import org.deeplearning4j.rl4j.learning.sync.qlearning.discrete.TDTargetAlgorithm.ITDTargetAlgorithm;
|
||||||
|
@ -29,30 +31,26 @@ import java.util.List;
|
||||||
// and network update to sub components.
|
// and network update to sub components.
|
||||||
public class DQNNeuralNetUpdateRule implements IUpdateRule<Transition<Integer>> {
|
public class DQNNeuralNetUpdateRule implements IUpdateRule<Transition<Integer>> {
|
||||||
|
|
||||||
private final IDQN qNetwork;
|
|
||||||
private final IDQN targetQNetwork;
|
private final IDQN targetQNetwork;
|
||||||
private final int targetUpdateFrequency;
|
private final INeuralNetUpdater updater;
|
||||||
|
|
||||||
private final ITDTargetAlgorithm<Integer> tdTargetAlgorithm;
|
private final ITDTargetAlgorithm<Integer> tdTargetAlgorithm;
|
||||||
|
|
||||||
@Getter
|
@Getter
|
||||||
private int updateCount = 0;
|
private int updateCount = 0;
|
||||||
|
|
||||||
public DQNNeuralNetUpdateRule(IDQN qNetwork, int targetUpdateFrequency, boolean isDoubleDQN, double gamma, double errorClamp) {
|
public DQNNeuralNetUpdateRule(IDQN<IDQN> qNetwork, int targetUpdateFrequency, boolean isDoubleDQN, double gamma, double errorClamp) {
|
||||||
this.qNetwork = qNetwork;
|
|
||||||
this.targetQNetwork = qNetwork.clone();
|
this.targetQNetwork = qNetwork.clone();
|
||||||
this.targetUpdateFrequency = targetUpdateFrequency;
|
|
||||||
tdTargetAlgorithm = isDoubleDQN
|
tdTargetAlgorithm = isDoubleDQN
|
||||||
? new DoubleDQN(qNetwork, targetQNetwork, gamma, errorClamp)
|
? new DoubleDQN(qNetwork, targetQNetwork, gamma, errorClamp)
|
||||||
: new StandardDQN(qNetwork, targetQNetwork, gamma, errorClamp);
|
: new StandardDQN(qNetwork, targetQNetwork, gamma, errorClamp);
|
||||||
|
updater = new NeuralNetUpdater(qNetwork, targetQNetwork, targetUpdateFrequency);
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void update(List<Transition<Integer>> trainingBatch) {
|
public void update(List<Transition<Integer>> trainingBatch) {
|
||||||
DataSet targets = tdTargetAlgorithm.computeTDTargets(trainingBatch);
|
DataSet targets = tdTargetAlgorithm.compute(trainingBatch);
|
||||||
qNetwork.fit(targets.getFeatures(), targets.getLabels());
|
updater.update(targets);
|
||||||
if(++updateCount % targetUpdateFrequency == 0) {
|
|
||||||
targetQNetwork.copy(qNetwork);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -0,0 +1,30 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2020 Konduit K.K.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
package org.deeplearning4j.rl4j.agent.update.neuralnetupdater;
|
||||||
|
|
||||||
|
import org.nd4j.linalg.dataset.api.DataSet;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* The role of INeuralNetUpdater implementations is to update a {@link org.deeplearning4j.rl4j.network.NeuralNet NeuralNet}
|
||||||
|
* from a {@link DataSet}.<p />
|
||||||
|
*/
|
||||||
|
public interface INeuralNetUpdater {
|
||||||
|
/**
|
||||||
|
* Update a {@link org.deeplearning4j.rl4j.network.NeuralNet NeuralNet}.
|
||||||
|
* @param featuresLabels A Dataset that will be used to update the network.
|
||||||
|
*/
|
||||||
|
void update(DataSet featuresLabels);
|
||||||
|
}
|
|
@ -0,0 +1,62 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2020 Konduit K.K.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
package org.deeplearning4j.rl4j.agent.update.neuralnetupdater;
|
||||||
|
|
||||||
|
import org.deeplearning4j.rl4j.network.ITrainableNeuralNet;
|
||||||
|
import org.nd4j.linalg.dataset.api.DataSet;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* A {@link INeuralNetUpdater} that updates a neural network and sync a target network at defined intervals
|
||||||
|
*/
|
||||||
|
public class NeuralNetUpdater implements INeuralNetUpdater {
|
||||||
|
|
||||||
|
private final ITrainableNeuralNet current;
|
||||||
|
private final ITrainableNeuralNet target;
|
||||||
|
|
||||||
|
private int updateCount = 0;
|
||||||
|
private final int targetUpdateFrequency;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @param current The current {@link ITrainableNeuralNet network}
|
||||||
|
* @param target The target {@link ITrainableNeuralNet network}
|
||||||
|
* @param targetUpdateFrequency Will synchronize the target network at every <i>targetUpdateFrequency</i> updates
|
||||||
|
*/
|
||||||
|
public NeuralNetUpdater(ITrainableNeuralNet current,
|
||||||
|
ITrainableNeuralNet target,
|
||||||
|
int targetUpdateFrequency) {
|
||||||
|
this.current = current;
|
||||||
|
this.target = target;
|
||||||
|
|
||||||
|
this.targetUpdateFrequency = targetUpdateFrequency;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Update the current network
|
||||||
|
* @param featuresLabels A Dataset that will be used to update the network.
|
||||||
|
*/
|
||||||
|
@Override
|
||||||
|
public void update(DataSet featuresLabels) {
|
||||||
|
current.fit(featuresLabels);
|
||||||
|
syncTargetNetwork();
|
||||||
|
}
|
||||||
|
|
||||||
|
private void syncTargetNetwork() {
|
||||||
|
if(++updateCount % targetUpdateFrequency == 0) {
|
||||||
|
target.copy(current);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
|
@ -23,16 +23,19 @@ import lombok.Setter;
|
||||||
import lombok.Value;
|
import lombok.Value;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import org.deeplearning4j.gym.StepReply;
|
import org.deeplearning4j.gym.StepReply;
|
||||||
import org.deeplearning4j.rl4j.learning.*;
|
import org.deeplearning4j.rl4j.learning.HistoryProcessor;
|
||||||
|
import org.deeplearning4j.rl4j.learning.IEpochTrainer;
|
||||||
|
import org.deeplearning4j.rl4j.learning.IHistoryProcessor;
|
||||||
|
import org.deeplearning4j.rl4j.learning.Learning;
|
||||||
import org.deeplearning4j.rl4j.learning.configuration.IAsyncLearningConfiguration;
|
import org.deeplearning4j.rl4j.learning.configuration.IAsyncLearningConfiguration;
|
||||||
import org.deeplearning4j.rl4j.learning.listener.TrainingListener;
|
import org.deeplearning4j.rl4j.learning.listener.TrainingListener;
|
||||||
import org.deeplearning4j.rl4j.learning.listener.TrainingListenerList;
|
import org.deeplearning4j.rl4j.learning.listener.TrainingListenerList;
|
||||||
import org.deeplearning4j.rl4j.mdp.MDP;
|
import org.deeplearning4j.rl4j.mdp.MDP;
|
||||||
import org.deeplearning4j.rl4j.network.NeuralNet;
|
import org.deeplearning4j.rl4j.network.NeuralNet;
|
||||||
import org.deeplearning4j.rl4j.space.Encodable;
|
|
||||||
import org.deeplearning4j.rl4j.observation.Observation;
|
import org.deeplearning4j.rl4j.observation.Observation;
|
||||||
import org.deeplearning4j.rl4j.policy.IPolicy;
|
import org.deeplearning4j.rl4j.policy.IPolicy;
|
||||||
import org.deeplearning4j.rl4j.space.ActionSpace;
|
import org.deeplearning4j.rl4j.space.ActionSpace;
|
||||||
|
import org.deeplearning4j.rl4j.space.Encodable;
|
||||||
import org.deeplearning4j.rl4j.util.IDataManager;
|
import org.deeplearning4j.rl4j.util.IDataManager;
|
||||||
import org.deeplearning4j.rl4j.util.LegacyMDPWrapper;
|
import org.deeplearning4j.rl4j.util.LegacyMDPWrapper;
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
|
|
|
@ -18,11 +18,9 @@
|
||||||
package org.deeplearning4j.rl4j.learning.async;
|
package org.deeplearning4j.rl4j.learning.async;
|
||||||
|
|
||||||
import org.deeplearning4j.nn.gradient.Gradient;
|
import org.deeplearning4j.nn.gradient.Gradient;
|
||||||
import org.deeplearning4j.rl4j.network.NeuralNet;
|
import org.deeplearning4j.rl4j.network.ITrainableNeuralNet;
|
||||||
|
|
||||||
import java.util.concurrent.atomic.AtomicInteger;
|
public interface IAsyncGlobal<NN extends ITrainableNeuralNet> {
|
||||||
|
|
||||||
public interface IAsyncGlobal<NN extends NeuralNet> {
|
|
||||||
|
|
||||||
boolean isTrainingComplete();
|
boolean isTrainingComplete();
|
||||||
|
|
||||||
|
|
|
@ -80,7 +80,7 @@ public abstract class BaseTDTargetAlgorithm implements ITDTargetAlgorithm<Intege
|
||||||
protected abstract double computeTarget(int batchIdx, double reward, boolean isTerminal);
|
protected abstract double computeTarget(int batchIdx, double reward, boolean isTerminal);
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public DataSet computeTDTargets(List<Transition<Integer>> transitions) {
|
public DataSet compute(List<Transition<Integer>> transitions) {
|
||||||
|
|
||||||
int size = transitions.size();
|
int size = transitions.size();
|
||||||
|
|
||||||
|
|
|
@ -34,5 +34,5 @@ public interface ITDTargetAlgorithm<A> {
|
||||||
* @param transitions The transitions from the experience replay
|
* @param transitions The transitions from the experience replay
|
||||||
* @return A DataSet where every element is the observation and the estimated Q-Values for all actions
|
* @return A DataSet where every element is the observation and the estimated Q-Values for all actions
|
||||||
*/
|
*/
|
||||||
DataSet computeTDTargets(List<Transition<A>> transitions);
|
DataSet compute(List<Transition<A>> transitions);
|
||||||
}
|
}
|
||||||
|
|
|
@ -0,0 +1,40 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2020 Konduit K.K.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
package org.deeplearning4j.rl4j.network;
|
||||||
|
|
||||||
|
import org.nd4j.linalg.dataset.api.DataSet;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* An interface defining the <i>trainable</i> aspect of a {@link NeuralNet}.
|
||||||
|
*/
|
||||||
|
public interface ITrainableNeuralNet<NET_TYPE extends ITrainableNeuralNet> {
|
||||||
|
/**
|
||||||
|
* Train the neural net using the supplied <i>feature-labels</i>
|
||||||
|
* @param featuresLabels The feature-labels
|
||||||
|
*/
|
||||||
|
void fit(DataSet featuresLabels);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Changes this instance to be a copy of the <i>from</i> network.
|
||||||
|
* @param from The network that will be the source of the copy.
|
||||||
|
*/
|
||||||
|
void copy(NET_TYPE from);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Creates a clone of the network instance.
|
||||||
|
*/
|
||||||
|
NET_TYPE clone();
|
||||||
|
}
|
|
@ -29,7 +29,7 @@ import java.io.OutputStream;
|
||||||
* Factorisation between ActorCritic and DQN neural net.
|
* Factorisation between ActorCritic and DQN neural net.
|
||||||
* Useful for AsyncLearning and Thread code.
|
* Useful for AsyncLearning and Thread code.
|
||||||
*/
|
*/
|
||||||
public interface NeuralNet<NN extends NeuralNet> {
|
public interface NeuralNet<NN extends NeuralNet> extends IOutputNeuralNet, ITrainableNeuralNet<NN> {
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Returns the underlying MultiLayerNetwork or ComputationGraph objects.
|
* Returns the underlying MultiLayerNetwork or ComputationGraph objects.
|
||||||
|
@ -52,18 +52,6 @@ public interface NeuralNet<NN extends NeuralNet> {
|
||||||
*/
|
*/
|
||||||
INDArray[] outputAll(INDArray batch);
|
INDArray[] outputAll(INDArray batch);
|
||||||
|
|
||||||
/**
|
|
||||||
* clone the Neural Net with the same paramaeters
|
|
||||||
* @return the cloned neural net
|
|
||||||
*/
|
|
||||||
NN clone();
|
|
||||||
|
|
||||||
/**
|
|
||||||
* copy the parameters from a neural net
|
|
||||||
* @param from where to copy parameters
|
|
||||||
*/
|
|
||||||
void copy(NN from);
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Calculate the gradients from input and label (target) of all outputs
|
* Calculate the gradients from input and label (target) of all outputs
|
||||||
* @param input input batch
|
* @param input input batch
|
||||||
|
|
|
@ -18,6 +18,7 @@
|
||||||
package org.deeplearning4j.rl4j.network.ac;
|
package org.deeplearning4j.rl4j.network.ac;
|
||||||
|
|
||||||
import lombok.Getter;
|
import lombok.Getter;
|
||||||
|
import org.apache.commons.lang3.NotImplementedException;
|
||||||
import org.deeplearning4j.nn.api.NeuralNetwork;
|
import org.deeplearning4j.nn.api.NeuralNetwork;
|
||||||
import org.deeplearning4j.nn.conf.ComputationGraphConfiguration;
|
import org.deeplearning4j.nn.conf.ComputationGraphConfiguration;
|
||||||
import org.deeplearning4j.nn.gradient.Gradient;
|
import org.deeplearning4j.nn.gradient.Gradient;
|
||||||
|
@ -25,8 +26,10 @@ import org.deeplearning4j.nn.graph.ComputationGraph;
|
||||||
import org.deeplearning4j.nn.layers.recurrent.RnnOutputLayer;
|
import org.deeplearning4j.nn.layers.recurrent.RnnOutputLayer;
|
||||||
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
|
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
|
||||||
import org.deeplearning4j.optimize.api.TrainingListener;
|
import org.deeplearning4j.optimize.api.TrainingListener;
|
||||||
|
import org.deeplearning4j.rl4j.observation.Observation;
|
||||||
import org.deeplearning4j.util.ModelSerializer;
|
import org.deeplearning4j.util.ModelSerializer;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
import org.nd4j.linalg.dataset.api.DataSet;
|
||||||
|
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
import java.io.OutputStream;
|
import java.io.OutputStream;
|
||||||
|
@ -80,6 +83,11 @@ public class ActorCriticCompGraph implements IActorCritic<ActorCriticCompGraph>
|
||||||
return nn;
|
return nn;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void fit(DataSet featuresLabels) {
|
||||||
|
fit(featuresLabels.getFeatures(), new INDArray[] { featuresLabels.getLabels() });
|
||||||
|
}
|
||||||
|
|
||||||
public void copy(ActorCriticCompGraph from) {
|
public void copy(ActorCriticCompGraph from) {
|
||||||
cg.setParams(from.cg.params());
|
cg.setParams(from.cg.params());
|
||||||
}
|
}
|
||||||
|
@ -137,5 +145,19 @@ public class ActorCriticCompGraph implements IActorCritic<ActorCriticCompGraph>
|
||||||
public void save(String pathValue, String pathPolicy) throws IOException {
|
public void save(String pathValue, String pathPolicy) throws IOException {
|
||||||
throw new UnsupportedOperationException("Call save(path)");
|
throw new UnsupportedOperationException("Call save(path)");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public INDArray output(Observation observation) {
|
||||||
|
// TODO: signature of output() will change to return a class that has named outputs to support network like
|
||||||
|
// this one (output from the value-network and another output for the policy-network
|
||||||
|
throw new NotImplementedException("Not implemented: will be done with AgentLearner async support");
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public INDArray output(INDArray batch) {
|
||||||
|
// TODO: signature of output() will change to return a class that has named outputs to support network like
|
||||||
|
// this one (output from the value-network and another output for the policy-network
|
||||||
|
throw new NotImplementedException("Not implemented: will be done with AgentLearner async support");
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -17,6 +17,7 @@
|
||||||
package org.deeplearning4j.rl4j.network.ac;
|
package org.deeplearning4j.rl4j.network.ac;
|
||||||
|
|
||||||
import lombok.Getter;
|
import lombok.Getter;
|
||||||
|
import org.apache.commons.lang3.NotImplementedException;
|
||||||
import org.deeplearning4j.nn.api.NeuralNetwork;
|
import org.deeplearning4j.nn.api.NeuralNetwork;
|
||||||
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
|
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
|
||||||
import org.deeplearning4j.nn.gradient.Gradient;
|
import org.deeplearning4j.nn.gradient.Gradient;
|
||||||
|
@ -24,8 +25,10 @@ import org.deeplearning4j.nn.layers.recurrent.RnnOutputLayer;
|
||||||
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
|
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
|
||||||
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
|
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
|
||||||
import org.deeplearning4j.optimize.api.TrainingListener;
|
import org.deeplearning4j.optimize.api.TrainingListener;
|
||||||
|
import org.deeplearning4j.rl4j.observation.Observation;
|
||||||
import org.deeplearning4j.util.ModelSerializer;
|
import org.deeplearning4j.util.ModelSerializer;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
import org.nd4j.linalg.dataset.api.DataSet;
|
||||||
|
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
import java.io.OutputStream;
|
import java.io.OutputStream;
|
||||||
|
@ -86,6 +89,13 @@ public class ActorCriticSeparate<NN extends ActorCriticSeparate> implements IAct
|
||||||
return nn;
|
return nn;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void fit(DataSet featuresLabels) {
|
||||||
|
// TODO: signature of fit() will change from DataSet to a class that has named labels to support network like
|
||||||
|
// this one (labels for the value-network and another labels for the policy-network
|
||||||
|
throw new NotImplementedException("Not implemented: will be done with AgentLearner async support");
|
||||||
|
}
|
||||||
|
|
||||||
public void copy(NN from) {
|
public void copy(NN from) {
|
||||||
valueNet.setParams(from.valueNet.params());
|
valueNet.setParams(from.valueNet.params());
|
||||||
policyNet.setParams(from.policyNet.params());
|
policyNet.setParams(from.policyNet.params());
|
||||||
|
@ -164,6 +174,20 @@ public class ActorCriticSeparate<NN extends ActorCriticSeparate> implements IAct
|
||||||
ModelSerializer.writeModel(valueNet, pathValue, true);
|
ModelSerializer.writeModel(valueNet, pathValue, true);
|
||||||
ModelSerializer.writeModel(policyNet, pathPolicy, true);
|
ModelSerializer.writeModel(policyNet, pathPolicy, true);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public INDArray output(Observation observation) {
|
||||||
|
// TODO: signature of output() will change to return a class that has named outputs to support network like
|
||||||
|
// this one (output from the value-network and another output for the policy-network
|
||||||
|
throw new NotImplementedException("Not implemented: will be done with AgentLearner async support");
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public INDArray output(INDArray batch) {
|
||||||
|
// TODO: signature of output() will change to return a class that has named outputs to support network like
|
||||||
|
// this one (output from the value-network and another output for the policy-network
|
||||||
|
throw new NotImplementedException("Not implemented: will be done with AgentLearner async support");
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -16,7 +16,6 @@
|
||||||
|
|
||||||
package org.deeplearning4j.rl4j.network.ac;
|
package org.deeplearning4j.rl4j.network.ac;
|
||||||
|
|
||||||
import org.deeplearning4j.nn.gradient.Gradient;
|
|
||||||
import org.deeplearning4j.rl4j.network.NeuralNet;
|
import org.deeplearning4j.rl4j.network.NeuralNet;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
|
||||||
|
@ -33,27 +32,11 @@ import java.io.OutputStream;
|
||||||
*/
|
*/
|
||||||
public interface IActorCritic<NN extends IActorCritic> extends NeuralNet<NN> {
|
public interface IActorCritic<NN extends IActorCritic> extends NeuralNet<NN> {
|
||||||
|
|
||||||
boolean isRecurrent();
|
|
||||||
|
|
||||||
void reset();
|
|
||||||
|
|
||||||
void fit(INDArray input, INDArray[] labels);
|
|
||||||
|
|
||||||
//FIRST SHOULD BE VALUE AND SECOND IS SOFTMAX POLICY. DONT MESS THIS UP OR ELSE ASYNC THREAD IS BROKEN (maxQ) !
|
//FIRST SHOULD BE VALUE AND SECOND IS SOFTMAX POLICY. DONT MESS THIS UP OR ELSE ASYNC THREAD IS BROKEN (maxQ) !
|
||||||
INDArray[] outputAll(INDArray batch);
|
INDArray[] outputAll(INDArray batch);
|
||||||
|
|
||||||
NN clone();
|
|
||||||
|
|
||||||
void copy(NN from);
|
|
||||||
|
|
||||||
Gradient[] gradient(INDArray input, INDArray[] labels);
|
|
||||||
|
|
||||||
void applyGradient(Gradient[] gradient, int batchSize);
|
|
||||||
|
|
||||||
void save(OutputStream streamValue, OutputStream streamPolicy) throws IOException;
|
void save(OutputStream streamValue, OutputStream streamPolicy) throws IOException;
|
||||||
|
|
||||||
void save(String pathValue, String pathPolicy) throws IOException;
|
void save(String pathValue, String pathPolicy) throws IOException;
|
||||||
|
|
||||||
double getLatestScore();
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -25,6 +25,7 @@ import org.deeplearning4j.optimize.api.TrainingListener;
|
||||||
import org.deeplearning4j.rl4j.observation.Observation;
|
import org.deeplearning4j.rl4j.observation.Observation;
|
||||||
import org.deeplearning4j.util.ModelSerializer;
|
import org.deeplearning4j.util.ModelSerializer;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
import org.nd4j.linalg.dataset.api.DataSet;
|
||||||
|
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
import java.io.OutputStream;
|
import java.io.OutputStream;
|
||||||
|
@ -33,7 +34,7 @@ import java.util.Collection;
|
||||||
/**
|
/**
|
||||||
* @author rubenfiszel (ruben.fiszel@epfl.ch) 7/25/16.
|
* @author rubenfiszel (ruben.fiszel@epfl.ch) 7/25/16.
|
||||||
*/
|
*/
|
||||||
public class DQN<NN extends DQN> implements IDQN<NN> {
|
public class DQN implements IDQN<DQN> {
|
||||||
|
|
||||||
final protected MultiLayerNetwork mln;
|
final protected MultiLayerNetwork mln;
|
||||||
|
|
||||||
|
@ -79,16 +80,23 @@ public class DQN<NN extends DQN> implements IDQN<NN> {
|
||||||
return new INDArray[] {output(batch)};
|
return new INDArray[] {output(batch)};
|
||||||
}
|
}
|
||||||
|
|
||||||
public NN clone() {
|
@Override
|
||||||
NN nn = (NN)new DQN(mln.clone());
|
public void fit(DataSet featuresLabels) {
|
||||||
nn.mln.setListeners(mln.getListeners());
|
fit(featuresLabels.getFeatures(), featuresLabels.getLabels());
|
||||||
return nn;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
public void copy(NN from) {
|
@Override
|
||||||
|
public void copy(DQN from) {
|
||||||
mln.setParams(from.mln.params());
|
mln.setParams(from.mln.params());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public DQN clone() {
|
||||||
|
DQN nn = new DQN(mln.clone());
|
||||||
|
nn.mln.setListeners(mln.getListeners());
|
||||||
|
return nn;
|
||||||
|
}
|
||||||
|
|
||||||
public Gradient[] gradient(INDArray input, INDArray labels) {
|
public Gradient[] gradient(INDArray input, INDArray labels) {
|
||||||
mln.setInput(input);
|
mln.setInput(input);
|
||||||
mln.setLabels(labels);
|
mln.setLabels(labels);
|
||||||
|
|
|
@ -17,9 +17,7 @@
|
||||||
package org.deeplearning4j.rl4j.network.dqn;
|
package org.deeplearning4j.rl4j.network.dqn;
|
||||||
|
|
||||||
import org.deeplearning4j.nn.gradient.Gradient;
|
import org.deeplearning4j.nn.gradient.Gradient;
|
||||||
import org.deeplearning4j.rl4j.network.IOutputNeuralNet;
|
|
||||||
import org.deeplearning4j.rl4j.network.NeuralNet;
|
import org.deeplearning4j.rl4j.network.NeuralNet;
|
||||||
import org.deeplearning4j.rl4j.observation.Observation;
|
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -28,27 +26,11 @@ import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
* This neural net quantify the value of each action given a state
|
* This neural net quantify the value of each action given a state
|
||||||
*
|
*
|
||||||
*/
|
*/
|
||||||
public interface IDQN<NN extends IDQN> extends NeuralNet<NN>, IOutputNeuralNet {
|
public interface IDQN<NN extends IDQN> extends NeuralNet<NN> {
|
||||||
|
|
||||||
boolean isRecurrent();
|
|
||||||
|
|
||||||
void reset();
|
|
||||||
|
|
||||||
void fit(INDArray input, INDArray labels);
|
void fit(INDArray input, INDArray labels);
|
||||||
|
|
||||||
void fit(INDArray input, INDArray[] labels);
|
|
||||||
|
|
||||||
INDArray[] outputAll(INDArray batch);
|
|
||||||
|
|
||||||
NN clone();
|
|
||||||
|
|
||||||
void copy(NN from);
|
|
||||||
|
|
||||||
Gradient[] gradient(INDArray input, INDArray label);
|
Gradient[] gradient(INDArray input, INDArray label);
|
||||||
|
|
||||||
Gradient[] gradient(INDArray input, INDArray[] label);
|
Gradient[] gradient(INDArray input, INDArray[] label);
|
||||||
|
|
||||||
void applyGradient(Gradient[] gradient, int batchSize);
|
|
||||||
|
|
||||||
double getLatestScore();
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -0,0 +1,51 @@
|
||||||
|
package org.deeplearning4j.rl4j.agent.update.neuralnetupdater;
|
||||||
|
|
||||||
|
import org.deeplearning4j.rl4j.network.ITrainableNeuralNet;
|
||||||
|
import org.junit.Test;
|
||||||
|
import org.junit.runner.RunWith;
|
||||||
|
import org.mockito.Mock;
|
||||||
|
import org.mockito.junit.MockitoJUnitRunner;
|
||||||
|
import org.nd4j.linalg.dataset.api.DataSet;
|
||||||
|
|
||||||
|
import static org.mockito.Mockito.*;
|
||||||
|
|
||||||
|
@RunWith(MockitoJUnitRunner.class)
|
||||||
|
public class NeuralNetUpdaterTest {
|
||||||
|
|
||||||
|
@Mock
|
||||||
|
ITrainableNeuralNet currentMock;
|
||||||
|
|
||||||
|
@Mock
|
||||||
|
ITrainableNeuralNet targetMock;
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void when_callingUpdate_expect_currentUpdatedAndtargetNotChanged() {
|
||||||
|
// Arrange
|
||||||
|
NeuralNetUpdater sut = new NeuralNetUpdater(currentMock, targetMock, Integer.MAX_VALUE);
|
||||||
|
DataSet featureLabels = new org.nd4j.linalg.dataset.DataSet();
|
||||||
|
|
||||||
|
// Act
|
||||||
|
sut.update(featureLabels);
|
||||||
|
|
||||||
|
// Assert
|
||||||
|
verify(currentMock, times(1)).fit(featureLabels);
|
||||||
|
verify(targetMock, never()).fit(any());
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void when_callingUpdate_expect_targetUpdatedFromCurrentAtFrequency() {
|
||||||
|
// Arrange
|
||||||
|
NeuralNetUpdater sut = new NeuralNetUpdater(currentMock, targetMock, 3);
|
||||||
|
DataSet featureLabels = new org.nd4j.linalg.dataset.DataSet();
|
||||||
|
|
||||||
|
// Act
|
||||||
|
sut.update(featureLabels);
|
||||||
|
sut.update(featureLabels);
|
||||||
|
sut.update(featureLabels);
|
||||||
|
|
||||||
|
// Assert
|
||||||
|
verify(currentMock, never()).copy(any());
|
||||||
|
verify(targetMock, times(1)).copy(currentMock);
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
|
@ -50,7 +50,7 @@ public class DoubleDQNTest {
|
||||||
DoubleDQN sut = new DoubleDQN(qNetworkMock, targetQNetworkMock, 0.5);
|
DoubleDQN sut = new DoubleDQN(qNetworkMock, targetQNetworkMock, 0.5);
|
||||||
|
|
||||||
// Act
|
// Act
|
||||||
DataSet result = sut.computeTDTargets(transitions);
|
DataSet result = sut.compute(transitions);
|
||||||
|
|
||||||
// Assert
|
// Assert
|
||||||
INDArray evaluatedQValues = result.getLabels();
|
INDArray evaluatedQValues = result.getLabels();
|
||||||
|
@ -74,7 +74,7 @@ public class DoubleDQNTest {
|
||||||
DoubleDQN sut = new DoubleDQN(qNetworkMock, targetQNetworkMock, 0.5);
|
DoubleDQN sut = new DoubleDQN(qNetworkMock, targetQNetworkMock, 0.5);
|
||||||
|
|
||||||
// Act
|
// Act
|
||||||
DataSet result = sut.computeTDTargets(transitions);
|
DataSet result = sut.compute(transitions);
|
||||||
|
|
||||||
// Assert
|
// Assert
|
||||||
INDArray evaluatedQValues = result.getLabels();
|
INDArray evaluatedQValues = result.getLabels();
|
||||||
|
@ -102,7 +102,7 @@ public class DoubleDQNTest {
|
||||||
DoubleDQN sut = new DoubleDQN(qNetworkMock, targetQNetworkMock, 0.5);
|
DoubleDQN sut = new DoubleDQN(qNetworkMock, targetQNetworkMock, 0.5);
|
||||||
|
|
||||||
// Act
|
// Act
|
||||||
DataSet result = sut.computeTDTargets(transitions);
|
DataSet result = sut.compute(transitions);
|
||||||
|
|
||||||
// Assert
|
// Assert
|
||||||
INDArray evaluatedQValues = result.getLabels();
|
INDArray evaluatedQValues = result.getLabels();
|
||||||
|
|
|
@ -50,7 +50,7 @@ public class StandardDQNTest {
|
||||||
StandardDQN sut = new StandardDQN(qNetworkMock, targetQNetworkMock, 0.5);
|
StandardDQN sut = new StandardDQN(qNetworkMock, targetQNetworkMock, 0.5);
|
||||||
|
|
||||||
// Act
|
// Act
|
||||||
DataSet result = sut.computeTDTargets(transitions);
|
DataSet result = sut.compute(transitions);
|
||||||
|
|
||||||
// Assert
|
// Assert
|
||||||
INDArray evaluatedQValues = result.getLabels();
|
INDArray evaluatedQValues = result.getLabels();
|
||||||
|
@ -72,7 +72,7 @@ public class StandardDQNTest {
|
||||||
StandardDQN sut = new StandardDQN(qNetworkMock, targetQNetworkMock, 0.5);
|
StandardDQN sut = new StandardDQN(qNetworkMock, targetQNetworkMock, 0.5);
|
||||||
|
|
||||||
// Act
|
// Act
|
||||||
DataSet result = sut.computeTDTargets(transitions);
|
DataSet result = sut.compute(transitions);
|
||||||
|
|
||||||
// Assert
|
// Assert
|
||||||
INDArray evaluatedQValues = result.getLabels();
|
INDArray evaluatedQValues = result.getLabels();
|
||||||
|
@ -98,7 +98,7 @@ public class StandardDQNTest {
|
||||||
StandardDQN sut = new StandardDQN(qNetworkMock, targetQNetworkMock, 0.5);
|
StandardDQN sut = new StandardDQN(qNetworkMock, targetQNetworkMock, 0.5);
|
||||||
|
|
||||||
// Act
|
// Act
|
||||||
DataSet result = sut.computeTDTargets(transitions);
|
DataSet result = sut.compute(transitions);
|
||||||
|
|
||||||
// Assert
|
// Assert
|
||||||
INDArray evaluatedQValues = result.getLabels();
|
INDArray evaluatedQValues = result.getLabels();
|
||||||
|
|
|
@ -2,10 +2,12 @@ package org.deeplearning4j.rl4j.learning.sync.support;
|
||||||
|
|
||||||
import org.deeplearning4j.nn.api.NeuralNetwork;
|
import org.deeplearning4j.nn.api.NeuralNetwork;
|
||||||
import org.deeplearning4j.nn.gradient.Gradient;
|
import org.deeplearning4j.nn.gradient.Gradient;
|
||||||
|
import org.deeplearning4j.rl4j.network.ITrainableNeuralNet;
|
||||||
import org.deeplearning4j.rl4j.network.NeuralNet;
|
import org.deeplearning4j.rl4j.network.NeuralNet;
|
||||||
import org.deeplearning4j.rl4j.network.dqn.IDQN;
|
import org.deeplearning4j.rl4j.network.dqn.IDQN;
|
||||||
import org.deeplearning4j.rl4j.observation.Observation;
|
import org.deeplearning4j.rl4j.observation.Observation;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
import org.nd4j.linalg.dataset.api.DataSet;
|
||||||
|
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
import java.io.OutputStream;
|
import java.io.OutputStream;
|
||||||
|
@ -66,21 +68,21 @@ public class MockDQN implements IDQN {
|
||||||
return new INDArray[0];
|
return new INDArray[0];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void fit(DataSet featuresLabels) {
|
||||||
|
throw new UnsupportedOperationException();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void copy(ITrainableNeuralNet from) {
|
||||||
|
throw new UnsupportedOperationException();
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public IDQN clone() {
|
public IDQN clone() {
|
||||||
return null;
|
return null;
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
|
||||||
public void copy(NeuralNet from) {
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void copy(IDQN from) {
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public Gradient[] gradient(INDArray input, INDArray label) {
|
public Gradient[] gradient(INDArray input, INDArray label) {
|
||||||
return new Gradient[0];
|
return new Gradient[0];
|
||||||
|
|
|
@ -36,6 +36,7 @@ import org.deeplearning4j.rl4j.util.LegacyMDPWrapper;
|
||||||
import org.junit.Test;
|
import org.junit.Test;
|
||||||
import org.nd4j.linalg.activations.Activation;
|
import org.nd4j.linalg.activations.Activation;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
import org.nd4j.linalg.dataset.api.DataSet;
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
import org.nd4j.linalg.lossfunctions.LossFunctions;
|
import org.nd4j.linalg.lossfunctions.LossFunctions;
|
||||||
|
|
||||||
|
@ -88,6 +89,11 @@ public class PolicyTest {
|
||||||
throw new UnsupportedOperationException();
|
throw new UnsupportedOperationException();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void fit(DataSet featuresLabels) {
|
||||||
|
throw new UnsupportedOperationException();
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void copy(NN from) {
|
public void copy(NN from) {
|
||||||
throw new UnsupportedOperationException();
|
throw new UnsupportedOperationException();
|
||||||
|
@ -127,6 +133,16 @@ public class PolicyTest {
|
||||||
public void save(String filename) throws IOException {
|
public void save(String filename) throws IOException {
|
||||||
throw new UnsupportedOperationException();
|
throw new UnsupportedOperationException();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public INDArray output(Observation observation) {
|
||||||
|
throw new UnsupportedOperationException();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public INDArray output(INDArray batch) {
|
||||||
|
throw new UnsupportedOperationException();
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
|
|
|
@ -2,11 +2,13 @@ package org.deeplearning4j.rl4j.support;
|
||||||
|
|
||||||
import org.deeplearning4j.nn.api.NeuralNetwork;
|
import org.deeplearning4j.nn.api.NeuralNetwork;
|
||||||
import org.deeplearning4j.nn.gradient.Gradient;
|
import org.deeplearning4j.nn.gradient.Gradient;
|
||||||
|
import org.deeplearning4j.rl4j.network.ITrainableNeuralNet;
|
||||||
import org.deeplearning4j.rl4j.network.NeuralNet;
|
import org.deeplearning4j.rl4j.network.NeuralNet;
|
||||||
import org.deeplearning4j.rl4j.network.dqn.IDQN;
|
import org.deeplearning4j.rl4j.network.dqn.IDQN;
|
||||||
import org.deeplearning4j.rl4j.observation.Observation;
|
import org.deeplearning4j.rl4j.observation.Observation;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.common.primitives.Pair;
|
import org.nd4j.common.primitives.Pair;
|
||||||
|
import org.nd4j.linalg.dataset.api.DataSet;
|
||||||
|
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
import java.io.OutputStream;
|
import java.io.OutputStream;
|
||||||
|
@ -63,6 +65,16 @@ public class MockDQN implements IDQN {
|
||||||
return new INDArray[] { batch.mul(-1.0) };
|
return new INDArray[] { batch.mul(-1.0) };
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void fit(DataSet featuresLabels) {
|
||||||
|
throw new UnsupportedOperationException();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void copy(ITrainableNeuralNet from) {
|
||||||
|
throw new UnsupportedOperationException();
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public IDQN clone() {
|
public IDQN clone() {
|
||||||
MockDQN clone = new MockDQN();
|
MockDQN clone = new MockDQN();
|
||||||
|
@ -71,16 +83,6 @@ public class MockDQN implements IDQN {
|
||||||
return clone;
|
return clone;
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
|
||||||
public void copy(NeuralNet from) {
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void copy(IDQN from) {
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public Gradient[] gradient(INDArray input, INDArray label) {
|
public Gradient[] gradient(INDArray input, INDArray label) {
|
||||||
gradientParams.add(new Pair<INDArray, INDArray>(input, label));
|
gradientParams.add(new Pair<INDArray, INDArray>(input, label));
|
||||||
|
|
|
@ -2,8 +2,11 @@ package org.deeplearning4j.rl4j.support;
|
||||||
|
|
||||||
import org.deeplearning4j.nn.api.NeuralNetwork;
|
import org.deeplearning4j.nn.api.NeuralNetwork;
|
||||||
import org.deeplearning4j.nn.gradient.Gradient;
|
import org.deeplearning4j.nn.gradient.Gradient;
|
||||||
|
import org.deeplearning4j.rl4j.network.ITrainableNeuralNet;
|
||||||
import org.deeplearning4j.rl4j.network.NeuralNet;
|
import org.deeplearning4j.rl4j.network.NeuralNet;
|
||||||
|
import org.deeplearning4j.rl4j.observation.Observation;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
import org.nd4j.linalg.dataset.api.DataSet;
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
|
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
|
@ -39,15 +42,20 @@ public class MockNeuralNet implements NeuralNet {
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public NeuralNet clone() {
|
public void fit(DataSet featuresLabels) {
|
||||||
return this;
|
throw new UnsupportedOperationException();
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void copy(NeuralNet from) {
|
public void copy(ITrainableNeuralNet from) {
|
||||||
++copyCallCount;
|
++copyCallCount;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public NeuralNet clone() {
|
||||||
|
return this;
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public Gradient[] gradient(INDArray input, INDArray[] labels) {
|
public Gradient[] gradient(INDArray input, INDArray[] labels) {
|
||||||
return new Gradient[0];
|
return new Gradient[0];
|
||||||
|
@ -77,4 +85,14 @@ public class MockNeuralNet implements NeuralNet {
|
||||||
public void save(String filename) throws IOException {
|
public void save(String filename) throws IOException {
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public INDArray output(Observation observation) {
|
||||||
|
throw new UnsupportedOperationException();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public INDArray output(INDArray batch) {
|
||||||
|
throw new UnsupportedOperationException();
|
||||||
|
}
|
||||||
}
|
}
|
Loading…
Reference in New Issue