RL4J: Add TransformProcess, part 2 (#8766)
* Part 2 of TransformProcess Signed-off-by: Alexandre Boulanger <aboulang2002@yahoo.com> * Fix compile errors Signed-off-by: Samuel Audet <samuel.audet@gmail.com> * Revert unrelated changes Signed-off-by: Samuel Audet <samuel.audet@gmail.com> Co-authored-by: Samuel Audet <samuel.audet@gmail.com>
This commit is contained in:
		
							parent
							
								
									0faf83b1b6
								
							
						
					
					
						commit
						8b10f0b876
					
				@ -1,5 +1,5 @@
 | 
			
		||||
/*******************************************************************************
 | 
			
		||||
 * Copyright (c) 2015-2019 Skymind, Inc.
 | 
			
		||||
 * Copyright (c) 2020 Konduit K.K.
 | 
			
		||||
 *
 | 
			
		||||
 * This program and the accompanying materials are made available under the
 | 
			
		||||
 * terms of the Apache License, Version 2.0 which is available at
 | 
			
		||||
@ -17,57 +17,43 @@
 | 
			
		||||
package org.deeplearning4j.rl4j.observation;
 | 
			
		||||
 | 
			
		||||
import lombok.Getter;
 | 
			
		||||
import lombok.Setter;
 | 
			
		||||
import org.nd4j.linalg.api.ndarray.INDArray;
 | 
			
		||||
import org.nd4j.linalg.dataset.api.DataSet;
 | 
			
		||||
import org.nd4j.linalg.factory.Nd4j;
 | 
			
		||||
 | 
			
		||||
/**
 | 
			
		||||
 * Presently only a dummy container. Will contain observation channels when done.
 | 
			
		||||
 * Represent an observation from the environment
 | 
			
		||||
 *
 | 
			
		||||
 * @author Alexandre Boulanger
 | 
			
		||||
 */
 | 
			
		||||
public class Observation {
 | 
			
		||||
    // TODO: Presently only a dummy container. Will contain observation channels when done.
 | 
			
		||||
 | 
			
		||||
    private final DataSet data;
 | 
			
		||||
    /**
 | 
			
		||||
     * A singleton representing a skipped observation
 | 
			
		||||
     */
 | 
			
		||||
    public static Observation SkippedObservation = new Observation(null);
 | 
			
		||||
 | 
			
		||||
    @Getter @Setter
 | 
			
		||||
    private boolean skipped;
 | 
			
		||||
    /**
 | 
			
		||||
     * @return A INDArray containing the data of the observation
 | 
			
		||||
     */
 | 
			
		||||
    @Getter
 | 
			
		||||
    private final INDArray data;
 | 
			
		||||
 | 
			
		||||
    public Observation(INDArray[] data) {
 | 
			
		||||
        this(data, false);
 | 
			
		||||
    public boolean isSkipped() {
 | 
			
		||||
        return data == null;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    public Observation(INDArray[] data, boolean shouldReshape) {
 | 
			
		||||
        INDArray features = Nd4j.concat(0, data);
 | 
			
		||||
        if(shouldReshape) {
 | 
			
		||||
            features = reshape(features);
 | 
			
		||||
        }
 | 
			
		||||
        this.data = new org.nd4j.linalg.dataset.DataSet(features, null);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    // FIXME: Remove -- only used in unit tests
 | 
			
		||||
    public Observation(INDArray data) {
 | 
			
		||||
        this.data = new org.nd4j.linalg.dataset.DataSet(data, null);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    private INDArray reshape(INDArray source) {
 | 
			
		||||
        long[] shape = source.shape();
 | 
			
		||||
        long[] nshape = new long[shape.length + 1];
 | 
			
		||||
        nshape[0] = 1;
 | 
			
		||||
        System.arraycopy(shape, 0, nshape, 1, shape.length);
 | 
			
		||||
 | 
			
		||||
        return source.reshape(nshape);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    private Observation(DataSet data) {
 | 
			
		||||
        this.data = data;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    /**
 | 
			
		||||
     * Creates a duplicate instance of the current observation
 | 
			
		||||
     * @return
 | 
			
		||||
     */
 | 
			
		||||
    public Observation dup() {
 | 
			
		||||
        return new Observation(new org.nd4j.linalg.dataset.DataSet(data.getFeatures().dup(), null));
 | 
			
		||||
    }
 | 
			
		||||
        if(data == null) {
 | 
			
		||||
            return SkippedObservation;
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
    public INDArray getData() {
 | 
			
		||||
        return data.getFeatures();
 | 
			
		||||
        return new Observation(data.dup());
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@ -0,0 +1,231 @@
 | 
			
		||||
/*******************************************************************************
 | 
			
		||||
 * Copyright (c) 2020 Konduit K.K.
 | 
			
		||||
 *
 | 
			
		||||
 * This program and the accompanying materials are made available under the
 | 
			
		||||
 * terms of the Apache License, Version 2.0 which is available at
 | 
			
		||||
 * https://www.apache.org/licenses/LICENSE-2.0.
 | 
			
		||||
 *
 | 
			
		||||
 * Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
 * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 | 
			
		||||
 * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
 | 
			
		||||
 * License for the specific language governing permissions and limitations
 | 
			
		||||
 * under the License.
 | 
			
		||||
 *
 | 
			
		||||
 * SPDX-License-Identifier: Apache-2.0
 | 
			
		||||
 ******************************************************************************/
 | 
			
		||||
package org.deeplearning4j.rl4j.observation.transform;
 | 
			
		||||
 | 
			
		||||
import org.apache.commons.lang3.NotImplementedException;
 | 
			
		||||
import org.deeplearning4j.rl4j.helper.INDArrayHelper;
 | 
			
		||||
import org.deeplearning4j.rl4j.observation.Observation;
 | 
			
		||||
import org.nd4j.base.Preconditions;
 | 
			
		||||
import org.nd4j.linalg.api.ndarray.INDArray;
 | 
			
		||||
import org.nd4j.linalg.dataset.api.DataSet;
 | 
			
		||||
import org.nd4j.linalg.dataset.api.DataSetPreProcessor;
 | 
			
		||||
import org.nd4j.shade.guava.collect.Maps;
 | 
			
		||||
import org.datavec.api.transform.Operation;
 | 
			
		||||
 | 
			
		||||
import java.util.ArrayList;
 | 
			
		||||
import java.util.HashSet;
 | 
			
		||||
import java.util.List;
 | 
			
		||||
import java.util.Map;
 | 
			
		||||
 | 
			
		||||
/**
 | 
			
		||||
 * A TransformProcess will build an {@link Observation Observation} from the raw data coming from the environment.
 | 
			
		||||
 * Three types of steps are available:
 | 
			
		||||
 *   1) A {@link FilterOperation FilterOperation}: Used to determine if an observation should be skipped.
 | 
			
		||||
 *   2) An {@link Operation Operation}: Applies a transform and/or conversion to an observation channel.
 | 
			
		||||
 *   3) A {@link DataSetPreProcessor DataSetPreProcessor}: Applies a DataSetPreProcessor to the observation channel. The data in the channel must be a DataSet.
 | 
			
		||||
 *
 | 
			
		||||
 * Instances of the three types above can be called in any order. The only requirement is that when build() is called,
 | 
			
		||||
 * all channels must be instances of INDArrays or DataSets
 | 
			
		||||
 *
 | 
			
		||||
 *   NOTE: Presently, only single-channels observations are supported.
 | 
			
		||||
 *
 | 
			
		||||
 * @author Alexandre Boulanger
 | 
			
		||||
 */
 | 
			
		||||
public class TransformProcess {
 | 
			
		||||
 | 
			
		||||
    private final List<Map.Entry<String, Object>> operations;
 | 
			
		||||
    private final String[] channelNames;
 | 
			
		||||
    private final HashSet<String> operationsChannelNames;
 | 
			
		||||
 | 
			
		||||
    private TransformProcess(Builder builder, String... channelNames) {
 | 
			
		||||
        operations = builder.operations;
 | 
			
		||||
        this.channelNames = channelNames;
 | 
			
		||||
        this.operationsChannelNames = builder.requiredChannelNames;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    /**
 | 
			
		||||
     * This method will call reset() of all steps implementing {@link ResettableOperation ResettableOperation} in the transform process.
 | 
			
		||||
     */
 | 
			
		||||
    public void reset() {
 | 
			
		||||
        for(Map.Entry<String, Object> entry : operations) {
 | 
			
		||||
            if(entry.getValue() instanceof ResettableOperation) {
 | 
			
		||||
                ((ResettableOperation) entry.getValue()).reset();
 | 
			
		||||
            }
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    /**
 | 
			
		||||
     * Transforms the channel data into an Observation or a skipped observation depending on the specific steps in the transform process.
 | 
			
		||||
     *
 | 
			
		||||
     * @param channelsData A Map that maps the channel name to its data.
 | 
			
		||||
     * @param currentObservationStep The observation's step number within the episode.
 | 
			
		||||
     * @param isFinalObservation True if the observation is the last of the episode.
 | 
			
		||||
     * @return An observation (may be a skipped observation)
 | 
			
		||||
     */
 | 
			
		||||
    public Observation transform(Map<String, Object> channelsData, int currentObservationStep, boolean isFinalObservation) {
 | 
			
		||||
        // null or empty channelData
 | 
			
		||||
        Preconditions.checkArgument(channelsData != null && channelsData.size() != 0, "Error: channelsData not supplied.");
 | 
			
		||||
 | 
			
		||||
        // Check that all channels have data
 | 
			
		||||
        for(Map.Entry<String, Object> channel : channelsData.entrySet()) {
 | 
			
		||||
            Preconditions.checkNotNull(channel.getValue(), "Error: data of channel '%s' is null", channel.getKey());
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        // Check that all required channels are present
 | 
			
		||||
        for(String channelName : operationsChannelNames) {
 | 
			
		||||
            Preconditions.checkArgument(channelsData.containsKey(channelName), "The channelsData map does not contain the channel '%s'", channelName);
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        for(Map.Entry<String, Object> entry : operations) {
 | 
			
		||||
 | 
			
		||||
            // Filter
 | 
			
		||||
            if(entry.getValue() instanceof FilterOperation) {
 | 
			
		||||
                FilterOperation filterOperation = (FilterOperation)entry.getValue();
 | 
			
		||||
                if(filterOperation.isSkipped(channelsData, currentObservationStep, isFinalObservation)) {
 | 
			
		||||
                    return Observation.SkippedObservation;
 | 
			
		||||
                }
 | 
			
		||||
            }
 | 
			
		||||
 | 
			
		||||
            // Transform
 | 
			
		||||
            // null results are considered skipped observations
 | 
			
		||||
            else if(entry.getValue() instanceof Operation) {
 | 
			
		||||
                Operation transformOperation = (Operation)entry.getValue();
 | 
			
		||||
                Object transformed = transformOperation.transform(channelsData.get(entry.getKey()));
 | 
			
		||||
                if(transformed == null) {
 | 
			
		||||
                    return Observation.SkippedObservation;
 | 
			
		||||
                }
 | 
			
		||||
                channelsData.replace(entry.getKey(), transformed);
 | 
			
		||||
            }
 | 
			
		||||
 | 
			
		||||
            // PreProcess
 | 
			
		||||
            else if(entry.getValue() instanceof DataSetPreProcessor) {
 | 
			
		||||
                Object channelData = channelsData.get(entry.getKey());
 | 
			
		||||
                DataSetPreProcessor dataSetPreProcessor = (DataSetPreProcessor)entry.getValue();
 | 
			
		||||
                if(!(channelData instanceof DataSet)) {
 | 
			
		||||
                    throw new IllegalArgumentException("The channel data must be a DataSet to call preProcess");
 | 
			
		||||
                }
 | 
			
		||||
                dataSetPreProcessor.preProcess((DataSet)channelData);
 | 
			
		||||
            }
 | 
			
		||||
 | 
			
		||||
            else {
 | 
			
		||||
                throw new IllegalArgumentException(String.format("Unknown operation: '%s'", entry.getValue().getClass().getName()));
 | 
			
		||||
            }
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        // Check that all channels used to build the observation are instances of
 | 
			
		||||
        // INDArray or DataSet
 | 
			
		||||
        // TODO: Add support for an interface with a toINDArray() method
 | 
			
		||||
        for(String channelName : channelNames) {
 | 
			
		||||
            Object channelData = channelsData.get(channelName);
 | 
			
		||||
 | 
			
		||||
            INDArray finalChannelData;
 | 
			
		||||
            if(channelData instanceof DataSet) {
 | 
			
		||||
                finalChannelData = ((DataSet)channelData).getFeatures();
 | 
			
		||||
            }
 | 
			
		||||
            else if(channelData instanceof INDArray) {
 | 
			
		||||
                finalChannelData = (INDArray) channelData;
 | 
			
		||||
            }
 | 
			
		||||
            else {
 | 
			
		||||
                throw new IllegalStateException("All channels used to build the observation must be instances of DataSet or INDArray");
 | 
			
		||||
            }
 | 
			
		||||
 | 
			
		||||
            // The dimension 0 of all INDArrays must be 1 (batch count)
 | 
			
		||||
            channelsData.replace(channelName, INDArrayHelper.forceCorrectShape(finalChannelData));
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        // TODO: Add support to multi-channel observations
 | 
			
		||||
        INDArray data = ((INDArray) channelsData.get(channelNames[0]));
 | 
			
		||||
        return new Observation(data);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    /**
 | 
			
		||||
     * @return An instance of a builder
 | 
			
		||||
     */
 | 
			
		||||
    public static Builder builder() {
 | 
			
		||||
        return new Builder();
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    public static class Builder {
 | 
			
		||||
 | 
			
		||||
        private final List<Map.Entry<String, Object>> operations = new ArrayList<Map.Entry<String, Object>>();
 | 
			
		||||
        private final HashSet<String> requiredChannelNames = new HashSet<String>();
 | 
			
		||||
 | 
			
		||||
        /**
 | 
			
		||||
         * Add a filter to the transform process steps. Used to skip observations on certain conditions.
 | 
			
		||||
         * See {@link FilterOperation FilterOperation}
 | 
			
		||||
         * @param filterOperation An instance
 | 
			
		||||
         */
 | 
			
		||||
        public Builder filter(FilterOperation filterOperation) {
 | 
			
		||||
            Preconditions.checkNotNull(filterOperation, "The filterOperation must not be null");
 | 
			
		||||
 | 
			
		||||
            operations.add((Map.Entry)Maps.immutableEntry(null, filterOperation));
 | 
			
		||||
            return this;
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        /**
 | 
			
		||||
         * Add a transform to the steps. The transform can change the data and / or change the type of the data
 | 
			
		||||
         * (e.g. Byte[] to a ImageWritable)
 | 
			
		||||
         *
 | 
			
		||||
         * @param targetChannel The name of the channel to which the transform is applied.
 | 
			
		||||
         * @param transformOperation An instance of {@link Operation Operation}
 | 
			
		||||
         */
 | 
			
		||||
        public Builder transform(String targetChannel, Operation transformOperation) {
 | 
			
		||||
            Preconditions.checkNotNull(targetChannel, "The targetChannel must not be null");
 | 
			
		||||
            Preconditions.checkNotNull(transformOperation, "The transformOperation must not be null");
 | 
			
		||||
 | 
			
		||||
            requiredChannelNames.add(targetChannel);
 | 
			
		||||
            operations.add((Map.Entry)Maps.immutableEntry(targetChannel, transformOperation));
 | 
			
		||||
            return this;
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        /**
 | 
			
		||||
         * Add a DataSetPreProcessor to the steps. The channel must be a DataSet instance at this step.
 | 
			
		||||
         * @param targetChannel The name of the channel to which the pre processor is applied.
 | 
			
		||||
         * @param dataSetPreProcessor
 | 
			
		||||
         */
 | 
			
		||||
        public Builder preProcess(String targetChannel, DataSetPreProcessor dataSetPreProcessor) {
 | 
			
		||||
            Preconditions.checkNotNull(targetChannel, "The targetChannel must not be null");
 | 
			
		||||
            Preconditions.checkNotNull(dataSetPreProcessor, "The dataSetPreProcessor must not be null");
 | 
			
		||||
 | 
			
		||||
            requiredChannelNames.add(targetChannel);
 | 
			
		||||
            operations.add((Map.Entry)Maps.immutableEntry(targetChannel, dataSetPreProcessor));
 | 
			
		||||
            return this;
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        /**
 | 
			
		||||
         * Builds the TransformProcess.
 | 
			
		||||
         * @param channelNames A subset of channel names to be used to build the observation
 | 
			
		||||
         * @return An instance of TransformProcess
 | 
			
		||||
         */
 | 
			
		||||
        public TransformProcess build(String... channelNames) {
 | 
			
		||||
            if(channelNames.length == 0) {
 | 
			
		||||
                throw new IllegalArgumentException("At least one channel must be supplied.");
 | 
			
		||||
            }
 | 
			
		||||
 | 
			
		||||
            for(String channelName : channelNames) {
 | 
			
		||||
                Preconditions.checkNotNull(channelName, "Error: got a null channel name");
 | 
			
		||||
                requiredChannelNames.add(channelName);
 | 
			
		||||
            }
 | 
			
		||||
 | 
			
		||||
            // TODO: Remove when multi-channel observation is supported
 | 
			
		||||
            if(channelNames.length != 1) {
 | 
			
		||||
                throw new NotImplementedException("Multi-channel observations is not presently supported.");
 | 
			
		||||
            }
 | 
			
		||||
 | 
			
		||||
            return new TransformProcess(this, channelNames);
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
@ -25,14 +25,14 @@ import org.nd4j.linalg.factory.Nd4j;
 | 
			
		||||
 | 
			
		||||
import static org.bytedeco.opencv.global.opencv_core.CV_32FC;
 | 
			
		||||
 | 
			
		||||
public class EncodableToImageWriteableTransform implements Operation<Encodable, ImageWritable> {
 | 
			
		||||
public class EncodableToImageWritableTransform implements Operation<Encodable, ImageWritable> {
 | 
			
		||||
 | 
			
		||||
    private final OpenCVFrameConverter.ToMat converter = new OpenCVFrameConverter.ToMat();
 | 
			
		||||
    private final int height;
 | 
			
		||||
    private final int width;
 | 
			
		||||
    private final int colorChannels;
 | 
			
		||||
 | 
			
		||||
    public EncodableToImageWriteableTransform(int height, int width, int colorChannels) {
 | 
			
		||||
    public EncodableToImageWritableTransform(int height, int width, int colorChannels) {
 | 
			
		||||
        this.height = height;
 | 
			
		||||
        this.width = width;
 | 
			
		||||
        this.colorChannels = colorChannels;
 | 
			
		||||
@ -1,3 +1,18 @@
 | 
			
		||||
/*******************************************************************************
 | 
			
		||||
 * Copyright (c) 2020 Konduit K.K.
 | 
			
		||||
 *
 | 
			
		||||
 * This program and the accompanying materials are made available under the
 | 
			
		||||
 * terms of the Apache License, Version 2.0 which is available at
 | 
			
		||||
 * https://www.apache.org/licenses/LICENSE-2.0.
 | 
			
		||||
 *
 | 
			
		||||
 * Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
 * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 | 
			
		||||
 * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
 | 
			
		||||
 * License for the specific language governing permissions and limitations
 | 
			
		||||
 * under the License.
 | 
			
		||||
 *
 | 
			
		||||
 * SPDX-License-Identifier: Apache-2.0
 | 
			
		||||
 ******************************************************************************/
 | 
			
		||||
package org.deeplearning4j.rl4j.observation.transform.legacy;
 | 
			
		||||
 | 
			
		||||
import org.datavec.api.transform.Operation;
 | 
			
		||||
@ -10,13 +25,13 @@ import org.nd4j.linalg.factory.Nd4j;
 | 
			
		||||
 | 
			
		||||
import java.io.IOException;
 | 
			
		||||
 | 
			
		||||
public class ImageWriteableToINDArrayTransform implements Operation<ImageWritable, INDArray> {
 | 
			
		||||
public class ImageWritableToINDArrayTransform implements Operation<ImageWritable, INDArray> {
 | 
			
		||||
 | 
			
		||||
    private final int height;
 | 
			
		||||
    private final int width;
 | 
			
		||||
    private final NativeImageLoader loader;
 | 
			
		||||
 | 
			
		||||
    public ImageWriteableToINDArrayTransform(int height, int width) {
 | 
			
		||||
    public ImageWritableToINDArrayTransform(int height, int width) {
 | 
			
		||||
        this.height = height;
 | 
			
		||||
        this.width = width;
 | 
			
		||||
        this.loader = new NativeImageLoader(height, width);
 | 
			
		||||
@ -0,0 +1,44 @@
 | 
			
		||||
/*******************************************************************************
 | 
			
		||||
 * Copyright (c) 2020 Konduit K.K.
 | 
			
		||||
 *
 | 
			
		||||
 * This program and the accompanying materials are made available under the
 | 
			
		||||
 * terms of the Apache License, Version 2.0 which is available at
 | 
			
		||||
 * https://www.apache.org/licenses/LICENSE-2.0.
 | 
			
		||||
 *
 | 
			
		||||
 * Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
 * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 | 
			
		||||
 * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
 | 
			
		||||
 * License for the specific language governing permissions and limitations
 | 
			
		||||
 * under the License.
 | 
			
		||||
 *
 | 
			
		||||
 * SPDX-License-Identifier: Apache-2.0
 | 
			
		||||
 ******************************************************************************/
 | 
			
		||||
package org.deeplearning4j.rl4j.observation.transform.operation;
 | 
			
		||||
 | 
			
		||||
import org.datavec.api.transform.Operation;
 | 
			
		||||
import org.nd4j.base.Preconditions;
 | 
			
		||||
import org.nd4j.linalg.api.ndarray.INDArray;
 | 
			
		||||
 | 
			
		||||
public class SimpleNormalizationTransform implements Operation<INDArray, INDArray> {
 | 
			
		||||
 | 
			
		||||
    private final double offset;
 | 
			
		||||
    private final double divisor;
 | 
			
		||||
 | 
			
		||||
    public SimpleNormalizationTransform(double min, double max) {
 | 
			
		||||
        Preconditions.checkArgument(min < max, "Min must be smaller than max.");
 | 
			
		||||
 | 
			
		||||
        this.offset = min;
 | 
			
		||||
        this.divisor = (max - min);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    @Override
 | 
			
		||||
    public INDArray transform(INDArray input) {
 | 
			
		||||
        if(offset != 0.0) {
 | 
			
		||||
            input.subi(offset);
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        input.divi(divisor);
 | 
			
		||||
 | 
			
		||||
        return input;
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
@ -3,37 +3,95 @@ package org.deeplearning4j.rl4j.util;
 | 
			
		||||
import lombok.AccessLevel;
 | 
			
		||||
import lombok.Getter;
 | 
			
		||||
import lombok.Setter;
 | 
			
		||||
import org.datavec.image.transform.ColorConversionTransform;
 | 
			
		||||
import org.datavec.image.transform.CropImageTransform;
 | 
			
		||||
import org.datavec.image.transform.MultiImageTransform;
 | 
			
		||||
import org.datavec.image.transform.ResizeImageTransform;
 | 
			
		||||
import org.deeplearning4j.gym.StepReply;
 | 
			
		||||
import org.deeplearning4j.rl4j.learning.EpochStepCounter;
 | 
			
		||||
import org.deeplearning4j.rl4j.learning.IHistoryProcessor;
 | 
			
		||||
import org.deeplearning4j.rl4j.mdp.MDP;
 | 
			
		||||
import org.deeplearning4j.rl4j.observation.Observation;
 | 
			
		||||
import org.deeplearning4j.rl4j.observation.transform.TransformProcess;
 | 
			
		||||
import org.deeplearning4j.rl4j.observation.transform.filter.UniformSkippingFilter;
 | 
			
		||||
import org.deeplearning4j.rl4j.observation.transform.legacy.EncodableToINDArrayTransform;
 | 
			
		||||
import org.deeplearning4j.rl4j.observation.transform.legacy.EncodableToImageWritableTransform;
 | 
			
		||||
import org.deeplearning4j.rl4j.observation.transform.legacy.ImageWritableToINDArrayTransform;
 | 
			
		||||
import org.deeplearning4j.rl4j.observation.transform.operation.HistoryMergeTransform;
 | 
			
		||||
import org.deeplearning4j.rl4j.observation.transform.operation.SimpleNormalizationTransform;
 | 
			
		||||
import org.deeplearning4j.rl4j.space.ActionSpace;
 | 
			
		||||
import org.deeplearning4j.rl4j.space.Encodable;
 | 
			
		||||
import org.deeplearning4j.rl4j.space.ObservationSpace;
 | 
			
		||||
import org.nd4j.linalg.api.ndarray.INDArray;
 | 
			
		||||
import org.nd4j.linalg.factory.Nd4j;
 | 
			
		||||
 | 
			
		||||
import java.util.HashMap;
 | 
			
		||||
import java.util.Map;
 | 
			
		||||
 | 
			
		||||
import static org.bytedeco.opencv.global.opencv_imgproc.COLOR_BGR2GRAY;
 | 
			
		||||
 | 
			
		||||
public class LegacyMDPWrapper<O, A, AS extends ActionSpace<A>> implements MDP<Observation, A, AS> {
 | 
			
		||||
 | 
			
		||||
    @Getter
 | 
			
		||||
    private final MDP<O, A, AS> wrappedMDP;
 | 
			
		||||
    @Getter
 | 
			
		||||
    private final WrapperObservationSpace observationSpace;
 | 
			
		||||
    private final int[] shape;
 | 
			
		||||
 | 
			
		||||
    @Getter(AccessLevel.PRIVATE) @Setter(AccessLevel.PUBLIC)
 | 
			
		||||
    @Setter
 | 
			
		||||
    private TransformProcess transformProcess;
 | 
			
		||||
 | 
			
		||||
    @Getter(AccessLevel.PRIVATE)
 | 
			
		||||
    private IHistoryProcessor historyProcessor;
 | 
			
		||||
 | 
			
		||||
    private final EpochStepCounter epochStepCounter;
 | 
			
		||||
 | 
			
		||||
    private int skipFrame = 1;
 | 
			
		||||
    private int requiredFrame = 0;
 | 
			
		||||
 | 
			
		||||
    public LegacyMDPWrapper(MDP<O, A, AS> wrappedMDP, IHistoryProcessor historyProcessor, EpochStepCounter epochStepCounter) {
 | 
			
		||||
        this.wrappedMDP = wrappedMDP;
 | 
			
		||||
        this.observationSpace = new WrapperObservationSpace(wrappedMDP.getObservationSpace().getShape());
 | 
			
		||||
        this.shape = wrappedMDP.getObservationSpace().getShape();
 | 
			
		||||
        this.observationSpace = new WrapperObservationSpace(shape);
 | 
			
		||||
        this.historyProcessor = historyProcessor;
 | 
			
		||||
        this.epochStepCounter = epochStepCounter;
 | 
			
		||||
 | 
			
		||||
        setHistoryProcessor(historyProcessor);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    public void setHistoryProcessor(IHistoryProcessor historyProcessor) {
 | 
			
		||||
        this.historyProcessor = historyProcessor;
 | 
			
		||||
        createTransformProcess();
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    private void createTransformProcess() {
 | 
			
		||||
        IHistoryProcessor historyProcessor = getHistoryProcessor();
 | 
			
		||||
 | 
			
		||||
        if(historyProcessor != null && shape.length == 3) {
 | 
			
		||||
            int skipFrame = historyProcessor.getConf().getSkipFrame();
 | 
			
		||||
 | 
			
		||||
            int finalHeight = historyProcessor.getConf().getCroppingHeight();
 | 
			
		||||
            int finalWidth = historyProcessor.getConf().getCroppingWidth();
 | 
			
		||||
 | 
			
		||||
            transformProcess = TransformProcess.builder()
 | 
			
		||||
                    .filter(new UniformSkippingFilter(skipFrame))
 | 
			
		||||
                    .transform("data", new EncodableToImageWritableTransform(shape[0], shape[1], shape[2]))
 | 
			
		||||
                    .transform("data", new MultiImageTransform(
 | 
			
		||||
                            new ResizeImageTransform(historyProcessor.getConf().getRescaledWidth(), historyProcessor.getConf().getRescaledHeight()),
 | 
			
		||||
                            new ColorConversionTransform(COLOR_BGR2GRAY),
 | 
			
		||||
                            new CropImageTransform(historyProcessor.getConf().getOffsetY(), historyProcessor.getConf().getOffsetX(), finalHeight, finalWidth)
 | 
			
		||||
                    ))
 | 
			
		||||
                    .transform("data", new ImageWritableToINDArrayTransform(finalHeight, finalWidth))
 | 
			
		||||
                    .transform("data", new SimpleNormalizationTransform(0.0, 255.0))
 | 
			
		||||
                    .transform("data", HistoryMergeTransform.builder()
 | 
			
		||||
                            .isFirstDimenstionBatch(true)
 | 
			
		||||
                            .build())
 | 
			
		||||
                    .build("data");
 | 
			
		||||
        }
 | 
			
		||||
        else {
 | 
			
		||||
            transformProcess = TransformProcess.builder()
 | 
			
		||||
                    .transform("data", new EncodableToINDArrayTransform(shape))
 | 
			
		||||
                    .build("data");
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    @Override
 | 
			
		||||
@ -43,25 +101,17 @@ public class LegacyMDPWrapper<O, A, AS extends ActionSpace<A>> implements MDP<Ob
 | 
			
		||||
 | 
			
		||||
    @Override
 | 
			
		||||
    public Observation reset() {
 | 
			
		||||
        INDArray rawObservation = getInput(wrappedMDP.reset());
 | 
			
		||||
        transformProcess.reset();
 | 
			
		||||
 | 
			
		||||
        IHistoryProcessor historyProcessor = getHistoryProcessor();
 | 
			
		||||
        if(historyProcessor != null) {
 | 
			
		||||
            historyProcessor.record(rawObservation);
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        Observation observation = new Observation(new INDArray[] { rawObservation }, false);
 | 
			
		||||
        O rawResetResponse = wrappedMDP.reset();
 | 
			
		||||
        record(rawResetResponse);
 | 
			
		||||
 | 
			
		||||
        if(historyProcessor != null) {
 | 
			
		||||
            skipFrame = historyProcessor.getConf().getSkipFrame();
 | 
			
		||||
            requiredFrame = skipFrame * (historyProcessor.getConf().getHistoryLength() - 1);
 | 
			
		||||
 | 
			
		||||
            historyProcessor.add(rawObservation);
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        observation.setSkipped(skipFrame != 0);
 | 
			
		||||
 | 
			
		||||
        return observation;
 | 
			
		||||
        Map<String, Object> channelsData = buildChannelsData(rawResetResponse);
 | 
			
		||||
        return transformProcess.transform(channelsData, 0, false);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    @Override
 | 
			
		||||
@ -71,32 +121,32 @@ public class LegacyMDPWrapper<O, A, AS extends ActionSpace<A>> implements MDP<Ob
 | 
			
		||||
        StepReply<O> rawStepReply = wrappedMDP.step(a);
 | 
			
		||||
        INDArray rawObservation = getInput(rawStepReply.getObservation());
 | 
			
		||||
 | 
			
		||||
        int stepOfObservation = epochStepCounter.getCurrentEpochStep() + 1;
 | 
			
		||||
 | 
			
		||||
        if(historyProcessor != null) {
 | 
			
		||||
            historyProcessor.record(rawObservation);
 | 
			
		||||
 | 
			
		||||
            if (stepOfObservation % skipFrame == 0) {
 | 
			
		||||
                historyProcessor.add(rawObservation);
 | 
			
		||||
            }
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        Observation observation;
 | 
			
		||||
        if(historyProcessor != null && stepOfObservation >= requiredFrame) {
 | 
			
		||||
            observation = new Observation(historyProcessor.getHistory(), true);
 | 
			
		||||
            observation.getData().muli(1.0 / historyProcessor.getScale());
 | 
			
		||||
        }
 | 
			
		||||
        else {
 | 
			
		||||
            observation = new Observation(new INDArray[] { rawObservation }, false);
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        if(stepOfObservation % skipFrame != 0 || stepOfObservation < requiredFrame) {
 | 
			
		||||
            observation.setSkipped(true);
 | 
			
		||||
        }
 | 
			
		||||
        int stepOfObservation = epochStepCounter.getCurrentEpochStep() + 1;
 | 
			
		||||
 | 
			
		||||
        Map<String, Object> channelsData = buildChannelsData(rawStepReply.getObservation());
 | 
			
		||||
        Observation observation =  transformProcess.transform(channelsData, stepOfObservation, rawStepReply.isDone());
 | 
			
		||||
        return new StepReply<Observation>(observation, rawStepReply.getReward(), rawStepReply.isDone(), rawStepReply.getInfo());
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    private void record(O obs) {
 | 
			
		||||
        INDArray rawObservation = getInput(obs);
 | 
			
		||||
 | 
			
		||||
        IHistoryProcessor historyProcessor = getHistoryProcessor();
 | 
			
		||||
        if(historyProcessor != null) {
 | 
			
		||||
            historyProcessor.record(rawObservation);
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    private Map<String, Object> buildChannelsData(final O obs) {
 | 
			
		||||
        return new HashMap<String, Object>() {{
 | 
			
		||||
            put("data", obs);
 | 
			
		||||
        }};
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    @Override
 | 
			
		||||
    public void close() {
 | 
			
		||||
        wrappedMDP.close();
 | 
			
		||||
 | 
			
		||||
@ -34,6 +34,7 @@ public class AsyncThreadDiscreteTest {
 | 
			
		||||
        MockPolicy policyMock = new MockPolicy();
 | 
			
		||||
        MockAsyncConfiguration config = new MockAsyncConfiguration(5, 100, 0, 0, 2, 5,0, 0, 0, 0);
 | 
			
		||||
        TestAsyncThreadDiscrete sut = new TestAsyncThreadDiscrete(asyncGlobalMock, mdpMock, listeners, 0, 0, policyMock, config, hpMock);
 | 
			
		||||
        sut.getLegacyMDPWrapper().setTransformProcess(MockMDP.buildTransformProcess(observationSpace.getShape(), hpConf.getSkipFrame(), hpConf.getHistoryLength()));
 | 
			
		||||
 | 
			
		||||
        // Act
 | 
			
		||||
        sut.run();
 | 
			
		||||
@ -60,12 +61,6 @@ public class AsyncThreadDiscreteTest {
 | 
			
		||||
        assertEquals(2, asyncGlobalMock.enqueueCallCount);
 | 
			
		||||
 | 
			
		||||
        // HistoryProcessor
 | 
			
		||||
        double[] expectedAddValues = new double[] { 0.0, 2.0, 4.0, 6.0, 8.0, 10.0, 12.0, 14.0, 16.0 };
 | 
			
		||||
        assertEquals(expectedAddValues.length, hpMock.addCalls.size());
 | 
			
		||||
        for(int i = 0; i < expectedAddValues.length; ++i) {
 | 
			
		||||
            assertEquals(expectedAddValues[i], hpMock.addCalls.get(i).getDouble(0), 0.00001);
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        double[] expectedRecordValues = new double[] { 0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0, };
 | 
			
		||||
        assertEquals(expectedRecordValues.length, hpMock.recordCalls.size());
 | 
			
		||||
        for(int i = 0; i < expectedRecordValues.length; ++i) {
 | 
			
		||||
 | 
			
		||||
@ -138,6 +138,7 @@ public class AsyncThreadTest {
 | 
			
		||||
            asyncGlobal.setMaxLoops(numEpochs);
 | 
			
		||||
            listeners.add(listener);
 | 
			
		||||
            sut.setHistoryProcessor(historyProcessor);
 | 
			
		||||
            sut.getLegacyMDPWrapper().setTransformProcess(MockMDP.buildTransformProcess(observationSpace.getShape(), hpConf.getSkipFrame(), hpConf.getHistoryLength()));
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
@ -209,7 +210,4 @@ public class AsyncThreadTest {
 | 
			
		||||
            int nstep;
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@ -18,7 +18,7 @@ public class ExpReplayTest {
 | 
			
		||||
        ExpReplay<Integer> sut = new ExpReplay<Integer>(2, 1, randomMock);
 | 
			
		||||
 | 
			
		||||
        // Act
 | 
			
		||||
        Transition<Integer> transition = buildTransition(new Observation(new INDArray[] { Nd4j.create(1) }),
 | 
			
		||||
        Transition<Integer> transition = buildTransition(buildObservation(),
 | 
			
		||||
                123, 234, new Observation(Nd4j.create(1)));
 | 
			
		||||
        sut.store(transition);
 | 
			
		||||
        List<Transition<Integer>> results = sut.getBatch(1);
 | 
			
		||||
@ -36,11 +36,11 @@ public class ExpReplayTest {
 | 
			
		||||
        ExpReplay<Integer> sut = new ExpReplay<Integer>(2, 1, randomMock);
 | 
			
		||||
 | 
			
		||||
        // Act
 | 
			
		||||
        Transition<Integer> transition1 = buildTransition(new Observation(new INDArray[] { Nd4j.create(1) }),
 | 
			
		||||
        Transition<Integer> transition1 = buildTransition(buildObservation(),
 | 
			
		||||
                1, 2, new Observation(Nd4j.create(1)));
 | 
			
		||||
        Transition<Integer> transition2 = buildTransition(new Observation(new INDArray[] { Nd4j.create(1) }),
 | 
			
		||||
        Transition<Integer> transition2 = buildTransition(buildObservation(),
 | 
			
		||||
                3, 4, new Observation(Nd4j.create(1)));
 | 
			
		||||
        Transition<Integer> transition3 = buildTransition(new Observation(new INDArray[] { Nd4j.create(1) }),
 | 
			
		||||
        Transition<Integer> transition3 = buildTransition(buildObservation(),
 | 
			
		||||
                5, 6, new Observation(Nd4j.create(1)));
 | 
			
		||||
        sut.store(transition1);
 | 
			
		||||
        sut.store(transition2);
 | 
			
		||||
@ -78,11 +78,11 @@ public class ExpReplayTest {
 | 
			
		||||
        ExpReplay<Integer> sut = new ExpReplay<Integer>(5, 1, randomMock);
 | 
			
		||||
 | 
			
		||||
        // Act
 | 
			
		||||
        Transition<Integer> transition1 = buildTransition(new Observation(new INDArray[] { Nd4j.create(1) }),
 | 
			
		||||
        Transition<Integer> transition1 = buildTransition(buildObservation(),
 | 
			
		||||
                1, 2, new Observation(Nd4j.create(1)));
 | 
			
		||||
        Transition<Integer> transition2 = buildTransition(new Observation(new INDArray[] { Nd4j.create(1) }),
 | 
			
		||||
        Transition<Integer> transition2 = buildTransition(buildObservation(),
 | 
			
		||||
                3, 4, new Observation(Nd4j.create(1)));
 | 
			
		||||
        Transition<Integer> transition3 = buildTransition(new Observation(new INDArray[] { Nd4j.create(1) }),
 | 
			
		||||
        Transition<Integer> transition3 = buildTransition(buildObservation(),
 | 
			
		||||
                5, 6, new Observation(Nd4j.create(1)));
 | 
			
		||||
        sut.store(transition1);
 | 
			
		||||
        sut.store(transition2);
 | 
			
		||||
@ -100,11 +100,11 @@ public class ExpReplayTest {
 | 
			
		||||
        ExpReplay<Integer> sut = new ExpReplay<Integer>(5, 1, randomMock);
 | 
			
		||||
 | 
			
		||||
        // Act
 | 
			
		||||
        Transition<Integer> transition1 = buildTransition(new Observation(new INDArray[] { Nd4j.create(1) }),
 | 
			
		||||
        Transition<Integer> transition1 = buildTransition(buildObservation(),
 | 
			
		||||
                1, 2, new Observation(Nd4j.create(1)));
 | 
			
		||||
        Transition<Integer> transition2 = buildTransition(new Observation(new INDArray[] { Nd4j.create(1) }),
 | 
			
		||||
        Transition<Integer> transition2 = buildTransition(buildObservation(),
 | 
			
		||||
                3, 4, new Observation(Nd4j.create(1)));
 | 
			
		||||
        Transition<Integer> transition3 = buildTransition(new Observation(new INDArray[] { Nd4j.create(1) }),
 | 
			
		||||
        Transition<Integer> transition3 = buildTransition(buildObservation(),
 | 
			
		||||
                5, 6, new Observation(Nd4j.create(1)));
 | 
			
		||||
        sut.store(transition1);
 | 
			
		||||
        sut.store(transition2);
 | 
			
		||||
@ -131,15 +131,15 @@ public class ExpReplayTest {
 | 
			
		||||
        ExpReplay<Integer> sut = new ExpReplay<Integer>(5, 1, randomMock);
 | 
			
		||||
 | 
			
		||||
        // Act
 | 
			
		||||
        Transition<Integer> transition1 = buildTransition(new Observation(new INDArray[] { Nd4j.create(1) }),
 | 
			
		||||
        Transition<Integer> transition1 = buildTransition(buildObservation(),
 | 
			
		||||
                1, 2, new Observation(Nd4j.create(1)));
 | 
			
		||||
        Transition<Integer> transition2 = buildTransition(new Observation(new INDArray[] { Nd4j.create(1) }),
 | 
			
		||||
        Transition<Integer> transition2 = buildTransition(buildObservation(),
 | 
			
		||||
                3, 4, new Observation(Nd4j.create(1)));
 | 
			
		||||
        Transition<Integer> transition3 = buildTransition(new Observation(new INDArray[] { Nd4j.create(1) }),
 | 
			
		||||
        Transition<Integer> transition3 = buildTransition(buildObservation(),
 | 
			
		||||
                5, 6, new Observation(Nd4j.create(1)));
 | 
			
		||||
        Transition<Integer> transition4 = buildTransition(new Observation(new INDArray[] { Nd4j.create(1) }),
 | 
			
		||||
        Transition<Integer> transition4 = buildTransition(buildObservation(),
 | 
			
		||||
                7, 8, new Observation(Nd4j.create(1)));
 | 
			
		||||
        Transition<Integer> transition5 = buildTransition(new Observation(new INDArray[] { Nd4j.create(1) }),
 | 
			
		||||
        Transition<Integer> transition5 = buildTransition(buildObservation(),
 | 
			
		||||
                9, 10, new Observation(Nd4j.create(1)));
 | 
			
		||||
        sut.store(transition1);
 | 
			
		||||
        sut.store(transition2);
 | 
			
		||||
@ -168,15 +168,15 @@ public class ExpReplayTest {
 | 
			
		||||
        ExpReplay<Integer> sut = new ExpReplay<Integer>(5, 1, randomMock);
 | 
			
		||||
 | 
			
		||||
        // Act
 | 
			
		||||
        Transition<Integer> transition1 = buildTransition(new Observation(new INDArray[] { Nd4j.create(1) }),
 | 
			
		||||
        Transition<Integer> transition1 = buildTransition(buildObservation(),
 | 
			
		||||
                1, 2, new Observation(Nd4j.create(1)));
 | 
			
		||||
        Transition<Integer> transition2 = buildTransition(new Observation(new INDArray[] { Nd4j.create(1) }),
 | 
			
		||||
        Transition<Integer> transition2 = buildTransition(buildObservation(),
 | 
			
		||||
                3, 4, new Observation(Nd4j.create(1)));
 | 
			
		||||
        Transition<Integer> transition3 = buildTransition(new Observation(new INDArray[] { Nd4j.create(1) }),
 | 
			
		||||
        Transition<Integer> transition3 = buildTransition(buildObservation(),
 | 
			
		||||
                5, 6, new Observation(Nd4j.create(1)));
 | 
			
		||||
        Transition<Integer> transition4 = buildTransition(new Observation(new INDArray[] { Nd4j.create(1) }),
 | 
			
		||||
        Transition<Integer> transition4 = buildTransition(buildObservation(),
 | 
			
		||||
                7, 8, new Observation(Nd4j.create(1)));
 | 
			
		||||
        Transition<Integer> transition5 = buildTransition(new Observation(new INDArray[] { Nd4j.create(1) }),
 | 
			
		||||
        Transition<Integer> transition5 = buildTransition(buildObservation(),
 | 
			
		||||
                9, 10, new Observation(Nd4j.create(1)));
 | 
			
		||||
        sut.store(transition1);
 | 
			
		||||
        sut.store(transition2);
 | 
			
		||||
@ -204,4 +204,8 @@ public class ExpReplayTest {
 | 
			
		||||
 | 
			
		||||
        return result;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    private Observation buildObservation() {
 | 
			
		||||
        return new Observation(Nd4j.create(1, 1));
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@ -193,11 +193,11 @@ public class TransitionTest {
 | 
			
		||||
                Nd4j.create(obs[1]).reshape(1, 3),
 | 
			
		||||
                Nd4j.create(obs[2]).reshape(1, 3),
 | 
			
		||||
        };
 | 
			
		||||
        return new Observation(history);
 | 
			
		||||
        return new Observation(Nd4j.concat(0, history));
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    private Observation buildObservation(double[] obs) {
 | 
			
		||||
        return new Observation(new INDArray[] { Nd4j.create(obs).reshape(1, 3) });
 | 
			
		||||
        return new Observation(Nd4j.create(obs).reshape(1, 3));
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    private Observation buildNextObservation(double[][] obs, double[] nextObs) {
 | 
			
		||||
@ -206,7 +206,7 @@ public class TransitionTest {
 | 
			
		||||
                Nd4j.create(obs[0]).reshape(1, 3),
 | 
			
		||||
                Nd4j.create(obs[1]).reshape(1, 3),
 | 
			
		||||
        };
 | 
			
		||||
        return new Observation(nextHistory);
 | 
			
		||||
        return new Observation(Nd4j.concat(0, nextHistory));
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    private Transition buildTransition(Observation observation, int action, double reward, Observation nextObservation) {
 | 
			
		||||
 | 
			
		||||
@ -50,6 +50,7 @@ public class QLearningDiscreteTest {
 | 
			
		||||
        IHistoryProcessor.Configuration hpConf = new IHistoryProcessor.Configuration(5, 4, 4, 4, 4, 0, 0, 2);
 | 
			
		||||
        MockHistoryProcessor hp = new MockHistoryProcessor(hpConf);
 | 
			
		||||
        sut.setHistoryProcessor(hp);
 | 
			
		||||
        sut.getLegacyMDPWrapper().setTransformProcess(MockMDP.buildTransformProcess(observationSpace.getShape(), hpConf.getSkipFrame(), hpConf.getHistoryLength()));
 | 
			
		||||
        List<QLearning.QLStepReturn<MockEncodable>> results = new ArrayList<>();
 | 
			
		||||
 | 
			
		||||
        // Act
 | 
			
		||||
@ -62,11 +63,7 @@ public class QLearningDiscreteTest {
 | 
			
		||||
        for(int i = 0; i < expectedRecords.length; ++i) {
 | 
			
		||||
            assertEquals(expectedRecords[i], hp.recordCalls.get(i).getDouble(0), 0.0001);
 | 
			
		||||
        }
 | 
			
		||||
        double[] expectedAdds = new double[] { 0.0, 2.0, 4.0, 6.0, 8.0, 10.0, 12.0, 14.0, 16.0, 18.0, 20.0, 22.0, 24.0 };
 | 
			
		||||
        assertEquals(expectedAdds.length, hp.addCalls.size());
 | 
			
		||||
        for(int i = 0; i < expectedAdds.length; ++i) {
 | 
			
		||||
            assertEquals(expectedAdds[i], hp.addCalls.get(i).getDouble(0), 0.0001);
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        assertEquals(0, hp.startMonitorCallCount);
 | 
			
		||||
        assertEquals(0, hp.stopMonitorCallCount);
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -106,7 +106,7 @@ public class DoubleDQNTest {
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    private Observation buildObservation(double[] data) {
 | 
			
		||||
        return new Observation(new INDArray[]{Nd4j.create(data).reshape(1, 2)});
 | 
			
		||||
        return new Observation(Nd4j.create(data).reshape(1, 2));
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    private Transition<Integer> builtTransition(Observation observation, Integer action, double reward, boolean isTerminal, Observation nextObservation) {
 | 
			
		||||
 | 
			
		||||
@ -105,7 +105,7 @@ public class StandardDQNTest {
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    private Observation buildObservation(double[] data) {
 | 
			
		||||
        return new Observation(new INDArray[]{Nd4j.create(data).reshape(1, 2)});
 | 
			
		||||
        return new Observation(Nd4j.create(data).reshape(1, 2));
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    private Transition<Integer> buildTransition(Observation observation, Integer action, double reward, boolean isTerminal, Observation nextObservation) {
 | 
			
		||||
 | 
			
		||||
@ -0,0 +1,378 @@
 | 
			
		||||
package org.deeplearning4j.rl4j.observation.transform;
 | 
			
		||||
 | 
			
		||||
import org.deeplearning4j.rl4j.observation.Observation;
 | 
			
		||||
import org.junit.Test;
 | 
			
		||||
import org.nd4j.linalg.dataset.api.DataSet;
 | 
			
		||||
import org.nd4j.linalg.dataset.api.DataSetPreProcessor;
 | 
			
		||||
import org.nd4j.linalg.factory.Nd4j;
 | 
			
		||||
import org.datavec.api.transform.Operation;
 | 
			
		||||
 | 
			
		||||
import java.util.HashMap;
 | 
			
		||||
import java.util.Map;
 | 
			
		||||
 | 
			
		||||
import static org.junit.Assert.*;
 | 
			
		||||
 | 
			
		||||
public class TransformProcessTest {
 | 
			
		||||
    @Test(expected = IllegalArgumentException.class)
 | 
			
		||||
    public void when_noChannelNameIsSuppliedToBuild_expect_exception() {
 | 
			
		||||
        // Arrange
 | 
			
		||||
        TransformProcess.builder().build();
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    @Test(expected = IllegalArgumentException.class)
 | 
			
		||||
    public void when_callingTransformWithNullArg_expect_exception() {
 | 
			
		||||
        // Arrange
 | 
			
		||||
        TransformProcess sut = TransformProcess.builder()
 | 
			
		||||
                .build("test");
 | 
			
		||||
 | 
			
		||||
        // Act
 | 
			
		||||
        sut.transform(null, 0, false);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    @Test(expected = IllegalArgumentException.class)
 | 
			
		||||
    public void when_callingTransformWithEmptyChannelData_expect_exception() {
 | 
			
		||||
        // Arrange
 | 
			
		||||
        TransformProcess sut = TransformProcess.builder()
 | 
			
		||||
                .build("test");
 | 
			
		||||
        Map<String, Object> channelsData = new HashMap<String, Object>();
 | 
			
		||||
 | 
			
		||||
        // Act
 | 
			
		||||
        sut.transform(channelsData, 0, false);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    @Test(expected = NullPointerException.class)
 | 
			
		||||
    public void when_addingNullFilter_expect_nullException() {
 | 
			
		||||
        // Act
 | 
			
		||||
        TransformProcess.builder().filter(null);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    @Test
 | 
			
		||||
    public void when_fileteredOut_expect_skippedObservationAndFollowingOperationsSkipped() {
 | 
			
		||||
        // Arrange
 | 
			
		||||
        IntegerTransformOperationMock transformOperationMock = new IntegerTransformOperationMock();
 | 
			
		||||
        TransformProcess sut = TransformProcess.builder()
 | 
			
		||||
                .filter(new FilterOperationMock(true))
 | 
			
		||||
                .transform("test", transformOperationMock)
 | 
			
		||||
                .build("test");
 | 
			
		||||
        Map<String, Object> channelsData = new HashMap<String, Object>() {{
 | 
			
		||||
            put("test", 1);
 | 
			
		||||
        }};
 | 
			
		||||
 | 
			
		||||
        // Act
 | 
			
		||||
        Observation result = sut.transform(channelsData, 0, false);
 | 
			
		||||
 | 
			
		||||
        // Assert
 | 
			
		||||
        assertTrue(result.isSkipped());
 | 
			
		||||
        assertFalse(transformOperationMock.isCalled);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    @Test(expected = NullPointerException.class)
 | 
			
		||||
    public void when_addingTransformOnNullChannel_expect_nullException() {
 | 
			
		||||
        // Act
 | 
			
		||||
        TransformProcess.builder().transform(null, new IntegerTransformOperationMock());
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    @Test(expected = NullPointerException.class)
 | 
			
		||||
    public void when_addingTransformWithNullTransform_expect_nullException() {
 | 
			
		||||
        // Act
 | 
			
		||||
        TransformProcess.builder().transform("test", null);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    @Test
 | 
			
		||||
    public void when_transformIsCalled_expect_channelDataTransformedInSameOrder() {
 | 
			
		||||
        // Arrange
 | 
			
		||||
        TransformProcess sut = TransformProcess.builder()
 | 
			
		||||
                .filter(new FilterOperationMock(false))
 | 
			
		||||
                .transform("test", new IntegerTransformOperationMock())
 | 
			
		||||
                .transform("test", new ToDataSetTransformOperationMock())
 | 
			
		||||
                .build("test");
 | 
			
		||||
        Map<String, Object> channelsData = new HashMap<String, Object>() {{
 | 
			
		||||
            put("test", 1);
 | 
			
		||||
        }};
 | 
			
		||||
 | 
			
		||||
        // Act
 | 
			
		||||
        Observation result = sut.transform(channelsData, 0, false);
 | 
			
		||||
 | 
			
		||||
        // Assert
 | 
			
		||||
        assertFalse(result.isSkipped());
 | 
			
		||||
        assertEquals(-1.0, result.getData().getDouble(0), 0.00001);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    @Test(expected = NullPointerException.class)
 | 
			
		||||
    public void when_addingPreProcessOnNullChannel_expect_nullException() {
 | 
			
		||||
        // Act
 | 
			
		||||
        TransformProcess.builder().preProcess(null, new DataSetPreProcessorMock());
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    @Test(expected = NullPointerException.class)
 | 
			
		||||
    public void when_addingPreProcessWithNullTransform_expect_nullException() {
 | 
			
		||||
        // Act
 | 
			
		||||
        TransformProcess.builder().transform("test", null);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    @Test
 | 
			
		||||
    public void when_preProcessIsCalled_expect_channelDataPreProcessedInSameOrder() {
 | 
			
		||||
        // Arrange
 | 
			
		||||
        TransformProcess sut = TransformProcess.builder()
 | 
			
		||||
                .filter(new FilterOperationMock(false))
 | 
			
		||||
                .transform("test", new IntegerTransformOperationMock())
 | 
			
		||||
                .transform("test", new ToDataSetTransformOperationMock())
 | 
			
		||||
                .preProcess("test", new DataSetPreProcessorMock())
 | 
			
		||||
                .build("test");
 | 
			
		||||
        Map<String, Object> channelsData = new HashMap<String, Object>() {{
 | 
			
		||||
            put("test", 1);
 | 
			
		||||
        }};
 | 
			
		||||
 | 
			
		||||
        // Act
 | 
			
		||||
        Observation result = sut.transform(channelsData, 0, false);
 | 
			
		||||
 | 
			
		||||
        // Assert
 | 
			
		||||
        assertFalse(result.isSkipped());
 | 
			
		||||
        assertEquals(1, result.getData().shape().length);
 | 
			
		||||
        assertEquals(1, result.getData().shape()[0]);
 | 
			
		||||
        assertEquals(-10.0, result.getData().getDouble(0), 0.00001);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    @Test(expected = IllegalStateException.class)
 | 
			
		||||
    public void when_transformingNullData_expect_exception() {
 | 
			
		||||
        // Arrange
 | 
			
		||||
        TransformProcess sut = TransformProcess.builder()
 | 
			
		||||
                .transform("test", new IntegerTransformOperationMock())
 | 
			
		||||
                .build("test");
 | 
			
		||||
        Map<String, Object> channelsData = new HashMap<String, Object>() {{
 | 
			
		||||
            put("test", 1);
 | 
			
		||||
        }};
 | 
			
		||||
 | 
			
		||||
        // Act
 | 
			
		||||
        Observation result = sut.transform(channelsData, 0, false);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    @Test(expected = IllegalArgumentException.class)
 | 
			
		||||
    public void when_transformingAndChannelsNotDataSet_expect_exception() {
 | 
			
		||||
        // Arrange
 | 
			
		||||
        TransformProcess sut = TransformProcess.builder()
 | 
			
		||||
                .preProcess("test", new DataSetPreProcessorMock())
 | 
			
		||||
                .build("test");
 | 
			
		||||
 | 
			
		||||
        // Act
 | 
			
		||||
        Observation result = sut.transform(null, 0, false);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
    @Test(expected = IllegalArgumentException.class)
 | 
			
		||||
    public void when_transformingAndChannelsEmptyDataSet_expect_exception() {
 | 
			
		||||
        // Arrange
 | 
			
		||||
        TransformProcess sut = TransformProcess.builder()
 | 
			
		||||
                .preProcess("test", new DataSetPreProcessorMock())
 | 
			
		||||
                .build("test");
 | 
			
		||||
        Map<String, Object> channelsData = new HashMap<String, Object>();
 | 
			
		||||
 | 
			
		||||
        // Act
 | 
			
		||||
        Observation result = sut.transform(channelsData, 0, false);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    @Test(expected = IllegalArgumentException.class)
 | 
			
		||||
    public void when_buildIsCalledWithoutChannelNames_expect_exception() {
 | 
			
		||||
        // Act
 | 
			
		||||
        TransformProcess.builder().build();
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    @Test(expected = NullPointerException.class)
 | 
			
		||||
    public void when_buildIsCalledWithNullChannelName_expect_exception() {
 | 
			
		||||
        // Act
 | 
			
		||||
        TransformProcess.builder().build(null);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    @Test
 | 
			
		||||
    public void when_resetIsCalled_expect_resettableAreReset() {
 | 
			
		||||
        // Arrange
 | 
			
		||||
        ResettableTransformOperationMock resettableOperation = new ResettableTransformOperationMock();
 | 
			
		||||
        TransformProcess sut = TransformProcess.builder()
 | 
			
		||||
                .filter(new FilterOperationMock(false))
 | 
			
		||||
                .transform("test", new IntegerTransformOperationMock())
 | 
			
		||||
                .transform("test", resettableOperation)
 | 
			
		||||
                .build("test");
 | 
			
		||||
 | 
			
		||||
        // Act
 | 
			
		||||
        sut.reset();
 | 
			
		||||
 | 
			
		||||
        // Assert
 | 
			
		||||
        assertTrue(resettableOperation.isResetCalled);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    @Test
 | 
			
		||||
    public void when_buildIsCalledAndAllChannelsAreDataSets_expect_observation() {
 | 
			
		||||
        // Arrange
 | 
			
		||||
        TransformProcess sut = TransformProcess.builder()
 | 
			
		||||
                .transform("test", new ToDataSetTransformOperationMock())
 | 
			
		||||
                .build("test");
 | 
			
		||||
        Map<String, Object> channelsData = new HashMap<String, Object>() {{
 | 
			
		||||
            put("test", 1);
 | 
			
		||||
        }};
 | 
			
		||||
 | 
			
		||||
        // Act
 | 
			
		||||
        Observation result = sut.transform(channelsData, 123, true);
 | 
			
		||||
 | 
			
		||||
        // Assert
 | 
			
		||||
        assertFalse(result.isSkipped());
 | 
			
		||||
 | 
			
		||||
        assertEquals(1.0, result.getData().getDouble(0), 0.00001);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    @Test
 | 
			
		||||
    public void when_buildIsCalledAndAllChannelsAreINDArrays_expect_observation() {
 | 
			
		||||
        // Arrange
 | 
			
		||||
        TransformProcess sut = TransformProcess.builder()
 | 
			
		||||
                .build("test");
 | 
			
		||||
        Map<String, Object> channelsData = new HashMap<String, Object>() {{
 | 
			
		||||
            put("test", Nd4j.create(new double[] { 1.0 }));
 | 
			
		||||
        }};
 | 
			
		||||
 | 
			
		||||
        // Act
 | 
			
		||||
        Observation result = sut.transform(channelsData, 123, true);
 | 
			
		||||
 | 
			
		||||
        // Assert
 | 
			
		||||
        assertFalse(result.isSkipped());
 | 
			
		||||
 | 
			
		||||
        assertEquals(1.0, result.getData().getDouble(0), 0.00001);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    @Test(expected = IllegalStateException.class)
 | 
			
		||||
    public void when_buildIsCalledAndChannelsNotDataSetsOrINDArrays_expect_exception() {
 | 
			
		||||
        // Arrange
 | 
			
		||||
        TransformProcess sut = TransformProcess.builder()
 | 
			
		||||
                .build("test");
 | 
			
		||||
        Map<String, Object> channelsData = new HashMap<String, Object>() {{
 | 
			
		||||
            put("test", 1);
 | 
			
		||||
        }};
 | 
			
		||||
 | 
			
		||||
        // Act
 | 
			
		||||
        Observation result = sut.transform(channelsData, 123, true);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    @Test(expected = NullPointerException.class)
 | 
			
		||||
    public void when_channelDataIsNull_expect_exception() {
 | 
			
		||||
        // Arrange
 | 
			
		||||
        TransformProcess sut = TransformProcess.builder()
 | 
			
		||||
                .transform("test", new IntegerTransformOperationMock())
 | 
			
		||||
                .build("test");
 | 
			
		||||
        Map<String, Object> channelsData = new HashMap<String, Object>() {{
 | 
			
		||||
            put("test", null);
 | 
			
		||||
        }};
 | 
			
		||||
 | 
			
		||||
        // Act
 | 
			
		||||
        sut.transform(channelsData, 0, false);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    @Test(expected = IllegalArgumentException.class)
 | 
			
		||||
    public void when_transformAppliedOnChannelNotInMap_expect_exception() {
 | 
			
		||||
        // Arrange
 | 
			
		||||
        TransformProcess sut = TransformProcess.builder()
 | 
			
		||||
                .transform("test", new IntegerTransformOperationMock())
 | 
			
		||||
                .build("test");
 | 
			
		||||
        Map<String, Object> channelsData = new HashMap<String, Object>() {{
 | 
			
		||||
            put("not-test", 1);
 | 
			
		||||
        }};
 | 
			
		||||
 | 
			
		||||
        // Act
 | 
			
		||||
        sut.transform(channelsData, 0, false);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    @Test(expected = IllegalArgumentException.class)
 | 
			
		||||
    public void when_preProcessAppliedOnChannelNotInMap_expect_exception() {
 | 
			
		||||
        // Arrange
 | 
			
		||||
        TransformProcess sut = TransformProcess.builder()
 | 
			
		||||
                .preProcess("test", new DataSetPreProcessorMock())
 | 
			
		||||
                .build("test");
 | 
			
		||||
        Map<String, Object> channelsData = new HashMap<String, Object>() {{
 | 
			
		||||
            put("not-test", 1);
 | 
			
		||||
        }};
 | 
			
		||||
 | 
			
		||||
        // Act
 | 
			
		||||
        sut.transform(channelsData, 0, false);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    @Test(expected = IllegalArgumentException.class)
 | 
			
		||||
    public void when_buildContainsChannelNotInMap_expect_exception() {
 | 
			
		||||
        // Arrange
 | 
			
		||||
        TransformProcess sut = TransformProcess.builder()
 | 
			
		||||
                .transform("test", new IntegerTransformOperationMock())
 | 
			
		||||
                .build("not-test");
 | 
			
		||||
        Map<String, Object> channelsData = new HashMap<String, Object>() {{
 | 
			
		||||
            put("test", 1);
 | 
			
		||||
        }};
 | 
			
		||||
 | 
			
		||||
        // Act
 | 
			
		||||
        sut.transform(channelsData, 0, false);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    @Test(expected = IllegalArgumentException.class)
 | 
			
		||||
    public void when_preProcessNotAppliedOnDataSet_expect_exception() {
 | 
			
		||||
        // Arrange
 | 
			
		||||
        TransformProcess sut = TransformProcess.builder()
 | 
			
		||||
                .preProcess("test", new DataSetPreProcessorMock())
 | 
			
		||||
                .build("test");
 | 
			
		||||
        Map<String, Object> channelsData = new HashMap<String, Object>() {{
 | 
			
		||||
            put("test", 1);
 | 
			
		||||
        }};
 | 
			
		||||
 | 
			
		||||
        // Act
 | 
			
		||||
        sut.transform(channelsData, 0, false);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    private static class FilterOperationMock implements FilterOperation {
 | 
			
		||||
 | 
			
		||||
        private final boolean skipped;
 | 
			
		||||
 | 
			
		||||
        public FilterOperationMock(boolean skipped) {
 | 
			
		||||
            this.skipped = skipped;
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        @Override
 | 
			
		||||
        public boolean isSkipped(Map<String, Object> channelsData, int currentObservationStep, boolean isFinalObservation) {
 | 
			
		||||
            return skipped;
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    private static class IntegerTransformOperationMock implements Operation<Integer, Integer> {
 | 
			
		||||
 | 
			
		||||
        public boolean isCalled = false;
 | 
			
		||||
 | 
			
		||||
        @Override
 | 
			
		||||
        public Integer transform(Integer input) {
 | 
			
		||||
            isCalled = true;
 | 
			
		||||
            return -input;
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    private static class ToDataSetTransformOperationMock implements Operation<Integer, DataSet> {
 | 
			
		||||
 | 
			
		||||
        @Override
 | 
			
		||||
        public DataSet transform(Integer input) {
 | 
			
		||||
            return new org.nd4j.linalg.dataset.DataSet(Nd4j.create(new double[] { input }), null);
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    private static class ResettableTransformOperationMock implements Operation<Integer, Integer>, ResettableOperation {
 | 
			
		||||
 | 
			
		||||
        private boolean isResetCalled = false;
 | 
			
		||||
 | 
			
		||||
        @Override
 | 
			
		||||
        public Integer transform(Integer input) {
 | 
			
		||||
            return input * 10;
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        @Override
 | 
			
		||||
        public void reset() {
 | 
			
		||||
            isResetCalled = true;
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    private static class DataSetPreProcessorMock implements DataSetPreProcessor {
 | 
			
		||||
 | 
			
		||||
        @Override
 | 
			
		||||
        public void preProcess(DataSet dataSet) {
 | 
			
		||||
            dataSet.getFeatures().muli(10.0);
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
@ -0,0 +1,54 @@
 | 
			
		||||
package org.deeplearning4j.rl4j.observation.transform.filter;
 | 
			
		||||
 | 
			
		||||
import org.deeplearning4j.rl4j.observation.transform.FilterOperation;
 | 
			
		||||
import org.junit.Test;
 | 
			
		||||
 | 
			
		||||
import static org.junit.Assert.assertFalse;
 | 
			
		||||
import static org.junit.Assert.assertTrue;
 | 
			
		||||
 | 
			
		||||
public class UniformSkippingFilterTest {
 | 
			
		||||
 | 
			
		||||
    @Test(expected = IllegalArgumentException.class)
 | 
			
		||||
    public void when_negativeSkipFrame_expect_exception() {
 | 
			
		||||
        // Act
 | 
			
		||||
        new UniformSkippingFilter(-1);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    @Test
 | 
			
		||||
    public void when_skippingIs4_expect_firstNotSkippedOther3Skipped() {
 | 
			
		||||
        // Assemble
 | 
			
		||||
        FilterOperation sut = new UniformSkippingFilter(4);
 | 
			
		||||
        boolean[] isSkipped = new boolean[8];
 | 
			
		||||
 | 
			
		||||
        // Act
 | 
			
		||||
        for(int i = 0; i < 8; ++i) {
 | 
			
		||||
            isSkipped[i] = sut.isSkipped(null, i, false);
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        // Assert
 | 
			
		||||
        assertFalse(isSkipped[0]);
 | 
			
		||||
        assertTrue(isSkipped[1]);
 | 
			
		||||
        assertTrue(isSkipped[2]);
 | 
			
		||||
        assertTrue(isSkipped[3]);
 | 
			
		||||
 | 
			
		||||
        assertFalse(isSkipped[4]);
 | 
			
		||||
        assertTrue(isSkipped[5]);
 | 
			
		||||
        assertTrue(isSkipped[6]);
 | 
			
		||||
        assertTrue(isSkipped[7]);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    @Test
 | 
			
		||||
    public void when_isLastObservation_expect_observationNotSkipped() {
 | 
			
		||||
        // Assemble
 | 
			
		||||
        FilterOperation sut = new UniformSkippingFilter(4);
 | 
			
		||||
 | 
			
		||||
        // Act
 | 
			
		||||
        boolean isSkippedNotLastObservation = sut.isSkipped(null, 1, false);
 | 
			
		||||
        boolean isSkippedLastObservation = sut.isSkipped(null, 1, true);
 | 
			
		||||
 | 
			
		||||
        // Assert
 | 
			
		||||
        assertTrue(isSkippedNotLastObservation);
 | 
			
		||||
        assertFalse(isSkippedLastObservation);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
}
 | 
			
		||||
@ -0,0 +1,31 @@
 | 
			
		||||
package org.deeplearning4j.rl4j.observation.transform.operation;
 | 
			
		||||
 | 
			
		||||
import org.deeplearning4j.rl4j.helper.INDArrayHelper;
 | 
			
		||||
import org.junit.Test;
 | 
			
		||||
import org.nd4j.linalg.api.ndarray.INDArray;
 | 
			
		||||
import org.nd4j.linalg.factory.Nd4j;
 | 
			
		||||
 | 
			
		||||
import static org.junit.Assert.*;
 | 
			
		||||
 | 
			
		||||
public class SimpleNormalizationTransformTest {
 | 
			
		||||
    @Test(expected = IllegalArgumentException.class)
 | 
			
		||||
    public void when_maxIsLessThanMin_expect_exception() {
 | 
			
		||||
        // Arrange
 | 
			
		||||
        SimpleNormalizationTransform sut = new SimpleNormalizationTransform(10.0, 1.0);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    @Test
 | 
			
		||||
    public void when_transformIsCalled_expect_inputNormalized() {
 | 
			
		||||
        // Arrange
 | 
			
		||||
        SimpleNormalizationTransform sut = new SimpleNormalizationTransform(1.0, 11.0);
 | 
			
		||||
        INDArray input = Nd4j.create(new double[] { 1.0, 11.0 });
 | 
			
		||||
 | 
			
		||||
        // Act
 | 
			
		||||
        INDArray result = sut.transform(input);
 | 
			
		||||
 | 
			
		||||
        // Assert
 | 
			
		||||
        assertEquals(0.0, result.getDouble(0), 0.00001);
 | 
			
		||||
        assertEquals(1.0, result.getDouble(1), 0.00001);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
}
 | 
			
		||||
@ -23,15 +23,18 @@ import org.deeplearning4j.nn.gradient.Gradient;
 | 
			
		||||
import org.deeplearning4j.nn.graph.ComputationGraph;
 | 
			
		||||
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
 | 
			
		||||
import org.deeplearning4j.rl4j.learning.IHistoryProcessor;
 | 
			
		||||
import org.deeplearning4j.rl4j.learning.Learning;
 | 
			
		||||
import org.deeplearning4j.rl4j.learning.sync.qlearning.QLearning;
 | 
			
		||||
import org.deeplearning4j.rl4j.learning.sync.qlearning.discrete.QLearningDiscreteTest;
 | 
			
		||||
import org.deeplearning4j.rl4j.mdp.MDP;
 | 
			
		||||
import org.deeplearning4j.rl4j.network.NeuralNet;
 | 
			
		||||
import org.deeplearning4j.rl4j.network.ac.IActorCritic;
 | 
			
		||||
import org.deeplearning4j.rl4j.observation.Observation;
 | 
			
		||||
import org.deeplearning4j.rl4j.space.ActionSpace;
 | 
			
		||||
import org.deeplearning4j.rl4j.space.DiscreteSpace;
 | 
			
		||||
import org.deeplearning4j.rl4j.space.Encodable;
 | 
			
		||||
import org.deeplearning4j.rl4j.support.*;
 | 
			
		||||
import org.deeplearning4j.rl4j.util.LegacyMDPWrapper;
 | 
			
		||||
import org.junit.Test;
 | 
			
		||||
import org.nd4j.linalg.activations.Activation;
 | 
			
		||||
import org.nd4j.linalg.api.ndarray.INDArray;
 | 
			
		||||
@ -186,8 +189,8 @@ public class PolicyTest {
 | 
			
		||||
        QLearning.QLConfiguration conf = new QLearning.QLConfiguration(0, 0, 0, 5, 1, 0,
 | 
			
		||||
                0, 1.0, 0, 0, 0, 0, true);
 | 
			
		||||
        MockNeuralNet nnMock = new MockNeuralNet();
 | 
			
		||||
        MockRefacPolicy sut = new MockRefacPolicy(nnMock);
 | 
			
		||||
        IHistoryProcessor.Configuration hpConf = new IHistoryProcessor.Configuration(5, 4, 4, 4, 4, 0, 0, 2);
 | 
			
		||||
        MockRefacPolicy sut = new MockRefacPolicy(nnMock, observationSpace.getShape(), hpConf.getSkipFrame(), hpConf.getHistoryLength());
 | 
			
		||||
        MockHistoryProcessor hp = new MockHistoryProcessor(hpConf);
 | 
			
		||||
 | 
			
		||||
        // Act
 | 
			
		||||
@ -197,13 +200,6 @@ public class PolicyTest {
 | 
			
		||||
        assertEquals(1, nnMock.resetCallCount);
 | 
			
		||||
        assertEquals(465.0, totalReward, 0.0001);
 | 
			
		||||
 | 
			
		||||
        // HistoryProcessor
 | 
			
		||||
        assertEquals(16, hp.addCalls.size());
 | 
			
		||||
        assertEquals(31, hp.recordCalls.size());
 | 
			
		||||
        for(int i=0; i <= 30; ++i) {
 | 
			
		||||
            assertEquals((double)i, hp.recordCalls.get(i).getDouble(0), 0.0001);
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        // MDP
 | 
			
		||||
        assertEquals(1, mdp.resetCount);
 | 
			
		||||
        assertEquals(30, mdp.actions.size());
 | 
			
		||||
@ -219,10 +215,15 @@ public class PolicyTest {
 | 
			
		||||
    public static class MockRefacPolicy extends Policy<MockEncodable, Integer> {
 | 
			
		||||
 | 
			
		||||
        private NeuralNet neuralNet;
 | 
			
		||||
        private final int[] shape;
 | 
			
		||||
        private final int skipFrame;
 | 
			
		||||
        private final int historyLength;
 | 
			
		||||
 | 
			
		||||
        public MockRefacPolicy(NeuralNet neuralNet) {
 | 
			
		||||
 | 
			
		||||
        public MockRefacPolicy(NeuralNet neuralNet, int[] shape, int skipFrame, int historyLength) {
 | 
			
		||||
            this.neuralNet = neuralNet;
 | 
			
		||||
            this.shape = shape;
 | 
			
		||||
            this.skipFrame = skipFrame;
 | 
			
		||||
            this.historyLength = historyLength;
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        @Override
 | 
			
		||||
@ -239,5 +240,11 @@ public class PolicyTest {
 | 
			
		||||
        public Integer nextAction(INDArray input) {
 | 
			
		||||
            return (int)input.getDouble(0);
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        @Override
 | 
			
		||||
        protected <AS extends ActionSpace<Integer>> Learning.InitMdp<Observation> refacInitMdp(LegacyMDPWrapper<MockEncodable, Integer, AS> mdpWrapper, IHistoryProcessor hp, RefacEpochStepCounter epochStepCounter) {
 | 
			
		||||
            mdpWrapper.setTransformProcess(MockMDP.buildTransformProcess(shape, skipFrame, historyLength));
 | 
			
		||||
            return super.refacInitMdp(mdpWrapper, hp, epochStepCounter);
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@ -2,6 +2,12 @@ package org.deeplearning4j.rl4j.support;
 | 
			
		||||
 | 
			
		||||
import org.deeplearning4j.gym.StepReply;
 | 
			
		||||
import org.deeplearning4j.rl4j.mdp.MDP;
 | 
			
		||||
import org.deeplearning4j.rl4j.observation.transform.TransformProcess;
 | 
			
		||||
import org.deeplearning4j.rl4j.observation.transform.filter.UniformSkippingFilter;
 | 
			
		||||
import org.deeplearning4j.rl4j.observation.transform.legacy.EncodableToINDArrayTransform;
 | 
			
		||||
import org.deeplearning4j.rl4j.observation.transform.operation.HistoryMergeTransform;
 | 
			
		||||
import org.deeplearning4j.rl4j.observation.transform.operation.SimpleNormalizationTransform;
 | 
			
		||||
import org.deeplearning4j.rl4j.observation.transform.operation.historymerge.CircularFifoStore;
 | 
			
		||||
import org.deeplearning4j.rl4j.space.DiscreteSpace;
 | 
			
		||||
import org.deeplearning4j.rl4j.space.ObservationSpace;
 | 
			
		||||
import org.nd4j.linalg.api.rng.Random;
 | 
			
		||||
@ -77,4 +83,16 @@ public class MockMDP implements MDP<MockEncodable, Integer, DiscreteSpace> {
 | 
			
		||||
    public MDP newInstance() {
 | 
			
		||||
        return null;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    public static TransformProcess buildTransformProcess(int[] shape, int skipFrame, int historyLength) {
 | 
			
		||||
        return TransformProcess.builder()
 | 
			
		||||
                .filter(new UniformSkippingFilter(skipFrame))
 | 
			
		||||
                .transform("data", new EncodableToINDArrayTransform(shape))
 | 
			
		||||
                .transform("data", new SimpleNormalizationTransform(0.0, 255.0))
 | 
			
		||||
                .transform("data", HistoryMergeTransform.builder()
 | 
			
		||||
                        .elementStore(new CircularFifoStore(historyLength))
 | 
			
		||||
                        .build())
 | 
			
		||||
                .build("data");
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user