diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/observation/Observation.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/observation/Observation.java index af1fe18ea..c8a0c38d9 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/observation/Observation.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/observation/Observation.java @@ -1,5 +1,5 @@ /******************************************************************************* - * Copyright (c) 2015-2019 Skymind, Inc. + * Copyright (c) 2020 Konduit K.K. * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at @@ -17,57 +17,43 @@ package org.deeplearning4j.rl4j.observation; import lombok.Getter; -import lombok.Setter; import org.nd4j.linalg.api.ndarray.INDArray; -import org.nd4j.linalg.dataset.api.DataSet; -import org.nd4j.linalg.factory.Nd4j; /** - * Presently only a dummy container. Will contain observation channels when done. + * Represent an observation from the environment + * + * @author Alexandre Boulanger */ public class Observation { - // TODO: Presently only a dummy container. Will contain observation channels when done. - private final DataSet data; + /** + * A singleton representing a skipped observation + */ + public static Observation SkippedObservation = new Observation(null); - @Getter @Setter - private boolean skipped; + /** + * @return A INDArray containing the data of the observation + */ + @Getter + private final INDArray data; - public Observation(INDArray[] data) { - this(data, false); + public boolean isSkipped() { + return data == null; } - public Observation(INDArray[] data, boolean shouldReshape) { - INDArray features = Nd4j.concat(0, data); - if(shouldReshape) { - features = reshape(features); - } - this.data = new org.nd4j.linalg.dataset.DataSet(features, null); - } - - // FIXME: Remove -- only used in unit tests public Observation(INDArray data) { - this.data = new org.nd4j.linalg.dataset.DataSet(data, null); - } - - private INDArray reshape(INDArray source) { - long[] shape = source.shape(); - long[] nshape = new long[shape.length + 1]; - nshape[0] = 1; - System.arraycopy(shape, 0, nshape, 1, shape.length); - - return source.reshape(nshape); - } - - private Observation(DataSet data) { this.data = data; } + /** + * Creates a duplicate instance of the current observation + * @return + */ public Observation dup() { - return new Observation(new org.nd4j.linalg.dataset.DataSet(data.getFeatures().dup(), null)); - } + if(data == null) { + return SkippedObservation; + } - public INDArray getData() { - return data.getFeatures(); + return new Observation(data.dup()); } } diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/observation/transform/TransformProcess.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/observation/transform/TransformProcess.java new file mode 100644 index 000000000..53bc01421 --- /dev/null +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/observation/transform/TransformProcess.java @@ -0,0 +1,231 @@ +/******************************************************************************* + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ +package org.deeplearning4j.rl4j.observation.transform; + +import org.apache.commons.lang3.NotImplementedException; +import org.deeplearning4j.rl4j.helper.INDArrayHelper; +import org.deeplearning4j.rl4j.observation.Observation; +import org.nd4j.base.Preconditions; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.dataset.api.DataSet; +import org.nd4j.linalg.dataset.api.DataSetPreProcessor; +import org.nd4j.shade.guava.collect.Maps; +import org.datavec.api.transform.Operation; + +import java.util.ArrayList; +import java.util.HashSet; +import java.util.List; +import java.util.Map; + +/** + * A TransformProcess will build an {@link Observation Observation} from the raw data coming from the environment. + * Three types of steps are available: + * 1) A {@link FilterOperation FilterOperation}: Used to determine if an observation should be skipped. + * 2) An {@link Operation Operation}: Applies a transform and/or conversion to an observation channel. + * 3) A {@link DataSetPreProcessor DataSetPreProcessor}: Applies a DataSetPreProcessor to the observation channel. The data in the channel must be a DataSet. + * + * Instances of the three types above can be called in any order. The only requirement is that when build() is called, + * all channels must be instances of INDArrays or DataSets + * + * NOTE: Presently, only single-channels observations are supported. + * + * @author Alexandre Boulanger + */ +public class TransformProcess { + + private final List> operations; + private final String[] channelNames; + private final HashSet operationsChannelNames; + + private TransformProcess(Builder builder, String... channelNames) { + operations = builder.operations; + this.channelNames = channelNames; + this.operationsChannelNames = builder.requiredChannelNames; + } + + /** + * This method will call reset() of all steps implementing {@link ResettableOperation ResettableOperation} in the transform process. + */ + public void reset() { + for(Map.Entry entry : operations) { + if(entry.getValue() instanceof ResettableOperation) { + ((ResettableOperation) entry.getValue()).reset(); + } + } + } + + /** + * Transforms the channel data into an Observation or a skipped observation depending on the specific steps in the transform process. + * + * @param channelsData A Map that maps the channel name to its data. + * @param currentObservationStep The observation's step number within the episode. + * @param isFinalObservation True if the observation is the last of the episode. + * @return An observation (may be a skipped observation) + */ + public Observation transform(Map channelsData, int currentObservationStep, boolean isFinalObservation) { + // null or empty channelData + Preconditions.checkArgument(channelsData != null && channelsData.size() != 0, "Error: channelsData not supplied."); + + // Check that all channels have data + for(Map.Entry channel : channelsData.entrySet()) { + Preconditions.checkNotNull(channel.getValue(), "Error: data of channel '%s' is null", channel.getKey()); + } + + // Check that all required channels are present + for(String channelName : operationsChannelNames) { + Preconditions.checkArgument(channelsData.containsKey(channelName), "The channelsData map does not contain the channel '%s'", channelName); + } + + for(Map.Entry entry : operations) { + + // Filter + if(entry.getValue() instanceof FilterOperation) { + FilterOperation filterOperation = (FilterOperation)entry.getValue(); + if(filterOperation.isSkipped(channelsData, currentObservationStep, isFinalObservation)) { + return Observation.SkippedObservation; + } + } + + // Transform + // null results are considered skipped observations + else if(entry.getValue() instanceof Operation) { + Operation transformOperation = (Operation)entry.getValue(); + Object transformed = transformOperation.transform(channelsData.get(entry.getKey())); + if(transformed == null) { + return Observation.SkippedObservation; + } + channelsData.replace(entry.getKey(), transformed); + } + + // PreProcess + else if(entry.getValue() instanceof DataSetPreProcessor) { + Object channelData = channelsData.get(entry.getKey()); + DataSetPreProcessor dataSetPreProcessor = (DataSetPreProcessor)entry.getValue(); + if(!(channelData instanceof DataSet)) { + throw new IllegalArgumentException("The channel data must be a DataSet to call preProcess"); + } + dataSetPreProcessor.preProcess((DataSet)channelData); + } + + else { + throw new IllegalArgumentException(String.format("Unknown operation: '%s'", entry.getValue().getClass().getName())); + } + } + + // Check that all channels used to build the observation are instances of + // INDArray or DataSet + // TODO: Add support for an interface with a toINDArray() method + for(String channelName : channelNames) { + Object channelData = channelsData.get(channelName); + + INDArray finalChannelData; + if(channelData instanceof DataSet) { + finalChannelData = ((DataSet)channelData).getFeatures(); + } + else if(channelData instanceof INDArray) { + finalChannelData = (INDArray) channelData; + } + else { + throw new IllegalStateException("All channels used to build the observation must be instances of DataSet or INDArray"); + } + + // The dimension 0 of all INDArrays must be 1 (batch count) + channelsData.replace(channelName, INDArrayHelper.forceCorrectShape(finalChannelData)); + } + + // TODO: Add support to multi-channel observations + INDArray data = ((INDArray) channelsData.get(channelNames[0])); + return new Observation(data); + } + + /** + * @return An instance of a builder + */ + public static Builder builder() { + return new Builder(); + } + + public static class Builder { + + private final List> operations = new ArrayList>(); + private final HashSet requiredChannelNames = new HashSet(); + + /** + * Add a filter to the transform process steps. Used to skip observations on certain conditions. + * See {@link FilterOperation FilterOperation} + * @param filterOperation An instance + */ + public Builder filter(FilterOperation filterOperation) { + Preconditions.checkNotNull(filterOperation, "The filterOperation must not be null"); + + operations.add((Map.Entry)Maps.immutableEntry(null, filterOperation)); + return this; + } + + /** + * Add a transform to the steps. The transform can change the data and / or change the type of the data + * (e.g. Byte[] to a ImageWritable) + * + * @param targetChannel The name of the channel to which the transform is applied. + * @param transformOperation An instance of {@link Operation Operation} + */ + public Builder transform(String targetChannel, Operation transformOperation) { + Preconditions.checkNotNull(targetChannel, "The targetChannel must not be null"); + Preconditions.checkNotNull(transformOperation, "The transformOperation must not be null"); + + requiredChannelNames.add(targetChannel); + operations.add((Map.Entry)Maps.immutableEntry(targetChannel, transformOperation)); + return this; + } + + /** + * Add a DataSetPreProcessor to the steps. The channel must be a DataSet instance at this step. + * @param targetChannel The name of the channel to which the pre processor is applied. + * @param dataSetPreProcessor + */ + public Builder preProcess(String targetChannel, DataSetPreProcessor dataSetPreProcessor) { + Preconditions.checkNotNull(targetChannel, "The targetChannel must not be null"); + Preconditions.checkNotNull(dataSetPreProcessor, "The dataSetPreProcessor must not be null"); + + requiredChannelNames.add(targetChannel); + operations.add((Map.Entry)Maps.immutableEntry(targetChannel, dataSetPreProcessor)); + return this; + } + + /** + * Builds the TransformProcess. + * @param channelNames A subset of channel names to be used to build the observation + * @return An instance of TransformProcess + */ + public TransformProcess build(String... channelNames) { + if(channelNames.length == 0) { + throw new IllegalArgumentException("At least one channel must be supplied."); + } + + for(String channelName : channelNames) { + Preconditions.checkNotNull(channelName, "Error: got a null channel name"); + requiredChannelNames.add(channelName); + } + + // TODO: Remove when multi-channel observation is supported + if(channelNames.length != 1) { + throw new NotImplementedException("Multi-channel observations is not presently supported."); + } + + return new TransformProcess(this, channelNames); + } + } +} diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/observation/transform/legacy/EncodableToImageWriteableTransform.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/observation/transform/legacy/EncodableToImageWritableTransform.java similarity index 88% rename from rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/observation/transform/legacy/EncodableToImageWriteableTransform.java rename to rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/observation/transform/legacy/EncodableToImageWritableTransform.java index b1c32abed..3eaeec4dc 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/observation/transform/legacy/EncodableToImageWriteableTransform.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/observation/transform/legacy/EncodableToImageWritableTransform.java @@ -25,14 +25,14 @@ import org.nd4j.linalg.factory.Nd4j; import static org.bytedeco.opencv.global.opencv_core.CV_32FC; -public class EncodableToImageWriteableTransform implements Operation { +public class EncodableToImageWritableTransform implements Operation { private final OpenCVFrameConverter.ToMat converter = new OpenCVFrameConverter.ToMat(); private final int height; private final int width; private final int colorChannels; - public EncodableToImageWriteableTransform(int height, int width, int colorChannels) { + public EncodableToImageWritableTransform(int height, int width, int colorChannels) { this.height = height; this.width = width; this.colorChannels = colorChannels; diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/observation/transform/legacy/ImageWriteableToINDArrayTransform.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/observation/transform/legacy/ImageWritableToINDArrayTransform.java similarity index 51% rename from rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/observation/transform/legacy/ImageWriteableToINDArrayTransform.java rename to rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/observation/transform/legacy/ImageWritableToINDArrayTransform.java index d20f9f9f8..3a48c128a 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/observation/transform/legacy/ImageWriteableToINDArrayTransform.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/observation/transform/legacy/ImageWritableToINDArrayTransform.java @@ -1,3 +1,18 @@ +/******************************************************************************* + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ package org.deeplearning4j.rl4j.observation.transform.legacy; import org.datavec.api.transform.Operation; @@ -10,13 +25,13 @@ import org.nd4j.linalg.factory.Nd4j; import java.io.IOException; -public class ImageWriteableToINDArrayTransform implements Operation { +public class ImageWritableToINDArrayTransform implements Operation { private final int height; private final int width; private final NativeImageLoader loader; - public ImageWriteableToINDArrayTransform(int height, int width) { + public ImageWritableToINDArrayTransform(int height, int width) { this.height = height; this.width = width; this.loader = new NativeImageLoader(height, width); diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/observation/transform/operation/SimpleNormalizationTransform.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/observation/transform/operation/SimpleNormalizationTransform.java new file mode 100644 index 000000000..592a1f86d --- /dev/null +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/observation/transform/operation/SimpleNormalizationTransform.java @@ -0,0 +1,44 @@ +/******************************************************************************* + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ +package org.deeplearning4j.rl4j.observation.transform.operation; + +import org.datavec.api.transform.Operation; +import org.nd4j.base.Preconditions; +import org.nd4j.linalg.api.ndarray.INDArray; + +public class SimpleNormalizationTransform implements Operation { + + private final double offset; + private final double divisor; + + public SimpleNormalizationTransform(double min, double max) { + Preconditions.checkArgument(min < max, "Min must be smaller than max."); + + this.offset = min; + this.divisor = (max - min); + } + + @Override + public INDArray transform(INDArray input) { + if(offset != 0.0) { + input.subi(offset); + } + + input.divi(divisor); + + return input; + } +} diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/util/LegacyMDPWrapper.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/util/LegacyMDPWrapper.java index 26546a923..b0f46ef57 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/util/LegacyMDPWrapper.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/util/LegacyMDPWrapper.java @@ -3,37 +3,95 @@ package org.deeplearning4j.rl4j.util; import lombok.AccessLevel; import lombok.Getter; import lombok.Setter; +import org.datavec.image.transform.ColorConversionTransform; +import org.datavec.image.transform.CropImageTransform; +import org.datavec.image.transform.MultiImageTransform; +import org.datavec.image.transform.ResizeImageTransform; import org.deeplearning4j.gym.StepReply; import org.deeplearning4j.rl4j.learning.EpochStepCounter; import org.deeplearning4j.rl4j.learning.IHistoryProcessor; import org.deeplearning4j.rl4j.mdp.MDP; import org.deeplearning4j.rl4j.observation.Observation; +import org.deeplearning4j.rl4j.observation.transform.TransformProcess; +import org.deeplearning4j.rl4j.observation.transform.filter.UniformSkippingFilter; +import org.deeplearning4j.rl4j.observation.transform.legacy.EncodableToINDArrayTransform; +import org.deeplearning4j.rl4j.observation.transform.legacy.EncodableToImageWritableTransform; +import org.deeplearning4j.rl4j.observation.transform.legacy.ImageWritableToINDArrayTransform; +import org.deeplearning4j.rl4j.observation.transform.operation.HistoryMergeTransform; +import org.deeplearning4j.rl4j.observation.transform.operation.SimpleNormalizationTransform; import org.deeplearning4j.rl4j.space.ActionSpace; import org.deeplearning4j.rl4j.space.Encodable; import org.deeplearning4j.rl4j.space.ObservationSpace; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; +import java.util.HashMap; +import java.util.Map; + +import static org.bytedeco.opencv.global.opencv_imgproc.COLOR_BGR2GRAY; + public class LegacyMDPWrapper> implements MDP { @Getter private final MDP wrappedMDP; @Getter private final WrapperObservationSpace observationSpace; + private final int[] shape; - @Getter(AccessLevel.PRIVATE) @Setter(AccessLevel.PUBLIC) + @Setter + private TransformProcess transformProcess; + + @Getter(AccessLevel.PRIVATE) private IHistoryProcessor historyProcessor; private final EpochStepCounter epochStepCounter; private int skipFrame = 1; - private int requiredFrame = 0; public LegacyMDPWrapper(MDP wrappedMDP, IHistoryProcessor historyProcessor, EpochStepCounter epochStepCounter) { this.wrappedMDP = wrappedMDP; - this.observationSpace = new WrapperObservationSpace(wrappedMDP.getObservationSpace().getShape()); + this.shape = wrappedMDP.getObservationSpace().getShape(); + this.observationSpace = new WrapperObservationSpace(shape); this.historyProcessor = historyProcessor; this.epochStepCounter = epochStepCounter; + + setHistoryProcessor(historyProcessor); + } + + public void setHistoryProcessor(IHistoryProcessor historyProcessor) { + this.historyProcessor = historyProcessor; + createTransformProcess(); + } + + private void createTransformProcess() { + IHistoryProcessor historyProcessor = getHistoryProcessor(); + + if(historyProcessor != null && shape.length == 3) { + int skipFrame = historyProcessor.getConf().getSkipFrame(); + + int finalHeight = historyProcessor.getConf().getCroppingHeight(); + int finalWidth = historyProcessor.getConf().getCroppingWidth(); + + transformProcess = TransformProcess.builder() + .filter(new UniformSkippingFilter(skipFrame)) + .transform("data", new EncodableToImageWritableTransform(shape[0], shape[1], shape[2])) + .transform("data", new MultiImageTransform( + new ResizeImageTransform(historyProcessor.getConf().getRescaledWidth(), historyProcessor.getConf().getRescaledHeight()), + new ColorConversionTransform(COLOR_BGR2GRAY), + new CropImageTransform(historyProcessor.getConf().getOffsetY(), historyProcessor.getConf().getOffsetX(), finalHeight, finalWidth) + )) + .transform("data", new ImageWritableToINDArrayTransform(finalHeight, finalWidth)) + .transform("data", new SimpleNormalizationTransform(0.0, 255.0)) + .transform("data", HistoryMergeTransform.builder() + .isFirstDimenstionBatch(true) + .build()) + .build("data"); + } + else { + transformProcess = TransformProcess.builder() + .transform("data", new EncodableToINDArrayTransform(shape)) + .build("data"); + } } @Override @@ -43,25 +101,17 @@ public class LegacyMDPWrapper> implements MDP channelsData = buildChannelsData(rawResetResponse); + return transformProcess.transform(channelsData, 0, false); } @Override @@ -71,32 +121,32 @@ public class LegacyMDPWrapper> implements MDP rawStepReply = wrappedMDP.step(a); INDArray rawObservation = getInput(rawStepReply.getObservation()); - int stepOfObservation = epochStepCounter.getCurrentEpochStep() + 1; - if(historyProcessor != null) { historyProcessor.record(rawObservation); - - if (stepOfObservation % skipFrame == 0) { - historyProcessor.add(rawObservation); - } } - Observation observation; - if(historyProcessor != null && stepOfObservation >= requiredFrame) { - observation = new Observation(historyProcessor.getHistory(), true); - observation.getData().muli(1.0 / historyProcessor.getScale()); - } - else { - observation = new Observation(new INDArray[] { rawObservation }, false); - } - - if(stepOfObservation % skipFrame != 0 || stepOfObservation < requiredFrame) { - observation.setSkipped(true); - } + int stepOfObservation = epochStepCounter.getCurrentEpochStep() + 1; + Map channelsData = buildChannelsData(rawStepReply.getObservation()); + Observation observation = transformProcess.transform(channelsData, stepOfObservation, rawStepReply.isDone()); return new StepReply(observation, rawStepReply.getReward(), rawStepReply.isDone(), rawStepReply.getInfo()); } + private void record(O obs) { + INDArray rawObservation = getInput(obs); + + IHistoryProcessor historyProcessor = getHistoryProcessor(); + if(historyProcessor != null) { + historyProcessor.record(rawObservation); + } + } + + private Map buildChannelsData(final O obs) { + return new HashMap() {{ + put("data", obs); + }}; + } + @Override public void close() { wrappedMDP.close(); diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/async/AsyncThreadDiscreteTest.java b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/async/AsyncThreadDiscreteTest.java index 8a0090b62..bc396502f 100644 --- a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/async/AsyncThreadDiscreteTest.java +++ b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/async/AsyncThreadDiscreteTest.java @@ -34,6 +34,7 @@ public class AsyncThreadDiscreteTest { MockPolicy policyMock = new MockPolicy(); MockAsyncConfiguration config = new MockAsyncConfiguration(5, 100, 0, 0, 2, 5,0, 0, 0, 0); TestAsyncThreadDiscrete sut = new TestAsyncThreadDiscrete(asyncGlobalMock, mdpMock, listeners, 0, 0, policyMock, config, hpMock); + sut.getLegacyMDPWrapper().setTransformProcess(MockMDP.buildTransformProcess(observationSpace.getShape(), hpConf.getSkipFrame(), hpConf.getHistoryLength())); // Act sut.run(); @@ -60,12 +61,6 @@ public class AsyncThreadDiscreteTest { assertEquals(2, asyncGlobalMock.enqueueCallCount); // HistoryProcessor - double[] expectedAddValues = new double[] { 0.0, 2.0, 4.0, 6.0, 8.0, 10.0, 12.0, 14.0, 16.0 }; - assertEquals(expectedAddValues.length, hpMock.addCalls.size()); - for(int i = 0; i < expectedAddValues.length; ++i) { - assertEquals(expectedAddValues[i], hpMock.addCalls.get(i).getDouble(0), 0.00001); - } - double[] expectedRecordValues = new double[] { 0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0, }; assertEquals(expectedRecordValues.length, hpMock.recordCalls.size()); for(int i = 0; i < expectedRecordValues.length; ++i) { diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/async/AsyncThreadTest.java b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/async/AsyncThreadTest.java index b01105294..3dea25936 100644 --- a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/async/AsyncThreadTest.java +++ b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/async/AsyncThreadTest.java @@ -138,6 +138,7 @@ public class AsyncThreadTest { asyncGlobal.setMaxLoops(numEpochs); listeners.add(listener); sut.setHistoryProcessor(historyProcessor); + sut.getLegacyMDPWrapper().setTransformProcess(MockMDP.buildTransformProcess(observationSpace.getShape(), hpConf.getSkipFrame(), hpConf.getHistoryLength())); } } @@ -209,7 +210,4 @@ public class AsyncThreadTest { int nstep; } } - - - } diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/ExpReplayTest.java b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/ExpReplayTest.java index 73b27776a..08c8ba24e 100644 --- a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/ExpReplayTest.java +++ b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/ExpReplayTest.java @@ -18,7 +18,7 @@ public class ExpReplayTest { ExpReplay sut = new ExpReplay(2, 1, randomMock); // Act - Transition transition = buildTransition(new Observation(new INDArray[] { Nd4j.create(1) }), + Transition transition = buildTransition(buildObservation(), 123, 234, new Observation(Nd4j.create(1))); sut.store(transition); List> results = sut.getBatch(1); @@ -36,11 +36,11 @@ public class ExpReplayTest { ExpReplay sut = new ExpReplay(2, 1, randomMock); // Act - Transition transition1 = buildTransition(new Observation(new INDArray[] { Nd4j.create(1) }), + Transition transition1 = buildTransition(buildObservation(), 1, 2, new Observation(Nd4j.create(1))); - Transition transition2 = buildTransition(new Observation(new INDArray[] { Nd4j.create(1) }), + Transition transition2 = buildTransition(buildObservation(), 3, 4, new Observation(Nd4j.create(1))); - Transition transition3 = buildTransition(new Observation(new INDArray[] { Nd4j.create(1) }), + Transition transition3 = buildTransition(buildObservation(), 5, 6, new Observation(Nd4j.create(1))); sut.store(transition1); sut.store(transition2); @@ -78,11 +78,11 @@ public class ExpReplayTest { ExpReplay sut = new ExpReplay(5, 1, randomMock); // Act - Transition transition1 = buildTransition(new Observation(new INDArray[] { Nd4j.create(1) }), + Transition transition1 = buildTransition(buildObservation(), 1, 2, new Observation(Nd4j.create(1))); - Transition transition2 = buildTransition(new Observation(new INDArray[] { Nd4j.create(1) }), + Transition transition2 = buildTransition(buildObservation(), 3, 4, new Observation(Nd4j.create(1))); - Transition transition3 = buildTransition(new Observation(new INDArray[] { Nd4j.create(1) }), + Transition transition3 = buildTransition(buildObservation(), 5, 6, new Observation(Nd4j.create(1))); sut.store(transition1); sut.store(transition2); @@ -100,11 +100,11 @@ public class ExpReplayTest { ExpReplay sut = new ExpReplay(5, 1, randomMock); // Act - Transition transition1 = buildTransition(new Observation(new INDArray[] { Nd4j.create(1) }), + Transition transition1 = buildTransition(buildObservation(), 1, 2, new Observation(Nd4j.create(1))); - Transition transition2 = buildTransition(new Observation(new INDArray[] { Nd4j.create(1) }), + Transition transition2 = buildTransition(buildObservation(), 3, 4, new Observation(Nd4j.create(1))); - Transition transition3 = buildTransition(new Observation(new INDArray[] { Nd4j.create(1) }), + Transition transition3 = buildTransition(buildObservation(), 5, 6, new Observation(Nd4j.create(1))); sut.store(transition1); sut.store(transition2); @@ -131,15 +131,15 @@ public class ExpReplayTest { ExpReplay sut = new ExpReplay(5, 1, randomMock); // Act - Transition transition1 = buildTransition(new Observation(new INDArray[] { Nd4j.create(1) }), + Transition transition1 = buildTransition(buildObservation(), 1, 2, new Observation(Nd4j.create(1))); - Transition transition2 = buildTransition(new Observation(new INDArray[] { Nd4j.create(1) }), + Transition transition2 = buildTransition(buildObservation(), 3, 4, new Observation(Nd4j.create(1))); - Transition transition3 = buildTransition(new Observation(new INDArray[] { Nd4j.create(1) }), + Transition transition3 = buildTransition(buildObservation(), 5, 6, new Observation(Nd4j.create(1))); - Transition transition4 = buildTransition(new Observation(new INDArray[] { Nd4j.create(1) }), + Transition transition4 = buildTransition(buildObservation(), 7, 8, new Observation(Nd4j.create(1))); - Transition transition5 = buildTransition(new Observation(new INDArray[] { Nd4j.create(1) }), + Transition transition5 = buildTransition(buildObservation(), 9, 10, new Observation(Nd4j.create(1))); sut.store(transition1); sut.store(transition2); @@ -168,15 +168,15 @@ public class ExpReplayTest { ExpReplay sut = new ExpReplay(5, 1, randomMock); // Act - Transition transition1 = buildTransition(new Observation(new INDArray[] { Nd4j.create(1) }), + Transition transition1 = buildTransition(buildObservation(), 1, 2, new Observation(Nd4j.create(1))); - Transition transition2 = buildTransition(new Observation(new INDArray[] { Nd4j.create(1) }), + Transition transition2 = buildTransition(buildObservation(), 3, 4, new Observation(Nd4j.create(1))); - Transition transition3 = buildTransition(new Observation(new INDArray[] { Nd4j.create(1) }), + Transition transition3 = buildTransition(buildObservation(), 5, 6, new Observation(Nd4j.create(1))); - Transition transition4 = buildTransition(new Observation(new INDArray[] { Nd4j.create(1) }), + Transition transition4 = buildTransition(buildObservation(), 7, 8, new Observation(Nd4j.create(1))); - Transition transition5 = buildTransition(new Observation(new INDArray[] { Nd4j.create(1) }), + Transition transition5 = buildTransition(buildObservation(), 9, 10, new Observation(Nd4j.create(1))); sut.store(transition1); sut.store(transition2); @@ -204,4 +204,8 @@ public class ExpReplayTest { return result; } + + private Observation buildObservation() { + return new Observation(Nd4j.create(1, 1)); + } } diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/TransitionTest.java b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/TransitionTest.java index 944e41a31..374b6a140 100644 --- a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/TransitionTest.java +++ b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/TransitionTest.java @@ -193,11 +193,11 @@ public class TransitionTest { Nd4j.create(obs[1]).reshape(1, 3), Nd4j.create(obs[2]).reshape(1, 3), }; - return new Observation(history); + return new Observation(Nd4j.concat(0, history)); } private Observation buildObservation(double[] obs) { - return new Observation(new INDArray[] { Nd4j.create(obs).reshape(1, 3) }); + return new Observation(Nd4j.create(obs).reshape(1, 3)); } private Observation buildNextObservation(double[][] obs, double[] nextObs) { @@ -206,7 +206,7 @@ public class TransitionTest { Nd4j.create(obs[0]).reshape(1, 3), Nd4j.create(obs[1]).reshape(1, 3), }; - return new Observation(nextHistory); + return new Observation(Nd4j.concat(0, nextHistory)); } private Transition buildTransition(Observation observation, int action, double reward, Observation nextObservation) { diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/QLearningDiscreteTest.java b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/QLearningDiscreteTest.java index ee8b365f0..58aaab297 100644 --- a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/QLearningDiscreteTest.java +++ b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/QLearningDiscreteTest.java @@ -50,6 +50,7 @@ public class QLearningDiscreteTest { IHistoryProcessor.Configuration hpConf = new IHistoryProcessor.Configuration(5, 4, 4, 4, 4, 0, 0, 2); MockHistoryProcessor hp = new MockHistoryProcessor(hpConf); sut.setHistoryProcessor(hp); + sut.getLegacyMDPWrapper().setTransformProcess(MockMDP.buildTransformProcess(observationSpace.getShape(), hpConf.getSkipFrame(), hpConf.getHistoryLength())); List> results = new ArrayList<>(); // Act @@ -62,11 +63,7 @@ public class QLearningDiscreteTest { for(int i = 0; i < expectedRecords.length; ++i) { assertEquals(expectedRecords[i], hp.recordCalls.get(i).getDouble(0), 0.0001); } - double[] expectedAdds = new double[] { 0.0, 2.0, 4.0, 6.0, 8.0, 10.0, 12.0, 14.0, 16.0, 18.0, 20.0, 22.0, 24.0 }; - assertEquals(expectedAdds.length, hp.addCalls.size()); - for(int i = 0; i < expectedAdds.length; ++i) { - assertEquals(expectedAdds[i], hp.addCalls.get(i).getDouble(0), 0.0001); - } + assertEquals(0, hp.startMonitorCallCount); assertEquals(0, hp.stopMonitorCallCount); diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/TDTargetAlgorithm/DoubleDQNTest.java b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/TDTargetAlgorithm/DoubleDQNTest.java index 760666f33..798bddf0d 100644 --- a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/TDTargetAlgorithm/DoubleDQNTest.java +++ b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/TDTargetAlgorithm/DoubleDQNTest.java @@ -106,7 +106,7 @@ public class DoubleDQNTest { } private Observation buildObservation(double[] data) { - return new Observation(new INDArray[]{Nd4j.create(data).reshape(1, 2)}); + return new Observation(Nd4j.create(data).reshape(1, 2)); } private Transition builtTransition(Observation observation, Integer action, double reward, boolean isTerminal, Observation nextObservation) { diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/TDTargetAlgorithm/StandardDQNTest.java b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/TDTargetAlgorithm/StandardDQNTest.java index c540646a8..3e3701669 100644 --- a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/TDTargetAlgorithm/StandardDQNTest.java +++ b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/TDTargetAlgorithm/StandardDQNTest.java @@ -105,7 +105,7 @@ public class StandardDQNTest { } private Observation buildObservation(double[] data) { - return new Observation(new INDArray[]{Nd4j.create(data).reshape(1, 2)}); + return new Observation(Nd4j.create(data).reshape(1, 2)); } private Transition buildTransition(Observation observation, Integer action, double reward, boolean isTerminal, Observation nextObservation) { diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/observation/transform/TransformProcessTest.java b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/observation/transform/TransformProcessTest.java new file mode 100644 index 000000000..fe79bdfc7 --- /dev/null +++ b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/observation/transform/TransformProcessTest.java @@ -0,0 +1,378 @@ +package org.deeplearning4j.rl4j.observation.transform; + +import org.deeplearning4j.rl4j.observation.Observation; +import org.junit.Test; +import org.nd4j.linalg.dataset.api.DataSet; +import org.nd4j.linalg.dataset.api.DataSetPreProcessor; +import org.nd4j.linalg.factory.Nd4j; +import org.datavec.api.transform.Operation; + +import java.util.HashMap; +import java.util.Map; + +import static org.junit.Assert.*; + +public class TransformProcessTest { + @Test(expected = IllegalArgumentException.class) + public void when_noChannelNameIsSuppliedToBuild_expect_exception() { + // Arrange + TransformProcess.builder().build(); + } + + @Test(expected = IllegalArgumentException.class) + public void when_callingTransformWithNullArg_expect_exception() { + // Arrange + TransformProcess sut = TransformProcess.builder() + .build("test"); + + // Act + sut.transform(null, 0, false); + } + + @Test(expected = IllegalArgumentException.class) + public void when_callingTransformWithEmptyChannelData_expect_exception() { + // Arrange + TransformProcess sut = TransformProcess.builder() + .build("test"); + Map channelsData = new HashMap(); + + // Act + sut.transform(channelsData, 0, false); + } + + @Test(expected = NullPointerException.class) + public void when_addingNullFilter_expect_nullException() { + // Act + TransformProcess.builder().filter(null); + } + + @Test + public void when_fileteredOut_expect_skippedObservationAndFollowingOperationsSkipped() { + // Arrange + IntegerTransformOperationMock transformOperationMock = new IntegerTransformOperationMock(); + TransformProcess sut = TransformProcess.builder() + .filter(new FilterOperationMock(true)) + .transform("test", transformOperationMock) + .build("test"); + Map channelsData = new HashMap() {{ + put("test", 1); + }}; + + // Act + Observation result = sut.transform(channelsData, 0, false); + + // Assert + assertTrue(result.isSkipped()); + assertFalse(transformOperationMock.isCalled); + } + + @Test(expected = NullPointerException.class) + public void when_addingTransformOnNullChannel_expect_nullException() { + // Act + TransformProcess.builder().transform(null, new IntegerTransformOperationMock()); + } + + @Test(expected = NullPointerException.class) + public void when_addingTransformWithNullTransform_expect_nullException() { + // Act + TransformProcess.builder().transform("test", null); + } + + @Test + public void when_transformIsCalled_expect_channelDataTransformedInSameOrder() { + // Arrange + TransformProcess sut = TransformProcess.builder() + .filter(new FilterOperationMock(false)) + .transform("test", new IntegerTransformOperationMock()) + .transform("test", new ToDataSetTransformOperationMock()) + .build("test"); + Map channelsData = new HashMap() {{ + put("test", 1); + }}; + + // Act + Observation result = sut.transform(channelsData, 0, false); + + // Assert + assertFalse(result.isSkipped()); + assertEquals(-1.0, result.getData().getDouble(0), 0.00001); + } + + @Test(expected = NullPointerException.class) + public void when_addingPreProcessOnNullChannel_expect_nullException() { + // Act + TransformProcess.builder().preProcess(null, new DataSetPreProcessorMock()); + } + + @Test(expected = NullPointerException.class) + public void when_addingPreProcessWithNullTransform_expect_nullException() { + // Act + TransformProcess.builder().transform("test", null); + } + + @Test + public void when_preProcessIsCalled_expect_channelDataPreProcessedInSameOrder() { + // Arrange + TransformProcess sut = TransformProcess.builder() + .filter(new FilterOperationMock(false)) + .transform("test", new IntegerTransformOperationMock()) + .transform("test", new ToDataSetTransformOperationMock()) + .preProcess("test", new DataSetPreProcessorMock()) + .build("test"); + Map channelsData = new HashMap() {{ + put("test", 1); + }}; + + // Act + Observation result = sut.transform(channelsData, 0, false); + + // Assert + assertFalse(result.isSkipped()); + assertEquals(1, result.getData().shape().length); + assertEquals(1, result.getData().shape()[0]); + assertEquals(-10.0, result.getData().getDouble(0), 0.00001); + } + + @Test(expected = IllegalStateException.class) + public void when_transformingNullData_expect_exception() { + // Arrange + TransformProcess sut = TransformProcess.builder() + .transform("test", new IntegerTransformOperationMock()) + .build("test"); + Map channelsData = new HashMap() {{ + put("test", 1); + }}; + + // Act + Observation result = sut.transform(channelsData, 0, false); + } + + @Test(expected = IllegalArgumentException.class) + public void when_transformingAndChannelsNotDataSet_expect_exception() { + // Arrange + TransformProcess sut = TransformProcess.builder() + .preProcess("test", new DataSetPreProcessorMock()) + .build("test"); + + // Act + Observation result = sut.transform(null, 0, false); + } + + + @Test(expected = IllegalArgumentException.class) + public void when_transformingAndChannelsEmptyDataSet_expect_exception() { + // Arrange + TransformProcess sut = TransformProcess.builder() + .preProcess("test", new DataSetPreProcessorMock()) + .build("test"); + Map channelsData = new HashMap(); + + // Act + Observation result = sut.transform(channelsData, 0, false); + } + + @Test(expected = IllegalArgumentException.class) + public void when_buildIsCalledWithoutChannelNames_expect_exception() { + // Act + TransformProcess.builder().build(); + } + + @Test(expected = NullPointerException.class) + public void when_buildIsCalledWithNullChannelName_expect_exception() { + // Act + TransformProcess.builder().build(null); + } + + @Test + public void when_resetIsCalled_expect_resettableAreReset() { + // Arrange + ResettableTransformOperationMock resettableOperation = new ResettableTransformOperationMock(); + TransformProcess sut = TransformProcess.builder() + .filter(new FilterOperationMock(false)) + .transform("test", new IntegerTransformOperationMock()) + .transform("test", resettableOperation) + .build("test"); + + // Act + sut.reset(); + + // Assert + assertTrue(resettableOperation.isResetCalled); + } + + @Test + public void when_buildIsCalledAndAllChannelsAreDataSets_expect_observation() { + // Arrange + TransformProcess sut = TransformProcess.builder() + .transform("test", new ToDataSetTransformOperationMock()) + .build("test"); + Map channelsData = new HashMap() {{ + put("test", 1); + }}; + + // Act + Observation result = sut.transform(channelsData, 123, true); + + // Assert + assertFalse(result.isSkipped()); + + assertEquals(1.0, result.getData().getDouble(0), 0.00001); + } + + @Test + public void when_buildIsCalledAndAllChannelsAreINDArrays_expect_observation() { + // Arrange + TransformProcess sut = TransformProcess.builder() + .build("test"); + Map channelsData = new HashMap() {{ + put("test", Nd4j.create(new double[] { 1.0 })); + }}; + + // Act + Observation result = sut.transform(channelsData, 123, true); + + // Assert + assertFalse(result.isSkipped()); + + assertEquals(1.0, result.getData().getDouble(0), 0.00001); + } + + @Test(expected = IllegalStateException.class) + public void when_buildIsCalledAndChannelsNotDataSetsOrINDArrays_expect_exception() { + // Arrange + TransformProcess sut = TransformProcess.builder() + .build("test"); + Map channelsData = new HashMap() {{ + put("test", 1); + }}; + + // Act + Observation result = sut.transform(channelsData, 123, true); + } + + @Test(expected = NullPointerException.class) + public void when_channelDataIsNull_expect_exception() { + // Arrange + TransformProcess sut = TransformProcess.builder() + .transform("test", new IntegerTransformOperationMock()) + .build("test"); + Map channelsData = new HashMap() {{ + put("test", null); + }}; + + // Act + sut.transform(channelsData, 0, false); + } + + @Test(expected = IllegalArgumentException.class) + public void when_transformAppliedOnChannelNotInMap_expect_exception() { + // Arrange + TransformProcess sut = TransformProcess.builder() + .transform("test", new IntegerTransformOperationMock()) + .build("test"); + Map channelsData = new HashMap() {{ + put("not-test", 1); + }}; + + // Act + sut.transform(channelsData, 0, false); + } + + @Test(expected = IllegalArgumentException.class) + public void when_preProcessAppliedOnChannelNotInMap_expect_exception() { + // Arrange + TransformProcess sut = TransformProcess.builder() + .preProcess("test", new DataSetPreProcessorMock()) + .build("test"); + Map channelsData = new HashMap() {{ + put("not-test", 1); + }}; + + // Act + sut.transform(channelsData, 0, false); + } + + @Test(expected = IllegalArgumentException.class) + public void when_buildContainsChannelNotInMap_expect_exception() { + // Arrange + TransformProcess sut = TransformProcess.builder() + .transform("test", new IntegerTransformOperationMock()) + .build("not-test"); + Map channelsData = new HashMap() {{ + put("test", 1); + }}; + + // Act + sut.transform(channelsData, 0, false); + } + + @Test(expected = IllegalArgumentException.class) + public void when_preProcessNotAppliedOnDataSet_expect_exception() { + // Arrange + TransformProcess sut = TransformProcess.builder() + .preProcess("test", new DataSetPreProcessorMock()) + .build("test"); + Map channelsData = new HashMap() {{ + put("test", 1); + }}; + + // Act + sut.transform(channelsData, 0, false); + } + + private static class FilterOperationMock implements FilterOperation { + + private final boolean skipped; + + public FilterOperationMock(boolean skipped) { + this.skipped = skipped; + } + + @Override + public boolean isSkipped(Map channelsData, int currentObservationStep, boolean isFinalObservation) { + return skipped; + } + } + + private static class IntegerTransformOperationMock implements Operation { + + public boolean isCalled = false; + + @Override + public Integer transform(Integer input) { + isCalled = true; + return -input; + } + } + + private static class ToDataSetTransformOperationMock implements Operation { + + @Override + public DataSet transform(Integer input) { + return new org.nd4j.linalg.dataset.DataSet(Nd4j.create(new double[] { input }), null); + } + } + + private static class ResettableTransformOperationMock implements Operation, ResettableOperation { + + private boolean isResetCalled = false; + + @Override + public Integer transform(Integer input) { + return input * 10; + } + + @Override + public void reset() { + isResetCalled = true; + } + } + + private static class DataSetPreProcessorMock implements DataSetPreProcessor { + + @Override + public void preProcess(DataSet dataSet) { + dataSet.getFeatures().muli(10.0); + } + } +} diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/observation/transform/filter/UniformSkippingFilterTest.java b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/observation/transform/filter/UniformSkippingFilterTest.java new file mode 100644 index 000000000..3aa5a17cf --- /dev/null +++ b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/observation/transform/filter/UniformSkippingFilterTest.java @@ -0,0 +1,54 @@ +package org.deeplearning4j.rl4j.observation.transform.filter; + +import org.deeplearning4j.rl4j.observation.transform.FilterOperation; +import org.junit.Test; + +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertTrue; + +public class UniformSkippingFilterTest { + + @Test(expected = IllegalArgumentException.class) + public void when_negativeSkipFrame_expect_exception() { + // Act + new UniformSkippingFilter(-1); + } + + @Test + public void when_skippingIs4_expect_firstNotSkippedOther3Skipped() { + // Assemble + FilterOperation sut = new UniformSkippingFilter(4); + boolean[] isSkipped = new boolean[8]; + + // Act + for(int i = 0; i < 8; ++i) { + isSkipped[i] = sut.isSkipped(null, i, false); + } + + // Assert + assertFalse(isSkipped[0]); + assertTrue(isSkipped[1]); + assertTrue(isSkipped[2]); + assertTrue(isSkipped[3]); + + assertFalse(isSkipped[4]); + assertTrue(isSkipped[5]); + assertTrue(isSkipped[6]); + assertTrue(isSkipped[7]); + } + + @Test + public void when_isLastObservation_expect_observationNotSkipped() { + // Assemble + FilterOperation sut = new UniformSkippingFilter(4); + + // Act + boolean isSkippedNotLastObservation = sut.isSkipped(null, 1, false); + boolean isSkippedLastObservation = sut.isSkipped(null, 1, true); + + // Assert + assertTrue(isSkippedNotLastObservation); + assertFalse(isSkippedLastObservation); + } + +} diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/observation/transform/operation/SimpleNormalizationTransformTest.java b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/observation/transform/operation/SimpleNormalizationTransformTest.java new file mode 100644 index 000000000..b330a4fb0 --- /dev/null +++ b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/observation/transform/operation/SimpleNormalizationTransformTest.java @@ -0,0 +1,31 @@ +package org.deeplearning4j.rl4j.observation.transform.operation; + +import org.deeplearning4j.rl4j.helper.INDArrayHelper; +import org.junit.Test; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.factory.Nd4j; + +import static org.junit.Assert.*; + +public class SimpleNormalizationTransformTest { + @Test(expected = IllegalArgumentException.class) + public void when_maxIsLessThanMin_expect_exception() { + // Arrange + SimpleNormalizationTransform sut = new SimpleNormalizationTransform(10.0, 1.0); + } + + @Test + public void when_transformIsCalled_expect_inputNormalized() { + // Arrange + SimpleNormalizationTransform sut = new SimpleNormalizationTransform(1.0, 11.0); + INDArray input = Nd4j.create(new double[] { 1.0, 11.0 }); + + // Act + INDArray result = sut.transform(input); + + // Assert + assertEquals(0.0, result.getDouble(0), 0.00001); + assertEquals(1.0, result.getDouble(1), 0.00001); + } + +} diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/policy/PolicyTest.java b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/policy/PolicyTest.java index 2262f1789..0707e16ab 100644 --- a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/policy/PolicyTest.java +++ b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/policy/PolicyTest.java @@ -23,15 +23,18 @@ import org.deeplearning4j.nn.gradient.Gradient; import org.deeplearning4j.nn.graph.ComputationGraph; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.deeplearning4j.rl4j.learning.IHistoryProcessor; +import org.deeplearning4j.rl4j.learning.Learning; import org.deeplearning4j.rl4j.learning.sync.qlearning.QLearning; import org.deeplearning4j.rl4j.learning.sync.qlearning.discrete.QLearningDiscreteTest; import org.deeplearning4j.rl4j.mdp.MDP; import org.deeplearning4j.rl4j.network.NeuralNet; import org.deeplearning4j.rl4j.network.ac.IActorCritic; import org.deeplearning4j.rl4j.observation.Observation; +import org.deeplearning4j.rl4j.space.ActionSpace; import org.deeplearning4j.rl4j.space.DiscreteSpace; import org.deeplearning4j.rl4j.space.Encodable; import org.deeplearning4j.rl4j.support.*; +import org.deeplearning4j.rl4j.util.LegacyMDPWrapper; import org.junit.Test; import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.api.ndarray.INDArray; @@ -186,8 +189,8 @@ public class PolicyTest { QLearning.QLConfiguration conf = new QLearning.QLConfiguration(0, 0, 0, 5, 1, 0, 0, 1.0, 0, 0, 0, 0, true); MockNeuralNet nnMock = new MockNeuralNet(); - MockRefacPolicy sut = new MockRefacPolicy(nnMock); IHistoryProcessor.Configuration hpConf = new IHistoryProcessor.Configuration(5, 4, 4, 4, 4, 0, 0, 2); + MockRefacPolicy sut = new MockRefacPolicy(nnMock, observationSpace.getShape(), hpConf.getSkipFrame(), hpConf.getHistoryLength()); MockHistoryProcessor hp = new MockHistoryProcessor(hpConf); // Act @@ -197,13 +200,6 @@ public class PolicyTest { assertEquals(1, nnMock.resetCallCount); assertEquals(465.0, totalReward, 0.0001); - // HistoryProcessor - assertEquals(16, hp.addCalls.size()); - assertEquals(31, hp.recordCalls.size()); - for(int i=0; i <= 30; ++i) { - assertEquals((double)i, hp.recordCalls.get(i).getDouble(0), 0.0001); - } - // MDP assertEquals(1, mdp.resetCount); assertEquals(30, mdp.actions.size()); @@ -219,10 +215,15 @@ public class PolicyTest { public static class MockRefacPolicy extends Policy { private NeuralNet neuralNet; + private final int[] shape; + private final int skipFrame; + private final int historyLength; - public MockRefacPolicy(NeuralNet neuralNet) { - + public MockRefacPolicy(NeuralNet neuralNet, int[] shape, int skipFrame, int historyLength) { this.neuralNet = neuralNet; + this.shape = shape; + this.skipFrame = skipFrame; + this.historyLength = historyLength; } @Override @@ -239,5 +240,11 @@ public class PolicyTest { public Integer nextAction(INDArray input) { return (int)input.getDouble(0); } + + @Override + protected > Learning.InitMdp refacInitMdp(LegacyMDPWrapper mdpWrapper, IHistoryProcessor hp, RefacEpochStepCounter epochStepCounter) { + mdpWrapper.setTransformProcess(MockMDP.buildTransformProcess(shape, skipFrame, historyLength)); + return super.refacInitMdp(mdpWrapper, hp, epochStepCounter); + } } } diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/support/MockMDP.java b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/support/MockMDP.java index c0ac23a2d..bbed87624 100644 --- a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/support/MockMDP.java +++ b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/support/MockMDP.java @@ -2,6 +2,12 @@ package org.deeplearning4j.rl4j.support; import org.deeplearning4j.gym.StepReply; import org.deeplearning4j.rl4j.mdp.MDP; +import org.deeplearning4j.rl4j.observation.transform.TransformProcess; +import org.deeplearning4j.rl4j.observation.transform.filter.UniformSkippingFilter; +import org.deeplearning4j.rl4j.observation.transform.legacy.EncodableToINDArrayTransform; +import org.deeplearning4j.rl4j.observation.transform.operation.HistoryMergeTransform; +import org.deeplearning4j.rl4j.observation.transform.operation.SimpleNormalizationTransform; +import org.deeplearning4j.rl4j.observation.transform.operation.historymerge.CircularFifoStore; import org.deeplearning4j.rl4j.space.DiscreteSpace; import org.deeplearning4j.rl4j.space.ObservationSpace; import org.nd4j.linalg.api.rng.Random; @@ -77,4 +83,16 @@ public class MockMDP implements MDP { public MDP newInstance() { return null; } + + public static TransformProcess buildTransformProcess(int[] shape, int skipFrame, int historyLength) { + return TransformProcess.builder() + .filter(new UniformSkippingFilter(skipFrame)) + .transform("data", new EncodableToINDArrayTransform(shape)) + .transform("data", new SimpleNormalizationTransform(0.0, 255.0)) + .transform("data", HistoryMergeTransform.builder() + .elementStore(new CircularFifoStore(historyLength)) + .build()) + .build("data"); + } + }