RL4J: Add TransformProcess, part 1 (#8711)
* Added TransformProcess, part 1 Signed-off-by: unknown <aboulang2002@yahoo.com> * Renamed TemporalMergeTransform to HistoryMergeTransform Signed-off-by: unknown <aboulang2002@yahoo.com> * changed INDArrayHelper to use Nd4j.expandDims Signed-off-by: Alexandre Boulanger <aboulang2002@yahoo.com> * Adjusted copyrights Signed-off-by: unknown <aboulang2002@yahoo.com>
This commit is contained in:
		
							parent
							
								
									e4ddf109c3
								
							
						
					
					
						commit
						58aa5a3a9b
					
				@ -26,10 +26,7 @@ import org.datavec.api.transform.schema.Schema;
 | 
			
		||||
 *
 | 
			
		||||
 * @author Adam Gibson
 | 
			
		||||
 */
 | 
			
		||||
public interface ColumnOp {
 | 
			
		||||
    /** Get the output schema for this transformation, given an input schema */
 | 
			
		||||
    Schema transform(Schema inputSchema);
 | 
			
		||||
 | 
			
		||||
public interface ColumnOp extends Operation<Schema, Schema> {
 | 
			
		||||
 | 
			
		||||
    /** Set the input schema.
 | 
			
		||||
     */
 | 
			
		||||
 | 
			
		||||
@ -1,28 +1,20 @@
 | 
			
		||||
/*******************************************************************************
 | 
			
		||||
 * Copyright (c) 2015-2019 Skymind, Inc.
 | 
			
		||||
 *
 | 
			
		||||
 * This program and the accompanying materials are made available under the
 | 
			
		||||
 * terms of the Apache License, Version 2.0 which is available at
 | 
			
		||||
 * https://www.apache.org/licenses/LICENSE-2.0.
 | 
			
		||||
 *
 | 
			
		||||
 * Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
 * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 | 
			
		||||
 * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
 | 
			
		||||
 * License for the specific language governing permissions and limitations
 | 
			
		||||
 * under the License.
 | 
			
		||||
 *
 | 
			
		||||
 * SPDX-License-Identifier: Apache-2.0
 | 
			
		||||
 ******************************************************************************/
 | 
			
		||||
 | 
			
		||||
package org.deeplearning4j.rl4j.observation.preprocessor;
 | 
			
		||||
 | 
			
		||||
import org.nd4j.linalg.dataset.api.DataSetPreProcessor;
 | 
			
		||||
 | 
			
		||||
/**
 | 
			
		||||
 * A base class for all DataSetPreProcessor that must be reset between each MDP sessions (games).
 | 
			
		||||
 *
 | 
			
		||||
 * @author Alexandre Boulanger
 | 
			
		||||
 */
 | 
			
		||||
public abstract class ResettableDataSetPreProcessor implements DataSetPreProcessor {
 | 
			
		||||
    public abstract void reset();
 | 
			
		||||
}
 | 
			
		||||
/*******************************************************************************
 | 
			
		||||
 * Copyright (c) 2020 Konduit K.K.
 | 
			
		||||
 *
 | 
			
		||||
 * This program and the accompanying materials are made available under the
 | 
			
		||||
 * terms of the Apache License, Version 2.0 which is available at
 | 
			
		||||
 * https://www.apache.org/licenses/LICENSE-2.0.
 | 
			
		||||
 *
 | 
			
		||||
 * Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
 * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 | 
			
		||||
 * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
 | 
			
		||||
 * License for the specific language governing permissions and limitations
 | 
			
		||||
 * under the License.
 | 
			
		||||
 *
 | 
			
		||||
 * SPDX-License-Identifier: Apache-2.0
 | 
			
		||||
 ******************************************************************************/
 | 
			
		||||
package org.datavec.api.transform;
 | 
			
		||||
 | 
			
		||||
public interface Operation<TIn, TOut> {
 | 
			
		||||
    TOut transform(TIn input);
 | 
			
		||||
}
 | 
			
		||||
@ -16,6 +16,7 @@
 | 
			
		||||
 | 
			
		||||
package org.datavec.image.transform;
 | 
			
		||||
 | 
			
		||||
import org.datavec.api.transform.Operation;
 | 
			
		||||
import org.datavec.image.data.ImageWritable;
 | 
			
		||||
import org.nd4j.shade.jackson.annotation.JsonInclude;
 | 
			
		||||
import org.nd4j.shade.jackson.annotation.JsonTypeInfo;
 | 
			
		||||
@ -29,15 +30,7 @@ import java.util.Random;
 | 
			
		||||
 */
 | 
			
		||||
@JsonInclude(JsonInclude.Include.NON_NULL)
 | 
			
		||||
@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class")
 | 
			
		||||
public interface ImageTransform {
 | 
			
		||||
 | 
			
		||||
    /**
 | 
			
		||||
     * Takes an image and returns a transformed image.
 | 
			
		||||
     *
 | 
			
		||||
     * @param image to transform, null == end of stream
 | 
			
		||||
     * @return      transformed image
 | 
			
		||||
     */
 | 
			
		||||
    ImageWritable transform(ImageWritable image);
 | 
			
		||||
public interface ImageTransform extends Operation<ImageWritable, ImageWritable>  {
 | 
			
		||||
 | 
			
		||||
    /**
 | 
			
		||||
     * Takes an image and returns a transformed image.
 | 
			
		||||
 | 
			
		||||
@ -102,6 +102,13 @@
 | 
			
		||||
            <artifactId>gson</artifactId>
 | 
			
		||||
            <version>${gson.version}</version>
 | 
			
		||||
        </dependency>
 | 
			
		||||
 | 
			
		||||
        <dependency>
 | 
			
		||||
            <groupId>org.datavec</groupId>
 | 
			
		||||
            <artifactId>datavec-api</artifactId>
 | 
			
		||||
            <version>${datavec.version}</version>
 | 
			
		||||
        </dependency>
 | 
			
		||||
 | 
			
		||||
    </dependencies>
 | 
			
		||||
 | 
			
		||||
    <profiles>
 | 
			
		||||
 | 
			
		||||
@ -1,30 +1,39 @@
 | 
			
		||||
/*******************************************************************************
 | 
			
		||||
 * Copyright (c) 2015-2019 Skymind, Inc.
 | 
			
		||||
 *
 | 
			
		||||
 * This program and the accompanying materials are made available under the
 | 
			
		||||
 * terms of the Apache License, Version 2.0 which is available at
 | 
			
		||||
 * https://www.apache.org/licenses/LICENSE-2.0.
 | 
			
		||||
 *
 | 
			
		||||
 * Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
 * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 | 
			
		||||
 * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
 | 
			
		||||
 * License for the specific language governing permissions and limitations
 | 
			
		||||
 * under the License.
 | 
			
		||||
 *
 | 
			
		||||
 * SPDX-License-Identifier: Apache-2.0
 | 
			
		||||
 ******************************************************************************/
 | 
			
		||||
 | 
			
		||||
package org.deeplearning4j.rl4j.observation.preprocessor.pooling;
 | 
			
		||||
 | 
			
		||||
import org.nd4j.linalg.api.ndarray.INDArray;
 | 
			
		||||
 | 
			
		||||
/**
 | 
			
		||||
 * A PoolContentAssembler is used with the PoolingDataSetPreProcessor. This interface defines how the array of INDArray
 | 
			
		||||
 * returned by the ObservationPool is packaged into the single INDArray that will be set
 | 
			
		||||
 * in the DataSet of PoolingDataSetPreProcessor.preProcess
 | 
			
		||||
 *
 | 
			
		||||
 * @author Alexandre Boulanger
 | 
			
		||||
 */
 | 
			
		||||
public interface PoolContentAssembler {
 | 
			
		||||
    INDArray assemble(INDArray[] poolContent);
 | 
			
		||||
}
 | 
			
		||||
/*******************************************************************************
 | 
			
		||||
 * Copyright (c) 2020 Konduit K.K.
 | 
			
		||||
 *
 | 
			
		||||
 * This program and the accompanying materials are made available under the
 | 
			
		||||
 * terms of the Apache License, Version 2.0 which is available at
 | 
			
		||||
 * https://www.apache.org/licenses/LICENSE-2.0.
 | 
			
		||||
 *
 | 
			
		||||
 * Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
 * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 | 
			
		||||
 * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
 | 
			
		||||
 * License for the specific language governing permissions and limitations
 | 
			
		||||
 * under the License.
 | 
			
		||||
 *
 | 
			
		||||
 * SPDX-License-Identifier: Apache-2.0
 | 
			
		||||
 ******************************************************************************/
 | 
			
		||||
package org.deeplearning4j.rl4j.helper;
 | 
			
		||||
 | 
			
		||||
import org.nd4j.linalg.api.ndarray.INDArray;
 | 
			
		||||
import org.nd4j.linalg.factory.Nd4j;
 | 
			
		||||
 | 
			
		||||
/**
 | 
			
		||||
 * INDArray helper methods used by RL4J
 | 
			
		||||
 *
 | 
			
		||||
 * @author Alexandre Boulanger
 | 
			
		||||
 */
 | 
			
		||||
public class INDArrayHelper {
 | 
			
		||||
    /**
 | 
			
		||||
     * MultiLayerNetwork and ComputationGraph expect the first dimension to be the number of examples in the INDArray.
 | 
			
		||||
     * In the case of RL4J, it must be 1. This method will return a INDArray with the correct shape.
 | 
			
		||||
     *
 | 
			
		||||
     * @param source A INDArray
 | 
			
		||||
     * @return The source INDArray with the correct shape
 | 
			
		||||
     */
 | 
			
		||||
    public static INDArray forceCorrectShape(INDArray source) {
 | 
			
		||||
        return source.shape()[0] == 1
 | 
			
		||||
                ? source
 | 
			
		||||
                : Nd4j.expandDims(source, 0);
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
@ -90,7 +90,7 @@ public abstract class AsyncThreadDiscrete<O, NN extends NeuralNet>
 | 
			
		||||
            accuReward += stepReply.getReward() * getConf().getRewardFactor();
 | 
			
		||||
 | 
			
		||||
            //if it's not a skipped frame, you can do a step of training
 | 
			
		||||
            if (!obs.isSkipped() || stepReply.isDone()) {
 | 
			
		||||
            if (!obs.isSkipped()) {
 | 
			
		||||
 | 
			
		||||
                INDArray[] output = current.outputAll(obs.getData());
 | 
			
		||||
                rewards.add(new MiniTrans(obs.getData(), action, output, accuReward));
 | 
			
		||||
@ -99,7 +99,6 @@ public abstract class AsyncThreadDiscrete<O, NN extends NeuralNet>
 | 
			
		||||
            }
 | 
			
		||||
 | 
			
		||||
            obs = stepReply.getObservation();
 | 
			
		||||
 | 
			
		||||
            reward += stepReply.getReward();
 | 
			
		||||
 | 
			
		||||
            incrementStep();
 | 
			
		||||
 | 
			
		||||
@ -158,7 +158,7 @@ public abstract class QLearningDiscrete<O extends Encodable> extends QLearning<O
 | 
			
		||||
        accuReward += stepReply.getReward() * configuration.getRewardFactor();
 | 
			
		||||
 | 
			
		||||
        //if it's not a skipped frame, you can do a step of training
 | 
			
		||||
        if (!obs.isSkipped() || stepReply.isDone()) {
 | 
			
		||||
        if (!obs.isSkipped()) {
 | 
			
		||||
 | 
			
		||||
            // Add experience
 | 
			
		||||
            if(pendingTransition != null) {
 | 
			
		||||
 | 
			
		||||
@ -1,130 +0,0 @@
 | 
			
		||||
/*******************************************************************************
 | 
			
		||||
 * Copyright (c) 2015-2019 Skymind, Inc.
 | 
			
		||||
 *
 | 
			
		||||
 * This program and the accompanying materials are made available under the
 | 
			
		||||
 * terms of the Apache License, Version 2.0 which is available at
 | 
			
		||||
 * https://www.apache.org/licenses/LICENSE-2.0.
 | 
			
		||||
 *
 | 
			
		||||
 * Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
 * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 | 
			
		||||
 * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
 | 
			
		||||
 * License for the specific language governing permissions and limitations
 | 
			
		||||
 * under the License.
 | 
			
		||||
 *
 | 
			
		||||
 * SPDX-License-Identifier: Apache-2.0
 | 
			
		||||
 ******************************************************************************/
 | 
			
		||||
 | 
			
		||||
package org.deeplearning4j.rl4j.observation.preprocessor;
 | 
			
		||||
 | 
			
		||||
import org.deeplearning4j.rl4j.observation.preprocessor.pooling.ChannelStackPoolContentAssembler;
 | 
			
		||||
import org.deeplearning4j.rl4j.observation.preprocessor.pooling.PoolContentAssembler;
 | 
			
		||||
import org.deeplearning4j.rl4j.observation.preprocessor.pooling.CircularFifoObservationPool;
 | 
			
		||||
import org.deeplearning4j.rl4j.observation.preprocessor.pooling.ObservationPool;
 | 
			
		||||
import org.nd4j.base.Preconditions;
 | 
			
		||||
import org.nd4j.linalg.api.ndarray.INDArray;
 | 
			
		||||
import org.nd4j.linalg.dataset.api.DataSet;
 | 
			
		||||
 | 
			
		||||
/**
 | 
			
		||||
 * The PoolingDataSetPreProcessor will accumulate features from incoming DataSets and will assemble its content
 | 
			
		||||
 * into a DataSet containing a single example.
 | 
			
		||||
 *
 | 
			
		||||
 * There are two special cases:
 | 
			
		||||
 *    1) preProcess will return without doing anything if the input DataSet is empty
 | 
			
		||||
 *    2) When the pool has not yet filled, the data from the incoming DataSet is stored in the pool but the DataSet is emptied
 | 
			
		||||
 *       on exit.
 | 
			
		||||
 * <br>
 | 
			
		||||
 * The PoolingDataSetPreProcessor requires two sub components: <br>
 | 
			
		||||
 *    1) The ObservationPool that supervises what and how input observations are kept. (ex.: Circular FIFO, trailing min/max/avg, etc...)
 | 
			
		||||
 *       The default is a Circular FIFO.
 | 
			
		||||
 *    2) The PoolContentAssembler that will assemble the pool content into a resulting single INDArray. (ex.: stacked along a dimention, squashed into a single observation, etc...)
 | 
			
		||||
 *       The default is stacking along the dimension 0.
 | 
			
		||||
 *
 | 
			
		||||
 * @author Alexandre Boulanger
 | 
			
		||||
 */
 | 
			
		||||
public class PoolingDataSetPreProcessor extends ResettableDataSetPreProcessor {
 | 
			
		||||
    private final ObservationPool observationPool;
 | 
			
		||||
    private final PoolContentAssembler poolContentAssembler;
 | 
			
		||||
 | 
			
		||||
    protected PoolingDataSetPreProcessor(PoolingDataSetPreProcessor.Builder builder)
 | 
			
		||||
    {
 | 
			
		||||
        observationPool = builder.observationPool;
 | 
			
		||||
        poolContentAssembler = builder.poolContentAssembler;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    /**
 | 
			
		||||
     * Note: preProcess will empty the processed dataset if the pool has not filled. Empty datasets should ignored by the
 | 
			
		||||
     * Policy/Learning class and other DataSetPreProcessors
 | 
			
		||||
     *
 | 
			
		||||
     * @param dataSet
 | 
			
		||||
     */
 | 
			
		||||
    @Override
 | 
			
		||||
    public void preProcess(DataSet dataSet) {
 | 
			
		||||
        Preconditions.checkNotNull(dataSet, "Encountered null dataSet");
 | 
			
		||||
 | 
			
		||||
        if(dataSet.isEmpty()) {
 | 
			
		||||
            return;
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        Preconditions.checkArgument(dataSet.numExamples() == 1, "Pooling datasets conatining more than one example is not supported");
 | 
			
		||||
 | 
			
		||||
        // store a duplicate in the pool
 | 
			
		||||
        observationPool.add(dataSet.getFeatures().slice(0, 0).dup());
 | 
			
		||||
        if(!observationPool.isAtFullCapacity()) {
 | 
			
		||||
            dataSet.setFeatures(null);
 | 
			
		||||
            return;
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        INDArray result = poolContentAssembler.assemble(observationPool.get());
 | 
			
		||||
 | 
			
		||||
        // return a DataSet containing only 1 example (the result)
 | 
			
		||||
        long[] resultShape = result.shape();
 | 
			
		||||
        long[] newShape = new long[resultShape.length + 1];
 | 
			
		||||
        newShape[0] = 1;
 | 
			
		||||
        System.arraycopy(resultShape, 0, newShape, 1, resultShape.length);
 | 
			
		||||
 | 
			
		||||
        dataSet.setFeatures(result.reshape(newShape));
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    public static PoolingDataSetPreProcessor.Builder builder() {
 | 
			
		||||
        return new PoolingDataSetPreProcessor.Builder();
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    @Override
 | 
			
		||||
    public void reset() {
 | 
			
		||||
        observationPool.reset();
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    public static class Builder {
 | 
			
		||||
        private ObservationPool observationPool;
 | 
			
		||||
        private PoolContentAssembler poolContentAssembler;
 | 
			
		||||
 | 
			
		||||
        /**
 | 
			
		||||
         * Default is CircularFifoObservationPool
 | 
			
		||||
         */
 | 
			
		||||
        public PoolingDataSetPreProcessor.Builder observablePool(ObservationPool observationPool) {
 | 
			
		||||
            this.observationPool = observationPool;
 | 
			
		||||
            return this;
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        /**
 | 
			
		||||
         * Default is ChannelStackPoolContentAssembler
 | 
			
		||||
         */
 | 
			
		||||
        public PoolingDataSetPreProcessor.Builder poolContentAssembler(PoolContentAssembler poolContentAssembler) {
 | 
			
		||||
            this.poolContentAssembler = poolContentAssembler;
 | 
			
		||||
            return this;
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        public PoolingDataSetPreProcessor build() {
 | 
			
		||||
            if(observationPool == null) {
 | 
			
		||||
                observationPool = new CircularFifoObservationPool();
 | 
			
		||||
            }
 | 
			
		||||
 | 
			
		||||
            if(poolContentAssembler == null) {
 | 
			
		||||
                poolContentAssembler = new ChannelStackPoolContentAssembler();
 | 
			
		||||
            }
 | 
			
		||||
 | 
			
		||||
            return new PoolingDataSetPreProcessor(this);
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
}
 | 
			
		||||
@ -1,62 +0,0 @@
 | 
			
		||||
/*******************************************************************************
 | 
			
		||||
 * Copyright (c) 2015-2019 Skymind, Inc.
 | 
			
		||||
 *
 | 
			
		||||
 * This program and the accompanying materials are made available under the
 | 
			
		||||
 * terms of the Apache License, Version 2.0 which is available at
 | 
			
		||||
 * https://www.apache.org/licenses/LICENSE-2.0.
 | 
			
		||||
 *
 | 
			
		||||
 * Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
 * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 | 
			
		||||
 * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
 | 
			
		||||
 * License for the specific language governing permissions and limitations
 | 
			
		||||
 * under the License.
 | 
			
		||||
 *
 | 
			
		||||
 * SPDX-License-Identifier: Apache-2.0
 | 
			
		||||
 ******************************************************************************/
 | 
			
		||||
 | 
			
		||||
package org.deeplearning4j.rl4j.observation.preprocessor;
 | 
			
		||||
 | 
			
		||||
import lombok.Builder;
 | 
			
		||||
import org.nd4j.base.Preconditions;
 | 
			
		||||
import org.nd4j.linalg.dataset.api.DataSet;
 | 
			
		||||
 | 
			
		||||
/**
 | 
			
		||||
 * The SkippingDataSetPreProcessor will either do nothing to the input (when not skipped) or will empty
 | 
			
		||||
 * the input DataSet when skipping.
 | 
			
		||||
 *
 | 
			
		||||
 * @author Alexandre Boulanger
 | 
			
		||||
 */
 | 
			
		||||
public class SkippingDataSetPreProcessor extends ResettableDataSetPreProcessor {
 | 
			
		||||
 | 
			
		||||
    private final int skipFrame;
 | 
			
		||||
 | 
			
		||||
    private int currentIdx = 0;
 | 
			
		||||
 | 
			
		||||
    /**
 | 
			
		||||
     * @param skipFrame For example, a skipFrame of 4 will skip 3 out of 4 observations.
 | 
			
		||||
     */
 | 
			
		||||
    @Builder
 | 
			
		||||
    public SkippingDataSetPreProcessor(int skipFrame) {
 | 
			
		||||
        Preconditions.checkArgument(skipFrame > 0, "skipFrame must be greater than 0, got %s", skipFrame);
 | 
			
		||||
        this.skipFrame = skipFrame;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    @Override
 | 
			
		||||
    public void preProcess(DataSet dataSet) {
 | 
			
		||||
        Preconditions.checkNotNull(dataSet, "Encountered null dataSet");
 | 
			
		||||
 | 
			
		||||
        if(dataSet.isEmpty()) {
 | 
			
		||||
            return;
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        if(currentIdx++ % skipFrame != 0) {
 | 
			
		||||
            dataSet.setFeatures(null);
 | 
			
		||||
            dataSet.setLabels(null);
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    @Override
 | 
			
		||||
    public void reset() {
 | 
			
		||||
        currentIdx = 0;
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
@ -0,0 +1,35 @@
 | 
			
		||||
/*******************************************************************************
 | 
			
		||||
 * Copyright (c) 2020 Konduit K.K.
 | 
			
		||||
 *
 | 
			
		||||
 * This program and the accompanying materials are made available under the
 | 
			
		||||
 * terms of the Apache License, Version 2.0 which is available at
 | 
			
		||||
 * https://www.apache.org/licenses/LICENSE-2.0.
 | 
			
		||||
 *
 | 
			
		||||
 * Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
 * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 | 
			
		||||
 * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
 | 
			
		||||
 * License for the specific language governing permissions and limitations
 | 
			
		||||
 * under the License.
 | 
			
		||||
 *
 | 
			
		||||
 * SPDX-License-Identifier: Apache-2.0
 | 
			
		||||
 ******************************************************************************/
 | 
			
		||||
package org.deeplearning4j.rl4j.observation.transform;
 | 
			
		||||
 | 
			
		||||
import java.util.Map;
 | 
			
		||||
 | 
			
		||||
/**
 | 
			
		||||
 * Used with {@link TransformProcess TransformProcess} to filter-out an observation.
 | 
			
		||||
 *
 | 
			
		||||
 * @author Alexandre Boulanger
 | 
			
		||||
 */
 | 
			
		||||
public interface FilterOperation {
 | 
			
		||||
    /**
 | 
			
		||||
     * The logic that determines if the observation should be skipped.
 | 
			
		||||
     *
 | 
			
		||||
     * @param channelsData the name of the channel
 | 
			
		||||
     * @param currentObservationStep The step number if the observation in the current episode.
 | 
			
		||||
     * @param isFinalObservation true if this is the last observation of the episode
 | 
			
		||||
     * @return true if the observation should be skipped
 | 
			
		||||
     */
 | 
			
		||||
    boolean isSkipped(Map<String, Object> channelsData, int currentObservationStep, boolean isFinalObservation);
 | 
			
		||||
}
 | 
			
		||||
@ -1,32 +1,26 @@
 | 
			
		||||
/*******************************************************************************
 | 
			
		||||
 * Copyright (c) 2015-2019 Skymind, Inc.
 | 
			
		||||
 *
 | 
			
		||||
 * This program and the accompanying materials are made available under the
 | 
			
		||||
 * terms of the Apache License, Version 2.0 which is available at
 | 
			
		||||
 * https://www.apache.org/licenses/LICENSE-2.0.
 | 
			
		||||
 *
 | 
			
		||||
 * Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
 * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 | 
			
		||||
 * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
 | 
			
		||||
 * License for the specific language governing permissions and limitations
 | 
			
		||||
 * under the License.
 | 
			
		||||
 *
 | 
			
		||||
 * SPDX-License-Identifier: Apache-2.0
 | 
			
		||||
 ******************************************************************************/
 | 
			
		||||
 | 
			
		||||
package org.deeplearning4j.rl4j.observation.preprocessor.pooling;
 | 
			
		||||
 | 
			
		||||
import org.nd4j.linalg.api.ndarray.INDArray;
 | 
			
		||||
 | 
			
		||||
/**
 | 
			
		||||
 * ObservationPool is used with the PoolingDataSetPreProcessor. Used to supervise how data from the
 | 
			
		||||
 * PoolingDataSetPreProcessor is stored.
 | 
			
		||||
 *
 | 
			
		||||
 * @author Alexandre Boulanger
 | 
			
		||||
 */
 | 
			
		||||
public interface ObservationPool {
 | 
			
		||||
    void add(INDArray observation);
 | 
			
		||||
    INDArray[] get();
 | 
			
		||||
    boolean isAtFullCapacity();
 | 
			
		||||
    void reset();
 | 
			
		||||
}
 | 
			
		||||
/*******************************************************************************
 | 
			
		||||
 * Copyright (c) 2020 Konduit K.K.
 | 
			
		||||
 *
 | 
			
		||||
 * This program and the accompanying materials are made available under the
 | 
			
		||||
 * terms of the Apache License, Version 2.0 which is available at
 | 
			
		||||
 * https://www.apache.org/licenses/LICENSE-2.0.
 | 
			
		||||
 *
 | 
			
		||||
 * Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
 * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 | 
			
		||||
 * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
 | 
			
		||||
 * License for the specific language governing permissions and limitations
 | 
			
		||||
 * under the License.
 | 
			
		||||
 *
 | 
			
		||||
 * SPDX-License-Identifier: Apache-2.0
 | 
			
		||||
 ******************************************************************************/
 | 
			
		||||
package org.deeplearning4j.rl4j.observation.transform;
 | 
			
		||||
 | 
			
		||||
/**
 | 
			
		||||
 * The {@link TransformProcess TransformProcess} will call reset() (at the start of an episode) of any step that implement this interface.
 | 
			
		||||
 */
 | 
			
		||||
public interface ResettableOperation {
 | 
			
		||||
    /**
 | 
			
		||||
     * Called by TransformProcess when an episode starts. See {@link TransformProcess#reset() TransformProcess.reset()}
 | 
			
		||||
     */
 | 
			
		||||
    void reset();
 | 
			
		||||
}
 | 
			
		||||
@ -0,0 +1,45 @@
 | 
			
		||||
/*******************************************************************************
 | 
			
		||||
 * Copyright (c) 2020 Konduit K.K.
 | 
			
		||||
 *
 | 
			
		||||
 * This program and the accompanying materials are made available under the
 | 
			
		||||
 * terms of the Apache License, Version 2.0 which is available at
 | 
			
		||||
 * https://www.apache.org/licenses/LICENSE-2.0.
 | 
			
		||||
 *
 | 
			
		||||
 * Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
 * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 | 
			
		||||
 * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
 | 
			
		||||
 * License for the specific language governing permissions and limitations
 | 
			
		||||
 * under the License.
 | 
			
		||||
 *
 | 
			
		||||
 * SPDX-License-Identifier: Apache-2.0
 | 
			
		||||
 ******************************************************************************/
 | 
			
		||||
package org.deeplearning4j.rl4j.observation.transform.filter;
 | 
			
		||||
 | 
			
		||||
import org.deeplearning4j.rl4j.observation.transform.FilterOperation;
 | 
			
		||||
import org.nd4j.base.Preconditions;
 | 
			
		||||
import java.util.Map;
 | 
			
		||||
 | 
			
		||||
/**
 | 
			
		||||
 * Used with {@link org.deeplearning4j.rl4j.observation.transform.TransformProcess TransformProcess}. Will cause the
 | 
			
		||||
 * transform process to skip a fixed number of frames between non skipped ones.
 | 
			
		||||
 *
 | 
			
		||||
 * @author Alexandre Boulanger
 | 
			
		||||
 */
 | 
			
		||||
public class UniformSkippingFilter implements FilterOperation {
 | 
			
		||||
 | 
			
		||||
    private final int skipFrame;
 | 
			
		||||
 | 
			
		||||
    /**
 | 
			
		||||
     * @param skipFrame Will cause the filter to keep (not skip) 1 frame every skipFrames.
 | 
			
		||||
     */
 | 
			
		||||
    public UniformSkippingFilter(int skipFrame) {
 | 
			
		||||
        Preconditions.checkArgument(skipFrame > 0, "skipFrame should be greater than 0");
 | 
			
		||||
 | 
			
		||||
        this.skipFrame = skipFrame;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    @Override
 | 
			
		||||
    public boolean isSkipped(Map<String, Object> channelsData, int currentObservationStep, boolean isFinalObservation) {
 | 
			
		||||
        return !isFinalObservation && (currentObservationStep % skipFrame != 0);
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
@ -0,0 +1,41 @@
 | 
			
		||||
/*******************************************************************************
 | 
			
		||||
 * Copyright (c) 2020 Konduit K.K.
 | 
			
		||||
 *
 | 
			
		||||
 * This program and the accompanying materials are made available under the
 | 
			
		||||
 * terms of the Apache License, Version 2.0 which is available at
 | 
			
		||||
 * https://www.apache.org/licenses/LICENSE-2.0.
 | 
			
		||||
 *
 | 
			
		||||
 * Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
 * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 | 
			
		||||
 * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
 | 
			
		||||
 * License for the specific language governing permissions and limitations
 | 
			
		||||
 * under the License.
 | 
			
		||||
 *
 | 
			
		||||
 * SPDX-License-Identifier: Apache-2.0
 | 
			
		||||
 ******************************************************************************/
 | 
			
		||||
package org.deeplearning4j.rl4j.observation.transform.legacy;
 | 
			
		||||
 | 
			
		||||
import org.bytedeco.javacv.OpenCVFrameConverter;
 | 
			
		||||
import org.bytedeco.opencv.opencv_core.Mat;
 | 
			
		||||
import org.datavec.api.transform.Operation;
 | 
			
		||||
import org.datavec.image.data.ImageWritable;
 | 
			
		||||
import org.deeplearning4j.rl4j.space.Encodable;
 | 
			
		||||
import org.nd4j.linalg.api.ndarray.INDArray;
 | 
			
		||||
import org.nd4j.linalg.factory.Nd4j;
 | 
			
		||||
 | 
			
		||||
import static org.bytedeco.opencv.global.opencv_core.CV_32FC;
 | 
			
		||||
 | 
			
		||||
public class EncodableToINDArrayTransform implements Operation<Encodable, INDArray> {
 | 
			
		||||
 | 
			
		||||
    private final int[] shape;
 | 
			
		||||
 | 
			
		||||
    public EncodableToINDArrayTransform(int[] shape) {
 | 
			
		||||
        this.shape = shape;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    @Override
 | 
			
		||||
    public INDArray transform(Encodable encodable) {
 | 
			
		||||
        return Nd4j.create(encodable.toArray()).reshape(shape);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
}
 | 
			
		||||
@ -0,0 +1,48 @@
 | 
			
		||||
/*******************************************************************************
 | 
			
		||||
 * Copyright (c) 2020 Konduit K.K.
 | 
			
		||||
 *
 | 
			
		||||
 * This program and the accompanying materials are made available under the
 | 
			
		||||
 * terms of the Apache License, Version 2.0 which is available at
 | 
			
		||||
 * https://www.apache.org/licenses/LICENSE-2.0.
 | 
			
		||||
 *
 | 
			
		||||
 * Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
 * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 | 
			
		||||
 * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
 | 
			
		||||
 * License for the specific language governing permissions and limitations
 | 
			
		||||
 * under the License.
 | 
			
		||||
 *
 | 
			
		||||
 * SPDX-License-Identifier: Apache-2.0
 | 
			
		||||
 ******************************************************************************/
 | 
			
		||||
package org.deeplearning4j.rl4j.observation.transform.legacy;
 | 
			
		||||
 | 
			
		||||
import org.bytedeco.javacv.OpenCVFrameConverter;
 | 
			
		||||
import org.bytedeco.opencv.opencv_core.Mat;
 | 
			
		||||
import org.datavec.api.transform.Operation;
 | 
			
		||||
import org.datavec.image.data.ImageWritable;
 | 
			
		||||
import org.deeplearning4j.rl4j.space.Encodable;
 | 
			
		||||
import org.nd4j.linalg.api.ndarray.INDArray;
 | 
			
		||||
import org.nd4j.linalg.factory.Nd4j;
 | 
			
		||||
 | 
			
		||||
import static org.bytedeco.opencv.global.opencv_core.CV_32FC;
 | 
			
		||||
 | 
			
		||||
public class EncodableToImageWriteableTransform implements Operation<Encodable, ImageWritable> {
 | 
			
		||||
 | 
			
		||||
    private final OpenCVFrameConverter.ToMat converter = new OpenCVFrameConverter.ToMat();
 | 
			
		||||
    private final int height;
 | 
			
		||||
    private final int width;
 | 
			
		||||
    private final int colorChannels;
 | 
			
		||||
 | 
			
		||||
    public EncodableToImageWriteableTransform(int height, int width, int colorChannels) {
 | 
			
		||||
        this.height = height;
 | 
			
		||||
        this.width = width;
 | 
			
		||||
        this.colorChannels = colorChannels;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    @Override
 | 
			
		||||
    public ImageWritable transform(Encodable encodable) {
 | 
			
		||||
        INDArray indArray = Nd4j.create((encodable).toArray()).reshape(height, width, colorChannels);
 | 
			
		||||
        Mat mat = new Mat(height, width, CV_32FC(3), indArray.data().pointer());
 | 
			
		||||
        return new ImageWritable(converter.convert(mat));
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
}
 | 
			
		||||
@ -0,0 +1,37 @@
 | 
			
		||||
package org.deeplearning4j.rl4j.observation.transform.legacy;
 | 
			
		||||
 | 
			
		||||
import org.datavec.api.transform.Operation;
 | 
			
		||||
import org.datavec.image.data.ImageWritable;
 | 
			
		||||
import org.datavec.image.loader.NativeImageLoader;
 | 
			
		||||
import org.deeplearning4j.rl4j.space.Encodable;
 | 
			
		||||
import org.nd4j.linalg.api.buffer.DataType;
 | 
			
		||||
import org.nd4j.linalg.api.ndarray.INDArray;
 | 
			
		||||
import org.nd4j.linalg.factory.Nd4j;
 | 
			
		||||
 | 
			
		||||
import java.io.IOException;
 | 
			
		||||
 | 
			
		||||
public class ImageWriteableToINDArrayTransform implements Operation<ImageWritable, INDArray> {
 | 
			
		||||
 | 
			
		||||
    private final int height;
 | 
			
		||||
    private final int width;
 | 
			
		||||
    private final NativeImageLoader loader;
 | 
			
		||||
 | 
			
		||||
    public ImageWriteableToINDArrayTransform(int height, int width) {
 | 
			
		||||
        this.height = height;
 | 
			
		||||
        this.width = width;
 | 
			
		||||
        this.loader = new NativeImageLoader(height, width);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    @Override
 | 
			
		||||
    public INDArray transform(ImageWritable imageWritable) {
 | 
			
		||||
        INDArray out = null;
 | 
			
		||||
        try {
 | 
			
		||||
            out = loader.asMatrix(imageWritable);
 | 
			
		||||
        } catch (IOException e) {
 | 
			
		||||
            e.printStackTrace();
 | 
			
		||||
        }
 | 
			
		||||
        out = out.reshape(1, height, width);
 | 
			
		||||
        INDArray compressed = out.castTo(DataType.UINT8);
 | 
			
		||||
        return compressed;
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
@ -0,0 +1,147 @@
 | 
			
		||||
/*******************************************************************************
 | 
			
		||||
 * Copyright (c) 2020 Konduit K.K.
 | 
			
		||||
 *
 | 
			
		||||
 * This program and the accompanying materials are made available under the
 | 
			
		||||
 * terms of the Apache License, Version 2.0 which is available at
 | 
			
		||||
 * https://www.apache.org/licenses/LICENSE-2.0.
 | 
			
		||||
 *
 | 
			
		||||
 * Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
 * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 | 
			
		||||
 * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
 | 
			
		||||
 * License for the specific language governing permissions and limitations
 | 
			
		||||
 * under the License.
 | 
			
		||||
 *
 | 
			
		||||
 * SPDX-License-Identifier: Apache-2.0
 | 
			
		||||
 ******************************************************************************/
 | 
			
		||||
package org.deeplearning4j.rl4j.observation.transform.operation;
 | 
			
		||||
 | 
			
		||||
import org.datavec.api.transform.Operation;
 | 
			
		||||
import org.deeplearning4j.rl4j.helper.INDArrayHelper;
 | 
			
		||||
import org.deeplearning4j.rl4j.observation.transform.ResettableOperation;
 | 
			
		||||
import org.deeplearning4j.rl4j.observation.transform.operation.historymerge.CircularFifoStore;
 | 
			
		||||
import org.deeplearning4j.rl4j.observation.transform.operation.historymerge.HistoryMergeAssembler;
 | 
			
		||||
import org.deeplearning4j.rl4j.observation.transform.operation.historymerge.HistoryMergeElementStore;
 | 
			
		||||
import org.deeplearning4j.rl4j.observation.transform.operation.historymerge.HistoryStackAssembler;
 | 
			
		||||
import org.nd4j.linalg.api.ndarray.INDArray;
 | 
			
		||||
 | 
			
		||||
/**
 | 
			
		||||
 * The HistoryMergeTransform will accumulate features from incoming INDArrays and will assemble its content
 | 
			
		||||
 * into a new INDArray containing a single example.
 | 
			
		||||
 *
 | 
			
		||||
 * This is used in scenarios where motion in an important element.
 | 
			
		||||
 *
 | 
			
		||||
 * There is a special case:
 | 
			
		||||
 *    * When the store is not full (not ready), the data from the incoming INDArray is stored but null is returned (will be interpreted as a skipped observation)
 | 
			
		||||
 * <br>
 | 
			
		||||
 * The HistoryMergeTransform requires two sub components: <br>
 | 
			
		||||
 *    1) The {@link HistoryMergeElementStore HistoryMergeElementStore} that supervises what and how input INDArrays are kept. (ex.: Circular FIFO, trailing min/max/avg, etc...)
 | 
			
		||||
 *       The default is a Circular FIFO.
 | 
			
		||||
 *    2) The {@link HistoryMergeAssembler HistoryMergeAssembler} that will assemble the store content into a resulting single INDArray. (ex.: stacked along a dimension, squashed into a single observation, etc...)
 | 
			
		||||
 *       The default is stacking along the dimension 0.
 | 
			
		||||
 *
 | 
			
		||||
 * @author Alexandre Boulanger
 | 
			
		||||
 */
 | 
			
		||||
public class HistoryMergeTransform implements Operation<INDArray, INDArray>, ResettableOperation {
 | 
			
		||||
 | 
			
		||||
    private final HistoryMergeElementStore historyMergeElementStore;
 | 
			
		||||
    private final HistoryMergeAssembler historyMergeAssembler;
 | 
			
		||||
    private final boolean shouldStoreCopy;
 | 
			
		||||
    private final boolean isFirstDimenstionBatch;
 | 
			
		||||
 | 
			
		||||
    private HistoryMergeTransform(Builder builder) {
 | 
			
		||||
        this.historyMergeElementStore = builder.historyMergeElementStore;
 | 
			
		||||
        this.historyMergeAssembler = builder.historyMergeAssembler;
 | 
			
		||||
        this.shouldStoreCopy = builder.shouldStoreCopy;
 | 
			
		||||
        this.isFirstDimenstionBatch = builder.isFirstDimenstionBatch;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    @Override
 | 
			
		||||
    public INDArray transform(INDArray input) {
 | 
			
		||||
        INDArray element;
 | 
			
		||||
        if(isFirstDimenstionBatch) {
 | 
			
		||||
            element = input.slice(0, 0);
 | 
			
		||||
        }
 | 
			
		||||
        else {
 | 
			
		||||
            element = input;
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        if(shouldStoreCopy) {
 | 
			
		||||
            element = element.dup();
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        historyMergeElementStore.add(element);
 | 
			
		||||
        if(!historyMergeElementStore.isReady()) {
 | 
			
		||||
            return null;
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        INDArray result = historyMergeAssembler.assemble(historyMergeElementStore.get());
 | 
			
		||||
 | 
			
		||||
        return INDArrayHelper.forceCorrectShape(result);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    @Override
 | 
			
		||||
    public void reset() {
 | 
			
		||||
        historyMergeElementStore.reset();
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    public static Builder builder() {
 | 
			
		||||
        return new Builder();
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    public static class Builder {
 | 
			
		||||
        private HistoryMergeElementStore historyMergeElementStore;
 | 
			
		||||
        private HistoryMergeAssembler historyMergeAssembler;
 | 
			
		||||
        private boolean shouldStoreCopy = false;
 | 
			
		||||
        private boolean isFirstDimenstionBatch = false;
 | 
			
		||||
 | 
			
		||||
        /**
 | 
			
		||||
         * Default is {@link CircularFifoStore CircularFifoStore}
 | 
			
		||||
         */
 | 
			
		||||
        public Builder elementStore(HistoryMergeElementStore store) {
 | 
			
		||||
            this.historyMergeElementStore = store;
 | 
			
		||||
            return this;
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        /**
 | 
			
		||||
         * Default is {@link HistoryStackAssembler HistoryStackAssembler}
 | 
			
		||||
         */
 | 
			
		||||
        public Builder assembler(HistoryMergeAssembler assembler) {
 | 
			
		||||
            this.historyMergeAssembler = assembler;
 | 
			
		||||
            return this;
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        /**
 | 
			
		||||
         * If true, tell the HistoryMergeTransform to store copies of incoming INDArrays.
 | 
			
		||||
         * (To prevent later in-place changes to a stored INDArray from changing what has been stored)
 | 
			
		||||
         *
 | 
			
		||||
         * Default is false
 | 
			
		||||
         */
 | 
			
		||||
        public Builder shouldStoreCopy(boolean shouldStoreCopy) {
 | 
			
		||||
            this.shouldStoreCopy = shouldStoreCopy;
 | 
			
		||||
            return this;
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        /**
 | 
			
		||||
         * If true, tell the HistoryMergeTransform that the first dimension of the input INDArray is the batch count.
 | 
			
		||||
         * When this is the case, the HistoryMergeTransform will slice the input like this [batch, height, width] -> [height, width]
 | 
			
		||||
         *
 | 
			
		||||
         * Default is false
 | 
			
		||||
         */
 | 
			
		||||
        public Builder isFirstDimenstionBatch(boolean isFirstDimenstionBatch) {
 | 
			
		||||
            this.isFirstDimenstionBatch = isFirstDimenstionBatch;
 | 
			
		||||
            return this;
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        public HistoryMergeTransform build() {
 | 
			
		||||
            if(historyMergeElementStore == null) {
 | 
			
		||||
                historyMergeElementStore = new CircularFifoStore();
 | 
			
		||||
            }
 | 
			
		||||
 | 
			
		||||
            if(historyMergeAssembler == null) {
 | 
			
		||||
                historyMergeAssembler = new HistoryStackAssembler();
 | 
			
		||||
            }
 | 
			
		||||
 | 
			
		||||
            return new HistoryMergeTransform(this);
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
@ -1,95 +1,82 @@
 | 
			
		||||
/*******************************************************************************
 | 
			
		||||
 * Copyright (c) 2015-2019 Skymind, Inc.
 | 
			
		||||
 *
 | 
			
		||||
 * This program and the accompanying materials are made available under the
 | 
			
		||||
 * terms of the Apache License, Version 2.0 which is available at
 | 
			
		||||
 * https://www.apache.org/licenses/LICENSE-2.0.
 | 
			
		||||
 *
 | 
			
		||||
 * Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
 * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 | 
			
		||||
 * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
 | 
			
		||||
 * License for the specific language governing permissions and limitations
 | 
			
		||||
 * under the License.
 | 
			
		||||
 *
 | 
			
		||||
 * SPDX-License-Identifier: Apache-2.0
 | 
			
		||||
 ******************************************************************************/
 | 
			
		||||
 | 
			
		||||
package org.deeplearning4j.rl4j.observation.preprocessor.pooling;
 | 
			
		||||
 | 
			
		||||
import org.apache.commons.collections4.queue.CircularFifoQueue;
 | 
			
		||||
import org.nd4j.base.Preconditions;
 | 
			
		||||
import org.nd4j.linalg.api.ndarray.INDArray;
 | 
			
		||||
import org.nd4j.linalg.factory.Nd4j;
 | 
			
		||||
 | 
			
		||||
/**
 | 
			
		||||
 * CircularFifoObservationPool is used with the PoolingDataSetPreProcessor. This pool is a first-in first-out queue
 | 
			
		||||
 * with a fixed size that replaces its oldest element if full.
 | 
			
		||||
 *
 | 
			
		||||
 * @author Alexandre Boulanger
 | 
			
		||||
 */
 | 
			
		||||
public class CircularFifoObservationPool implements ObservationPool {
 | 
			
		||||
    private static final int DEFAULT_POOL_SIZE = 4;
 | 
			
		||||
 | 
			
		||||
    private final CircularFifoQueue<INDArray> queue;
 | 
			
		||||
 | 
			
		||||
    private CircularFifoObservationPool(Builder builder) {
 | 
			
		||||
        queue = new CircularFifoQueue<>(builder.poolSize);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    public CircularFifoObservationPool()
 | 
			
		||||
    {
 | 
			
		||||
        this(DEFAULT_POOL_SIZE);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    public CircularFifoObservationPool(int poolSize)
 | 
			
		||||
    {
 | 
			
		||||
        Preconditions.checkArgument(poolSize > 0, "The pool size must be at least 1, got %s", poolSize);
 | 
			
		||||
        queue = new CircularFifoQueue<>(poolSize);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    /**
 | 
			
		||||
     * Add an element to the pool, if this addition would make the pool to overflow, the added element replaces the oldest one.
 | 
			
		||||
     * @param elem
 | 
			
		||||
     */
 | 
			
		||||
    public void add(INDArray elem) {
 | 
			
		||||
        queue.add(elem);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    /**
 | 
			
		||||
     * @return The content of the pool, returned in order from oldest to newest.
 | 
			
		||||
     */
 | 
			
		||||
    public INDArray[] get() {
 | 
			
		||||
        int size = queue.size();
 | 
			
		||||
        INDArray[] array = new INDArray[size];
 | 
			
		||||
        for (int i = 0; i < size; ++i) {
 | 
			
		||||
            array[i] = queue.get(i).castTo(Nd4j.dataType());
 | 
			
		||||
        }
 | 
			
		||||
        return array;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    public boolean isAtFullCapacity() {
 | 
			
		||||
        return queue.isAtFullCapacity();
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    @Override
 | 
			
		||||
    public void reset() {
 | 
			
		||||
        queue.clear();
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    public static Builder builder() {
 | 
			
		||||
        return new Builder();
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    public static class Builder {
 | 
			
		||||
        private int poolSize = DEFAULT_POOL_SIZE;
 | 
			
		||||
 | 
			
		||||
        public Builder poolSize(int poolSize) {
 | 
			
		||||
            this.poolSize = poolSize;
 | 
			
		||||
            return this;
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        public CircularFifoObservationPool build() {
 | 
			
		||||
            return new CircularFifoObservationPool(this);
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
/*******************************************************************************
 | 
			
		||||
 * Copyright (c) 2020 Konduit K.K.
 | 
			
		||||
 *
 | 
			
		||||
 * This program and the accompanying materials are made available under the
 | 
			
		||||
 * terms of the Apache License, Version 2.0 which is available at
 | 
			
		||||
 * https://www.apache.org/licenses/LICENSE-2.0.
 | 
			
		||||
 *
 | 
			
		||||
 * Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
 * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 | 
			
		||||
 * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
 | 
			
		||||
 * License for the specific language governing permissions and limitations
 | 
			
		||||
 * under the License.
 | 
			
		||||
 *
 | 
			
		||||
 * SPDX-License-Identifier: Apache-2.0
 | 
			
		||||
 ******************************************************************************/
 | 
			
		||||
package org.deeplearning4j.rl4j.observation.transform.operation.historymerge;
 | 
			
		||||
 | 
			
		||||
import org.apache.commons.collections4.queue.CircularFifoQueue;
 | 
			
		||||
import org.deeplearning4j.rl4j.observation.transform.operation.HistoryMergeTransform;
 | 
			
		||||
import org.nd4j.base.Preconditions;
 | 
			
		||||
import org.nd4j.linalg.api.ndarray.INDArray;
 | 
			
		||||
import org.nd4j.linalg.factory.Nd4j;
 | 
			
		||||
 | 
			
		||||
/**
 | 
			
		||||
 * CircularFifoStore is used with the {@link HistoryMergeTransform HistoryMergeTransform}. This store is a first-in first-out queue
 | 
			
		||||
 * with a fixed size that replaces its oldest element if full.
 | 
			
		||||
 *
 | 
			
		||||
 * @author Alexandre Boulanger
 | 
			
		||||
 */
 | 
			
		||||
public class CircularFifoStore implements HistoryMergeElementStore {
 | 
			
		||||
    private static final int DEFAULT_STORE_SIZE = 4;
 | 
			
		||||
 | 
			
		||||
    private final CircularFifoQueue<INDArray> queue;
 | 
			
		||||
 | 
			
		||||
    public CircularFifoStore() {
 | 
			
		||||
        this(DEFAULT_STORE_SIZE);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    public CircularFifoStore(int size) {
 | 
			
		||||
        Preconditions.checkArgument(size > 0, "The size must be at least 1, got %s", size);
 | 
			
		||||
        queue = new CircularFifoQueue<>(size);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    /**
 | 
			
		||||
     * Add an element to the store, if this addition would make the store to overflow, the new element replaces the oldest.
 | 
			
		||||
     * @param elem
 | 
			
		||||
     */
 | 
			
		||||
    @Override
 | 
			
		||||
    public void add(INDArray elem) {
 | 
			
		||||
        queue.add(elem);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    /**
 | 
			
		||||
     * @return The content of the store, returned in order from oldest to newest.
 | 
			
		||||
     */
 | 
			
		||||
    @Override
 | 
			
		||||
    public INDArray[] get() {
 | 
			
		||||
        int size = queue.size();
 | 
			
		||||
        INDArray[] array = new INDArray[size];
 | 
			
		||||
        for (int i = 0; i < size; ++i) {
 | 
			
		||||
            array[i] = queue.get(i).castTo(Nd4j.dataType());
 | 
			
		||||
        }
 | 
			
		||||
        return array;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    /**
 | 
			
		||||
     * The CircularFifoStore needs to be completely filled before being ready.
 | 
			
		||||
     * @return false when the number of elements in the store is less than the store capacity (default is 4)
 | 
			
		||||
     */
 | 
			
		||||
    @Override
 | 
			
		||||
    public boolean isReady() {
 | 
			
		||||
        return queue.isAtFullCapacity();
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    /**
 | 
			
		||||
     * Clears the store.
 | 
			
		||||
     */
 | 
			
		||||
    @Override
 | 
			
		||||
    public void reset() {
 | 
			
		||||
        queue.clear();
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
@ -0,0 +1,35 @@
 | 
			
		||||
/*******************************************************************************
 | 
			
		||||
 * Copyright (c) 2020 Konduit K.K.
 | 
			
		||||
 *
 | 
			
		||||
 * This program and the accompanying materials are made available under the
 | 
			
		||||
 * terms of the Apache License, Version 2.0 which is available at
 | 
			
		||||
 * https://www.apache.org/licenses/LICENSE-2.0.
 | 
			
		||||
 *
 | 
			
		||||
 * Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
 * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 | 
			
		||||
 * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
 | 
			
		||||
 * License for the specific language governing permissions and limitations
 | 
			
		||||
 * under the License.
 | 
			
		||||
 *
 | 
			
		||||
 * SPDX-License-Identifier: Apache-2.0
 | 
			
		||||
 ******************************************************************************/
 | 
			
		||||
package org.deeplearning4j.rl4j.observation.transform.operation.historymerge;
 | 
			
		||||
 | 
			
		||||
import org.deeplearning4j.rl4j.observation.transform.operation.HistoryMergeTransform;
 | 
			
		||||
import org.nd4j.linalg.api.ndarray.INDArray;
 | 
			
		||||
 | 
			
		||||
/**
 | 
			
		||||
 * A HistoryMergeAssembler is used with the {@link HistoryMergeTransform HistoryMergeTransform}. This interface defines how the array of INDArray
 | 
			
		||||
 * given by the {@link HistoryMergeElementStore HistoryMergeElementStore} is packaged into the single INDArray that will be
 | 
			
		||||
 * returned by the HistoryMergeTransform
 | 
			
		||||
 *
 | 
			
		||||
 * @author Alexandre Boulanger
 | 
			
		||||
 */
 | 
			
		||||
public interface HistoryMergeAssembler {
 | 
			
		||||
    /**
 | 
			
		||||
     * Assemble an array of INDArray into a single INArray
 | 
			
		||||
     * @param elements The input INDArray[]
 | 
			
		||||
     * @return the assembled INDArray
 | 
			
		||||
     */
 | 
			
		||||
    INDArray assemble(INDArray[] elements);
 | 
			
		||||
}
 | 
			
		||||
@ -0,0 +1,51 @@
 | 
			
		||||
/*******************************************************************************
 | 
			
		||||
 * Copyright (c) 2020 Konduit K.K.
 | 
			
		||||
 *
 | 
			
		||||
 * This program and the accompanying materials are made available under the
 | 
			
		||||
 * terms of the Apache License, Version 2.0 which is available at
 | 
			
		||||
 * https://www.apache.org/licenses/LICENSE-2.0.
 | 
			
		||||
 *
 | 
			
		||||
 * Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
 * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 | 
			
		||||
 * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
 | 
			
		||||
 * License for the specific language governing permissions and limitations
 | 
			
		||||
 * under the License.
 | 
			
		||||
 *
 | 
			
		||||
 * SPDX-License-Identifier: Apache-2.0
 | 
			
		||||
 ******************************************************************************/
 | 
			
		||||
package org.deeplearning4j.rl4j.observation.transform.operation.historymerge;
 | 
			
		||||
 | 
			
		||||
import org.deeplearning4j.rl4j.observation.transform.operation.HistoryMergeTransform;
 | 
			
		||||
import org.nd4j.linalg.api.ndarray.INDArray;
 | 
			
		||||
 | 
			
		||||
/**
 | 
			
		||||
 * HistoryMergeElementStore is used with the {@link HistoryMergeTransform HistoryMergeTransform}. Used to supervise how data from the
 | 
			
		||||
 * HistoryMergeTransform is stored.
 | 
			
		||||
 *
 | 
			
		||||
 * @author Alexandre Boulanger
 | 
			
		||||
 */
 | 
			
		||||
public interface HistoryMergeElementStore {
 | 
			
		||||
    /**
 | 
			
		||||
     * Add an element into the store
 | 
			
		||||
     * @param observation
 | 
			
		||||
     */
 | 
			
		||||
    void add(INDArray observation);
 | 
			
		||||
 | 
			
		||||
    /**
 | 
			
		||||
     * Get the content of the store
 | 
			
		||||
     * @return the content of the store
 | 
			
		||||
     */
 | 
			
		||||
    INDArray[] get();
 | 
			
		||||
 | 
			
		||||
    /**
 | 
			
		||||
     * Used to tell the HistoryMergeTransform that the store is ready. The HistoryMergeTransform will tell the {@link org.deeplearning4j.rl4j.observation.transform.TransformProcess TransformProcess}
 | 
			
		||||
     * to skip the observation is the store is not ready.
 | 
			
		||||
     * @return true if the store is ready
 | 
			
		||||
     */
 | 
			
		||||
    boolean isReady();
 | 
			
		||||
 | 
			
		||||
    /**
 | 
			
		||||
     * Resets the store to an initial state.
 | 
			
		||||
     */
 | 
			
		||||
    void reset();
 | 
			
		||||
}
 | 
			
		||||
@ -1,53 +1,52 @@
 | 
			
		||||
/*******************************************************************************
 | 
			
		||||
 * Copyright (c) 2015-2019 Skymind, Inc.
 | 
			
		||||
 *
 | 
			
		||||
 * This program and the accompanying materials are made available under the
 | 
			
		||||
 * terms of the Apache License, Version 2.0 which is available at
 | 
			
		||||
 * https://www.apache.org/licenses/LICENSE-2.0.
 | 
			
		||||
 *
 | 
			
		||||
 * Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
 * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 | 
			
		||||
 * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
 | 
			
		||||
 * License for the specific language governing permissions and limitations
 | 
			
		||||
 * under the License.
 | 
			
		||||
 *
 | 
			
		||||
 * SPDX-License-Identifier: Apache-2.0
 | 
			
		||||
 ******************************************************************************/
 | 
			
		||||
 | 
			
		||||
package org.deeplearning4j.rl4j.observation.preprocessor.pooling;
 | 
			
		||||
 | 
			
		||||
import org.nd4j.linalg.api.ndarray.INDArray;
 | 
			
		||||
import org.nd4j.linalg.factory.Nd4j;
 | 
			
		||||
 | 
			
		||||
/**
 | 
			
		||||
 * ChannelStackPoolContentAssembler is used with the PoolingDataSetPreProcessor. This assembler will
 | 
			
		||||
 * stack along the dimension 0. For example if the pool elements are of shape [ Height, Width ]
 | 
			
		||||
 * the output will be of shape [ Stacked, Height, Width ]
 | 
			
		||||
 *
 | 
			
		||||
 * @author Alexandre Boulanger
 | 
			
		||||
 */
 | 
			
		||||
public class ChannelStackPoolContentAssembler implements PoolContentAssembler {
 | 
			
		||||
 | 
			
		||||
    /**
 | 
			
		||||
     * Will return a new INDArray with one more dimension and with poolContent stacked along dimension 0.
 | 
			
		||||
     *
 | 
			
		||||
     * @param poolContent Array of INDArray
 | 
			
		||||
     * @return A new INDArray with 1 more dimension than the input elements
 | 
			
		||||
     */
 | 
			
		||||
    @Override
 | 
			
		||||
    public INDArray assemble(INDArray[] poolContent)
 | 
			
		||||
    {
 | 
			
		||||
        // build the new shape
 | 
			
		||||
        long[] elementShape = poolContent[0].shape();
 | 
			
		||||
        long[] newShape = new long[elementShape.length + 1];
 | 
			
		||||
        newShape[0] = poolContent.length;
 | 
			
		||||
        System.arraycopy(elementShape, 0, newShape, 1, elementShape.length);
 | 
			
		||||
 | 
			
		||||
        // put pool elements in result
 | 
			
		||||
        INDArray result = Nd4j.create(newShape);
 | 
			
		||||
        for(int i = 0; i < poolContent.length; ++i) {
 | 
			
		||||
            result.putRow(i, poolContent[i]);
 | 
			
		||||
        }
 | 
			
		||||
        return result;
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
/*******************************************************************************
 | 
			
		||||
 * Copyright (c) 2020 Konduit K.K.
 | 
			
		||||
 *
 | 
			
		||||
 * This program and the accompanying materials are made available under the
 | 
			
		||||
 * terms of the Apache License, Version 2.0 which is available at
 | 
			
		||||
 * https://www.apache.org/licenses/LICENSE-2.0.
 | 
			
		||||
 *
 | 
			
		||||
 * Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
 * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 | 
			
		||||
 * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
 | 
			
		||||
 * License for the specific language governing permissions and limitations
 | 
			
		||||
 * under the License.
 | 
			
		||||
 *
 | 
			
		||||
 * SPDX-License-Identifier: Apache-2.0
 | 
			
		||||
 ******************************************************************************/
 | 
			
		||||
 | 
			
		||||
package org.deeplearning4j.rl4j.observation.transform.operation.historymerge;
 | 
			
		||||
 | 
			
		||||
import org.nd4j.linalg.api.ndarray.INDArray;
 | 
			
		||||
import org.nd4j.linalg.factory.Nd4j;
 | 
			
		||||
 | 
			
		||||
/**
 | 
			
		||||
 * HistoryStackAssembler is used with the HistoryMergeTransform. This assembler will
 | 
			
		||||
 * stack along the dimension 0. For example if the store elements are of shape [ Height, Width ]
 | 
			
		||||
 * the output will be of shape [ Stacked, Height, Width ]
 | 
			
		||||
 *
 | 
			
		||||
 * @author Alexandre Boulanger
 | 
			
		||||
 */
 | 
			
		||||
public class HistoryStackAssembler implements HistoryMergeAssembler {
 | 
			
		||||
 | 
			
		||||
    /**
 | 
			
		||||
     * Will return a new INDArray with one more dimension and with elements stacked along dimension 0.
 | 
			
		||||
     *
 | 
			
		||||
     * @param elements Array of INDArray
 | 
			
		||||
     * @return A new INDArray with 1 more dimension than the input elements
 | 
			
		||||
     */
 | 
			
		||||
    @Override
 | 
			
		||||
    public INDArray assemble(INDArray[] elements) {
 | 
			
		||||
        // build the new shape
 | 
			
		||||
        long[] elementShape = elements[0].shape();
 | 
			
		||||
        long[] newShape = new long[elementShape.length + 1];
 | 
			
		||||
        newShape[0] = elements.length;
 | 
			
		||||
        System.arraycopy(elementShape, 0, newShape, 1, elementShape.length);
 | 
			
		||||
 | 
			
		||||
        // stack the elements in result on the dimension 0
 | 
			
		||||
        INDArray result = Nd4j.create(newShape);
 | 
			
		||||
        for(int i = 0; i < elements.length; ++i) {
 | 
			
		||||
            result.putRow(i, elements[i]);
 | 
			
		||||
        }
 | 
			
		||||
        return result;
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
@ -89,7 +89,7 @@ public abstract class Policy<O, A> implements IPolicy<O, A> {
 | 
			
		||||
        getNeuralNet().reset();
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    private <AS extends ActionSpace<A>> Learning.InitMdp<Observation> refacInitMdp(LegacyMDPWrapper<O, A, AS> mdpWrapper, IHistoryProcessor hp, RefacEpochStepCounter epochStepCounter) {
 | 
			
		||||
    protected <AS extends ActionSpace<A>> Learning.InitMdp<Observation> refacInitMdp(LegacyMDPWrapper<O, A, AS> mdpWrapper, IHistoryProcessor hp, RefacEpochStepCounter epochStepCounter) {
 | 
			
		||||
        epochStepCounter.setCurrentEpochStep(0);
 | 
			
		||||
 | 
			
		||||
        double reward = 0;
 | 
			
		||||
 | 
			
		||||
@ -0,0 +1,38 @@
 | 
			
		||||
package org.deeplearning4j.rl4j.helper;
 | 
			
		||||
 | 
			
		||||
import org.junit.Test;
 | 
			
		||||
import org.nd4j.linalg.api.ndarray.INDArray;
 | 
			
		||||
import org.nd4j.linalg.factory.Nd4j;
 | 
			
		||||
 | 
			
		||||
import static org.junit.Assert.*;
 | 
			
		||||
 | 
			
		||||
public class INDArrayHelperTest {
 | 
			
		||||
    @Test
 | 
			
		||||
    public void when_inputHasIncorrectShape_expect_outputWithCorrectShape() {
 | 
			
		||||
        // Arrange
 | 
			
		||||
        INDArray input = Nd4j.create(new double[] { 1.0, 2.0, 3.0});
 | 
			
		||||
 | 
			
		||||
        // Act
 | 
			
		||||
        INDArray output = INDArrayHelper.forceCorrectShape(input);
 | 
			
		||||
 | 
			
		||||
        // Assert
 | 
			
		||||
        assertEquals(2, output.shape().length);
 | 
			
		||||
        assertEquals(1, output.shape()[0]);
 | 
			
		||||
        assertEquals(3, output.shape()[1]);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    @Test
 | 
			
		||||
    public void when_inputHasCorrectShape_expect_outputWithSameShape() {
 | 
			
		||||
        // Arrange
 | 
			
		||||
        INDArray input = Nd4j.create(new double[] { 1.0, 2.0, 3.0}).reshape(1, 3);
 | 
			
		||||
 | 
			
		||||
        // Act
 | 
			
		||||
        INDArray output = INDArrayHelper.forceCorrectShape(input);
 | 
			
		||||
 | 
			
		||||
        // Assert
 | 
			
		||||
        assertEquals(2, output.shape().length);
 | 
			
		||||
        assertEquals(1, output.shape()[0]);
 | 
			
		||||
        assertEquals(3, output.shape()[1]);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
}
 | 
			
		||||
@ -1,164 +0,0 @@
 | 
			
		||||
package org.deeplearning4j.rl4j.observation.preprocessor;
 | 
			
		||||
 | 
			
		||||
import org.deeplearning4j.rl4j.observation.preprocessor.pooling.ObservationPool;
 | 
			
		||||
import org.deeplearning4j.rl4j.observation.preprocessor.pooling.PoolContentAssembler;
 | 
			
		||||
import org.junit.Assert;
 | 
			
		||||
import org.junit.Test;
 | 
			
		||||
import org.nd4j.linalg.api.ndarray.INDArray;
 | 
			
		||||
import org.nd4j.linalg.dataset.DataSet;
 | 
			
		||||
import org.nd4j.linalg.factory.Nd4j;
 | 
			
		||||
 | 
			
		||||
import static junit.framework.TestCase.assertTrue;
 | 
			
		||||
import static org.junit.Assert.assertEquals;
 | 
			
		||||
 | 
			
		||||
public class PoolingDataSetPreProcessorTest {
 | 
			
		||||
 | 
			
		||||
    @Test(expected = NullPointerException.class)
 | 
			
		||||
    public void when_dataSetIsNull_expect_NullPointerException() {
 | 
			
		||||
        // Assemble
 | 
			
		||||
        PoolingDataSetPreProcessor sut = PoolingDataSetPreProcessor.builder().build();
 | 
			
		||||
 | 
			
		||||
        // Act
 | 
			
		||||
        sut.preProcess(null);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    @Test(expected = IllegalArgumentException.class)
 | 
			
		||||
    public void when_dataSetHasMoreThanOneExample_expect_IllegalArgumentException() {
 | 
			
		||||
        // Assemble
 | 
			
		||||
        PoolingDataSetPreProcessor sut = PoolingDataSetPreProcessor.builder().build();
 | 
			
		||||
 | 
			
		||||
        // Act
 | 
			
		||||
        sut.preProcess(new DataSet(Nd4j.rand(new long[] { 2, 2, 2 }), null));
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    @Test
 | 
			
		||||
    public void when_dataSetIsEmpty_expect_EmptyDataSet() {
 | 
			
		||||
        // Assemble
 | 
			
		||||
        PoolingDataSetPreProcessor sut = PoolingDataSetPreProcessor.builder().build();
 | 
			
		||||
        DataSet ds = new DataSet(null, null);
 | 
			
		||||
 | 
			
		||||
        // Act
 | 
			
		||||
        sut.preProcess(ds);
 | 
			
		||||
 | 
			
		||||
        // Assert
 | 
			
		||||
        Assert.assertTrue(ds.isEmpty());
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    @Test
 | 
			
		||||
    public void when_builderHasNoPoolOrAssembler_expect_defaultPoolBehavior() {
 | 
			
		||||
        // Arrange
 | 
			
		||||
        PoolingDataSetPreProcessor sut = PoolingDataSetPreProcessor.builder().build();
 | 
			
		||||
        DataSet[] observations = new DataSet[5];
 | 
			
		||||
        INDArray[] inputs = new INDArray[5];
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
        // Act
 | 
			
		||||
        for(int i = 0; i < 5; ++i) {
 | 
			
		||||
            inputs[i] = Nd4j.rand(new long[] { 1, 2, 2 });
 | 
			
		||||
            DataSet input = new DataSet(inputs[i], null);
 | 
			
		||||
            sut.preProcess(input);
 | 
			
		||||
            observations[i] = input;
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        // Assert
 | 
			
		||||
        assertTrue(observations[0].isEmpty());
 | 
			
		||||
        assertTrue(observations[1].isEmpty());
 | 
			
		||||
        assertTrue(observations[2].isEmpty());
 | 
			
		||||
 | 
			
		||||
        for(int i = 0; i < 4; ++i) {
 | 
			
		||||
            assertEquals(inputs[i].getDouble(new int[] { 0, 0, 0 }), observations[3].getFeatures().getDouble(new int[] { 0, i, 0, 0 }), 0.0001);
 | 
			
		||||
            assertEquals(inputs[i].getDouble(new int[] { 0, 0, 1 }), observations[3].getFeatures().getDouble(new int[] { 0, i, 0, 1 }), 0.0001);
 | 
			
		||||
            assertEquals(inputs[i].getDouble(new int[] { 0, 1, 0 }), observations[3].getFeatures().getDouble(new int[] { 0, i, 1, 0 }), 0.0001);
 | 
			
		||||
            assertEquals(inputs[i].getDouble(new int[] { 0, 1, 1 }), observations[3].getFeatures().getDouble(new int[] { 0, i, 1, 1 }), 0.0001);
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        for(int i = 0; i < 4; ++i) {
 | 
			
		||||
            assertEquals(inputs[i+1].getDouble(new int[] { 0, 0, 0 }), observations[4].getFeatures().getDouble(new int[] { 0, i, 0, 0 }), 0.0001);
 | 
			
		||||
            assertEquals(inputs[i+1].getDouble(new int[] { 0, 0, 1 }), observations[4].getFeatures().getDouble(new int[] { 0, i, 0, 1 }), 0.0001);
 | 
			
		||||
            assertEquals(inputs[i+1].getDouble(new int[] { 0, 1, 0 }), observations[4].getFeatures().getDouble(new int[] { 0, i, 1, 0 }), 0.0001);
 | 
			
		||||
            assertEquals(inputs[i+1].getDouble(new int[] { 0, 1, 1 }), observations[4].getFeatures().getDouble(new int[] { 0, i, 1, 1 }), 0.0001);
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    @Test
 | 
			
		||||
    public void when_builderHasPoolAndAssembler_expect_paramPoolAndAssemblerAreUsed() {
 | 
			
		||||
        // Arrange
 | 
			
		||||
        INDArray input = Nd4j.rand(1, 1);
 | 
			
		||||
        TestObservationPool pool = new TestObservationPool();
 | 
			
		||||
        TestPoolContentAssembler assembler = new TestPoolContentAssembler();
 | 
			
		||||
        PoolingDataSetPreProcessor sut = PoolingDataSetPreProcessor.builder()
 | 
			
		||||
                .observablePool(pool)
 | 
			
		||||
                .poolContentAssembler(assembler)
 | 
			
		||||
                .build();
 | 
			
		||||
 | 
			
		||||
        // Act
 | 
			
		||||
        sut.preProcess(new DataSet(input, null));
 | 
			
		||||
 | 
			
		||||
        // Assert
 | 
			
		||||
        assertTrue(pool.isAtFullCapacityCalled);
 | 
			
		||||
        assertTrue(pool.isGetCalled);
 | 
			
		||||
        assertEquals(input.getDouble(0), pool.observation.getDouble(0), 0.0);
 | 
			
		||||
        assertTrue(assembler.assembleIsCalled);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    @Test
 | 
			
		||||
    public void when_pastInputChanges_expect_outputNotChanged() {
 | 
			
		||||
        // Arrange
 | 
			
		||||
        INDArray input = Nd4j.zeros(1, 1);
 | 
			
		||||
        TestObservationPool pool = new TestObservationPool();
 | 
			
		||||
        TestPoolContentAssembler assembler = new TestPoolContentAssembler();
 | 
			
		||||
        PoolingDataSetPreProcessor sut = PoolingDataSetPreProcessor.builder()
 | 
			
		||||
                .observablePool(pool)
 | 
			
		||||
                .poolContentAssembler(assembler)
 | 
			
		||||
                .build();
 | 
			
		||||
 | 
			
		||||
        // Act
 | 
			
		||||
        sut.preProcess(new DataSet(input, null));
 | 
			
		||||
        input.putScalar(0, 0, 1.0);
 | 
			
		||||
 | 
			
		||||
        // Assert
 | 
			
		||||
        assertEquals(0.0, pool.observation.getDouble(0), 0.0);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    private static class TestObservationPool implements ObservationPool {
 | 
			
		||||
 | 
			
		||||
        public INDArray observation;
 | 
			
		||||
        public boolean isGetCalled;
 | 
			
		||||
        public boolean isAtFullCapacityCalled;
 | 
			
		||||
        private boolean isResetCalled;
 | 
			
		||||
 | 
			
		||||
        @Override
 | 
			
		||||
        public void add(INDArray observation) {
 | 
			
		||||
            this.observation = observation;
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        @Override
 | 
			
		||||
        public INDArray[] get() {
 | 
			
		||||
            isGetCalled = true;
 | 
			
		||||
            return new INDArray[0];
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        @Override
 | 
			
		||||
        public boolean isAtFullCapacity() {
 | 
			
		||||
            isAtFullCapacityCalled = true;
 | 
			
		||||
            return true;
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        @Override
 | 
			
		||||
        public void reset() {
 | 
			
		||||
            isResetCalled = true;
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    private static class TestPoolContentAssembler implements PoolContentAssembler {
 | 
			
		||||
 | 
			
		||||
        public boolean assembleIsCalled;
 | 
			
		||||
 | 
			
		||||
        @Override
 | 
			
		||||
        public INDArray assemble(INDArray[] poolContent) {
 | 
			
		||||
            assembleIsCalled = true;
 | 
			
		||||
            return Nd4j.create(1, 1);
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
@ -1,70 +0,0 @@
 | 
			
		||||
package org.deeplearning4j.rl4j.observation.preprocessor;
 | 
			
		||||
 | 
			
		||||
import org.junit.Test;
 | 
			
		||||
import org.nd4j.linalg.dataset.DataSet;
 | 
			
		||||
import org.nd4j.linalg.factory.Nd4j;
 | 
			
		||||
 | 
			
		||||
import static org.junit.Assert.assertFalse;
 | 
			
		||||
import static org.junit.Assert.assertTrue;
 | 
			
		||||
 | 
			
		||||
public class SkippingDataSetPreProcessorTest {
 | 
			
		||||
    @Test(expected = IllegalArgumentException.class)
 | 
			
		||||
    public void when_ctorSkipFrameIsZero_expect_IllegalArgumentException() {
 | 
			
		||||
        SkippingDataSetPreProcessor sut = new SkippingDataSetPreProcessor(0);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    @Test(expected = IllegalArgumentException.class)
 | 
			
		||||
    public void when_builderSkipFrameIsZero_expect_IllegalArgumentException() {
 | 
			
		||||
        SkippingDataSetPreProcessor sut = SkippingDataSetPreProcessor.builder()
 | 
			
		||||
                .skipFrame(0)
 | 
			
		||||
                .build();
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    @Test
 | 
			
		||||
    public void when_skipFrameIs3_expect_Skip2OutOf3() {
 | 
			
		||||
        // Arrange
 | 
			
		||||
        SkippingDataSetPreProcessor sut = SkippingDataSetPreProcessor.builder()
 | 
			
		||||
                .skipFrame(3)
 | 
			
		||||
                .build();
 | 
			
		||||
        DataSet[] results = new DataSet[4];
 | 
			
		||||
 | 
			
		||||
        // Act
 | 
			
		||||
        for(int i = 0; i < 4; ++i) {
 | 
			
		||||
            results[i] = new DataSet(Nd4j.create(new double[] { 123.0 }), null);
 | 
			
		||||
            sut.preProcess(results[i]);
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        // Assert
 | 
			
		||||
        assertFalse(results[0].isEmpty());
 | 
			
		||||
        assertTrue(results[1].isEmpty());
 | 
			
		||||
        assertTrue(results[2].isEmpty());
 | 
			
		||||
        assertFalse(results[3].isEmpty());
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    @Test
 | 
			
		||||
    public void when_resetIsCalled_expect_skippingIsReset() {
 | 
			
		||||
        // Arrange
 | 
			
		||||
        SkippingDataSetPreProcessor sut = SkippingDataSetPreProcessor.builder()
 | 
			
		||||
                .skipFrame(3)
 | 
			
		||||
                .build();
 | 
			
		||||
        DataSet[] results = new DataSet[4];
 | 
			
		||||
 | 
			
		||||
        // Act
 | 
			
		||||
        results[0] = new DataSet(Nd4j.create(new double[] { 123.0 }), null);
 | 
			
		||||
        results[1] = new DataSet(Nd4j.create(new double[] { 123.0 }), null);
 | 
			
		||||
        results[2] = new DataSet(Nd4j.create(new double[] { 123.0 }), null);
 | 
			
		||||
        results[3] = new DataSet(Nd4j.create(new double[] { 123.0 }), null);
 | 
			
		||||
 | 
			
		||||
        sut.preProcess(results[0]);
 | 
			
		||||
        sut.preProcess(results[1]);
 | 
			
		||||
        sut.reset();
 | 
			
		||||
        sut.preProcess(results[2]);
 | 
			
		||||
        sut.preProcess(results[3]);
 | 
			
		||||
 | 
			
		||||
        // Assert
 | 
			
		||||
        assertFalse(results[0].isEmpty());
 | 
			
		||||
        assertTrue(results[1].isEmpty());
 | 
			
		||||
        assertFalse(results[2].isEmpty());
 | 
			
		||||
        assertTrue(results[3].isEmpty());
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
@ -1,41 +0,0 @@
 | 
			
		||||
package org.deeplearning4j.rl4j.observation.preprocessor.pooling;
 | 
			
		||||
 | 
			
		||||
import org.junit.Test;
 | 
			
		||||
import org.nd4j.linalg.api.ndarray.INDArray;
 | 
			
		||||
import org.nd4j.linalg.factory.Nd4j;
 | 
			
		||||
 | 
			
		||||
import static org.junit.Assert.assertEquals;
 | 
			
		||||
 | 
			
		||||
public class ChannelStackPoolContentAssemblerTest {
 | 
			
		||||
 | 
			
		||||
    @Test
 | 
			
		||||
    public void when_assemble_expect_poolContentStackedOnChannel() {
 | 
			
		||||
        // Assemble
 | 
			
		||||
        ChannelStackPoolContentAssembler sut = new ChannelStackPoolContentAssembler();
 | 
			
		||||
        INDArray[] poolContent = new INDArray[] {
 | 
			
		||||
            Nd4j.rand(2, 2),
 | 
			
		||||
            Nd4j.rand(2, 2),
 | 
			
		||||
        };
 | 
			
		||||
 | 
			
		||||
        // Act
 | 
			
		||||
        INDArray result = sut.assemble(poolContent);
 | 
			
		||||
 | 
			
		||||
        // Assert
 | 
			
		||||
        assertEquals(3, result.shape().length);
 | 
			
		||||
        assertEquals(2, result.shape()[0]);
 | 
			
		||||
        assertEquals(2, result.shape()[1]);
 | 
			
		||||
        assertEquals(2, result.shape()[2]);
 | 
			
		||||
 | 
			
		||||
        assertEquals(poolContent[0].getDouble(0, 0), result.getDouble(0, 0, 0), 0.0001);
 | 
			
		||||
        assertEquals(poolContent[0].getDouble(0, 1), result.getDouble(0, 0, 1), 0.0001);
 | 
			
		||||
        assertEquals(poolContent[0].getDouble(1, 0), result.getDouble(0, 1, 0), 0.0001);
 | 
			
		||||
        assertEquals(poolContent[0].getDouble(1, 1), result.getDouble(0, 1, 1), 0.0001);
 | 
			
		||||
 | 
			
		||||
        assertEquals(poolContent[1].getDouble(0, 0), result.getDouble(1, 0, 0), 0.0001);
 | 
			
		||||
        assertEquals(poolContent[1].getDouble(0, 1), result.getDouble(1, 0, 1), 0.0001);
 | 
			
		||||
        assertEquals(poolContent[1].getDouble(1, 0), result.getDouble(1, 1, 0), 0.0001);
 | 
			
		||||
        assertEquals(poolContent[1].getDouble(1, 1), result.getDouble(1, 1, 1), 0.0001);
 | 
			
		||||
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
}
 | 
			
		||||
@ -1,100 +0,0 @@
 | 
			
		||||
package org.deeplearning4j.rl4j.observation.preprocessor.pooling;
 | 
			
		||||
 | 
			
		||||
import org.junit.Test;
 | 
			
		||||
import org.nd4j.linalg.api.ndarray.INDArray;
 | 
			
		||||
import org.nd4j.linalg.factory.Nd4j;
 | 
			
		||||
 | 
			
		||||
import static org.junit.Assert.assertEquals;
 | 
			
		||||
import static org.junit.Assert.assertFalse;
 | 
			
		||||
import static org.junit.Assert.assertTrue;
 | 
			
		||||
 | 
			
		||||
public class CircularFifoObservationPoolTest {
 | 
			
		||||
 | 
			
		||||
    @Test(expected = IllegalArgumentException.class)
 | 
			
		||||
    public void when_poolSizeZeroOrLess_expect_IllegalArgumentException() {
 | 
			
		||||
        CircularFifoObservationPool sut = new CircularFifoObservationPool(0);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    @Test
 | 
			
		||||
    public void when_poolIsEmpty_expect_NotReady() {
 | 
			
		||||
        // Assemble
 | 
			
		||||
        CircularFifoObservationPool sut = new CircularFifoObservationPool();
 | 
			
		||||
 | 
			
		||||
        // Act
 | 
			
		||||
        boolean isReady = sut.isAtFullCapacity();
 | 
			
		||||
 | 
			
		||||
        // Assert
 | 
			
		||||
        assertFalse(isReady);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    @Test
 | 
			
		||||
    public void when_notEnoughElementsInPool_expect_notReady() {
 | 
			
		||||
        // Assemble
 | 
			
		||||
        CircularFifoObservationPool sut = new CircularFifoObservationPool();
 | 
			
		||||
        sut.add(Nd4j.create(new double[] { 123.0 }));
 | 
			
		||||
 | 
			
		||||
        // Act
 | 
			
		||||
        boolean isReady = sut.isAtFullCapacity();
 | 
			
		||||
 | 
			
		||||
        // Assert
 | 
			
		||||
        assertFalse(isReady);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    @Test
 | 
			
		||||
    public void when_enoughElementsInPool_expect_ready() {
 | 
			
		||||
        // Assemble
 | 
			
		||||
        CircularFifoObservationPool sut = CircularFifoObservationPool.builder()
 | 
			
		||||
                .poolSize(2)
 | 
			
		||||
                .build();
 | 
			
		||||
        sut.add(Nd4j.createFromArray(123.0));
 | 
			
		||||
        sut.add(Nd4j.createFromArray(123.0));
 | 
			
		||||
 | 
			
		||||
        // Act
 | 
			
		||||
        boolean isReady = sut.isAtFullCapacity();
 | 
			
		||||
 | 
			
		||||
        // Assert
 | 
			
		||||
        assertTrue(isReady);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    @Test
 | 
			
		||||
    public void when_addMoreThanSize_expect_getReturnOnlyLastElements() {
 | 
			
		||||
        // Assemble
 | 
			
		||||
        CircularFifoObservationPool sut = CircularFifoObservationPool.builder().build();
 | 
			
		||||
        sut.add(Nd4j.createFromArray(0.0));
 | 
			
		||||
        sut.add(Nd4j.createFromArray(1.0));
 | 
			
		||||
        sut.add(Nd4j.createFromArray(2.0));
 | 
			
		||||
        sut.add(Nd4j.createFromArray(3.0));
 | 
			
		||||
        sut.add(Nd4j.createFromArray(4.0));
 | 
			
		||||
        sut.add(Nd4j.createFromArray(5.0));
 | 
			
		||||
        sut.add(Nd4j.createFromArray(6.0));
 | 
			
		||||
 | 
			
		||||
        // Act
 | 
			
		||||
        INDArray[] result = sut.get();
 | 
			
		||||
 | 
			
		||||
        // Assert
 | 
			
		||||
        assertEquals(3.0, result[0].getDouble(0), 0.0);
 | 
			
		||||
        assertEquals(4.0, result[1].getDouble(0), 0.0);
 | 
			
		||||
        assertEquals(5.0, result[2].getDouble(0), 0.0);
 | 
			
		||||
        assertEquals(6.0, result[3].getDouble(0), 0.0);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    @Test
 | 
			
		||||
    public void when_resetIsCalled_expect_poolContentFlushed() {
 | 
			
		||||
        // Assemble
 | 
			
		||||
        CircularFifoObservationPool sut = CircularFifoObservationPool.builder().build();
 | 
			
		||||
        sut.add(Nd4j.createFromArray(0.0));
 | 
			
		||||
        sut.add(Nd4j.createFromArray(1.0));
 | 
			
		||||
        sut.add(Nd4j.createFromArray(2.0));
 | 
			
		||||
        sut.add(Nd4j.createFromArray(3.0));
 | 
			
		||||
        sut.add(Nd4j.createFromArray(4.0));
 | 
			
		||||
        sut.add(Nd4j.createFromArray(5.0));
 | 
			
		||||
        sut.add(Nd4j.createFromArray(6.0));
 | 
			
		||||
        sut.reset();
 | 
			
		||||
 | 
			
		||||
        // Act
 | 
			
		||||
        INDArray[] result = sut.get();
 | 
			
		||||
 | 
			
		||||
        // Assert
 | 
			
		||||
        assertEquals(0, result.length);
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
@ -0,0 +1,166 @@
 | 
			
		||||
package org.deeplearning4j.rl4j.observation.transform.operation;
 | 
			
		||||
 | 
			
		||||
import org.deeplearning4j.rl4j.observation.transform.operation.historymerge.HistoryMergeAssembler;
 | 
			
		||||
import org.deeplearning4j.rl4j.observation.transform.operation.historymerge.HistoryMergeElementStore;
 | 
			
		||||
import org.junit.Test;
 | 
			
		||||
import org.nd4j.linalg.api.ndarray.INDArray;
 | 
			
		||||
import org.nd4j.linalg.factory.Nd4j;
 | 
			
		||||
 | 
			
		||||
import static org.junit.Assert.*;
 | 
			
		||||
 | 
			
		||||
public class HistoryMergeTransformTest {
 | 
			
		||||
 | 
			
		||||
    @Test
 | 
			
		||||
    public void when_firstDimensionIsNotBatch_expect_observationAddedAsIs() {
 | 
			
		||||
        // Arrange
 | 
			
		||||
        MockStore store = new MockStore(false);
 | 
			
		||||
        HistoryMergeTransform sut = HistoryMergeTransform.builder()
 | 
			
		||||
                .isFirstDimenstionBatch(false)
 | 
			
		||||
                .elementStore(store)
 | 
			
		||||
                .build();
 | 
			
		||||
        INDArray input = Nd4j.create(new double[] { 1.0, 2.0, 3.0 });
 | 
			
		||||
 | 
			
		||||
        // Act
 | 
			
		||||
        sut.transform(input);
 | 
			
		||||
 | 
			
		||||
        // Assert
 | 
			
		||||
        assertEquals(1, store.addedObservation.shape().length);
 | 
			
		||||
        assertEquals(3, store.addedObservation.shape()[0]);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    @Test
 | 
			
		||||
    public void when_firstDimensionIsBatch_expect_observationAddedAsSliced() {
 | 
			
		||||
        // Arrange
 | 
			
		||||
        MockStore store = new MockStore(false);
 | 
			
		||||
        HistoryMergeTransform sut = HistoryMergeTransform.builder()
 | 
			
		||||
                .isFirstDimenstionBatch(true)
 | 
			
		||||
                .elementStore(store)
 | 
			
		||||
                .build();
 | 
			
		||||
        INDArray input = Nd4j.create(new double[] { 1.0, 2.0, 3.0 }).reshape(1, 3);
 | 
			
		||||
 | 
			
		||||
        // Act
 | 
			
		||||
        sut.transform(input);
 | 
			
		||||
 | 
			
		||||
        // Assert
 | 
			
		||||
        assertEquals(1, store.addedObservation.shape().length);
 | 
			
		||||
        assertEquals(3, store.addedObservation.shape()[0]);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    @Test
 | 
			
		||||
    public void when_notReady_expect_resultIsNull() {
 | 
			
		||||
        // Arrange
 | 
			
		||||
        MockStore store = new MockStore(false);
 | 
			
		||||
        HistoryMergeTransform sut = HistoryMergeTransform.builder()
 | 
			
		||||
                .isFirstDimenstionBatch(true)
 | 
			
		||||
                .elementStore(store)
 | 
			
		||||
                .build();
 | 
			
		||||
        INDArray input = Nd4j.create(new double[] { 1.0, 2.0, 3.0 });
 | 
			
		||||
 | 
			
		||||
        // Act
 | 
			
		||||
        INDArray result = sut.transform(input);
 | 
			
		||||
 | 
			
		||||
        // Assert
 | 
			
		||||
        assertNull(result);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    @Test
 | 
			
		||||
    public void when_notShouldStoreCopy_expect_sameIsStored() {
 | 
			
		||||
        // Arrange
 | 
			
		||||
        MockStore store = new MockStore(false);
 | 
			
		||||
        HistoryMergeTransform sut = HistoryMergeTransform.builder()
 | 
			
		||||
                .shouldStoreCopy(false)
 | 
			
		||||
                .elementStore(store)
 | 
			
		||||
                .build();
 | 
			
		||||
        INDArray input = Nd4j.create(new double[] { 1.0, 2.0, 3.0 });
 | 
			
		||||
 | 
			
		||||
        // Act
 | 
			
		||||
        INDArray result = sut.transform(input);
 | 
			
		||||
 | 
			
		||||
        // Assert
 | 
			
		||||
        assertSame(input, store.addedObservation);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    @Test
 | 
			
		||||
    public void when_shouldStoreCopy_expect_copyIsStored() {
 | 
			
		||||
        // Arrange
 | 
			
		||||
        MockStore store = new MockStore(true);
 | 
			
		||||
        HistoryMergeTransform sut = HistoryMergeTransform.builder()
 | 
			
		||||
                .shouldStoreCopy(true)
 | 
			
		||||
                .elementStore(store)
 | 
			
		||||
                .build();
 | 
			
		||||
        INDArray input = Nd4j.create(new double[] { 1.0, 2.0, 3.0 });
 | 
			
		||||
 | 
			
		||||
        // Act
 | 
			
		||||
        INDArray result = sut.transform(input);
 | 
			
		||||
 | 
			
		||||
        // Assert
 | 
			
		||||
        assertNotSame(input, store.addedObservation);
 | 
			
		||||
        assertEquals(1, store.addedObservation.shape().length);
 | 
			
		||||
        assertEquals(3, store.addedObservation.shape()[0]);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    @Test
 | 
			
		||||
    public void when_transformCalled_expect_storeContentAssembledAndOutputHasCorrectShape() {
 | 
			
		||||
        // Arrange
 | 
			
		||||
        MockStore store = new MockStore(true);
 | 
			
		||||
        MockAssemble assemble = new MockAssemble();
 | 
			
		||||
        HistoryMergeTransform sut = HistoryMergeTransform.builder()
 | 
			
		||||
                .elementStore(store)
 | 
			
		||||
                .assembler(assemble)
 | 
			
		||||
                .build();
 | 
			
		||||
        INDArray input = Nd4j.create(new double[] { 1.0, 2.0, 3.0 });
 | 
			
		||||
 | 
			
		||||
        // Act
 | 
			
		||||
        INDArray result = sut.transform(input);
 | 
			
		||||
 | 
			
		||||
        // Assert
 | 
			
		||||
        assertEquals(1, assemble.assembleElements.length);
 | 
			
		||||
        assertSame(store.addedObservation, assemble.assembleElements[0]);
 | 
			
		||||
 | 
			
		||||
        assertEquals(2, result.shape().length);
 | 
			
		||||
        assertEquals(1, result.shape()[0]);
 | 
			
		||||
        assertEquals(3, result.shape()[1]);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    public static class MockStore implements HistoryMergeElementStore {
 | 
			
		||||
 | 
			
		||||
        private final boolean isReady;
 | 
			
		||||
        private INDArray addedObservation;
 | 
			
		||||
 | 
			
		||||
        public MockStore(boolean isReady) {
 | 
			
		||||
 | 
			
		||||
            this.isReady = isReady;
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        @Override
 | 
			
		||||
        public void add(INDArray observation) {
 | 
			
		||||
            addedObservation = observation;
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        @Override
 | 
			
		||||
        public INDArray[] get() {
 | 
			
		||||
            return new INDArray[] { addedObservation };
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        @Override
 | 
			
		||||
        public boolean isReady() {
 | 
			
		||||
            return isReady;
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        @Override
 | 
			
		||||
        public void reset() {
 | 
			
		||||
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    public static class MockAssemble implements HistoryMergeAssembler {
 | 
			
		||||
 | 
			
		||||
        private INDArray[] assembleElements;
 | 
			
		||||
 | 
			
		||||
        @Override
 | 
			
		||||
        public INDArray assemble(INDArray[] elements) {
 | 
			
		||||
            assembleElements = elements;
 | 
			
		||||
            return elements[0];
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
@ -0,0 +1,77 @@
 | 
			
		||||
package org.deeplearning4j.rl4j.observation.transform.operation.historymerge;
 | 
			
		||||
 | 
			
		||||
import org.junit.Test;
 | 
			
		||||
import org.nd4j.linalg.api.ndarray.INDArray;
 | 
			
		||||
import org.nd4j.linalg.factory.Nd4j;
 | 
			
		||||
 | 
			
		||||
import static org.junit.Assert.*;
 | 
			
		||||
 | 
			
		||||
public class CircularFifoStoreTest {
 | 
			
		||||
 | 
			
		||||
    @Test(expected = IllegalArgumentException.class)
 | 
			
		||||
    public void when_fifoSizeIsLessThan1_expect_exception() {
 | 
			
		||||
        // Arrange
 | 
			
		||||
        CircularFifoStore sut = new CircularFifoStore(0);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    @Test
 | 
			
		||||
    public void when_adding2elementsWithSize2_expect_notReadyAfter1stReadyAfter2nd() {
 | 
			
		||||
        // Arrange
 | 
			
		||||
        CircularFifoStore sut = new CircularFifoStore(2);
 | 
			
		||||
        INDArray firstElement = Nd4j.create(new double[] { 1.0, 2.0, 3.0 });
 | 
			
		||||
        INDArray secondElement = Nd4j.create(new double[] { 10.0, 20.0, 30.0 });
 | 
			
		||||
 | 
			
		||||
        // Act
 | 
			
		||||
        sut.add(firstElement);
 | 
			
		||||
        boolean isReadyAfter1st = sut.isReady();
 | 
			
		||||
        sut.add(secondElement);
 | 
			
		||||
        boolean isReadyAfter2nd = sut.isReady();
 | 
			
		||||
 | 
			
		||||
        // Assert
 | 
			
		||||
        assertFalse(isReadyAfter1st);
 | 
			
		||||
        assertTrue(isReadyAfter2nd);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    @Test
 | 
			
		||||
    public void when_adding2elementsWithSize2_expect_getReturnThese2() {
 | 
			
		||||
        // Arrange
 | 
			
		||||
        CircularFifoStore sut = new CircularFifoStore(2);
 | 
			
		||||
        INDArray firstElement = Nd4j.create(new double[] { 1.0, 2.0, 3.0 });
 | 
			
		||||
        INDArray secondElement = Nd4j.create(new double[] { 10.0, 20.0, 30.0 });
 | 
			
		||||
 | 
			
		||||
        // Act
 | 
			
		||||
        sut.add(firstElement);
 | 
			
		||||
        sut.add(secondElement);
 | 
			
		||||
        INDArray[] results = sut.get();
 | 
			
		||||
 | 
			
		||||
        // Assert
 | 
			
		||||
        assertEquals(2, results.length);
 | 
			
		||||
 | 
			
		||||
        assertEquals(1.0, results[0].getDouble(0), 0.00001);
 | 
			
		||||
        assertEquals(2.0, results[0].getDouble(1), 0.00001);
 | 
			
		||||
        assertEquals(3.0, results[0].getDouble(2), 0.00001);
 | 
			
		||||
 | 
			
		||||
        assertEquals(10.0, results[1].getDouble(0), 0.00001);
 | 
			
		||||
        assertEquals(20.0, results[1].getDouble(1), 0.00001);
 | 
			
		||||
        assertEquals(30.0, results[1].getDouble(2), 0.00001);
 | 
			
		||||
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    @Test
 | 
			
		||||
    public void when_adding2elementsThenCallingReset_expect_getReturnEmpty() {
 | 
			
		||||
        // Arrange
 | 
			
		||||
        CircularFifoStore sut = new CircularFifoStore(2);
 | 
			
		||||
        INDArray firstElement = Nd4j.create(new double[] { 1.0, 2.0, 3.0 });
 | 
			
		||||
        INDArray secondElement = Nd4j.create(new double[] { 10.0, 20.0, 30.0 });
 | 
			
		||||
 | 
			
		||||
        // Act
 | 
			
		||||
        sut.add(firstElement);
 | 
			
		||||
        sut.add(secondElement);
 | 
			
		||||
        sut.reset();
 | 
			
		||||
        INDArray[] results = sut.get();
 | 
			
		||||
 | 
			
		||||
        // Assert
 | 
			
		||||
        assertEquals(0, results.length);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
}
 | 
			
		||||
@ -0,0 +1,37 @@
 | 
			
		||||
package org.deeplearning4j.rl4j.observation.transform.operation.historymerge;
 | 
			
		||||
 | 
			
		||||
import org.junit.Test;
 | 
			
		||||
import org.nd4j.linalg.api.ndarray.INDArray;
 | 
			
		||||
import org.nd4j.linalg.factory.Nd4j;
 | 
			
		||||
 | 
			
		||||
import static org.junit.Assert.*;
 | 
			
		||||
 | 
			
		||||
public class HistoryStackAssemblerTest {
 | 
			
		||||
 | 
			
		||||
    @Test
 | 
			
		||||
    public void when_assembling2INDArrays_expect_stackedAsResult() {
 | 
			
		||||
        // Arrange
 | 
			
		||||
        INDArray[] input = new INDArray[] {
 | 
			
		||||
                Nd4j.create(new double[] { 1.0, 2.0, 3.0 }),
 | 
			
		||||
                Nd4j.create(new double[] { 10.0, 20.0, 30.0 }),
 | 
			
		||||
        };
 | 
			
		||||
        HistoryStackAssembler sut = new HistoryStackAssembler();
 | 
			
		||||
 | 
			
		||||
        // Act
 | 
			
		||||
        INDArray result = sut.assemble(input);
 | 
			
		||||
 | 
			
		||||
        // Assert
 | 
			
		||||
        assertEquals(2, result.shape().length);
 | 
			
		||||
        assertEquals(2, result.shape()[0]);
 | 
			
		||||
        assertEquals(3, result.shape()[1]);
 | 
			
		||||
 | 
			
		||||
        assertEquals(1.0, result.getDouble(0, 0), 0.00001);
 | 
			
		||||
        assertEquals(2.0, result.getDouble(0, 1), 0.00001);
 | 
			
		||||
        assertEquals(3.0, result.getDouble(0, 2), 0.00001);
 | 
			
		||||
 | 
			
		||||
        assertEquals(10.0, result.getDouble(1, 0), 0.00001);
 | 
			
		||||
        assertEquals(20.0, result.getDouble(1, 1), 0.00001);
 | 
			
		||||
        assertEquals(30.0, result.getDouble(1, 2), 0.00001);
 | 
			
		||||
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user