From ee6aae268f642daf00a85b1a276dad63a2e6e7d7 Mon Sep 17 00:00:00 2001 From: Alexandre Boulanger <44292157+aboulang2002@users.noreply.github.com> Date: Fri, 19 Jul 2019 20:28:20 -0400 Subject: [PATCH] RL4J refac: Added some observation transform classes (#7958) * Added observation classes and tests Signed-off-by: unknown * Now uses DataSetPreProcessors Signed-off-by: Alexandre Boulanger * CompositeDataSetPreProcessor can now stop processing on empty dataset; Some DataSetPreProcessors moving from RL4J to ND4J Signed-off-by: Alexandre Boulanger * Did requested minor changes Signed-off-by: Alexandre Boulanger Signed-off-by: Alexandre Boulanger --- .../CompositeDataSetPreProcessor.java | 19 +- .../CropAndResizeDataSetPreProcessor.java | 105 +++++++++++ .../PermuteDataSetPreProcessor.java | 87 ++++++++++ .../RGBtoGrayscaleDataSetPreProcessor.java | 70 ++++++++ .../CompositeDataSetPreProcessorTest.java | 102 +++++++++++ .../CropAndResizeDataSetPreProcessorTest.java | 131 ++++++++++++++ .../PermuteDataSetPreProcessorTest.java | 124 +++++++++++++ ...RGBtoGrayscaleDataSetPreProcessorTest.java | 123 +++++++++++++ .../PoolingDataSetPreProcessor.java | 130 ++++++++++++++ .../ResettableDataSetPreProcessor.java | 28 +++ .../SkippingDataSetPreProcessor.java | 62 +++++++ .../ChannelStackPoolContentAssembler.java | 53 ++++++ .../pooling/CircularFifoObservationPool.java | 95 ++++++++++ .../preprocessor/pooling/ObservationPool.java | 32 ++++ .../pooling/PoolContentAssembler.java | 30 ++++ .../PoolingDataSetPreProcessorTest.java | 164 ++++++++++++++++++ .../SkippingDataSetPreProcessorTest.java | 70 ++++++++ .../ChannelStackPoolContentAssemblerTest.java | 41 +++++ .../CircularFifoObservationPoolTest.java | 100 +++++++++++ 19 files changed, 1565 insertions(+), 1 deletion(-) create mode 100644 nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/dataset/api/preprocessor/CropAndResizeDataSetPreProcessor.java create mode 100644 nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/dataset/api/preprocessor/PermuteDataSetPreProcessor.java create mode 100644 nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/dataset/api/preprocessor/RGBtoGrayscaleDataSetPreProcessor.java create mode 100644 nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/dataset/api/preprocessor/CompositeDataSetPreProcessorTest.java create mode 100644 nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/dataset/api/preprocessor/CropAndResizeDataSetPreProcessorTest.java create mode 100644 nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/dataset/api/preprocessor/PermuteDataSetPreProcessorTest.java create mode 100644 nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/dataset/api/preprocessor/RGBtoGrayscaleDataSetPreProcessorTest.java create mode 100644 rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/observation/preprocessor/PoolingDataSetPreProcessor.java create mode 100644 rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/observation/preprocessor/ResettableDataSetPreProcessor.java create mode 100644 rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/observation/preprocessor/SkippingDataSetPreProcessor.java create mode 100644 rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/observation/preprocessor/pooling/ChannelStackPoolContentAssembler.java create mode 100644 rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/observation/preprocessor/pooling/CircularFifoObservationPool.java create mode 100644 rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/observation/preprocessor/pooling/ObservationPool.java create mode 100644 rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/observation/preprocessor/pooling/PoolContentAssembler.java create mode 100644 rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/observation/preprocessor/PoolingDataSetPreProcessorTest.java create mode 100644 rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/observation/preprocessor/SkippingDataSetPreProcessorTest.java create mode 100644 rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/observation/preprocessor/pooling/ChannelStackPoolContentAssemblerTest.java create mode 100644 rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/observation/preprocessor/pooling/CircularFifoObservationPoolTest.java diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/dataset/api/preprocessor/CompositeDataSetPreProcessor.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/dataset/api/preprocessor/CompositeDataSetPreProcessor.java index cf868257c..14a8bf489 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/dataset/api/preprocessor/CompositeDataSetPreProcessor.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/dataset/api/preprocessor/CompositeDataSetPreProcessor.java @@ -16,6 +16,7 @@ package org.nd4j.linalg.dataset.api.preprocessor; +import org.nd4j.base.Preconditions; import org.nd4j.linalg.dataset.api.DataSet; import org.nd4j.linalg.dataset.api.DataSetPreProcessor; import org.nd4j.linalg.dataset.api.MultiDataSet; @@ -29,19 +30,35 @@ import org.nd4j.linalg.dataset.api.MultiDataSetPreProcessor; */ public class CompositeDataSetPreProcessor implements DataSetPreProcessor { + private final boolean stopOnEmptyDataSet; private DataSetPreProcessor[] preProcessors; /** * @param preProcessors Preprocessors to apply. They will be applied in this order */ - public CompositeDataSetPreProcessor(DataSetPreProcessor... preProcessors){ + public CompositeDataSetPreProcessor(DataSetPreProcessor... preProcessors) { + this(false, preProcessors); + } + + public CompositeDataSetPreProcessor(boolean stopOnEmptyDataSet, DataSetPreProcessor... preProcessors){ + this.stopOnEmptyDataSet = stopOnEmptyDataSet; this.preProcessors = preProcessors; } @Override public void preProcess(DataSet dataSet) { + Preconditions.checkNotNull(dataSet, "Encountered null dataSet"); + + if(stopOnEmptyDataSet && dataSet.isEmpty()) { + return; + } + for(DataSetPreProcessor p : preProcessors){ p.preProcess(dataSet); + + if(stopOnEmptyDataSet && dataSet.isEmpty()) { + return; + } } } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/dataset/api/preprocessor/CropAndResizeDataSetPreProcessor.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/dataset/api/preprocessor/CropAndResizeDataSetPreProcessor.java new file mode 100644 index 000000000..c515b1c5a --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/dataset/api/preprocessor/CropAndResizeDataSetPreProcessor.java @@ -0,0 +1,105 @@ +/******************************************************************************* + * Copyright (c) 2015-2019 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +package org.nd4j.linalg.dataset.api.preprocessor; + +import org.nd4j.base.Preconditions; +import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.ops.CustomOp; +import org.nd4j.linalg.api.ops.DynamicCustomOp; +import org.nd4j.linalg.api.shape.LongShapeDescriptor; +import org.nd4j.linalg.dataset.api.DataSet; +import org.nd4j.linalg.dataset.api.DataSetPreProcessor; +import org.nd4j.linalg.factory.Nd4j; + +/** + * The CropAndResizeDataSetPreProcessor will crop and resize the processed dataset. + * NOTE: The data format must be NHWC + * + * @author Alexandre Boulanger + */ +public class CropAndResizeDataSetPreProcessor implements DataSetPreProcessor { + + public enum ResizeMethod { + Bilinear, + NearestNeighbor + } + + private final long[] resizedShape; + private final INDArray indices; + private final INDArray resize; + private final INDArray boxes; + private final int method; + + /** + * + * @param originalHeight Height of the input datasets + * @param originalWidth Width of the input datasets + * @param cropYStart y coord of the starting point on the input datasets + * @param cropXStart x coord of the starting point on the input datasets + * @param resizedHeight Height of the output dataset + * @param resizedWidth Width of the output dataset + * @param numChannels + * @param resizeMethod + */ + public CropAndResizeDataSetPreProcessor(int originalHeight, int originalWidth, int cropYStart, int cropXStart, int resizedHeight, int resizedWidth, int numChannels, ResizeMethod resizeMethod) { + Preconditions.checkArgument(originalHeight > 0, "originalHeight must be greater than 0, got %s", originalHeight); + Preconditions.checkArgument(originalWidth > 0, "originalWidth must be greater than 0, got %s", originalWidth); + Preconditions.checkArgument(cropYStart >= 0, "cropYStart must be greater or equal to 0, got %s", cropYStart); + Preconditions.checkArgument(cropXStart >= 0, "cropXStart must be greater or equal to 0, got %s", cropXStart); + Preconditions.checkArgument(resizedHeight > 0, "resizedHeight must be greater than 0, got %s", resizedHeight); + Preconditions.checkArgument(resizedWidth > 0, "resizedWidth must be greater than 0, got %s", resizedWidth); + Preconditions.checkArgument(numChannels > 0, "numChannels must be greater than 0, got %s", numChannels); + + resizedShape = new long[] { 1, resizedHeight, resizedWidth, numChannels }; + + boxes = Nd4j.create(new float[] { + (float)cropYStart / (float)originalHeight, + (float)cropXStart / (float)originalWidth, + (float)(cropYStart + resizedHeight) / (float)originalHeight, + (float)(cropXStart + resizedWidth) / (float)originalWidth + }, new long[] { 1, 4 }, DataType.FLOAT); + indices = Nd4j.create(new int[] { 0 }, new long[] { 1, 1 }, DataType.INT); + + resize = Nd4j.create(new int[] { resizedHeight, resizedWidth }, new long[] { 1, 2 }, DataType.INT); + method = resizeMethod == ResizeMethod.Bilinear ? 0 : 1; + } + + /** + * NOTE: The data format must be NHWC + */ + @Override + public void preProcess(DataSet dataSet) { + Preconditions.checkNotNull(dataSet, "Encountered null dataSet"); + + if(dataSet.isEmpty()) { + return; + } + + INDArray input = dataSet.getFeatures(); + INDArray output = Nd4j.create(LongShapeDescriptor.fromShape(resizedShape, input.dataType()), false); + + CustomOp op = DynamicCustomOp.builder("crop_and_resize") + .addInputs(input, boxes, indices, resize) + .addIntegerArguments(method) + .addOutputs(output) + .build(); + Nd4j.getExecutioner().exec(op); + + dataSet.setFeatures(output); + } +} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/dataset/api/preprocessor/PermuteDataSetPreProcessor.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/dataset/api/preprocessor/PermuteDataSetPreProcessor.java new file mode 100644 index 000000000..f2aded02b --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/dataset/api/preprocessor/PermuteDataSetPreProcessor.java @@ -0,0 +1,87 @@ +/******************************************************************************* + * Copyright (c) 2015-2019 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +package org.nd4j.linalg.dataset.api.preprocessor; + +import org.nd4j.base.Preconditions; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.dataset.api.DataSet; +import org.nd4j.linalg.dataset.api.DataSetPreProcessor; + +/** + * The PermuteDataSetPreProcessor will rearrange the dimensions. + * There are two pre-defined permutation types: + * - from NCHW to NHWC + * - from NHWC to NCHW + * + * Or, pass the new order to the ctor. For example PermuteDataSetPreProcessor(1, 2, 0) will rearrange the middle dimension first, the last one in the middle and the first one last. + * + * @author Alexandre Boulanger + */ +public class PermuteDataSetPreProcessor implements DataSetPreProcessor { + + private final PermutationTypes permutationType; + private final int[] rearrange; + + public enum PermutationTypes { NCHWtoNHWC, NHWCtoNCHW, Custom } + + public PermuteDataSetPreProcessor(PermutationTypes permutationType) { + Preconditions.checkArgument(permutationType != PermutationTypes.Custom, "Use the ctor PermuteDataSetPreProcessor(int... rearrange) for custom permutations."); + + this.permutationType = permutationType; + rearrange = null; + } + + /** + * @param rearrange The new order. For example PermuteDataSetPreProcessor(1, 2, 0) will rearrange the middle dimension first, the last one in the middle and the first one last. + */ + public PermuteDataSetPreProcessor(int... rearrange) { + + this.permutationType = PermutationTypes.Custom; + this.rearrange = rearrange; + } + + @Override + public void preProcess(DataSet dataSet) { + Preconditions.checkNotNull(dataSet, "Encountered null dataSet"); + + if(dataSet.isEmpty()) { + return; + } + + INDArray input = dataSet.getFeatures(); + INDArray output; + switch (permutationType) { + case NCHWtoNHWC: + output = input.permute(0, 2, 3, 1); + break; + + case NHWCtoNCHW: + output = input.permute(0, 3, 1, 2); + break; + + case Custom: + output = input.permute(rearrange); + break; + + default: + output = input; + break; + } + + dataSet.setFeatures(output); + } +} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/dataset/api/preprocessor/RGBtoGrayscaleDataSetPreProcessor.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/dataset/api/preprocessor/RGBtoGrayscaleDataSetPreProcessor.java new file mode 100644 index 000000000..5042510ce --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/dataset/api/preprocessor/RGBtoGrayscaleDataSetPreProcessor.java @@ -0,0 +1,70 @@ +/******************************************************************************* + * Copyright (c) 2015-2019 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +package org.nd4j.linalg.dataset.api.preprocessor; + +import org.nd4j.base.Preconditions; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.dataset.api.DataSet; +import org.nd4j.linalg.dataset.api.DataSetPreProcessor; +import org.nd4j.linalg.factory.Nd4j; + +/** + * The RGBtoGrayscaleDataSetPreProcessor will turn a DataSet of a RGB image into a grayscale one. + * NOTE: Expects data format to be NCHW. After processing, the channel dimension is eliminated. (NCHW -> NHW) + * + * @author Alexandre Boulanger + */ +public class RGBtoGrayscaleDataSetPreProcessor implements DataSetPreProcessor { + + private static final float RED_RATIO = 0.3f; + private static final float GREEN_RATIO = 0.59f; + private static final float BLUE_RATIO = 0.11f; + + @Override + public void preProcess(DataSet dataSet) { + Preconditions.checkNotNull(dataSet, "Encountered null dataSet"); + + if(dataSet.isEmpty()) { + return; + } + + INDArray originalFeatures = dataSet.getFeatures(); + long[] originalShape = originalFeatures.shape(); + + // result shape is NHW + INDArray result = Nd4j.create(originalShape[0], originalShape[2], originalShape[3]); + + for(long n = 0, numExamples = originalShape[0]; n < numExamples; ++n) { + // Extract channels + INDArray itemFeatures = originalFeatures.slice(n, 0); // shape is CHW + INDArray R = itemFeatures.slice(0, 0); // shape is HW + INDArray G = itemFeatures.slice(1, 0); + INDArray B = itemFeatures.slice(2, 0); + + // Convert + R.muli(RED_RATIO); + G.muli(GREEN_RATIO); + B.muli(BLUE_RATIO); + R.addi(G).addi(B); + + // FIXME: int cast + result.putSlice((int)n, R); + } + + dataSet.setFeatures(result); + } +} diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/dataset/api/preprocessor/CompositeDataSetPreProcessorTest.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/dataset/api/preprocessor/CompositeDataSetPreProcessorTest.java new file mode 100644 index 000000000..a2af67dc9 --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/dataset/api/preprocessor/CompositeDataSetPreProcessorTest.java @@ -0,0 +1,102 @@ +package org.nd4j.linalg.dataset.api.preprocessor; + +import org.junit.Test; +import org.nd4j.linalg.dataset.DataSet; +import org.nd4j.linalg.dataset.api.DataSetPreProcessor; +import org.nd4j.linalg.factory.Nd4j; + +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertTrue; + +public class CompositeDataSetPreProcessorTest { + @Test(expected = NullPointerException.class) + public void when_preConditionsIsNull_expect_NullPointerException() { + // Assemble + CompositeDataSetPreProcessor sut = new CompositeDataSetPreProcessor(); + + // Act + sut.preProcess(null); + + } + + @Test + public void when_dataSetIsEmpty_expect_emptyDataSet() { + // Assemble + CompositeDataSetPreProcessor sut = new CompositeDataSetPreProcessor(); + DataSet ds = new DataSet(null, null); + + // Act + sut.preProcess(ds); + + // Assert + assertTrue(ds.isEmpty()); + } + + @Test + public void when_notStoppingOnEmptyDataSet_expect_allPreProcessorsCalled() { + // Assemble + TestDataSetPreProcessor preProcessor1 = new TestDataSetPreProcessor(true); + TestDataSetPreProcessor preProcessor2 = new TestDataSetPreProcessor(true); + CompositeDataSetPreProcessor sut = new CompositeDataSetPreProcessor(preProcessor1, preProcessor2); + DataSet ds = new DataSet(Nd4j.rand(2, 2), null); + + // Act + sut.preProcess(ds); + + // Assert + assertTrue(preProcessor1.hasBeenCalled); + assertTrue(preProcessor2.hasBeenCalled); + } + + @Test + public void when_stoppingOnEmptyDataSetAndFirstPreProcessorClearDS_expect_firstPreProcessorsCalled() { + // Assemble + TestDataSetPreProcessor preProcessor1 = new TestDataSetPreProcessor(true); + TestDataSetPreProcessor preProcessor2 = new TestDataSetPreProcessor(true); + CompositeDataSetPreProcessor sut = new CompositeDataSetPreProcessor(true, preProcessor1, preProcessor2); + DataSet ds = new DataSet(Nd4j.rand(2, 2), null); + + // Act + sut.preProcess(ds); + + // Assert + assertTrue(preProcessor1.hasBeenCalled); + assertFalse(preProcessor2.hasBeenCalled); + } + + @Test + public void when_stoppingOnEmptyDataSet_expect_firstPreProcessorsCalled() { + // Assemble + TestDataSetPreProcessor preProcessor1 = new TestDataSetPreProcessor(false); + TestDataSetPreProcessor preProcessor2 = new TestDataSetPreProcessor(false); + CompositeDataSetPreProcessor sut = new CompositeDataSetPreProcessor(true, preProcessor1, preProcessor2); + DataSet ds = new DataSet(Nd4j.rand(2, 2), null); + + // Act + sut.preProcess(ds); + + // Assert + assertTrue(preProcessor1.hasBeenCalled); + assertTrue(preProcessor2.hasBeenCalled); + } + + public static class TestDataSetPreProcessor implements DataSetPreProcessor { + + private final boolean clearDataSet; + + public boolean hasBeenCalled; + + public TestDataSetPreProcessor(boolean clearDataSet) { + this.clearDataSet = clearDataSet; + } + + @Override + public void preProcess(org.nd4j.linalg.dataset.api.DataSet dataSet) { + hasBeenCalled = true; + if(clearDataSet) { + dataSet.setFeatures(null); + } + } + } + +} diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/dataset/api/preprocessor/CropAndResizeDataSetPreProcessorTest.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/dataset/api/preprocessor/CropAndResizeDataSetPreProcessorTest.java new file mode 100644 index 000000000..63abfffcd --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/dataset/api/preprocessor/CropAndResizeDataSetPreProcessorTest.java @@ -0,0 +1,131 @@ +package org.nd4j.linalg.dataset.api.preprocessor; + +import org.junit.Test; +import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.shape.LongShapeDescriptor; +import org.nd4j.linalg.dataset.DataSet; +import org.nd4j.linalg.factory.Nd4j; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; + +public class CropAndResizeDataSetPreProcessorTest { + + @Test(expected = IllegalArgumentException.class) + public void when_originalHeightIsZero_expect_IllegalArgumentException() { + CropAndResizeDataSetPreProcessor sut = new CropAndResizeDataSetPreProcessor(0, 15, 5, 5, 4, 3, 3, CropAndResizeDataSetPreProcessor.ResizeMethod.NearestNeighbor); + } + + @Test(expected = IllegalArgumentException.class) + public void when_originalWidthIsZero_expect_IllegalArgumentException() { + CropAndResizeDataSetPreProcessor sut = new CropAndResizeDataSetPreProcessor(10, 0, 5, 5, 4, 3, 3, CropAndResizeDataSetPreProcessor.ResizeMethod.NearestNeighbor); + } + + @Test(expected = IllegalArgumentException.class) + public void when_yStartIsNegative_expect_IllegalArgumentException() { + CropAndResizeDataSetPreProcessor sut = new CropAndResizeDataSetPreProcessor(10, 15, -1, 5, 4, 3, 3, CropAndResizeDataSetPreProcessor.ResizeMethod.NearestNeighbor); + } + + @Test(expected = IllegalArgumentException.class) + public void when_xStartIsNegative_expect_IllegalArgumentException() { + CropAndResizeDataSetPreProcessor sut = new CropAndResizeDataSetPreProcessor(10, 15, 5, -1, 4, 3, 3, CropAndResizeDataSetPreProcessor.ResizeMethod.NearestNeighbor); + } + + @Test(expected = IllegalArgumentException.class) + public void when_heightIsNotGreaterThanZero_expect_IllegalArgumentException() { + CropAndResizeDataSetPreProcessor sut = new CropAndResizeDataSetPreProcessor(10, 15, 5, 5, 0, 3, 3, CropAndResizeDataSetPreProcessor.ResizeMethod.NearestNeighbor); + } + + @Test(expected = IllegalArgumentException.class) + public void when_widthIsNotGreaterThanZero_expect_IllegalArgumentException() { + CropAndResizeDataSetPreProcessor sut = new CropAndResizeDataSetPreProcessor(10, 15, 5, 5, 4, 0, 3, CropAndResizeDataSetPreProcessor.ResizeMethod.NearestNeighbor); + } + + @Test(expected = IllegalArgumentException.class) + public void when_numChannelsIsNotGreaterThanZero_expect_IllegalArgumentException() { + CropAndResizeDataSetPreProcessor sut = new CropAndResizeDataSetPreProcessor(10, 15, 5, 5, 4, 3, 0, CropAndResizeDataSetPreProcessor.ResizeMethod.NearestNeighbor); + } + + @Test(expected = NullPointerException.class) + public void when_dataSetIsNull_expect_NullPointerException() { + // Assemble + CropAndResizeDataSetPreProcessor sut = new CropAndResizeDataSetPreProcessor(10, 15, 5, 5, 4, 3, 3, CropAndResizeDataSetPreProcessor.ResizeMethod.NearestNeighbor); + + // Act + sut.preProcess(null); + } + + @Test + public void when_dataSetIsEmpty_expect_emptyDataSet() { + // Assemble + CropAndResizeDataSetPreProcessor sut = new CropAndResizeDataSetPreProcessor(10, 15, 5, 5, 4, 3, 3, CropAndResizeDataSetPreProcessor.ResizeMethod.NearestNeighbor); + DataSet ds = new DataSet(null, null); + + // Act + sut.preProcess(ds); + + // Assert + assertTrue(ds.isEmpty()); + } + + @Test + public void when_dataSetIs15wx10h_expect_3wx4hDataSet() { + // Assemble + int numChannels = 3; + int height = 10; + int width = 15; + CropAndResizeDataSetPreProcessor sut = new CropAndResizeDataSetPreProcessor(height, width, 5, 5, 4, 3, 3, CropAndResizeDataSetPreProcessor.ResizeMethod.NearestNeighbor); + INDArray input = Nd4j.create(LongShapeDescriptor.fromShape(new int[] { 1, height, width, numChannels }, DataType.FLOAT), true); + for(int c = 0; c < numChannels; ++c) { + for(int h = 0; h < height; ++h) { + for(int w = 0; w < width; ++w) { + input.putScalar(0, h, w, c, c*100 + h*10 + w); + } + } + } + + DataSet ds = new DataSet(input, null); + + // Act + sut.preProcess(ds); + + // Assert + INDArray results = ds.getFeatures(); + long[] shape = results.shape(); + assertEquals(1, shape[0]); + assertEquals(4, shape[1]); + assertEquals(3, shape[2]); + assertEquals(3, shape[3]); + + // Test a few values + assertEquals(55.0, results.getDouble(0, 0, 0, 0), 0.0); + assertEquals(155.0, results.getDouble(0, 0, 0, 1), 0.0); + assertEquals(255.0, results.getDouble(0, 0, 0, 2), 0.0); + + assertEquals(56.0, results.getDouble(0, 0, 1, 0), 0.0); + assertEquals(156.0, results.getDouble(0, 0, 1, 1), 0.0); + assertEquals(256.0, results.getDouble(0, 0, 1, 2), 0.0); + + assertEquals(57.0, results.getDouble(0, 0, 2, 0), 0.0); + assertEquals(157.0, results.getDouble(0, 0, 2, 1), 0.0); + assertEquals(257.0, results.getDouble(0, 0, 2, 2), 0.0); + + assertEquals(65.0, results.getDouble(0, 1, 0, 0), 0.0); + assertEquals(165.0, results.getDouble(0, 1, 0, 1), 0.0); + assertEquals(265.0, results.getDouble(0, 1, 0, 2), 0.0); + + assertEquals(66.0, results.getDouble(0, 1, 1, 0), 0.0); + assertEquals(166.0, results.getDouble(0, 1, 1, 1), 0.0); + assertEquals(266.0, results.getDouble(0, 1, 1, 2), 0.0); + + assertEquals(75.0, results.getDouble(0, 2, 0, 0), 0.0); + assertEquals(175.0, results.getDouble(0, 2, 0, 1), 0.0); + assertEquals(275.0, results.getDouble(0, 2, 0, 2), 0.0); + + assertEquals(76.0, results.getDouble(0, 2, 1, 0), 0.0); + assertEquals(176.0, results.getDouble(0, 2, 1, 1), 0.0); + assertEquals(276.0, results.getDouble(0, 2, 1, 2), 0.0); + } + +} diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/dataset/api/preprocessor/PermuteDataSetPreProcessorTest.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/dataset/api/preprocessor/PermuteDataSetPreProcessorTest.java new file mode 100644 index 000000000..acbac85df --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/dataset/api/preprocessor/PermuteDataSetPreProcessorTest.java @@ -0,0 +1,124 @@ +package org.nd4j.linalg.dataset.api.preprocessor; + +import org.nd4j.linalg.dataset.api.preprocessor.PermuteDataSetPreProcessor; +import org.junit.Test; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.dataset.DataSet; +import org.nd4j.linalg.factory.Nd4j; + +import static org.junit.Assert.*; + +public class PermuteDataSetPreProcessorTest { + + @Test(expected = NullPointerException.class) + public void when_dataSetIsNull_expect_NullPointerException() { + // Assemble + PermuteDataSetPreProcessor sut = new PermuteDataSetPreProcessor(PermuteDataSetPreProcessor.PermutationTypes.NCHWtoNHWC); + + // Act + sut.preProcess(null); + } + + @Test + public void when_emptyDatasetInInputdataSetIsNCHW_expect_emptyDataSet() { + // Assemble + PermuteDataSetPreProcessor sut = new PermuteDataSetPreProcessor(PermuteDataSetPreProcessor.PermutationTypes.NCHWtoNHWC); + DataSet ds = new DataSet(null, null); + + // Act + sut.preProcess(ds); + + // Assert + assertTrue(ds.isEmpty()); + } + + @Test + public void when_dataSetIsNCHW_expect_dataSetTransformedToNHWC() { + // Assemble + int numChannels = 3; + int height = 5; + int width = 4; + PermuteDataSetPreProcessor sut = new PermuteDataSetPreProcessor(PermuteDataSetPreProcessor.PermutationTypes.NCHWtoNHWC); + INDArray input = Nd4j.create(1, numChannels, height, width); + for(int c = 0; c < numChannels; ++c) { + for(int h = 0; h < height; ++h) { + for(int w = 0; w < width; ++w) { + input.putScalar(0, c, h, w, c*100.0 + h*10.0 + w); + } + } + } + DataSet ds = new DataSet(input, null); + + // Act + sut.preProcess(ds); + + // Assert + INDArray result = ds.getFeatures(); + long[] shape = result.shape(); + assertEquals(1, shape[0]); + assertEquals(height, shape[1]); + assertEquals(width, shape[2]); + assertEquals(numChannels, shape[3]); + + assertEquals(0.0, result.getDouble(0, 0, 0, 0), 0.0); + assertEquals(1.0, result.getDouble(0, 0, 1, 0), 0.0); + assertEquals(2.0, result.getDouble(0, 0, 2, 0), 0.0); + assertEquals(3.0, result.getDouble(0, 0, 3, 0), 0.0); + + assertEquals(110.0, result.getDouble(0, 1, 0, 1), 0.0); + assertEquals(111.0, result.getDouble(0, 1, 1, 1), 0.0); + assertEquals(112.0, result.getDouble(0, 1, 2, 1), 0.0); + assertEquals(113.0, result.getDouble(0, 1, 3, 1), 0.0); + + assertEquals(210.0, result.getDouble(0, 1, 0, 2), 0.0); + assertEquals(211.0, result.getDouble(0, 1, 1, 2), 0.0); + assertEquals(212.0, result.getDouble(0, 1, 2, 2), 0.0); + assertEquals(213.0, result.getDouble(0, 1, 3, 2), 0.0); + + } + + @Test + public void when_dataSetIsNHWC_expect_dataSetTransformedToNCHW() { + // Assemble + int numChannels = 3; + int height = 5; + int width = 4; + PermuteDataSetPreProcessor sut = new PermuteDataSetPreProcessor(PermuteDataSetPreProcessor.PermutationTypes.NHWCtoNCHW); + INDArray input = Nd4j.create(1, height, width, numChannels); + for(int c = 0; c < numChannels; ++c) { + for(int h = 0; h < height; ++h) { + for(int w = 0; w < width; ++w) { + input.putScalar(new int[] { 0, h, w, c }, c*100.0 + h*10.0 + w); + } + } + } + DataSet ds = new DataSet(input, null); + + // Act + sut.preProcess(ds); + + // Assert + INDArray result = ds.getFeatures(); + long[] shape = result.shape(); + assertEquals(1, shape[0]); + assertEquals(numChannels, shape[1]); + assertEquals(height, shape[2]); + assertEquals(width, shape[3]); + + assertEquals(0.0, result.getDouble(0, 0, 0, 0), 0.0); + assertEquals(1.0, result.getDouble(0, 0, 0, 1), 0.0); + assertEquals(2.0, result.getDouble(0, 0, 0, 2), 0.0); + assertEquals(3.0, result.getDouble(0, 0, 0, 3), 0.0); + + assertEquals(110.0, result.getDouble(0, 1, 1, 0), 0.0); + assertEquals(111.0, result.getDouble(0, 1, 1, 1), 0.0); + assertEquals(112.0, result.getDouble(0, 1, 1, 2), 0.0); + assertEquals(113.0, result.getDouble(0, 1, 1, 3), 0.0); + + assertEquals(210.0, result.getDouble(0, 2, 1, 0), 0.0); + assertEquals(211.0, result.getDouble(0, 2, 1, 1), 0.0); + assertEquals(212.0, result.getDouble(0, 2, 1, 2), 0.0); + assertEquals(213.0, result.getDouble(0, 2, 1, 3), 0.0); + + } +} diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/dataset/api/preprocessor/RGBtoGrayscaleDataSetPreProcessorTest.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/dataset/api/preprocessor/RGBtoGrayscaleDataSetPreProcessorTest.java new file mode 100644 index 000000000..b0408d8b7 --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/dataset/api/preprocessor/RGBtoGrayscaleDataSetPreProcessorTest.java @@ -0,0 +1,123 @@ +package org.nd4j.linalg.dataset.api.preprocessor; + +import org.junit.Test; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.dataset.DataSet; +import org.nd4j.linalg.factory.Nd4j; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; + +public class RGBtoGrayscaleDataSetPreProcessorTest { + + @Test(expected = NullPointerException.class) + public void when_dataSetIsNull_expect_NullPointerException() { + // Assemble + RGBtoGrayscaleDataSetPreProcessor sut = new RGBtoGrayscaleDataSetPreProcessor(); + + // Act + sut.preProcess(null); + } + + @Test + public void when_dataSetIsEmpty_expect_EmptyDataSet() { + // Assemble + RGBtoGrayscaleDataSetPreProcessor sut = new RGBtoGrayscaleDataSetPreProcessor(); + DataSet ds = new DataSet(null, null); + + // Act + sut.preProcess(ds); + + // Assert + assertTrue(ds.isEmpty()); + } + + @Test + public void when_colorsAreConverted_expect_grayScaleResult() { + // Assign + int numChannels = 3; + int height = 1; + int width = 5; + + RGBtoGrayscaleDataSetPreProcessor sut = new RGBtoGrayscaleDataSetPreProcessor(); + INDArray input = Nd4j.create(2, numChannels, height, width); + + // Black, Example 1 + input.putScalar(0, 0, 0, 0, 0.0 ); + input.putScalar(0, 1, 0, 0, 0.0 ); + input.putScalar(0, 2, 0, 0, 0.0 ); + + // White, Example 1 + input.putScalar(0, 0, 0, 1, 255.0 ); + input.putScalar(0, 1, 0, 1, 255.0 ); + input.putScalar(0, 2, 0, 1, 255.0 ); + + // Red, Example 1 + input.putScalar(0, 0, 0, 2, 255.0 ); + input.putScalar(0, 1, 0, 2, 0.0 ); + input.putScalar(0, 2, 0, 2, 0.0 ); + + // Green, Example 1 + input.putScalar(0, 0, 0, 3, 0.0 ); + input.putScalar(0, 1, 0, 3, 255.0 ); + input.putScalar(0, 2, 0, 3, 0.0 ); + + // Blue, Example 1 + input.putScalar(0, 0, 0, 4, 0.0 ); + input.putScalar(0, 1, 0, 4, 0.0 ); + input.putScalar(0, 2, 0, 4, 255.0 ); + + + // Black, Example 2 + input.putScalar(1, 0, 0, 4, 0.0 ); + input.putScalar(1, 1, 0, 4, 0.0 ); + input.putScalar(1, 2, 0, 4, 0.0 ); + + // White, Example 2 + input.putScalar(1, 0, 0, 3, 255.0 ); + input.putScalar(1, 1, 0, 3, 255.0 ); + input.putScalar(1, 2, 0, 3, 255.0 ); + + // Red, Example 2 + input.putScalar(1, 0, 0, 2, 255.0 ); + input.putScalar(1, 1, 0, 2, 0.0 ); + input.putScalar(1, 2, 0, 2, 0.0 ); + + // Green, Example 2 + input.putScalar(1, 0, 0, 1, 0.0 ); + input.putScalar(1, 1, 0, 1, 255.0 ); + input.putScalar(1, 2, 0, 1, 0.0 ); + + // Blue, Example 2 + input.putScalar(1, 0, 0, 0, 0.0 ); + input.putScalar(1, 1, 0, 0, 0.0 ); + input.putScalar(1, 2, 0, 0, 255.0 ); + + DataSet ds = new DataSet(input, null); + + // Act + sut.preProcess(ds); + + // Assert + INDArray result = ds.getFeatures(); + long[] shape = result.shape(); + + assertEquals(3, shape.length); + assertEquals(2, shape[0]); + assertEquals(1, shape[1]); + assertEquals(5, shape[2]); + + assertEquals(0.0, result.getDouble(0, 0, 0), 0.05); + assertEquals(255.0, result.getDouble(0, 0, 1), 0.05); + assertEquals(255.0 * 0.3, result.getDouble(0, 0, 2), 0.05); + assertEquals(255.0 * 0.59, result.getDouble(0, 0, 3), 0.05); + assertEquals(255.0 * 0.11, result.getDouble(0, 0, 4), 0.05); + + assertEquals(0.0, result.getDouble(1, 0, 4), 0.05); + assertEquals(255.0, result.getDouble(1, 0, 3), 0.05); + assertEquals(255.0 * 0.3, result.getDouble(1, 0, 2), 0.05); + assertEquals(255.0 * 0.59, result.getDouble(1, 0, 1), 0.05); + assertEquals(255.0 * 0.11, result.getDouble(1, 0, 0), 0.05); + + } +} diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/observation/preprocessor/PoolingDataSetPreProcessor.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/observation/preprocessor/PoolingDataSetPreProcessor.java new file mode 100644 index 000000000..e3fad8b3a --- /dev/null +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/observation/preprocessor/PoolingDataSetPreProcessor.java @@ -0,0 +1,130 @@ +/******************************************************************************* + * Copyright (c) 2015-2019 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +package org.deeplearning4j.rl4j.observation.preprocessor; + +import org.deeplearning4j.rl4j.observation.preprocessor.pooling.ChannelStackPoolContentAssembler; +import org.deeplearning4j.rl4j.observation.preprocessor.pooling.PoolContentAssembler; +import org.deeplearning4j.rl4j.observation.preprocessor.pooling.CircularFifoObservationPool; +import org.deeplearning4j.rl4j.observation.preprocessor.pooling.ObservationPool; +import org.nd4j.base.Preconditions; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.dataset.api.DataSet; + +/** + * The PoolingDataSetPreProcessor will accumulate features from incoming DataSets and will assemble its content + * into a DataSet containing a single example. + * + * There are two special cases: + * 1) preProcess will return without doing anything if the input DataSet is empty + * 2) When the pool has not yet filled, the data from the incoming DataSet is stored in the pool but the DataSet is emptied + * on exit. + *
+ * The PoolingDataSetPreProcessor requires two sub components:
+ * 1) The ObservationPool that supervises what and how input observations are kept. (ex.: Circular FIFO, trailing min/max/avg, etc...) + * The default is a Circular FIFO. + * 2) The PoolContentAssembler that will assemble the pool content into a resulting single INDArray. (ex.: stacked along a dimention, squashed into a single observation, etc...) + * The default is stacking along the dimension 0. + * + * @author Alexandre Boulanger + */ +public class PoolingDataSetPreProcessor extends ResettableDataSetPreProcessor { + private final ObservationPool observationPool; + private final PoolContentAssembler poolContentAssembler; + + protected PoolingDataSetPreProcessor(PoolingDataSetPreProcessor.Builder builder) + { + observationPool = builder.observationPool; + poolContentAssembler = builder.poolContentAssembler; + } + + /** + * Note: preProcess will empty the processed dataset if the pool has not filled. Empty datasets should ignored by the + * Policy/Learning class and other DataSetPreProcessors + * + * @param dataSet + */ + @Override + public void preProcess(DataSet dataSet) { + Preconditions.checkNotNull(dataSet, "Encountered null dataSet"); + + if(dataSet.isEmpty()) { + return; + } + + Preconditions.checkArgument(dataSet.numExamples() == 1, "Pooling datasets conatining more than one example is not supported"); + + // store a duplicate in the pool + observationPool.add(dataSet.getFeatures().slice(0, 0).dup()); + if(!observationPool.isAtFullCapacity()) { + dataSet.setFeatures(null); + return; + } + + INDArray result = poolContentAssembler.assemble(observationPool.get()); + + // return a DataSet containing only 1 example (the result) + long[] resultShape = result.shape(); + long[] newShape = new long[resultShape.length + 1]; + newShape[0] = 1; + System.arraycopy(resultShape, 0, newShape, 1, resultShape.length); + + dataSet.setFeatures(result.reshape(newShape)); + } + + public static PoolingDataSetPreProcessor.Builder builder() { + return new PoolingDataSetPreProcessor.Builder(); + } + + @Override + public void reset() { + observationPool.reset(); + } + + public static class Builder { + private ObservationPool observationPool; + private PoolContentAssembler poolContentAssembler; + + /** + * Default is CircularFifoObservationPool + */ + public PoolingDataSetPreProcessor.Builder observablePool(ObservationPool observationPool) { + this.observationPool = observationPool; + return this; + } + + /** + * Default is ChannelStackPoolContentAssembler + */ + public PoolingDataSetPreProcessor.Builder poolContentAssembler(PoolContentAssembler poolContentAssembler) { + this.poolContentAssembler = poolContentAssembler; + return this; + } + + public PoolingDataSetPreProcessor build() { + if(observationPool == null) { + observationPool = new CircularFifoObservationPool(); + } + + if(poolContentAssembler == null) { + poolContentAssembler = new ChannelStackPoolContentAssembler(); + } + + return new PoolingDataSetPreProcessor(this); + } + } + +} diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/observation/preprocessor/ResettableDataSetPreProcessor.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/observation/preprocessor/ResettableDataSetPreProcessor.java new file mode 100644 index 000000000..46ff4e39c --- /dev/null +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/observation/preprocessor/ResettableDataSetPreProcessor.java @@ -0,0 +1,28 @@ +/******************************************************************************* + * Copyright (c) 2015-2019 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +package org.deeplearning4j.rl4j.observation.preprocessor; + +import org.nd4j.linalg.dataset.api.DataSetPreProcessor; + +/** + * A base class for all DataSetPreProcessor that must be reset between each MDP sessions (games). + * + * @author Alexandre Boulanger + */ +public abstract class ResettableDataSetPreProcessor implements DataSetPreProcessor { + public abstract void reset(); +} diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/observation/preprocessor/SkippingDataSetPreProcessor.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/observation/preprocessor/SkippingDataSetPreProcessor.java new file mode 100644 index 000000000..940966823 --- /dev/null +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/observation/preprocessor/SkippingDataSetPreProcessor.java @@ -0,0 +1,62 @@ +/******************************************************************************* + * Copyright (c) 2015-2019 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +package org.deeplearning4j.rl4j.observation.preprocessor; + +import lombok.Builder; +import org.nd4j.base.Preconditions; +import org.nd4j.linalg.dataset.api.DataSet; + +/** + * The SkippingDataSetPreProcessor will either do nothing to the input (when not skipped) or will empty + * the input DataSet when skipping. + * + * @author Alexandre Boulanger + */ +public class SkippingDataSetPreProcessor extends ResettableDataSetPreProcessor { + + private final int skipFrame; + + private int currentIdx = 0; + + /** + * @param skipFrame For example, a skipFrame of 4 will skip 3 out of 4 observations. + */ + @Builder + public SkippingDataSetPreProcessor(int skipFrame) { + Preconditions.checkArgument(skipFrame > 0, "skipFrame must be greater than 0, got %s", skipFrame); + this.skipFrame = skipFrame; + } + + @Override + public void preProcess(DataSet dataSet) { + Preconditions.checkNotNull(dataSet, "Encountered null dataSet"); + + if(dataSet.isEmpty()) { + return; + } + + if(currentIdx++ % skipFrame != 0) { + dataSet.setFeatures(null); + dataSet.setLabels(null); + } + } + + @Override + public void reset() { + currentIdx = 0; + } +} diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/observation/preprocessor/pooling/ChannelStackPoolContentAssembler.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/observation/preprocessor/pooling/ChannelStackPoolContentAssembler.java new file mode 100644 index 000000000..d53f5c8c8 --- /dev/null +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/observation/preprocessor/pooling/ChannelStackPoolContentAssembler.java @@ -0,0 +1,53 @@ +/******************************************************************************* + * Copyright (c) 2015-2019 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +package org.deeplearning4j.rl4j.observation.preprocessor.pooling; + +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.factory.Nd4j; + +/** + * ChannelStackPoolContentAssembler is used with the PoolingDataSetPreProcessor. This assembler will + * stack along the dimension 0. For example if the pool elements are of shape [ Height, Width ] + * the output will be of shape [ Stacked, Height, Width ] + * + * @author Alexandre Boulanger + */ +public class ChannelStackPoolContentAssembler implements PoolContentAssembler { + + /** + * Will return a new INDArray with one more dimension and with poolContent stacked along dimension 0. + * + * @param poolContent Array of INDArray + * @return A new INDArray with 1 more dimension than the input elements + */ + @Override + public INDArray assemble(INDArray[] poolContent) + { + // build the new shape + long[] elementShape = poolContent[0].shape(); + long[] newShape = new long[elementShape.length + 1]; + newShape[0] = poolContent.length; + System.arraycopy(elementShape, 0, newShape, 1, elementShape.length); + + // put pool elements in result + INDArray result = Nd4j.create(newShape); + for(int i = 0; i < poolContent.length; ++i) { + result.putRow(i, poolContent[i]); + } + return result; + } +} diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/observation/preprocessor/pooling/CircularFifoObservationPool.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/observation/preprocessor/pooling/CircularFifoObservationPool.java new file mode 100644 index 000000000..6eb950e48 --- /dev/null +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/observation/preprocessor/pooling/CircularFifoObservationPool.java @@ -0,0 +1,95 @@ +/******************************************************************************* + * Copyright (c) 2015-2019 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +package org.deeplearning4j.rl4j.observation.preprocessor.pooling; + +import org.apache.commons.collections4.queue.CircularFifoQueue; +import org.nd4j.base.Preconditions; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.factory.Nd4j; + +/** + * CircularFifoObservationPool is used with the PoolingDataSetPreProcessor. This pool is a first-in first-out queue + * with a fixed size that replaces its oldest element if full. + * + * @author Alexandre Boulanger + */ +public class CircularFifoObservationPool implements ObservationPool { + private static final int DEFAULT_POOL_SIZE = 4; + + private final CircularFifoQueue queue; + + private CircularFifoObservationPool(Builder builder) { + queue = new CircularFifoQueue<>(builder.poolSize); + } + + public CircularFifoObservationPool() + { + this(DEFAULT_POOL_SIZE); + } + + public CircularFifoObservationPool(int poolSize) + { + Preconditions.checkArgument(poolSize > 0, "The pool size must be at least 1, got %s", poolSize); + queue = new CircularFifoQueue<>(poolSize); + } + + /** + * Add an element to the pool, if this addition would make the pool to overflow, the added element replaces the oldest one. + * @param elem + */ + public void add(INDArray elem) { + queue.add(elem); + } + + /** + * @return The content of the pool, returned in order from oldest to newest. + */ + public INDArray[] get() { + int size = queue.size(); + INDArray[] array = new INDArray[size]; + for (int i = 0; i < size; ++i) { + array[i] = queue.get(i).castTo(Nd4j.dataType()); + } + return array; + } + + public boolean isAtFullCapacity() { + return queue.isAtFullCapacity(); + } + + @Override + public void reset() { + queue.clear(); + } + + public static Builder builder() { + return new Builder(); + } + + public static class Builder { + private int poolSize = DEFAULT_POOL_SIZE; + + public Builder poolSize(int poolSize) { + this.poolSize = poolSize; + return this; + } + + public CircularFifoObservationPool build() { + return new CircularFifoObservationPool(this); + } + } +} diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/observation/preprocessor/pooling/ObservationPool.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/observation/preprocessor/pooling/ObservationPool.java new file mode 100644 index 000000000..1d8363ad8 --- /dev/null +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/observation/preprocessor/pooling/ObservationPool.java @@ -0,0 +1,32 @@ +/******************************************************************************* + * Copyright (c) 2015-2019 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +package org.deeplearning4j.rl4j.observation.preprocessor.pooling; + +import org.nd4j.linalg.api.ndarray.INDArray; + +/** + * ObservationPool is used with the PoolingDataSetPreProcessor. Used to supervise how data from the + * PoolingDataSetPreProcessor is stored. + * + * @author Alexandre Boulanger + */ +public interface ObservationPool { + void add(INDArray observation); + INDArray[] get(); + boolean isAtFullCapacity(); + void reset(); +} diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/observation/preprocessor/pooling/PoolContentAssembler.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/observation/preprocessor/pooling/PoolContentAssembler.java new file mode 100644 index 000000000..63b382a09 --- /dev/null +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/observation/preprocessor/pooling/PoolContentAssembler.java @@ -0,0 +1,30 @@ +/******************************************************************************* + * Copyright (c) 2015-2019 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +package org.deeplearning4j.rl4j.observation.preprocessor.pooling; + +import org.nd4j.linalg.api.ndarray.INDArray; + +/** + * A PoolContentAssembler is used with the PoolingDataSetPreProcessor. This interface defines how the array of INDArray + * returned by the ObservationPool is packaged into the single INDArray that will be set + * in the DataSet of PoolingDataSetPreProcessor.preProcess + * + * @author Alexandre Boulanger + */ +public interface PoolContentAssembler { + INDArray assemble(INDArray[] poolContent); +} diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/observation/preprocessor/PoolingDataSetPreProcessorTest.java b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/observation/preprocessor/PoolingDataSetPreProcessorTest.java new file mode 100644 index 000000000..db239b0f3 --- /dev/null +++ b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/observation/preprocessor/PoolingDataSetPreProcessorTest.java @@ -0,0 +1,164 @@ +package org.deeplearning4j.rl4j.observation.preprocessor; + +import org.deeplearning4j.rl4j.observation.preprocessor.pooling.ObservationPool; +import org.deeplearning4j.rl4j.observation.preprocessor.pooling.PoolContentAssembler; +import org.junit.Assert; +import org.junit.Test; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.dataset.DataSet; +import org.nd4j.linalg.factory.Nd4j; + +import static junit.framework.TestCase.assertTrue; +import static org.junit.Assert.assertEquals; + +public class PoolingDataSetPreProcessorTest { + + @Test(expected = NullPointerException.class) + public void when_dataSetIsNull_expect_NullPointerException() { + // Assemble + PoolingDataSetPreProcessor sut = PoolingDataSetPreProcessor.builder().build(); + + // Act + sut.preProcess(null); + } + + @Test(expected = IllegalArgumentException.class) + public void when_dataSetHasMoreThanOneExample_expect_IllegalArgumentException() { + // Assemble + PoolingDataSetPreProcessor sut = PoolingDataSetPreProcessor.builder().build(); + + // Act + sut.preProcess(new DataSet(Nd4j.rand(new long[] { 2, 2, 2 }), null)); + } + + @Test + public void when_dataSetIsEmpty_expect_EmptyDataSet() { + // Assemble + PoolingDataSetPreProcessor sut = PoolingDataSetPreProcessor.builder().build(); + DataSet ds = new DataSet(null, null); + + // Act + sut.preProcess(ds); + + // Assert + Assert.assertTrue(ds.isEmpty()); + } + + @Test + public void when_builderHasNoPoolOrAssembler_expect_defaultPoolBehavior() { + // Arrange + PoolingDataSetPreProcessor sut = PoolingDataSetPreProcessor.builder().build(); + DataSet[] observations = new DataSet[5]; + INDArray[] inputs = new INDArray[5]; + + + // Act + for(int i = 0; i < 5; ++i) { + inputs[i] = Nd4j.rand(new long[] { 1, 2, 2 }); + DataSet input = new DataSet(inputs[i], null); + sut.preProcess(input); + observations[i] = input; + } + + // Assert + assertTrue(observations[0].isEmpty()); + assertTrue(observations[1].isEmpty()); + assertTrue(observations[2].isEmpty()); + + for(int i = 0; i < 4; ++i) { + assertEquals(inputs[i].getDouble(new int[] { 0, 0, 0 }), observations[3].getFeatures().getDouble(new int[] { 0, i, 0, 0 }), 0.0001); + assertEquals(inputs[i].getDouble(new int[] { 0, 0, 1 }), observations[3].getFeatures().getDouble(new int[] { 0, i, 0, 1 }), 0.0001); + assertEquals(inputs[i].getDouble(new int[] { 0, 1, 0 }), observations[3].getFeatures().getDouble(new int[] { 0, i, 1, 0 }), 0.0001); + assertEquals(inputs[i].getDouble(new int[] { 0, 1, 1 }), observations[3].getFeatures().getDouble(new int[] { 0, i, 1, 1 }), 0.0001); + } + + for(int i = 0; i < 4; ++i) { + assertEquals(inputs[i+1].getDouble(new int[] { 0, 0, 0 }), observations[4].getFeatures().getDouble(new int[] { 0, i, 0, 0 }), 0.0001); + assertEquals(inputs[i+1].getDouble(new int[] { 0, 0, 1 }), observations[4].getFeatures().getDouble(new int[] { 0, i, 0, 1 }), 0.0001); + assertEquals(inputs[i+1].getDouble(new int[] { 0, 1, 0 }), observations[4].getFeatures().getDouble(new int[] { 0, i, 1, 0 }), 0.0001); + assertEquals(inputs[i+1].getDouble(new int[] { 0, 1, 1 }), observations[4].getFeatures().getDouble(new int[] { 0, i, 1, 1 }), 0.0001); + } + + } + + @Test + public void when_builderHasPoolAndAssembler_expect_paramPoolAndAssemblerAreUsed() { + // Arrange + INDArray input = Nd4j.rand(1, 1); + TestObservationPool pool = new TestObservationPool(); + TestPoolContentAssembler assembler = new TestPoolContentAssembler(); + PoolingDataSetPreProcessor sut = PoolingDataSetPreProcessor.builder() + .observablePool(pool) + .poolContentAssembler(assembler) + .build(); + + // Act + sut.preProcess(new DataSet(input, null)); + + // Assert + assertTrue(pool.isAtFullCapacityCalled); + assertTrue(pool.isGetCalled); + assertEquals(input.getDouble(0), pool.observation.getDouble(0), 0.0); + assertTrue(assembler.assembleIsCalled); + } + + @Test + public void when_pastInputChanges_expect_outputNotChanged() { + // Arrange + INDArray input = Nd4j.zeros(1, 1); + TestObservationPool pool = new TestObservationPool(); + TestPoolContentAssembler assembler = new TestPoolContentAssembler(); + PoolingDataSetPreProcessor sut = PoolingDataSetPreProcessor.builder() + .observablePool(pool) + .poolContentAssembler(assembler) + .build(); + + // Act + sut.preProcess(new DataSet(input, null)); + input.putScalar(0, 0, 1.0); + + // Assert + assertEquals(0.0, pool.observation.getDouble(0), 0.0); + } + + private static class TestObservationPool implements ObservationPool { + + public INDArray observation; + public boolean isGetCalled; + public boolean isAtFullCapacityCalled; + private boolean isResetCalled; + + @Override + public void add(INDArray observation) { + this.observation = observation; + } + + @Override + public INDArray[] get() { + isGetCalled = true; + return new INDArray[0]; + } + + @Override + public boolean isAtFullCapacity() { + isAtFullCapacityCalled = true; + return true; + } + + @Override + public void reset() { + isResetCalled = true; + } + } + + private static class TestPoolContentAssembler implements PoolContentAssembler { + + public boolean assembleIsCalled; + + @Override + public INDArray assemble(INDArray[] poolContent) { + assembleIsCalled = true; + return Nd4j.create(1, 1); + } + } +} diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/observation/preprocessor/SkippingDataSetPreProcessorTest.java b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/observation/preprocessor/SkippingDataSetPreProcessorTest.java new file mode 100644 index 000000000..3f1de3426 --- /dev/null +++ b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/observation/preprocessor/SkippingDataSetPreProcessorTest.java @@ -0,0 +1,70 @@ +package org.deeplearning4j.rl4j.observation.preprocessor; + +import org.junit.Test; +import org.nd4j.linalg.dataset.DataSet; +import org.nd4j.linalg.factory.Nd4j; + +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertTrue; + +public class SkippingDataSetPreProcessorTest { + @Test(expected = IllegalArgumentException.class) + public void when_ctorSkipFrameIsZero_expect_IllegalArgumentException() { + SkippingDataSetPreProcessor sut = new SkippingDataSetPreProcessor(0); + } + + @Test(expected = IllegalArgumentException.class) + public void when_builderSkipFrameIsZero_expect_IllegalArgumentException() { + SkippingDataSetPreProcessor sut = SkippingDataSetPreProcessor.builder() + .skipFrame(0) + .build(); + } + + @Test + public void when_skipFrameIs3_expect_Skip2OutOf3() { + // Arrange + SkippingDataSetPreProcessor sut = SkippingDataSetPreProcessor.builder() + .skipFrame(3) + .build(); + DataSet[] results = new DataSet[4]; + + // Act + for(int i = 0; i < 4; ++i) { + results[i] = new DataSet(Nd4j.create(new double[] { 123.0 }), null); + sut.preProcess(results[i]); + } + + // Assert + assertFalse(results[0].isEmpty()); + assertTrue(results[1].isEmpty()); + assertTrue(results[2].isEmpty()); + assertFalse(results[3].isEmpty()); + } + + @Test + public void when_resetIsCalled_expect_skippingIsReset() { + // Arrange + SkippingDataSetPreProcessor sut = SkippingDataSetPreProcessor.builder() + .skipFrame(3) + .build(); + DataSet[] results = new DataSet[4]; + + // Act + results[0] = new DataSet(Nd4j.create(new double[] { 123.0 }), null); + results[1] = new DataSet(Nd4j.create(new double[] { 123.0 }), null); + results[2] = new DataSet(Nd4j.create(new double[] { 123.0 }), null); + results[3] = new DataSet(Nd4j.create(new double[] { 123.0 }), null); + + sut.preProcess(results[0]); + sut.preProcess(results[1]); + sut.reset(); + sut.preProcess(results[2]); + sut.preProcess(results[3]); + + // Assert + assertFalse(results[0].isEmpty()); + assertTrue(results[1].isEmpty()); + assertFalse(results[2].isEmpty()); + assertTrue(results[3].isEmpty()); + } +} diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/observation/preprocessor/pooling/ChannelStackPoolContentAssemblerTest.java b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/observation/preprocessor/pooling/ChannelStackPoolContentAssemblerTest.java new file mode 100644 index 000000000..de0db015c --- /dev/null +++ b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/observation/preprocessor/pooling/ChannelStackPoolContentAssemblerTest.java @@ -0,0 +1,41 @@ +package org.deeplearning4j.rl4j.observation.preprocessor.pooling; + +import org.junit.Test; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.factory.Nd4j; + +import static org.junit.Assert.assertEquals; + +public class ChannelStackPoolContentAssemblerTest { + + @Test + public void when_assemble_expect_poolContentStackedOnChannel() { + // Assemble + ChannelStackPoolContentAssembler sut = new ChannelStackPoolContentAssembler(); + INDArray[] poolContent = new INDArray[] { + Nd4j.rand(2, 2), + Nd4j.rand(2, 2), + }; + + // Act + INDArray result = sut.assemble(poolContent); + + // Assert + assertEquals(3, result.shape().length); + assertEquals(2, result.shape()[0]); + assertEquals(2, result.shape()[1]); + assertEquals(2, result.shape()[2]); + + assertEquals(poolContent[0].getDouble(0, 0), result.getDouble(0, 0, 0), 0.0001); + assertEquals(poolContent[0].getDouble(0, 1), result.getDouble(0, 0, 1), 0.0001); + assertEquals(poolContent[0].getDouble(1, 0), result.getDouble(0, 1, 0), 0.0001); + assertEquals(poolContent[0].getDouble(1, 1), result.getDouble(0, 1, 1), 0.0001); + + assertEquals(poolContent[1].getDouble(0, 0), result.getDouble(1, 0, 0), 0.0001); + assertEquals(poolContent[1].getDouble(0, 1), result.getDouble(1, 0, 1), 0.0001); + assertEquals(poolContent[1].getDouble(1, 0), result.getDouble(1, 1, 0), 0.0001); + assertEquals(poolContent[1].getDouble(1, 1), result.getDouble(1, 1, 1), 0.0001); + + } + +} diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/observation/preprocessor/pooling/CircularFifoObservationPoolTest.java b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/observation/preprocessor/pooling/CircularFifoObservationPoolTest.java new file mode 100644 index 000000000..88e7b33dd --- /dev/null +++ b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/observation/preprocessor/pooling/CircularFifoObservationPoolTest.java @@ -0,0 +1,100 @@ +package org.deeplearning4j.rl4j.observation.preprocessor.pooling; + +import org.junit.Test; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.factory.Nd4j; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertTrue; + +public class CircularFifoObservationPoolTest { + + @Test(expected = IllegalArgumentException.class) + public void when_poolSizeZeroOrLess_expect_IllegalArgumentException() { + CircularFifoObservationPool sut = new CircularFifoObservationPool(0); + } + + @Test + public void when_poolIsEmpty_expect_NotReady() { + // Assemble + CircularFifoObservationPool sut = new CircularFifoObservationPool(); + + // Act + boolean isReady = sut.isAtFullCapacity(); + + // Assert + assertFalse(isReady); + } + + @Test + public void when_notEnoughElementsInPool_expect_notReady() { + // Assemble + CircularFifoObservationPool sut = new CircularFifoObservationPool(); + sut.add(Nd4j.create(new double[] { 123.0 })); + + // Act + boolean isReady = sut.isAtFullCapacity(); + + // Assert + assertFalse(isReady); + } + + @Test + public void when_enoughElementsInPool_expect_ready() { + // Assemble + CircularFifoObservationPool sut = CircularFifoObservationPool.builder() + .poolSize(2) + .build(); + sut.add(Nd4j.createFromArray(123.0)); + sut.add(Nd4j.createFromArray(123.0)); + + // Act + boolean isReady = sut.isAtFullCapacity(); + + // Assert + assertTrue(isReady); + } + + @Test + public void when_addMoreThanSize_expect_getReturnOnlyLastElements() { + // Assemble + CircularFifoObservationPool sut = CircularFifoObservationPool.builder().build(); + sut.add(Nd4j.createFromArray(0.0)); + sut.add(Nd4j.createFromArray(1.0)); + sut.add(Nd4j.createFromArray(2.0)); + sut.add(Nd4j.createFromArray(3.0)); + sut.add(Nd4j.createFromArray(4.0)); + sut.add(Nd4j.createFromArray(5.0)); + sut.add(Nd4j.createFromArray(6.0)); + + // Act + INDArray[] result = sut.get(); + + // Assert + assertEquals(3.0, result[0].getDouble(0), 0.0); + assertEquals(4.0, result[1].getDouble(0), 0.0); + assertEquals(5.0, result[2].getDouble(0), 0.0); + assertEquals(6.0, result[3].getDouble(0), 0.0); + } + + @Test + public void when_resetIsCalled_expect_poolContentFlushed() { + // Assemble + CircularFifoObservationPool sut = CircularFifoObservationPool.builder().build(); + sut.add(Nd4j.createFromArray(0.0)); + sut.add(Nd4j.createFromArray(1.0)); + sut.add(Nd4j.createFromArray(2.0)); + sut.add(Nd4j.createFromArray(3.0)); + sut.add(Nd4j.createFromArray(4.0)); + sut.add(Nd4j.createFromArray(5.0)); + sut.add(Nd4j.createFromArray(6.0)); + sut.reset(); + + // Act + INDArray[] result = sut.get(); + + // Assert + assertEquals(0, result.length); + } +}