diff --git a/.gitignore b/.gitignore index 501c270c4..ad2e28e6f 100644 --- a/.gitignore +++ b/.gitignore @@ -65,4 +65,8 @@ doc_sources_* # Python virtual environments venv/ -venv2/ \ No newline at end of file +venv2/ + +# Ignore the nd4j files that are created by javacpp at build to stop merge conflicts +nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/nativeblas/Nd4jCpu.java +nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/nativeblas/Nd4jCuda.java diff --git a/datavec/datavec-api/src/main/java/org/datavec/api/transform/ColumnOp.java b/datavec/datavec-api/src/main/java/org/datavec/api/transform/ColumnOp.java index d8caf1962..ae2d17788 100644 --- a/datavec/datavec-api/src/main/java/org/datavec/api/transform/ColumnOp.java +++ b/datavec/datavec-api/src/main/java/org/datavec/api/transform/ColumnOp.java @@ -26,10 +26,7 @@ import org.datavec.api.transform.schema.Schema; * * @author Adam Gibson */ -public interface ColumnOp { - /** Get the output schema for this transformation, given an input schema */ - Schema transform(Schema inputSchema); - +public interface ColumnOp extends Operation { /** Set the input schema. */ diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/observation/preprocessor/ResettableDataSetPreProcessor.java b/datavec/datavec-api/src/main/java/org/datavec/api/transform/Operation.java similarity index 63% rename from rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/observation/preprocessor/ResettableDataSetPreProcessor.java rename to datavec/datavec-api/src/main/java/org/datavec/api/transform/Operation.java index 46ff4e39c..ecb58543b 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/observation/preprocessor/ResettableDataSetPreProcessor.java +++ b/datavec/datavec-api/src/main/java/org/datavec/api/transform/Operation.java @@ -1,28 +1,20 @@ -/******************************************************************************* - * Copyright (c) 2015-2019 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.deeplearning4j.rl4j.observation.preprocessor; - -import org.nd4j.linalg.dataset.api.DataSetPreProcessor; - -/** - * A base class for all DataSetPreProcessor that must be reset between each MDP sessions (games). - * - * @author Alexandre Boulanger - */ -public abstract class ResettableDataSetPreProcessor implements DataSetPreProcessor { - public abstract void reset(); -} +/******************************************************************************* + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ +package org.datavec.api.transform; + +public interface Operation { + TOut transform(TIn input); +} diff --git a/datavec/datavec-data/datavec-data-image/src/main/java/org/datavec/image/transform/ImageTransform.java b/datavec/datavec-data/datavec-data-image/src/main/java/org/datavec/image/transform/ImageTransform.java index afcdf894f..bdd74df13 100644 --- a/datavec/datavec-data/datavec-data-image/src/main/java/org/datavec/image/transform/ImageTransform.java +++ b/datavec/datavec-data/datavec-data-image/src/main/java/org/datavec/image/transform/ImageTransform.java @@ -16,6 +16,7 @@ package org.datavec.image.transform; +import org.datavec.api.transform.Operation; import org.datavec.image.data.ImageWritable; import org.nd4j.shade.jackson.annotation.JsonInclude; import org.nd4j.shade.jackson.annotation.JsonTypeInfo; @@ -29,15 +30,7 @@ import java.util.Random; */ @JsonInclude(JsonInclude.Include.NON_NULL) @JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class") -public interface ImageTransform { - - /** - * Takes an image and returns a transformed image. - * - * @param image to transform, null == end of stream - * @return transformed image - */ - ImageWritable transform(ImageWritable image); +public interface ImageTransform extends Operation { /** * Takes an image and returns a transformed image. diff --git a/pom.xml b/pom.xml index d4bf988af..a2188de1a 100644 --- a/pom.xml +++ b/pom.xml @@ -297,14 +297,14 @@ 1.18.1 ${numpy.version}-${javacpp-presets.version} - 0.3.8 + 0.3.9 2020.0 4.2.0 4.2.2 1.79.0 - 1.10.6 + 1.12.0 0.6.1 - 0.15.5 + 0.17.1 1.15.2 ${tensorflow.version}-${javacpp-presets.version} diff --git a/rl4j/rl4j-core/pom.xml b/rl4j/rl4j-core/pom.xml index a78157603..c08615250 100644 --- a/rl4j/rl4j-core/pom.xml +++ b/rl4j/rl4j-core/pom.xml @@ -102,6 +102,13 @@ gson ${gson.version} + + + org.datavec + datavec-api + ${datavec.version} + + diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/observation/preprocessor/pooling/PoolContentAssembler.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/helper/INDArrayHelper.java similarity index 52% rename from rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/observation/preprocessor/pooling/PoolContentAssembler.java rename to rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/helper/INDArrayHelper.java index 63b382a09..7d93b1175 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/observation/preprocessor/pooling/PoolContentAssembler.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/helper/INDArrayHelper.java @@ -1,30 +1,39 @@ -/******************************************************************************* - * Copyright (c) 2015-2019 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.deeplearning4j.rl4j.observation.preprocessor.pooling; - -import org.nd4j.linalg.api.ndarray.INDArray; - -/** - * A PoolContentAssembler is used with the PoolingDataSetPreProcessor. This interface defines how the array of INDArray - * returned by the ObservationPool is packaged into the single INDArray that will be set - * in the DataSet of PoolingDataSetPreProcessor.preProcess - * - * @author Alexandre Boulanger - */ -public interface PoolContentAssembler { - INDArray assemble(INDArray[] poolContent); -} +/******************************************************************************* + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ +package org.deeplearning4j.rl4j.helper; + +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.factory.Nd4j; + +/** + * INDArray helper methods used by RL4J + * + * @author Alexandre Boulanger + */ +public class INDArrayHelper { + /** + * MultiLayerNetwork and ComputationGraph expect the first dimension to be the number of examples in the INDArray. + * In the case of RL4J, it must be 1. This method will return a INDArray with the correct shape. + * + * @param source A INDArray + * @return The source INDArray with the correct shape + */ + public static INDArray forceCorrectShape(INDArray source) { + return source.shape()[0] == 1 + ? source + : Nd4j.expandDims(source, 0); + } +} diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/AsyncThreadDiscrete.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/AsyncThreadDiscrete.java index 0b3eb1c72..27d49c366 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/AsyncThreadDiscrete.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/AsyncThreadDiscrete.java @@ -90,7 +90,7 @@ public abstract class AsyncThreadDiscrete accuReward += stepReply.getReward() * getConf().getRewardFactor(); //if it's not a skipped frame, you can do a step of training - if (!obs.isSkipped() || stepReply.isDone()) { + if (!obs.isSkipped()) { INDArray[] output = current.outputAll(obs.getData()); rewards.add(new MiniTrans(obs.getData(), action, output, accuReward)); @@ -99,7 +99,6 @@ public abstract class AsyncThreadDiscrete } obs = stepReply.getObservation(); - reward += stepReply.getReward(); incrementStep(); diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/QLearningDiscrete.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/QLearningDiscrete.java index 0fa0f33e6..614ddb793 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/QLearningDiscrete.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/QLearningDiscrete.java @@ -158,7 +158,7 @@ public abstract class QLearningDiscrete extends QLearning - * The PoolingDataSetPreProcessor requires two sub components:
- * 1) The ObservationPool that supervises what and how input observations are kept. (ex.: Circular FIFO, trailing min/max/avg, etc...) - * The default is a Circular FIFO. - * 2) The PoolContentAssembler that will assemble the pool content into a resulting single INDArray. (ex.: stacked along a dimention, squashed into a single observation, etc...) - * The default is stacking along the dimension 0. - * - * @author Alexandre Boulanger - */ -public class PoolingDataSetPreProcessor extends ResettableDataSetPreProcessor { - private final ObservationPool observationPool; - private final PoolContentAssembler poolContentAssembler; - - protected PoolingDataSetPreProcessor(PoolingDataSetPreProcessor.Builder builder) - { - observationPool = builder.observationPool; - poolContentAssembler = builder.poolContentAssembler; - } - - /** - * Note: preProcess will empty the processed dataset if the pool has not filled. Empty datasets should ignored by the - * Policy/Learning class and other DataSetPreProcessors - * - * @param dataSet - */ - @Override - public void preProcess(DataSet dataSet) { - Preconditions.checkNotNull(dataSet, "Encountered null dataSet"); - - if(dataSet.isEmpty()) { - return; - } - - Preconditions.checkArgument(dataSet.numExamples() == 1, "Pooling datasets conatining more than one example is not supported"); - - // store a duplicate in the pool - observationPool.add(dataSet.getFeatures().slice(0, 0).dup()); - if(!observationPool.isAtFullCapacity()) { - dataSet.setFeatures(null); - return; - } - - INDArray result = poolContentAssembler.assemble(observationPool.get()); - - // return a DataSet containing only 1 example (the result) - long[] resultShape = result.shape(); - long[] newShape = new long[resultShape.length + 1]; - newShape[0] = 1; - System.arraycopy(resultShape, 0, newShape, 1, resultShape.length); - - dataSet.setFeatures(result.reshape(newShape)); - } - - public static PoolingDataSetPreProcessor.Builder builder() { - return new PoolingDataSetPreProcessor.Builder(); - } - - @Override - public void reset() { - observationPool.reset(); - } - - public static class Builder { - private ObservationPool observationPool; - private PoolContentAssembler poolContentAssembler; - - /** - * Default is CircularFifoObservationPool - */ - public PoolingDataSetPreProcessor.Builder observablePool(ObservationPool observationPool) { - this.observationPool = observationPool; - return this; - } - - /** - * Default is ChannelStackPoolContentAssembler - */ - public PoolingDataSetPreProcessor.Builder poolContentAssembler(PoolContentAssembler poolContentAssembler) { - this.poolContentAssembler = poolContentAssembler; - return this; - } - - public PoolingDataSetPreProcessor build() { - if(observationPool == null) { - observationPool = new CircularFifoObservationPool(); - } - - if(poolContentAssembler == null) { - poolContentAssembler = new ChannelStackPoolContentAssembler(); - } - - return new PoolingDataSetPreProcessor(this); - } - } - -} diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/observation/preprocessor/SkippingDataSetPreProcessor.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/observation/preprocessor/SkippingDataSetPreProcessor.java deleted file mode 100644 index 940966823..000000000 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/observation/preprocessor/SkippingDataSetPreProcessor.java +++ /dev/null @@ -1,62 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2019 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.deeplearning4j.rl4j.observation.preprocessor; - -import lombok.Builder; -import org.nd4j.base.Preconditions; -import org.nd4j.linalg.dataset.api.DataSet; - -/** - * The SkippingDataSetPreProcessor will either do nothing to the input (when not skipped) or will empty - * the input DataSet when skipping. - * - * @author Alexandre Boulanger - */ -public class SkippingDataSetPreProcessor extends ResettableDataSetPreProcessor { - - private final int skipFrame; - - private int currentIdx = 0; - - /** - * @param skipFrame For example, a skipFrame of 4 will skip 3 out of 4 observations. - */ - @Builder - public SkippingDataSetPreProcessor(int skipFrame) { - Preconditions.checkArgument(skipFrame > 0, "skipFrame must be greater than 0, got %s", skipFrame); - this.skipFrame = skipFrame; - } - - @Override - public void preProcess(DataSet dataSet) { - Preconditions.checkNotNull(dataSet, "Encountered null dataSet"); - - if(dataSet.isEmpty()) { - return; - } - - if(currentIdx++ % skipFrame != 0) { - dataSet.setFeatures(null); - dataSet.setLabels(null); - } - } - - @Override - public void reset() { - currentIdx = 0; - } -} diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/observation/transform/FilterOperation.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/observation/transform/FilterOperation.java new file mode 100644 index 000000000..74b67bcaa --- /dev/null +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/observation/transform/FilterOperation.java @@ -0,0 +1,35 @@ +/******************************************************************************* + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ +package org.deeplearning4j.rl4j.observation.transform; + +import java.util.Map; + +/** + * Used with {@link TransformProcess TransformProcess} to filter-out an observation. + * + * @author Alexandre Boulanger + */ +public interface FilterOperation { + /** + * The logic that determines if the observation should be skipped. + * + * @param channelsData the name of the channel + * @param currentObservationStep The step number if the observation in the current episode. + * @param isFinalObservation true if this is the last observation of the episode + * @return true if the observation should be skipped + */ + boolean isSkipped(Map channelsData, int currentObservationStep, boolean isFinalObservation); +} diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/observation/preprocessor/pooling/ObservationPool.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/observation/transform/ResettableOperation.java similarity index 62% rename from rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/observation/preprocessor/pooling/ObservationPool.java rename to rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/observation/transform/ResettableOperation.java index 1d8363ad8..a17bdc6c4 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/observation/preprocessor/pooling/ObservationPool.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/observation/transform/ResettableOperation.java @@ -1,32 +1,26 @@ -/******************************************************************************* - * Copyright (c) 2015-2019 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.deeplearning4j.rl4j.observation.preprocessor.pooling; - -import org.nd4j.linalg.api.ndarray.INDArray; - -/** - * ObservationPool is used with the PoolingDataSetPreProcessor. Used to supervise how data from the - * PoolingDataSetPreProcessor is stored. - * - * @author Alexandre Boulanger - */ -public interface ObservationPool { - void add(INDArray observation); - INDArray[] get(); - boolean isAtFullCapacity(); - void reset(); -} +/******************************************************************************* + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ +package org.deeplearning4j.rl4j.observation.transform; + +/** + * The {@link TransformProcess TransformProcess} will call reset() (at the start of an episode) of any step that implement this interface. + */ +public interface ResettableOperation { + /** + * Called by TransformProcess when an episode starts. See {@link TransformProcess#reset() TransformProcess.reset()} + */ + void reset(); +} diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/observation/transform/TransformProcess.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/observation/transform/TransformProcess.java new file mode 100644 index 000000000..53bc01421 --- /dev/null +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/observation/transform/TransformProcess.java @@ -0,0 +1,231 @@ +/******************************************************************************* + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ +package org.deeplearning4j.rl4j.observation.transform; + +import org.apache.commons.lang3.NotImplementedException; +import org.deeplearning4j.rl4j.helper.INDArrayHelper; +import org.deeplearning4j.rl4j.observation.Observation; +import org.nd4j.base.Preconditions; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.dataset.api.DataSet; +import org.nd4j.linalg.dataset.api.DataSetPreProcessor; +import org.nd4j.shade.guava.collect.Maps; +import org.datavec.api.transform.Operation; + +import java.util.ArrayList; +import java.util.HashSet; +import java.util.List; +import java.util.Map; + +/** + * A TransformProcess will build an {@link Observation Observation} from the raw data coming from the environment. + * Three types of steps are available: + * 1) A {@link FilterOperation FilterOperation}: Used to determine if an observation should be skipped. + * 2) An {@link Operation Operation}: Applies a transform and/or conversion to an observation channel. + * 3) A {@link DataSetPreProcessor DataSetPreProcessor}: Applies a DataSetPreProcessor to the observation channel. The data in the channel must be a DataSet. + * + * Instances of the three types above can be called in any order. The only requirement is that when build() is called, + * all channels must be instances of INDArrays or DataSets + * + * NOTE: Presently, only single-channels observations are supported. + * + * @author Alexandre Boulanger + */ +public class TransformProcess { + + private final List> operations; + private final String[] channelNames; + private final HashSet operationsChannelNames; + + private TransformProcess(Builder builder, String... channelNames) { + operations = builder.operations; + this.channelNames = channelNames; + this.operationsChannelNames = builder.requiredChannelNames; + } + + /** + * This method will call reset() of all steps implementing {@link ResettableOperation ResettableOperation} in the transform process. + */ + public void reset() { + for(Map.Entry entry : operations) { + if(entry.getValue() instanceof ResettableOperation) { + ((ResettableOperation) entry.getValue()).reset(); + } + } + } + + /** + * Transforms the channel data into an Observation or a skipped observation depending on the specific steps in the transform process. + * + * @param channelsData A Map that maps the channel name to its data. + * @param currentObservationStep The observation's step number within the episode. + * @param isFinalObservation True if the observation is the last of the episode. + * @return An observation (may be a skipped observation) + */ + public Observation transform(Map channelsData, int currentObservationStep, boolean isFinalObservation) { + // null or empty channelData + Preconditions.checkArgument(channelsData != null && channelsData.size() != 0, "Error: channelsData not supplied."); + + // Check that all channels have data + for(Map.Entry channel : channelsData.entrySet()) { + Preconditions.checkNotNull(channel.getValue(), "Error: data of channel '%s' is null", channel.getKey()); + } + + // Check that all required channels are present + for(String channelName : operationsChannelNames) { + Preconditions.checkArgument(channelsData.containsKey(channelName), "The channelsData map does not contain the channel '%s'", channelName); + } + + for(Map.Entry entry : operations) { + + // Filter + if(entry.getValue() instanceof FilterOperation) { + FilterOperation filterOperation = (FilterOperation)entry.getValue(); + if(filterOperation.isSkipped(channelsData, currentObservationStep, isFinalObservation)) { + return Observation.SkippedObservation; + } + } + + // Transform + // null results are considered skipped observations + else if(entry.getValue() instanceof Operation) { + Operation transformOperation = (Operation)entry.getValue(); + Object transformed = transformOperation.transform(channelsData.get(entry.getKey())); + if(transformed == null) { + return Observation.SkippedObservation; + } + channelsData.replace(entry.getKey(), transformed); + } + + // PreProcess + else if(entry.getValue() instanceof DataSetPreProcessor) { + Object channelData = channelsData.get(entry.getKey()); + DataSetPreProcessor dataSetPreProcessor = (DataSetPreProcessor)entry.getValue(); + if(!(channelData instanceof DataSet)) { + throw new IllegalArgumentException("The channel data must be a DataSet to call preProcess"); + } + dataSetPreProcessor.preProcess((DataSet)channelData); + } + + else { + throw new IllegalArgumentException(String.format("Unknown operation: '%s'", entry.getValue().getClass().getName())); + } + } + + // Check that all channels used to build the observation are instances of + // INDArray or DataSet + // TODO: Add support for an interface with a toINDArray() method + for(String channelName : channelNames) { + Object channelData = channelsData.get(channelName); + + INDArray finalChannelData; + if(channelData instanceof DataSet) { + finalChannelData = ((DataSet)channelData).getFeatures(); + } + else if(channelData instanceof INDArray) { + finalChannelData = (INDArray) channelData; + } + else { + throw new IllegalStateException("All channels used to build the observation must be instances of DataSet or INDArray"); + } + + // The dimension 0 of all INDArrays must be 1 (batch count) + channelsData.replace(channelName, INDArrayHelper.forceCorrectShape(finalChannelData)); + } + + // TODO: Add support to multi-channel observations + INDArray data = ((INDArray) channelsData.get(channelNames[0])); + return new Observation(data); + } + + /** + * @return An instance of a builder + */ + public static Builder builder() { + return new Builder(); + } + + public static class Builder { + + private final List> operations = new ArrayList>(); + private final HashSet requiredChannelNames = new HashSet(); + + /** + * Add a filter to the transform process steps. Used to skip observations on certain conditions. + * See {@link FilterOperation FilterOperation} + * @param filterOperation An instance + */ + public Builder filter(FilterOperation filterOperation) { + Preconditions.checkNotNull(filterOperation, "The filterOperation must not be null"); + + operations.add((Map.Entry)Maps.immutableEntry(null, filterOperation)); + return this; + } + + /** + * Add a transform to the steps. The transform can change the data and / or change the type of the data + * (e.g. Byte[] to a ImageWritable) + * + * @param targetChannel The name of the channel to which the transform is applied. + * @param transformOperation An instance of {@link Operation Operation} + */ + public Builder transform(String targetChannel, Operation transformOperation) { + Preconditions.checkNotNull(targetChannel, "The targetChannel must not be null"); + Preconditions.checkNotNull(transformOperation, "The transformOperation must not be null"); + + requiredChannelNames.add(targetChannel); + operations.add((Map.Entry)Maps.immutableEntry(targetChannel, transformOperation)); + return this; + } + + /** + * Add a DataSetPreProcessor to the steps. The channel must be a DataSet instance at this step. + * @param targetChannel The name of the channel to which the pre processor is applied. + * @param dataSetPreProcessor + */ + public Builder preProcess(String targetChannel, DataSetPreProcessor dataSetPreProcessor) { + Preconditions.checkNotNull(targetChannel, "The targetChannel must not be null"); + Preconditions.checkNotNull(dataSetPreProcessor, "The dataSetPreProcessor must not be null"); + + requiredChannelNames.add(targetChannel); + operations.add((Map.Entry)Maps.immutableEntry(targetChannel, dataSetPreProcessor)); + return this; + } + + /** + * Builds the TransformProcess. + * @param channelNames A subset of channel names to be used to build the observation + * @return An instance of TransformProcess + */ + public TransformProcess build(String... channelNames) { + if(channelNames.length == 0) { + throw new IllegalArgumentException("At least one channel must be supplied."); + } + + for(String channelName : channelNames) { + Preconditions.checkNotNull(channelName, "Error: got a null channel name"); + requiredChannelNames.add(channelName); + } + + // TODO: Remove when multi-channel observation is supported + if(channelNames.length != 1) { + throw new NotImplementedException("Multi-channel observations is not presently supported."); + } + + return new TransformProcess(this, channelNames); + } + } +} diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/observation/transform/filter/UniformSkippingFilter.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/observation/transform/filter/UniformSkippingFilter.java new file mode 100644 index 000000000..0b31752d4 --- /dev/null +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/observation/transform/filter/UniformSkippingFilter.java @@ -0,0 +1,45 @@ +/******************************************************************************* + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ +package org.deeplearning4j.rl4j.observation.transform.filter; + +import org.deeplearning4j.rl4j.observation.transform.FilterOperation; +import org.nd4j.base.Preconditions; +import java.util.Map; + +/** + * Used with {@link org.deeplearning4j.rl4j.observation.transform.TransformProcess TransformProcess}. Will cause the + * transform process to skip a fixed number of frames between non skipped ones. + * + * @author Alexandre Boulanger + */ +public class UniformSkippingFilter implements FilterOperation { + + private final int skipFrame; + + /** + * @param skipFrame Will cause the filter to keep (not skip) 1 frame every skipFrames. + */ + public UniformSkippingFilter(int skipFrame) { + Preconditions.checkArgument(skipFrame > 0, "skipFrame should be greater than 0"); + + this.skipFrame = skipFrame; + } + + @Override + public boolean isSkipped(Map channelsData, int currentObservationStep, boolean isFinalObservation) { + return !isFinalObservation && (currentObservationStep % skipFrame != 0); + } +} diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/observation/transform/legacy/EncodableToINDArrayTransform.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/observation/transform/legacy/EncodableToINDArrayTransform.java new file mode 100644 index 000000000..a9214bbff --- /dev/null +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/observation/transform/legacy/EncodableToINDArrayTransform.java @@ -0,0 +1,41 @@ +/******************************************************************************* + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ +package org.deeplearning4j.rl4j.observation.transform.legacy; + +import org.bytedeco.javacv.OpenCVFrameConverter; +import org.bytedeco.opencv.opencv_core.Mat; +import org.datavec.api.transform.Operation; +import org.datavec.image.data.ImageWritable; +import org.deeplearning4j.rl4j.space.Encodable; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.factory.Nd4j; + +import static org.bytedeco.opencv.global.opencv_core.CV_32FC; + +public class EncodableToINDArrayTransform implements Operation { + + private final int[] shape; + + public EncodableToINDArrayTransform(int[] shape) { + this.shape = shape; + } + + @Override + public INDArray transform(Encodable encodable) { + return Nd4j.create(encodable.toArray()).reshape(shape); + } + +} diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/observation/transform/legacy/EncodableToImageWritableTransform.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/observation/transform/legacy/EncodableToImageWritableTransform.java new file mode 100644 index 000000000..3eaeec4dc --- /dev/null +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/observation/transform/legacy/EncodableToImageWritableTransform.java @@ -0,0 +1,48 @@ +/******************************************************************************* + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ +package org.deeplearning4j.rl4j.observation.transform.legacy; + +import org.bytedeco.javacv.OpenCVFrameConverter; +import org.bytedeco.opencv.opencv_core.Mat; +import org.datavec.api.transform.Operation; +import org.datavec.image.data.ImageWritable; +import org.deeplearning4j.rl4j.space.Encodable; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.factory.Nd4j; + +import static org.bytedeco.opencv.global.opencv_core.CV_32FC; + +public class EncodableToImageWritableTransform implements Operation { + + private final OpenCVFrameConverter.ToMat converter = new OpenCVFrameConverter.ToMat(); + private final int height; + private final int width; + private final int colorChannels; + + public EncodableToImageWritableTransform(int height, int width, int colorChannels) { + this.height = height; + this.width = width; + this.colorChannels = colorChannels; + } + + @Override + public ImageWritable transform(Encodable encodable) { + INDArray indArray = Nd4j.create((encodable).toArray()).reshape(height, width, colorChannels); + Mat mat = new Mat(height, width, CV_32FC(3), indArray.data().pointer()); + return new ImageWritable(converter.convert(mat)); + } + +} diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/observation/transform/legacy/ImageWritableToINDArrayTransform.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/observation/transform/legacy/ImageWritableToINDArrayTransform.java new file mode 100644 index 000000000..3a48c128a --- /dev/null +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/observation/transform/legacy/ImageWritableToINDArrayTransform.java @@ -0,0 +1,52 @@ +/******************************************************************************* + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ +package org.deeplearning4j.rl4j.observation.transform.legacy; + +import org.datavec.api.transform.Operation; +import org.datavec.image.data.ImageWritable; +import org.datavec.image.loader.NativeImageLoader; +import org.deeplearning4j.rl4j.space.Encodable; +import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.factory.Nd4j; + +import java.io.IOException; + +public class ImageWritableToINDArrayTransform implements Operation { + + private final int height; + private final int width; + private final NativeImageLoader loader; + + public ImageWritableToINDArrayTransform(int height, int width) { + this.height = height; + this.width = width; + this.loader = new NativeImageLoader(height, width); + } + + @Override + public INDArray transform(ImageWritable imageWritable) { + INDArray out = null; + try { + out = loader.asMatrix(imageWritable); + } catch (IOException e) { + e.printStackTrace(); + } + out = out.reshape(1, height, width); + INDArray compressed = out.castTo(DataType.UINT8); + return compressed; + } +} diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/observation/transform/operation/HistoryMergeTransform.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/observation/transform/operation/HistoryMergeTransform.java new file mode 100644 index 000000000..e27d1134c --- /dev/null +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/observation/transform/operation/HistoryMergeTransform.java @@ -0,0 +1,147 @@ +/******************************************************************************* + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ +package org.deeplearning4j.rl4j.observation.transform.operation; + +import org.datavec.api.transform.Operation; +import org.deeplearning4j.rl4j.helper.INDArrayHelper; +import org.deeplearning4j.rl4j.observation.transform.ResettableOperation; +import org.deeplearning4j.rl4j.observation.transform.operation.historymerge.CircularFifoStore; +import org.deeplearning4j.rl4j.observation.transform.operation.historymerge.HistoryMergeAssembler; +import org.deeplearning4j.rl4j.observation.transform.operation.historymerge.HistoryMergeElementStore; +import org.deeplearning4j.rl4j.observation.transform.operation.historymerge.HistoryStackAssembler; +import org.nd4j.linalg.api.ndarray.INDArray; + +/** + * The HistoryMergeTransform will accumulate features from incoming INDArrays and will assemble its content + * into a new INDArray containing a single example. + * + * This is used in scenarios where motion in an important element. + * + * There is a special case: + * * When the store is not full (not ready), the data from the incoming INDArray is stored but null is returned (will be interpreted as a skipped observation) + *
+ * The HistoryMergeTransform requires two sub components:
+ * 1) The {@link HistoryMergeElementStore HistoryMergeElementStore} that supervises what and how input INDArrays are kept. (ex.: Circular FIFO, trailing min/max/avg, etc...) + * The default is a Circular FIFO. + * 2) The {@link HistoryMergeAssembler HistoryMergeAssembler} that will assemble the store content into a resulting single INDArray. (ex.: stacked along a dimension, squashed into a single observation, etc...) + * The default is stacking along the dimension 0. + * + * @author Alexandre Boulanger + */ +public class HistoryMergeTransform implements Operation, ResettableOperation { + + private final HistoryMergeElementStore historyMergeElementStore; + private final HistoryMergeAssembler historyMergeAssembler; + private final boolean shouldStoreCopy; + private final boolean isFirstDimenstionBatch; + + private HistoryMergeTransform(Builder builder) { + this.historyMergeElementStore = builder.historyMergeElementStore; + this.historyMergeAssembler = builder.historyMergeAssembler; + this.shouldStoreCopy = builder.shouldStoreCopy; + this.isFirstDimenstionBatch = builder.isFirstDimenstionBatch; + } + + @Override + public INDArray transform(INDArray input) { + INDArray element; + if(isFirstDimenstionBatch) { + element = input.slice(0, 0); + } + else { + element = input; + } + + if(shouldStoreCopy) { + element = element.dup(); + } + + historyMergeElementStore.add(element); + if(!historyMergeElementStore.isReady()) { + return null; + } + + INDArray result = historyMergeAssembler.assemble(historyMergeElementStore.get()); + + return INDArrayHelper.forceCorrectShape(result); + } + + @Override + public void reset() { + historyMergeElementStore.reset(); + } + + public static Builder builder() { + return new Builder(); + } + + public static class Builder { + private HistoryMergeElementStore historyMergeElementStore; + private HistoryMergeAssembler historyMergeAssembler; + private boolean shouldStoreCopy = false; + private boolean isFirstDimenstionBatch = false; + + /** + * Default is {@link CircularFifoStore CircularFifoStore} + */ + public Builder elementStore(HistoryMergeElementStore store) { + this.historyMergeElementStore = store; + return this; + } + + /** + * Default is {@link HistoryStackAssembler HistoryStackAssembler} + */ + public Builder assembler(HistoryMergeAssembler assembler) { + this.historyMergeAssembler = assembler; + return this; + } + + /** + * If true, tell the HistoryMergeTransform to store copies of incoming INDArrays. + * (To prevent later in-place changes to a stored INDArray from changing what has been stored) + * + * Default is false + */ + public Builder shouldStoreCopy(boolean shouldStoreCopy) { + this.shouldStoreCopy = shouldStoreCopy; + return this; + } + + /** + * If true, tell the HistoryMergeTransform that the first dimension of the input INDArray is the batch count. + * When this is the case, the HistoryMergeTransform will slice the input like this [batch, height, width] -> [height, width] + * + * Default is false + */ + public Builder isFirstDimenstionBatch(boolean isFirstDimenstionBatch) { + this.isFirstDimenstionBatch = isFirstDimenstionBatch; + return this; + } + + public HistoryMergeTransform build() { + if(historyMergeElementStore == null) { + historyMergeElementStore = new CircularFifoStore(); + } + + if(historyMergeAssembler == null) { + historyMergeAssembler = new HistoryStackAssembler(); + } + + return new HistoryMergeTransform(this); + } + } +} diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/observation/transform/operation/SimpleNormalizationTransform.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/observation/transform/operation/SimpleNormalizationTransform.java new file mode 100644 index 000000000..592a1f86d --- /dev/null +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/observation/transform/operation/SimpleNormalizationTransform.java @@ -0,0 +1,44 @@ +/******************************************************************************* + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ +package org.deeplearning4j.rl4j.observation.transform.operation; + +import org.datavec.api.transform.Operation; +import org.nd4j.base.Preconditions; +import org.nd4j.linalg.api.ndarray.INDArray; + +public class SimpleNormalizationTransform implements Operation { + + private final double offset; + private final double divisor; + + public SimpleNormalizationTransform(double min, double max) { + Preconditions.checkArgument(min < max, "Min must be smaller than max."); + + this.offset = min; + this.divisor = (max - min); + } + + @Override + public INDArray transform(INDArray input) { + if(offset != 0.0) { + input.subi(offset); + } + + input.divi(divisor); + + return input; + } +} diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/observation/preprocessor/pooling/CircularFifoObservationPool.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/observation/transform/operation/historymerge/CircularFifoStore.java similarity index 52% rename from rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/observation/preprocessor/pooling/CircularFifoObservationPool.java rename to rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/observation/transform/operation/historymerge/CircularFifoStore.java index 6eb950e48..db1cbb2bd 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/observation/preprocessor/pooling/CircularFifoObservationPool.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/observation/transform/operation/historymerge/CircularFifoStore.java @@ -1,95 +1,82 @@ -/******************************************************************************* - * Copyright (c) 2015-2019 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.deeplearning4j.rl4j.observation.preprocessor.pooling; - -import org.apache.commons.collections4.queue.CircularFifoQueue; -import org.nd4j.base.Preconditions; -import org.nd4j.linalg.api.ndarray.INDArray; -import org.nd4j.linalg.factory.Nd4j; - -/** - * CircularFifoObservationPool is used with the PoolingDataSetPreProcessor. This pool is a first-in first-out queue - * with a fixed size that replaces its oldest element if full. - * - * @author Alexandre Boulanger - */ -public class CircularFifoObservationPool implements ObservationPool { - private static final int DEFAULT_POOL_SIZE = 4; - - private final CircularFifoQueue queue; - - private CircularFifoObservationPool(Builder builder) { - queue = new CircularFifoQueue<>(builder.poolSize); - } - - public CircularFifoObservationPool() - { - this(DEFAULT_POOL_SIZE); - } - - public CircularFifoObservationPool(int poolSize) - { - Preconditions.checkArgument(poolSize > 0, "The pool size must be at least 1, got %s", poolSize); - queue = new CircularFifoQueue<>(poolSize); - } - - /** - * Add an element to the pool, if this addition would make the pool to overflow, the added element replaces the oldest one. - * @param elem - */ - public void add(INDArray elem) { - queue.add(elem); - } - - /** - * @return The content of the pool, returned in order from oldest to newest. - */ - public INDArray[] get() { - int size = queue.size(); - INDArray[] array = new INDArray[size]; - for (int i = 0; i < size; ++i) { - array[i] = queue.get(i).castTo(Nd4j.dataType()); - } - return array; - } - - public boolean isAtFullCapacity() { - return queue.isAtFullCapacity(); - } - - @Override - public void reset() { - queue.clear(); - } - - public static Builder builder() { - return new Builder(); - } - - public static class Builder { - private int poolSize = DEFAULT_POOL_SIZE; - - public Builder poolSize(int poolSize) { - this.poolSize = poolSize; - return this; - } - - public CircularFifoObservationPool build() { - return new CircularFifoObservationPool(this); - } - } -} +/******************************************************************************* + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ +package org.deeplearning4j.rl4j.observation.transform.operation.historymerge; + +import org.apache.commons.collections4.queue.CircularFifoQueue; +import org.deeplearning4j.rl4j.observation.transform.operation.HistoryMergeTransform; +import org.nd4j.base.Preconditions; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.factory.Nd4j; + +/** + * CircularFifoStore is used with the {@link HistoryMergeTransform HistoryMergeTransform}. This store is a first-in first-out queue + * with a fixed size that replaces its oldest element if full. + * + * @author Alexandre Boulanger + */ +public class CircularFifoStore implements HistoryMergeElementStore { + private static final int DEFAULT_STORE_SIZE = 4; + + private final CircularFifoQueue queue; + + public CircularFifoStore() { + this(DEFAULT_STORE_SIZE); + } + + public CircularFifoStore(int size) { + Preconditions.checkArgument(size > 0, "The size must be at least 1, got %s", size); + queue = new CircularFifoQueue<>(size); + } + + /** + * Add an element to the store, if this addition would make the store to overflow, the new element replaces the oldest. + * @param elem + */ + @Override + public void add(INDArray elem) { + queue.add(elem); + } + + /** + * @return The content of the store, returned in order from oldest to newest. + */ + @Override + public INDArray[] get() { + int size = queue.size(); + INDArray[] array = new INDArray[size]; + for (int i = 0; i < size; ++i) { + array[i] = queue.get(i).castTo(Nd4j.dataType()); + } + return array; + } + + /** + * The CircularFifoStore needs to be completely filled before being ready. + * @return false when the number of elements in the store is less than the store capacity (default is 4) + */ + @Override + public boolean isReady() { + return queue.isAtFullCapacity(); + } + + /** + * Clears the store. + */ + @Override + public void reset() { + queue.clear(); + } +} diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/observation/transform/operation/historymerge/HistoryMergeAssembler.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/observation/transform/operation/historymerge/HistoryMergeAssembler.java new file mode 100644 index 000000000..0487d7c57 --- /dev/null +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/observation/transform/operation/historymerge/HistoryMergeAssembler.java @@ -0,0 +1,35 @@ +/******************************************************************************* + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ +package org.deeplearning4j.rl4j.observation.transform.operation.historymerge; + +import org.deeplearning4j.rl4j.observation.transform.operation.HistoryMergeTransform; +import org.nd4j.linalg.api.ndarray.INDArray; + +/** + * A HistoryMergeAssembler is used with the {@link HistoryMergeTransform HistoryMergeTransform}. This interface defines how the array of INDArray + * given by the {@link HistoryMergeElementStore HistoryMergeElementStore} is packaged into the single INDArray that will be + * returned by the HistoryMergeTransform + * + * @author Alexandre Boulanger + */ +public interface HistoryMergeAssembler { + /** + * Assemble an array of INDArray into a single INArray + * @param elements The input INDArray[] + * @return the assembled INDArray + */ + INDArray assemble(INDArray[] elements); +} diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/observation/transform/operation/historymerge/HistoryMergeElementStore.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/observation/transform/operation/historymerge/HistoryMergeElementStore.java new file mode 100644 index 000000000..04d61da45 --- /dev/null +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/observation/transform/operation/historymerge/HistoryMergeElementStore.java @@ -0,0 +1,51 @@ +/******************************************************************************* + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ +package org.deeplearning4j.rl4j.observation.transform.operation.historymerge; + +import org.deeplearning4j.rl4j.observation.transform.operation.HistoryMergeTransform; +import org.nd4j.linalg.api.ndarray.INDArray; + +/** + * HistoryMergeElementStore is used with the {@link HistoryMergeTransform HistoryMergeTransform}. Used to supervise how data from the + * HistoryMergeTransform is stored. + * + * @author Alexandre Boulanger + */ +public interface HistoryMergeElementStore { + /** + * Add an element into the store + * @param observation + */ + void add(INDArray observation); + + /** + * Get the content of the store + * @return the content of the store + */ + INDArray[] get(); + + /** + * Used to tell the HistoryMergeTransform that the store is ready. The HistoryMergeTransform will tell the {@link org.deeplearning4j.rl4j.observation.transform.TransformProcess TransformProcess} + * to skip the observation is the store is not ready. + * @return true if the store is ready + */ + boolean isReady(); + + /** + * Resets the store to an initial state. + */ + void reset(); +} diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/observation/preprocessor/pooling/ChannelStackPoolContentAssembler.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/observation/transform/operation/historymerge/HistoryStackAssembler.java similarity index 62% rename from rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/observation/preprocessor/pooling/ChannelStackPoolContentAssembler.java rename to rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/observation/transform/operation/historymerge/HistoryStackAssembler.java index d53f5c8c8..0559f25df 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/observation/preprocessor/pooling/ChannelStackPoolContentAssembler.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/observation/transform/operation/historymerge/HistoryStackAssembler.java @@ -1,53 +1,52 @@ -/******************************************************************************* - * Copyright (c) 2015-2019 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.deeplearning4j.rl4j.observation.preprocessor.pooling; - -import org.nd4j.linalg.api.ndarray.INDArray; -import org.nd4j.linalg.factory.Nd4j; - -/** - * ChannelStackPoolContentAssembler is used with the PoolingDataSetPreProcessor. This assembler will - * stack along the dimension 0. For example if the pool elements are of shape [ Height, Width ] - * the output will be of shape [ Stacked, Height, Width ] - * - * @author Alexandre Boulanger - */ -public class ChannelStackPoolContentAssembler implements PoolContentAssembler { - - /** - * Will return a new INDArray with one more dimension and with poolContent stacked along dimension 0. - * - * @param poolContent Array of INDArray - * @return A new INDArray with 1 more dimension than the input elements - */ - @Override - public INDArray assemble(INDArray[] poolContent) - { - // build the new shape - long[] elementShape = poolContent[0].shape(); - long[] newShape = new long[elementShape.length + 1]; - newShape[0] = poolContent.length; - System.arraycopy(elementShape, 0, newShape, 1, elementShape.length); - - // put pool elements in result - INDArray result = Nd4j.create(newShape); - for(int i = 0; i < poolContent.length; ++i) { - result.putRow(i, poolContent[i]); - } - return result; - } -} +/******************************************************************************* + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +package org.deeplearning4j.rl4j.observation.transform.operation.historymerge; + +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.factory.Nd4j; + +/** + * HistoryStackAssembler is used with the HistoryMergeTransform. This assembler will + * stack along the dimension 0. For example if the store elements are of shape [ Height, Width ] + * the output will be of shape [ Stacked, Height, Width ] + * + * @author Alexandre Boulanger + */ +public class HistoryStackAssembler implements HistoryMergeAssembler { + + /** + * Will return a new INDArray with one more dimension and with elements stacked along dimension 0. + * + * @param elements Array of INDArray + * @return A new INDArray with 1 more dimension than the input elements + */ + @Override + public INDArray assemble(INDArray[] elements) { + // build the new shape + long[] elementShape = elements[0].shape(); + long[] newShape = new long[elementShape.length + 1]; + newShape[0] = elements.length; + System.arraycopy(elementShape, 0, newShape, 1, elementShape.length); + + // stack the elements in result on the dimension 0 + INDArray result = Nd4j.create(newShape); + for(int i = 0; i < elements.length; ++i) { + result.putRow(i, elements[i]); + } + return result; + } +} diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/policy/Policy.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/policy/Policy.java index fb20a60ac..7719df612 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/policy/Policy.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/policy/Policy.java @@ -89,7 +89,7 @@ public abstract class Policy implements IPolicy { getNeuralNet().reset(); } - private > Learning.InitMdp refacInitMdp(LegacyMDPWrapper mdpWrapper, IHistoryProcessor hp, RefacEpochStepCounter epochStepCounter) { + protected > Learning.InitMdp refacInitMdp(LegacyMDPWrapper mdpWrapper, IHistoryProcessor hp, RefacEpochStepCounter epochStepCounter) { epochStepCounter.setCurrentEpochStep(0); double reward = 0; diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/util/LegacyMDPWrapper.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/util/LegacyMDPWrapper.java index 26546a923..b0f46ef57 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/util/LegacyMDPWrapper.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/util/LegacyMDPWrapper.java @@ -3,37 +3,95 @@ package org.deeplearning4j.rl4j.util; import lombok.AccessLevel; import lombok.Getter; import lombok.Setter; +import org.datavec.image.transform.ColorConversionTransform; +import org.datavec.image.transform.CropImageTransform; +import org.datavec.image.transform.MultiImageTransform; +import org.datavec.image.transform.ResizeImageTransform; import org.deeplearning4j.gym.StepReply; import org.deeplearning4j.rl4j.learning.EpochStepCounter; import org.deeplearning4j.rl4j.learning.IHistoryProcessor; import org.deeplearning4j.rl4j.mdp.MDP; import org.deeplearning4j.rl4j.observation.Observation; +import org.deeplearning4j.rl4j.observation.transform.TransformProcess; +import org.deeplearning4j.rl4j.observation.transform.filter.UniformSkippingFilter; +import org.deeplearning4j.rl4j.observation.transform.legacy.EncodableToINDArrayTransform; +import org.deeplearning4j.rl4j.observation.transform.legacy.EncodableToImageWritableTransform; +import org.deeplearning4j.rl4j.observation.transform.legacy.ImageWritableToINDArrayTransform; +import org.deeplearning4j.rl4j.observation.transform.operation.HistoryMergeTransform; +import org.deeplearning4j.rl4j.observation.transform.operation.SimpleNormalizationTransform; import org.deeplearning4j.rl4j.space.ActionSpace; import org.deeplearning4j.rl4j.space.Encodable; import org.deeplearning4j.rl4j.space.ObservationSpace; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; +import java.util.HashMap; +import java.util.Map; + +import static org.bytedeco.opencv.global.opencv_imgproc.COLOR_BGR2GRAY; + public class LegacyMDPWrapper> implements MDP { @Getter private final MDP wrappedMDP; @Getter private final WrapperObservationSpace observationSpace; + private final int[] shape; - @Getter(AccessLevel.PRIVATE) @Setter(AccessLevel.PUBLIC) + @Setter + private TransformProcess transformProcess; + + @Getter(AccessLevel.PRIVATE) private IHistoryProcessor historyProcessor; private final EpochStepCounter epochStepCounter; private int skipFrame = 1; - private int requiredFrame = 0; public LegacyMDPWrapper(MDP wrappedMDP, IHistoryProcessor historyProcessor, EpochStepCounter epochStepCounter) { this.wrappedMDP = wrappedMDP; - this.observationSpace = new WrapperObservationSpace(wrappedMDP.getObservationSpace().getShape()); + this.shape = wrappedMDP.getObservationSpace().getShape(); + this.observationSpace = new WrapperObservationSpace(shape); this.historyProcessor = historyProcessor; this.epochStepCounter = epochStepCounter; + + setHistoryProcessor(historyProcessor); + } + + public void setHistoryProcessor(IHistoryProcessor historyProcessor) { + this.historyProcessor = historyProcessor; + createTransformProcess(); + } + + private void createTransformProcess() { + IHistoryProcessor historyProcessor = getHistoryProcessor(); + + if(historyProcessor != null && shape.length == 3) { + int skipFrame = historyProcessor.getConf().getSkipFrame(); + + int finalHeight = historyProcessor.getConf().getCroppingHeight(); + int finalWidth = historyProcessor.getConf().getCroppingWidth(); + + transformProcess = TransformProcess.builder() + .filter(new UniformSkippingFilter(skipFrame)) + .transform("data", new EncodableToImageWritableTransform(shape[0], shape[1], shape[2])) + .transform("data", new MultiImageTransform( + new ResizeImageTransform(historyProcessor.getConf().getRescaledWidth(), historyProcessor.getConf().getRescaledHeight()), + new ColorConversionTransform(COLOR_BGR2GRAY), + new CropImageTransform(historyProcessor.getConf().getOffsetY(), historyProcessor.getConf().getOffsetX(), finalHeight, finalWidth) + )) + .transform("data", new ImageWritableToINDArrayTransform(finalHeight, finalWidth)) + .transform("data", new SimpleNormalizationTransform(0.0, 255.0)) + .transform("data", HistoryMergeTransform.builder() + .isFirstDimenstionBatch(true) + .build()) + .build("data"); + } + else { + transformProcess = TransformProcess.builder() + .transform("data", new EncodableToINDArrayTransform(shape)) + .build("data"); + } } @Override @@ -43,25 +101,17 @@ public class LegacyMDPWrapper> implements MDP channelsData = buildChannelsData(rawResetResponse); + return transformProcess.transform(channelsData, 0, false); } @Override @@ -71,32 +121,32 @@ public class LegacyMDPWrapper> implements MDP rawStepReply = wrappedMDP.step(a); INDArray rawObservation = getInput(rawStepReply.getObservation()); - int stepOfObservation = epochStepCounter.getCurrentEpochStep() + 1; - if(historyProcessor != null) { historyProcessor.record(rawObservation); - - if (stepOfObservation % skipFrame == 0) { - historyProcessor.add(rawObservation); - } } - Observation observation; - if(historyProcessor != null && stepOfObservation >= requiredFrame) { - observation = new Observation(historyProcessor.getHistory(), true); - observation.getData().muli(1.0 / historyProcessor.getScale()); - } - else { - observation = new Observation(new INDArray[] { rawObservation }, false); - } - - if(stepOfObservation % skipFrame != 0 || stepOfObservation < requiredFrame) { - observation.setSkipped(true); - } + int stepOfObservation = epochStepCounter.getCurrentEpochStep() + 1; + Map channelsData = buildChannelsData(rawStepReply.getObservation()); + Observation observation = transformProcess.transform(channelsData, stepOfObservation, rawStepReply.isDone()); return new StepReply(observation, rawStepReply.getReward(), rawStepReply.isDone(), rawStepReply.getInfo()); } + private void record(O obs) { + INDArray rawObservation = getInput(obs); + + IHistoryProcessor historyProcessor = getHistoryProcessor(); + if(historyProcessor != null) { + historyProcessor.record(rawObservation); + } + } + + private Map buildChannelsData(final O obs) { + return new HashMap() {{ + put("data", obs); + }}; + } + @Override public void close() { wrappedMDP.close(); diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/helper/INDArrayHelperTest.java b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/helper/INDArrayHelperTest.java new file mode 100644 index 000000000..9bfceadad --- /dev/null +++ b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/helper/INDArrayHelperTest.java @@ -0,0 +1,38 @@ +package org.deeplearning4j.rl4j.helper; + +import org.junit.Test; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.factory.Nd4j; + +import static org.junit.Assert.*; + +public class INDArrayHelperTest { + @Test + public void when_inputHasIncorrectShape_expect_outputWithCorrectShape() { + // Arrange + INDArray input = Nd4j.create(new double[] { 1.0, 2.0, 3.0}); + + // Act + INDArray output = INDArrayHelper.forceCorrectShape(input); + + // Assert + assertEquals(2, output.shape().length); + assertEquals(1, output.shape()[0]); + assertEquals(3, output.shape()[1]); + } + + @Test + public void when_inputHasCorrectShape_expect_outputWithSameShape() { + // Arrange + INDArray input = Nd4j.create(new double[] { 1.0, 2.0, 3.0}).reshape(1, 3); + + // Act + INDArray output = INDArrayHelper.forceCorrectShape(input); + + // Assert + assertEquals(2, output.shape().length); + assertEquals(1, output.shape()[0]); + assertEquals(3, output.shape()[1]); + } + +} diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/async/AsyncThreadDiscreteTest.java b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/async/AsyncThreadDiscreteTest.java index 8a0090b62..bc396502f 100644 --- a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/async/AsyncThreadDiscreteTest.java +++ b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/async/AsyncThreadDiscreteTest.java @@ -34,6 +34,7 @@ public class AsyncThreadDiscreteTest { MockPolicy policyMock = new MockPolicy(); MockAsyncConfiguration config = new MockAsyncConfiguration(5, 100, 0, 0, 2, 5,0, 0, 0, 0); TestAsyncThreadDiscrete sut = new TestAsyncThreadDiscrete(asyncGlobalMock, mdpMock, listeners, 0, 0, policyMock, config, hpMock); + sut.getLegacyMDPWrapper().setTransformProcess(MockMDP.buildTransformProcess(observationSpace.getShape(), hpConf.getSkipFrame(), hpConf.getHistoryLength())); // Act sut.run(); @@ -60,12 +61,6 @@ public class AsyncThreadDiscreteTest { assertEquals(2, asyncGlobalMock.enqueueCallCount); // HistoryProcessor - double[] expectedAddValues = new double[] { 0.0, 2.0, 4.0, 6.0, 8.0, 10.0, 12.0, 14.0, 16.0 }; - assertEquals(expectedAddValues.length, hpMock.addCalls.size()); - for(int i = 0; i < expectedAddValues.length; ++i) { - assertEquals(expectedAddValues[i], hpMock.addCalls.get(i).getDouble(0), 0.00001); - } - double[] expectedRecordValues = new double[] { 0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0, }; assertEquals(expectedRecordValues.length, hpMock.recordCalls.size()); for(int i = 0; i < expectedRecordValues.length; ++i) { diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/async/AsyncThreadTest.java b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/async/AsyncThreadTest.java index b01105294..3dea25936 100644 --- a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/async/AsyncThreadTest.java +++ b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/async/AsyncThreadTest.java @@ -138,6 +138,7 @@ public class AsyncThreadTest { asyncGlobal.setMaxLoops(numEpochs); listeners.add(listener); sut.setHistoryProcessor(historyProcessor); + sut.getLegacyMDPWrapper().setTransformProcess(MockMDP.buildTransformProcess(observationSpace.getShape(), hpConf.getSkipFrame(), hpConf.getHistoryLength())); } } @@ -209,7 +210,4 @@ public class AsyncThreadTest { int nstep; } } - - - } diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/ExpReplayTest.java b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/ExpReplayTest.java index 73b27776a..08c8ba24e 100644 --- a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/ExpReplayTest.java +++ b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/ExpReplayTest.java @@ -18,7 +18,7 @@ public class ExpReplayTest { ExpReplay sut = new ExpReplay(2, 1, randomMock); // Act - Transition transition = buildTransition(new Observation(new INDArray[] { Nd4j.create(1) }), + Transition transition = buildTransition(buildObservation(), 123, 234, new Observation(Nd4j.create(1))); sut.store(transition); List> results = sut.getBatch(1); @@ -36,11 +36,11 @@ public class ExpReplayTest { ExpReplay sut = new ExpReplay(2, 1, randomMock); // Act - Transition transition1 = buildTransition(new Observation(new INDArray[] { Nd4j.create(1) }), + Transition transition1 = buildTransition(buildObservation(), 1, 2, new Observation(Nd4j.create(1))); - Transition transition2 = buildTransition(new Observation(new INDArray[] { Nd4j.create(1) }), + Transition transition2 = buildTransition(buildObservation(), 3, 4, new Observation(Nd4j.create(1))); - Transition transition3 = buildTransition(new Observation(new INDArray[] { Nd4j.create(1) }), + Transition transition3 = buildTransition(buildObservation(), 5, 6, new Observation(Nd4j.create(1))); sut.store(transition1); sut.store(transition2); @@ -78,11 +78,11 @@ public class ExpReplayTest { ExpReplay sut = new ExpReplay(5, 1, randomMock); // Act - Transition transition1 = buildTransition(new Observation(new INDArray[] { Nd4j.create(1) }), + Transition transition1 = buildTransition(buildObservation(), 1, 2, new Observation(Nd4j.create(1))); - Transition transition2 = buildTransition(new Observation(new INDArray[] { Nd4j.create(1) }), + Transition transition2 = buildTransition(buildObservation(), 3, 4, new Observation(Nd4j.create(1))); - Transition transition3 = buildTransition(new Observation(new INDArray[] { Nd4j.create(1) }), + Transition transition3 = buildTransition(buildObservation(), 5, 6, new Observation(Nd4j.create(1))); sut.store(transition1); sut.store(transition2); @@ -100,11 +100,11 @@ public class ExpReplayTest { ExpReplay sut = new ExpReplay(5, 1, randomMock); // Act - Transition transition1 = buildTransition(new Observation(new INDArray[] { Nd4j.create(1) }), + Transition transition1 = buildTransition(buildObservation(), 1, 2, new Observation(Nd4j.create(1))); - Transition transition2 = buildTransition(new Observation(new INDArray[] { Nd4j.create(1) }), + Transition transition2 = buildTransition(buildObservation(), 3, 4, new Observation(Nd4j.create(1))); - Transition transition3 = buildTransition(new Observation(new INDArray[] { Nd4j.create(1) }), + Transition transition3 = buildTransition(buildObservation(), 5, 6, new Observation(Nd4j.create(1))); sut.store(transition1); sut.store(transition2); @@ -131,15 +131,15 @@ public class ExpReplayTest { ExpReplay sut = new ExpReplay(5, 1, randomMock); // Act - Transition transition1 = buildTransition(new Observation(new INDArray[] { Nd4j.create(1) }), + Transition transition1 = buildTransition(buildObservation(), 1, 2, new Observation(Nd4j.create(1))); - Transition transition2 = buildTransition(new Observation(new INDArray[] { Nd4j.create(1) }), + Transition transition2 = buildTransition(buildObservation(), 3, 4, new Observation(Nd4j.create(1))); - Transition transition3 = buildTransition(new Observation(new INDArray[] { Nd4j.create(1) }), + Transition transition3 = buildTransition(buildObservation(), 5, 6, new Observation(Nd4j.create(1))); - Transition transition4 = buildTransition(new Observation(new INDArray[] { Nd4j.create(1) }), + Transition transition4 = buildTransition(buildObservation(), 7, 8, new Observation(Nd4j.create(1))); - Transition transition5 = buildTransition(new Observation(new INDArray[] { Nd4j.create(1) }), + Transition transition5 = buildTransition(buildObservation(), 9, 10, new Observation(Nd4j.create(1))); sut.store(transition1); sut.store(transition2); @@ -168,15 +168,15 @@ public class ExpReplayTest { ExpReplay sut = new ExpReplay(5, 1, randomMock); // Act - Transition transition1 = buildTransition(new Observation(new INDArray[] { Nd4j.create(1) }), + Transition transition1 = buildTransition(buildObservation(), 1, 2, new Observation(Nd4j.create(1))); - Transition transition2 = buildTransition(new Observation(new INDArray[] { Nd4j.create(1) }), + Transition transition2 = buildTransition(buildObservation(), 3, 4, new Observation(Nd4j.create(1))); - Transition transition3 = buildTransition(new Observation(new INDArray[] { Nd4j.create(1) }), + Transition transition3 = buildTransition(buildObservation(), 5, 6, new Observation(Nd4j.create(1))); - Transition transition4 = buildTransition(new Observation(new INDArray[] { Nd4j.create(1) }), + Transition transition4 = buildTransition(buildObservation(), 7, 8, new Observation(Nd4j.create(1))); - Transition transition5 = buildTransition(new Observation(new INDArray[] { Nd4j.create(1) }), + Transition transition5 = buildTransition(buildObservation(), 9, 10, new Observation(Nd4j.create(1))); sut.store(transition1); sut.store(transition2); @@ -204,4 +204,8 @@ public class ExpReplayTest { return result; } + + private Observation buildObservation() { + return new Observation(Nd4j.create(1, 1)); + } } diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/TransitionTest.java b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/TransitionTest.java index 944e41a31..374b6a140 100644 --- a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/TransitionTest.java +++ b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/TransitionTest.java @@ -193,11 +193,11 @@ public class TransitionTest { Nd4j.create(obs[1]).reshape(1, 3), Nd4j.create(obs[2]).reshape(1, 3), }; - return new Observation(history); + return new Observation(Nd4j.concat(0, history)); } private Observation buildObservation(double[] obs) { - return new Observation(new INDArray[] { Nd4j.create(obs).reshape(1, 3) }); + return new Observation(Nd4j.create(obs).reshape(1, 3)); } private Observation buildNextObservation(double[][] obs, double[] nextObs) { @@ -206,7 +206,7 @@ public class TransitionTest { Nd4j.create(obs[0]).reshape(1, 3), Nd4j.create(obs[1]).reshape(1, 3), }; - return new Observation(nextHistory); + return new Observation(Nd4j.concat(0, nextHistory)); } private Transition buildTransition(Observation observation, int action, double reward, Observation nextObservation) { diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/QLearningDiscreteTest.java b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/QLearningDiscreteTest.java index ee8b365f0..58aaab297 100644 --- a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/QLearningDiscreteTest.java +++ b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/QLearningDiscreteTest.java @@ -50,6 +50,7 @@ public class QLearningDiscreteTest { IHistoryProcessor.Configuration hpConf = new IHistoryProcessor.Configuration(5, 4, 4, 4, 4, 0, 0, 2); MockHistoryProcessor hp = new MockHistoryProcessor(hpConf); sut.setHistoryProcessor(hp); + sut.getLegacyMDPWrapper().setTransformProcess(MockMDP.buildTransformProcess(observationSpace.getShape(), hpConf.getSkipFrame(), hpConf.getHistoryLength())); List> results = new ArrayList<>(); // Act @@ -62,11 +63,7 @@ public class QLearningDiscreteTest { for(int i = 0; i < expectedRecords.length; ++i) { assertEquals(expectedRecords[i], hp.recordCalls.get(i).getDouble(0), 0.0001); } - double[] expectedAdds = new double[] { 0.0, 2.0, 4.0, 6.0, 8.0, 10.0, 12.0, 14.0, 16.0, 18.0, 20.0, 22.0, 24.0 }; - assertEquals(expectedAdds.length, hp.addCalls.size()); - for(int i = 0; i < expectedAdds.length; ++i) { - assertEquals(expectedAdds[i], hp.addCalls.get(i).getDouble(0), 0.0001); - } + assertEquals(0, hp.startMonitorCallCount); assertEquals(0, hp.stopMonitorCallCount); diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/TDTargetAlgorithm/DoubleDQNTest.java b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/TDTargetAlgorithm/DoubleDQNTest.java index 760666f33..798bddf0d 100644 --- a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/TDTargetAlgorithm/DoubleDQNTest.java +++ b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/TDTargetAlgorithm/DoubleDQNTest.java @@ -106,7 +106,7 @@ public class DoubleDQNTest { } private Observation buildObservation(double[] data) { - return new Observation(new INDArray[]{Nd4j.create(data).reshape(1, 2)}); + return new Observation(Nd4j.create(data).reshape(1, 2)); } private Transition builtTransition(Observation observation, Integer action, double reward, boolean isTerminal, Observation nextObservation) { diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/TDTargetAlgorithm/StandardDQNTest.java b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/TDTargetAlgorithm/StandardDQNTest.java index c540646a8..3e3701669 100644 --- a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/TDTargetAlgorithm/StandardDQNTest.java +++ b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/TDTargetAlgorithm/StandardDQNTest.java @@ -105,7 +105,7 @@ public class StandardDQNTest { } private Observation buildObservation(double[] data) { - return new Observation(new INDArray[]{Nd4j.create(data).reshape(1, 2)}); + return new Observation(Nd4j.create(data).reshape(1, 2)); } private Transition buildTransition(Observation observation, Integer action, double reward, boolean isTerminal, Observation nextObservation) { diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/observation/preprocessor/PoolingDataSetPreProcessorTest.java b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/observation/preprocessor/PoolingDataSetPreProcessorTest.java deleted file mode 100644 index db239b0f3..000000000 --- a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/observation/preprocessor/PoolingDataSetPreProcessorTest.java +++ /dev/null @@ -1,164 +0,0 @@ -package org.deeplearning4j.rl4j.observation.preprocessor; - -import org.deeplearning4j.rl4j.observation.preprocessor.pooling.ObservationPool; -import org.deeplearning4j.rl4j.observation.preprocessor.pooling.PoolContentAssembler; -import org.junit.Assert; -import org.junit.Test; -import org.nd4j.linalg.api.ndarray.INDArray; -import org.nd4j.linalg.dataset.DataSet; -import org.nd4j.linalg.factory.Nd4j; - -import static junit.framework.TestCase.assertTrue; -import static org.junit.Assert.assertEquals; - -public class PoolingDataSetPreProcessorTest { - - @Test(expected = NullPointerException.class) - public void when_dataSetIsNull_expect_NullPointerException() { - // Assemble - PoolingDataSetPreProcessor sut = PoolingDataSetPreProcessor.builder().build(); - - // Act - sut.preProcess(null); - } - - @Test(expected = IllegalArgumentException.class) - public void when_dataSetHasMoreThanOneExample_expect_IllegalArgumentException() { - // Assemble - PoolingDataSetPreProcessor sut = PoolingDataSetPreProcessor.builder().build(); - - // Act - sut.preProcess(new DataSet(Nd4j.rand(new long[] { 2, 2, 2 }), null)); - } - - @Test - public void when_dataSetIsEmpty_expect_EmptyDataSet() { - // Assemble - PoolingDataSetPreProcessor sut = PoolingDataSetPreProcessor.builder().build(); - DataSet ds = new DataSet(null, null); - - // Act - sut.preProcess(ds); - - // Assert - Assert.assertTrue(ds.isEmpty()); - } - - @Test - public void when_builderHasNoPoolOrAssembler_expect_defaultPoolBehavior() { - // Arrange - PoolingDataSetPreProcessor sut = PoolingDataSetPreProcessor.builder().build(); - DataSet[] observations = new DataSet[5]; - INDArray[] inputs = new INDArray[5]; - - - // Act - for(int i = 0; i < 5; ++i) { - inputs[i] = Nd4j.rand(new long[] { 1, 2, 2 }); - DataSet input = new DataSet(inputs[i], null); - sut.preProcess(input); - observations[i] = input; - } - - // Assert - assertTrue(observations[0].isEmpty()); - assertTrue(observations[1].isEmpty()); - assertTrue(observations[2].isEmpty()); - - for(int i = 0; i < 4; ++i) { - assertEquals(inputs[i].getDouble(new int[] { 0, 0, 0 }), observations[3].getFeatures().getDouble(new int[] { 0, i, 0, 0 }), 0.0001); - assertEquals(inputs[i].getDouble(new int[] { 0, 0, 1 }), observations[3].getFeatures().getDouble(new int[] { 0, i, 0, 1 }), 0.0001); - assertEquals(inputs[i].getDouble(new int[] { 0, 1, 0 }), observations[3].getFeatures().getDouble(new int[] { 0, i, 1, 0 }), 0.0001); - assertEquals(inputs[i].getDouble(new int[] { 0, 1, 1 }), observations[3].getFeatures().getDouble(new int[] { 0, i, 1, 1 }), 0.0001); - } - - for(int i = 0; i < 4; ++i) { - assertEquals(inputs[i+1].getDouble(new int[] { 0, 0, 0 }), observations[4].getFeatures().getDouble(new int[] { 0, i, 0, 0 }), 0.0001); - assertEquals(inputs[i+1].getDouble(new int[] { 0, 0, 1 }), observations[4].getFeatures().getDouble(new int[] { 0, i, 0, 1 }), 0.0001); - assertEquals(inputs[i+1].getDouble(new int[] { 0, 1, 0 }), observations[4].getFeatures().getDouble(new int[] { 0, i, 1, 0 }), 0.0001); - assertEquals(inputs[i+1].getDouble(new int[] { 0, 1, 1 }), observations[4].getFeatures().getDouble(new int[] { 0, i, 1, 1 }), 0.0001); - } - - } - - @Test - public void when_builderHasPoolAndAssembler_expect_paramPoolAndAssemblerAreUsed() { - // Arrange - INDArray input = Nd4j.rand(1, 1); - TestObservationPool pool = new TestObservationPool(); - TestPoolContentAssembler assembler = new TestPoolContentAssembler(); - PoolingDataSetPreProcessor sut = PoolingDataSetPreProcessor.builder() - .observablePool(pool) - .poolContentAssembler(assembler) - .build(); - - // Act - sut.preProcess(new DataSet(input, null)); - - // Assert - assertTrue(pool.isAtFullCapacityCalled); - assertTrue(pool.isGetCalled); - assertEquals(input.getDouble(0), pool.observation.getDouble(0), 0.0); - assertTrue(assembler.assembleIsCalled); - } - - @Test - public void when_pastInputChanges_expect_outputNotChanged() { - // Arrange - INDArray input = Nd4j.zeros(1, 1); - TestObservationPool pool = new TestObservationPool(); - TestPoolContentAssembler assembler = new TestPoolContentAssembler(); - PoolingDataSetPreProcessor sut = PoolingDataSetPreProcessor.builder() - .observablePool(pool) - .poolContentAssembler(assembler) - .build(); - - // Act - sut.preProcess(new DataSet(input, null)); - input.putScalar(0, 0, 1.0); - - // Assert - assertEquals(0.0, pool.observation.getDouble(0), 0.0); - } - - private static class TestObservationPool implements ObservationPool { - - public INDArray observation; - public boolean isGetCalled; - public boolean isAtFullCapacityCalled; - private boolean isResetCalled; - - @Override - public void add(INDArray observation) { - this.observation = observation; - } - - @Override - public INDArray[] get() { - isGetCalled = true; - return new INDArray[0]; - } - - @Override - public boolean isAtFullCapacity() { - isAtFullCapacityCalled = true; - return true; - } - - @Override - public void reset() { - isResetCalled = true; - } - } - - private static class TestPoolContentAssembler implements PoolContentAssembler { - - public boolean assembleIsCalled; - - @Override - public INDArray assemble(INDArray[] poolContent) { - assembleIsCalled = true; - return Nd4j.create(1, 1); - } - } -} diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/observation/preprocessor/SkippingDataSetPreProcessorTest.java b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/observation/preprocessor/SkippingDataSetPreProcessorTest.java deleted file mode 100644 index 3f1de3426..000000000 --- a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/observation/preprocessor/SkippingDataSetPreProcessorTest.java +++ /dev/null @@ -1,70 +0,0 @@ -package org.deeplearning4j.rl4j.observation.preprocessor; - -import org.junit.Test; -import org.nd4j.linalg.dataset.DataSet; -import org.nd4j.linalg.factory.Nd4j; - -import static org.junit.Assert.assertFalse; -import static org.junit.Assert.assertTrue; - -public class SkippingDataSetPreProcessorTest { - @Test(expected = IllegalArgumentException.class) - public void when_ctorSkipFrameIsZero_expect_IllegalArgumentException() { - SkippingDataSetPreProcessor sut = new SkippingDataSetPreProcessor(0); - } - - @Test(expected = IllegalArgumentException.class) - public void when_builderSkipFrameIsZero_expect_IllegalArgumentException() { - SkippingDataSetPreProcessor sut = SkippingDataSetPreProcessor.builder() - .skipFrame(0) - .build(); - } - - @Test - public void when_skipFrameIs3_expect_Skip2OutOf3() { - // Arrange - SkippingDataSetPreProcessor sut = SkippingDataSetPreProcessor.builder() - .skipFrame(3) - .build(); - DataSet[] results = new DataSet[4]; - - // Act - for(int i = 0; i < 4; ++i) { - results[i] = new DataSet(Nd4j.create(new double[] { 123.0 }), null); - sut.preProcess(results[i]); - } - - // Assert - assertFalse(results[0].isEmpty()); - assertTrue(results[1].isEmpty()); - assertTrue(results[2].isEmpty()); - assertFalse(results[3].isEmpty()); - } - - @Test - public void when_resetIsCalled_expect_skippingIsReset() { - // Arrange - SkippingDataSetPreProcessor sut = SkippingDataSetPreProcessor.builder() - .skipFrame(3) - .build(); - DataSet[] results = new DataSet[4]; - - // Act - results[0] = new DataSet(Nd4j.create(new double[] { 123.0 }), null); - results[1] = new DataSet(Nd4j.create(new double[] { 123.0 }), null); - results[2] = new DataSet(Nd4j.create(new double[] { 123.0 }), null); - results[3] = new DataSet(Nd4j.create(new double[] { 123.0 }), null); - - sut.preProcess(results[0]); - sut.preProcess(results[1]); - sut.reset(); - sut.preProcess(results[2]); - sut.preProcess(results[3]); - - // Assert - assertFalse(results[0].isEmpty()); - assertTrue(results[1].isEmpty()); - assertFalse(results[2].isEmpty()); - assertTrue(results[3].isEmpty()); - } -} diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/observation/preprocessor/pooling/ChannelStackPoolContentAssemblerTest.java b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/observation/preprocessor/pooling/ChannelStackPoolContentAssemblerTest.java deleted file mode 100644 index de0db015c..000000000 --- a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/observation/preprocessor/pooling/ChannelStackPoolContentAssemblerTest.java +++ /dev/null @@ -1,41 +0,0 @@ -package org.deeplearning4j.rl4j.observation.preprocessor.pooling; - -import org.junit.Test; -import org.nd4j.linalg.api.ndarray.INDArray; -import org.nd4j.linalg.factory.Nd4j; - -import static org.junit.Assert.assertEquals; - -public class ChannelStackPoolContentAssemblerTest { - - @Test - public void when_assemble_expect_poolContentStackedOnChannel() { - // Assemble - ChannelStackPoolContentAssembler sut = new ChannelStackPoolContentAssembler(); - INDArray[] poolContent = new INDArray[] { - Nd4j.rand(2, 2), - Nd4j.rand(2, 2), - }; - - // Act - INDArray result = sut.assemble(poolContent); - - // Assert - assertEquals(3, result.shape().length); - assertEquals(2, result.shape()[0]); - assertEquals(2, result.shape()[1]); - assertEquals(2, result.shape()[2]); - - assertEquals(poolContent[0].getDouble(0, 0), result.getDouble(0, 0, 0), 0.0001); - assertEquals(poolContent[0].getDouble(0, 1), result.getDouble(0, 0, 1), 0.0001); - assertEquals(poolContent[0].getDouble(1, 0), result.getDouble(0, 1, 0), 0.0001); - assertEquals(poolContent[0].getDouble(1, 1), result.getDouble(0, 1, 1), 0.0001); - - assertEquals(poolContent[1].getDouble(0, 0), result.getDouble(1, 0, 0), 0.0001); - assertEquals(poolContent[1].getDouble(0, 1), result.getDouble(1, 0, 1), 0.0001); - assertEquals(poolContent[1].getDouble(1, 0), result.getDouble(1, 1, 0), 0.0001); - assertEquals(poolContent[1].getDouble(1, 1), result.getDouble(1, 1, 1), 0.0001); - - } - -} diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/observation/preprocessor/pooling/CircularFifoObservationPoolTest.java b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/observation/preprocessor/pooling/CircularFifoObservationPoolTest.java deleted file mode 100644 index 88e7b33dd..000000000 --- a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/observation/preprocessor/pooling/CircularFifoObservationPoolTest.java +++ /dev/null @@ -1,100 +0,0 @@ -package org.deeplearning4j.rl4j.observation.preprocessor.pooling; - -import org.junit.Test; -import org.nd4j.linalg.api.ndarray.INDArray; -import org.nd4j.linalg.factory.Nd4j; - -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertFalse; -import static org.junit.Assert.assertTrue; - -public class CircularFifoObservationPoolTest { - - @Test(expected = IllegalArgumentException.class) - public void when_poolSizeZeroOrLess_expect_IllegalArgumentException() { - CircularFifoObservationPool sut = new CircularFifoObservationPool(0); - } - - @Test - public void when_poolIsEmpty_expect_NotReady() { - // Assemble - CircularFifoObservationPool sut = new CircularFifoObservationPool(); - - // Act - boolean isReady = sut.isAtFullCapacity(); - - // Assert - assertFalse(isReady); - } - - @Test - public void when_notEnoughElementsInPool_expect_notReady() { - // Assemble - CircularFifoObservationPool sut = new CircularFifoObservationPool(); - sut.add(Nd4j.create(new double[] { 123.0 })); - - // Act - boolean isReady = sut.isAtFullCapacity(); - - // Assert - assertFalse(isReady); - } - - @Test - public void when_enoughElementsInPool_expect_ready() { - // Assemble - CircularFifoObservationPool sut = CircularFifoObservationPool.builder() - .poolSize(2) - .build(); - sut.add(Nd4j.createFromArray(123.0)); - sut.add(Nd4j.createFromArray(123.0)); - - // Act - boolean isReady = sut.isAtFullCapacity(); - - // Assert - assertTrue(isReady); - } - - @Test - public void when_addMoreThanSize_expect_getReturnOnlyLastElements() { - // Assemble - CircularFifoObservationPool sut = CircularFifoObservationPool.builder().build(); - sut.add(Nd4j.createFromArray(0.0)); - sut.add(Nd4j.createFromArray(1.0)); - sut.add(Nd4j.createFromArray(2.0)); - sut.add(Nd4j.createFromArray(3.0)); - sut.add(Nd4j.createFromArray(4.0)); - sut.add(Nd4j.createFromArray(5.0)); - sut.add(Nd4j.createFromArray(6.0)); - - // Act - INDArray[] result = sut.get(); - - // Assert - assertEquals(3.0, result[0].getDouble(0), 0.0); - assertEquals(4.0, result[1].getDouble(0), 0.0); - assertEquals(5.0, result[2].getDouble(0), 0.0); - assertEquals(6.0, result[3].getDouble(0), 0.0); - } - - @Test - public void when_resetIsCalled_expect_poolContentFlushed() { - // Assemble - CircularFifoObservationPool sut = CircularFifoObservationPool.builder().build(); - sut.add(Nd4j.createFromArray(0.0)); - sut.add(Nd4j.createFromArray(1.0)); - sut.add(Nd4j.createFromArray(2.0)); - sut.add(Nd4j.createFromArray(3.0)); - sut.add(Nd4j.createFromArray(4.0)); - sut.add(Nd4j.createFromArray(5.0)); - sut.add(Nd4j.createFromArray(6.0)); - sut.reset(); - - // Act - INDArray[] result = sut.get(); - - // Assert - assertEquals(0, result.length); - } -} diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/observation/transform/TransformProcessTest.java b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/observation/transform/TransformProcessTest.java new file mode 100644 index 000000000..fe79bdfc7 --- /dev/null +++ b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/observation/transform/TransformProcessTest.java @@ -0,0 +1,378 @@ +package org.deeplearning4j.rl4j.observation.transform; + +import org.deeplearning4j.rl4j.observation.Observation; +import org.junit.Test; +import org.nd4j.linalg.dataset.api.DataSet; +import org.nd4j.linalg.dataset.api.DataSetPreProcessor; +import org.nd4j.linalg.factory.Nd4j; +import org.datavec.api.transform.Operation; + +import java.util.HashMap; +import java.util.Map; + +import static org.junit.Assert.*; + +public class TransformProcessTest { + @Test(expected = IllegalArgumentException.class) + public void when_noChannelNameIsSuppliedToBuild_expect_exception() { + // Arrange + TransformProcess.builder().build(); + } + + @Test(expected = IllegalArgumentException.class) + public void when_callingTransformWithNullArg_expect_exception() { + // Arrange + TransformProcess sut = TransformProcess.builder() + .build("test"); + + // Act + sut.transform(null, 0, false); + } + + @Test(expected = IllegalArgumentException.class) + public void when_callingTransformWithEmptyChannelData_expect_exception() { + // Arrange + TransformProcess sut = TransformProcess.builder() + .build("test"); + Map channelsData = new HashMap(); + + // Act + sut.transform(channelsData, 0, false); + } + + @Test(expected = NullPointerException.class) + public void when_addingNullFilter_expect_nullException() { + // Act + TransformProcess.builder().filter(null); + } + + @Test + public void when_fileteredOut_expect_skippedObservationAndFollowingOperationsSkipped() { + // Arrange + IntegerTransformOperationMock transformOperationMock = new IntegerTransformOperationMock(); + TransformProcess sut = TransformProcess.builder() + .filter(new FilterOperationMock(true)) + .transform("test", transformOperationMock) + .build("test"); + Map channelsData = new HashMap() {{ + put("test", 1); + }}; + + // Act + Observation result = sut.transform(channelsData, 0, false); + + // Assert + assertTrue(result.isSkipped()); + assertFalse(transformOperationMock.isCalled); + } + + @Test(expected = NullPointerException.class) + public void when_addingTransformOnNullChannel_expect_nullException() { + // Act + TransformProcess.builder().transform(null, new IntegerTransformOperationMock()); + } + + @Test(expected = NullPointerException.class) + public void when_addingTransformWithNullTransform_expect_nullException() { + // Act + TransformProcess.builder().transform("test", null); + } + + @Test + public void when_transformIsCalled_expect_channelDataTransformedInSameOrder() { + // Arrange + TransformProcess sut = TransformProcess.builder() + .filter(new FilterOperationMock(false)) + .transform("test", new IntegerTransformOperationMock()) + .transform("test", new ToDataSetTransformOperationMock()) + .build("test"); + Map channelsData = new HashMap() {{ + put("test", 1); + }}; + + // Act + Observation result = sut.transform(channelsData, 0, false); + + // Assert + assertFalse(result.isSkipped()); + assertEquals(-1.0, result.getData().getDouble(0), 0.00001); + } + + @Test(expected = NullPointerException.class) + public void when_addingPreProcessOnNullChannel_expect_nullException() { + // Act + TransformProcess.builder().preProcess(null, new DataSetPreProcessorMock()); + } + + @Test(expected = NullPointerException.class) + public void when_addingPreProcessWithNullTransform_expect_nullException() { + // Act + TransformProcess.builder().transform("test", null); + } + + @Test + public void when_preProcessIsCalled_expect_channelDataPreProcessedInSameOrder() { + // Arrange + TransformProcess sut = TransformProcess.builder() + .filter(new FilterOperationMock(false)) + .transform("test", new IntegerTransformOperationMock()) + .transform("test", new ToDataSetTransformOperationMock()) + .preProcess("test", new DataSetPreProcessorMock()) + .build("test"); + Map channelsData = new HashMap() {{ + put("test", 1); + }}; + + // Act + Observation result = sut.transform(channelsData, 0, false); + + // Assert + assertFalse(result.isSkipped()); + assertEquals(1, result.getData().shape().length); + assertEquals(1, result.getData().shape()[0]); + assertEquals(-10.0, result.getData().getDouble(0), 0.00001); + } + + @Test(expected = IllegalStateException.class) + public void when_transformingNullData_expect_exception() { + // Arrange + TransformProcess sut = TransformProcess.builder() + .transform("test", new IntegerTransformOperationMock()) + .build("test"); + Map channelsData = new HashMap() {{ + put("test", 1); + }}; + + // Act + Observation result = sut.transform(channelsData, 0, false); + } + + @Test(expected = IllegalArgumentException.class) + public void when_transformingAndChannelsNotDataSet_expect_exception() { + // Arrange + TransformProcess sut = TransformProcess.builder() + .preProcess("test", new DataSetPreProcessorMock()) + .build("test"); + + // Act + Observation result = sut.transform(null, 0, false); + } + + + @Test(expected = IllegalArgumentException.class) + public void when_transformingAndChannelsEmptyDataSet_expect_exception() { + // Arrange + TransformProcess sut = TransformProcess.builder() + .preProcess("test", new DataSetPreProcessorMock()) + .build("test"); + Map channelsData = new HashMap(); + + // Act + Observation result = sut.transform(channelsData, 0, false); + } + + @Test(expected = IllegalArgumentException.class) + public void when_buildIsCalledWithoutChannelNames_expect_exception() { + // Act + TransformProcess.builder().build(); + } + + @Test(expected = NullPointerException.class) + public void when_buildIsCalledWithNullChannelName_expect_exception() { + // Act + TransformProcess.builder().build(null); + } + + @Test + public void when_resetIsCalled_expect_resettableAreReset() { + // Arrange + ResettableTransformOperationMock resettableOperation = new ResettableTransformOperationMock(); + TransformProcess sut = TransformProcess.builder() + .filter(new FilterOperationMock(false)) + .transform("test", new IntegerTransformOperationMock()) + .transform("test", resettableOperation) + .build("test"); + + // Act + sut.reset(); + + // Assert + assertTrue(resettableOperation.isResetCalled); + } + + @Test + public void when_buildIsCalledAndAllChannelsAreDataSets_expect_observation() { + // Arrange + TransformProcess sut = TransformProcess.builder() + .transform("test", new ToDataSetTransformOperationMock()) + .build("test"); + Map channelsData = new HashMap() {{ + put("test", 1); + }}; + + // Act + Observation result = sut.transform(channelsData, 123, true); + + // Assert + assertFalse(result.isSkipped()); + + assertEquals(1.0, result.getData().getDouble(0), 0.00001); + } + + @Test + public void when_buildIsCalledAndAllChannelsAreINDArrays_expect_observation() { + // Arrange + TransformProcess sut = TransformProcess.builder() + .build("test"); + Map channelsData = new HashMap() {{ + put("test", Nd4j.create(new double[] { 1.0 })); + }}; + + // Act + Observation result = sut.transform(channelsData, 123, true); + + // Assert + assertFalse(result.isSkipped()); + + assertEquals(1.0, result.getData().getDouble(0), 0.00001); + } + + @Test(expected = IllegalStateException.class) + public void when_buildIsCalledAndChannelsNotDataSetsOrINDArrays_expect_exception() { + // Arrange + TransformProcess sut = TransformProcess.builder() + .build("test"); + Map channelsData = new HashMap() {{ + put("test", 1); + }}; + + // Act + Observation result = sut.transform(channelsData, 123, true); + } + + @Test(expected = NullPointerException.class) + public void when_channelDataIsNull_expect_exception() { + // Arrange + TransformProcess sut = TransformProcess.builder() + .transform("test", new IntegerTransformOperationMock()) + .build("test"); + Map channelsData = new HashMap() {{ + put("test", null); + }}; + + // Act + sut.transform(channelsData, 0, false); + } + + @Test(expected = IllegalArgumentException.class) + public void when_transformAppliedOnChannelNotInMap_expect_exception() { + // Arrange + TransformProcess sut = TransformProcess.builder() + .transform("test", new IntegerTransformOperationMock()) + .build("test"); + Map channelsData = new HashMap() {{ + put("not-test", 1); + }}; + + // Act + sut.transform(channelsData, 0, false); + } + + @Test(expected = IllegalArgumentException.class) + public void when_preProcessAppliedOnChannelNotInMap_expect_exception() { + // Arrange + TransformProcess sut = TransformProcess.builder() + .preProcess("test", new DataSetPreProcessorMock()) + .build("test"); + Map channelsData = new HashMap() {{ + put("not-test", 1); + }}; + + // Act + sut.transform(channelsData, 0, false); + } + + @Test(expected = IllegalArgumentException.class) + public void when_buildContainsChannelNotInMap_expect_exception() { + // Arrange + TransformProcess sut = TransformProcess.builder() + .transform("test", new IntegerTransformOperationMock()) + .build("not-test"); + Map channelsData = new HashMap() {{ + put("test", 1); + }}; + + // Act + sut.transform(channelsData, 0, false); + } + + @Test(expected = IllegalArgumentException.class) + public void when_preProcessNotAppliedOnDataSet_expect_exception() { + // Arrange + TransformProcess sut = TransformProcess.builder() + .preProcess("test", new DataSetPreProcessorMock()) + .build("test"); + Map channelsData = new HashMap() {{ + put("test", 1); + }}; + + // Act + sut.transform(channelsData, 0, false); + } + + private static class FilterOperationMock implements FilterOperation { + + private final boolean skipped; + + public FilterOperationMock(boolean skipped) { + this.skipped = skipped; + } + + @Override + public boolean isSkipped(Map channelsData, int currentObservationStep, boolean isFinalObservation) { + return skipped; + } + } + + private static class IntegerTransformOperationMock implements Operation { + + public boolean isCalled = false; + + @Override + public Integer transform(Integer input) { + isCalled = true; + return -input; + } + } + + private static class ToDataSetTransformOperationMock implements Operation { + + @Override + public DataSet transform(Integer input) { + return new org.nd4j.linalg.dataset.DataSet(Nd4j.create(new double[] { input }), null); + } + } + + private static class ResettableTransformOperationMock implements Operation, ResettableOperation { + + private boolean isResetCalled = false; + + @Override + public Integer transform(Integer input) { + return input * 10; + } + + @Override + public void reset() { + isResetCalled = true; + } + } + + private static class DataSetPreProcessorMock implements DataSetPreProcessor { + + @Override + public void preProcess(DataSet dataSet) { + dataSet.getFeatures().muli(10.0); + } + } +} diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/observation/transform/filter/UniformSkippingFilterTest.java b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/observation/transform/filter/UniformSkippingFilterTest.java new file mode 100644 index 000000000..3aa5a17cf --- /dev/null +++ b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/observation/transform/filter/UniformSkippingFilterTest.java @@ -0,0 +1,54 @@ +package org.deeplearning4j.rl4j.observation.transform.filter; + +import org.deeplearning4j.rl4j.observation.transform.FilterOperation; +import org.junit.Test; + +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertTrue; + +public class UniformSkippingFilterTest { + + @Test(expected = IllegalArgumentException.class) + public void when_negativeSkipFrame_expect_exception() { + // Act + new UniformSkippingFilter(-1); + } + + @Test + public void when_skippingIs4_expect_firstNotSkippedOther3Skipped() { + // Assemble + FilterOperation sut = new UniformSkippingFilter(4); + boolean[] isSkipped = new boolean[8]; + + // Act + for(int i = 0; i < 8; ++i) { + isSkipped[i] = sut.isSkipped(null, i, false); + } + + // Assert + assertFalse(isSkipped[0]); + assertTrue(isSkipped[1]); + assertTrue(isSkipped[2]); + assertTrue(isSkipped[3]); + + assertFalse(isSkipped[4]); + assertTrue(isSkipped[5]); + assertTrue(isSkipped[6]); + assertTrue(isSkipped[7]); + } + + @Test + public void when_isLastObservation_expect_observationNotSkipped() { + // Assemble + FilterOperation sut = new UniformSkippingFilter(4); + + // Act + boolean isSkippedNotLastObservation = sut.isSkipped(null, 1, false); + boolean isSkippedLastObservation = sut.isSkipped(null, 1, true); + + // Assert + assertTrue(isSkippedNotLastObservation); + assertFalse(isSkippedLastObservation); + } + +} diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/observation/transform/operation/HistoryMergeTransformTest.java b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/observation/transform/operation/HistoryMergeTransformTest.java new file mode 100644 index 000000000..9c7a172bb --- /dev/null +++ b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/observation/transform/operation/HistoryMergeTransformTest.java @@ -0,0 +1,166 @@ +package org.deeplearning4j.rl4j.observation.transform.operation; + +import org.deeplearning4j.rl4j.observation.transform.operation.historymerge.HistoryMergeAssembler; +import org.deeplearning4j.rl4j.observation.transform.operation.historymerge.HistoryMergeElementStore; +import org.junit.Test; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.factory.Nd4j; + +import static org.junit.Assert.*; + +public class HistoryMergeTransformTest { + + @Test + public void when_firstDimensionIsNotBatch_expect_observationAddedAsIs() { + // Arrange + MockStore store = new MockStore(false); + HistoryMergeTransform sut = HistoryMergeTransform.builder() + .isFirstDimenstionBatch(false) + .elementStore(store) + .build(); + INDArray input = Nd4j.create(new double[] { 1.0, 2.0, 3.0 }); + + // Act + sut.transform(input); + + // Assert + assertEquals(1, store.addedObservation.shape().length); + assertEquals(3, store.addedObservation.shape()[0]); + } + + @Test + public void when_firstDimensionIsBatch_expect_observationAddedAsSliced() { + // Arrange + MockStore store = new MockStore(false); + HistoryMergeTransform sut = HistoryMergeTransform.builder() + .isFirstDimenstionBatch(true) + .elementStore(store) + .build(); + INDArray input = Nd4j.create(new double[] { 1.0, 2.0, 3.0 }).reshape(1, 3); + + // Act + sut.transform(input); + + // Assert + assertEquals(1, store.addedObservation.shape().length); + assertEquals(3, store.addedObservation.shape()[0]); + } + + @Test + public void when_notReady_expect_resultIsNull() { + // Arrange + MockStore store = new MockStore(false); + HistoryMergeTransform sut = HistoryMergeTransform.builder() + .isFirstDimenstionBatch(true) + .elementStore(store) + .build(); + INDArray input = Nd4j.create(new double[] { 1.0, 2.0, 3.0 }); + + // Act + INDArray result = sut.transform(input); + + // Assert + assertNull(result); + } + + @Test + public void when_notShouldStoreCopy_expect_sameIsStored() { + // Arrange + MockStore store = new MockStore(false); + HistoryMergeTransform sut = HistoryMergeTransform.builder() + .shouldStoreCopy(false) + .elementStore(store) + .build(); + INDArray input = Nd4j.create(new double[] { 1.0, 2.0, 3.0 }); + + // Act + INDArray result = sut.transform(input); + + // Assert + assertSame(input, store.addedObservation); + } + + @Test + public void when_shouldStoreCopy_expect_copyIsStored() { + // Arrange + MockStore store = new MockStore(true); + HistoryMergeTransform sut = HistoryMergeTransform.builder() + .shouldStoreCopy(true) + .elementStore(store) + .build(); + INDArray input = Nd4j.create(new double[] { 1.0, 2.0, 3.0 }); + + // Act + INDArray result = sut.transform(input); + + // Assert + assertNotSame(input, store.addedObservation); + assertEquals(1, store.addedObservation.shape().length); + assertEquals(3, store.addedObservation.shape()[0]); + } + + @Test + public void when_transformCalled_expect_storeContentAssembledAndOutputHasCorrectShape() { + // Arrange + MockStore store = new MockStore(true); + MockAssemble assemble = new MockAssemble(); + HistoryMergeTransform sut = HistoryMergeTransform.builder() + .elementStore(store) + .assembler(assemble) + .build(); + INDArray input = Nd4j.create(new double[] { 1.0, 2.0, 3.0 }); + + // Act + INDArray result = sut.transform(input); + + // Assert + assertEquals(1, assemble.assembleElements.length); + assertSame(store.addedObservation, assemble.assembleElements[0]); + + assertEquals(2, result.shape().length); + assertEquals(1, result.shape()[0]); + assertEquals(3, result.shape()[1]); + } + + public static class MockStore implements HistoryMergeElementStore { + + private final boolean isReady; + private INDArray addedObservation; + + public MockStore(boolean isReady) { + + this.isReady = isReady; + } + + @Override + public void add(INDArray observation) { + addedObservation = observation; + } + + @Override + public INDArray[] get() { + return new INDArray[] { addedObservation }; + } + + @Override + public boolean isReady() { + return isReady; + } + + @Override + public void reset() { + + } + } + + public static class MockAssemble implements HistoryMergeAssembler { + + private INDArray[] assembleElements; + + @Override + public INDArray assemble(INDArray[] elements) { + assembleElements = elements; + return elements[0]; + } + } +} diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/observation/transform/operation/SimpleNormalizationTransformTest.java b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/observation/transform/operation/SimpleNormalizationTransformTest.java new file mode 100644 index 000000000..b330a4fb0 --- /dev/null +++ b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/observation/transform/operation/SimpleNormalizationTransformTest.java @@ -0,0 +1,31 @@ +package org.deeplearning4j.rl4j.observation.transform.operation; + +import org.deeplearning4j.rl4j.helper.INDArrayHelper; +import org.junit.Test; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.factory.Nd4j; + +import static org.junit.Assert.*; + +public class SimpleNormalizationTransformTest { + @Test(expected = IllegalArgumentException.class) + public void when_maxIsLessThanMin_expect_exception() { + // Arrange + SimpleNormalizationTransform sut = new SimpleNormalizationTransform(10.0, 1.0); + } + + @Test + public void when_transformIsCalled_expect_inputNormalized() { + // Arrange + SimpleNormalizationTransform sut = new SimpleNormalizationTransform(1.0, 11.0); + INDArray input = Nd4j.create(new double[] { 1.0, 11.0 }); + + // Act + INDArray result = sut.transform(input); + + // Assert + assertEquals(0.0, result.getDouble(0), 0.00001); + assertEquals(1.0, result.getDouble(1), 0.00001); + } + +} diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/observation/transform/operation/historymerge/CircularFifoStoreTest.java b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/observation/transform/operation/historymerge/CircularFifoStoreTest.java new file mode 100644 index 000000000..f9b34a1f1 --- /dev/null +++ b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/observation/transform/operation/historymerge/CircularFifoStoreTest.java @@ -0,0 +1,77 @@ +package org.deeplearning4j.rl4j.observation.transform.operation.historymerge; + +import org.junit.Test; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.factory.Nd4j; + +import static org.junit.Assert.*; + +public class CircularFifoStoreTest { + + @Test(expected = IllegalArgumentException.class) + public void when_fifoSizeIsLessThan1_expect_exception() { + // Arrange + CircularFifoStore sut = new CircularFifoStore(0); + } + + @Test + public void when_adding2elementsWithSize2_expect_notReadyAfter1stReadyAfter2nd() { + // Arrange + CircularFifoStore sut = new CircularFifoStore(2); + INDArray firstElement = Nd4j.create(new double[] { 1.0, 2.0, 3.0 }); + INDArray secondElement = Nd4j.create(new double[] { 10.0, 20.0, 30.0 }); + + // Act + sut.add(firstElement); + boolean isReadyAfter1st = sut.isReady(); + sut.add(secondElement); + boolean isReadyAfter2nd = sut.isReady(); + + // Assert + assertFalse(isReadyAfter1st); + assertTrue(isReadyAfter2nd); + } + + @Test + public void when_adding2elementsWithSize2_expect_getReturnThese2() { + // Arrange + CircularFifoStore sut = new CircularFifoStore(2); + INDArray firstElement = Nd4j.create(new double[] { 1.0, 2.0, 3.0 }); + INDArray secondElement = Nd4j.create(new double[] { 10.0, 20.0, 30.0 }); + + // Act + sut.add(firstElement); + sut.add(secondElement); + INDArray[] results = sut.get(); + + // Assert + assertEquals(2, results.length); + + assertEquals(1.0, results[0].getDouble(0), 0.00001); + assertEquals(2.0, results[0].getDouble(1), 0.00001); + assertEquals(3.0, results[0].getDouble(2), 0.00001); + + assertEquals(10.0, results[1].getDouble(0), 0.00001); + assertEquals(20.0, results[1].getDouble(1), 0.00001); + assertEquals(30.0, results[1].getDouble(2), 0.00001); + + } + + @Test + public void when_adding2elementsThenCallingReset_expect_getReturnEmpty() { + // Arrange + CircularFifoStore sut = new CircularFifoStore(2); + INDArray firstElement = Nd4j.create(new double[] { 1.0, 2.0, 3.0 }); + INDArray secondElement = Nd4j.create(new double[] { 10.0, 20.0, 30.0 }); + + // Act + sut.add(firstElement); + sut.add(secondElement); + sut.reset(); + INDArray[] results = sut.get(); + + // Assert + assertEquals(0, results.length); + } + +} diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/observation/transform/operation/historymerge/HistoryStackAssemblerTest.java b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/observation/transform/operation/historymerge/HistoryStackAssemblerTest.java new file mode 100644 index 000000000..36826430e --- /dev/null +++ b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/observation/transform/operation/historymerge/HistoryStackAssemblerTest.java @@ -0,0 +1,37 @@ +package org.deeplearning4j.rl4j.observation.transform.operation.historymerge; + +import org.junit.Test; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.factory.Nd4j; + +import static org.junit.Assert.*; + +public class HistoryStackAssemblerTest { + + @Test + public void when_assembling2INDArrays_expect_stackedAsResult() { + // Arrange + INDArray[] input = new INDArray[] { + Nd4j.create(new double[] { 1.0, 2.0, 3.0 }), + Nd4j.create(new double[] { 10.0, 20.0, 30.0 }), + }; + HistoryStackAssembler sut = new HistoryStackAssembler(); + + // Act + INDArray result = sut.assemble(input); + + // Assert + assertEquals(2, result.shape().length); + assertEquals(2, result.shape()[0]); + assertEquals(3, result.shape()[1]); + + assertEquals(1.0, result.getDouble(0, 0), 0.00001); + assertEquals(2.0, result.getDouble(0, 1), 0.00001); + assertEquals(3.0, result.getDouble(0, 2), 0.00001); + + assertEquals(10.0, result.getDouble(1, 0), 0.00001); + assertEquals(20.0, result.getDouble(1, 1), 0.00001); + assertEquals(30.0, result.getDouble(1, 2), 0.00001); + + } +} diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/policy/PolicyTest.java b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/policy/PolicyTest.java index 2262f1789..0707e16ab 100644 --- a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/policy/PolicyTest.java +++ b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/policy/PolicyTest.java @@ -23,15 +23,18 @@ import org.deeplearning4j.nn.gradient.Gradient; import org.deeplearning4j.nn.graph.ComputationGraph; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.deeplearning4j.rl4j.learning.IHistoryProcessor; +import org.deeplearning4j.rl4j.learning.Learning; import org.deeplearning4j.rl4j.learning.sync.qlearning.QLearning; import org.deeplearning4j.rl4j.learning.sync.qlearning.discrete.QLearningDiscreteTest; import org.deeplearning4j.rl4j.mdp.MDP; import org.deeplearning4j.rl4j.network.NeuralNet; import org.deeplearning4j.rl4j.network.ac.IActorCritic; import org.deeplearning4j.rl4j.observation.Observation; +import org.deeplearning4j.rl4j.space.ActionSpace; import org.deeplearning4j.rl4j.space.DiscreteSpace; import org.deeplearning4j.rl4j.space.Encodable; import org.deeplearning4j.rl4j.support.*; +import org.deeplearning4j.rl4j.util.LegacyMDPWrapper; import org.junit.Test; import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.api.ndarray.INDArray; @@ -186,8 +189,8 @@ public class PolicyTest { QLearning.QLConfiguration conf = new QLearning.QLConfiguration(0, 0, 0, 5, 1, 0, 0, 1.0, 0, 0, 0, 0, true); MockNeuralNet nnMock = new MockNeuralNet(); - MockRefacPolicy sut = new MockRefacPolicy(nnMock); IHistoryProcessor.Configuration hpConf = new IHistoryProcessor.Configuration(5, 4, 4, 4, 4, 0, 0, 2); + MockRefacPolicy sut = new MockRefacPolicy(nnMock, observationSpace.getShape(), hpConf.getSkipFrame(), hpConf.getHistoryLength()); MockHistoryProcessor hp = new MockHistoryProcessor(hpConf); // Act @@ -197,13 +200,6 @@ public class PolicyTest { assertEquals(1, nnMock.resetCallCount); assertEquals(465.0, totalReward, 0.0001); - // HistoryProcessor - assertEquals(16, hp.addCalls.size()); - assertEquals(31, hp.recordCalls.size()); - for(int i=0; i <= 30; ++i) { - assertEquals((double)i, hp.recordCalls.get(i).getDouble(0), 0.0001); - } - // MDP assertEquals(1, mdp.resetCount); assertEquals(30, mdp.actions.size()); @@ -219,10 +215,15 @@ public class PolicyTest { public static class MockRefacPolicy extends Policy { private NeuralNet neuralNet; + private final int[] shape; + private final int skipFrame; + private final int historyLength; - public MockRefacPolicy(NeuralNet neuralNet) { - + public MockRefacPolicy(NeuralNet neuralNet, int[] shape, int skipFrame, int historyLength) { this.neuralNet = neuralNet; + this.shape = shape; + this.skipFrame = skipFrame; + this.historyLength = historyLength; } @Override @@ -239,5 +240,11 @@ public class PolicyTest { public Integer nextAction(INDArray input) { return (int)input.getDouble(0); } + + @Override + protected > Learning.InitMdp refacInitMdp(LegacyMDPWrapper mdpWrapper, IHistoryProcessor hp, RefacEpochStepCounter epochStepCounter) { + mdpWrapper.setTransformProcess(MockMDP.buildTransformProcess(shape, skipFrame, historyLength)); + return super.refacInitMdp(mdpWrapper, hp, epochStepCounter); + } } } diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/support/MockMDP.java b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/support/MockMDP.java index c0ac23a2d..bbed87624 100644 --- a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/support/MockMDP.java +++ b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/support/MockMDP.java @@ -2,6 +2,12 @@ package org.deeplearning4j.rl4j.support; import org.deeplearning4j.gym.StepReply; import org.deeplearning4j.rl4j.mdp.MDP; +import org.deeplearning4j.rl4j.observation.transform.TransformProcess; +import org.deeplearning4j.rl4j.observation.transform.filter.UniformSkippingFilter; +import org.deeplearning4j.rl4j.observation.transform.legacy.EncodableToINDArrayTransform; +import org.deeplearning4j.rl4j.observation.transform.operation.HistoryMergeTransform; +import org.deeplearning4j.rl4j.observation.transform.operation.SimpleNormalizationTransform; +import org.deeplearning4j.rl4j.observation.transform.operation.historymerge.CircularFifoStore; import org.deeplearning4j.rl4j.space.DiscreteSpace; import org.deeplearning4j.rl4j.space.ObservationSpace; import org.nd4j.linalg.api.rng.Random; @@ -77,4 +83,16 @@ public class MockMDP implements MDP { public MDP newInstance() { return null; } + + public static TransformProcess buildTransformProcess(int[] shape, int skipFrame, int historyLength) { + return TransformProcess.builder() + .filter(new UniformSkippingFilter(skipFrame)) + .transform("data", new EncodableToINDArrayTransform(shape)) + .transform("data", new SimpleNormalizationTransform(0.0, 255.0)) + .transform("data", HistoryMergeTransform.builder() + .elementStore(new CircularFifoStore(historyLength)) + .build()) + .build("data"); + } + } diff --git a/rl4j/rl4j-gym/src/main/java/org/deeplearning4j/rl4j/mdp/gym/GymEnv.java b/rl4j/rl4j-gym/src/main/java/org/deeplearning4j/rl4j/mdp/gym/GymEnv.java index afeea5f7c..8a402a8ff 100644 --- a/rl4j/rl4j-gym/src/main/java/org/deeplearning4j/rl4j/mdp/gym/GymEnv.java +++ b/rl4j/rl4j-gym/src/main/java/org/deeplearning4j/rl4j/mdp/gym/GymEnv.java @@ -62,7 +62,7 @@ public class GymEnv> implements MDP { private static PyObject globals; static { try { - Py_SetPath(org.bytedeco.gym.presets.gym.cachePackages()); + Py_AddPath(org.bytedeco.gym.presets.gym.cachePackages()); program = Py_DecodeLocale(GymEnv.class.getSimpleName(), null); Py_SetProgramName(program); Py_Initialize();