From 58aa5a3a9bbd2c1cbe041c81a57182e4c5d62606 Mon Sep 17 00:00:00 2001 From: Samuel Audet Date: Thu, 5 Mar 2020 14:43:13 +0900 Subject: [PATCH] RL4J: Add TransformProcess, part 1 (#8711) * Added TransformProcess, part 1 Signed-off-by: unknown * Renamed TemporalMergeTransform to HistoryMergeTransform Signed-off-by: unknown * changed INDArrayHelper to use Nd4j.expandDims Signed-off-by: Alexandre Boulanger * Adjusted copyrights Signed-off-by: unknown --- .../org/datavec/api/transform/ColumnOp.java | 5 +- .../org/datavec/api/transform/Operation.java | 48 ++--- .../image/transform/ImageTransform.java | 11 +- rl4j/rl4j-core/pom.xml | 7 + .../INDArrayHelper.java} | 69 ++++--- .../learning/async/AsyncThreadDiscrete.java | 3 +- .../qlearning/discrete/QLearningDiscrete.java | 2 +- .../PoolingDataSetPreProcessor.java | 130 ------------- .../SkippingDataSetPreProcessor.java | 62 ------ .../transform/FilterOperation.java | 35 ++++ .../ResettableOperation.java} | 58 +++--- .../filter/UniformSkippingFilter.java | 45 +++++ .../legacy/EncodableToINDArrayTransform.java | 41 ++++ .../EncodableToImageWriteableTransform.java | 48 +++++ .../ImageWriteableToINDArrayTransform.java | 37 ++++ .../operation/HistoryMergeTransform.java | 147 +++++++++++++++ .../historymerge/CircularFifoStore.java} | 177 ++++++++---------- .../historymerge/HistoryMergeAssembler.java | 35 ++++ .../HistoryMergeElementStore.java | 51 +++++ .../historymerge/HistoryStackAssembler.java} | 105 +++++------ .../deeplearning4j/rl4j/policy/Policy.java | 2 +- .../rl4j/helper/INDArrayHelperTest.java | 38 ++++ .../PoolingDataSetPreProcessorTest.java | 164 ---------------- .../SkippingDataSetPreProcessorTest.java | 70 ------- .../ChannelStackPoolContentAssemblerTest.java | 41 ---- .../CircularFifoObservationPoolTest.java | 100 ---------- .../operation/HistoryMergeTransformTest.java | 166 ++++++++++++++++ .../historymerge/CircularFifoStoreTest.java | 77 ++++++++ .../HistoryStackAssemblerTest.java | 37 ++++ 29 files changed, 989 insertions(+), 822 deletions(-) rename rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/observation/preprocessor/ResettableDataSetPreProcessor.java => datavec/datavec-api/src/main/java/org/datavec/api/transform/Operation.java (63%) rename rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/{observation/preprocessor/pooling/PoolContentAssembler.java => helper/INDArrayHelper.java} (52%) delete mode 100644 rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/observation/preprocessor/PoolingDataSetPreProcessor.java delete mode 100644 rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/observation/preprocessor/SkippingDataSetPreProcessor.java create mode 100644 rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/observation/transform/FilterOperation.java rename rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/observation/{preprocessor/pooling/ObservationPool.java => transform/ResettableOperation.java} (62%) create mode 100644 rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/observation/transform/filter/UniformSkippingFilter.java create mode 100644 rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/observation/transform/legacy/EncodableToINDArrayTransform.java create mode 100644 rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/observation/transform/legacy/EncodableToImageWriteableTransform.java create mode 100644 rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/observation/transform/legacy/ImageWriteableToINDArrayTransform.java create mode 100644 rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/observation/transform/operation/HistoryMergeTransform.java rename rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/observation/{preprocessor/pooling/CircularFifoObservationPool.java => transform/operation/historymerge/CircularFifoStore.java} (52%) create mode 100644 rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/observation/transform/operation/historymerge/HistoryMergeAssembler.java create mode 100644 rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/observation/transform/operation/historymerge/HistoryMergeElementStore.java rename rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/observation/{preprocessor/pooling/ChannelStackPoolContentAssembler.java => transform/operation/historymerge/HistoryStackAssembler.java} (62%) create mode 100644 rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/helper/INDArrayHelperTest.java delete mode 100644 rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/observation/preprocessor/PoolingDataSetPreProcessorTest.java delete mode 100644 rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/observation/preprocessor/SkippingDataSetPreProcessorTest.java delete mode 100644 rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/observation/preprocessor/pooling/ChannelStackPoolContentAssemblerTest.java delete mode 100644 rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/observation/preprocessor/pooling/CircularFifoObservationPoolTest.java create mode 100644 rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/observation/transform/operation/HistoryMergeTransformTest.java create mode 100644 rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/observation/transform/operation/historymerge/CircularFifoStoreTest.java create mode 100644 rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/observation/transform/operation/historymerge/HistoryStackAssemblerTest.java diff --git a/datavec/datavec-api/src/main/java/org/datavec/api/transform/ColumnOp.java b/datavec/datavec-api/src/main/java/org/datavec/api/transform/ColumnOp.java index d8caf1962..ae2d17788 100644 --- a/datavec/datavec-api/src/main/java/org/datavec/api/transform/ColumnOp.java +++ b/datavec/datavec-api/src/main/java/org/datavec/api/transform/ColumnOp.java @@ -26,10 +26,7 @@ import org.datavec.api.transform.schema.Schema; * * @author Adam Gibson */ -public interface ColumnOp { - /** Get the output schema for this transformation, given an input schema */ - Schema transform(Schema inputSchema); - +public interface ColumnOp extends Operation { /** Set the input schema. */ diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/observation/preprocessor/ResettableDataSetPreProcessor.java b/datavec/datavec-api/src/main/java/org/datavec/api/transform/Operation.java similarity index 63% rename from rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/observation/preprocessor/ResettableDataSetPreProcessor.java rename to datavec/datavec-api/src/main/java/org/datavec/api/transform/Operation.java index 46ff4e39c..ecb58543b 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/observation/preprocessor/ResettableDataSetPreProcessor.java +++ b/datavec/datavec-api/src/main/java/org/datavec/api/transform/Operation.java @@ -1,28 +1,20 @@ -/******************************************************************************* - * Copyright (c) 2015-2019 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.deeplearning4j.rl4j.observation.preprocessor; - -import org.nd4j.linalg.dataset.api.DataSetPreProcessor; - -/** - * A base class for all DataSetPreProcessor that must be reset between each MDP sessions (games). - * - * @author Alexandre Boulanger - */ -public abstract class ResettableDataSetPreProcessor implements DataSetPreProcessor { - public abstract void reset(); -} +/******************************************************************************* + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ +package org.datavec.api.transform; + +public interface Operation { + TOut transform(TIn input); +} diff --git a/datavec/datavec-data/datavec-data-image/src/main/java/org/datavec/image/transform/ImageTransform.java b/datavec/datavec-data/datavec-data-image/src/main/java/org/datavec/image/transform/ImageTransform.java index afcdf894f..bdd74df13 100644 --- a/datavec/datavec-data/datavec-data-image/src/main/java/org/datavec/image/transform/ImageTransform.java +++ b/datavec/datavec-data/datavec-data-image/src/main/java/org/datavec/image/transform/ImageTransform.java @@ -16,6 +16,7 @@ package org.datavec.image.transform; +import org.datavec.api.transform.Operation; import org.datavec.image.data.ImageWritable; import org.nd4j.shade.jackson.annotation.JsonInclude; import org.nd4j.shade.jackson.annotation.JsonTypeInfo; @@ -29,15 +30,7 @@ import java.util.Random; */ @JsonInclude(JsonInclude.Include.NON_NULL) @JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class") -public interface ImageTransform { - - /** - * Takes an image and returns a transformed image. - * - * @param image to transform, null == end of stream - * @return transformed image - */ - ImageWritable transform(ImageWritable image); +public interface ImageTransform extends Operation { /** * Takes an image and returns a transformed image. diff --git a/rl4j/rl4j-core/pom.xml b/rl4j/rl4j-core/pom.xml index a78157603..c08615250 100644 --- a/rl4j/rl4j-core/pom.xml +++ b/rl4j/rl4j-core/pom.xml @@ -102,6 +102,13 @@ gson ${gson.version} + + + org.datavec + datavec-api + ${datavec.version} + + diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/observation/preprocessor/pooling/PoolContentAssembler.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/helper/INDArrayHelper.java similarity index 52% rename from rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/observation/preprocessor/pooling/PoolContentAssembler.java rename to rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/helper/INDArrayHelper.java index 63b382a09..7d93b1175 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/observation/preprocessor/pooling/PoolContentAssembler.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/helper/INDArrayHelper.java @@ -1,30 +1,39 @@ -/******************************************************************************* - * Copyright (c) 2015-2019 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.deeplearning4j.rl4j.observation.preprocessor.pooling; - -import org.nd4j.linalg.api.ndarray.INDArray; - -/** - * A PoolContentAssembler is used with the PoolingDataSetPreProcessor. This interface defines how the array of INDArray - * returned by the ObservationPool is packaged into the single INDArray that will be set - * in the DataSet of PoolingDataSetPreProcessor.preProcess - * - * @author Alexandre Boulanger - */ -public interface PoolContentAssembler { - INDArray assemble(INDArray[] poolContent); -} +/******************************************************************************* + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ +package org.deeplearning4j.rl4j.helper; + +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.factory.Nd4j; + +/** + * INDArray helper methods used by RL4J + * + * @author Alexandre Boulanger + */ +public class INDArrayHelper { + /** + * MultiLayerNetwork and ComputationGraph expect the first dimension to be the number of examples in the INDArray. + * In the case of RL4J, it must be 1. This method will return a INDArray with the correct shape. + * + * @param source A INDArray + * @return The source INDArray with the correct shape + */ + public static INDArray forceCorrectShape(INDArray source) { + return source.shape()[0] == 1 + ? source + : Nd4j.expandDims(source, 0); + } +} diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/AsyncThreadDiscrete.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/AsyncThreadDiscrete.java index 0b3eb1c72..27d49c366 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/AsyncThreadDiscrete.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/AsyncThreadDiscrete.java @@ -90,7 +90,7 @@ public abstract class AsyncThreadDiscrete accuReward += stepReply.getReward() * getConf().getRewardFactor(); //if it's not a skipped frame, you can do a step of training - if (!obs.isSkipped() || stepReply.isDone()) { + if (!obs.isSkipped()) { INDArray[] output = current.outputAll(obs.getData()); rewards.add(new MiniTrans(obs.getData(), action, output, accuReward)); @@ -99,7 +99,6 @@ public abstract class AsyncThreadDiscrete } obs = stepReply.getObservation(); - reward += stepReply.getReward(); incrementStep(); diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/QLearningDiscrete.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/QLearningDiscrete.java index 0fa0f33e6..614ddb793 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/QLearningDiscrete.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/QLearningDiscrete.java @@ -158,7 +158,7 @@ public abstract class QLearningDiscrete extends QLearning - * The PoolingDataSetPreProcessor requires two sub components:
- * 1) The ObservationPool that supervises what and how input observations are kept. (ex.: Circular FIFO, trailing min/max/avg, etc...) - * The default is a Circular FIFO. - * 2) The PoolContentAssembler that will assemble the pool content into a resulting single INDArray. (ex.: stacked along a dimention, squashed into a single observation, etc...) - * The default is stacking along the dimension 0. - * - * @author Alexandre Boulanger - */ -public class PoolingDataSetPreProcessor extends ResettableDataSetPreProcessor { - private final ObservationPool observationPool; - private final PoolContentAssembler poolContentAssembler; - - protected PoolingDataSetPreProcessor(PoolingDataSetPreProcessor.Builder builder) - { - observationPool = builder.observationPool; - poolContentAssembler = builder.poolContentAssembler; - } - - /** - * Note: preProcess will empty the processed dataset if the pool has not filled. Empty datasets should ignored by the - * Policy/Learning class and other DataSetPreProcessors - * - * @param dataSet - */ - @Override - public void preProcess(DataSet dataSet) { - Preconditions.checkNotNull(dataSet, "Encountered null dataSet"); - - if(dataSet.isEmpty()) { - return; - } - - Preconditions.checkArgument(dataSet.numExamples() == 1, "Pooling datasets conatining more than one example is not supported"); - - // store a duplicate in the pool - observationPool.add(dataSet.getFeatures().slice(0, 0).dup()); - if(!observationPool.isAtFullCapacity()) { - dataSet.setFeatures(null); - return; - } - - INDArray result = poolContentAssembler.assemble(observationPool.get()); - - // return a DataSet containing only 1 example (the result) - long[] resultShape = result.shape(); - long[] newShape = new long[resultShape.length + 1]; - newShape[0] = 1; - System.arraycopy(resultShape, 0, newShape, 1, resultShape.length); - - dataSet.setFeatures(result.reshape(newShape)); - } - - public static PoolingDataSetPreProcessor.Builder builder() { - return new PoolingDataSetPreProcessor.Builder(); - } - - @Override - public void reset() { - observationPool.reset(); - } - - public static class Builder { - private ObservationPool observationPool; - private PoolContentAssembler poolContentAssembler; - - /** - * Default is CircularFifoObservationPool - */ - public PoolingDataSetPreProcessor.Builder observablePool(ObservationPool observationPool) { - this.observationPool = observationPool; - return this; - } - - /** - * Default is ChannelStackPoolContentAssembler - */ - public PoolingDataSetPreProcessor.Builder poolContentAssembler(PoolContentAssembler poolContentAssembler) { - this.poolContentAssembler = poolContentAssembler; - return this; - } - - public PoolingDataSetPreProcessor build() { - if(observationPool == null) { - observationPool = new CircularFifoObservationPool(); - } - - if(poolContentAssembler == null) { - poolContentAssembler = new ChannelStackPoolContentAssembler(); - } - - return new PoolingDataSetPreProcessor(this); - } - } - -} diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/observation/preprocessor/SkippingDataSetPreProcessor.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/observation/preprocessor/SkippingDataSetPreProcessor.java deleted file mode 100644 index 940966823..000000000 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/observation/preprocessor/SkippingDataSetPreProcessor.java +++ /dev/null @@ -1,62 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2019 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.deeplearning4j.rl4j.observation.preprocessor; - -import lombok.Builder; -import org.nd4j.base.Preconditions; -import org.nd4j.linalg.dataset.api.DataSet; - -/** - * The SkippingDataSetPreProcessor will either do nothing to the input (when not skipped) or will empty - * the input DataSet when skipping. - * - * @author Alexandre Boulanger - */ -public class SkippingDataSetPreProcessor extends ResettableDataSetPreProcessor { - - private final int skipFrame; - - private int currentIdx = 0; - - /** - * @param skipFrame For example, a skipFrame of 4 will skip 3 out of 4 observations. - */ - @Builder - public SkippingDataSetPreProcessor(int skipFrame) { - Preconditions.checkArgument(skipFrame > 0, "skipFrame must be greater than 0, got %s", skipFrame); - this.skipFrame = skipFrame; - } - - @Override - public void preProcess(DataSet dataSet) { - Preconditions.checkNotNull(dataSet, "Encountered null dataSet"); - - if(dataSet.isEmpty()) { - return; - } - - if(currentIdx++ % skipFrame != 0) { - dataSet.setFeatures(null); - dataSet.setLabels(null); - } - } - - @Override - public void reset() { - currentIdx = 0; - } -} diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/observation/transform/FilterOperation.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/observation/transform/FilterOperation.java new file mode 100644 index 000000000..74b67bcaa --- /dev/null +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/observation/transform/FilterOperation.java @@ -0,0 +1,35 @@ +/******************************************************************************* + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ +package org.deeplearning4j.rl4j.observation.transform; + +import java.util.Map; + +/** + * Used with {@link TransformProcess TransformProcess} to filter-out an observation. + * + * @author Alexandre Boulanger + */ +public interface FilterOperation { + /** + * The logic that determines if the observation should be skipped. + * + * @param channelsData the name of the channel + * @param currentObservationStep The step number if the observation in the current episode. + * @param isFinalObservation true if this is the last observation of the episode + * @return true if the observation should be skipped + */ + boolean isSkipped(Map channelsData, int currentObservationStep, boolean isFinalObservation); +} diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/observation/preprocessor/pooling/ObservationPool.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/observation/transform/ResettableOperation.java similarity index 62% rename from rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/observation/preprocessor/pooling/ObservationPool.java rename to rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/observation/transform/ResettableOperation.java index 1d8363ad8..a17bdc6c4 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/observation/preprocessor/pooling/ObservationPool.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/observation/transform/ResettableOperation.java @@ -1,32 +1,26 @@ -/******************************************************************************* - * Copyright (c) 2015-2019 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.deeplearning4j.rl4j.observation.preprocessor.pooling; - -import org.nd4j.linalg.api.ndarray.INDArray; - -/** - * ObservationPool is used with the PoolingDataSetPreProcessor. Used to supervise how data from the - * PoolingDataSetPreProcessor is stored. - * - * @author Alexandre Boulanger - */ -public interface ObservationPool { - void add(INDArray observation); - INDArray[] get(); - boolean isAtFullCapacity(); - void reset(); -} +/******************************************************************************* + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ +package org.deeplearning4j.rl4j.observation.transform; + +/** + * The {@link TransformProcess TransformProcess} will call reset() (at the start of an episode) of any step that implement this interface. + */ +public interface ResettableOperation { + /** + * Called by TransformProcess when an episode starts. See {@link TransformProcess#reset() TransformProcess.reset()} + */ + void reset(); +} diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/observation/transform/filter/UniformSkippingFilter.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/observation/transform/filter/UniformSkippingFilter.java new file mode 100644 index 000000000..0b31752d4 --- /dev/null +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/observation/transform/filter/UniformSkippingFilter.java @@ -0,0 +1,45 @@ +/******************************************************************************* + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ +package org.deeplearning4j.rl4j.observation.transform.filter; + +import org.deeplearning4j.rl4j.observation.transform.FilterOperation; +import org.nd4j.base.Preconditions; +import java.util.Map; + +/** + * Used with {@link org.deeplearning4j.rl4j.observation.transform.TransformProcess TransformProcess}. Will cause the + * transform process to skip a fixed number of frames between non skipped ones. + * + * @author Alexandre Boulanger + */ +public class UniformSkippingFilter implements FilterOperation { + + private final int skipFrame; + + /** + * @param skipFrame Will cause the filter to keep (not skip) 1 frame every skipFrames. + */ + public UniformSkippingFilter(int skipFrame) { + Preconditions.checkArgument(skipFrame > 0, "skipFrame should be greater than 0"); + + this.skipFrame = skipFrame; + } + + @Override + public boolean isSkipped(Map channelsData, int currentObservationStep, boolean isFinalObservation) { + return !isFinalObservation && (currentObservationStep % skipFrame != 0); + } +} diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/observation/transform/legacy/EncodableToINDArrayTransform.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/observation/transform/legacy/EncodableToINDArrayTransform.java new file mode 100644 index 000000000..a9214bbff --- /dev/null +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/observation/transform/legacy/EncodableToINDArrayTransform.java @@ -0,0 +1,41 @@ +/******************************************************************************* + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ +package org.deeplearning4j.rl4j.observation.transform.legacy; + +import org.bytedeco.javacv.OpenCVFrameConverter; +import org.bytedeco.opencv.opencv_core.Mat; +import org.datavec.api.transform.Operation; +import org.datavec.image.data.ImageWritable; +import org.deeplearning4j.rl4j.space.Encodable; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.factory.Nd4j; + +import static org.bytedeco.opencv.global.opencv_core.CV_32FC; + +public class EncodableToINDArrayTransform implements Operation { + + private final int[] shape; + + public EncodableToINDArrayTransform(int[] shape) { + this.shape = shape; + } + + @Override + public INDArray transform(Encodable encodable) { + return Nd4j.create(encodable.toArray()).reshape(shape); + } + +} diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/observation/transform/legacy/EncodableToImageWriteableTransform.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/observation/transform/legacy/EncodableToImageWriteableTransform.java new file mode 100644 index 000000000..b1c32abed --- /dev/null +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/observation/transform/legacy/EncodableToImageWriteableTransform.java @@ -0,0 +1,48 @@ +/******************************************************************************* + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ +package org.deeplearning4j.rl4j.observation.transform.legacy; + +import org.bytedeco.javacv.OpenCVFrameConverter; +import org.bytedeco.opencv.opencv_core.Mat; +import org.datavec.api.transform.Operation; +import org.datavec.image.data.ImageWritable; +import org.deeplearning4j.rl4j.space.Encodable; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.factory.Nd4j; + +import static org.bytedeco.opencv.global.opencv_core.CV_32FC; + +public class EncodableToImageWriteableTransform implements Operation { + + private final OpenCVFrameConverter.ToMat converter = new OpenCVFrameConverter.ToMat(); + private final int height; + private final int width; + private final int colorChannels; + + public EncodableToImageWriteableTransform(int height, int width, int colorChannels) { + this.height = height; + this.width = width; + this.colorChannels = colorChannels; + } + + @Override + public ImageWritable transform(Encodable encodable) { + INDArray indArray = Nd4j.create((encodable).toArray()).reshape(height, width, colorChannels); + Mat mat = new Mat(height, width, CV_32FC(3), indArray.data().pointer()); + return new ImageWritable(converter.convert(mat)); + } + +} diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/observation/transform/legacy/ImageWriteableToINDArrayTransform.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/observation/transform/legacy/ImageWriteableToINDArrayTransform.java new file mode 100644 index 000000000..d20f9f9f8 --- /dev/null +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/observation/transform/legacy/ImageWriteableToINDArrayTransform.java @@ -0,0 +1,37 @@ +package org.deeplearning4j.rl4j.observation.transform.legacy; + +import org.datavec.api.transform.Operation; +import org.datavec.image.data.ImageWritable; +import org.datavec.image.loader.NativeImageLoader; +import org.deeplearning4j.rl4j.space.Encodable; +import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.factory.Nd4j; + +import java.io.IOException; + +public class ImageWriteableToINDArrayTransform implements Operation { + + private final int height; + private final int width; + private final NativeImageLoader loader; + + public ImageWriteableToINDArrayTransform(int height, int width) { + this.height = height; + this.width = width; + this.loader = new NativeImageLoader(height, width); + } + + @Override + public INDArray transform(ImageWritable imageWritable) { + INDArray out = null; + try { + out = loader.asMatrix(imageWritable); + } catch (IOException e) { + e.printStackTrace(); + } + out = out.reshape(1, height, width); + INDArray compressed = out.castTo(DataType.UINT8); + return compressed; + } +} diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/observation/transform/operation/HistoryMergeTransform.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/observation/transform/operation/HistoryMergeTransform.java new file mode 100644 index 000000000..e27d1134c --- /dev/null +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/observation/transform/operation/HistoryMergeTransform.java @@ -0,0 +1,147 @@ +/******************************************************************************* + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ +package org.deeplearning4j.rl4j.observation.transform.operation; + +import org.datavec.api.transform.Operation; +import org.deeplearning4j.rl4j.helper.INDArrayHelper; +import org.deeplearning4j.rl4j.observation.transform.ResettableOperation; +import org.deeplearning4j.rl4j.observation.transform.operation.historymerge.CircularFifoStore; +import org.deeplearning4j.rl4j.observation.transform.operation.historymerge.HistoryMergeAssembler; +import org.deeplearning4j.rl4j.observation.transform.operation.historymerge.HistoryMergeElementStore; +import org.deeplearning4j.rl4j.observation.transform.operation.historymerge.HistoryStackAssembler; +import org.nd4j.linalg.api.ndarray.INDArray; + +/** + * The HistoryMergeTransform will accumulate features from incoming INDArrays and will assemble its content + * into a new INDArray containing a single example. + * + * This is used in scenarios where motion in an important element. + * + * There is a special case: + * * When the store is not full (not ready), the data from the incoming INDArray is stored but null is returned (will be interpreted as a skipped observation) + *
+ * The HistoryMergeTransform requires two sub components:
+ * 1) The {@link HistoryMergeElementStore HistoryMergeElementStore} that supervises what and how input INDArrays are kept. (ex.: Circular FIFO, trailing min/max/avg, etc...) + * The default is a Circular FIFO. + * 2) The {@link HistoryMergeAssembler HistoryMergeAssembler} that will assemble the store content into a resulting single INDArray. (ex.: stacked along a dimension, squashed into a single observation, etc...) + * The default is stacking along the dimension 0. + * + * @author Alexandre Boulanger + */ +public class HistoryMergeTransform implements Operation, ResettableOperation { + + private final HistoryMergeElementStore historyMergeElementStore; + private final HistoryMergeAssembler historyMergeAssembler; + private final boolean shouldStoreCopy; + private final boolean isFirstDimenstionBatch; + + private HistoryMergeTransform(Builder builder) { + this.historyMergeElementStore = builder.historyMergeElementStore; + this.historyMergeAssembler = builder.historyMergeAssembler; + this.shouldStoreCopy = builder.shouldStoreCopy; + this.isFirstDimenstionBatch = builder.isFirstDimenstionBatch; + } + + @Override + public INDArray transform(INDArray input) { + INDArray element; + if(isFirstDimenstionBatch) { + element = input.slice(0, 0); + } + else { + element = input; + } + + if(shouldStoreCopy) { + element = element.dup(); + } + + historyMergeElementStore.add(element); + if(!historyMergeElementStore.isReady()) { + return null; + } + + INDArray result = historyMergeAssembler.assemble(historyMergeElementStore.get()); + + return INDArrayHelper.forceCorrectShape(result); + } + + @Override + public void reset() { + historyMergeElementStore.reset(); + } + + public static Builder builder() { + return new Builder(); + } + + public static class Builder { + private HistoryMergeElementStore historyMergeElementStore; + private HistoryMergeAssembler historyMergeAssembler; + private boolean shouldStoreCopy = false; + private boolean isFirstDimenstionBatch = false; + + /** + * Default is {@link CircularFifoStore CircularFifoStore} + */ + public Builder elementStore(HistoryMergeElementStore store) { + this.historyMergeElementStore = store; + return this; + } + + /** + * Default is {@link HistoryStackAssembler HistoryStackAssembler} + */ + public Builder assembler(HistoryMergeAssembler assembler) { + this.historyMergeAssembler = assembler; + return this; + } + + /** + * If true, tell the HistoryMergeTransform to store copies of incoming INDArrays. + * (To prevent later in-place changes to a stored INDArray from changing what has been stored) + * + * Default is false + */ + public Builder shouldStoreCopy(boolean shouldStoreCopy) { + this.shouldStoreCopy = shouldStoreCopy; + return this; + } + + /** + * If true, tell the HistoryMergeTransform that the first dimension of the input INDArray is the batch count. + * When this is the case, the HistoryMergeTransform will slice the input like this [batch, height, width] -> [height, width] + * + * Default is false + */ + public Builder isFirstDimenstionBatch(boolean isFirstDimenstionBatch) { + this.isFirstDimenstionBatch = isFirstDimenstionBatch; + return this; + } + + public HistoryMergeTransform build() { + if(historyMergeElementStore == null) { + historyMergeElementStore = new CircularFifoStore(); + } + + if(historyMergeAssembler == null) { + historyMergeAssembler = new HistoryStackAssembler(); + } + + return new HistoryMergeTransform(this); + } + } +} diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/observation/preprocessor/pooling/CircularFifoObservationPool.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/observation/transform/operation/historymerge/CircularFifoStore.java similarity index 52% rename from rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/observation/preprocessor/pooling/CircularFifoObservationPool.java rename to rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/observation/transform/operation/historymerge/CircularFifoStore.java index 6eb950e48..db1cbb2bd 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/observation/preprocessor/pooling/CircularFifoObservationPool.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/observation/transform/operation/historymerge/CircularFifoStore.java @@ -1,95 +1,82 @@ -/******************************************************************************* - * Copyright (c) 2015-2019 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.deeplearning4j.rl4j.observation.preprocessor.pooling; - -import org.apache.commons.collections4.queue.CircularFifoQueue; -import org.nd4j.base.Preconditions; -import org.nd4j.linalg.api.ndarray.INDArray; -import org.nd4j.linalg.factory.Nd4j; - -/** - * CircularFifoObservationPool is used with the PoolingDataSetPreProcessor. This pool is a first-in first-out queue - * with a fixed size that replaces its oldest element if full. - * - * @author Alexandre Boulanger - */ -public class CircularFifoObservationPool implements ObservationPool { - private static final int DEFAULT_POOL_SIZE = 4; - - private final CircularFifoQueue queue; - - private CircularFifoObservationPool(Builder builder) { - queue = new CircularFifoQueue<>(builder.poolSize); - } - - public CircularFifoObservationPool() - { - this(DEFAULT_POOL_SIZE); - } - - public CircularFifoObservationPool(int poolSize) - { - Preconditions.checkArgument(poolSize > 0, "The pool size must be at least 1, got %s", poolSize); - queue = new CircularFifoQueue<>(poolSize); - } - - /** - * Add an element to the pool, if this addition would make the pool to overflow, the added element replaces the oldest one. - * @param elem - */ - public void add(INDArray elem) { - queue.add(elem); - } - - /** - * @return The content of the pool, returned in order from oldest to newest. - */ - public INDArray[] get() { - int size = queue.size(); - INDArray[] array = new INDArray[size]; - for (int i = 0; i < size; ++i) { - array[i] = queue.get(i).castTo(Nd4j.dataType()); - } - return array; - } - - public boolean isAtFullCapacity() { - return queue.isAtFullCapacity(); - } - - @Override - public void reset() { - queue.clear(); - } - - public static Builder builder() { - return new Builder(); - } - - public static class Builder { - private int poolSize = DEFAULT_POOL_SIZE; - - public Builder poolSize(int poolSize) { - this.poolSize = poolSize; - return this; - } - - public CircularFifoObservationPool build() { - return new CircularFifoObservationPool(this); - } - } -} +/******************************************************************************* + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ +package org.deeplearning4j.rl4j.observation.transform.operation.historymerge; + +import org.apache.commons.collections4.queue.CircularFifoQueue; +import org.deeplearning4j.rl4j.observation.transform.operation.HistoryMergeTransform; +import org.nd4j.base.Preconditions; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.factory.Nd4j; + +/** + * CircularFifoStore is used with the {@link HistoryMergeTransform HistoryMergeTransform}. This store is a first-in first-out queue + * with a fixed size that replaces its oldest element if full. + * + * @author Alexandre Boulanger + */ +public class CircularFifoStore implements HistoryMergeElementStore { + private static final int DEFAULT_STORE_SIZE = 4; + + private final CircularFifoQueue queue; + + public CircularFifoStore() { + this(DEFAULT_STORE_SIZE); + } + + public CircularFifoStore(int size) { + Preconditions.checkArgument(size > 0, "The size must be at least 1, got %s", size); + queue = new CircularFifoQueue<>(size); + } + + /** + * Add an element to the store, if this addition would make the store to overflow, the new element replaces the oldest. + * @param elem + */ + @Override + public void add(INDArray elem) { + queue.add(elem); + } + + /** + * @return The content of the store, returned in order from oldest to newest. + */ + @Override + public INDArray[] get() { + int size = queue.size(); + INDArray[] array = new INDArray[size]; + for (int i = 0; i < size; ++i) { + array[i] = queue.get(i).castTo(Nd4j.dataType()); + } + return array; + } + + /** + * The CircularFifoStore needs to be completely filled before being ready. + * @return false when the number of elements in the store is less than the store capacity (default is 4) + */ + @Override + public boolean isReady() { + return queue.isAtFullCapacity(); + } + + /** + * Clears the store. + */ + @Override + public void reset() { + queue.clear(); + } +} diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/observation/transform/operation/historymerge/HistoryMergeAssembler.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/observation/transform/operation/historymerge/HistoryMergeAssembler.java new file mode 100644 index 000000000..0487d7c57 --- /dev/null +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/observation/transform/operation/historymerge/HistoryMergeAssembler.java @@ -0,0 +1,35 @@ +/******************************************************************************* + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ +package org.deeplearning4j.rl4j.observation.transform.operation.historymerge; + +import org.deeplearning4j.rl4j.observation.transform.operation.HistoryMergeTransform; +import org.nd4j.linalg.api.ndarray.INDArray; + +/** + * A HistoryMergeAssembler is used with the {@link HistoryMergeTransform HistoryMergeTransform}. This interface defines how the array of INDArray + * given by the {@link HistoryMergeElementStore HistoryMergeElementStore} is packaged into the single INDArray that will be + * returned by the HistoryMergeTransform + * + * @author Alexandre Boulanger + */ +public interface HistoryMergeAssembler { + /** + * Assemble an array of INDArray into a single INArray + * @param elements The input INDArray[] + * @return the assembled INDArray + */ + INDArray assemble(INDArray[] elements); +} diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/observation/transform/operation/historymerge/HistoryMergeElementStore.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/observation/transform/operation/historymerge/HistoryMergeElementStore.java new file mode 100644 index 000000000..04d61da45 --- /dev/null +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/observation/transform/operation/historymerge/HistoryMergeElementStore.java @@ -0,0 +1,51 @@ +/******************************************************************************* + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ +package org.deeplearning4j.rl4j.observation.transform.operation.historymerge; + +import org.deeplearning4j.rl4j.observation.transform.operation.HistoryMergeTransform; +import org.nd4j.linalg.api.ndarray.INDArray; + +/** + * HistoryMergeElementStore is used with the {@link HistoryMergeTransform HistoryMergeTransform}. Used to supervise how data from the + * HistoryMergeTransform is stored. + * + * @author Alexandre Boulanger + */ +public interface HistoryMergeElementStore { + /** + * Add an element into the store + * @param observation + */ + void add(INDArray observation); + + /** + * Get the content of the store + * @return the content of the store + */ + INDArray[] get(); + + /** + * Used to tell the HistoryMergeTransform that the store is ready. The HistoryMergeTransform will tell the {@link org.deeplearning4j.rl4j.observation.transform.TransformProcess TransformProcess} + * to skip the observation is the store is not ready. + * @return true if the store is ready + */ + boolean isReady(); + + /** + * Resets the store to an initial state. + */ + void reset(); +} diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/observation/preprocessor/pooling/ChannelStackPoolContentAssembler.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/observation/transform/operation/historymerge/HistoryStackAssembler.java similarity index 62% rename from rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/observation/preprocessor/pooling/ChannelStackPoolContentAssembler.java rename to rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/observation/transform/operation/historymerge/HistoryStackAssembler.java index d53f5c8c8..0559f25df 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/observation/preprocessor/pooling/ChannelStackPoolContentAssembler.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/observation/transform/operation/historymerge/HistoryStackAssembler.java @@ -1,53 +1,52 @@ -/******************************************************************************* - * Copyright (c) 2015-2019 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.deeplearning4j.rl4j.observation.preprocessor.pooling; - -import org.nd4j.linalg.api.ndarray.INDArray; -import org.nd4j.linalg.factory.Nd4j; - -/** - * ChannelStackPoolContentAssembler is used with the PoolingDataSetPreProcessor. This assembler will - * stack along the dimension 0. For example if the pool elements are of shape [ Height, Width ] - * the output will be of shape [ Stacked, Height, Width ] - * - * @author Alexandre Boulanger - */ -public class ChannelStackPoolContentAssembler implements PoolContentAssembler { - - /** - * Will return a new INDArray with one more dimension and with poolContent stacked along dimension 0. - * - * @param poolContent Array of INDArray - * @return A new INDArray with 1 more dimension than the input elements - */ - @Override - public INDArray assemble(INDArray[] poolContent) - { - // build the new shape - long[] elementShape = poolContent[0].shape(); - long[] newShape = new long[elementShape.length + 1]; - newShape[0] = poolContent.length; - System.arraycopy(elementShape, 0, newShape, 1, elementShape.length); - - // put pool elements in result - INDArray result = Nd4j.create(newShape); - for(int i = 0; i < poolContent.length; ++i) { - result.putRow(i, poolContent[i]); - } - return result; - } -} +/******************************************************************************* + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +package org.deeplearning4j.rl4j.observation.transform.operation.historymerge; + +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.factory.Nd4j; + +/** + * HistoryStackAssembler is used with the HistoryMergeTransform. This assembler will + * stack along the dimension 0. For example if the store elements are of shape [ Height, Width ] + * the output will be of shape [ Stacked, Height, Width ] + * + * @author Alexandre Boulanger + */ +public class HistoryStackAssembler implements HistoryMergeAssembler { + + /** + * Will return a new INDArray with one more dimension and with elements stacked along dimension 0. + * + * @param elements Array of INDArray + * @return A new INDArray with 1 more dimension than the input elements + */ + @Override + public INDArray assemble(INDArray[] elements) { + // build the new shape + long[] elementShape = elements[0].shape(); + long[] newShape = new long[elementShape.length + 1]; + newShape[0] = elements.length; + System.arraycopy(elementShape, 0, newShape, 1, elementShape.length); + + // stack the elements in result on the dimension 0 + INDArray result = Nd4j.create(newShape); + for(int i = 0; i < elements.length; ++i) { + result.putRow(i, elements[i]); + } + return result; + } +} diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/policy/Policy.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/policy/Policy.java index fb20a60ac..7719df612 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/policy/Policy.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/policy/Policy.java @@ -89,7 +89,7 @@ public abstract class Policy implements IPolicy { getNeuralNet().reset(); } - private > Learning.InitMdp refacInitMdp(LegacyMDPWrapper mdpWrapper, IHistoryProcessor hp, RefacEpochStepCounter epochStepCounter) { + protected > Learning.InitMdp refacInitMdp(LegacyMDPWrapper mdpWrapper, IHistoryProcessor hp, RefacEpochStepCounter epochStepCounter) { epochStepCounter.setCurrentEpochStep(0); double reward = 0; diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/helper/INDArrayHelperTest.java b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/helper/INDArrayHelperTest.java new file mode 100644 index 000000000..9bfceadad --- /dev/null +++ b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/helper/INDArrayHelperTest.java @@ -0,0 +1,38 @@ +package org.deeplearning4j.rl4j.helper; + +import org.junit.Test; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.factory.Nd4j; + +import static org.junit.Assert.*; + +public class INDArrayHelperTest { + @Test + public void when_inputHasIncorrectShape_expect_outputWithCorrectShape() { + // Arrange + INDArray input = Nd4j.create(new double[] { 1.0, 2.0, 3.0}); + + // Act + INDArray output = INDArrayHelper.forceCorrectShape(input); + + // Assert + assertEquals(2, output.shape().length); + assertEquals(1, output.shape()[0]); + assertEquals(3, output.shape()[1]); + } + + @Test + public void when_inputHasCorrectShape_expect_outputWithSameShape() { + // Arrange + INDArray input = Nd4j.create(new double[] { 1.0, 2.0, 3.0}).reshape(1, 3); + + // Act + INDArray output = INDArrayHelper.forceCorrectShape(input); + + // Assert + assertEquals(2, output.shape().length); + assertEquals(1, output.shape()[0]); + assertEquals(3, output.shape()[1]); + } + +} diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/observation/preprocessor/PoolingDataSetPreProcessorTest.java b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/observation/preprocessor/PoolingDataSetPreProcessorTest.java deleted file mode 100644 index db239b0f3..000000000 --- a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/observation/preprocessor/PoolingDataSetPreProcessorTest.java +++ /dev/null @@ -1,164 +0,0 @@ -package org.deeplearning4j.rl4j.observation.preprocessor; - -import org.deeplearning4j.rl4j.observation.preprocessor.pooling.ObservationPool; -import org.deeplearning4j.rl4j.observation.preprocessor.pooling.PoolContentAssembler; -import org.junit.Assert; -import org.junit.Test; -import org.nd4j.linalg.api.ndarray.INDArray; -import org.nd4j.linalg.dataset.DataSet; -import org.nd4j.linalg.factory.Nd4j; - -import static junit.framework.TestCase.assertTrue; -import static org.junit.Assert.assertEquals; - -public class PoolingDataSetPreProcessorTest { - - @Test(expected = NullPointerException.class) - public void when_dataSetIsNull_expect_NullPointerException() { - // Assemble - PoolingDataSetPreProcessor sut = PoolingDataSetPreProcessor.builder().build(); - - // Act - sut.preProcess(null); - } - - @Test(expected = IllegalArgumentException.class) - public void when_dataSetHasMoreThanOneExample_expect_IllegalArgumentException() { - // Assemble - PoolingDataSetPreProcessor sut = PoolingDataSetPreProcessor.builder().build(); - - // Act - sut.preProcess(new DataSet(Nd4j.rand(new long[] { 2, 2, 2 }), null)); - } - - @Test - public void when_dataSetIsEmpty_expect_EmptyDataSet() { - // Assemble - PoolingDataSetPreProcessor sut = PoolingDataSetPreProcessor.builder().build(); - DataSet ds = new DataSet(null, null); - - // Act - sut.preProcess(ds); - - // Assert - Assert.assertTrue(ds.isEmpty()); - } - - @Test - public void when_builderHasNoPoolOrAssembler_expect_defaultPoolBehavior() { - // Arrange - PoolingDataSetPreProcessor sut = PoolingDataSetPreProcessor.builder().build(); - DataSet[] observations = new DataSet[5]; - INDArray[] inputs = new INDArray[5]; - - - // Act - for(int i = 0; i < 5; ++i) { - inputs[i] = Nd4j.rand(new long[] { 1, 2, 2 }); - DataSet input = new DataSet(inputs[i], null); - sut.preProcess(input); - observations[i] = input; - } - - // Assert - assertTrue(observations[0].isEmpty()); - assertTrue(observations[1].isEmpty()); - assertTrue(observations[2].isEmpty()); - - for(int i = 0; i < 4; ++i) { - assertEquals(inputs[i].getDouble(new int[] { 0, 0, 0 }), observations[3].getFeatures().getDouble(new int[] { 0, i, 0, 0 }), 0.0001); - assertEquals(inputs[i].getDouble(new int[] { 0, 0, 1 }), observations[3].getFeatures().getDouble(new int[] { 0, i, 0, 1 }), 0.0001); - assertEquals(inputs[i].getDouble(new int[] { 0, 1, 0 }), observations[3].getFeatures().getDouble(new int[] { 0, i, 1, 0 }), 0.0001); - assertEquals(inputs[i].getDouble(new int[] { 0, 1, 1 }), observations[3].getFeatures().getDouble(new int[] { 0, i, 1, 1 }), 0.0001); - } - - for(int i = 0; i < 4; ++i) { - assertEquals(inputs[i+1].getDouble(new int[] { 0, 0, 0 }), observations[4].getFeatures().getDouble(new int[] { 0, i, 0, 0 }), 0.0001); - assertEquals(inputs[i+1].getDouble(new int[] { 0, 0, 1 }), observations[4].getFeatures().getDouble(new int[] { 0, i, 0, 1 }), 0.0001); - assertEquals(inputs[i+1].getDouble(new int[] { 0, 1, 0 }), observations[4].getFeatures().getDouble(new int[] { 0, i, 1, 0 }), 0.0001); - assertEquals(inputs[i+1].getDouble(new int[] { 0, 1, 1 }), observations[4].getFeatures().getDouble(new int[] { 0, i, 1, 1 }), 0.0001); - } - - } - - @Test - public void when_builderHasPoolAndAssembler_expect_paramPoolAndAssemblerAreUsed() { - // Arrange - INDArray input = Nd4j.rand(1, 1); - TestObservationPool pool = new TestObservationPool(); - TestPoolContentAssembler assembler = new TestPoolContentAssembler(); - PoolingDataSetPreProcessor sut = PoolingDataSetPreProcessor.builder() - .observablePool(pool) - .poolContentAssembler(assembler) - .build(); - - // Act - sut.preProcess(new DataSet(input, null)); - - // Assert - assertTrue(pool.isAtFullCapacityCalled); - assertTrue(pool.isGetCalled); - assertEquals(input.getDouble(0), pool.observation.getDouble(0), 0.0); - assertTrue(assembler.assembleIsCalled); - } - - @Test - public void when_pastInputChanges_expect_outputNotChanged() { - // Arrange - INDArray input = Nd4j.zeros(1, 1); - TestObservationPool pool = new TestObservationPool(); - TestPoolContentAssembler assembler = new TestPoolContentAssembler(); - PoolingDataSetPreProcessor sut = PoolingDataSetPreProcessor.builder() - .observablePool(pool) - .poolContentAssembler(assembler) - .build(); - - // Act - sut.preProcess(new DataSet(input, null)); - input.putScalar(0, 0, 1.0); - - // Assert - assertEquals(0.0, pool.observation.getDouble(0), 0.0); - } - - private static class TestObservationPool implements ObservationPool { - - public INDArray observation; - public boolean isGetCalled; - public boolean isAtFullCapacityCalled; - private boolean isResetCalled; - - @Override - public void add(INDArray observation) { - this.observation = observation; - } - - @Override - public INDArray[] get() { - isGetCalled = true; - return new INDArray[0]; - } - - @Override - public boolean isAtFullCapacity() { - isAtFullCapacityCalled = true; - return true; - } - - @Override - public void reset() { - isResetCalled = true; - } - } - - private static class TestPoolContentAssembler implements PoolContentAssembler { - - public boolean assembleIsCalled; - - @Override - public INDArray assemble(INDArray[] poolContent) { - assembleIsCalled = true; - return Nd4j.create(1, 1); - } - } -} diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/observation/preprocessor/SkippingDataSetPreProcessorTest.java b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/observation/preprocessor/SkippingDataSetPreProcessorTest.java deleted file mode 100644 index 3f1de3426..000000000 --- a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/observation/preprocessor/SkippingDataSetPreProcessorTest.java +++ /dev/null @@ -1,70 +0,0 @@ -package org.deeplearning4j.rl4j.observation.preprocessor; - -import org.junit.Test; -import org.nd4j.linalg.dataset.DataSet; -import org.nd4j.linalg.factory.Nd4j; - -import static org.junit.Assert.assertFalse; -import static org.junit.Assert.assertTrue; - -public class SkippingDataSetPreProcessorTest { - @Test(expected = IllegalArgumentException.class) - public void when_ctorSkipFrameIsZero_expect_IllegalArgumentException() { - SkippingDataSetPreProcessor sut = new SkippingDataSetPreProcessor(0); - } - - @Test(expected = IllegalArgumentException.class) - public void when_builderSkipFrameIsZero_expect_IllegalArgumentException() { - SkippingDataSetPreProcessor sut = SkippingDataSetPreProcessor.builder() - .skipFrame(0) - .build(); - } - - @Test - public void when_skipFrameIs3_expect_Skip2OutOf3() { - // Arrange - SkippingDataSetPreProcessor sut = SkippingDataSetPreProcessor.builder() - .skipFrame(3) - .build(); - DataSet[] results = new DataSet[4]; - - // Act - for(int i = 0; i < 4; ++i) { - results[i] = new DataSet(Nd4j.create(new double[] { 123.0 }), null); - sut.preProcess(results[i]); - } - - // Assert - assertFalse(results[0].isEmpty()); - assertTrue(results[1].isEmpty()); - assertTrue(results[2].isEmpty()); - assertFalse(results[3].isEmpty()); - } - - @Test - public void when_resetIsCalled_expect_skippingIsReset() { - // Arrange - SkippingDataSetPreProcessor sut = SkippingDataSetPreProcessor.builder() - .skipFrame(3) - .build(); - DataSet[] results = new DataSet[4]; - - // Act - results[0] = new DataSet(Nd4j.create(new double[] { 123.0 }), null); - results[1] = new DataSet(Nd4j.create(new double[] { 123.0 }), null); - results[2] = new DataSet(Nd4j.create(new double[] { 123.0 }), null); - results[3] = new DataSet(Nd4j.create(new double[] { 123.0 }), null); - - sut.preProcess(results[0]); - sut.preProcess(results[1]); - sut.reset(); - sut.preProcess(results[2]); - sut.preProcess(results[3]); - - // Assert - assertFalse(results[0].isEmpty()); - assertTrue(results[1].isEmpty()); - assertFalse(results[2].isEmpty()); - assertTrue(results[3].isEmpty()); - } -} diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/observation/preprocessor/pooling/ChannelStackPoolContentAssemblerTest.java b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/observation/preprocessor/pooling/ChannelStackPoolContentAssemblerTest.java deleted file mode 100644 index de0db015c..000000000 --- a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/observation/preprocessor/pooling/ChannelStackPoolContentAssemblerTest.java +++ /dev/null @@ -1,41 +0,0 @@ -package org.deeplearning4j.rl4j.observation.preprocessor.pooling; - -import org.junit.Test; -import org.nd4j.linalg.api.ndarray.INDArray; -import org.nd4j.linalg.factory.Nd4j; - -import static org.junit.Assert.assertEquals; - -public class ChannelStackPoolContentAssemblerTest { - - @Test - public void when_assemble_expect_poolContentStackedOnChannel() { - // Assemble - ChannelStackPoolContentAssembler sut = new ChannelStackPoolContentAssembler(); - INDArray[] poolContent = new INDArray[] { - Nd4j.rand(2, 2), - Nd4j.rand(2, 2), - }; - - // Act - INDArray result = sut.assemble(poolContent); - - // Assert - assertEquals(3, result.shape().length); - assertEquals(2, result.shape()[0]); - assertEquals(2, result.shape()[1]); - assertEquals(2, result.shape()[2]); - - assertEquals(poolContent[0].getDouble(0, 0), result.getDouble(0, 0, 0), 0.0001); - assertEquals(poolContent[0].getDouble(0, 1), result.getDouble(0, 0, 1), 0.0001); - assertEquals(poolContent[0].getDouble(1, 0), result.getDouble(0, 1, 0), 0.0001); - assertEquals(poolContent[0].getDouble(1, 1), result.getDouble(0, 1, 1), 0.0001); - - assertEquals(poolContent[1].getDouble(0, 0), result.getDouble(1, 0, 0), 0.0001); - assertEquals(poolContent[1].getDouble(0, 1), result.getDouble(1, 0, 1), 0.0001); - assertEquals(poolContent[1].getDouble(1, 0), result.getDouble(1, 1, 0), 0.0001); - assertEquals(poolContent[1].getDouble(1, 1), result.getDouble(1, 1, 1), 0.0001); - - } - -} diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/observation/preprocessor/pooling/CircularFifoObservationPoolTest.java b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/observation/preprocessor/pooling/CircularFifoObservationPoolTest.java deleted file mode 100644 index 88e7b33dd..000000000 --- a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/observation/preprocessor/pooling/CircularFifoObservationPoolTest.java +++ /dev/null @@ -1,100 +0,0 @@ -package org.deeplearning4j.rl4j.observation.preprocessor.pooling; - -import org.junit.Test; -import org.nd4j.linalg.api.ndarray.INDArray; -import org.nd4j.linalg.factory.Nd4j; - -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertFalse; -import static org.junit.Assert.assertTrue; - -public class CircularFifoObservationPoolTest { - - @Test(expected = IllegalArgumentException.class) - public void when_poolSizeZeroOrLess_expect_IllegalArgumentException() { - CircularFifoObservationPool sut = new CircularFifoObservationPool(0); - } - - @Test - public void when_poolIsEmpty_expect_NotReady() { - // Assemble - CircularFifoObservationPool sut = new CircularFifoObservationPool(); - - // Act - boolean isReady = sut.isAtFullCapacity(); - - // Assert - assertFalse(isReady); - } - - @Test - public void when_notEnoughElementsInPool_expect_notReady() { - // Assemble - CircularFifoObservationPool sut = new CircularFifoObservationPool(); - sut.add(Nd4j.create(new double[] { 123.0 })); - - // Act - boolean isReady = sut.isAtFullCapacity(); - - // Assert - assertFalse(isReady); - } - - @Test - public void when_enoughElementsInPool_expect_ready() { - // Assemble - CircularFifoObservationPool sut = CircularFifoObservationPool.builder() - .poolSize(2) - .build(); - sut.add(Nd4j.createFromArray(123.0)); - sut.add(Nd4j.createFromArray(123.0)); - - // Act - boolean isReady = sut.isAtFullCapacity(); - - // Assert - assertTrue(isReady); - } - - @Test - public void when_addMoreThanSize_expect_getReturnOnlyLastElements() { - // Assemble - CircularFifoObservationPool sut = CircularFifoObservationPool.builder().build(); - sut.add(Nd4j.createFromArray(0.0)); - sut.add(Nd4j.createFromArray(1.0)); - sut.add(Nd4j.createFromArray(2.0)); - sut.add(Nd4j.createFromArray(3.0)); - sut.add(Nd4j.createFromArray(4.0)); - sut.add(Nd4j.createFromArray(5.0)); - sut.add(Nd4j.createFromArray(6.0)); - - // Act - INDArray[] result = sut.get(); - - // Assert - assertEquals(3.0, result[0].getDouble(0), 0.0); - assertEquals(4.0, result[1].getDouble(0), 0.0); - assertEquals(5.0, result[2].getDouble(0), 0.0); - assertEquals(6.0, result[3].getDouble(0), 0.0); - } - - @Test - public void when_resetIsCalled_expect_poolContentFlushed() { - // Assemble - CircularFifoObservationPool sut = CircularFifoObservationPool.builder().build(); - sut.add(Nd4j.createFromArray(0.0)); - sut.add(Nd4j.createFromArray(1.0)); - sut.add(Nd4j.createFromArray(2.0)); - sut.add(Nd4j.createFromArray(3.0)); - sut.add(Nd4j.createFromArray(4.0)); - sut.add(Nd4j.createFromArray(5.0)); - sut.add(Nd4j.createFromArray(6.0)); - sut.reset(); - - // Act - INDArray[] result = sut.get(); - - // Assert - assertEquals(0, result.length); - } -} diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/observation/transform/operation/HistoryMergeTransformTest.java b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/observation/transform/operation/HistoryMergeTransformTest.java new file mode 100644 index 000000000..9c7a172bb --- /dev/null +++ b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/observation/transform/operation/HistoryMergeTransformTest.java @@ -0,0 +1,166 @@ +package org.deeplearning4j.rl4j.observation.transform.operation; + +import org.deeplearning4j.rl4j.observation.transform.operation.historymerge.HistoryMergeAssembler; +import org.deeplearning4j.rl4j.observation.transform.operation.historymerge.HistoryMergeElementStore; +import org.junit.Test; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.factory.Nd4j; + +import static org.junit.Assert.*; + +public class HistoryMergeTransformTest { + + @Test + public void when_firstDimensionIsNotBatch_expect_observationAddedAsIs() { + // Arrange + MockStore store = new MockStore(false); + HistoryMergeTransform sut = HistoryMergeTransform.builder() + .isFirstDimenstionBatch(false) + .elementStore(store) + .build(); + INDArray input = Nd4j.create(new double[] { 1.0, 2.0, 3.0 }); + + // Act + sut.transform(input); + + // Assert + assertEquals(1, store.addedObservation.shape().length); + assertEquals(3, store.addedObservation.shape()[0]); + } + + @Test + public void when_firstDimensionIsBatch_expect_observationAddedAsSliced() { + // Arrange + MockStore store = new MockStore(false); + HistoryMergeTransform sut = HistoryMergeTransform.builder() + .isFirstDimenstionBatch(true) + .elementStore(store) + .build(); + INDArray input = Nd4j.create(new double[] { 1.0, 2.0, 3.0 }).reshape(1, 3); + + // Act + sut.transform(input); + + // Assert + assertEquals(1, store.addedObservation.shape().length); + assertEquals(3, store.addedObservation.shape()[0]); + } + + @Test + public void when_notReady_expect_resultIsNull() { + // Arrange + MockStore store = new MockStore(false); + HistoryMergeTransform sut = HistoryMergeTransform.builder() + .isFirstDimenstionBatch(true) + .elementStore(store) + .build(); + INDArray input = Nd4j.create(new double[] { 1.0, 2.0, 3.0 }); + + // Act + INDArray result = sut.transform(input); + + // Assert + assertNull(result); + } + + @Test + public void when_notShouldStoreCopy_expect_sameIsStored() { + // Arrange + MockStore store = new MockStore(false); + HistoryMergeTransform sut = HistoryMergeTransform.builder() + .shouldStoreCopy(false) + .elementStore(store) + .build(); + INDArray input = Nd4j.create(new double[] { 1.0, 2.0, 3.0 }); + + // Act + INDArray result = sut.transform(input); + + // Assert + assertSame(input, store.addedObservation); + } + + @Test + public void when_shouldStoreCopy_expect_copyIsStored() { + // Arrange + MockStore store = new MockStore(true); + HistoryMergeTransform sut = HistoryMergeTransform.builder() + .shouldStoreCopy(true) + .elementStore(store) + .build(); + INDArray input = Nd4j.create(new double[] { 1.0, 2.0, 3.0 }); + + // Act + INDArray result = sut.transform(input); + + // Assert + assertNotSame(input, store.addedObservation); + assertEquals(1, store.addedObservation.shape().length); + assertEquals(3, store.addedObservation.shape()[0]); + } + + @Test + public void when_transformCalled_expect_storeContentAssembledAndOutputHasCorrectShape() { + // Arrange + MockStore store = new MockStore(true); + MockAssemble assemble = new MockAssemble(); + HistoryMergeTransform sut = HistoryMergeTransform.builder() + .elementStore(store) + .assembler(assemble) + .build(); + INDArray input = Nd4j.create(new double[] { 1.0, 2.0, 3.0 }); + + // Act + INDArray result = sut.transform(input); + + // Assert + assertEquals(1, assemble.assembleElements.length); + assertSame(store.addedObservation, assemble.assembleElements[0]); + + assertEquals(2, result.shape().length); + assertEquals(1, result.shape()[0]); + assertEquals(3, result.shape()[1]); + } + + public static class MockStore implements HistoryMergeElementStore { + + private final boolean isReady; + private INDArray addedObservation; + + public MockStore(boolean isReady) { + + this.isReady = isReady; + } + + @Override + public void add(INDArray observation) { + addedObservation = observation; + } + + @Override + public INDArray[] get() { + return new INDArray[] { addedObservation }; + } + + @Override + public boolean isReady() { + return isReady; + } + + @Override + public void reset() { + + } + } + + public static class MockAssemble implements HistoryMergeAssembler { + + private INDArray[] assembleElements; + + @Override + public INDArray assemble(INDArray[] elements) { + assembleElements = elements; + return elements[0]; + } + } +} diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/observation/transform/operation/historymerge/CircularFifoStoreTest.java b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/observation/transform/operation/historymerge/CircularFifoStoreTest.java new file mode 100644 index 000000000..f9b34a1f1 --- /dev/null +++ b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/observation/transform/operation/historymerge/CircularFifoStoreTest.java @@ -0,0 +1,77 @@ +package org.deeplearning4j.rl4j.observation.transform.operation.historymerge; + +import org.junit.Test; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.factory.Nd4j; + +import static org.junit.Assert.*; + +public class CircularFifoStoreTest { + + @Test(expected = IllegalArgumentException.class) + public void when_fifoSizeIsLessThan1_expect_exception() { + // Arrange + CircularFifoStore sut = new CircularFifoStore(0); + } + + @Test + public void when_adding2elementsWithSize2_expect_notReadyAfter1stReadyAfter2nd() { + // Arrange + CircularFifoStore sut = new CircularFifoStore(2); + INDArray firstElement = Nd4j.create(new double[] { 1.0, 2.0, 3.0 }); + INDArray secondElement = Nd4j.create(new double[] { 10.0, 20.0, 30.0 }); + + // Act + sut.add(firstElement); + boolean isReadyAfter1st = sut.isReady(); + sut.add(secondElement); + boolean isReadyAfter2nd = sut.isReady(); + + // Assert + assertFalse(isReadyAfter1st); + assertTrue(isReadyAfter2nd); + } + + @Test + public void when_adding2elementsWithSize2_expect_getReturnThese2() { + // Arrange + CircularFifoStore sut = new CircularFifoStore(2); + INDArray firstElement = Nd4j.create(new double[] { 1.0, 2.0, 3.0 }); + INDArray secondElement = Nd4j.create(new double[] { 10.0, 20.0, 30.0 }); + + // Act + sut.add(firstElement); + sut.add(secondElement); + INDArray[] results = sut.get(); + + // Assert + assertEquals(2, results.length); + + assertEquals(1.0, results[0].getDouble(0), 0.00001); + assertEquals(2.0, results[0].getDouble(1), 0.00001); + assertEquals(3.0, results[0].getDouble(2), 0.00001); + + assertEquals(10.0, results[1].getDouble(0), 0.00001); + assertEquals(20.0, results[1].getDouble(1), 0.00001); + assertEquals(30.0, results[1].getDouble(2), 0.00001); + + } + + @Test + public void when_adding2elementsThenCallingReset_expect_getReturnEmpty() { + // Arrange + CircularFifoStore sut = new CircularFifoStore(2); + INDArray firstElement = Nd4j.create(new double[] { 1.0, 2.0, 3.0 }); + INDArray secondElement = Nd4j.create(new double[] { 10.0, 20.0, 30.0 }); + + // Act + sut.add(firstElement); + sut.add(secondElement); + sut.reset(); + INDArray[] results = sut.get(); + + // Assert + assertEquals(0, results.length); + } + +} diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/observation/transform/operation/historymerge/HistoryStackAssemblerTest.java b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/observation/transform/operation/historymerge/HistoryStackAssemblerTest.java new file mode 100644 index 000000000..36826430e --- /dev/null +++ b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/observation/transform/operation/historymerge/HistoryStackAssemblerTest.java @@ -0,0 +1,37 @@ +package org.deeplearning4j.rl4j.observation.transform.operation.historymerge; + +import org.junit.Test; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.factory.Nd4j; + +import static org.junit.Assert.*; + +public class HistoryStackAssemblerTest { + + @Test + public void when_assembling2INDArrays_expect_stackedAsResult() { + // Arrange + INDArray[] input = new INDArray[] { + Nd4j.create(new double[] { 1.0, 2.0, 3.0 }), + Nd4j.create(new double[] { 10.0, 20.0, 30.0 }), + }; + HistoryStackAssembler sut = new HistoryStackAssembler(); + + // Act + INDArray result = sut.assemble(input); + + // Assert + assertEquals(2, result.shape().length); + assertEquals(2, result.shape()[0]); + assertEquals(3, result.shape()[1]); + + assertEquals(1.0, result.getDouble(0, 0), 0.00001); + assertEquals(2.0, result.getDouble(0, 1), 0.00001); + assertEquals(3.0, result.getDouble(0, 2), 0.00001); + + assertEquals(10.0, result.getDouble(1, 0), 0.00001); + assertEquals(20.0, result.getDouble(1, 1), 0.00001); + assertEquals(30.0, result.getDouble(1, 2), 0.00001); + + } +}