RL4J - Added a unit test to help refac QLearningDiscrete.trainStep() (#8065)
* Added a unit test to help refac QLearningDiscrete.trainStep() Signed-off-by: unknown <aboulang2002@yahoo.com> * Changed expReplay setter to package private Signed-off-by: Alexandre Boulanger <aboulang2002@yahoo.com>master
parent
b2145ca780
commit
b083c22de5
|
@ -45,8 +45,12 @@ import java.util.List;
|
|||
public abstract class QLearning<O extends Encodable, A, AS extends ActionSpace<A>>
|
||||
extends SyncLearning<O, A, AS, IDQN> {
|
||||
|
||||
// FIXME Changed for refac
|
||||
// @Getter
|
||||
// final private IExpReplay<A> expReplay;
|
||||
@Getter
|
||||
final private IExpReplay<A> expReplay;
|
||||
@Setter(AccessLevel.PACKAGE)
|
||||
private IExpReplay<A> expReplay;
|
||||
|
||||
public QLearning(QLConfiguration conf) {
|
||||
super(conf);
|
||||
|
|
|
@ -1,17 +1,16 @@
|
|||
package org.deeplearning4j.rl4j.learning.async;
|
||||
|
||||
import org.deeplearning4j.gym.StepReply;
|
||||
import org.deeplearning4j.nn.api.NeuralNetwork;
|
||||
import org.deeplearning4j.nn.gradient.Gradient;
|
||||
import org.deeplearning4j.rl4j.learning.IHistoryProcessor;
|
||||
import org.deeplearning4j.rl4j.mdp.MDP;
|
||||
import org.deeplearning4j.rl4j.network.NeuralNet;
|
||||
import org.deeplearning4j.rl4j.policy.Policy;
|
||||
import org.deeplearning4j.rl4j.space.DiscreteSpace;
|
||||
import org.deeplearning4j.rl4j.space.Encodable;
|
||||
import org.deeplearning4j.rl4j.space.ObservationSpace;
|
||||
import org.deeplearning4j.rl4j.support.MockDataManager;
|
||||
import org.deeplearning4j.rl4j.support.MockHistoryProcessor;
|
||||
import org.deeplearning4j.rl4j.support.MockMDP;
|
||||
import org.deeplearning4j.rl4j.support.MockObservationSpace;
|
||||
import org.deeplearning4j.rl4j.util.IDataManager;
|
||||
import org.junit.Test;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
|
@ -93,7 +92,7 @@ public class AsyncThreadTest {
|
|||
IDataManager.StatEntry entry = dataManager.statEntries.get(i);
|
||||
assertEquals(i + 1, entry.getStepCounter());
|
||||
assertEquals(i, entry.getEpochCounter());
|
||||
assertEquals(1.0, entry.getReward(), 0.0);
|
||||
assertEquals(79.0, entry.getReward(), 0.0);
|
||||
}
|
||||
|
||||
assertEquals(10, dataManager.isSaveDataCallCount);
|
||||
|
@ -128,7 +127,7 @@ public class AsyncThreadTest {
|
|||
IDataManager.StatEntry entry = dataManager.statEntries.get(i);
|
||||
assertEquals(i + 1, entry.getStepCounter());
|
||||
assertEquals(i, entry.getEpochCounter());
|
||||
assertEquals(1.0, entry.getReward(), 0.0);
|
||||
assertEquals(79.0, entry.getReward(), 0.0);
|
||||
}
|
||||
|
||||
assertEquals(1, dataManager.isSaveDataCallCount);
|
||||
|
@ -308,91 +307,6 @@ public class AsyncThreadTest {
|
|||
}
|
||||
}
|
||||
|
||||
public static class MockEncodable implements Encodable {
|
||||
|
||||
private final int value;
|
||||
|
||||
public MockEncodable(int value) {
|
||||
|
||||
this.value = value;
|
||||
}
|
||||
|
||||
@Override
|
||||
public double[] toArray() {
|
||||
return new double[] { value };
|
||||
}
|
||||
}
|
||||
|
||||
public static class MockObservationSpace implements ObservationSpace {
|
||||
|
||||
@Override
|
||||
public String getName() {
|
||||
return null;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int[] getShape() {
|
||||
return new int[] { 1 };
|
||||
}
|
||||
|
||||
@Override
|
||||
public INDArray getLow() {
|
||||
return null;
|
||||
}
|
||||
|
||||
@Override
|
||||
public INDArray getHigh() {
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
public static class MockMDP implements MDP<MockEncodable, Integer, DiscreteSpace> {
|
||||
|
||||
private final DiscreteSpace actionSpace;
|
||||
private int currentObsValue = 0;
|
||||
private final ObservationSpace observationSpace;
|
||||
|
||||
public MockMDP(ObservationSpace observationSpace) {
|
||||
actionSpace = new DiscreteSpace(5);
|
||||
this.observationSpace = observationSpace;
|
||||
}
|
||||
|
||||
@Override
|
||||
public ObservationSpace getObservationSpace() {
|
||||
return observationSpace;
|
||||
}
|
||||
|
||||
@Override
|
||||
public DiscreteSpace getActionSpace() {
|
||||
return actionSpace;
|
||||
}
|
||||
|
||||
@Override
|
||||
public MockEncodable reset() {
|
||||
return new MockEncodable(++currentObsValue);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void close() {
|
||||
|
||||
}
|
||||
|
||||
@Override
|
||||
public StepReply<MockEncodable> step(Integer obs) {
|
||||
return new StepReply<MockEncodable>(new MockEncodable(obs), (double)obs, isDone(), null);
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean isDone() {
|
||||
return false;
|
||||
}
|
||||
|
||||
@Override
|
||||
public MDP newInstance() {
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
public static class MockAsyncConfiguration implements AsyncConfiguration {
|
||||
|
||||
private final int nStep;
|
||||
|
|
|
@ -0,0 +1,142 @@
|
|||
package org.deeplearning4j.rl4j.learning.sync.qlearning.discrete;
|
||||
|
||||
import org.deeplearning4j.rl4j.learning.IHistoryProcessor;
|
||||
import org.deeplearning4j.rl4j.learning.sync.Transition;
|
||||
import org.deeplearning4j.rl4j.learning.sync.qlearning.QLearning;
|
||||
import org.deeplearning4j.rl4j.mdp.MDP;
|
||||
import org.deeplearning4j.rl4j.network.dqn.IDQN;
|
||||
import org.deeplearning4j.rl4j.space.DiscreteSpace;
|
||||
import org.deeplearning4j.rl4j.support.*;
|
||||
import org.deeplearning4j.rl4j.util.IDataManager;
|
||||
import org.junit.Test;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
import org.nd4j.linalg.primitives.Pair;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
|
||||
import static org.junit.Assert.*;
|
||||
|
||||
public class QLearningDiscreteTest {
|
||||
@Test
|
||||
public void refac_QLearningDiscrete_trainStep() {
|
||||
// Arrange
|
||||
MockObservationSpace observationSpace = new MockObservationSpace();
|
||||
MockMDP mdp = new MockMDP(observationSpace);
|
||||
MockDQN dqn = new MockDQN();
|
||||
QLearning.QLConfiguration conf = new QLearning.QLConfiguration(0, 0, 0, 5, 1, 0,
|
||||
0, 1.0, 0, 0, 0, 0, true);
|
||||
MockDataManager dataManager = new MockDataManager(false);
|
||||
TestQLearningDiscrete sut = new TestQLearningDiscrete(mdp, dqn, conf, dataManager, 10);
|
||||
IHistoryProcessor.Configuration hpConf = new IHistoryProcessor.Configuration(5, 4, 4, 4, 4, 0, 0, 2);
|
||||
MockHistoryProcessor hp = new MockHistoryProcessor(hpConf);
|
||||
sut.setHistoryProcessor(hp);
|
||||
MockExpReplay expReplay = new MockExpReplay();
|
||||
sut.setExpReplay(expReplay);
|
||||
MockEncodable obs = new MockEncodable(1);
|
||||
List<QLearning.QLStepReturn<MockEncodable>> results = new ArrayList<>();
|
||||
|
||||
// Act
|
||||
sut.initMdp();
|
||||
for(int step = 0; step < 16; ++step) {
|
||||
results.add(sut.trainStep(obs));
|
||||
sut.incrementStep();
|
||||
}
|
||||
|
||||
// Assert
|
||||
// HistoryProcessor calls
|
||||
assertEquals(24, hp.recordCallCount);
|
||||
assertEquals(13, hp.addCallCount);
|
||||
assertEquals(0, hp.startMonitorCallCount);
|
||||
assertEquals(0, hp.stopMonitorCallCount);
|
||||
|
||||
// DQN calls
|
||||
assertEquals(1, dqn.fitParams.size());
|
||||
assertEquals(123.0, dqn.fitParams.get(0).getFirst().getDouble(0), 0.001);
|
||||
assertEquals(234.0, dqn.fitParams.get(0).getSecond().getDouble(0), 0.001);
|
||||
assertEquals(14, dqn.outputParams.size());
|
||||
double[][] expectedDQNOutput = new double[][] {
|
||||
new double[] { 0.0, 0.0, 0.0, 0.0, 1.0 },
|
||||
new double[] { 0.0, 0.0, 0.0, 1.0, 9.0 },
|
||||
new double[] { 0.0, 0.0, 0.0, 1.0, 9.0 },
|
||||
new double[] { 0.0, 0.0, 1.0, 9.0, 11.0 },
|
||||
new double[] { 0.0, 1.0, 9.0, 11.0, 13.0 },
|
||||
new double[] { 0.0, 1.0, 9.0, 11.0, 13.0 },
|
||||
new double[] { 1.0, 9.0, 11.0, 13.0, 15.0 },
|
||||
new double[] { 1.0, 9.0, 11.0, 13.0, 15.0 },
|
||||
new double[] { 9.0, 11.0, 13.0, 15.0, 17.0 },
|
||||
new double[] { 9.0, 11.0, 13.0, 15.0, 17.0 },
|
||||
new double[] { 11.0, 13.0, 15.0, 17.0, 19.0 },
|
||||
new double[] { 11.0, 13.0, 15.0, 17.0, 19.0 },
|
||||
new double[] { 13.0, 15.0, 17.0, 19.0, 21.0 },
|
||||
new double[] { 13.0, 15.0, 17.0, 19.0, 21.0 },
|
||||
|
||||
};
|
||||
for(int i = 0; i < expectedDQNOutput.length; ++i) {
|
||||
INDArray outputParam = dqn.outputParams.get(i);
|
||||
|
||||
assertEquals(5, outputParam.shape()[0]);
|
||||
assertEquals(1, outputParam.shape()[1]);
|
||||
|
||||
double[] expectedRow = expectedDQNOutput[i];
|
||||
for(int j = 0; j < expectedRow.length; ++j) {
|
||||
assertEquals(expectedRow[j] / 255.0, outputParam.getDouble(j), 0.00001);
|
||||
}
|
||||
}
|
||||
|
||||
// MDP calls
|
||||
assertArrayEquals(new Integer[] { 0, 0, 0, 0, 0, 0, 0, 0, 0 ,0, 4, 4, 3, 3, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4 }, mdp.actions.toArray());
|
||||
|
||||
// ExpReplay calls
|
||||
double[] expectedTrRewards = new double[] { 9.0, 21.0, 25.0, 29.0, 33.0, 37.0, 41.0, 45.0 };
|
||||
int[] expectedTrActions = new int[] { 0, 4, 3, 4, 4, 4, 4, 4 };
|
||||
double[] expectedTrNextObservation = new double[] { 0, 0, 0, 1.0, 9.0, 11.0, 13.0, 15.0 };
|
||||
double[][] expectedTrObservations = new double[][] {
|
||||
new double[] { 0.0, 0.0, 0.0, 0.0, 1.0 },
|
||||
new double[] { 0.0, 0.0, 0.0, 1.0, 9.0 },
|
||||
new double[] { 0.0, 0.0, 1.0, 9.0, 11.0 },
|
||||
new double[] { 0.0, 1.0, 9.0, 11.0, 13.0 },
|
||||
new double[] { 1.0, 9.0, 11.0, 13.0, 15.0 },
|
||||
new double[] { 9.0, 11.0, 13.0, 15.0, 17.0 },
|
||||
new double[] { 11.0, 13.0, 15.0, 17.0, 19.0 },
|
||||
new double[] { 13.0, 15.0, 17.0, 19.0, 21.0 },
|
||||
};
|
||||
for(int i = 0; i < expectedTrRewards.length; ++i) {
|
||||
Transition tr = expReplay.transitions.get(i);
|
||||
assertEquals(expectedTrRewards[i], tr.getReward(), 0.0001);
|
||||
assertEquals(expectedTrActions[i], tr.getAction());
|
||||
assertEquals(expectedTrNextObservation[i], tr.getNextObservation().getDouble(0), 0.0001);
|
||||
for(int j = 0; j < expectedTrObservations[i].length; ++j) {
|
||||
assertEquals(expectedTrObservations[i][j], tr.getObservation()[j].getDouble(0), 0.0001);
|
||||
}
|
||||
}
|
||||
|
||||
// trainStep results
|
||||
assertEquals(16, results.size());
|
||||
double[] expectedMaxQ = new double[] { 1.0, 9.0, 11.0, 13.0, 15.0, 17.0, 19.0, 21.0 };
|
||||
double[] expectedRewards = new double[] { 9.0, 11.0, 13.0, 15.0, 17.0, 19.0, 21.0, 23.0 };
|
||||
for(int i=0; i < 16; ++i) {
|
||||
QLearning.QLStepReturn<MockEncodable> result = results.get(i);
|
||||
if(i % 2 == 0) {
|
||||
assertEquals(expectedMaxQ[i/2] / 255.0, result.getMaxQ(), 0.001);
|
||||
assertEquals(expectedRewards[i/2], result.getStepReply().getReward(), 0.001);
|
||||
}
|
||||
else {
|
||||
assertTrue(result.getMaxQ().isNaN());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
public static class TestQLearningDiscrete extends QLearningDiscrete<MockEncodable> {
|
||||
public TestQLearningDiscrete(MDP<MockEncodable, Integer, DiscreteSpace> mdp,IDQN dqn,
|
||||
QLConfiguration conf, IDataManager dataManager, int epsilonNbStep) {
|
||||
super(mdp, dqn, conf, dataManager, epsilonNbStep);
|
||||
}
|
||||
|
||||
@Override
|
||||
protected Pair<INDArray, INDArray> setTarget(ArrayList<Transition<Integer>> transitions) {
|
||||
return new Pair<>(Nd4j.create(new double[] { 123.0 }), Nd4j.create(new double[] { 234.0 }));
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,100 @@
|
|||
package org.deeplearning4j.rl4j.support;
|
||||
|
||||
import org.deeplearning4j.nn.api.NeuralNetwork;
|
||||
import org.deeplearning4j.nn.gradient.Gradient;
|
||||
import org.deeplearning4j.rl4j.network.NeuralNet;
|
||||
import org.deeplearning4j.rl4j.network.dqn.IDQN;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.primitives.Pair;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.io.OutputStream;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
|
||||
public class MockDQN implements IDQN {
|
||||
|
||||
public final List<INDArray> outputParams = new ArrayList<>();
|
||||
public final List<Pair<INDArray, INDArray>> fitParams = new ArrayList<>();
|
||||
|
||||
@Override
|
||||
public NeuralNetwork[] getNeuralNetworks() {
|
||||
return new NeuralNetwork[0];
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean isRecurrent() {
|
||||
return false;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void reset() {
|
||||
|
||||
}
|
||||
|
||||
@Override
|
||||
public void fit(INDArray input, INDArray labels) {
|
||||
fitParams.add(new Pair<>(input, labels));
|
||||
}
|
||||
|
||||
@Override
|
||||
public void fit(INDArray input, INDArray[] labels) {
|
||||
|
||||
}
|
||||
|
||||
@Override
|
||||
public INDArray output(INDArray batch){
|
||||
outputParams.add(batch);
|
||||
return batch;
|
||||
}
|
||||
|
||||
@Override
|
||||
public INDArray[] outputAll(INDArray batch) {
|
||||
return new INDArray[0];
|
||||
}
|
||||
|
||||
@Override
|
||||
public IDQN clone() {
|
||||
return null;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void copy(NeuralNet from) {
|
||||
|
||||
}
|
||||
|
||||
@Override
|
||||
public void copy(IDQN from) {
|
||||
|
||||
}
|
||||
|
||||
@Override
|
||||
public Gradient[] gradient(INDArray input, INDArray label) {
|
||||
return new Gradient[0];
|
||||
}
|
||||
|
||||
@Override
|
||||
public Gradient[] gradient(INDArray input, INDArray[] label) {
|
||||
return new Gradient[0];
|
||||
}
|
||||
|
||||
@Override
|
||||
public void applyGradient(Gradient[] gradient, int batchSize) {
|
||||
|
||||
}
|
||||
|
||||
@Override
|
||||
public double getLatestScore() {
|
||||
return 0;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void save(OutputStream os) throws IOException {
|
||||
|
||||
}
|
||||
|
||||
@Override
|
||||
public void save(String filename) throws IOException {
|
||||
|
||||
}
|
||||
}
|
|
@ -0,0 +1,18 @@
|
|||
package org.deeplearning4j.rl4j.support;
|
||||
|
||||
import org.deeplearning4j.rl4j.space.Encodable;
|
||||
|
||||
public class MockEncodable implements Encodable {
|
||||
|
||||
private final int value;
|
||||
|
||||
public MockEncodable(int value) {
|
||||
|
||||
this.value = value;
|
||||
}
|
||||
|
||||
@Override
|
||||
public double[] toArray() {
|
||||
return new double[] { value };
|
||||
}
|
||||
}
|
|
@ -0,0 +1,22 @@
|
|||
package org.deeplearning4j.rl4j.support;
|
||||
|
||||
import org.deeplearning4j.rl4j.learning.sync.IExpReplay;
|
||||
import org.deeplearning4j.rl4j.learning.sync.Transition;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
|
||||
public class MockExpReplay implements IExpReplay<Integer> {
|
||||
|
||||
public List<Transition<Integer>> transitions = new ArrayList<>();
|
||||
|
||||
@Override
|
||||
public ArrayList<Transition<Integer>> getBatch() {
|
||||
return null;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void store(Transition<Integer> transition) {
|
||||
transitions.add(transition);
|
||||
}
|
||||
}
|
|
@ -1,15 +1,24 @@
|
|||
package org.deeplearning4j.rl4j.support;
|
||||
|
||||
import org.apache.commons.collections4.queue.CircularFifoQueue;
|
||||
import org.deeplearning4j.rl4j.learning.IHistoryProcessor;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
|
||||
public class MockHistoryProcessor implements IHistoryProcessor {
|
||||
|
||||
public int recordCallCount = 0;
|
||||
public int addCallCount = 0;
|
||||
public int startMonitorCallCount = 0;
|
||||
public int stopMonitorCallCount = 0;
|
||||
|
||||
private final Configuration config;
|
||||
private final CircularFifoQueue<INDArray> history;
|
||||
|
||||
public MockHistoryProcessor(Configuration config) {
|
||||
|
||||
this.config = config;
|
||||
history = new CircularFifoQueue<>(config.getHistoryLength());
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -19,27 +28,32 @@ public class MockHistoryProcessor implements IHistoryProcessor {
|
|||
|
||||
@Override
|
||||
public INDArray[] getHistory() {
|
||||
return new INDArray[0];
|
||||
INDArray[] array = new INDArray[getConf().getHistoryLength()];
|
||||
for (int i = 0; i < config.getHistoryLength(); i++) {
|
||||
array[i] = history.get(i).castTo(Nd4j.dataType());
|
||||
}
|
||||
return array;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void record(INDArray image) {
|
||||
|
||||
++recordCallCount;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void add(INDArray image) {
|
||||
|
||||
++addCallCount;
|
||||
history.add(image);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void startMonitor(String filename, int[] shape) {
|
||||
|
||||
++startMonitorCallCount;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void stopMonitor() {
|
||||
|
||||
++stopMonitorCallCount;
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -49,6 +63,6 @@ public class MockHistoryProcessor implements IHistoryProcessor {
|
|||
|
||||
@Override
|
||||
public double getScale() {
|
||||
return 0;
|
||||
return 255.0;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -0,0 +1,60 @@
|
|||
package org.deeplearning4j.rl4j.support;
|
||||
|
||||
import org.deeplearning4j.gym.StepReply;
|
||||
import org.deeplearning4j.rl4j.mdp.MDP;
|
||||
import org.deeplearning4j.rl4j.space.DiscreteSpace;
|
||||
import org.deeplearning4j.rl4j.space.ObservationSpace;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
|
||||
public class MockMDP implements MDP<MockEncodable, Integer, DiscreteSpace> {
|
||||
|
||||
private final DiscreteSpace actionSpace;
|
||||
private int currentObsValue = 0;
|
||||
private final ObservationSpace observationSpace;
|
||||
|
||||
public final List<Integer> actions = new ArrayList<>();
|
||||
|
||||
public MockMDP(ObservationSpace observationSpace) {
|
||||
actionSpace = new DiscreteSpace(5);
|
||||
this.observationSpace = observationSpace;
|
||||
}
|
||||
|
||||
@Override
|
||||
public ObservationSpace getObservationSpace() {
|
||||
return observationSpace;
|
||||
}
|
||||
|
||||
@Override
|
||||
public DiscreteSpace getActionSpace() {
|
||||
return actionSpace;
|
||||
}
|
||||
|
||||
@Override
|
||||
public MockEncodable reset() {
|
||||
currentObsValue = 0;
|
||||
return new MockEncodable(currentObsValue++);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void close() {
|
||||
|
||||
}
|
||||
|
||||
@Override
|
||||
public StepReply<MockEncodable> step(Integer action) {
|
||||
actions.add(action);
|
||||
return new StepReply<>(new MockEncodable(currentObsValue), (double) currentObsValue++, isDone(), null);
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean isDone() {
|
||||
return false;
|
||||
}
|
||||
|
||||
@Override
|
||||
public MDP newInstance() {
|
||||
return null;
|
||||
}
|
||||
}
|
|
@ -0,0 +1,27 @@
|
|||
package org.deeplearning4j.rl4j.support;
|
||||
|
||||
import org.deeplearning4j.rl4j.space.ObservationSpace;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
|
||||
public class MockObservationSpace implements ObservationSpace {
|
||||
|
||||
@Override
|
||||
public String getName() {
|
||||
return null;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int[] getShape() {
|
||||
return new int[] { 1 };
|
||||
}
|
||||
|
||||
@Override
|
||||
public INDArray getLow() {
|
||||
return null;
|
||||
}
|
||||
|
||||
@Override
|
||||
public INDArray getHigh() {
|
||||
return null;
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue