RL4J: Add TransformProcess, part 2 (#8766)

* Part 2 of TransformProcess

Signed-off-by: Alexandre Boulanger <aboulang2002@yahoo.com>

* Fix compile errors

Signed-off-by: Samuel Audet <samuel.audet@gmail.com>

* Revert unrelated changes

Signed-off-by: Samuel Audet <samuel.audet@gmail.com>

Co-authored-by: Samuel Audet <samuel.audet@gmail.com>
master
Alexandre Boulanger 2020-03-10 22:56:41 -04:00 committed by GitHub
parent 0faf83b1b6
commit 8b10f0b876
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
18 changed files with 932 additions and 124 deletions

View File

@ -1,5 +1,5 @@
/******************************************************************************* /*******************************************************************************
* Copyright (c) 2015-2019 Skymind, Inc. * Copyright (c) 2020 Konduit K.K.
* *
* This program and the accompanying materials are made available under the * This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at * terms of the Apache License, Version 2.0 which is available at
@ -17,57 +17,43 @@
package org.deeplearning4j.rl4j.observation; package org.deeplearning4j.rl4j.observation;
import lombok.Getter; import lombok.Getter;
import lombok.Setter;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.api.DataSet;
import org.nd4j.linalg.factory.Nd4j;
/** /**
* Presently only a dummy container. Will contain observation channels when done. * Represent an observation from the environment
*
* @author Alexandre Boulanger
*/ */
public class Observation { public class Observation {
// TODO: Presently only a dummy container. Will contain observation channels when done.
private final DataSet data; /**
* A singleton representing a skipped observation
*/
public static Observation SkippedObservation = new Observation(null);
@Getter @Setter /**
private boolean skipped; * @return A INDArray containing the data of the observation
*/
@Getter
private final INDArray data;
public Observation(INDArray[] data) { public boolean isSkipped() {
this(data, false); return data == null;
} }
public Observation(INDArray[] data, boolean shouldReshape) {
INDArray features = Nd4j.concat(0, data);
if(shouldReshape) {
features = reshape(features);
}
this.data = new org.nd4j.linalg.dataset.DataSet(features, null);
}
// FIXME: Remove -- only used in unit tests
public Observation(INDArray data) { public Observation(INDArray data) {
this.data = new org.nd4j.linalg.dataset.DataSet(data, null);
}
private INDArray reshape(INDArray source) {
long[] shape = source.shape();
long[] nshape = new long[shape.length + 1];
nshape[0] = 1;
System.arraycopy(shape, 0, nshape, 1, shape.length);
return source.reshape(nshape);
}
private Observation(DataSet data) {
this.data = data; this.data = data;
} }
/**
* Creates a duplicate instance of the current observation
* @return
*/
public Observation dup() { public Observation dup() {
return new Observation(new org.nd4j.linalg.dataset.DataSet(data.getFeatures().dup(), null)); if(data == null) {
return SkippedObservation;
} }
public INDArray getData() { return new Observation(data.dup());
return data.getFeatures();
} }
} }

View File

@ -0,0 +1,231 @@
/*******************************************************************************
* Copyright (c) 2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.rl4j.observation.transform;
import org.apache.commons.lang3.NotImplementedException;
import org.deeplearning4j.rl4j.helper.INDArrayHelper;
import org.deeplearning4j.rl4j.observation.Observation;
import org.nd4j.base.Preconditions;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.api.DataSet;
import org.nd4j.linalg.dataset.api.DataSetPreProcessor;
import org.nd4j.shade.guava.collect.Maps;
import org.datavec.api.transform.Operation;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
/**
* A TransformProcess will build an {@link Observation Observation} from the raw data coming from the environment.
* Three types of steps are available:
* 1) A {@link FilterOperation FilterOperation}: Used to determine if an observation should be skipped.
* 2) An {@link Operation Operation}: Applies a transform and/or conversion to an observation channel.
* 3) A {@link DataSetPreProcessor DataSetPreProcessor}: Applies a DataSetPreProcessor to the observation channel. The data in the channel must be a DataSet.
*
* Instances of the three types above can be called in any order. The only requirement is that when build() is called,
* all channels must be instances of INDArrays or DataSets
*
* NOTE: Presently, only single-channels observations are supported.
*
* @author Alexandre Boulanger
*/
public class TransformProcess {
private final List<Map.Entry<String, Object>> operations;
private final String[] channelNames;
private final HashSet<String> operationsChannelNames;
private TransformProcess(Builder builder, String... channelNames) {
operations = builder.operations;
this.channelNames = channelNames;
this.operationsChannelNames = builder.requiredChannelNames;
}
/**
* This method will call reset() of all steps implementing {@link ResettableOperation ResettableOperation} in the transform process.
*/
public void reset() {
for(Map.Entry<String, Object> entry : operations) {
if(entry.getValue() instanceof ResettableOperation) {
((ResettableOperation) entry.getValue()).reset();
}
}
}
/**
* Transforms the channel data into an Observation or a skipped observation depending on the specific steps in the transform process.
*
* @param channelsData A Map that maps the channel name to its data.
* @param currentObservationStep The observation's step number within the episode.
* @param isFinalObservation True if the observation is the last of the episode.
* @return An observation (may be a skipped observation)
*/
public Observation transform(Map<String, Object> channelsData, int currentObservationStep, boolean isFinalObservation) {
// null or empty channelData
Preconditions.checkArgument(channelsData != null && channelsData.size() != 0, "Error: channelsData not supplied.");
// Check that all channels have data
for(Map.Entry<String, Object> channel : channelsData.entrySet()) {
Preconditions.checkNotNull(channel.getValue(), "Error: data of channel '%s' is null", channel.getKey());
}
// Check that all required channels are present
for(String channelName : operationsChannelNames) {
Preconditions.checkArgument(channelsData.containsKey(channelName), "The channelsData map does not contain the channel '%s'", channelName);
}
for(Map.Entry<String, Object> entry : operations) {
// Filter
if(entry.getValue() instanceof FilterOperation) {
FilterOperation filterOperation = (FilterOperation)entry.getValue();
if(filterOperation.isSkipped(channelsData, currentObservationStep, isFinalObservation)) {
return Observation.SkippedObservation;
}
}
// Transform
// null results are considered skipped observations
else if(entry.getValue() instanceof Operation) {
Operation transformOperation = (Operation)entry.getValue();
Object transformed = transformOperation.transform(channelsData.get(entry.getKey()));
if(transformed == null) {
return Observation.SkippedObservation;
}
channelsData.replace(entry.getKey(), transformed);
}
// PreProcess
else if(entry.getValue() instanceof DataSetPreProcessor) {
Object channelData = channelsData.get(entry.getKey());
DataSetPreProcessor dataSetPreProcessor = (DataSetPreProcessor)entry.getValue();
if(!(channelData instanceof DataSet)) {
throw new IllegalArgumentException("The channel data must be a DataSet to call preProcess");
}
dataSetPreProcessor.preProcess((DataSet)channelData);
}
else {
throw new IllegalArgumentException(String.format("Unknown operation: '%s'", entry.getValue().getClass().getName()));
}
}
// Check that all channels used to build the observation are instances of
// INDArray or DataSet
// TODO: Add support for an interface with a toINDArray() method
for(String channelName : channelNames) {
Object channelData = channelsData.get(channelName);
INDArray finalChannelData;
if(channelData instanceof DataSet) {
finalChannelData = ((DataSet)channelData).getFeatures();
}
else if(channelData instanceof INDArray) {
finalChannelData = (INDArray) channelData;
}
else {
throw new IllegalStateException("All channels used to build the observation must be instances of DataSet or INDArray");
}
// The dimension 0 of all INDArrays must be 1 (batch count)
channelsData.replace(channelName, INDArrayHelper.forceCorrectShape(finalChannelData));
}
// TODO: Add support to multi-channel observations
INDArray data = ((INDArray) channelsData.get(channelNames[0]));
return new Observation(data);
}
/**
* @return An instance of a builder
*/
public static Builder builder() {
return new Builder();
}
public static class Builder {
private final List<Map.Entry<String, Object>> operations = new ArrayList<Map.Entry<String, Object>>();
private final HashSet<String> requiredChannelNames = new HashSet<String>();
/**
* Add a filter to the transform process steps. Used to skip observations on certain conditions.
* See {@link FilterOperation FilterOperation}
* @param filterOperation An instance
*/
public Builder filter(FilterOperation filterOperation) {
Preconditions.checkNotNull(filterOperation, "The filterOperation must not be null");
operations.add((Map.Entry)Maps.immutableEntry(null, filterOperation));
return this;
}
/**
* Add a transform to the steps. The transform can change the data and / or change the type of the data
* (e.g. Byte[] to a ImageWritable)
*
* @param targetChannel The name of the channel to which the transform is applied.
* @param transformOperation An instance of {@link Operation Operation}
*/
public Builder transform(String targetChannel, Operation transformOperation) {
Preconditions.checkNotNull(targetChannel, "The targetChannel must not be null");
Preconditions.checkNotNull(transformOperation, "The transformOperation must not be null");
requiredChannelNames.add(targetChannel);
operations.add((Map.Entry)Maps.immutableEntry(targetChannel, transformOperation));
return this;
}
/**
* Add a DataSetPreProcessor to the steps. The channel must be a DataSet instance at this step.
* @param targetChannel The name of the channel to which the pre processor is applied.
* @param dataSetPreProcessor
*/
public Builder preProcess(String targetChannel, DataSetPreProcessor dataSetPreProcessor) {
Preconditions.checkNotNull(targetChannel, "The targetChannel must not be null");
Preconditions.checkNotNull(dataSetPreProcessor, "The dataSetPreProcessor must not be null");
requiredChannelNames.add(targetChannel);
operations.add((Map.Entry)Maps.immutableEntry(targetChannel, dataSetPreProcessor));
return this;
}
/**
* Builds the TransformProcess.
* @param channelNames A subset of channel names to be used to build the observation
* @return An instance of TransformProcess
*/
public TransformProcess build(String... channelNames) {
if(channelNames.length == 0) {
throw new IllegalArgumentException("At least one channel must be supplied.");
}
for(String channelName : channelNames) {
Preconditions.checkNotNull(channelName, "Error: got a null channel name");
requiredChannelNames.add(channelName);
}
// TODO: Remove when multi-channel observation is supported
if(channelNames.length != 1) {
throw new NotImplementedException("Multi-channel observations is not presently supported.");
}
return new TransformProcess(this, channelNames);
}
}
}

View File

@ -25,14 +25,14 @@ import org.nd4j.linalg.factory.Nd4j;
import static org.bytedeco.opencv.global.opencv_core.CV_32FC; import static org.bytedeco.opencv.global.opencv_core.CV_32FC;
public class EncodableToImageWriteableTransform implements Operation<Encodable, ImageWritable> { public class EncodableToImageWritableTransform implements Operation<Encodable, ImageWritable> {
private final OpenCVFrameConverter.ToMat converter = new OpenCVFrameConverter.ToMat(); private final OpenCVFrameConverter.ToMat converter = new OpenCVFrameConverter.ToMat();
private final int height; private final int height;
private final int width; private final int width;
private final int colorChannels; private final int colorChannels;
public EncodableToImageWriteableTransform(int height, int width, int colorChannels) { public EncodableToImageWritableTransform(int height, int width, int colorChannels) {
this.height = height; this.height = height;
this.width = width; this.width = width;
this.colorChannels = colorChannels; this.colorChannels = colorChannels;

View File

@ -1,3 +1,18 @@
/*******************************************************************************
* Copyright (c) 2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.rl4j.observation.transform.legacy; package org.deeplearning4j.rl4j.observation.transform.legacy;
import org.datavec.api.transform.Operation; import org.datavec.api.transform.Operation;
@ -10,13 +25,13 @@ import org.nd4j.linalg.factory.Nd4j;
import java.io.IOException; import java.io.IOException;
public class ImageWriteableToINDArrayTransform implements Operation<ImageWritable, INDArray> { public class ImageWritableToINDArrayTransform implements Operation<ImageWritable, INDArray> {
private final int height; private final int height;
private final int width; private final int width;
private final NativeImageLoader loader; private final NativeImageLoader loader;
public ImageWriteableToINDArrayTransform(int height, int width) { public ImageWritableToINDArrayTransform(int height, int width) {
this.height = height; this.height = height;
this.width = width; this.width = width;
this.loader = new NativeImageLoader(height, width); this.loader = new NativeImageLoader(height, width);

View File

@ -0,0 +1,44 @@
/*******************************************************************************
* Copyright (c) 2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.rl4j.observation.transform.operation;
import org.datavec.api.transform.Operation;
import org.nd4j.base.Preconditions;
import org.nd4j.linalg.api.ndarray.INDArray;
public class SimpleNormalizationTransform implements Operation<INDArray, INDArray> {
private final double offset;
private final double divisor;
public SimpleNormalizationTransform(double min, double max) {
Preconditions.checkArgument(min < max, "Min must be smaller than max.");
this.offset = min;
this.divisor = (max - min);
}
@Override
public INDArray transform(INDArray input) {
if(offset != 0.0) {
input.subi(offset);
}
input.divi(divisor);
return input;
}
}

View File

@ -3,37 +3,95 @@ package org.deeplearning4j.rl4j.util;
import lombok.AccessLevel; import lombok.AccessLevel;
import lombok.Getter; import lombok.Getter;
import lombok.Setter; import lombok.Setter;
import org.datavec.image.transform.ColorConversionTransform;
import org.datavec.image.transform.CropImageTransform;
import org.datavec.image.transform.MultiImageTransform;
import org.datavec.image.transform.ResizeImageTransform;
import org.deeplearning4j.gym.StepReply; import org.deeplearning4j.gym.StepReply;
import org.deeplearning4j.rl4j.learning.EpochStepCounter; import org.deeplearning4j.rl4j.learning.EpochStepCounter;
import org.deeplearning4j.rl4j.learning.IHistoryProcessor; import org.deeplearning4j.rl4j.learning.IHistoryProcessor;
import org.deeplearning4j.rl4j.mdp.MDP; import org.deeplearning4j.rl4j.mdp.MDP;
import org.deeplearning4j.rl4j.observation.Observation; import org.deeplearning4j.rl4j.observation.Observation;
import org.deeplearning4j.rl4j.observation.transform.TransformProcess;
import org.deeplearning4j.rl4j.observation.transform.filter.UniformSkippingFilter;
import org.deeplearning4j.rl4j.observation.transform.legacy.EncodableToINDArrayTransform;
import org.deeplearning4j.rl4j.observation.transform.legacy.EncodableToImageWritableTransform;
import org.deeplearning4j.rl4j.observation.transform.legacy.ImageWritableToINDArrayTransform;
import org.deeplearning4j.rl4j.observation.transform.operation.HistoryMergeTransform;
import org.deeplearning4j.rl4j.observation.transform.operation.SimpleNormalizationTransform;
import org.deeplearning4j.rl4j.space.ActionSpace; import org.deeplearning4j.rl4j.space.ActionSpace;
import org.deeplearning4j.rl4j.space.Encodable; import org.deeplearning4j.rl4j.space.Encodable;
import org.deeplearning4j.rl4j.space.ObservationSpace; import org.deeplearning4j.rl4j.space.ObservationSpace;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
import java.util.HashMap;
import java.util.Map;
import static org.bytedeco.opencv.global.opencv_imgproc.COLOR_BGR2GRAY;
public class LegacyMDPWrapper<O, A, AS extends ActionSpace<A>> implements MDP<Observation, A, AS> { public class LegacyMDPWrapper<O, A, AS extends ActionSpace<A>> implements MDP<Observation, A, AS> {
@Getter @Getter
private final MDP<O, A, AS> wrappedMDP; private final MDP<O, A, AS> wrappedMDP;
@Getter @Getter
private final WrapperObservationSpace observationSpace; private final WrapperObservationSpace observationSpace;
private final int[] shape;
@Getter(AccessLevel.PRIVATE) @Setter(AccessLevel.PUBLIC) @Setter
private TransformProcess transformProcess;
@Getter(AccessLevel.PRIVATE)
private IHistoryProcessor historyProcessor; private IHistoryProcessor historyProcessor;
private final EpochStepCounter epochStepCounter; private final EpochStepCounter epochStepCounter;
private int skipFrame = 1; private int skipFrame = 1;
private int requiredFrame = 0;
public LegacyMDPWrapper(MDP<O, A, AS> wrappedMDP, IHistoryProcessor historyProcessor, EpochStepCounter epochStepCounter) { public LegacyMDPWrapper(MDP<O, A, AS> wrappedMDP, IHistoryProcessor historyProcessor, EpochStepCounter epochStepCounter) {
this.wrappedMDP = wrappedMDP; this.wrappedMDP = wrappedMDP;
this.observationSpace = new WrapperObservationSpace(wrappedMDP.getObservationSpace().getShape()); this.shape = wrappedMDP.getObservationSpace().getShape();
this.observationSpace = new WrapperObservationSpace(shape);
this.historyProcessor = historyProcessor; this.historyProcessor = historyProcessor;
this.epochStepCounter = epochStepCounter; this.epochStepCounter = epochStepCounter;
setHistoryProcessor(historyProcessor);
}
public void setHistoryProcessor(IHistoryProcessor historyProcessor) {
this.historyProcessor = historyProcessor;
createTransformProcess();
}
private void createTransformProcess() {
IHistoryProcessor historyProcessor = getHistoryProcessor();
if(historyProcessor != null && shape.length == 3) {
int skipFrame = historyProcessor.getConf().getSkipFrame();
int finalHeight = historyProcessor.getConf().getCroppingHeight();
int finalWidth = historyProcessor.getConf().getCroppingWidth();
transformProcess = TransformProcess.builder()
.filter(new UniformSkippingFilter(skipFrame))
.transform("data", new EncodableToImageWritableTransform(shape[0], shape[1], shape[2]))
.transform("data", new MultiImageTransform(
new ResizeImageTransform(historyProcessor.getConf().getRescaledWidth(), historyProcessor.getConf().getRescaledHeight()),
new ColorConversionTransform(COLOR_BGR2GRAY),
new CropImageTransform(historyProcessor.getConf().getOffsetY(), historyProcessor.getConf().getOffsetX(), finalHeight, finalWidth)
))
.transform("data", new ImageWritableToINDArrayTransform(finalHeight, finalWidth))
.transform("data", new SimpleNormalizationTransform(0.0, 255.0))
.transform("data", HistoryMergeTransform.builder()
.isFirstDimenstionBatch(true)
.build())
.build("data");
}
else {
transformProcess = TransformProcess.builder()
.transform("data", new EncodableToINDArrayTransform(shape))
.build("data");
}
} }
@Override @Override
@ -43,25 +101,17 @@ public class LegacyMDPWrapper<O, A, AS extends ActionSpace<A>> implements MDP<Ob
@Override @Override
public Observation reset() { public Observation reset() {
INDArray rawObservation = getInput(wrappedMDP.reset()); transformProcess.reset();
IHistoryProcessor historyProcessor = getHistoryProcessor(); O rawResetResponse = wrappedMDP.reset();
if(historyProcessor != null) { record(rawResetResponse);
historyProcessor.record(rawObservation);
}
Observation observation = new Observation(new INDArray[] { rawObservation }, false);
if(historyProcessor != null) { if(historyProcessor != null) {
skipFrame = historyProcessor.getConf().getSkipFrame(); skipFrame = historyProcessor.getConf().getSkipFrame();
requiredFrame = skipFrame * (historyProcessor.getConf().getHistoryLength() - 1);
historyProcessor.add(rawObservation);
} }
observation.setSkipped(skipFrame != 0); Map<String, Object> channelsData = buildChannelsData(rawResetResponse);
return transformProcess.transform(channelsData, 0, false);
return observation;
} }
@Override @Override
@ -71,32 +121,32 @@ public class LegacyMDPWrapper<O, A, AS extends ActionSpace<A>> implements MDP<Ob
StepReply<O> rawStepReply = wrappedMDP.step(a); StepReply<O> rawStepReply = wrappedMDP.step(a);
INDArray rawObservation = getInput(rawStepReply.getObservation()); INDArray rawObservation = getInput(rawStepReply.getObservation());
int stepOfObservation = epochStepCounter.getCurrentEpochStep() + 1;
if(historyProcessor != null) { if(historyProcessor != null) {
historyProcessor.record(rawObservation); historyProcessor.record(rawObservation);
if (stepOfObservation % skipFrame == 0) {
historyProcessor.add(rawObservation);
}
} }
Observation observation; int stepOfObservation = epochStepCounter.getCurrentEpochStep() + 1;
if(historyProcessor != null && stepOfObservation >= requiredFrame) {
observation = new Observation(historyProcessor.getHistory(), true);
observation.getData().muli(1.0 / historyProcessor.getScale());
}
else {
observation = new Observation(new INDArray[] { rawObservation }, false);
}
if(stepOfObservation % skipFrame != 0 || stepOfObservation < requiredFrame) {
observation.setSkipped(true);
}
Map<String, Object> channelsData = buildChannelsData(rawStepReply.getObservation());
Observation observation = transformProcess.transform(channelsData, stepOfObservation, rawStepReply.isDone());
return new StepReply<Observation>(observation, rawStepReply.getReward(), rawStepReply.isDone(), rawStepReply.getInfo()); return new StepReply<Observation>(observation, rawStepReply.getReward(), rawStepReply.isDone(), rawStepReply.getInfo());
} }
private void record(O obs) {
INDArray rawObservation = getInput(obs);
IHistoryProcessor historyProcessor = getHistoryProcessor();
if(historyProcessor != null) {
historyProcessor.record(rawObservation);
}
}
private Map<String, Object> buildChannelsData(final O obs) {
return new HashMap<String, Object>() {{
put("data", obs);
}};
}
@Override @Override
public void close() { public void close() {
wrappedMDP.close(); wrappedMDP.close();

View File

@ -34,6 +34,7 @@ public class AsyncThreadDiscreteTest {
MockPolicy policyMock = new MockPolicy(); MockPolicy policyMock = new MockPolicy();
MockAsyncConfiguration config = new MockAsyncConfiguration(5, 100, 0, 0, 2, 5,0, 0, 0, 0); MockAsyncConfiguration config = new MockAsyncConfiguration(5, 100, 0, 0, 2, 5,0, 0, 0, 0);
TestAsyncThreadDiscrete sut = new TestAsyncThreadDiscrete(asyncGlobalMock, mdpMock, listeners, 0, 0, policyMock, config, hpMock); TestAsyncThreadDiscrete sut = new TestAsyncThreadDiscrete(asyncGlobalMock, mdpMock, listeners, 0, 0, policyMock, config, hpMock);
sut.getLegacyMDPWrapper().setTransformProcess(MockMDP.buildTransformProcess(observationSpace.getShape(), hpConf.getSkipFrame(), hpConf.getHistoryLength()));
// Act // Act
sut.run(); sut.run();
@ -60,12 +61,6 @@ public class AsyncThreadDiscreteTest {
assertEquals(2, asyncGlobalMock.enqueueCallCount); assertEquals(2, asyncGlobalMock.enqueueCallCount);
// HistoryProcessor // HistoryProcessor
double[] expectedAddValues = new double[] { 0.0, 2.0, 4.0, 6.0, 8.0, 10.0, 12.0, 14.0, 16.0 };
assertEquals(expectedAddValues.length, hpMock.addCalls.size());
for(int i = 0; i < expectedAddValues.length; ++i) {
assertEquals(expectedAddValues[i], hpMock.addCalls.get(i).getDouble(0), 0.00001);
}
double[] expectedRecordValues = new double[] { 0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0, }; double[] expectedRecordValues = new double[] { 0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0, };
assertEquals(expectedRecordValues.length, hpMock.recordCalls.size()); assertEquals(expectedRecordValues.length, hpMock.recordCalls.size());
for(int i = 0; i < expectedRecordValues.length; ++i) { for(int i = 0; i < expectedRecordValues.length; ++i) {

View File

@ -138,6 +138,7 @@ public class AsyncThreadTest {
asyncGlobal.setMaxLoops(numEpochs); asyncGlobal.setMaxLoops(numEpochs);
listeners.add(listener); listeners.add(listener);
sut.setHistoryProcessor(historyProcessor); sut.setHistoryProcessor(historyProcessor);
sut.getLegacyMDPWrapper().setTransformProcess(MockMDP.buildTransformProcess(observationSpace.getShape(), hpConf.getSkipFrame(), hpConf.getHistoryLength()));
} }
} }
@ -209,7 +210,4 @@ public class AsyncThreadTest {
int nstep; int nstep;
} }
} }
} }

View File

@ -18,7 +18,7 @@ public class ExpReplayTest {
ExpReplay<Integer> sut = new ExpReplay<Integer>(2, 1, randomMock); ExpReplay<Integer> sut = new ExpReplay<Integer>(2, 1, randomMock);
// Act // Act
Transition<Integer> transition = buildTransition(new Observation(new INDArray[] { Nd4j.create(1) }), Transition<Integer> transition = buildTransition(buildObservation(),
123, 234, new Observation(Nd4j.create(1))); 123, 234, new Observation(Nd4j.create(1)));
sut.store(transition); sut.store(transition);
List<Transition<Integer>> results = sut.getBatch(1); List<Transition<Integer>> results = sut.getBatch(1);
@ -36,11 +36,11 @@ public class ExpReplayTest {
ExpReplay<Integer> sut = new ExpReplay<Integer>(2, 1, randomMock); ExpReplay<Integer> sut = new ExpReplay<Integer>(2, 1, randomMock);
// Act // Act
Transition<Integer> transition1 = buildTransition(new Observation(new INDArray[] { Nd4j.create(1) }), Transition<Integer> transition1 = buildTransition(buildObservation(),
1, 2, new Observation(Nd4j.create(1))); 1, 2, new Observation(Nd4j.create(1)));
Transition<Integer> transition2 = buildTransition(new Observation(new INDArray[] { Nd4j.create(1) }), Transition<Integer> transition2 = buildTransition(buildObservation(),
3, 4, new Observation(Nd4j.create(1))); 3, 4, new Observation(Nd4j.create(1)));
Transition<Integer> transition3 = buildTransition(new Observation(new INDArray[] { Nd4j.create(1) }), Transition<Integer> transition3 = buildTransition(buildObservation(),
5, 6, new Observation(Nd4j.create(1))); 5, 6, new Observation(Nd4j.create(1)));
sut.store(transition1); sut.store(transition1);
sut.store(transition2); sut.store(transition2);
@ -78,11 +78,11 @@ public class ExpReplayTest {
ExpReplay<Integer> sut = new ExpReplay<Integer>(5, 1, randomMock); ExpReplay<Integer> sut = new ExpReplay<Integer>(5, 1, randomMock);
// Act // Act
Transition<Integer> transition1 = buildTransition(new Observation(new INDArray[] { Nd4j.create(1) }), Transition<Integer> transition1 = buildTransition(buildObservation(),
1, 2, new Observation(Nd4j.create(1))); 1, 2, new Observation(Nd4j.create(1)));
Transition<Integer> transition2 = buildTransition(new Observation(new INDArray[] { Nd4j.create(1) }), Transition<Integer> transition2 = buildTransition(buildObservation(),
3, 4, new Observation(Nd4j.create(1))); 3, 4, new Observation(Nd4j.create(1)));
Transition<Integer> transition3 = buildTransition(new Observation(new INDArray[] { Nd4j.create(1) }), Transition<Integer> transition3 = buildTransition(buildObservation(),
5, 6, new Observation(Nd4j.create(1))); 5, 6, new Observation(Nd4j.create(1)));
sut.store(transition1); sut.store(transition1);
sut.store(transition2); sut.store(transition2);
@ -100,11 +100,11 @@ public class ExpReplayTest {
ExpReplay<Integer> sut = new ExpReplay<Integer>(5, 1, randomMock); ExpReplay<Integer> sut = new ExpReplay<Integer>(5, 1, randomMock);
// Act // Act
Transition<Integer> transition1 = buildTransition(new Observation(new INDArray[] { Nd4j.create(1) }), Transition<Integer> transition1 = buildTransition(buildObservation(),
1, 2, new Observation(Nd4j.create(1))); 1, 2, new Observation(Nd4j.create(1)));
Transition<Integer> transition2 = buildTransition(new Observation(new INDArray[] { Nd4j.create(1) }), Transition<Integer> transition2 = buildTransition(buildObservation(),
3, 4, new Observation(Nd4j.create(1))); 3, 4, new Observation(Nd4j.create(1)));
Transition<Integer> transition3 = buildTransition(new Observation(new INDArray[] { Nd4j.create(1) }), Transition<Integer> transition3 = buildTransition(buildObservation(),
5, 6, new Observation(Nd4j.create(1))); 5, 6, new Observation(Nd4j.create(1)));
sut.store(transition1); sut.store(transition1);
sut.store(transition2); sut.store(transition2);
@ -131,15 +131,15 @@ public class ExpReplayTest {
ExpReplay<Integer> sut = new ExpReplay<Integer>(5, 1, randomMock); ExpReplay<Integer> sut = new ExpReplay<Integer>(5, 1, randomMock);
// Act // Act
Transition<Integer> transition1 = buildTransition(new Observation(new INDArray[] { Nd4j.create(1) }), Transition<Integer> transition1 = buildTransition(buildObservation(),
1, 2, new Observation(Nd4j.create(1))); 1, 2, new Observation(Nd4j.create(1)));
Transition<Integer> transition2 = buildTransition(new Observation(new INDArray[] { Nd4j.create(1) }), Transition<Integer> transition2 = buildTransition(buildObservation(),
3, 4, new Observation(Nd4j.create(1))); 3, 4, new Observation(Nd4j.create(1)));
Transition<Integer> transition3 = buildTransition(new Observation(new INDArray[] { Nd4j.create(1) }), Transition<Integer> transition3 = buildTransition(buildObservation(),
5, 6, new Observation(Nd4j.create(1))); 5, 6, new Observation(Nd4j.create(1)));
Transition<Integer> transition4 = buildTransition(new Observation(new INDArray[] { Nd4j.create(1) }), Transition<Integer> transition4 = buildTransition(buildObservation(),
7, 8, new Observation(Nd4j.create(1))); 7, 8, new Observation(Nd4j.create(1)));
Transition<Integer> transition5 = buildTransition(new Observation(new INDArray[] { Nd4j.create(1) }), Transition<Integer> transition5 = buildTransition(buildObservation(),
9, 10, new Observation(Nd4j.create(1))); 9, 10, new Observation(Nd4j.create(1)));
sut.store(transition1); sut.store(transition1);
sut.store(transition2); sut.store(transition2);
@ -168,15 +168,15 @@ public class ExpReplayTest {
ExpReplay<Integer> sut = new ExpReplay<Integer>(5, 1, randomMock); ExpReplay<Integer> sut = new ExpReplay<Integer>(5, 1, randomMock);
// Act // Act
Transition<Integer> transition1 = buildTransition(new Observation(new INDArray[] { Nd4j.create(1) }), Transition<Integer> transition1 = buildTransition(buildObservation(),
1, 2, new Observation(Nd4j.create(1))); 1, 2, new Observation(Nd4j.create(1)));
Transition<Integer> transition2 = buildTransition(new Observation(new INDArray[] { Nd4j.create(1) }), Transition<Integer> transition2 = buildTransition(buildObservation(),
3, 4, new Observation(Nd4j.create(1))); 3, 4, new Observation(Nd4j.create(1)));
Transition<Integer> transition3 = buildTransition(new Observation(new INDArray[] { Nd4j.create(1) }), Transition<Integer> transition3 = buildTransition(buildObservation(),
5, 6, new Observation(Nd4j.create(1))); 5, 6, new Observation(Nd4j.create(1)));
Transition<Integer> transition4 = buildTransition(new Observation(new INDArray[] { Nd4j.create(1) }), Transition<Integer> transition4 = buildTransition(buildObservation(),
7, 8, new Observation(Nd4j.create(1))); 7, 8, new Observation(Nd4j.create(1)));
Transition<Integer> transition5 = buildTransition(new Observation(new INDArray[] { Nd4j.create(1) }), Transition<Integer> transition5 = buildTransition(buildObservation(),
9, 10, new Observation(Nd4j.create(1))); 9, 10, new Observation(Nd4j.create(1)));
sut.store(transition1); sut.store(transition1);
sut.store(transition2); sut.store(transition2);
@ -204,4 +204,8 @@ public class ExpReplayTest {
return result; return result;
} }
private Observation buildObservation() {
return new Observation(Nd4j.create(1, 1));
}
} }

View File

@ -193,11 +193,11 @@ public class TransitionTest {
Nd4j.create(obs[1]).reshape(1, 3), Nd4j.create(obs[1]).reshape(1, 3),
Nd4j.create(obs[2]).reshape(1, 3), Nd4j.create(obs[2]).reshape(1, 3),
}; };
return new Observation(history); return new Observation(Nd4j.concat(0, history));
} }
private Observation buildObservation(double[] obs) { private Observation buildObservation(double[] obs) {
return new Observation(new INDArray[] { Nd4j.create(obs).reshape(1, 3) }); return new Observation(Nd4j.create(obs).reshape(1, 3));
} }
private Observation buildNextObservation(double[][] obs, double[] nextObs) { private Observation buildNextObservation(double[][] obs, double[] nextObs) {
@ -206,7 +206,7 @@ public class TransitionTest {
Nd4j.create(obs[0]).reshape(1, 3), Nd4j.create(obs[0]).reshape(1, 3),
Nd4j.create(obs[1]).reshape(1, 3), Nd4j.create(obs[1]).reshape(1, 3),
}; };
return new Observation(nextHistory); return new Observation(Nd4j.concat(0, nextHistory));
} }
private Transition buildTransition(Observation observation, int action, double reward, Observation nextObservation) { private Transition buildTransition(Observation observation, int action, double reward, Observation nextObservation) {

View File

@ -50,6 +50,7 @@ public class QLearningDiscreteTest {
IHistoryProcessor.Configuration hpConf = new IHistoryProcessor.Configuration(5, 4, 4, 4, 4, 0, 0, 2); IHistoryProcessor.Configuration hpConf = new IHistoryProcessor.Configuration(5, 4, 4, 4, 4, 0, 0, 2);
MockHistoryProcessor hp = new MockHistoryProcessor(hpConf); MockHistoryProcessor hp = new MockHistoryProcessor(hpConf);
sut.setHistoryProcessor(hp); sut.setHistoryProcessor(hp);
sut.getLegacyMDPWrapper().setTransformProcess(MockMDP.buildTransformProcess(observationSpace.getShape(), hpConf.getSkipFrame(), hpConf.getHistoryLength()));
List<QLearning.QLStepReturn<MockEncodable>> results = new ArrayList<>(); List<QLearning.QLStepReturn<MockEncodable>> results = new ArrayList<>();
// Act // Act
@ -62,11 +63,7 @@ public class QLearningDiscreteTest {
for(int i = 0; i < expectedRecords.length; ++i) { for(int i = 0; i < expectedRecords.length; ++i) {
assertEquals(expectedRecords[i], hp.recordCalls.get(i).getDouble(0), 0.0001); assertEquals(expectedRecords[i], hp.recordCalls.get(i).getDouble(0), 0.0001);
} }
double[] expectedAdds = new double[] { 0.0, 2.0, 4.0, 6.0, 8.0, 10.0, 12.0, 14.0, 16.0, 18.0, 20.0, 22.0, 24.0 };
assertEquals(expectedAdds.length, hp.addCalls.size());
for(int i = 0; i < expectedAdds.length; ++i) {
assertEquals(expectedAdds[i], hp.addCalls.get(i).getDouble(0), 0.0001);
}
assertEquals(0, hp.startMonitorCallCount); assertEquals(0, hp.startMonitorCallCount);
assertEquals(0, hp.stopMonitorCallCount); assertEquals(0, hp.stopMonitorCallCount);

View File

@ -106,7 +106,7 @@ public class DoubleDQNTest {
} }
private Observation buildObservation(double[] data) { private Observation buildObservation(double[] data) {
return new Observation(new INDArray[]{Nd4j.create(data).reshape(1, 2)}); return new Observation(Nd4j.create(data).reshape(1, 2));
} }
private Transition<Integer> builtTransition(Observation observation, Integer action, double reward, boolean isTerminal, Observation nextObservation) { private Transition<Integer> builtTransition(Observation observation, Integer action, double reward, boolean isTerminal, Observation nextObservation) {

View File

@ -105,7 +105,7 @@ public class StandardDQNTest {
} }
private Observation buildObservation(double[] data) { private Observation buildObservation(double[] data) {
return new Observation(new INDArray[]{Nd4j.create(data).reshape(1, 2)}); return new Observation(Nd4j.create(data).reshape(1, 2));
} }
private Transition<Integer> buildTransition(Observation observation, Integer action, double reward, boolean isTerminal, Observation nextObservation) { private Transition<Integer> buildTransition(Observation observation, Integer action, double reward, boolean isTerminal, Observation nextObservation) {

View File

@ -0,0 +1,378 @@
package org.deeplearning4j.rl4j.observation.transform;
import org.deeplearning4j.rl4j.observation.Observation;
import org.junit.Test;
import org.nd4j.linalg.dataset.api.DataSet;
import org.nd4j.linalg.dataset.api.DataSetPreProcessor;
import org.nd4j.linalg.factory.Nd4j;
import org.datavec.api.transform.Operation;
import java.util.HashMap;
import java.util.Map;
import static org.junit.Assert.*;
public class TransformProcessTest {
@Test(expected = IllegalArgumentException.class)
public void when_noChannelNameIsSuppliedToBuild_expect_exception() {
// Arrange
TransformProcess.builder().build();
}
@Test(expected = IllegalArgumentException.class)
public void when_callingTransformWithNullArg_expect_exception() {
// Arrange
TransformProcess sut = TransformProcess.builder()
.build("test");
// Act
sut.transform(null, 0, false);
}
@Test(expected = IllegalArgumentException.class)
public void when_callingTransformWithEmptyChannelData_expect_exception() {
// Arrange
TransformProcess sut = TransformProcess.builder()
.build("test");
Map<String, Object> channelsData = new HashMap<String, Object>();
// Act
sut.transform(channelsData, 0, false);
}
@Test(expected = NullPointerException.class)
public void when_addingNullFilter_expect_nullException() {
// Act
TransformProcess.builder().filter(null);
}
@Test
public void when_fileteredOut_expect_skippedObservationAndFollowingOperationsSkipped() {
// Arrange
IntegerTransformOperationMock transformOperationMock = new IntegerTransformOperationMock();
TransformProcess sut = TransformProcess.builder()
.filter(new FilterOperationMock(true))
.transform("test", transformOperationMock)
.build("test");
Map<String, Object> channelsData = new HashMap<String, Object>() {{
put("test", 1);
}};
// Act
Observation result = sut.transform(channelsData, 0, false);
// Assert
assertTrue(result.isSkipped());
assertFalse(transformOperationMock.isCalled);
}
@Test(expected = NullPointerException.class)
public void when_addingTransformOnNullChannel_expect_nullException() {
// Act
TransformProcess.builder().transform(null, new IntegerTransformOperationMock());
}
@Test(expected = NullPointerException.class)
public void when_addingTransformWithNullTransform_expect_nullException() {
// Act
TransformProcess.builder().transform("test", null);
}
@Test
public void when_transformIsCalled_expect_channelDataTransformedInSameOrder() {
// Arrange
TransformProcess sut = TransformProcess.builder()
.filter(new FilterOperationMock(false))
.transform("test", new IntegerTransformOperationMock())
.transform("test", new ToDataSetTransformOperationMock())
.build("test");
Map<String, Object> channelsData = new HashMap<String, Object>() {{
put("test", 1);
}};
// Act
Observation result = sut.transform(channelsData, 0, false);
// Assert
assertFalse(result.isSkipped());
assertEquals(-1.0, result.getData().getDouble(0), 0.00001);
}
@Test(expected = NullPointerException.class)
public void when_addingPreProcessOnNullChannel_expect_nullException() {
// Act
TransformProcess.builder().preProcess(null, new DataSetPreProcessorMock());
}
@Test(expected = NullPointerException.class)
public void when_addingPreProcessWithNullTransform_expect_nullException() {
// Act
TransformProcess.builder().transform("test", null);
}
@Test
public void when_preProcessIsCalled_expect_channelDataPreProcessedInSameOrder() {
// Arrange
TransformProcess sut = TransformProcess.builder()
.filter(new FilterOperationMock(false))
.transform("test", new IntegerTransformOperationMock())
.transform("test", new ToDataSetTransformOperationMock())
.preProcess("test", new DataSetPreProcessorMock())
.build("test");
Map<String, Object> channelsData = new HashMap<String, Object>() {{
put("test", 1);
}};
// Act
Observation result = sut.transform(channelsData, 0, false);
// Assert
assertFalse(result.isSkipped());
assertEquals(1, result.getData().shape().length);
assertEquals(1, result.getData().shape()[0]);
assertEquals(-10.0, result.getData().getDouble(0), 0.00001);
}
@Test(expected = IllegalStateException.class)
public void when_transformingNullData_expect_exception() {
// Arrange
TransformProcess sut = TransformProcess.builder()
.transform("test", new IntegerTransformOperationMock())
.build("test");
Map<String, Object> channelsData = new HashMap<String, Object>() {{
put("test", 1);
}};
// Act
Observation result = sut.transform(channelsData, 0, false);
}
@Test(expected = IllegalArgumentException.class)
public void when_transformingAndChannelsNotDataSet_expect_exception() {
// Arrange
TransformProcess sut = TransformProcess.builder()
.preProcess("test", new DataSetPreProcessorMock())
.build("test");
// Act
Observation result = sut.transform(null, 0, false);
}
@Test(expected = IllegalArgumentException.class)
public void when_transformingAndChannelsEmptyDataSet_expect_exception() {
// Arrange
TransformProcess sut = TransformProcess.builder()
.preProcess("test", new DataSetPreProcessorMock())
.build("test");
Map<String, Object> channelsData = new HashMap<String, Object>();
// Act
Observation result = sut.transform(channelsData, 0, false);
}
@Test(expected = IllegalArgumentException.class)
public void when_buildIsCalledWithoutChannelNames_expect_exception() {
// Act
TransformProcess.builder().build();
}
@Test(expected = NullPointerException.class)
public void when_buildIsCalledWithNullChannelName_expect_exception() {
// Act
TransformProcess.builder().build(null);
}
@Test
public void when_resetIsCalled_expect_resettableAreReset() {
// Arrange
ResettableTransformOperationMock resettableOperation = new ResettableTransformOperationMock();
TransformProcess sut = TransformProcess.builder()
.filter(new FilterOperationMock(false))
.transform("test", new IntegerTransformOperationMock())
.transform("test", resettableOperation)
.build("test");
// Act
sut.reset();
// Assert
assertTrue(resettableOperation.isResetCalled);
}
@Test
public void when_buildIsCalledAndAllChannelsAreDataSets_expect_observation() {
// Arrange
TransformProcess sut = TransformProcess.builder()
.transform("test", new ToDataSetTransformOperationMock())
.build("test");
Map<String, Object> channelsData = new HashMap<String, Object>() {{
put("test", 1);
}};
// Act
Observation result = sut.transform(channelsData, 123, true);
// Assert
assertFalse(result.isSkipped());
assertEquals(1.0, result.getData().getDouble(0), 0.00001);
}
@Test
public void when_buildIsCalledAndAllChannelsAreINDArrays_expect_observation() {
// Arrange
TransformProcess sut = TransformProcess.builder()
.build("test");
Map<String, Object> channelsData = new HashMap<String, Object>() {{
put("test", Nd4j.create(new double[] { 1.0 }));
}};
// Act
Observation result = sut.transform(channelsData, 123, true);
// Assert
assertFalse(result.isSkipped());
assertEquals(1.0, result.getData().getDouble(0), 0.00001);
}
@Test(expected = IllegalStateException.class)
public void when_buildIsCalledAndChannelsNotDataSetsOrINDArrays_expect_exception() {
// Arrange
TransformProcess sut = TransformProcess.builder()
.build("test");
Map<String, Object> channelsData = new HashMap<String, Object>() {{
put("test", 1);
}};
// Act
Observation result = sut.transform(channelsData, 123, true);
}
@Test(expected = NullPointerException.class)
public void when_channelDataIsNull_expect_exception() {
// Arrange
TransformProcess sut = TransformProcess.builder()
.transform("test", new IntegerTransformOperationMock())
.build("test");
Map<String, Object> channelsData = new HashMap<String, Object>() {{
put("test", null);
}};
// Act
sut.transform(channelsData, 0, false);
}
@Test(expected = IllegalArgumentException.class)
public void when_transformAppliedOnChannelNotInMap_expect_exception() {
// Arrange
TransformProcess sut = TransformProcess.builder()
.transform("test", new IntegerTransformOperationMock())
.build("test");
Map<String, Object> channelsData = new HashMap<String, Object>() {{
put("not-test", 1);
}};
// Act
sut.transform(channelsData, 0, false);
}
@Test(expected = IllegalArgumentException.class)
public void when_preProcessAppliedOnChannelNotInMap_expect_exception() {
// Arrange
TransformProcess sut = TransformProcess.builder()
.preProcess("test", new DataSetPreProcessorMock())
.build("test");
Map<String, Object> channelsData = new HashMap<String, Object>() {{
put("not-test", 1);
}};
// Act
sut.transform(channelsData, 0, false);
}
@Test(expected = IllegalArgumentException.class)
public void when_buildContainsChannelNotInMap_expect_exception() {
// Arrange
TransformProcess sut = TransformProcess.builder()
.transform("test", new IntegerTransformOperationMock())
.build("not-test");
Map<String, Object> channelsData = new HashMap<String, Object>() {{
put("test", 1);
}};
// Act
sut.transform(channelsData, 0, false);
}
@Test(expected = IllegalArgumentException.class)
public void when_preProcessNotAppliedOnDataSet_expect_exception() {
// Arrange
TransformProcess sut = TransformProcess.builder()
.preProcess("test", new DataSetPreProcessorMock())
.build("test");
Map<String, Object> channelsData = new HashMap<String, Object>() {{
put("test", 1);
}};
// Act
sut.transform(channelsData, 0, false);
}
private static class FilterOperationMock implements FilterOperation {
private final boolean skipped;
public FilterOperationMock(boolean skipped) {
this.skipped = skipped;
}
@Override
public boolean isSkipped(Map<String, Object> channelsData, int currentObservationStep, boolean isFinalObservation) {
return skipped;
}
}
private static class IntegerTransformOperationMock implements Operation<Integer, Integer> {
public boolean isCalled = false;
@Override
public Integer transform(Integer input) {
isCalled = true;
return -input;
}
}
private static class ToDataSetTransformOperationMock implements Operation<Integer, DataSet> {
@Override
public DataSet transform(Integer input) {
return new org.nd4j.linalg.dataset.DataSet(Nd4j.create(new double[] { input }), null);
}
}
private static class ResettableTransformOperationMock implements Operation<Integer, Integer>, ResettableOperation {
private boolean isResetCalled = false;
@Override
public Integer transform(Integer input) {
return input * 10;
}
@Override
public void reset() {
isResetCalled = true;
}
}
private static class DataSetPreProcessorMock implements DataSetPreProcessor {
@Override
public void preProcess(DataSet dataSet) {
dataSet.getFeatures().muli(10.0);
}
}
}

View File

@ -0,0 +1,54 @@
package org.deeplearning4j.rl4j.observation.transform.filter;
import org.deeplearning4j.rl4j.observation.transform.FilterOperation;
import org.junit.Test;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertTrue;
public class UniformSkippingFilterTest {
@Test(expected = IllegalArgumentException.class)
public void when_negativeSkipFrame_expect_exception() {
// Act
new UniformSkippingFilter(-1);
}
@Test
public void when_skippingIs4_expect_firstNotSkippedOther3Skipped() {
// Assemble
FilterOperation sut = new UniformSkippingFilter(4);
boolean[] isSkipped = new boolean[8];
// Act
for(int i = 0; i < 8; ++i) {
isSkipped[i] = sut.isSkipped(null, i, false);
}
// Assert
assertFalse(isSkipped[0]);
assertTrue(isSkipped[1]);
assertTrue(isSkipped[2]);
assertTrue(isSkipped[3]);
assertFalse(isSkipped[4]);
assertTrue(isSkipped[5]);
assertTrue(isSkipped[6]);
assertTrue(isSkipped[7]);
}
@Test
public void when_isLastObservation_expect_observationNotSkipped() {
// Assemble
FilterOperation sut = new UniformSkippingFilter(4);
// Act
boolean isSkippedNotLastObservation = sut.isSkipped(null, 1, false);
boolean isSkippedLastObservation = sut.isSkipped(null, 1, true);
// Assert
assertTrue(isSkippedNotLastObservation);
assertFalse(isSkippedLastObservation);
}
}

View File

@ -0,0 +1,31 @@
package org.deeplearning4j.rl4j.observation.transform.operation;
import org.deeplearning4j.rl4j.helper.INDArrayHelper;
import org.junit.Test;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import static org.junit.Assert.*;
public class SimpleNormalizationTransformTest {
@Test(expected = IllegalArgumentException.class)
public void when_maxIsLessThanMin_expect_exception() {
// Arrange
SimpleNormalizationTransform sut = new SimpleNormalizationTransform(10.0, 1.0);
}
@Test
public void when_transformIsCalled_expect_inputNormalized() {
// Arrange
SimpleNormalizationTransform sut = new SimpleNormalizationTransform(1.0, 11.0);
INDArray input = Nd4j.create(new double[] { 1.0, 11.0 });
// Act
INDArray result = sut.transform(input);
// Assert
assertEquals(0.0, result.getDouble(0), 0.00001);
assertEquals(1.0, result.getDouble(1), 0.00001);
}
}

View File

@ -23,15 +23,18 @@ import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.nn.graph.ComputationGraph; import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.rl4j.learning.IHistoryProcessor; import org.deeplearning4j.rl4j.learning.IHistoryProcessor;
import org.deeplearning4j.rl4j.learning.Learning;
import org.deeplearning4j.rl4j.learning.sync.qlearning.QLearning; import org.deeplearning4j.rl4j.learning.sync.qlearning.QLearning;
import org.deeplearning4j.rl4j.learning.sync.qlearning.discrete.QLearningDiscreteTest; import org.deeplearning4j.rl4j.learning.sync.qlearning.discrete.QLearningDiscreteTest;
import org.deeplearning4j.rl4j.mdp.MDP; import org.deeplearning4j.rl4j.mdp.MDP;
import org.deeplearning4j.rl4j.network.NeuralNet; import org.deeplearning4j.rl4j.network.NeuralNet;
import org.deeplearning4j.rl4j.network.ac.IActorCritic; import org.deeplearning4j.rl4j.network.ac.IActorCritic;
import org.deeplearning4j.rl4j.observation.Observation; import org.deeplearning4j.rl4j.observation.Observation;
import org.deeplearning4j.rl4j.space.ActionSpace;
import org.deeplearning4j.rl4j.space.DiscreteSpace; import org.deeplearning4j.rl4j.space.DiscreteSpace;
import org.deeplearning4j.rl4j.space.Encodable; import org.deeplearning4j.rl4j.space.Encodable;
import org.deeplearning4j.rl4j.support.*; import org.deeplearning4j.rl4j.support.*;
import org.deeplearning4j.rl4j.util.LegacyMDPWrapper;
import org.junit.Test; import org.junit.Test;
import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
@ -186,8 +189,8 @@ public class PolicyTest {
QLearning.QLConfiguration conf = new QLearning.QLConfiguration(0, 0, 0, 5, 1, 0, QLearning.QLConfiguration conf = new QLearning.QLConfiguration(0, 0, 0, 5, 1, 0,
0, 1.0, 0, 0, 0, 0, true); 0, 1.0, 0, 0, 0, 0, true);
MockNeuralNet nnMock = new MockNeuralNet(); MockNeuralNet nnMock = new MockNeuralNet();
MockRefacPolicy sut = new MockRefacPolicy(nnMock);
IHistoryProcessor.Configuration hpConf = new IHistoryProcessor.Configuration(5, 4, 4, 4, 4, 0, 0, 2); IHistoryProcessor.Configuration hpConf = new IHistoryProcessor.Configuration(5, 4, 4, 4, 4, 0, 0, 2);
MockRefacPolicy sut = new MockRefacPolicy(nnMock, observationSpace.getShape(), hpConf.getSkipFrame(), hpConf.getHistoryLength());
MockHistoryProcessor hp = new MockHistoryProcessor(hpConf); MockHistoryProcessor hp = new MockHistoryProcessor(hpConf);
// Act // Act
@ -197,13 +200,6 @@ public class PolicyTest {
assertEquals(1, nnMock.resetCallCount); assertEquals(1, nnMock.resetCallCount);
assertEquals(465.0, totalReward, 0.0001); assertEquals(465.0, totalReward, 0.0001);
// HistoryProcessor
assertEquals(16, hp.addCalls.size());
assertEquals(31, hp.recordCalls.size());
for(int i=0; i <= 30; ++i) {
assertEquals((double)i, hp.recordCalls.get(i).getDouble(0), 0.0001);
}
// MDP // MDP
assertEquals(1, mdp.resetCount); assertEquals(1, mdp.resetCount);
assertEquals(30, mdp.actions.size()); assertEquals(30, mdp.actions.size());
@ -219,10 +215,15 @@ public class PolicyTest {
public static class MockRefacPolicy extends Policy<MockEncodable, Integer> { public static class MockRefacPolicy extends Policy<MockEncodable, Integer> {
private NeuralNet neuralNet; private NeuralNet neuralNet;
private final int[] shape;
private final int skipFrame;
private final int historyLength;
public MockRefacPolicy(NeuralNet neuralNet) { public MockRefacPolicy(NeuralNet neuralNet, int[] shape, int skipFrame, int historyLength) {
this.neuralNet = neuralNet; this.neuralNet = neuralNet;
this.shape = shape;
this.skipFrame = skipFrame;
this.historyLength = historyLength;
} }
@Override @Override
@ -239,5 +240,11 @@ public class PolicyTest {
public Integer nextAction(INDArray input) { public Integer nextAction(INDArray input) {
return (int)input.getDouble(0); return (int)input.getDouble(0);
} }
@Override
protected <AS extends ActionSpace<Integer>> Learning.InitMdp<Observation> refacInitMdp(LegacyMDPWrapper<MockEncodable, Integer, AS> mdpWrapper, IHistoryProcessor hp, RefacEpochStepCounter epochStepCounter) {
mdpWrapper.setTransformProcess(MockMDP.buildTransformProcess(shape, skipFrame, historyLength));
return super.refacInitMdp(mdpWrapper, hp, epochStepCounter);
}
} }
} }

View File

@ -2,6 +2,12 @@ package org.deeplearning4j.rl4j.support;
import org.deeplearning4j.gym.StepReply; import org.deeplearning4j.gym.StepReply;
import org.deeplearning4j.rl4j.mdp.MDP; import org.deeplearning4j.rl4j.mdp.MDP;
import org.deeplearning4j.rl4j.observation.transform.TransformProcess;
import org.deeplearning4j.rl4j.observation.transform.filter.UniformSkippingFilter;
import org.deeplearning4j.rl4j.observation.transform.legacy.EncodableToINDArrayTransform;
import org.deeplearning4j.rl4j.observation.transform.operation.HistoryMergeTransform;
import org.deeplearning4j.rl4j.observation.transform.operation.SimpleNormalizationTransform;
import org.deeplearning4j.rl4j.observation.transform.operation.historymerge.CircularFifoStore;
import org.deeplearning4j.rl4j.space.DiscreteSpace; import org.deeplearning4j.rl4j.space.DiscreteSpace;
import org.deeplearning4j.rl4j.space.ObservationSpace; import org.deeplearning4j.rl4j.space.ObservationSpace;
import org.nd4j.linalg.api.rng.Random; import org.nd4j.linalg.api.rng.Random;
@ -77,4 +83,16 @@ public class MockMDP implements MDP<MockEncodable, Integer, DiscreteSpace> {
public MDP newInstance() { public MDP newInstance() {
return null; return null;
} }
public static TransformProcess buildTransformProcess(int[] shape, int skipFrame, int historyLength) {
return TransformProcess.builder()
.filter(new UniformSkippingFilter(skipFrame))
.transform("data", new EncodableToINDArrayTransform(shape))
.transform("data", new SimpleNormalizationTransform(0.0, 255.0))
.transform("data", HistoryMergeTransform.builder()
.elementStore(new CircularFifoStore(historyLength))
.build())
.build("data");
}
} }