Merge remote-tracking branch 'eclipse/master'
Signed-off-by: Samuel Audet <samuel.audet@gmail.com>master
commit
5cd143611e
|
@ -66,3 +66,7 @@ doc_sources_*
|
|||
# Python virtual environments
|
||||
venv/
|
||||
venv2/
|
||||
|
||||
# Ignore the nd4j files that are created by javacpp at build to stop merge conflicts
|
||||
nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/nativeblas/Nd4jCpu.java
|
||||
nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/nativeblas/Nd4jCuda.java
|
||||
|
|
|
@ -26,10 +26,7 @@ import org.datavec.api.transform.schema.Schema;
|
|||
*
|
||||
* @author Adam Gibson
|
||||
*/
|
||||
public interface ColumnOp {
|
||||
/** Get the output schema for this transformation, given an input schema */
|
||||
Schema transform(Schema inputSchema);
|
||||
|
||||
public interface ColumnOp extends Operation<Schema, Schema> {
|
||||
|
||||
/** Set the input schema.
|
||||
*/
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2019 Skymind, Inc.
|
||||
* Copyright (c) 2020 Konduit K.K.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
|
@ -13,16 +13,8 @@
|
|||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
package org.datavec.api.transform;
|
||||
|
||||
package org.deeplearning4j.rl4j.observation.preprocessor;
|
||||
|
||||
import org.nd4j.linalg.dataset.api.DataSetPreProcessor;
|
||||
|
||||
/**
|
||||
* A base class for all DataSetPreProcessor that must be reset between each MDP sessions (games).
|
||||
*
|
||||
* @author Alexandre Boulanger
|
||||
*/
|
||||
public abstract class ResettableDataSetPreProcessor implements DataSetPreProcessor {
|
||||
public abstract void reset();
|
||||
public interface Operation<TIn, TOut> {
|
||||
TOut transform(TIn input);
|
||||
}
|
|
@ -16,6 +16,7 @@
|
|||
|
||||
package org.datavec.image.transform;
|
||||
|
||||
import org.datavec.api.transform.Operation;
|
||||
import org.datavec.image.data.ImageWritable;
|
||||
import org.nd4j.shade.jackson.annotation.JsonInclude;
|
||||
import org.nd4j.shade.jackson.annotation.JsonTypeInfo;
|
||||
|
@ -29,15 +30,7 @@ import java.util.Random;
|
|||
*/
|
||||
@JsonInclude(JsonInclude.Include.NON_NULL)
|
||||
@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class")
|
||||
public interface ImageTransform {
|
||||
|
||||
/**
|
||||
* Takes an image and returns a transformed image.
|
||||
*
|
||||
* @param image to transform, null == end of stream
|
||||
* @return transformed image
|
||||
*/
|
||||
ImageWritable transform(ImageWritable image);
|
||||
public interface ImageTransform extends Operation<ImageWritable, ImageWritable> {
|
||||
|
||||
/**
|
||||
* Takes an image and returns a transformed image.
|
||||
|
|
6
pom.xml
6
pom.xml
|
@ -297,14 +297,14 @@
|
|||
<numpy.version>1.18.1</numpy.version>
|
||||
<numpy.javacpp.version>${numpy.version}-${javacpp-presets.version}</numpy.javacpp.version>
|
||||
|
||||
<openblas.version>0.3.8</openblas.version>
|
||||
<openblas.version>0.3.9</openblas.version>
|
||||
<mkl.version>2020.0</mkl.version>
|
||||
<opencv.version>4.2.0</opencv.version>
|
||||
<ffmpeg.version>4.2.2</ffmpeg.version>
|
||||
<leptonica.version>1.79.0</leptonica.version>
|
||||
<hdf5.version>1.10.6</hdf5.version>
|
||||
<hdf5.version>1.12.0</hdf5.version>
|
||||
<ale.version>0.6.1</ale.version>
|
||||
<gym.version>0.15.5</gym.version>
|
||||
<gym.version>0.17.1</gym.version>
|
||||
<tensorflow.version>1.15.2</tensorflow.version>
|
||||
<tensorflow.javacpp.version>${tensorflow.version}-${javacpp-presets.version}</tensorflow.javacpp.version>
|
||||
|
||||
|
|
|
@ -102,6 +102,13 @@
|
|||
<artifactId>gson</artifactId>
|
||||
<version>${gson.version}</version>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>org.datavec</groupId>
|
||||
<artifactId>datavec-api</artifactId>
|
||||
<version>${datavec.version}</version>
|
||||
</dependency>
|
||||
|
||||
</dependencies>
|
||||
|
||||
<profiles>
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2019 Skymind, Inc.
|
||||
* Copyright (c) 2020 Konduit K.K.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
|
@ -13,18 +13,27 @@
|
|||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
package org.deeplearning4j.rl4j.observation.preprocessor.pooling;
|
||||
package org.deeplearning4j.rl4j.helper;
|
||||
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
|
||||
/**
|
||||
* A PoolContentAssembler is used with the PoolingDataSetPreProcessor. This interface defines how the array of INDArray
|
||||
* returned by the ObservationPool is packaged into the single INDArray that will be set
|
||||
* in the DataSet of PoolingDataSetPreProcessor.preProcess
|
||||
* INDArray helper methods used by RL4J
|
||||
*
|
||||
* @author Alexandre Boulanger
|
||||
*/
|
||||
public interface PoolContentAssembler {
|
||||
INDArray assemble(INDArray[] poolContent);
|
||||
public class INDArrayHelper {
|
||||
/**
|
||||
* MultiLayerNetwork and ComputationGraph expect the first dimension to be the number of examples in the INDArray.
|
||||
* In the case of RL4J, it must be 1. This method will return a INDArray with the correct shape.
|
||||
*
|
||||
* @param source A INDArray
|
||||
* @return The source INDArray with the correct shape
|
||||
*/
|
||||
public static INDArray forceCorrectShape(INDArray source) {
|
||||
return source.shape()[0] == 1
|
||||
? source
|
||||
: Nd4j.expandDims(source, 0);
|
||||
}
|
||||
}
|
|
@ -90,7 +90,7 @@ public abstract class AsyncThreadDiscrete<O, NN extends NeuralNet>
|
|||
accuReward += stepReply.getReward() * getConf().getRewardFactor();
|
||||
|
||||
//if it's not a skipped frame, you can do a step of training
|
||||
if (!obs.isSkipped() || stepReply.isDone()) {
|
||||
if (!obs.isSkipped()) {
|
||||
|
||||
INDArray[] output = current.outputAll(obs.getData());
|
||||
rewards.add(new MiniTrans(obs.getData(), action, output, accuReward));
|
||||
|
@ -99,7 +99,6 @@ public abstract class AsyncThreadDiscrete<O, NN extends NeuralNet>
|
|||
}
|
||||
|
||||
obs = stepReply.getObservation();
|
||||
|
||||
reward += stepReply.getReward();
|
||||
|
||||
incrementStep();
|
||||
|
|
|
@ -158,7 +158,7 @@ public abstract class QLearningDiscrete<O extends Encodable> extends QLearning<O
|
|||
accuReward += stepReply.getReward() * configuration.getRewardFactor();
|
||||
|
||||
//if it's not a skipped frame, you can do a step of training
|
||||
if (!obs.isSkipped() || stepReply.isDone()) {
|
||||
if (!obs.isSkipped()) {
|
||||
|
||||
// Add experience
|
||||
if(pendingTransition != null) {
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2019 Skymind, Inc.
|
||||
* Copyright (c) 2020 Konduit K.K.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
|
@ -17,57 +17,43 @@
|
|||
package org.deeplearning4j.rl4j.observation;
|
||||
|
||||
import lombok.Getter;
|
||||
import lombok.Setter;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.dataset.api.DataSet;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
|
||||
/**
|
||||
* Presently only a dummy container. Will contain observation channels when done.
|
||||
* Represent an observation from the environment
|
||||
*
|
||||
* @author Alexandre Boulanger
|
||||
*/
|
||||
public class Observation {
|
||||
// TODO: Presently only a dummy container. Will contain observation channels when done.
|
||||
|
||||
private final DataSet data;
|
||||
/**
|
||||
* A singleton representing a skipped observation
|
||||
*/
|
||||
public static Observation SkippedObservation = new Observation(null);
|
||||
|
||||
@Getter @Setter
|
||||
private boolean skipped;
|
||||
/**
|
||||
* @return A INDArray containing the data of the observation
|
||||
*/
|
||||
@Getter
|
||||
private final INDArray data;
|
||||
|
||||
public Observation(INDArray[] data) {
|
||||
this(data, false);
|
||||
public boolean isSkipped() {
|
||||
return data == null;
|
||||
}
|
||||
|
||||
public Observation(INDArray[] data, boolean shouldReshape) {
|
||||
INDArray features = Nd4j.concat(0, data);
|
||||
if(shouldReshape) {
|
||||
features = reshape(features);
|
||||
}
|
||||
this.data = new org.nd4j.linalg.dataset.DataSet(features, null);
|
||||
}
|
||||
|
||||
// FIXME: Remove -- only used in unit tests
|
||||
public Observation(INDArray data) {
|
||||
this.data = new org.nd4j.linalg.dataset.DataSet(data, null);
|
||||
}
|
||||
|
||||
private INDArray reshape(INDArray source) {
|
||||
long[] shape = source.shape();
|
||||
long[] nshape = new long[shape.length + 1];
|
||||
nshape[0] = 1;
|
||||
System.arraycopy(shape, 0, nshape, 1, shape.length);
|
||||
|
||||
return source.reshape(nshape);
|
||||
}
|
||||
|
||||
private Observation(DataSet data) {
|
||||
this.data = data;
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates a duplicate instance of the current observation
|
||||
* @return
|
||||
*/
|
||||
public Observation dup() {
|
||||
return new Observation(new org.nd4j.linalg.dataset.DataSet(data.getFeatures().dup(), null));
|
||||
}
|
||||
if(data == null) {
|
||||
return SkippedObservation;
|
||||
}
|
||||
|
||||
public INDArray getData() {
|
||||
return data.getFeatures();
|
||||
return new Observation(data.dup());
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,130 +0,0 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2019 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
package org.deeplearning4j.rl4j.observation.preprocessor;
|
||||
|
||||
import org.deeplearning4j.rl4j.observation.preprocessor.pooling.ChannelStackPoolContentAssembler;
|
||||
import org.deeplearning4j.rl4j.observation.preprocessor.pooling.PoolContentAssembler;
|
||||
import org.deeplearning4j.rl4j.observation.preprocessor.pooling.CircularFifoObservationPool;
|
||||
import org.deeplearning4j.rl4j.observation.preprocessor.pooling.ObservationPool;
|
||||
import org.nd4j.base.Preconditions;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.dataset.api.DataSet;
|
||||
|
||||
/**
|
||||
* The PoolingDataSetPreProcessor will accumulate features from incoming DataSets and will assemble its content
|
||||
* into a DataSet containing a single example.
|
||||
*
|
||||
* There are two special cases:
|
||||
* 1) preProcess will return without doing anything if the input DataSet is empty
|
||||
* 2) When the pool has not yet filled, the data from the incoming DataSet is stored in the pool but the DataSet is emptied
|
||||
* on exit.
|
||||
* <br>
|
||||
* The PoolingDataSetPreProcessor requires two sub components: <br>
|
||||
* 1) The ObservationPool that supervises what and how input observations are kept. (ex.: Circular FIFO, trailing min/max/avg, etc...)
|
||||
* The default is a Circular FIFO.
|
||||
* 2) The PoolContentAssembler that will assemble the pool content into a resulting single INDArray. (ex.: stacked along a dimention, squashed into a single observation, etc...)
|
||||
* The default is stacking along the dimension 0.
|
||||
*
|
||||
* @author Alexandre Boulanger
|
||||
*/
|
||||
public class PoolingDataSetPreProcessor extends ResettableDataSetPreProcessor {
|
||||
private final ObservationPool observationPool;
|
||||
private final PoolContentAssembler poolContentAssembler;
|
||||
|
||||
protected PoolingDataSetPreProcessor(PoolingDataSetPreProcessor.Builder builder)
|
||||
{
|
||||
observationPool = builder.observationPool;
|
||||
poolContentAssembler = builder.poolContentAssembler;
|
||||
}
|
||||
|
||||
/**
|
||||
* Note: preProcess will empty the processed dataset if the pool has not filled. Empty datasets should ignored by the
|
||||
* Policy/Learning class and other DataSetPreProcessors
|
||||
*
|
||||
* @param dataSet
|
||||
*/
|
||||
@Override
|
||||
public void preProcess(DataSet dataSet) {
|
||||
Preconditions.checkNotNull(dataSet, "Encountered null dataSet");
|
||||
|
||||
if(dataSet.isEmpty()) {
|
||||
return;
|
||||
}
|
||||
|
||||
Preconditions.checkArgument(dataSet.numExamples() == 1, "Pooling datasets conatining more than one example is not supported");
|
||||
|
||||
// store a duplicate in the pool
|
||||
observationPool.add(dataSet.getFeatures().slice(0, 0).dup());
|
||||
if(!observationPool.isAtFullCapacity()) {
|
||||
dataSet.setFeatures(null);
|
||||
return;
|
||||
}
|
||||
|
||||
INDArray result = poolContentAssembler.assemble(observationPool.get());
|
||||
|
||||
// return a DataSet containing only 1 example (the result)
|
||||
long[] resultShape = result.shape();
|
||||
long[] newShape = new long[resultShape.length + 1];
|
||||
newShape[0] = 1;
|
||||
System.arraycopy(resultShape, 0, newShape, 1, resultShape.length);
|
||||
|
||||
dataSet.setFeatures(result.reshape(newShape));
|
||||
}
|
||||
|
||||
public static PoolingDataSetPreProcessor.Builder builder() {
|
||||
return new PoolingDataSetPreProcessor.Builder();
|
||||
}
|
||||
|
||||
@Override
|
||||
public void reset() {
|
||||
observationPool.reset();
|
||||
}
|
||||
|
||||
public static class Builder {
|
||||
private ObservationPool observationPool;
|
||||
private PoolContentAssembler poolContentAssembler;
|
||||
|
||||
/**
|
||||
* Default is CircularFifoObservationPool
|
||||
*/
|
||||
public PoolingDataSetPreProcessor.Builder observablePool(ObservationPool observationPool) {
|
||||
this.observationPool = observationPool;
|
||||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* Default is ChannelStackPoolContentAssembler
|
||||
*/
|
||||
public PoolingDataSetPreProcessor.Builder poolContentAssembler(PoolContentAssembler poolContentAssembler) {
|
||||
this.poolContentAssembler = poolContentAssembler;
|
||||
return this;
|
||||
}
|
||||
|
||||
public PoolingDataSetPreProcessor build() {
|
||||
if(observationPool == null) {
|
||||
observationPool = new CircularFifoObservationPool();
|
||||
}
|
||||
|
||||
if(poolContentAssembler == null) {
|
||||
poolContentAssembler = new ChannelStackPoolContentAssembler();
|
||||
}
|
||||
|
||||
return new PoolingDataSetPreProcessor(this);
|
||||
}
|
||||
}
|
||||
|
||||
}
|
|
@ -1,62 +0,0 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2019 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
package org.deeplearning4j.rl4j.observation.preprocessor;
|
||||
|
||||
import lombok.Builder;
|
||||
import org.nd4j.base.Preconditions;
|
||||
import org.nd4j.linalg.dataset.api.DataSet;
|
||||
|
||||
/**
|
||||
* The SkippingDataSetPreProcessor will either do nothing to the input (when not skipped) or will empty
|
||||
* the input DataSet when skipping.
|
||||
*
|
||||
* @author Alexandre Boulanger
|
||||
*/
|
||||
public class SkippingDataSetPreProcessor extends ResettableDataSetPreProcessor {
|
||||
|
||||
private final int skipFrame;
|
||||
|
||||
private int currentIdx = 0;
|
||||
|
||||
/**
|
||||
* @param skipFrame For example, a skipFrame of 4 will skip 3 out of 4 observations.
|
||||
*/
|
||||
@Builder
|
||||
public SkippingDataSetPreProcessor(int skipFrame) {
|
||||
Preconditions.checkArgument(skipFrame > 0, "skipFrame must be greater than 0, got %s", skipFrame);
|
||||
this.skipFrame = skipFrame;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void preProcess(DataSet dataSet) {
|
||||
Preconditions.checkNotNull(dataSet, "Encountered null dataSet");
|
||||
|
||||
if(dataSet.isEmpty()) {
|
||||
return;
|
||||
}
|
||||
|
||||
if(currentIdx++ % skipFrame != 0) {
|
||||
dataSet.setFeatures(null);
|
||||
dataSet.setLabels(null);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public void reset() {
|
||||
currentIdx = 0;
|
||||
}
|
||||
}
|
|
@ -0,0 +1,35 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2020 Konduit K.K.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
package org.deeplearning4j.rl4j.observation.transform;
|
||||
|
||||
import java.util.Map;
|
||||
|
||||
/**
|
||||
* Used with {@link TransformProcess TransformProcess} to filter-out an observation.
|
||||
*
|
||||
* @author Alexandre Boulanger
|
||||
*/
|
||||
public interface FilterOperation {
|
||||
/**
|
||||
* The logic that determines if the observation should be skipped.
|
||||
*
|
||||
* @param channelsData the name of the channel
|
||||
* @param currentObservationStep The step number if the observation in the current episode.
|
||||
* @param isFinalObservation true if this is the last observation of the episode
|
||||
* @return true if the observation should be skipped
|
||||
*/
|
||||
boolean isSkipped(Map<String, Object> channelsData, int currentObservationStep, boolean isFinalObservation);
|
||||
}
|
|
@ -1,5 +1,5 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2019 Skymind, Inc.
|
||||
* Copyright (c) 2020 Konduit K.K.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
|
@ -13,20 +13,14 @@
|
|||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
package org.deeplearning4j.rl4j.observation.preprocessor.pooling;
|
||||
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
package org.deeplearning4j.rl4j.observation.transform;
|
||||
|
||||
/**
|
||||
* ObservationPool is used with the PoolingDataSetPreProcessor. Used to supervise how data from the
|
||||
* PoolingDataSetPreProcessor is stored.
|
||||
*
|
||||
* @author Alexandre Boulanger
|
||||
* The {@link TransformProcess TransformProcess} will call reset() (at the start of an episode) of any step that implement this interface.
|
||||
*/
|
||||
public interface ObservationPool {
|
||||
void add(INDArray observation);
|
||||
INDArray[] get();
|
||||
boolean isAtFullCapacity();
|
||||
public interface ResettableOperation {
|
||||
/**
|
||||
* Called by TransformProcess when an episode starts. See {@link TransformProcess#reset() TransformProcess.reset()}
|
||||
*/
|
||||
void reset();
|
||||
}
|
|
@ -0,0 +1,231 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2020 Konduit K.K.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
package org.deeplearning4j.rl4j.observation.transform;
|
||||
|
||||
import org.apache.commons.lang3.NotImplementedException;
|
||||
import org.deeplearning4j.rl4j.helper.INDArrayHelper;
|
||||
import org.deeplearning4j.rl4j.observation.Observation;
|
||||
import org.nd4j.base.Preconditions;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.dataset.api.DataSet;
|
||||
import org.nd4j.linalg.dataset.api.DataSetPreProcessor;
|
||||
import org.nd4j.shade.guava.collect.Maps;
|
||||
import org.datavec.api.transform.Operation;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.HashSet;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
/**
|
||||
* A TransformProcess will build an {@link Observation Observation} from the raw data coming from the environment.
|
||||
* Three types of steps are available:
|
||||
* 1) A {@link FilterOperation FilterOperation}: Used to determine if an observation should be skipped.
|
||||
* 2) An {@link Operation Operation}: Applies a transform and/or conversion to an observation channel.
|
||||
* 3) A {@link DataSetPreProcessor DataSetPreProcessor}: Applies a DataSetPreProcessor to the observation channel. The data in the channel must be a DataSet.
|
||||
*
|
||||
* Instances of the three types above can be called in any order. The only requirement is that when build() is called,
|
||||
* all channels must be instances of INDArrays or DataSets
|
||||
*
|
||||
* NOTE: Presently, only single-channels observations are supported.
|
||||
*
|
||||
* @author Alexandre Boulanger
|
||||
*/
|
||||
public class TransformProcess {
|
||||
|
||||
private final List<Map.Entry<String, Object>> operations;
|
||||
private final String[] channelNames;
|
||||
private final HashSet<String> operationsChannelNames;
|
||||
|
||||
private TransformProcess(Builder builder, String... channelNames) {
|
||||
operations = builder.operations;
|
||||
this.channelNames = channelNames;
|
||||
this.operationsChannelNames = builder.requiredChannelNames;
|
||||
}
|
||||
|
||||
/**
|
||||
* This method will call reset() of all steps implementing {@link ResettableOperation ResettableOperation} in the transform process.
|
||||
*/
|
||||
public void reset() {
|
||||
for(Map.Entry<String, Object> entry : operations) {
|
||||
if(entry.getValue() instanceof ResettableOperation) {
|
||||
((ResettableOperation) entry.getValue()).reset();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Transforms the channel data into an Observation or a skipped observation depending on the specific steps in the transform process.
|
||||
*
|
||||
* @param channelsData A Map that maps the channel name to its data.
|
||||
* @param currentObservationStep The observation's step number within the episode.
|
||||
* @param isFinalObservation True if the observation is the last of the episode.
|
||||
* @return An observation (may be a skipped observation)
|
||||
*/
|
||||
public Observation transform(Map<String, Object> channelsData, int currentObservationStep, boolean isFinalObservation) {
|
||||
// null or empty channelData
|
||||
Preconditions.checkArgument(channelsData != null && channelsData.size() != 0, "Error: channelsData not supplied.");
|
||||
|
||||
// Check that all channels have data
|
||||
for(Map.Entry<String, Object> channel : channelsData.entrySet()) {
|
||||
Preconditions.checkNotNull(channel.getValue(), "Error: data of channel '%s' is null", channel.getKey());
|
||||
}
|
||||
|
||||
// Check that all required channels are present
|
||||
for(String channelName : operationsChannelNames) {
|
||||
Preconditions.checkArgument(channelsData.containsKey(channelName), "The channelsData map does not contain the channel '%s'", channelName);
|
||||
}
|
||||
|
||||
for(Map.Entry<String, Object> entry : operations) {
|
||||
|
||||
// Filter
|
||||
if(entry.getValue() instanceof FilterOperation) {
|
||||
FilterOperation filterOperation = (FilterOperation)entry.getValue();
|
||||
if(filterOperation.isSkipped(channelsData, currentObservationStep, isFinalObservation)) {
|
||||
return Observation.SkippedObservation;
|
||||
}
|
||||
}
|
||||
|
||||
// Transform
|
||||
// null results are considered skipped observations
|
||||
else if(entry.getValue() instanceof Operation) {
|
||||
Operation transformOperation = (Operation)entry.getValue();
|
||||
Object transformed = transformOperation.transform(channelsData.get(entry.getKey()));
|
||||
if(transformed == null) {
|
||||
return Observation.SkippedObservation;
|
||||
}
|
||||
channelsData.replace(entry.getKey(), transformed);
|
||||
}
|
||||
|
||||
// PreProcess
|
||||
else if(entry.getValue() instanceof DataSetPreProcessor) {
|
||||
Object channelData = channelsData.get(entry.getKey());
|
||||
DataSetPreProcessor dataSetPreProcessor = (DataSetPreProcessor)entry.getValue();
|
||||
if(!(channelData instanceof DataSet)) {
|
||||
throw new IllegalArgumentException("The channel data must be a DataSet to call preProcess");
|
||||
}
|
||||
dataSetPreProcessor.preProcess((DataSet)channelData);
|
||||
}
|
||||
|
||||
else {
|
||||
throw new IllegalArgumentException(String.format("Unknown operation: '%s'", entry.getValue().getClass().getName()));
|
||||
}
|
||||
}
|
||||
|
||||
// Check that all channels used to build the observation are instances of
|
||||
// INDArray or DataSet
|
||||
// TODO: Add support for an interface with a toINDArray() method
|
||||
for(String channelName : channelNames) {
|
||||
Object channelData = channelsData.get(channelName);
|
||||
|
||||
INDArray finalChannelData;
|
||||
if(channelData instanceof DataSet) {
|
||||
finalChannelData = ((DataSet)channelData).getFeatures();
|
||||
}
|
||||
else if(channelData instanceof INDArray) {
|
||||
finalChannelData = (INDArray) channelData;
|
||||
}
|
||||
else {
|
||||
throw new IllegalStateException("All channels used to build the observation must be instances of DataSet or INDArray");
|
||||
}
|
||||
|
||||
// The dimension 0 of all INDArrays must be 1 (batch count)
|
||||
channelsData.replace(channelName, INDArrayHelper.forceCorrectShape(finalChannelData));
|
||||
}
|
||||
|
||||
// TODO: Add support to multi-channel observations
|
||||
INDArray data = ((INDArray) channelsData.get(channelNames[0]));
|
||||
return new Observation(data);
|
||||
}
|
||||
|
||||
/**
|
||||
* @return An instance of a builder
|
||||
*/
|
||||
public static Builder builder() {
|
||||
return new Builder();
|
||||
}
|
||||
|
||||
public static class Builder {
|
||||
|
||||
private final List<Map.Entry<String, Object>> operations = new ArrayList<Map.Entry<String, Object>>();
|
||||
private final HashSet<String> requiredChannelNames = new HashSet<String>();
|
||||
|
||||
/**
|
||||
* Add a filter to the transform process steps. Used to skip observations on certain conditions.
|
||||
* See {@link FilterOperation FilterOperation}
|
||||
* @param filterOperation An instance
|
||||
*/
|
||||
public Builder filter(FilterOperation filterOperation) {
|
||||
Preconditions.checkNotNull(filterOperation, "The filterOperation must not be null");
|
||||
|
||||
operations.add((Map.Entry)Maps.immutableEntry(null, filterOperation));
|
||||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* Add a transform to the steps. The transform can change the data and / or change the type of the data
|
||||
* (e.g. Byte[] to a ImageWritable)
|
||||
*
|
||||
* @param targetChannel The name of the channel to which the transform is applied.
|
||||
* @param transformOperation An instance of {@link Operation Operation}
|
||||
*/
|
||||
public Builder transform(String targetChannel, Operation transformOperation) {
|
||||
Preconditions.checkNotNull(targetChannel, "The targetChannel must not be null");
|
||||
Preconditions.checkNotNull(transformOperation, "The transformOperation must not be null");
|
||||
|
||||
requiredChannelNames.add(targetChannel);
|
||||
operations.add((Map.Entry)Maps.immutableEntry(targetChannel, transformOperation));
|
||||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* Add a DataSetPreProcessor to the steps. The channel must be a DataSet instance at this step.
|
||||
* @param targetChannel The name of the channel to which the pre processor is applied.
|
||||
* @param dataSetPreProcessor
|
||||
*/
|
||||
public Builder preProcess(String targetChannel, DataSetPreProcessor dataSetPreProcessor) {
|
||||
Preconditions.checkNotNull(targetChannel, "The targetChannel must not be null");
|
||||
Preconditions.checkNotNull(dataSetPreProcessor, "The dataSetPreProcessor must not be null");
|
||||
|
||||
requiredChannelNames.add(targetChannel);
|
||||
operations.add((Map.Entry)Maps.immutableEntry(targetChannel, dataSetPreProcessor));
|
||||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* Builds the TransformProcess.
|
||||
* @param channelNames A subset of channel names to be used to build the observation
|
||||
* @return An instance of TransformProcess
|
||||
*/
|
||||
public TransformProcess build(String... channelNames) {
|
||||
if(channelNames.length == 0) {
|
||||
throw new IllegalArgumentException("At least one channel must be supplied.");
|
||||
}
|
||||
|
||||
for(String channelName : channelNames) {
|
||||
Preconditions.checkNotNull(channelName, "Error: got a null channel name");
|
||||
requiredChannelNames.add(channelName);
|
||||
}
|
||||
|
||||
// TODO: Remove when multi-channel observation is supported
|
||||
if(channelNames.length != 1) {
|
||||
throw new NotImplementedException("Multi-channel observations is not presently supported.");
|
||||
}
|
||||
|
||||
return new TransformProcess(this, channelNames);
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,45 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2020 Konduit K.K.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
package org.deeplearning4j.rl4j.observation.transform.filter;
|
||||
|
||||
import org.deeplearning4j.rl4j.observation.transform.FilterOperation;
|
||||
import org.nd4j.base.Preconditions;
|
||||
import java.util.Map;
|
||||
|
||||
/**
|
||||
* Used with {@link org.deeplearning4j.rl4j.observation.transform.TransformProcess TransformProcess}. Will cause the
|
||||
* transform process to skip a fixed number of frames between non skipped ones.
|
||||
*
|
||||
* @author Alexandre Boulanger
|
||||
*/
|
||||
public class UniformSkippingFilter implements FilterOperation {
|
||||
|
||||
private final int skipFrame;
|
||||
|
||||
/**
|
||||
* @param skipFrame Will cause the filter to keep (not skip) 1 frame every skipFrames.
|
||||
*/
|
||||
public UniformSkippingFilter(int skipFrame) {
|
||||
Preconditions.checkArgument(skipFrame > 0, "skipFrame should be greater than 0");
|
||||
|
||||
this.skipFrame = skipFrame;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean isSkipped(Map<String, Object> channelsData, int currentObservationStep, boolean isFinalObservation) {
|
||||
return !isFinalObservation && (currentObservationStep % skipFrame != 0);
|
||||
}
|
||||
}
|
|
@ -0,0 +1,41 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2020 Konduit K.K.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
package org.deeplearning4j.rl4j.observation.transform.legacy;
|
||||
|
||||
import org.bytedeco.javacv.OpenCVFrameConverter;
|
||||
import org.bytedeco.opencv.opencv_core.Mat;
|
||||
import org.datavec.api.transform.Operation;
|
||||
import org.datavec.image.data.ImageWritable;
|
||||
import org.deeplearning4j.rl4j.space.Encodable;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
|
||||
import static org.bytedeco.opencv.global.opencv_core.CV_32FC;
|
||||
|
||||
public class EncodableToINDArrayTransform implements Operation<Encodable, INDArray> {
|
||||
|
||||
private final int[] shape;
|
||||
|
||||
public EncodableToINDArrayTransform(int[] shape) {
|
||||
this.shape = shape;
|
||||
}
|
||||
|
||||
@Override
|
||||
public INDArray transform(Encodable encodable) {
|
||||
return Nd4j.create(encodable.toArray()).reshape(shape);
|
||||
}
|
||||
|
||||
}
|
|
@ -0,0 +1,48 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2020 Konduit K.K.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
package org.deeplearning4j.rl4j.observation.transform.legacy;
|
||||
|
||||
import org.bytedeco.javacv.OpenCVFrameConverter;
|
||||
import org.bytedeco.opencv.opencv_core.Mat;
|
||||
import org.datavec.api.transform.Operation;
|
||||
import org.datavec.image.data.ImageWritable;
|
||||
import org.deeplearning4j.rl4j.space.Encodable;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
|
||||
import static org.bytedeco.opencv.global.opencv_core.CV_32FC;
|
||||
|
||||
public class EncodableToImageWritableTransform implements Operation<Encodable, ImageWritable> {
|
||||
|
||||
private final OpenCVFrameConverter.ToMat converter = new OpenCVFrameConverter.ToMat();
|
||||
private final int height;
|
||||
private final int width;
|
||||
private final int colorChannels;
|
||||
|
||||
public EncodableToImageWritableTransform(int height, int width, int colorChannels) {
|
||||
this.height = height;
|
||||
this.width = width;
|
||||
this.colorChannels = colorChannels;
|
||||
}
|
||||
|
||||
@Override
|
||||
public ImageWritable transform(Encodable encodable) {
|
||||
INDArray indArray = Nd4j.create((encodable).toArray()).reshape(height, width, colorChannels);
|
||||
Mat mat = new Mat(height, width, CV_32FC(3), indArray.data().pointer());
|
||||
return new ImageWritable(converter.convert(mat));
|
||||
}
|
||||
|
||||
}
|
|
@ -0,0 +1,52 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2020 Konduit K.K.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
package org.deeplearning4j.rl4j.observation.transform.legacy;
|
||||
|
||||
import org.datavec.api.transform.Operation;
|
||||
import org.datavec.image.data.ImageWritable;
|
||||
import org.datavec.image.loader.NativeImageLoader;
|
||||
import org.deeplearning4j.rl4j.space.Encodable;
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
|
||||
import java.io.IOException;
|
||||
|
||||
public class ImageWritableToINDArrayTransform implements Operation<ImageWritable, INDArray> {
|
||||
|
||||
private final int height;
|
||||
private final int width;
|
||||
private final NativeImageLoader loader;
|
||||
|
||||
public ImageWritableToINDArrayTransform(int height, int width) {
|
||||
this.height = height;
|
||||
this.width = width;
|
||||
this.loader = new NativeImageLoader(height, width);
|
||||
}
|
||||
|
||||
@Override
|
||||
public INDArray transform(ImageWritable imageWritable) {
|
||||
INDArray out = null;
|
||||
try {
|
||||
out = loader.asMatrix(imageWritable);
|
||||
} catch (IOException e) {
|
||||
e.printStackTrace();
|
||||
}
|
||||
out = out.reshape(1, height, width);
|
||||
INDArray compressed = out.castTo(DataType.UINT8);
|
||||
return compressed;
|
||||
}
|
||||
}
|
|
@ -0,0 +1,147 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2020 Konduit K.K.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
package org.deeplearning4j.rl4j.observation.transform.operation;
|
||||
|
||||
import org.datavec.api.transform.Operation;
|
||||
import org.deeplearning4j.rl4j.helper.INDArrayHelper;
|
||||
import org.deeplearning4j.rl4j.observation.transform.ResettableOperation;
|
||||
import org.deeplearning4j.rl4j.observation.transform.operation.historymerge.CircularFifoStore;
|
||||
import org.deeplearning4j.rl4j.observation.transform.operation.historymerge.HistoryMergeAssembler;
|
||||
import org.deeplearning4j.rl4j.observation.transform.operation.historymerge.HistoryMergeElementStore;
|
||||
import org.deeplearning4j.rl4j.observation.transform.operation.historymerge.HistoryStackAssembler;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
|
||||
/**
|
||||
* The HistoryMergeTransform will accumulate features from incoming INDArrays and will assemble its content
|
||||
* into a new INDArray containing a single example.
|
||||
*
|
||||
* This is used in scenarios where motion in an important element.
|
||||
*
|
||||
* There is a special case:
|
||||
* * When the store is not full (not ready), the data from the incoming INDArray is stored but null is returned (will be interpreted as a skipped observation)
|
||||
* <br>
|
||||
* The HistoryMergeTransform requires two sub components: <br>
|
||||
* 1) The {@link HistoryMergeElementStore HistoryMergeElementStore} that supervises what and how input INDArrays are kept. (ex.: Circular FIFO, trailing min/max/avg, etc...)
|
||||
* The default is a Circular FIFO.
|
||||
* 2) The {@link HistoryMergeAssembler HistoryMergeAssembler} that will assemble the store content into a resulting single INDArray. (ex.: stacked along a dimension, squashed into a single observation, etc...)
|
||||
* The default is stacking along the dimension 0.
|
||||
*
|
||||
* @author Alexandre Boulanger
|
||||
*/
|
||||
public class HistoryMergeTransform implements Operation<INDArray, INDArray>, ResettableOperation {
|
||||
|
||||
private final HistoryMergeElementStore historyMergeElementStore;
|
||||
private final HistoryMergeAssembler historyMergeAssembler;
|
||||
private final boolean shouldStoreCopy;
|
||||
private final boolean isFirstDimenstionBatch;
|
||||
|
||||
private HistoryMergeTransform(Builder builder) {
|
||||
this.historyMergeElementStore = builder.historyMergeElementStore;
|
||||
this.historyMergeAssembler = builder.historyMergeAssembler;
|
||||
this.shouldStoreCopy = builder.shouldStoreCopy;
|
||||
this.isFirstDimenstionBatch = builder.isFirstDimenstionBatch;
|
||||
}
|
||||
|
||||
@Override
|
||||
public INDArray transform(INDArray input) {
|
||||
INDArray element;
|
||||
if(isFirstDimenstionBatch) {
|
||||
element = input.slice(0, 0);
|
||||
}
|
||||
else {
|
||||
element = input;
|
||||
}
|
||||
|
||||
if(shouldStoreCopy) {
|
||||
element = element.dup();
|
||||
}
|
||||
|
||||
historyMergeElementStore.add(element);
|
||||
if(!historyMergeElementStore.isReady()) {
|
||||
return null;
|
||||
}
|
||||
|
||||
INDArray result = historyMergeAssembler.assemble(historyMergeElementStore.get());
|
||||
|
||||
return INDArrayHelper.forceCorrectShape(result);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void reset() {
|
||||
historyMergeElementStore.reset();
|
||||
}
|
||||
|
||||
public static Builder builder() {
|
||||
return new Builder();
|
||||
}
|
||||
|
||||
public static class Builder {
|
||||
private HistoryMergeElementStore historyMergeElementStore;
|
||||
private HistoryMergeAssembler historyMergeAssembler;
|
||||
private boolean shouldStoreCopy = false;
|
||||
private boolean isFirstDimenstionBatch = false;
|
||||
|
||||
/**
|
||||
* Default is {@link CircularFifoStore CircularFifoStore}
|
||||
*/
|
||||
public Builder elementStore(HistoryMergeElementStore store) {
|
||||
this.historyMergeElementStore = store;
|
||||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* Default is {@link HistoryStackAssembler HistoryStackAssembler}
|
||||
*/
|
||||
public Builder assembler(HistoryMergeAssembler assembler) {
|
||||
this.historyMergeAssembler = assembler;
|
||||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* If true, tell the HistoryMergeTransform to store copies of incoming INDArrays.
|
||||
* (To prevent later in-place changes to a stored INDArray from changing what has been stored)
|
||||
*
|
||||
* Default is false
|
||||
*/
|
||||
public Builder shouldStoreCopy(boolean shouldStoreCopy) {
|
||||
this.shouldStoreCopy = shouldStoreCopy;
|
||||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* If true, tell the HistoryMergeTransform that the first dimension of the input INDArray is the batch count.
|
||||
* When this is the case, the HistoryMergeTransform will slice the input like this [batch, height, width] -> [height, width]
|
||||
*
|
||||
* Default is false
|
||||
*/
|
||||
public Builder isFirstDimenstionBatch(boolean isFirstDimenstionBatch) {
|
||||
this.isFirstDimenstionBatch = isFirstDimenstionBatch;
|
||||
return this;
|
||||
}
|
||||
|
||||
public HistoryMergeTransform build() {
|
||||
if(historyMergeElementStore == null) {
|
||||
historyMergeElementStore = new CircularFifoStore();
|
||||
}
|
||||
|
||||
if(historyMergeAssembler == null) {
|
||||
historyMergeAssembler = new HistoryStackAssembler();
|
||||
}
|
||||
|
||||
return new HistoryMergeTransform(this);
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,44 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2020 Konduit K.K.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
package org.deeplearning4j.rl4j.observation.transform.operation;
|
||||
|
||||
import org.datavec.api.transform.Operation;
|
||||
import org.nd4j.base.Preconditions;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
|
||||
public class SimpleNormalizationTransform implements Operation<INDArray, INDArray> {
|
||||
|
||||
private final double offset;
|
||||
private final double divisor;
|
||||
|
||||
public SimpleNormalizationTransform(double min, double max) {
|
||||
Preconditions.checkArgument(min < max, "Min must be smaller than max.");
|
||||
|
||||
this.offset = min;
|
||||
this.divisor = (max - min);
|
||||
}
|
||||
|
||||
@Override
|
||||
public INDArray transform(INDArray input) {
|
||||
if(offset != 0.0) {
|
||||
input.subi(offset);
|
||||
}
|
||||
|
||||
input.divi(divisor);
|
||||
|
||||
return input;
|
||||
}
|
||||
}
|
|
@ -1,5 +1,5 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2019 Skymind, Inc.
|
||||
* Copyright (c) 2020 Konduit K.K.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
|
@ -13,51 +13,47 @@
|
|||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
package org.deeplearning4j.rl4j.observation.preprocessor.pooling;
|
||||
package org.deeplearning4j.rl4j.observation.transform.operation.historymerge;
|
||||
|
||||
import org.apache.commons.collections4.queue.CircularFifoQueue;
|
||||
import org.deeplearning4j.rl4j.observation.transform.operation.HistoryMergeTransform;
|
||||
import org.nd4j.base.Preconditions;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
|
||||
/**
|
||||
* CircularFifoObservationPool is used with the PoolingDataSetPreProcessor. This pool is a first-in first-out queue
|
||||
* CircularFifoStore is used with the {@link HistoryMergeTransform HistoryMergeTransform}. This store is a first-in first-out queue
|
||||
* with a fixed size that replaces its oldest element if full.
|
||||
*
|
||||
* @author Alexandre Boulanger
|
||||
*/
|
||||
public class CircularFifoObservationPool implements ObservationPool {
|
||||
private static final int DEFAULT_POOL_SIZE = 4;
|
||||
public class CircularFifoStore implements HistoryMergeElementStore {
|
||||
private static final int DEFAULT_STORE_SIZE = 4;
|
||||
|
||||
private final CircularFifoQueue<INDArray> queue;
|
||||
|
||||
private CircularFifoObservationPool(Builder builder) {
|
||||
queue = new CircularFifoQueue<>(builder.poolSize);
|
||||
public CircularFifoStore() {
|
||||
this(DEFAULT_STORE_SIZE);
|
||||
}
|
||||
|
||||
public CircularFifoObservationPool()
|
||||
{
|
||||
this(DEFAULT_POOL_SIZE);
|
||||
}
|
||||
|
||||
public CircularFifoObservationPool(int poolSize)
|
||||
{
|
||||
Preconditions.checkArgument(poolSize > 0, "The pool size must be at least 1, got %s", poolSize);
|
||||
queue = new CircularFifoQueue<>(poolSize);
|
||||
public CircularFifoStore(int size) {
|
||||
Preconditions.checkArgument(size > 0, "The size must be at least 1, got %s", size);
|
||||
queue = new CircularFifoQueue<>(size);
|
||||
}
|
||||
|
||||
/**
|
||||
* Add an element to the pool, if this addition would make the pool to overflow, the added element replaces the oldest one.
|
||||
* Add an element to the store, if this addition would make the store to overflow, the new element replaces the oldest.
|
||||
* @param elem
|
||||
*/
|
||||
@Override
|
||||
public void add(INDArray elem) {
|
||||
queue.add(elem);
|
||||
}
|
||||
|
||||
/**
|
||||
* @return The content of the pool, returned in order from oldest to newest.
|
||||
* @return The content of the store, returned in order from oldest to newest.
|
||||
*/
|
||||
@Override
|
||||
public INDArray[] get() {
|
||||
int size = queue.size();
|
||||
INDArray[] array = new INDArray[size];
|
||||
|
@ -67,29 +63,20 @@ public class CircularFifoObservationPool implements ObservationPool {
|
|||
return array;
|
||||
}
|
||||
|
||||
public boolean isAtFullCapacity() {
|
||||
/**
|
||||
* The CircularFifoStore needs to be completely filled before being ready.
|
||||
* @return false when the number of elements in the store is less than the store capacity (default is 4)
|
||||
*/
|
||||
@Override
|
||||
public boolean isReady() {
|
||||
return queue.isAtFullCapacity();
|
||||
}
|
||||
|
||||
/**
|
||||
* Clears the store.
|
||||
*/
|
||||
@Override
|
||||
public void reset() {
|
||||
queue.clear();
|
||||
}
|
||||
|
||||
public static Builder builder() {
|
||||
return new Builder();
|
||||
}
|
||||
|
||||
public static class Builder {
|
||||
private int poolSize = DEFAULT_POOL_SIZE;
|
||||
|
||||
public Builder poolSize(int poolSize) {
|
||||
this.poolSize = poolSize;
|
||||
return this;
|
||||
}
|
||||
|
||||
public CircularFifoObservationPool build() {
|
||||
return new CircularFifoObservationPool(this);
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,35 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2020 Konduit K.K.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
package org.deeplearning4j.rl4j.observation.transform.operation.historymerge;
|
||||
|
||||
import org.deeplearning4j.rl4j.observation.transform.operation.HistoryMergeTransform;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
|
||||
/**
|
||||
* A HistoryMergeAssembler is used with the {@link HistoryMergeTransform HistoryMergeTransform}. This interface defines how the array of INDArray
|
||||
* given by the {@link HistoryMergeElementStore HistoryMergeElementStore} is packaged into the single INDArray that will be
|
||||
* returned by the HistoryMergeTransform
|
||||
*
|
||||
* @author Alexandre Boulanger
|
||||
*/
|
||||
public interface HistoryMergeAssembler {
|
||||
/**
|
||||
* Assemble an array of INDArray into a single INArray
|
||||
* @param elements The input INDArray[]
|
||||
* @return the assembled INDArray
|
||||
*/
|
||||
INDArray assemble(INDArray[] elements);
|
||||
}
|
|
@ -0,0 +1,51 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2020 Konduit K.K.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
package org.deeplearning4j.rl4j.observation.transform.operation.historymerge;
|
||||
|
||||
import org.deeplearning4j.rl4j.observation.transform.operation.HistoryMergeTransform;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
|
||||
/**
|
||||
* HistoryMergeElementStore is used with the {@link HistoryMergeTransform HistoryMergeTransform}. Used to supervise how data from the
|
||||
* HistoryMergeTransform is stored.
|
||||
*
|
||||
* @author Alexandre Boulanger
|
||||
*/
|
||||
public interface HistoryMergeElementStore {
|
||||
/**
|
||||
* Add an element into the store
|
||||
* @param observation
|
||||
*/
|
||||
void add(INDArray observation);
|
||||
|
||||
/**
|
||||
* Get the content of the store
|
||||
* @return the content of the store
|
||||
*/
|
||||
INDArray[] get();
|
||||
|
||||
/**
|
||||
* Used to tell the HistoryMergeTransform that the store is ready. The HistoryMergeTransform will tell the {@link org.deeplearning4j.rl4j.observation.transform.TransformProcess TransformProcess}
|
||||
* to skip the observation is the store is not ready.
|
||||
* @return true if the store is ready
|
||||
*/
|
||||
boolean isReady();
|
||||
|
||||
/**
|
||||
* Resets the store to an initial state.
|
||||
*/
|
||||
void reset();
|
||||
}
|
|
@ -1,5 +1,5 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2019 Skymind, Inc.
|
||||
* Copyright (c) 2020 Konduit K.K.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
|
@ -14,39 +14,38 @@
|
|||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
package org.deeplearning4j.rl4j.observation.preprocessor.pooling;
|
||||
package org.deeplearning4j.rl4j.observation.transform.operation.historymerge;
|
||||
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
|
||||
/**
|
||||
* ChannelStackPoolContentAssembler is used with the PoolingDataSetPreProcessor. This assembler will
|
||||
* stack along the dimension 0. For example if the pool elements are of shape [ Height, Width ]
|
||||
* HistoryStackAssembler is used with the HistoryMergeTransform. This assembler will
|
||||
* stack along the dimension 0. For example if the store elements are of shape [ Height, Width ]
|
||||
* the output will be of shape [ Stacked, Height, Width ]
|
||||
*
|
||||
* @author Alexandre Boulanger
|
||||
*/
|
||||
public class ChannelStackPoolContentAssembler implements PoolContentAssembler {
|
||||
public class HistoryStackAssembler implements HistoryMergeAssembler {
|
||||
|
||||
/**
|
||||
* Will return a new INDArray with one more dimension and with poolContent stacked along dimension 0.
|
||||
* Will return a new INDArray with one more dimension and with elements stacked along dimension 0.
|
||||
*
|
||||
* @param poolContent Array of INDArray
|
||||
* @param elements Array of INDArray
|
||||
* @return A new INDArray with 1 more dimension than the input elements
|
||||
*/
|
||||
@Override
|
||||
public INDArray assemble(INDArray[] poolContent)
|
||||
{
|
||||
public INDArray assemble(INDArray[] elements) {
|
||||
// build the new shape
|
||||
long[] elementShape = poolContent[0].shape();
|
||||
long[] elementShape = elements[0].shape();
|
||||
long[] newShape = new long[elementShape.length + 1];
|
||||
newShape[0] = poolContent.length;
|
||||
newShape[0] = elements.length;
|
||||
System.arraycopy(elementShape, 0, newShape, 1, elementShape.length);
|
||||
|
||||
// put pool elements in result
|
||||
// stack the elements in result on the dimension 0
|
||||
INDArray result = Nd4j.create(newShape);
|
||||
for(int i = 0; i < poolContent.length; ++i) {
|
||||
result.putRow(i, poolContent[i]);
|
||||
for(int i = 0; i < elements.length; ++i) {
|
||||
result.putRow(i, elements[i]);
|
||||
}
|
||||
return result;
|
||||
}
|
|
@ -89,7 +89,7 @@ public abstract class Policy<O, A> implements IPolicy<O, A> {
|
|||
getNeuralNet().reset();
|
||||
}
|
||||
|
||||
private <AS extends ActionSpace<A>> Learning.InitMdp<Observation> refacInitMdp(LegacyMDPWrapper<O, A, AS> mdpWrapper, IHistoryProcessor hp, RefacEpochStepCounter epochStepCounter) {
|
||||
protected <AS extends ActionSpace<A>> Learning.InitMdp<Observation> refacInitMdp(LegacyMDPWrapper<O, A, AS> mdpWrapper, IHistoryProcessor hp, RefacEpochStepCounter epochStepCounter) {
|
||||
epochStepCounter.setCurrentEpochStep(0);
|
||||
|
||||
double reward = 0;
|
||||
|
|
|
@ -3,37 +3,95 @@ package org.deeplearning4j.rl4j.util;
|
|||
import lombok.AccessLevel;
|
||||
import lombok.Getter;
|
||||
import lombok.Setter;
|
||||
import org.datavec.image.transform.ColorConversionTransform;
|
||||
import org.datavec.image.transform.CropImageTransform;
|
||||
import org.datavec.image.transform.MultiImageTransform;
|
||||
import org.datavec.image.transform.ResizeImageTransform;
|
||||
import org.deeplearning4j.gym.StepReply;
|
||||
import org.deeplearning4j.rl4j.learning.EpochStepCounter;
|
||||
import org.deeplearning4j.rl4j.learning.IHistoryProcessor;
|
||||
import org.deeplearning4j.rl4j.mdp.MDP;
|
||||
import org.deeplearning4j.rl4j.observation.Observation;
|
||||
import org.deeplearning4j.rl4j.observation.transform.TransformProcess;
|
||||
import org.deeplearning4j.rl4j.observation.transform.filter.UniformSkippingFilter;
|
||||
import org.deeplearning4j.rl4j.observation.transform.legacy.EncodableToINDArrayTransform;
|
||||
import org.deeplearning4j.rl4j.observation.transform.legacy.EncodableToImageWritableTransform;
|
||||
import org.deeplearning4j.rl4j.observation.transform.legacy.ImageWritableToINDArrayTransform;
|
||||
import org.deeplearning4j.rl4j.observation.transform.operation.HistoryMergeTransform;
|
||||
import org.deeplearning4j.rl4j.observation.transform.operation.SimpleNormalizationTransform;
|
||||
import org.deeplearning4j.rl4j.space.ActionSpace;
|
||||
import org.deeplearning4j.rl4j.space.Encodable;
|
||||
import org.deeplearning4j.rl4j.space.ObservationSpace;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
|
||||
import java.util.HashMap;
|
||||
import java.util.Map;
|
||||
|
||||
import static org.bytedeco.opencv.global.opencv_imgproc.COLOR_BGR2GRAY;
|
||||
|
||||
public class LegacyMDPWrapper<O, A, AS extends ActionSpace<A>> implements MDP<Observation, A, AS> {
|
||||
|
||||
@Getter
|
||||
private final MDP<O, A, AS> wrappedMDP;
|
||||
@Getter
|
||||
private final WrapperObservationSpace observationSpace;
|
||||
private final int[] shape;
|
||||
|
||||
@Getter(AccessLevel.PRIVATE) @Setter(AccessLevel.PUBLIC)
|
||||
@Setter
|
||||
private TransformProcess transformProcess;
|
||||
|
||||
@Getter(AccessLevel.PRIVATE)
|
||||
private IHistoryProcessor historyProcessor;
|
||||
|
||||
private final EpochStepCounter epochStepCounter;
|
||||
|
||||
private int skipFrame = 1;
|
||||
private int requiredFrame = 0;
|
||||
|
||||
public LegacyMDPWrapper(MDP<O, A, AS> wrappedMDP, IHistoryProcessor historyProcessor, EpochStepCounter epochStepCounter) {
|
||||
this.wrappedMDP = wrappedMDP;
|
||||
this.observationSpace = new WrapperObservationSpace(wrappedMDP.getObservationSpace().getShape());
|
||||
this.shape = wrappedMDP.getObservationSpace().getShape();
|
||||
this.observationSpace = new WrapperObservationSpace(shape);
|
||||
this.historyProcessor = historyProcessor;
|
||||
this.epochStepCounter = epochStepCounter;
|
||||
|
||||
setHistoryProcessor(historyProcessor);
|
||||
}
|
||||
|
||||
public void setHistoryProcessor(IHistoryProcessor historyProcessor) {
|
||||
this.historyProcessor = historyProcessor;
|
||||
createTransformProcess();
|
||||
}
|
||||
|
||||
private void createTransformProcess() {
|
||||
IHistoryProcessor historyProcessor = getHistoryProcessor();
|
||||
|
||||
if(historyProcessor != null && shape.length == 3) {
|
||||
int skipFrame = historyProcessor.getConf().getSkipFrame();
|
||||
|
||||
int finalHeight = historyProcessor.getConf().getCroppingHeight();
|
||||
int finalWidth = historyProcessor.getConf().getCroppingWidth();
|
||||
|
||||
transformProcess = TransformProcess.builder()
|
||||
.filter(new UniformSkippingFilter(skipFrame))
|
||||
.transform("data", new EncodableToImageWritableTransform(shape[0], shape[1], shape[2]))
|
||||
.transform("data", new MultiImageTransform(
|
||||
new ResizeImageTransform(historyProcessor.getConf().getRescaledWidth(), historyProcessor.getConf().getRescaledHeight()),
|
||||
new ColorConversionTransform(COLOR_BGR2GRAY),
|
||||
new CropImageTransform(historyProcessor.getConf().getOffsetY(), historyProcessor.getConf().getOffsetX(), finalHeight, finalWidth)
|
||||
))
|
||||
.transform("data", new ImageWritableToINDArrayTransform(finalHeight, finalWidth))
|
||||
.transform("data", new SimpleNormalizationTransform(0.0, 255.0))
|
||||
.transform("data", HistoryMergeTransform.builder()
|
||||
.isFirstDimenstionBatch(true)
|
||||
.build())
|
||||
.build("data");
|
||||
}
|
||||
else {
|
||||
transformProcess = TransformProcess.builder()
|
||||
.transform("data", new EncodableToINDArrayTransform(shape))
|
||||
.build("data");
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -43,25 +101,17 @@ public class LegacyMDPWrapper<O, A, AS extends ActionSpace<A>> implements MDP<Ob
|
|||
|
||||
@Override
|
||||
public Observation reset() {
|
||||
INDArray rawObservation = getInput(wrappedMDP.reset());
|
||||
transformProcess.reset();
|
||||
|
||||
IHistoryProcessor historyProcessor = getHistoryProcessor();
|
||||
if(historyProcessor != null) {
|
||||
historyProcessor.record(rawObservation);
|
||||
}
|
||||
|
||||
Observation observation = new Observation(new INDArray[] { rawObservation }, false);
|
||||
O rawResetResponse = wrappedMDP.reset();
|
||||
record(rawResetResponse);
|
||||
|
||||
if(historyProcessor != null) {
|
||||
skipFrame = historyProcessor.getConf().getSkipFrame();
|
||||
requiredFrame = skipFrame * (historyProcessor.getConf().getHistoryLength() - 1);
|
||||
|
||||
historyProcessor.add(rawObservation);
|
||||
}
|
||||
|
||||
observation.setSkipped(skipFrame != 0);
|
||||
|
||||
return observation;
|
||||
Map<String, Object> channelsData = buildChannelsData(rawResetResponse);
|
||||
return transformProcess.transform(channelsData, 0, false);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -71,32 +121,32 @@ public class LegacyMDPWrapper<O, A, AS extends ActionSpace<A>> implements MDP<Ob
|
|||
StepReply<O> rawStepReply = wrappedMDP.step(a);
|
||||
INDArray rawObservation = getInput(rawStepReply.getObservation());
|
||||
|
||||
int stepOfObservation = epochStepCounter.getCurrentEpochStep() + 1;
|
||||
|
||||
if(historyProcessor != null) {
|
||||
historyProcessor.record(rawObservation);
|
||||
|
||||
if (stepOfObservation % skipFrame == 0) {
|
||||
historyProcessor.add(rawObservation);
|
||||
}
|
||||
}
|
||||
|
||||
Observation observation;
|
||||
if(historyProcessor != null && stepOfObservation >= requiredFrame) {
|
||||
observation = new Observation(historyProcessor.getHistory(), true);
|
||||
observation.getData().muli(1.0 / historyProcessor.getScale());
|
||||
}
|
||||
else {
|
||||
observation = new Observation(new INDArray[] { rawObservation }, false);
|
||||
}
|
||||
|
||||
if(stepOfObservation % skipFrame != 0 || stepOfObservation < requiredFrame) {
|
||||
observation.setSkipped(true);
|
||||
}
|
||||
int stepOfObservation = epochStepCounter.getCurrentEpochStep() + 1;
|
||||
|
||||
Map<String, Object> channelsData = buildChannelsData(rawStepReply.getObservation());
|
||||
Observation observation = transformProcess.transform(channelsData, stepOfObservation, rawStepReply.isDone());
|
||||
return new StepReply<Observation>(observation, rawStepReply.getReward(), rawStepReply.isDone(), rawStepReply.getInfo());
|
||||
}
|
||||
|
||||
private void record(O obs) {
|
||||
INDArray rawObservation = getInput(obs);
|
||||
|
||||
IHistoryProcessor historyProcessor = getHistoryProcessor();
|
||||
if(historyProcessor != null) {
|
||||
historyProcessor.record(rawObservation);
|
||||
}
|
||||
}
|
||||
|
||||
private Map<String, Object> buildChannelsData(final O obs) {
|
||||
return new HashMap<String, Object>() {{
|
||||
put("data", obs);
|
||||
}};
|
||||
}
|
||||
|
||||
@Override
|
||||
public void close() {
|
||||
wrappedMDP.close();
|
||||
|
|
|
@ -0,0 +1,38 @@
|
|||
package org.deeplearning4j.rl4j.helper;
|
||||
|
||||
import org.junit.Test;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
|
||||
import static org.junit.Assert.*;
|
||||
|
||||
public class INDArrayHelperTest {
|
||||
@Test
|
||||
public void when_inputHasIncorrectShape_expect_outputWithCorrectShape() {
|
||||
// Arrange
|
||||
INDArray input = Nd4j.create(new double[] { 1.0, 2.0, 3.0});
|
||||
|
||||
// Act
|
||||
INDArray output = INDArrayHelper.forceCorrectShape(input);
|
||||
|
||||
// Assert
|
||||
assertEquals(2, output.shape().length);
|
||||
assertEquals(1, output.shape()[0]);
|
||||
assertEquals(3, output.shape()[1]);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void when_inputHasCorrectShape_expect_outputWithSameShape() {
|
||||
// Arrange
|
||||
INDArray input = Nd4j.create(new double[] { 1.0, 2.0, 3.0}).reshape(1, 3);
|
||||
|
||||
// Act
|
||||
INDArray output = INDArrayHelper.forceCorrectShape(input);
|
||||
|
||||
// Assert
|
||||
assertEquals(2, output.shape().length);
|
||||
assertEquals(1, output.shape()[0]);
|
||||
assertEquals(3, output.shape()[1]);
|
||||
}
|
||||
|
||||
}
|
|
@ -34,6 +34,7 @@ public class AsyncThreadDiscreteTest {
|
|||
MockPolicy policyMock = new MockPolicy();
|
||||
MockAsyncConfiguration config = new MockAsyncConfiguration(5, 100, 0, 0, 2, 5,0, 0, 0, 0);
|
||||
TestAsyncThreadDiscrete sut = new TestAsyncThreadDiscrete(asyncGlobalMock, mdpMock, listeners, 0, 0, policyMock, config, hpMock);
|
||||
sut.getLegacyMDPWrapper().setTransformProcess(MockMDP.buildTransformProcess(observationSpace.getShape(), hpConf.getSkipFrame(), hpConf.getHistoryLength()));
|
||||
|
||||
// Act
|
||||
sut.run();
|
||||
|
@ -60,12 +61,6 @@ public class AsyncThreadDiscreteTest {
|
|||
assertEquals(2, asyncGlobalMock.enqueueCallCount);
|
||||
|
||||
// HistoryProcessor
|
||||
double[] expectedAddValues = new double[] { 0.0, 2.0, 4.0, 6.0, 8.0, 10.0, 12.0, 14.0, 16.0 };
|
||||
assertEquals(expectedAddValues.length, hpMock.addCalls.size());
|
||||
for(int i = 0; i < expectedAddValues.length; ++i) {
|
||||
assertEquals(expectedAddValues[i], hpMock.addCalls.get(i).getDouble(0), 0.00001);
|
||||
}
|
||||
|
||||
double[] expectedRecordValues = new double[] { 0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0, };
|
||||
assertEquals(expectedRecordValues.length, hpMock.recordCalls.size());
|
||||
for(int i = 0; i < expectedRecordValues.length; ++i) {
|
||||
|
|
|
@ -138,6 +138,7 @@ public class AsyncThreadTest {
|
|||
asyncGlobal.setMaxLoops(numEpochs);
|
||||
listeners.add(listener);
|
||||
sut.setHistoryProcessor(historyProcessor);
|
||||
sut.getLegacyMDPWrapper().setTransformProcess(MockMDP.buildTransformProcess(observationSpace.getShape(), hpConf.getSkipFrame(), hpConf.getHistoryLength()));
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -209,7 +210,4 @@ public class AsyncThreadTest {
|
|||
int nstep;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
}
|
||||
|
|
|
@ -18,7 +18,7 @@ public class ExpReplayTest {
|
|||
ExpReplay<Integer> sut = new ExpReplay<Integer>(2, 1, randomMock);
|
||||
|
||||
// Act
|
||||
Transition<Integer> transition = buildTransition(new Observation(new INDArray[] { Nd4j.create(1) }),
|
||||
Transition<Integer> transition = buildTransition(buildObservation(),
|
||||
123, 234, new Observation(Nd4j.create(1)));
|
||||
sut.store(transition);
|
||||
List<Transition<Integer>> results = sut.getBatch(1);
|
||||
|
@ -36,11 +36,11 @@ public class ExpReplayTest {
|
|||
ExpReplay<Integer> sut = new ExpReplay<Integer>(2, 1, randomMock);
|
||||
|
||||
// Act
|
||||
Transition<Integer> transition1 = buildTransition(new Observation(new INDArray[] { Nd4j.create(1) }),
|
||||
Transition<Integer> transition1 = buildTransition(buildObservation(),
|
||||
1, 2, new Observation(Nd4j.create(1)));
|
||||
Transition<Integer> transition2 = buildTransition(new Observation(new INDArray[] { Nd4j.create(1) }),
|
||||
Transition<Integer> transition2 = buildTransition(buildObservation(),
|
||||
3, 4, new Observation(Nd4j.create(1)));
|
||||
Transition<Integer> transition3 = buildTransition(new Observation(new INDArray[] { Nd4j.create(1) }),
|
||||
Transition<Integer> transition3 = buildTransition(buildObservation(),
|
||||
5, 6, new Observation(Nd4j.create(1)));
|
||||
sut.store(transition1);
|
||||
sut.store(transition2);
|
||||
|
@ -78,11 +78,11 @@ public class ExpReplayTest {
|
|||
ExpReplay<Integer> sut = new ExpReplay<Integer>(5, 1, randomMock);
|
||||
|
||||
// Act
|
||||
Transition<Integer> transition1 = buildTransition(new Observation(new INDArray[] { Nd4j.create(1) }),
|
||||
Transition<Integer> transition1 = buildTransition(buildObservation(),
|
||||
1, 2, new Observation(Nd4j.create(1)));
|
||||
Transition<Integer> transition2 = buildTransition(new Observation(new INDArray[] { Nd4j.create(1) }),
|
||||
Transition<Integer> transition2 = buildTransition(buildObservation(),
|
||||
3, 4, new Observation(Nd4j.create(1)));
|
||||
Transition<Integer> transition3 = buildTransition(new Observation(new INDArray[] { Nd4j.create(1) }),
|
||||
Transition<Integer> transition3 = buildTransition(buildObservation(),
|
||||
5, 6, new Observation(Nd4j.create(1)));
|
||||
sut.store(transition1);
|
||||
sut.store(transition2);
|
||||
|
@ -100,11 +100,11 @@ public class ExpReplayTest {
|
|||
ExpReplay<Integer> sut = new ExpReplay<Integer>(5, 1, randomMock);
|
||||
|
||||
// Act
|
||||
Transition<Integer> transition1 = buildTransition(new Observation(new INDArray[] { Nd4j.create(1) }),
|
||||
Transition<Integer> transition1 = buildTransition(buildObservation(),
|
||||
1, 2, new Observation(Nd4j.create(1)));
|
||||
Transition<Integer> transition2 = buildTransition(new Observation(new INDArray[] { Nd4j.create(1) }),
|
||||
Transition<Integer> transition2 = buildTransition(buildObservation(),
|
||||
3, 4, new Observation(Nd4j.create(1)));
|
||||
Transition<Integer> transition3 = buildTransition(new Observation(new INDArray[] { Nd4j.create(1) }),
|
||||
Transition<Integer> transition3 = buildTransition(buildObservation(),
|
||||
5, 6, new Observation(Nd4j.create(1)));
|
||||
sut.store(transition1);
|
||||
sut.store(transition2);
|
||||
|
@ -131,15 +131,15 @@ public class ExpReplayTest {
|
|||
ExpReplay<Integer> sut = new ExpReplay<Integer>(5, 1, randomMock);
|
||||
|
||||
// Act
|
||||
Transition<Integer> transition1 = buildTransition(new Observation(new INDArray[] { Nd4j.create(1) }),
|
||||
Transition<Integer> transition1 = buildTransition(buildObservation(),
|
||||
1, 2, new Observation(Nd4j.create(1)));
|
||||
Transition<Integer> transition2 = buildTransition(new Observation(new INDArray[] { Nd4j.create(1) }),
|
||||
Transition<Integer> transition2 = buildTransition(buildObservation(),
|
||||
3, 4, new Observation(Nd4j.create(1)));
|
||||
Transition<Integer> transition3 = buildTransition(new Observation(new INDArray[] { Nd4j.create(1) }),
|
||||
Transition<Integer> transition3 = buildTransition(buildObservation(),
|
||||
5, 6, new Observation(Nd4j.create(1)));
|
||||
Transition<Integer> transition4 = buildTransition(new Observation(new INDArray[] { Nd4j.create(1) }),
|
||||
Transition<Integer> transition4 = buildTransition(buildObservation(),
|
||||
7, 8, new Observation(Nd4j.create(1)));
|
||||
Transition<Integer> transition5 = buildTransition(new Observation(new INDArray[] { Nd4j.create(1) }),
|
||||
Transition<Integer> transition5 = buildTransition(buildObservation(),
|
||||
9, 10, new Observation(Nd4j.create(1)));
|
||||
sut.store(transition1);
|
||||
sut.store(transition2);
|
||||
|
@ -168,15 +168,15 @@ public class ExpReplayTest {
|
|||
ExpReplay<Integer> sut = new ExpReplay<Integer>(5, 1, randomMock);
|
||||
|
||||
// Act
|
||||
Transition<Integer> transition1 = buildTransition(new Observation(new INDArray[] { Nd4j.create(1) }),
|
||||
Transition<Integer> transition1 = buildTransition(buildObservation(),
|
||||
1, 2, new Observation(Nd4j.create(1)));
|
||||
Transition<Integer> transition2 = buildTransition(new Observation(new INDArray[] { Nd4j.create(1) }),
|
||||
Transition<Integer> transition2 = buildTransition(buildObservation(),
|
||||
3, 4, new Observation(Nd4j.create(1)));
|
||||
Transition<Integer> transition3 = buildTransition(new Observation(new INDArray[] { Nd4j.create(1) }),
|
||||
Transition<Integer> transition3 = buildTransition(buildObservation(),
|
||||
5, 6, new Observation(Nd4j.create(1)));
|
||||
Transition<Integer> transition4 = buildTransition(new Observation(new INDArray[] { Nd4j.create(1) }),
|
||||
Transition<Integer> transition4 = buildTransition(buildObservation(),
|
||||
7, 8, new Observation(Nd4j.create(1)));
|
||||
Transition<Integer> transition5 = buildTransition(new Observation(new INDArray[] { Nd4j.create(1) }),
|
||||
Transition<Integer> transition5 = buildTransition(buildObservation(),
|
||||
9, 10, new Observation(Nd4j.create(1)));
|
||||
sut.store(transition1);
|
||||
sut.store(transition2);
|
||||
|
@ -204,4 +204,8 @@ public class ExpReplayTest {
|
|||
|
||||
return result;
|
||||
}
|
||||
|
||||
private Observation buildObservation() {
|
||||
return new Observation(Nd4j.create(1, 1));
|
||||
}
|
||||
}
|
||||
|
|
|
@ -193,11 +193,11 @@ public class TransitionTest {
|
|||
Nd4j.create(obs[1]).reshape(1, 3),
|
||||
Nd4j.create(obs[2]).reshape(1, 3),
|
||||
};
|
||||
return new Observation(history);
|
||||
return new Observation(Nd4j.concat(0, history));
|
||||
}
|
||||
|
||||
private Observation buildObservation(double[] obs) {
|
||||
return new Observation(new INDArray[] { Nd4j.create(obs).reshape(1, 3) });
|
||||
return new Observation(Nd4j.create(obs).reshape(1, 3));
|
||||
}
|
||||
|
||||
private Observation buildNextObservation(double[][] obs, double[] nextObs) {
|
||||
|
@ -206,7 +206,7 @@ public class TransitionTest {
|
|||
Nd4j.create(obs[0]).reshape(1, 3),
|
||||
Nd4j.create(obs[1]).reshape(1, 3),
|
||||
};
|
||||
return new Observation(nextHistory);
|
||||
return new Observation(Nd4j.concat(0, nextHistory));
|
||||
}
|
||||
|
||||
private Transition buildTransition(Observation observation, int action, double reward, Observation nextObservation) {
|
||||
|
|
|
@ -50,6 +50,7 @@ public class QLearningDiscreteTest {
|
|||
IHistoryProcessor.Configuration hpConf = new IHistoryProcessor.Configuration(5, 4, 4, 4, 4, 0, 0, 2);
|
||||
MockHistoryProcessor hp = new MockHistoryProcessor(hpConf);
|
||||
sut.setHistoryProcessor(hp);
|
||||
sut.getLegacyMDPWrapper().setTransformProcess(MockMDP.buildTransformProcess(observationSpace.getShape(), hpConf.getSkipFrame(), hpConf.getHistoryLength()));
|
||||
List<QLearning.QLStepReturn<MockEncodable>> results = new ArrayList<>();
|
||||
|
||||
// Act
|
||||
|
@ -62,11 +63,7 @@ public class QLearningDiscreteTest {
|
|||
for(int i = 0; i < expectedRecords.length; ++i) {
|
||||
assertEquals(expectedRecords[i], hp.recordCalls.get(i).getDouble(0), 0.0001);
|
||||
}
|
||||
double[] expectedAdds = new double[] { 0.0, 2.0, 4.0, 6.0, 8.0, 10.0, 12.0, 14.0, 16.0, 18.0, 20.0, 22.0, 24.0 };
|
||||
assertEquals(expectedAdds.length, hp.addCalls.size());
|
||||
for(int i = 0; i < expectedAdds.length; ++i) {
|
||||
assertEquals(expectedAdds[i], hp.addCalls.get(i).getDouble(0), 0.0001);
|
||||
}
|
||||
|
||||
assertEquals(0, hp.startMonitorCallCount);
|
||||
assertEquals(0, hp.stopMonitorCallCount);
|
||||
|
||||
|
|
|
@ -106,7 +106,7 @@ public class DoubleDQNTest {
|
|||
}
|
||||
|
||||
private Observation buildObservation(double[] data) {
|
||||
return new Observation(new INDArray[]{Nd4j.create(data).reshape(1, 2)});
|
||||
return new Observation(Nd4j.create(data).reshape(1, 2));
|
||||
}
|
||||
|
||||
private Transition<Integer> builtTransition(Observation observation, Integer action, double reward, boolean isTerminal, Observation nextObservation) {
|
||||
|
|
|
@ -105,7 +105,7 @@ public class StandardDQNTest {
|
|||
}
|
||||
|
||||
private Observation buildObservation(double[] data) {
|
||||
return new Observation(new INDArray[]{Nd4j.create(data).reshape(1, 2)});
|
||||
return new Observation(Nd4j.create(data).reshape(1, 2));
|
||||
}
|
||||
|
||||
private Transition<Integer> buildTransition(Observation observation, Integer action, double reward, boolean isTerminal, Observation nextObservation) {
|
||||
|
|
|
@ -1,164 +0,0 @@
|
|||
package org.deeplearning4j.rl4j.observation.preprocessor;
|
||||
|
||||
import org.deeplearning4j.rl4j.observation.preprocessor.pooling.ObservationPool;
|
||||
import org.deeplearning4j.rl4j.observation.preprocessor.pooling.PoolContentAssembler;
|
||||
import org.junit.Assert;
|
||||
import org.junit.Test;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.dataset.DataSet;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
|
||||
import static junit.framework.TestCase.assertTrue;
|
||||
import static org.junit.Assert.assertEquals;
|
||||
|
||||
public class PoolingDataSetPreProcessorTest {
|
||||
|
||||
@Test(expected = NullPointerException.class)
|
||||
public void when_dataSetIsNull_expect_NullPointerException() {
|
||||
// Assemble
|
||||
PoolingDataSetPreProcessor sut = PoolingDataSetPreProcessor.builder().build();
|
||||
|
||||
// Act
|
||||
sut.preProcess(null);
|
||||
}
|
||||
|
||||
@Test(expected = IllegalArgumentException.class)
|
||||
public void when_dataSetHasMoreThanOneExample_expect_IllegalArgumentException() {
|
||||
// Assemble
|
||||
PoolingDataSetPreProcessor sut = PoolingDataSetPreProcessor.builder().build();
|
||||
|
||||
// Act
|
||||
sut.preProcess(new DataSet(Nd4j.rand(new long[] { 2, 2, 2 }), null));
|
||||
}
|
||||
|
||||
@Test
|
||||
public void when_dataSetIsEmpty_expect_EmptyDataSet() {
|
||||
// Assemble
|
||||
PoolingDataSetPreProcessor sut = PoolingDataSetPreProcessor.builder().build();
|
||||
DataSet ds = new DataSet(null, null);
|
||||
|
||||
// Act
|
||||
sut.preProcess(ds);
|
||||
|
||||
// Assert
|
||||
Assert.assertTrue(ds.isEmpty());
|
||||
}
|
||||
|
||||
@Test
|
||||
public void when_builderHasNoPoolOrAssembler_expect_defaultPoolBehavior() {
|
||||
// Arrange
|
||||
PoolingDataSetPreProcessor sut = PoolingDataSetPreProcessor.builder().build();
|
||||
DataSet[] observations = new DataSet[5];
|
||||
INDArray[] inputs = new INDArray[5];
|
||||
|
||||
|
||||
// Act
|
||||
for(int i = 0; i < 5; ++i) {
|
||||
inputs[i] = Nd4j.rand(new long[] { 1, 2, 2 });
|
||||
DataSet input = new DataSet(inputs[i], null);
|
||||
sut.preProcess(input);
|
||||
observations[i] = input;
|
||||
}
|
||||
|
||||
// Assert
|
||||
assertTrue(observations[0].isEmpty());
|
||||
assertTrue(observations[1].isEmpty());
|
||||
assertTrue(observations[2].isEmpty());
|
||||
|
||||
for(int i = 0; i < 4; ++i) {
|
||||
assertEquals(inputs[i].getDouble(new int[] { 0, 0, 0 }), observations[3].getFeatures().getDouble(new int[] { 0, i, 0, 0 }), 0.0001);
|
||||
assertEquals(inputs[i].getDouble(new int[] { 0, 0, 1 }), observations[3].getFeatures().getDouble(new int[] { 0, i, 0, 1 }), 0.0001);
|
||||
assertEquals(inputs[i].getDouble(new int[] { 0, 1, 0 }), observations[3].getFeatures().getDouble(new int[] { 0, i, 1, 0 }), 0.0001);
|
||||
assertEquals(inputs[i].getDouble(new int[] { 0, 1, 1 }), observations[3].getFeatures().getDouble(new int[] { 0, i, 1, 1 }), 0.0001);
|
||||
}
|
||||
|
||||
for(int i = 0; i < 4; ++i) {
|
||||
assertEquals(inputs[i+1].getDouble(new int[] { 0, 0, 0 }), observations[4].getFeatures().getDouble(new int[] { 0, i, 0, 0 }), 0.0001);
|
||||
assertEquals(inputs[i+1].getDouble(new int[] { 0, 0, 1 }), observations[4].getFeatures().getDouble(new int[] { 0, i, 0, 1 }), 0.0001);
|
||||
assertEquals(inputs[i+1].getDouble(new int[] { 0, 1, 0 }), observations[4].getFeatures().getDouble(new int[] { 0, i, 1, 0 }), 0.0001);
|
||||
assertEquals(inputs[i+1].getDouble(new int[] { 0, 1, 1 }), observations[4].getFeatures().getDouble(new int[] { 0, i, 1, 1 }), 0.0001);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@Test
|
||||
public void when_builderHasPoolAndAssembler_expect_paramPoolAndAssemblerAreUsed() {
|
||||
// Arrange
|
||||
INDArray input = Nd4j.rand(1, 1);
|
||||
TestObservationPool pool = new TestObservationPool();
|
||||
TestPoolContentAssembler assembler = new TestPoolContentAssembler();
|
||||
PoolingDataSetPreProcessor sut = PoolingDataSetPreProcessor.builder()
|
||||
.observablePool(pool)
|
||||
.poolContentAssembler(assembler)
|
||||
.build();
|
||||
|
||||
// Act
|
||||
sut.preProcess(new DataSet(input, null));
|
||||
|
||||
// Assert
|
||||
assertTrue(pool.isAtFullCapacityCalled);
|
||||
assertTrue(pool.isGetCalled);
|
||||
assertEquals(input.getDouble(0), pool.observation.getDouble(0), 0.0);
|
||||
assertTrue(assembler.assembleIsCalled);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void when_pastInputChanges_expect_outputNotChanged() {
|
||||
// Arrange
|
||||
INDArray input = Nd4j.zeros(1, 1);
|
||||
TestObservationPool pool = new TestObservationPool();
|
||||
TestPoolContentAssembler assembler = new TestPoolContentAssembler();
|
||||
PoolingDataSetPreProcessor sut = PoolingDataSetPreProcessor.builder()
|
||||
.observablePool(pool)
|
||||
.poolContentAssembler(assembler)
|
||||
.build();
|
||||
|
||||
// Act
|
||||
sut.preProcess(new DataSet(input, null));
|
||||
input.putScalar(0, 0, 1.0);
|
||||
|
||||
// Assert
|
||||
assertEquals(0.0, pool.observation.getDouble(0), 0.0);
|
||||
}
|
||||
|
||||
private static class TestObservationPool implements ObservationPool {
|
||||
|
||||
public INDArray observation;
|
||||
public boolean isGetCalled;
|
||||
public boolean isAtFullCapacityCalled;
|
||||
private boolean isResetCalled;
|
||||
|
||||
@Override
|
||||
public void add(INDArray observation) {
|
||||
this.observation = observation;
|
||||
}
|
||||
|
||||
@Override
|
||||
public INDArray[] get() {
|
||||
isGetCalled = true;
|
||||
return new INDArray[0];
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean isAtFullCapacity() {
|
||||
isAtFullCapacityCalled = true;
|
||||
return true;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void reset() {
|
||||
isResetCalled = true;
|
||||
}
|
||||
}
|
||||
|
||||
private static class TestPoolContentAssembler implements PoolContentAssembler {
|
||||
|
||||
public boolean assembleIsCalled;
|
||||
|
||||
@Override
|
||||
public INDArray assemble(INDArray[] poolContent) {
|
||||
assembleIsCalled = true;
|
||||
return Nd4j.create(1, 1);
|
||||
}
|
||||
}
|
||||
}
|
|
@ -1,70 +0,0 @@
|
|||
package org.deeplearning4j.rl4j.observation.preprocessor;
|
||||
|
||||
import org.junit.Test;
|
||||
import org.nd4j.linalg.dataset.DataSet;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
|
||||
import static org.junit.Assert.assertFalse;
|
||||
import static org.junit.Assert.assertTrue;
|
||||
|
||||
public class SkippingDataSetPreProcessorTest {
|
||||
@Test(expected = IllegalArgumentException.class)
|
||||
public void when_ctorSkipFrameIsZero_expect_IllegalArgumentException() {
|
||||
SkippingDataSetPreProcessor sut = new SkippingDataSetPreProcessor(0);
|
||||
}
|
||||
|
||||
@Test(expected = IllegalArgumentException.class)
|
||||
public void when_builderSkipFrameIsZero_expect_IllegalArgumentException() {
|
||||
SkippingDataSetPreProcessor sut = SkippingDataSetPreProcessor.builder()
|
||||
.skipFrame(0)
|
||||
.build();
|
||||
}
|
||||
|
||||
@Test
|
||||
public void when_skipFrameIs3_expect_Skip2OutOf3() {
|
||||
// Arrange
|
||||
SkippingDataSetPreProcessor sut = SkippingDataSetPreProcessor.builder()
|
||||
.skipFrame(3)
|
||||
.build();
|
||||
DataSet[] results = new DataSet[4];
|
||||
|
||||
// Act
|
||||
for(int i = 0; i < 4; ++i) {
|
||||
results[i] = new DataSet(Nd4j.create(new double[] { 123.0 }), null);
|
||||
sut.preProcess(results[i]);
|
||||
}
|
||||
|
||||
// Assert
|
||||
assertFalse(results[0].isEmpty());
|
||||
assertTrue(results[1].isEmpty());
|
||||
assertTrue(results[2].isEmpty());
|
||||
assertFalse(results[3].isEmpty());
|
||||
}
|
||||
|
||||
@Test
|
||||
public void when_resetIsCalled_expect_skippingIsReset() {
|
||||
// Arrange
|
||||
SkippingDataSetPreProcessor sut = SkippingDataSetPreProcessor.builder()
|
||||
.skipFrame(3)
|
||||
.build();
|
||||
DataSet[] results = new DataSet[4];
|
||||
|
||||
// Act
|
||||
results[0] = new DataSet(Nd4j.create(new double[] { 123.0 }), null);
|
||||
results[1] = new DataSet(Nd4j.create(new double[] { 123.0 }), null);
|
||||
results[2] = new DataSet(Nd4j.create(new double[] { 123.0 }), null);
|
||||
results[3] = new DataSet(Nd4j.create(new double[] { 123.0 }), null);
|
||||
|
||||
sut.preProcess(results[0]);
|
||||
sut.preProcess(results[1]);
|
||||
sut.reset();
|
||||
sut.preProcess(results[2]);
|
||||
sut.preProcess(results[3]);
|
||||
|
||||
// Assert
|
||||
assertFalse(results[0].isEmpty());
|
||||
assertTrue(results[1].isEmpty());
|
||||
assertFalse(results[2].isEmpty());
|
||||
assertTrue(results[3].isEmpty());
|
||||
}
|
||||
}
|
|
@ -1,41 +0,0 @@
|
|||
package org.deeplearning4j.rl4j.observation.preprocessor.pooling;
|
||||
|
||||
import org.junit.Test;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
|
||||
import static org.junit.Assert.assertEquals;
|
||||
|
||||
public class ChannelStackPoolContentAssemblerTest {
|
||||
|
||||
@Test
|
||||
public void when_assemble_expect_poolContentStackedOnChannel() {
|
||||
// Assemble
|
||||
ChannelStackPoolContentAssembler sut = new ChannelStackPoolContentAssembler();
|
||||
INDArray[] poolContent = new INDArray[] {
|
||||
Nd4j.rand(2, 2),
|
||||
Nd4j.rand(2, 2),
|
||||
};
|
||||
|
||||
// Act
|
||||
INDArray result = sut.assemble(poolContent);
|
||||
|
||||
// Assert
|
||||
assertEquals(3, result.shape().length);
|
||||
assertEquals(2, result.shape()[0]);
|
||||
assertEquals(2, result.shape()[1]);
|
||||
assertEquals(2, result.shape()[2]);
|
||||
|
||||
assertEquals(poolContent[0].getDouble(0, 0), result.getDouble(0, 0, 0), 0.0001);
|
||||
assertEquals(poolContent[0].getDouble(0, 1), result.getDouble(0, 0, 1), 0.0001);
|
||||
assertEquals(poolContent[0].getDouble(1, 0), result.getDouble(0, 1, 0), 0.0001);
|
||||
assertEquals(poolContent[0].getDouble(1, 1), result.getDouble(0, 1, 1), 0.0001);
|
||||
|
||||
assertEquals(poolContent[1].getDouble(0, 0), result.getDouble(1, 0, 0), 0.0001);
|
||||
assertEquals(poolContent[1].getDouble(0, 1), result.getDouble(1, 0, 1), 0.0001);
|
||||
assertEquals(poolContent[1].getDouble(1, 0), result.getDouble(1, 1, 0), 0.0001);
|
||||
assertEquals(poolContent[1].getDouble(1, 1), result.getDouble(1, 1, 1), 0.0001);
|
||||
|
||||
}
|
||||
|
||||
}
|
|
@ -1,100 +0,0 @@
|
|||
package org.deeplearning4j.rl4j.observation.preprocessor.pooling;
|
||||
|
||||
import org.junit.Test;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
|
||||
import static org.junit.Assert.assertEquals;
|
||||
import static org.junit.Assert.assertFalse;
|
||||
import static org.junit.Assert.assertTrue;
|
||||
|
||||
public class CircularFifoObservationPoolTest {
|
||||
|
||||
@Test(expected = IllegalArgumentException.class)
|
||||
public void when_poolSizeZeroOrLess_expect_IllegalArgumentException() {
|
||||
CircularFifoObservationPool sut = new CircularFifoObservationPool(0);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void when_poolIsEmpty_expect_NotReady() {
|
||||
// Assemble
|
||||
CircularFifoObservationPool sut = new CircularFifoObservationPool();
|
||||
|
||||
// Act
|
||||
boolean isReady = sut.isAtFullCapacity();
|
||||
|
||||
// Assert
|
||||
assertFalse(isReady);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void when_notEnoughElementsInPool_expect_notReady() {
|
||||
// Assemble
|
||||
CircularFifoObservationPool sut = new CircularFifoObservationPool();
|
||||
sut.add(Nd4j.create(new double[] { 123.0 }));
|
||||
|
||||
// Act
|
||||
boolean isReady = sut.isAtFullCapacity();
|
||||
|
||||
// Assert
|
||||
assertFalse(isReady);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void when_enoughElementsInPool_expect_ready() {
|
||||
// Assemble
|
||||
CircularFifoObservationPool sut = CircularFifoObservationPool.builder()
|
||||
.poolSize(2)
|
||||
.build();
|
||||
sut.add(Nd4j.createFromArray(123.0));
|
||||
sut.add(Nd4j.createFromArray(123.0));
|
||||
|
||||
// Act
|
||||
boolean isReady = sut.isAtFullCapacity();
|
||||
|
||||
// Assert
|
||||
assertTrue(isReady);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void when_addMoreThanSize_expect_getReturnOnlyLastElements() {
|
||||
// Assemble
|
||||
CircularFifoObservationPool sut = CircularFifoObservationPool.builder().build();
|
||||
sut.add(Nd4j.createFromArray(0.0));
|
||||
sut.add(Nd4j.createFromArray(1.0));
|
||||
sut.add(Nd4j.createFromArray(2.0));
|
||||
sut.add(Nd4j.createFromArray(3.0));
|
||||
sut.add(Nd4j.createFromArray(4.0));
|
||||
sut.add(Nd4j.createFromArray(5.0));
|
||||
sut.add(Nd4j.createFromArray(6.0));
|
||||
|
||||
// Act
|
||||
INDArray[] result = sut.get();
|
||||
|
||||
// Assert
|
||||
assertEquals(3.0, result[0].getDouble(0), 0.0);
|
||||
assertEquals(4.0, result[1].getDouble(0), 0.0);
|
||||
assertEquals(5.0, result[2].getDouble(0), 0.0);
|
||||
assertEquals(6.0, result[3].getDouble(0), 0.0);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void when_resetIsCalled_expect_poolContentFlushed() {
|
||||
// Assemble
|
||||
CircularFifoObservationPool sut = CircularFifoObservationPool.builder().build();
|
||||
sut.add(Nd4j.createFromArray(0.0));
|
||||
sut.add(Nd4j.createFromArray(1.0));
|
||||
sut.add(Nd4j.createFromArray(2.0));
|
||||
sut.add(Nd4j.createFromArray(3.0));
|
||||
sut.add(Nd4j.createFromArray(4.0));
|
||||
sut.add(Nd4j.createFromArray(5.0));
|
||||
sut.add(Nd4j.createFromArray(6.0));
|
||||
sut.reset();
|
||||
|
||||
// Act
|
||||
INDArray[] result = sut.get();
|
||||
|
||||
// Assert
|
||||
assertEquals(0, result.length);
|
||||
}
|
||||
}
|
|
@ -0,0 +1,378 @@
|
|||
package org.deeplearning4j.rl4j.observation.transform;
|
||||
|
||||
import org.deeplearning4j.rl4j.observation.Observation;
|
||||
import org.junit.Test;
|
||||
import org.nd4j.linalg.dataset.api.DataSet;
|
||||
import org.nd4j.linalg.dataset.api.DataSetPreProcessor;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
import org.datavec.api.transform.Operation;
|
||||
|
||||
import java.util.HashMap;
|
||||
import java.util.Map;
|
||||
|
||||
import static org.junit.Assert.*;
|
||||
|
||||
public class TransformProcessTest {
|
||||
@Test(expected = IllegalArgumentException.class)
|
||||
public void when_noChannelNameIsSuppliedToBuild_expect_exception() {
|
||||
// Arrange
|
||||
TransformProcess.builder().build();
|
||||
}
|
||||
|
||||
@Test(expected = IllegalArgumentException.class)
|
||||
public void when_callingTransformWithNullArg_expect_exception() {
|
||||
// Arrange
|
||||
TransformProcess sut = TransformProcess.builder()
|
||||
.build("test");
|
||||
|
||||
// Act
|
||||
sut.transform(null, 0, false);
|
||||
}
|
||||
|
||||
@Test(expected = IllegalArgumentException.class)
|
||||
public void when_callingTransformWithEmptyChannelData_expect_exception() {
|
||||
// Arrange
|
||||
TransformProcess sut = TransformProcess.builder()
|
||||
.build("test");
|
||||
Map<String, Object> channelsData = new HashMap<String, Object>();
|
||||
|
||||
// Act
|
||||
sut.transform(channelsData, 0, false);
|
||||
}
|
||||
|
||||
@Test(expected = NullPointerException.class)
|
||||
public void when_addingNullFilter_expect_nullException() {
|
||||
// Act
|
||||
TransformProcess.builder().filter(null);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void when_fileteredOut_expect_skippedObservationAndFollowingOperationsSkipped() {
|
||||
// Arrange
|
||||
IntegerTransformOperationMock transformOperationMock = new IntegerTransformOperationMock();
|
||||
TransformProcess sut = TransformProcess.builder()
|
||||
.filter(new FilterOperationMock(true))
|
||||
.transform("test", transformOperationMock)
|
||||
.build("test");
|
||||
Map<String, Object> channelsData = new HashMap<String, Object>() {{
|
||||
put("test", 1);
|
||||
}};
|
||||
|
||||
// Act
|
||||
Observation result = sut.transform(channelsData, 0, false);
|
||||
|
||||
// Assert
|
||||
assertTrue(result.isSkipped());
|
||||
assertFalse(transformOperationMock.isCalled);
|
||||
}
|
||||
|
||||
@Test(expected = NullPointerException.class)
|
||||
public void when_addingTransformOnNullChannel_expect_nullException() {
|
||||
// Act
|
||||
TransformProcess.builder().transform(null, new IntegerTransformOperationMock());
|
||||
}
|
||||
|
||||
@Test(expected = NullPointerException.class)
|
||||
public void when_addingTransformWithNullTransform_expect_nullException() {
|
||||
// Act
|
||||
TransformProcess.builder().transform("test", null);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void when_transformIsCalled_expect_channelDataTransformedInSameOrder() {
|
||||
// Arrange
|
||||
TransformProcess sut = TransformProcess.builder()
|
||||
.filter(new FilterOperationMock(false))
|
||||
.transform("test", new IntegerTransformOperationMock())
|
||||
.transform("test", new ToDataSetTransformOperationMock())
|
||||
.build("test");
|
||||
Map<String, Object> channelsData = new HashMap<String, Object>() {{
|
||||
put("test", 1);
|
||||
}};
|
||||
|
||||
// Act
|
||||
Observation result = sut.transform(channelsData, 0, false);
|
||||
|
||||
// Assert
|
||||
assertFalse(result.isSkipped());
|
||||
assertEquals(-1.0, result.getData().getDouble(0), 0.00001);
|
||||
}
|
||||
|
||||
@Test(expected = NullPointerException.class)
|
||||
public void when_addingPreProcessOnNullChannel_expect_nullException() {
|
||||
// Act
|
||||
TransformProcess.builder().preProcess(null, new DataSetPreProcessorMock());
|
||||
}
|
||||
|
||||
@Test(expected = NullPointerException.class)
|
||||
public void when_addingPreProcessWithNullTransform_expect_nullException() {
|
||||
// Act
|
||||
TransformProcess.builder().transform("test", null);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void when_preProcessIsCalled_expect_channelDataPreProcessedInSameOrder() {
|
||||
// Arrange
|
||||
TransformProcess sut = TransformProcess.builder()
|
||||
.filter(new FilterOperationMock(false))
|
||||
.transform("test", new IntegerTransformOperationMock())
|
||||
.transform("test", new ToDataSetTransformOperationMock())
|
||||
.preProcess("test", new DataSetPreProcessorMock())
|
||||
.build("test");
|
||||
Map<String, Object> channelsData = new HashMap<String, Object>() {{
|
||||
put("test", 1);
|
||||
}};
|
||||
|
||||
// Act
|
||||
Observation result = sut.transform(channelsData, 0, false);
|
||||
|
||||
// Assert
|
||||
assertFalse(result.isSkipped());
|
||||
assertEquals(1, result.getData().shape().length);
|
||||
assertEquals(1, result.getData().shape()[0]);
|
||||
assertEquals(-10.0, result.getData().getDouble(0), 0.00001);
|
||||
}
|
||||
|
||||
@Test(expected = IllegalStateException.class)
|
||||
public void when_transformingNullData_expect_exception() {
|
||||
// Arrange
|
||||
TransformProcess sut = TransformProcess.builder()
|
||||
.transform("test", new IntegerTransformOperationMock())
|
||||
.build("test");
|
||||
Map<String, Object> channelsData = new HashMap<String, Object>() {{
|
||||
put("test", 1);
|
||||
}};
|
||||
|
||||
// Act
|
||||
Observation result = sut.transform(channelsData, 0, false);
|
||||
}
|
||||
|
||||
@Test(expected = IllegalArgumentException.class)
|
||||
public void when_transformingAndChannelsNotDataSet_expect_exception() {
|
||||
// Arrange
|
||||
TransformProcess sut = TransformProcess.builder()
|
||||
.preProcess("test", new DataSetPreProcessorMock())
|
||||
.build("test");
|
||||
|
||||
// Act
|
||||
Observation result = sut.transform(null, 0, false);
|
||||
}
|
||||
|
||||
|
||||
@Test(expected = IllegalArgumentException.class)
|
||||
public void when_transformingAndChannelsEmptyDataSet_expect_exception() {
|
||||
// Arrange
|
||||
TransformProcess sut = TransformProcess.builder()
|
||||
.preProcess("test", new DataSetPreProcessorMock())
|
||||
.build("test");
|
||||
Map<String, Object> channelsData = new HashMap<String, Object>();
|
||||
|
||||
// Act
|
||||
Observation result = sut.transform(channelsData, 0, false);
|
||||
}
|
||||
|
||||
@Test(expected = IllegalArgumentException.class)
|
||||
public void when_buildIsCalledWithoutChannelNames_expect_exception() {
|
||||
// Act
|
||||
TransformProcess.builder().build();
|
||||
}
|
||||
|
||||
@Test(expected = NullPointerException.class)
|
||||
public void when_buildIsCalledWithNullChannelName_expect_exception() {
|
||||
// Act
|
||||
TransformProcess.builder().build(null);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void when_resetIsCalled_expect_resettableAreReset() {
|
||||
// Arrange
|
||||
ResettableTransformOperationMock resettableOperation = new ResettableTransformOperationMock();
|
||||
TransformProcess sut = TransformProcess.builder()
|
||||
.filter(new FilterOperationMock(false))
|
||||
.transform("test", new IntegerTransformOperationMock())
|
||||
.transform("test", resettableOperation)
|
||||
.build("test");
|
||||
|
||||
// Act
|
||||
sut.reset();
|
||||
|
||||
// Assert
|
||||
assertTrue(resettableOperation.isResetCalled);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void when_buildIsCalledAndAllChannelsAreDataSets_expect_observation() {
|
||||
// Arrange
|
||||
TransformProcess sut = TransformProcess.builder()
|
||||
.transform("test", new ToDataSetTransformOperationMock())
|
||||
.build("test");
|
||||
Map<String, Object> channelsData = new HashMap<String, Object>() {{
|
||||
put("test", 1);
|
||||
}};
|
||||
|
||||
// Act
|
||||
Observation result = sut.transform(channelsData, 123, true);
|
||||
|
||||
// Assert
|
||||
assertFalse(result.isSkipped());
|
||||
|
||||
assertEquals(1.0, result.getData().getDouble(0), 0.00001);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void when_buildIsCalledAndAllChannelsAreINDArrays_expect_observation() {
|
||||
// Arrange
|
||||
TransformProcess sut = TransformProcess.builder()
|
||||
.build("test");
|
||||
Map<String, Object> channelsData = new HashMap<String, Object>() {{
|
||||
put("test", Nd4j.create(new double[] { 1.0 }));
|
||||
}};
|
||||
|
||||
// Act
|
||||
Observation result = sut.transform(channelsData, 123, true);
|
||||
|
||||
// Assert
|
||||
assertFalse(result.isSkipped());
|
||||
|
||||
assertEquals(1.0, result.getData().getDouble(0), 0.00001);
|
||||
}
|
||||
|
||||
@Test(expected = IllegalStateException.class)
|
||||
public void when_buildIsCalledAndChannelsNotDataSetsOrINDArrays_expect_exception() {
|
||||
// Arrange
|
||||
TransformProcess sut = TransformProcess.builder()
|
||||
.build("test");
|
||||
Map<String, Object> channelsData = new HashMap<String, Object>() {{
|
||||
put("test", 1);
|
||||
}};
|
||||
|
||||
// Act
|
||||
Observation result = sut.transform(channelsData, 123, true);
|
||||
}
|
||||
|
||||
@Test(expected = NullPointerException.class)
|
||||
public void when_channelDataIsNull_expect_exception() {
|
||||
// Arrange
|
||||
TransformProcess sut = TransformProcess.builder()
|
||||
.transform("test", new IntegerTransformOperationMock())
|
||||
.build("test");
|
||||
Map<String, Object> channelsData = new HashMap<String, Object>() {{
|
||||
put("test", null);
|
||||
}};
|
||||
|
||||
// Act
|
||||
sut.transform(channelsData, 0, false);
|
||||
}
|
||||
|
||||
@Test(expected = IllegalArgumentException.class)
|
||||
public void when_transformAppliedOnChannelNotInMap_expect_exception() {
|
||||
// Arrange
|
||||
TransformProcess sut = TransformProcess.builder()
|
||||
.transform("test", new IntegerTransformOperationMock())
|
||||
.build("test");
|
||||
Map<String, Object> channelsData = new HashMap<String, Object>() {{
|
||||
put("not-test", 1);
|
||||
}};
|
||||
|
||||
// Act
|
||||
sut.transform(channelsData, 0, false);
|
||||
}
|
||||
|
||||
@Test(expected = IllegalArgumentException.class)
|
||||
public void when_preProcessAppliedOnChannelNotInMap_expect_exception() {
|
||||
// Arrange
|
||||
TransformProcess sut = TransformProcess.builder()
|
||||
.preProcess("test", new DataSetPreProcessorMock())
|
||||
.build("test");
|
||||
Map<String, Object> channelsData = new HashMap<String, Object>() {{
|
||||
put("not-test", 1);
|
||||
}};
|
||||
|
||||
// Act
|
||||
sut.transform(channelsData, 0, false);
|
||||
}
|
||||
|
||||
@Test(expected = IllegalArgumentException.class)
|
||||
public void when_buildContainsChannelNotInMap_expect_exception() {
|
||||
// Arrange
|
||||
TransformProcess sut = TransformProcess.builder()
|
||||
.transform("test", new IntegerTransformOperationMock())
|
||||
.build("not-test");
|
||||
Map<String, Object> channelsData = new HashMap<String, Object>() {{
|
||||
put("test", 1);
|
||||
}};
|
||||
|
||||
// Act
|
||||
sut.transform(channelsData, 0, false);
|
||||
}
|
||||
|
||||
@Test(expected = IllegalArgumentException.class)
|
||||
public void when_preProcessNotAppliedOnDataSet_expect_exception() {
|
||||
// Arrange
|
||||
TransformProcess sut = TransformProcess.builder()
|
||||
.preProcess("test", new DataSetPreProcessorMock())
|
||||
.build("test");
|
||||
Map<String, Object> channelsData = new HashMap<String, Object>() {{
|
||||
put("test", 1);
|
||||
}};
|
||||
|
||||
// Act
|
||||
sut.transform(channelsData, 0, false);
|
||||
}
|
||||
|
||||
private static class FilterOperationMock implements FilterOperation {
|
||||
|
||||
private final boolean skipped;
|
||||
|
||||
public FilterOperationMock(boolean skipped) {
|
||||
this.skipped = skipped;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean isSkipped(Map<String, Object> channelsData, int currentObservationStep, boolean isFinalObservation) {
|
||||
return skipped;
|
||||
}
|
||||
}
|
||||
|
||||
private static class IntegerTransformOperationMock implements Operation<Integer, Integer> {
|
||||
|
||||
public boolean isCalled = false;
|
||||
|
||||
@Override
|
||||
public Integer transform(Integer input) {
|
||||
isCalled = true;
|
||||
return -input;
|
||||
}
|
||||
}
|
||||
|
||||
private static class ToDataSetTransformOperationMock implements Operation<Integer, DataSet> {
|
||||
|
||||
@Override
|
||||
public DataSet transform(Integer input) {
|
||||
return new org.nd4j.linalg.dataset.DataSet(Nd4j.create(new double[] { input }), null);
|
||||
}
|
||||
}
|
||||
|
||||
private static class ResettableTransformOperationMock implements Operation<Integer, Integer>, ResettableOperation {
|
||||
|
||||
private boolean isResetCalled = false;
|
||||
|
||||
@Override
|
||||
public Integer transform(Integer input) {
|
||||
return input * 10;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void reset() {
|
||||
isResetCalled = true;
|
||||
}
|
||||
}
|
||||
|
||||
private static class DataSetPreProcessorMock implements DataSetPreProcessor {
|
||||
|
||||
@Override
|
||||
public void preProcess(DataSet dataSet) {
|
||||
dataSet.getFeatures().muli(10.0);
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,54 @@
|
|||
package org.deeplearning4j.rl4j.observation.transform.filter;
|
||||
|
||||
import org.deeplearning4j.rl4j.observation.transform.FilterOperation;
|
||||
import org.junit.Test;
|
||||
|
||||
import static org.junit.Assert.assertFalse;
|
||||
import static org.junit.Assert.assertTrue;
|
||||
|
||||
public class UniformSkippingFilterTest {
|
||||
|
||||
@Test(expected = IllegalArgumentException.class)
|
||||
public void when_negativeSkipFrame_expect_exception() {
|
||||
// Act
|
||||
new UniformSkippingFilter(-1);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void when_skippingIs4_expect_firstNotSkippedOther3Skipped() {
|
||||
// Assemble
|
||||
FilterOperation sut = new UniformSkippingFilter(4);
|
||||
boolean[] isSkipped = new boolean[8];
|
||||
|
||||
// Act
|
||||
for(int i = 0; i < 8; ++i) {
|
||||
isSkipped[i] = sut.isSkipped(null, i, false);
|
||||
}
|
||||
|
||||
// Assert
|
||||
assertFalse(isSkipped[0]);
|
||||
assertTrue(isSkipped[1]);
|
||||
assertTrue(isSkipped[2]);
|
||||
assertTrue(isSkipped[3]);
|
||||
|
||||
assertFalse(isSkipped[4]);
|
||||
assertTrue(isSkipped[5]);
|
||||
assertTrue(isSkipped[6]);
|
||||
assertTrue(isSkipped[7]);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void when_isLastObservation_expect_observationNotSkipped() {
|
||||
// Assemble
|
||||
FilterOperation sut = new UniformSkippingFilter(4);
|
||||
|
||||
// Act
|
||||
boolean isSkippedNotLastObservation = sut.isSkipped(null, 1, false);
|
||||
boolean isSkippedLastObservation = sut.isSkipped(null, 1, true);
|
||||
|
||||
// Assert
|
||||
assertTrue(isSkippedNotLastObservation);
|
||||
assertFalse(isSkippedLastObservation);
|
||||
}
|
||||
|
||||
}
|
|
@ -0,0 +1,166 @@
|
|||
package org.deeplearning4j.rl4j.observation.transform.operation;
|
||||
|
||||
import org.deeplearning4j.rl4j.observation.transform.operation.historymerge.HistoryMergeAssembler;
|
||||
import org.deeplearning4j.rl4j.observation.transform.operation.historymerge.HistoryMergeElementStore;
|
||||
import org.junit.Test;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
|
||||
import static org.junit.Assert.*;
|
||||
|
||||
public class HistoryMergeTransformTest {
|
||||
|
||||
@Test
|
||||
public void when_firstDimensionIsNotBatch_expect_observationAddedAsIs() {
|
||||
// Arrange
|
||||
MockStore store = new MockStore(false);
|
||||
HistoryMergeTransform sut = HistoryMergeTransform.builder()
|
||||
.isFirstDimenstionBatch(false)
|
||||
.elementStore(store)
|
||||
.build();
|
||||
INDArray input = Nd4j.create(new double[] { 1.0, 2.0, 3.0 });
|
||||
|
||||
// Act
|
||||
sut.transform(input);
|
||||
|
||||
// Assert
|
||||
assertEquals(1, store.addedObservation.shape().length);
|
||||
assertEquals(3, store.addedObservation.shape()[0]);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void when_firstDimensionIsBatch_expect_observationAddedAsSliced() {
|
||||
// Arrange
|
||||
MockStore store = new MockStore(false);
|
||||
HistoryMergeTransform sut = HistoryMergeTransform.builder()
|
||||
.isFirstDimenstionBatch(true)
|
||||
.elementStore(store)
|
||||
.build();
|
||||
INDArray input = Nd4j.create(new double[] { 1.0, 2.0, 3.0 }).reshape(1, 3);
|
||||
|
||||
// Act
|
||||
sut.transform(input);
|
||||
|
||||
// Assert
|
||||
assertEquals(1, store.addedObservation.shape().length);
|
||||
assertEquals(3, store.addedObservation.shape()[0]);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void when_notReady_expect_resultIsNull() {
|
||||
// Arrange
|
||||
MockStore store = new MockStore(false);
|
||||
HistoryMergeTransform sut = HistoryMergeTransform.builder()
|
||||
.isFirstDimenstionBatch(true)
|
||||
.elementStore(store)
|
||||
.build();
|
||||
INDArray input = Nd4j.create(new double[] { 1.0, 2.0, 3.0 });
|
||||
|
||||
// Act
|
||||
INDArray result = sut.transform(input);
|
||||
|
||||
// Assert
|
||||
assertNull(result);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void when_notShouldStoreCopy_expect_sameIsStored() {
|
||||
// Arrange
|
||||
MockStore store = new MockStore(false);
|
||||
HistoryMergeTransform sut = HistoryMergeTransform.builder()
|
||||
.shouldStoreCopy(false)
|
||||
.elementStore(store)
|
||||
.build();
|
||||
INDArray input = Nd4j.create(new double[] { 1.0, 2.0, 3.0 });
|
||||
|
||||
// Act
|
||||
INDArray result = sut.transform(input);
|
||||
|
||||
// Assert
|
||||
assertSame(input, store.addedObservation);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void when_shouldStoreCopy_expect_copyIsStored() {
|
||||
// Arrange
|
||||
MockStore store = new MockStore(true);
|
||||
HistoryMergeTransform sut = HistoryMergeTransform.builder()
|
||||
.shouldStoreCopy(true)
|
||||
.elementStore(store)
|
||||
.build();
|
||||
INDArray input = Nd4j.create(new double[] { 1.0, 2.0, 3.0 });
|
||||
|
||||
// Act
|
||||
INDArray result = sut.transform(input);
|
||||
|
||||
// Assert
|
||||
assertNotSame(input, store.addedObservation);
|
||||
assertEquals(1, store.addedObservation.shape().length);
|
||||
assertEquals(3, store.addedObservation.shape()[0]);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void when_transformCalled_expect_storeContentAssembledAndOutputHasCorrectShape() {
|
||||
// Arrange
|
||||
MockStore store = new MockStore(true);
|
||||
MockAssemble assemble = new MockAssemble();
|
||||
HistoryMergeTransform sut = HistoryMergeTransform.builder()
|
||||
.elementStore(store)
|
||||
.assembler(assemble)
|
||||
.build();
|
||||
INDArray input = Nd4j.create(new double[] { 1.0, 2.0, 3.0 });
|
||||
|
||||
// Act
|
||||
INDArray result = sut.transform(input);
|
||||
|
||||
// Assert
|
||||
assertEquals(1, assemble.assembleElements.length);
|
||||
assertSame(store.addedObservation, assemble.assembleElements[0]);
|
||||
|
||||
assertEquals(2, result.shape().length);
|
||||
assertEquals(1, result.shape()[0]);
|
||||
assertEquals(3, result.shape()[1]);
|
||||
}
|
||||
|
||||
public static class MockStore implements HistoryMergeElementStore {
|
||||
|
||||
private final boolean isReady;
|
||||
private INDArray addedObservation;
|
||||
|
||||
public MockStore(boolean isReady) {
|
||||
|
||||
this.isReady = isReady;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void add(INDArray observation) {
|
||||
addedObservation = observation;
|
||||
}
|
||||
|
||||
@Override
|
||||
public INDArray[] get() {
|
||||
return new INDArray[] { addedObservation };
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean isReady() {
|
||||
return isReady;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void reset() {
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
public static class MockAssemble implements HistoryMergeAssembler {
|
||||
|
||||
private INDArray[] assembleElements;
|
||||
|
||||
@Override
|
||||
public INDArray assemble(INDArray[] elements) {
|
||||
assembleElements = elements;
|
||||
return elements[0];
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,31 @@
|
|||
package org.deeplearning4j.rl4j.observation.transform.operation;
|
||||
|
||||
import org.deeplearning4j.rl4j.helper.INDArrayHelper;
|
||||
import org.junit.Test;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
|
||||
import static org.junit.Assert.*;
|
||||
|
||||
public class SimpleNormalizationTransformTest {
|
||||
@Test(expected = IllegalArgumentException.class)
|
||||
public void when_maxIsLessThanMin_expect_exception() {
|
||||
// Arrange
|
||||
SimpleNormalizationTransform sut = new SimpleNormalizationTransform(10.0, 1.0);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void when_transformIsCalled_expect_inputNormalized() {
|
||||
// Arrange
|
||||
SimpleNormalizationTransform sut = new SimpleNormalizationTransform(1.0, 11.0);
|
||||
INDArray input = Nd4j.create(new double[] { 1.0, 11.0 });
|
||||
|
||||
// Act
|
||||
INDArray result = sut.transform(input);
|
||||
|
||||
// Assert
|
||||
assertEquals(0.0, result.getDouble(0), 0.00001);
|
||||
assertEquals(1.0, result.getDouble(1), 0.00001);
|
||||
}
|
||||
|
||||
}
|
|
@ -0,0 +1,77 @@
|
|||
package org.deeplearning4j.rl4j.observation.transform.operation.historymerge;
|
||||
|
||||
import org.junit.Test;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
|
||||
import static org.junit.Assert.*;
|
||||
|
||||
public class CircularFifoStoreTest {
|
||||
|
||||
@Test(expected = IllegalArgumentException.class)
|
||||
public void when_fifoSizeIsLessThan1_expect_exception() {
|
||||
// Arrange
|
||||
CircularFifoStore sut = new CircularFifoStore(0);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void when_adding2elementsWithSize2_expect_notReadyAfter1stReadyAfter2nd() {
|
||||
// Arrange
|
||||
CircularFifoStore sut = new CircularFifoStore(2);
|
||||
INDArray firstElement = Nd4j.create(new double[] { 1.0, 2.0, 3.0 });
|
||||
INDArray secondElement = Nd4j.create(new double[] { 10.0, 20.0, 30.0 });
|
||||
|
||||
// Act
|
||||
sut.add(firstElement);
|
||||
boolean isReadyAfter1st = sut.isReady();
|
||||
sut.add(secondElement);
|
||||
boolean isReadyAfter2nd = sut.isReady();
|
||||
|
||||
// Assert
|
||||
assertFalse(isReadyAfter1st);
|
||||
assertTrue(isReadyAfter2nd);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void when_adding2elementsWithSize2_expect_getReturnThese2() {
|
||||
// Arrange
|
||||
CircularFifoStore sut = new CircularFifoStore(2);
|
||||
INDArray firstElement = Nd4j.create(new double[] { 1.0, 2.0, 3.0 });
|
||||
INDArray secondElement = Nd4j.create(new double[] { 10.0, 20.0, 30.0 });
|
||||
|
||||
// Act
|
||||
sut.add(firstElement);
|
||||
sut.add(secondElement);
|
||||
INDArray[] results = sut.get();
|
||||
|
||||
// Assert
|
||||
assertEquals(2, results.length);
|
||||
|
||||
assertEquals(1.0, results[0].getDouble(0), 0.00001);
|
||||
assertEquals(2.0, results[0].getDouble(1), 0.00001);
|
||||
assertEquals(3.0, results[0].getDouble(2), 0.00001);
|
||||
|
||||
assertEquals(10.0, results[1].getDouble(0), 0.00001);
|
||||
assertEquals(20.0, results[1].getDouble(1), 0.00001);
|
||||
assertEquals(30.0, results[1].getDouble(2), 0.00001);
|
||||
|
||||
}
|
||||
|
||||
@Test
|
||||
public void when_adding2elementsThenCallingReset_expect_getReturnEmpty() {
|
||||
// Arrange
|
||||
CircularFifoStore sut = new CircularFifoStore(2);
|
||||
INDArray firstElement = Nd4j.create(new double[] { 1.0, 2.0, 3.0 });
|
||||
INDArray secondElement = Nd4j.create(new double[] { 10.0, 20.0, 30.0 });
|
||||
|
||||
// Act
|
||||
sut.add(firstElement);
|
||||
sut.add(secondElement);
|
||||
sut.reset();
|
||||
INDArray[] results = sut.get();
|
||||
|
||||
// Assert
|
||||
assertEquals(0, results.length);
|
||||
}
|
||||
|
||||
}
|
|
@ -0,0 +1,37 @@
|
|||
package org.deeplearning4j.rl4j.observation.transform.operation.historymerge;
|
||||
|
||||
import org.junit.Test;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
|
||||
import static org.junit.Assert.*;
|
||||
|
||||
public class HistoryStackAssemblerTest {
|
||||
|
||||
@Test
|
||||
public void when_assembling2INDArrays_expect_stackedAsResult() {
|
||||
// Arrange
|
||||
INDArray[] input = new INDArray[] {
|
||||
Nd4j.create(new double[] { 1.0, 2.0, 3.0 }),
|
||||
Nd4j.create(new double[] { 10.0, 20.0, 30.0 }),
|
||||
};
|
||||
HistoryStackAssembler sut = new HistoryStackAssembler();
|
||||
|
||||
// Act
|
||||
INDArray result = sut.assemble(input);
|
||||
|
||||
// Assert
|
||||
assertEquals(2, result.shape().length);
|
||||
assertEquals(2, result.shape()[0]);
|
||||
assertEquals(3, result.shape()[1]);
|
||||
|
||||
assertEquals(1.0, result.getDouble(0, 0), 0.00001);
|
||||
assertEquals(2.0, result.getDouble(0, 1), 0.00001);
|
||||
assertEquals(3.0, result.getDouble(0, 2), 0.00001);
|
||||
|
||||
assertEquals(10.0, result.getDouble(1, 0), 0.00001);
|
||||
assertEquals(20.0, result.getDouble(1, 1), 0.00001);
|
||||
assertEquals(30.0, result.getDouble(1, 2), 0.00001);
|
||||
|
||||
}
|
||||
}
|
|
@ -23,15 +23,18 @@ import org.deeplearning4j.nn.gradient.Gradient;
|
|||
import org.deeplearning4j.nn.graph.ComputationGraph;
|
||||
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
|
||||
import org.deeplearning4j.rl4j.learning.IHistoryProcessor;
|
||||
import org.deeplearning4j.rl4j.learning.Learning;
|
||||
import org.deeplearning4j.rl4j.learning.sync.qlearning.QLearning;
|
||||
import org.deeplearning4j.rl4j.learning.sync.qlearning.discrete.QLearningDiscreteTest;
|
||||
import org.deeplearning4j.rl4j.mdp.MDP;
|
||||
import org.deeplearning4j.rl4j.network.NeuralNet;
|
||||
import org.deeplearning4j.rl4j.network.ac.IActorCritic;
|
||||
import org.deeplearning4j.rl4j.observation.Observation;
|
||||
import org.deeplearning4j.rl4j.space.ActionSpace;
|
||||
import org.deeplearning4j.rl4j.space.DiscreteSpace;
|
||||
import org.deeplearning4j.rl4j.space.Encodable;
|
||||
import org.deeplearning4j.rl4j.support.*;
|
||||
import org.deeplearning4j.rl4j.util.LegacyMDPWrapper;
|
||||
import org.junit.Test;
|
||||
import org.nd4j.linalg.activations.Activation;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
|
@ -186,8 +189,8 @@ public class PolicyTest {
|
|||
QLearning.QLConfiguration conf = new QLearning.QLConfiguration(0, 0, 0, 5, 1, 0,
|
||||
0, 1.0, 0, 0, 0, 0, true);
|
||||
MockNeuralNet nnMock = new MockNeuralNet();
|
||||
MockRefacPolicy sut = new MockRefacPolicy(nnMock);
|
||||
IHistoryProcessor.Configuration hpConf = new IHistoryProcessor.Configuration(5, 4, 4, 4, 4, 0, 0, 2);
|
||||
MockRefacPolicy sut = new MockRefacPolicy(nnMock, observationSpace.getShape(), hpConf.getSkipFrame(), hpConf.getHistoryLength());
|
||||
MockHistoryProcessor hp = new MockHistoryProcessor(hpConf);
|
||||
|
||||
// Act
|
||||
|
@ -197,13 +200,6 @@ public class PolicyTest {
|
|||
assertEquals(1, nnMock.resetCallCount);
|
||||
assertEquals(465.0, totalReward, 0.0001);
|
||||
|
||||
// HistoryProcessor
|
||||
assertEquals(16, hp.addCalls.size());
|
||||
assertEquals(31, hp.recordCalls.size());
|
||||
for(int i=0; i <= 30; ++i) {
|
||||
assertEquals((double)i, hp.recordCalls.get(i).getDouble(0), 0.0001);
|
||||
}
|
||||
|
||||
// MDP
|
||||
assertEquals(1, mdp.resetCount);
|
||||
assertEquals(30, mdp.actions.size());
|
||||
|
@ -219,10 +215,15 @@ public class PolicyTest {
|
|||
public static class MockRefacPolicy extends Policy<MockEncodable, Integer> {
|
||||
|
||||
private NeuralNet neuralNet;
|
||||
private final int[] shape;
|
||||
private final int skipFrame;
|
||||
private final int historyLength;
|
||||
|
||||
public MockRefacPolicy(NeuralNet neuralNet) {
|
||||
|
||||
public MockRefacPolicy(NeuralNet neuralNet, int[] shape, int skipFrame, int historyLength) {
|
||||
this.neuralNet = neuralNet;
|
||||
this.shape = shape;
|
||||
this.skipFrame = skipFrame;
|
||||
this.historyLength = historyLength;
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -239,5 +240,11 @@ public class PolicyTest {
|
|||
public Integer nextAction(INDArray input) {
|
||||
return (int)input.getDouble(0);
|
||||
}
|
||||
|
||||
@Override
|
||||
protected <AS extends ActionSpace<Integer>> Learning.InitMdp<Observation> refacInitMdp(LegacyMDPWrapper<MockEncodable, Integer, AS> mdpWrapper, IHistoryProcessor hp, RefacEpochStepCounter epochStepCounter) {
|
||||
mdpWrapper.setTransformProcess(MockMDP.buildTransformProcess(shape, skipFrame, historyLength));
|
||||
return super.refacInitMdp(mdpWrapper, hp, epochStepCounter);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -2,6 +2,12 @@ package org.deeplearning4j.rl4j.support;
|
|||
|
||||
import org.deeplearning4j.gym.StepReply;
|
||||
import org.deeplearning4j.rl4j.mdp.MDP;
|
||||
import org.deeplearning4j.rl4j.observation.transform.TransformProcess;
|
||||
import org.deeplearning4j.rl4j.observation.transform.filter.UniformSkippingFilter;
|
||||
import org.deeplearning4j.rl4j.observation.transform.legacy.EncodableToINDArrayTransform;
|
||||
import org.deeplearning4j.rl4j.observation.transform.operation.HistoryMergeTransform;
|
||||
import org.deeplearning4j.rl4j.observation.transform.operation.SimpleNormalizationTransform;
|
||||
import org.deeplearning4j.rl4j.observation.transform.operation.historymerge.CircularFifoStore;
|
||||
import org.deeplearning4j.rl4j.space.DiscreteSpace;
|
||||
import org.deeplearning4j.rl4j.space.ObservationSpace;
|
||||
import org.nd4j.linalg.api.rng.Random;
|
||||
|
@ -77,4 +83,16 @@ public class MockMDP implements MDP<MockEncodable, Integer, DiscreteSpace> {
|
|||
public MDP newInstance() {
|
||||
return null;
|
||||
}
|
||||
|
||||
public static TransformProcess buildTransformProcess(int[] shape, int skipFrame, int historyLength) {
|
||||
return TransformProcess.builder()
|
||||
.filter(new UniformSkippingFilter(skipFrame))
|
||||
.transform("data", new EncodableToINDArrayTransform(shape))
|
||||
.transform("data", new SimpleNormalizationTransform(0.0, 255.0))
|
||||
.transform("data", HistoryMergeTransform.builder()
|
||||
.elementStore(new CircularFifoStore(historyLength))
|
||||
.build())
|
||||
.build("data");
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
@ -62,7 +62,7 @@ public class GymEnv<O, A, AS extends ActionSpace<A>> implements MDP<O, A, AS> {
|
|||
private static PyObject globals;
|
||||
static {
|
||||
try {
|
||||
Py_SetPath(org.bytedeco.gym.presets.gym.cachePackages());
|
||||
Py_AddPath(org.bytedeco.gym.presets.gym.cachePackages());
|
||||
program = Py_DecodeLocale(GymEnv.class.getSimpleName(), null);
|
||||
Py_SetProgramName(program);
|
||||
Py_Initialize();
|
||||
|
|
Loading…
Reference in New Issue