RL4J refac: Added some observation transform classes (#7958)
* Added observation classes and tests Signed-off-by: unknown <aboulang2002@yahoo.com> * Now uses DataSetPreProcessors Signed-off-by: Alexandre Boulanger <aboulang2002@yahoo.com> * CompositeDataSetPreProcessor can now stop processing on empty dataset; Some DataSetPreProcessors moving from RL4J to ND4J Signed-off-by: Alexandre Boulanger <aboulang2002@yahoo.com> * Did requested minor changes Signed-off-by: Alexandre Boulanger <Alexandre.Boulanger@ia.ca> Signed-off-by: Alexandre Boulanger <aboulang2002@yahoo.com>master
parent
9bb11d5b06
commit
ee6aae268f
|
@ -16,6 +16,7 @@
|
||||||
|
|
||||||
package org.nd4j.linalg.dataset.api.preprocessor;
|
package org.nd4j.linalg.dataset.api.preprocessor;
|
||||||
|
|
||||||
|
import org.nd4j.base.Preconditions;
|
||||||
import org.nd4j.linalg.dataset.api.DataSet;
|
import org.nd4j.linalg.dataset.api.DataSet;
|
||||||
import org.nd4j.linalg.dataset.api.DataSetPreProcessor;
|
import org.nd4j.linalg.dataset.api.DataSetPreProcessor;
|
||||||
import org.nd4j.linalg.dataset.api.MultiDataSet;
|
import org.nd4j.linalg.dataset.api.MultiDataSet;
|
||||||
|
@ -29,19 +30,35 @@ import org.nd4j.linalg.dataset.api.MultiDataSetPreProcessor;
|
||||||
*/
|
*/
|
||||||
public class CompositeDataSetPreProcessor implements DataSetPreProcessor {
|
public class CompositeDataSetPreProcessor implements DataSetPreProcessor {
|
||||||
|
|
||||||
|
private final boolean stopOnEmptyDataSet;
|
||||||
private DataSetPreProcessor[] preProcessors;
|
private DataSetPreProcessor[] preProcessors;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @param preProcessors Preprocessors to apply. They will be applied in this order
|
* @param preProcessors Preprocessors to apply. They will be applied in this order
|
||||||
*/
|
*/
|
||||||
public CompositeDataSetPreProcessor(DataSetPreProcessor... preProcessors){
|
public CompositeDataSetPreProcessor(DataSetPreProcessor... preProcessors) {
|
||||||
|
this(false, preProcessors);
|
||||||
|
}
|
||||||
|
|
||||||
|
public CompositeDataSetPreProcessor(boolean stopOnEmptyDataSet, DataSetPreProcessor... preProcessors){
|
||||||
|
this.stopOnEmptyDataSet = stopOnEmptyDataSet;
|
||||||
this.preProcessors = preProcessors;
|
this.preProcessors = preProcessors;
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void preProcess(DataSet dataSet) {
|
public void preProcess(DataSet dataSet) {
|
||||||
|
Preconditions.checkNotNull(dataSet, "Encountered null dataSet");
|
||||||
|
|
||||||
|
if(stopOnEmptyDataSet && dataSet.isEmpty()) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
for(DataSetPreProcessor p : preProcessors){
|
for(DataSetPreProcessor p : preProcessors){
|
||||||
p.preProcess(dataSet);
|
p.preProcess(dataSet);
|
||||||
|
|
||||||
|
if(stopOnEmptyDataSet && dataSet.isEmpty()) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -0,0 +1,105 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2015-2019 Skymind, Inc.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
package org.nd4j.linalg.dataset.api.preprocessor;
|
||||||
|
|
||||||
|
import org.nd4j.base.Preconditions;
|
||||||
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
import org.nd4j.linalg.api.ops.CustomOp;
|
||||||
|
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||||
|
import org.nd4j.linalg.api.shape.LongShapeDescriptor;
|
||||||
|
import org.nd4j.linalg.dataset.api.DataSet;
|
||||||
|
import org.nd4j.linalg.dataset.api.DataSetPreProcessor;
|
||||||
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* The CropAndResizeDataSetPreProcessor will crop and resize the processed dataset.
|
||||||
|
* NOTE: The data format must be NHWC
|
||||||
|
*
|
||||||
|
* @author Alexandre Boulanger
|
||||||
|
*/
|
||||||
|
public class CropAndResizeDataSetPreProcessor implements DataSetPreProcessor {
|
||||||
|
|
||||||
|
public enum ResizeMethod {
|
||||||
|
Bilinear,
|
||||||
|
NearestNeighbor
|
||||||
|
}
|
||||||
|
|
||||||
|
private final long[] resizedShape;
|
||||||
|
private final INDArray indices;
|
||||||
|
private final INDArray resize;
|
||||||
|
private final INDArray boxes;
|
||||||
|
private final int method;
|
||||||
|
|
||||||
|
/**
|
||||||
|
*
|
||||||
|
* @param originalHeight Height of the input datasets
|
||||||
|
* @param originalWidth Width of the input datasets
|
||||||
|
* @param cropYStart y coord of the starting point on the input datasets
|
||||||
|
* @param cropXStart x coord of the starting point on the input datasets
|
||||||
|
* @param resizedHeight Height of the output dataset
|
||||||
|
* @param resizedWidth Width of the output dataset
|
||||||
|
* @param numChannels
|
||||||
|
* @param resizeMethod
|
||||||
|
*/
|
||||||
|
public CropAndResizeDataSetPreProcessor(int originalHeight, int originalWidth, int cropYStart, int cropXStart, int resizedHeight, int resizedWidth, int numChannels, ResizeMethod resizeMethod) {
|
||||||
|
Preconditions.checkArgument(originalHeight > 0, "originalHeight must be greater than 0, got %s", originalHeight);
|
||||||
|
Preconditions.checkArgument(originalWidth > 0, "originalWidth must be greater than 0, got %s", originalWidth);
|
||||||
|
Preconditions.checkArgument(cropYStart >= 0, "cropYStart must be greater or equal to 0, got %s", cropYStart);
|
||||||
|
Preconditions.checkArgument(cropXStart >= 0, "cropXStart must be greater or equal to 0, got %s", cropXStart);
|
||||||
|
Preconditions.checkArgument(resizedHeight > 0, "resizedHeight must be greater than 0, got %s", resizedHeight);
|
||||||
|
Preconditions.checkArgument(resizedWidth > 0, "resizedWidth must be greater than 0, got %s", resizedWidth);
|
||||||
|
Preconditions.checkArgument(numChannels > 0, "numChannels must be greater than 0, got %s", numChannels);
|
||||||
|
|
||||||
|
resizedShape = new long[] { 1, resizedHeight, resizedWidth, numChannels };
|
||||||
|
|
||||||
|
boxes = Nd4j.create(new float[] {
|
||||||
|
(float)cropYStart / (float)originalHeight,
|
||||||
|
(float)cropXStart / (float)originalWidth,
|
||||||
|
(float)(cropYStart + resizedHeight) / (float)originalHeight,
|
||||||
|
(float)(cropXStart + resizedWidth) / (float)originalWidth
|
||||||
|
}, new long[] { 1, 4 }, DataType.FLOAT);
|
||||||
|
indices = Nd4j.create(new int[] { 0 }, new long[] { 1, 1 }, DataType.INT);
|
||||||
|
|
||||||
|
resize = Nd4j.create(new int[] { resizedHeight, resizedWidth }, new long[] { 1, 2 }, DataType.INT);
|
||||||
|
method = resizeMethod == ResizeMethod.Bilinear ? 0 : 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* NOTE: The data format must be NHWC
|
||||||
|
*/
|
||||||
|
@Override
|
||||||
|
public void preProcess(DataSet dataSet) {
|
||||||
|
Preconditions.checkNotNull(dataSet, "Encountered null dataSet");
|
||||||
|
|
||||||
|
if(dataSet.isEmpty()) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
INDArray input = dataSet.getFeatures();
|
||||||
|
INDArray output = Nd4j.create(LongShapeDescriptor.fromShape(resizedShape, input.dataType()), false);
|
||||||
|
|
||||||
|
CustomOp op = DynamicCustomOp.builder("crop_and_resize")
|
||||||
|
.addInputs(input, boxes, indices, resize)
|
||||||
|
.addIntegerArguments(method)
|
||||||
|
.addOutputs(output)
|
||||||
|
.build();
|
||||||
|
Nd4j.getExecutioner().exec(op);
|
||||||
|
|
||||||
|
dataSet.setFeatures(output);
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,87 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2015-2019 Skymind, Inc.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
package org.nd4j.linalg.dataset.api.preprocessor;
|
||||||
|
|
||||||
|
import org.nd4j.base.Preconditions;
|
||||||
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
import org.nd4j.linalg.dataset.api.DataSet;
|
||||||
|
import org.nd4j.linalg.dataset.api.DataSetPreProcessor;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* The PermuteDataSetPreProcessor will rearrange the dimensions.
|
||||||
|
* There are two pre-defined permutation types:
|
||||||
|
* - from NCHW to NHWC
|
||||||
|
* - from NHWC to NCHW
|
||||||
|
*
|
||||||
|
* Or, pass the new order to the ctor. For example PermuteDataSetPreProcessor(1, 2, 0) will rearrange the middle dimension first, the last one in the middle and the first one last.
|
||||||
|
*
|
||||||
|
* @author Alexandre Boulanger
|
||||||
|
*/
|
||||||
|
public class PermuteDataSetPreProcessor implements DataSetPreProcessor {
|
||||||
|
|
||||||
|
private final PermutationTypes permutationType;
|
||||||
|
private final int[] rearrange;
|
||||||
|
|
||||||
|
public enum PermutationTypes { NCHWtoNHWC, NHWCtoNCHW, Custom }
|
||||||
|
|
||||||
|
public PermuteDataSetPreProcessor(PermutationTypes permutationType) {
|
||||||
|
Preconditions.checkArgument(permutationType != PermutationTypes.Custom, "Use the ctor PermuteDataSetPreProcessor(int... rearrange) for custom permutations.");
|
||||||
|
|
||||||
|
this.permutationType = permutationType;
|
||||||
|
rearrange = null;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @param rearrange The new order. For example PermuteDataSetPreProcessor(1, 2, 0) will rearrange the middle dimension first, the last one in the middle and the first one last.
|
||||||
|
*/
|
||||||
|
public PermuteDataSetPreProcessor(int... rearrange) {
|
||||||
|
|
||||||
|
this.permutationType = PermutationTypes.Custom;
|
||||||
|
this.rearrange = rearrange;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void preProcess(DataSet dataSet) {
|
||||||
|
Preconditions.checkNotNull(dataSet, "Encountered null dataSet");
|
||||||
|
|
||||||
|
if(dataSet.isEmpty()) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
INDArray input = dataSet.getFeatures();
|
||||||
|
INDArray output;
|
||||||
|
switch (permutationType) {
|
||||||
|
case NCHWtoNHWC:
|
||||||
|
output = input.permute(0, 2, 3, 1);
|
||||||
|
break;
|
||||||
|
|
||||||
|
case NHWCtoNCHW:
|
||||||
|
output = input.permute(0, 3, 1, 2);
|
||||||
|
break;
|
||||||
|
|
||||||
|
case Custom:
|
||||||
|
output = input.permute(rearrange);
|
||||||
|
break;
|
||||||
|
|
||||||
|
default:
|
||||||
|
output = input;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
dataSet.setFeatures(output);
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,70 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2015-2019 Skymind, Inc.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
package org.nd4j.linalg.dataset.api.preprocessor;
|
||||||
|
|
||||||
|
import org.nd4j.base.Preconditions;
|
||||||
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
import org.nd4j.linalg.dataset.api.DataSet;
|
||||||
|
import org.nd4j.linalg.dataset.api.DataSetPreProcessor;
|
||||||
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* The RGBtoGrayscaleDataSetPreProcessor will turn a DataSet of a RGB image into a grayscale one.
|
||||||
|
* NOTE: Expects data format to be NCHW. After processing, the channel dimension is eliminated. (NCHW -> NHW)
|
||||||
|
*
|
||||||
|
* @author Alexandre Boulanger
|
||||||
|
*/
|
||||||
|
public class RGBtoGrayscaleDataSetPreProcessor implements DataSetPreProcessor {
|
||||||
|
|
||||||
|
private static final float RED_RATIO = 0.3f;
|
||||||
|
private static final float GREEN_RATIO = 0.59f;
|
||||||
|
private static final float BLUE_RATIO = 0.11f;
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void preProcess(DataSet dataSet) {
|
||||||
|
Preconditions.checkNotNull(dataSet, "Encountered null dataSet");
|
||||||
|
|
||||||
|
if(dataSet.isEmpty()) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
INDArray originalFeatures = dataSet.getFeatures();
|
||||||
|
long[] originalShape = originalFeatures.shape();
|
||||||
|
|
||||||
|
// result shape is NHW
|
||||||
|
INDArray result = Nd4j.create(originalShape[0], originalShape[2], originalShape[3]);
|
||||||
|
|
||||||
|
for(long n = 0, numExamples = originalShape[0]; n < numExamples; ++n) {
|
||||||
|
// Extract channels
|
||||||
|
INDArray itemFeatures = originalFeatures.slice(n, 0); // shape is CHW
|
||||||
|
INDArray R = itemFeatures.slice(0, 0); // shape is HW
|
||||||
|
INDArray G = itemFeatures.slice(1, 0);
|
||||||
|
INDArray B = itemFeatures.slice(2, 0);
|
||||||
|
|
||||||
|
// Convert
|
||||||
|
R.muli(RED_RATIO);
|
||||||
|
G.muli(GREEN_RATIO);
|
||||||
|
B.muli(BLUE_RATIO);
|
||||||
|
R.addi(G).addi(B);
|
||||||
|
|
||||||
|
// FIXME: int cast
|
||||||
|
result.putSlice((int)n, R);
|
||||||
|
}
|
||||||
|
|
||||||
|
dataSet.setFeatures(result);
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,102 @@
|
||||||
|
package org.nd4j.linalg.dataset.api.preprocessor;
|
||||||
|
|
||||||
|
import org.junit.Test;
|
||||||
|
import org.nd4j.linalg.dataset.DataSet;
|
||||||
|
import org.nd4j.linalg.dataset.api.DataSetPreProcessor;
|
||||||
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
|
|
||||||
|
import static org.junit.Assert.assertFalse;
|
||||||
|
import static org.junit.Assert.assertTrue;
|
||||||
|
|
||||||
|
public class CompositeDataSetPreProcessorTest {
|
||||||
|
@Test(expected = NullPointerException.class)
|
||||||
|
public void when_preConditionsIsNull_expect_NullPointerException() {
|
||||||
|
// Assemble
|
||||||
|
CompositeDataSetPreProcessor sut = new CompositeDataSetPreProcessor();
|
||||||
|
|
||||||
|
// Act
|
||||||
|
sut.preProcess(null);
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void when_dataSetIsEmpty_expect_emptyDataSet() {
|
||||||
|
// Assemble
|
||||||
|
CompositeDataSetPreProcessor sut = new CompositeDataSetPreProcessor();
|
||||||
|
DataSet ds = new DataSet(null, null);
|
||||||
|
|
||||||
|
// Act
|
||||||
|
sut.preProcess(ds);
|
||||||
|
|
||||||
|
// Assert
|
||||||
|
assertTrue(ds.isEmpty());
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void when_notStoppingOnEmptyDataSet_expect_allPreProcessorsCalled() {
|
||||||
|
// Assemble
|
||||||
|
TestDataSetPreProcessor preProcessor1 = new TestDataSetPreProcessor(true);
|
||||||
|
TestDataSetPreProcessor preProcessor2 = new TestDataSetPreProcessor(true);
|
||||||
|
CompositeDataSetPreProcessor sut = new CompositeDataSetPreProcessor(preProcessor1, preProcessor2);
|
||||||
|
DataSet ds = new DataSet(Nd4j.rand(2, 2), null);
|
||||||
|
|
||||||
|
// Act
|
||||||
|
sut.preProcess(ds);
|
||||||
|
|
||||||
|
// Assert
|
||||||
|
assertTrue(preProcessor1.hasBeenCalled);
|
||||||
|
assertTrue(preProcessor2.hasBeenCalled);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void when_stoppingOnEmptyDataSetAndFirstPreProcessorClearDS_expect_firstPreProcessorsCalled() {
|
||||||
|
// Assemble
|
||||||
|
TestDataSetPreProcessor preProcessor1 = new TestDataSetPreProcessor(true);
|
||||||
|
TestDataSetPreProcessor preProcessor2 = new TestDataSetPreProcessor(true);
|
||||||
|
CompositeDataSetPreProcessor sut = new CompositeDataSetPreProcessor(true, preProcessor1, preProcessor2);
|
||||||
|
DataSet ds = new DataSet(Nd4j.rand(2, 2), null);
|
||||||
|
|
||||||
|
// Act
|
||||||
|
sut.preProcess(ds);
|
||||||
|
|
||||||
|
// Assert
|
||||||
|
assertTrue(preProcessor1.hasBeenCalled);
|
||||||
|
assertFalse(preProcessor2.hasBeenCalled);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void when_stoppingOnEmptyDataSet_expect_firstPreProcessorsCalled() {
|
||||||
|
// Assemble
|
||||||
|
TestDataSetPreProcessor preProcessor1 = new TestDataSetPreProcessor(false);
|
||||||
|
TestDataSetPreProcessor preProcessor2 = new TestDataSetPreProcessor(false);
|
||||||
|
CompositeDataSetPreProcessor sut = new CompositeDataSetPreProcessor(true, preProcessor1, preProcessor2);
|
||||||
|
DataSet ds = new DataSet(Nd4j.rand(2, 2), null);
|
||||||
|
|
||||||
|
// Act
|
||||||
|
sut.preProcess(ds);
|
||||||
|
|
||||||
|
// Assert
|
||||||
|
assertTrue(preProcessor1.hasBeenCalled);
|
||||||
|
assertTrue(preProcessor2.hasBeenCalled);
|
||||||
|
}
|
||||||
|
|
||||||
|
public static class TestDataSetPreProcessor implements DataSetPreProcessor {
|
||||||
|
|
||||||
|
private final boolean clearDataSet;
|
||||||
|
|
||||||
|
public boolean hasBeenCalled;
|
||||||
|
|
||||||
|
public TestDataSetPreProcessor(boolean clearDataSet) {
|
||||||
|
this.clearDataSet = clearDataSet;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void preProcess(org.nd4j.linalg.dataset.api.DataSet dataSet) {
|
||||||
|
hasBeenCalled = true;
|
||||||
|
if(clearDataSet) {
|
||||||
|
dataSet.setFeatures(null);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
|
@ -0,0 +1,131 @@
|
||||||
|
package org.nd4j.linalg.dataset.api.preprocessor;
|
||||||
|
|
||||||
|
import org.junit.Test;
|
||||||
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
import org.nd4j.linalg.api.shape.LongShapeDescriptor;
|
||||||
|
import org.nd4j.linalg.dataset.DataSet;
|
||||||
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
|
|
||||||
|
import static org.junit.Assert.assertEquals;
|
||||||
|
import static org.junit.Assert.assertTrue;
|
||||||
|
|
||||||
|
public class CropAndResizeDataSetPreProcessorTest {
|
||||||
|
|
||||||
|
@Test(expected = IllegalArgumentException.class)
|
||||||
|
public void when_originalHeightIsZero_expect_IllegalArgumentException() {
|
||||||
|
CropAndResizeDataSetPreProcessor sut = new CropAndResizeDataSetPreProcessor(0, 15, 5, 5, 4, 3, 3, CropAndResizeDataSetPreProcessor.ResizeMethod.NearestNeighbor);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test(expected = IllegalArgumentException.class)
|
||||||
|
public void when_originalWidthIsZero_expect_IllegalArgumentException() {
|
||||||
|
CropAndResizeDataSetPreProcessor sut = new CropAndResizeDataSetPreProcessor(10, 0, 5, 5, 4, 3, 3, CropAndResizeDataSetPreProcessor.ResizeMethod.NearestNeighbor);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test(expected = IllegalArgumentException.class)
|
||||||
|
public void when_yStartIsNegative_expect_IllegalArgumentException() {
|
||||||
|
CropAndResizeDataSetPreProcessor sut = new CropAndResizeDataSetPreProcessor(10, 15, -1, 5, 4, 3, 3, CropAndResizeDataSetPreProcessor.ResizeMethod.NearestNeighbor);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test(expected = IllegalArgumentException.class)
|
||||||
|
public void when_xStartIsNegative_expect_IllegalArgumentException() {
|
||||||
|
CropAndResizeDataSetPreProcessor sut = new CropAndResizeDataSetPreProcessor(10, 15, 5, -1, 4, 3, 3, CropAndResizeDataSetPreProcessor.ResizeMethod.NearestNeighbor);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test(expected = IllegalArgumentException.class)
|
||||||
|
public void when_heightIsNotGreaterThanZero_expect_IllegalArgumentException() {
|
||||||
|
CropAndResizeDataSetPreProcessor sut = new CropAndResizeDataSetPreProcessor(10, 15, 5, 5, 0, 3, 3, CropAndResizeDataSetPreProcessor.ResizeMethod.NearestNeighbor);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test(expected = IllegalArgumentException.class)
|
||||||
|
public void when_widthIsNotGreaterThanZero_expect_IllegalArgumentException() {
|
||||||
|
CropAndResizeDataSetPreProcessor sut = new CropAndResizeDataSetPreProcessor(10, 15, 5, 5, 4, 0, 3, CropAndResizeDataSetPreProcessor.ResizeMethod.NearestNeighbor);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test(expected = IllegalArgumentException.class)
|
||||||
|
public void when_numChannelsIsNotGreaterThanZero_expect_IllegalArgumentException() {
|
||||||
|
CropAndResizeDataSetPreProcessor sut = new CropAndResizeDataSetPreProcessor(10, 15, 5, 5, 4, 3, 0, CropAndResizeDataSetPreProcessor.ResizeMethod.NearestNeighbor);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test(expected = NullPointerException.class)
|
||||||
|
public void when_dataSetIsNull_expect_NullPointerException() {
|
||||||
|
// Assemble
|
||||||
|
CropAndResizeDataSetPreProcessor sut = new CropAndResizeDataSetPreProcessor(10, 15, 5, 5, 4, 3, 3, CropAndResizeDataSetPreProcessor.ResizeMethod.NearestNeighbor);
|
||||||
|
|
||||||
|
// Act
|
||||||
|
sut.preProcess(null);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void when_dataSetIsEmpty_expect_emptyDataSet() {
|
||||||
|
// Assemble
|
||||||
|
CropAndResizeDataSetPreProcessor sut = new CropAndResizeDataSetPreProcessor(10, 15, 5, 5, 4, 3, 3, CropAndResizeDataSetPreProcessor.ResizeMethod.NearestNeighbor);
|
||||||
|
DataSet ds = new DataSet(null, null);
|
||||||
|
|
||||||
|
// Act
|
||||||
|
sut.preProcess(ds);
|
||||||
|
|
||||||
|
// Assert
|
||||||
|
assertTrue(ds.isEmpty());
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void when_dataSetIs15wx10h_expect_3wx4hDataSet() {
|
||||||
|
// Assemble
|
||||||
|
int numChannels = 3;
|
||||||
|
int height = 10;
|
||||||
|
int width = 15;
|
||||||
|
CropAndResizeDataSetPreProcessor sut = new CropAndResizeDataSetPreProcessor(height, width, 5, 5, 4, 3, 3, CropAndResizeDataSetPreProcessor.ResizeMethod.NearestNeighbor);
|
||||||
|
INDArray input = Nd4j.create(LongShapeDescriptor.fromShape(new int[] { 1, height, width, numChannels }, DataType.FLOAT), true);
|
||||||
|
for(int c = 0; c < numChannels; ++c) {
|
||||||
|
for(int h = 0; h < height; ++h) {
|
||||||
|
for(int w = 0; w < width; ++w) {
|
||||||
|
input.putScalar(0, h, w, c, c*100 + h*10 + w);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
DataSet ds = new DataSet(input, null);
|
||||||
|
|
||||||
|
// Act
|
||||||
|
sut.preProcess(ds);
|
||||||
|
|
||||||
|
// Assert
|
||||||
|
INDArray results = ds.getFeatures();
|
||||||
|
long[] shape = results.shape();
|
||||||
|
assertEquals(1, shape[0]);
|
||||||
|
assertEquals(4, shape[1]);
|
||||||
|
assertEquals(3, shape[2]);
|
||||||
|
assertEquals(3, shape[3]);
|
||||||
|
|
||||||
|
// Test a few values
|
||||||
|
assertEquals(55.0, results.getDouble(0, 0, 0, 0), 0.0);
|
||||||
|
assertEquals(155.0, results.getDouble(0, 0, 0, 1), 0.0);
|
||||||
|
assertEquals(255.0, results.getDouble(0, 0, 0, 2), 0.0);
|
||||||
|
|
||||||
|
assertEquals(56.0, results.getDouble(0, 0, 1, 0), 0.0);
|
||||||
|
assertEquals(156.0, results.getDouble(0, 0, 1, 1), 0.0);
|
||||||
|
assertEquals(256.0, results.getDouble(0, 0, 1, 2), 0.0);
|
||||||
|
|
||||||
|
assertEquals(57.0, results.getDouble(0, 0, 2, 0), 0.0);
|
||||||
|
assertEquals(157.0, results.getDouble(0, 0, 2, 1), 0.0);
|
||||||
|
assertEquals(257.0, results.getDouble(0, 0, 2, 2), 0.0);
|
||||||
|
|
||||||
|
assertEquals(65.0, results.getDouble(0, 1, 0, 0), 0.0);
|
||||||
|
assertEquals(165.0, results.getDouble(0, 1, 0, 1), 0.0);
|
||||||
|
assertEquals(265.0, results.getDouble(0, 1, 0, 2), 0.0);
|
||||||
|
|
||||||
|
assertEquals(66.0, results.getDouble(0, 1, 1, 0), 0.0);
|
||||||
|
assertEquals(166.0, results.getDouble(0, 1, 1, 1), 0.0);
|
||||||
|
assertEquals(266.0, results.getDouble(0, 1, 1, 2), 0.0);
|
||||||
|
|
||||||
|
assertEquals(75.0, results.getDouble(0, 2, 0, 0), 0.0);
|
||||||
|
assertEquals(175.0, results.getDouble(0, 2, 0, 1), 0.0);
|
||||||
|
assertEquals(275.0, results.getDouble(0, 2, 0, 2), 0.0);
|
||||||
|
|
||||||
|
assertEquals(76.0, results.getDouble(0, 2, 1, 0), 0.0);
|
||||||
|
assertEquals(176.0, results.getDouble(0, 2, 1, 1), 0.0);
|
||||||
|
assertEquals(276.0, results.getDouble(0, 2, 1, 2), 0.0);
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
|
@ -0,0 +1,124 @@
|
||||||
|
package org.nd4j.linalg.dataset.api.preprocessor;
|
||||||
|
|
||||||
|
import org.nd4j.linalg.dataset.api.preprocessor.PermuteDataSetPreProcessor;
|
||||||
|
import org.junit.Test;
|
||||||
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
import org.nd4j.linalg.dataset.DataSet;
|
||||||
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
|
|
||||||
|
import static org.junit.Assert.*;
|
||||||
|
|
||||||
|
public class PermuteDataSetPreProcessorTest {
|
||||||
|
|
||||||
|
@Test(expected = NullPointerException.class)
|
||||||
|
public void when_dataSetIsNull_expect_NullPointerException() {
|
||||||
|
// Assemble
|
||||||
|
PermuteDataSetPreProcessor sut = new PermuteDataSetPreProcessor(PermuteDataSetPreProcessor.PermutationTypes.NCHWtoNHWC);
|
||||||
|
|
||||||
|
// Act
|
||||||
|
sut.preProcess(null);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void when_emptyDatasetInInputdataSetIsNCHW_expect_emptyDataSet() {
|
||||||
|
// Assemble
|
||||||
|
PermuteDataSetPreProcessor sut = new PermuteDataSetPreProcessor(PermuteDataSetPreProcessor.PermutationTypes.NCHWtoNHWC);
|
||||||
|
DataSet ds = new DataSet(null, null);
|
||||||
|
|
||||||
|
// Act
|
||||||
|
sut.preProcess(ds);
|
||||||
|
|
||||||
|
// Assert
|
||||||
|
assertTrue(ds.isEmpty());
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void when_dataSetIsNCHW_expect_dataSetTransformedToNHWC() {
|
||||||
|
// Assemble
|
||||||
|
int numChannels = 3;
|
||||||
|
int height = 5;
|
||||||
|
int width = 4;
|
||||||
|
PermuteDataSetPreProcessor sut = new PermuteDataSetPreProcessor(PermuteDataSetPreProcessor.PermutationTypes.NCHWtoNHWC);
|
||||||
|
INDArray input = Nd4j.create(1, numChannels, height, width);
|
||||||
|
for(int c = 0; c < numChannels; ++c) {
|
||||||
|
for(int h = 0; h < height; ++h) {
|
||||||
|
for(int w = 0; w < width; ++w) {
|
||||||
|
input.putScalar(0, c, h, w, c*100.0 + h*10.0 + w);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
DataSet ds = new DataSet(input, null);
|
||||||
|
|
||||||
|
// Act
|
||||||
|
sut.preProcess(ds);
|
||||||
|
|
||||||
|
// Assert
|
||||||
|
INDArray result = ds.getFeatures();
|
||||||
|
long[] shape = result.shape();
|
||||||
|
assertEquals(1, shape[0]);
|
||||||
|
assertEquals(height, shape[1]);
|
||||||
|
assertEquals(width, shape[2]);
|
||||||
|
assertEquals(numChannels, shape[3]);
|
||||||
|
|
||||||
|
assertEquals(0.0, result.getDouble(0, 0, 0, 0), 0.0);
|
||||||
|
assertEquals(1.0, result.getDouble(0, 0, 1, 0), 0.0);
|
||||||
|
assertEquals(2.0, result.getDouble(0, 0, 2, 0), 0.0);
|
||||||
|
assertEquals(3.0, result.getDouble(0, 0, 3, 0), 0.0);
|
||||||
|
|
||||||
|
assertEquals(110.0, result.getDouble(0, 1, 0, 1), 0.0);
|
||||||
|
assertEquals(111.0, result.getDouble(0, 1, 1, 1), 0.0);
|
||||||
|
assertEquals(112.0, result.getDouble(0, 1, 2, 1), 0.0);
|
||||||
|
assertEquals(113.0, result.getDouble(0, 1, 3, 1), 0.0);
|
||||||
|
|
||||||
|
assertEquals(210.0, result.getDouble(0, 1, 0, 2), 0.0);
|
||||||
|
assertEquals(211.0, result.getDouble(0, 1, 1, 2), 0.0);
|
||||||
|
assertEquals(212.0, result.getDouble(0, 1, 2, 2), 0.0);
|
||||||
|
assertEquals(213.0, result.getDouble(0, 1, 3, 2), 0.0);
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void when_dataSetIsNHWC_expect_dataSetTransformedToNCHW() {
|
||||||
|
// Assemble
|
||||||
|
int numChannels = 3;
|
||||||
|
int height = 5;
|
||||||
|
int width = 4;
|
||||||
|
PermuteDataSetPreProcessor sut = new PermuteDataSetPreProcessor(PermuteDataSetPreProcessor.PermutationTypes.NHWCtoNCHW);
|
||||||
|
INDArray input = Nd4j.create(1, height, width, numChannels);
|
||||||
|
for(int c = 0; c < numChannels; ++c) {
|
||||||
|
for(int h = 0; h < height; ++h) {
|
||||||
|
for(int w = 0; w < width; ++w) {
|
||||||
|
input.putScalar(new int[] { 0, h, w, c }, c*100.0 + h*10.0 + w);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
DataSet ds = new DataSet(input, null);
|
||||||
|
|
||||||
|
// Act
|
||||||
|
sut.preProcess(ds);
|
||||||
|
|
||||||
|
// Assert
|
||||||
|
INDArray result = ds.getFeatures();
|
||||||
|
long[] shape = result.shape();
|
||||||
|
assertEquals(1, shape[0]);
|
||||||
|
assertEquals(numChannels, shape[1]);
|
||||||
|
assertEquals(height, shape[2]);
|
||||||
|
assertEquals(width, shape[3]);
|
||||||
|
|
||||||
|
assertEquals(0.0, result.getDouble(0, 0, 0, 0), 0.0);
|
||||||
|
assertEquals(1.0, result.getDouble(0, 0, 0, 1), 0.0);
|
||||||
|
assertEquals(2.0, result.getDouble(0, 0, 0, 2), 0.0);
|
||||||
|
assertEquals(3.0, result.getDouble(0, 0, 0, 3), 0.0);
|
||||||
|
|
||||||
|
assertEquals(110.0, result.getDouble(0, 1, 1, 0), 0.0);
|
||||||
|
assertEquals(111.0, result.getDouble(0, 1, 1, 1), 0.0);
|
||||||
|
assertEquals(112.0, result.getDouble(0, 1, 1, 2), 0.0);
|
||||||
|
assertEquals(113.0, result.getDouble(0, 1, 1, 3), 0.0);
|
||||||
|
|
||||||
|
assertEquals(210.0, result.getDouble(0, 2, 1, 0), 0.0);
|
||||||
|
assertEquals(211.0, result.getDouble(0, 2, 1, 1), 0.0);
|
||||||
|
assertEquals(212.0, result.getDouble(0, 2, 1, 2), 0.0);
|
||||||
|
assertEquals(213.0, result.getDouble(0, 2, 1, 3), 0.0);
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,123 @@
|
||||||
|
package org.nd4j.linalg.dataset.api.preprocessor;
|
||||||
|
|
||||||
|
import org.junit.Test;
|
||||||
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
import org.nd4j.linalg.dataset.DataSet;
|
||||||
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
|
|
||||||
|
import static org.junit.Assert.assertEquals;
|
||||||
|
import static org.junit.Assert.assertTrue;
|
||||||
|
|
||||||
|
public class RGBtoGrayscaleDataSetPreProcessorTest {
|
||||||
|
|
||||||
|
@Test(expected = NullPointerException.class)
|
||||||
|
public void when_dataSetIsNull_expect_NullPointerException() {
|
||||||
|
// Assemble
|
||||||
|
RGBtoGrayscaleDataSetPreProcessor sut = new RGBtoGrayscaleDataSetPreProcessor();
|
||||||
|
|
||||||
|
// Act
|
||||||
|
sut.preProcess(null);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void when_dataSetIsEmpty_expect_EmptyDataSet() {
|
||||||
|
// Assemble
|
||||||
|
RGBtoGrayscaleDataSetPreProcessor sut = new RGBtoGrayscaleDataSetPreProcessor();
|
||||||
|
DataSet ds = new DataSet(null, null);
|
||||||
|
|
||||||
|
// Act
|
||||||
|
sut.preProcess(ds);
|
||||||
|
|
||||||
|
// Assert
|
||||||
|
assertTrue(ds.isEmpty());
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void when_colorsAreConverted_expect_grayScaleResult() {
|
||||||
|
// Assign
|
||||||
|
int numChannels = 3;
|
||||||
|
int height = 1;
|
||||||
|
int width = 5;
|
||||||
|
|
||||||
|
RGBtoGrayscaleDataSetPreProcessor sut = new RGBtoGrayscaleDataSetPreProcessor();
|
||||||
|
INDArray input = Nd4j.create(2, numChannels, height, width);
|
||||||
|
|
||||||
|
// Black, Example 1
|
||||||
|
input.putScalar(0, 0, 0, 0, 0.0 );
|
||||||
|
input.putScalar(0, 1, 0, 0, 0.0 );
|
||||||
|
input.putScalar(0, 2, 0, 0, 0.0 );
|
||||||
|
|
||||||
|
// White, Example 1
|
||||||
|
input.putScalar(0, 0, 0, 1, 255.0 );
|
||||||
|
input.putScalar(0, 1, 0, 1, 255.0 );
|
||||||
|
input.putScalar(0, 2, 0, 1, 255.0 );
|
||||||
|
|
||||||
|
// Red, Example 1
|
||||||
|
input.putScalar(0, 0, 0, 2, 255.0 );
|
||||||
|
input.putScalar(0, 1, 0, 2, 0.0 );
|
||||||
|
input.putScalar(0, 2, 0, 2, 0.0 );
|
||||||
|
|
||||||
|
// Green, Example 1
|
||||||
|
input.putScalar(0, 0, 0, 3, 0.0 );
|
||||||
|
input.putScalar(0, 1, 0, 3, 255.0 );
|
||||||
|
input.putScalar(0, 2, 0, 3, 0.0 );
|
||||||
|
|
||||||
|
// Blue, Example 1
|
||||||
|
input.putScalar(0, 0, 0, 4, 0.0 );
|
||||||
|
input.putScalar(0, 1, 0, 4, 0.0 );
|
||||||
|
input.putScalar(0, 2, 0, 4, 255.0 );
|
||||||
|
|
||||||
|
|
||||||
|
// Black, Example 2
|
||||||
|
input.putScalar(1, 0, 0, 4, 0.0 );
|
||||||
|
input.putScalar(1, 1, 0, 4, 0.0 );
|
||||||
|
input.putScalar(1, 2, 0, 4, 0.0 );
|
||||||
|
|
||||||
|
// White, Example 2
|
||||||
|
input.putScalar(1, 0, 0, 3, 255.0 );
|
||||||
|
input.putScalar(1, 1, 0, 3, 255.0 );
|
||||||
|
input.putScalar(1, 2, 0, 3, 255.0 );
|
||||||
|
|
||||||
|
// Red, Example 2
|
||||||
|
input.putScalar(1, 0, 0, 2, 255.0 );
|
||||||
|
input.putScalar(1, 1, 0, 2, 0.0 );
|
||||||
|
input.putScalar(1, 2, 0, 2, 0.0 );
|
||||||
|
|
||||||
|
// Green, Example 2
|
||||||
|
input.putScalar(1, 0, 0, 1, 0.0 );
|
||||||
|
input.putScalar(1, 1, 0, 1, 255.0 );
|
||||||
|
input.putScalar(1, 2, 0, 1, 0.0 );
|
||||||
|
|
||||||
|
// Blue, Example 2
|
||||||
|
input.putScalar(1, 0, 0, 0, 0.0 );
|
||||||
|
input.putScalar(1, 1, 0, 0, 0.0 );
|
||||||
|
input.putScalar(1, 2, 0, 0, 255.0 );
|
||||||
|
|
||||||
|
DataSet ds = new DataSet(input, null);
|
||||||
|
|
||||||
|
// Act
|
||||||
|
sut.preProcess(ds);
|
||||||
|
|
||||||
|
// Assert
|
||||||
|
INDArray result = ds.getFeatures();
|
||||||
|
long[] shape = result.shape();
|
||||||
|
|
||||||
|
assertEquals(3, shape.length);
|
||||||
|
assertEquals(2, shape[0]);
|
||||||
|
assertEquals(1, shape[1]);
|
||||||
|
assertEquals(5, shape[2]);
|
||||||
|
|
||||||
|
assertEquals(0.0, result.getDouble(0, 0, 0), 0.05);
|
||||||
|
assertEquals(255.0, result.getDouble(0, 0, 1), 0.05);
|
||||||
|
assertEquals(255.0 * 0.3, result.getDouble(0, 0, 2), 0.05);
|
||||||
|
assertEquals(255.0 * 0.59, result.getDouble(0, 0, 3), 0.05);
|
||||||
|
assertEquals(255.0 * 0.11, result.getDouble(0, 0, 4), 0.05);
|
||||||
|
|
||||||
|
assertEquals(0.0, result.getDouble(1, 0, 4), 0.05);
|
||||||
|
assertEquals(255.0, result.getDouble(1, 0, 3), 0.05);
|
||||||
|
assertEquals(255.0 * 0.3, result.getDouble(1, 0, 2), 0.05);
|
||||||
|
assertEquals(255.0 * 0.59, result.getDouble(1, 0, 1), 0.05);
|
||||||
|
assertEquals(255.0 * 0.11, result.getDouble(1, 0, 0), 0.05);
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,130 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2015-2019 Skymind, Inc.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
package org.deeplearning4j.rl4j.observation.preprocessor;
|
||||||
|
|
||||||
|
import org.deeplearning4j.rl4j.observation.preprocessor.pooling.ChannelStackPoolContentAssembler;
|
||||||
|
import org.deeplearning4j.rl4j.observation.preprocessor.pooling.PoolContentAssembler;
|
||||||
|
import org.deeplearning4j.rl4j.observation.preprocessor.pooling.CircularFifoObservationPool;
|
||||||
|
import org.deeplearning4j.rl4j.observation.preprocessor.pooling.ObservationPool;
|
||||||
|
import org.nd4j.base.Preconditions;
|
||||||
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
import org.nd4j.linalg.dataset.api.DataSet;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* The PoolingDataSetPreProcessor will accumulate features from incoming DataSets and will assemble its content
|
||||||
|
* into a DataSet containing a single example.
|
||||||
|
*
|
||||||
|
* There are two special cases:
|
||||||
|
* 1) preProcess will return without doing anything if the input DataSet is empty
|
||||||
|
* 2) When the pool has not yet filled, the data from the incoming DataSet is stored in the pool but the DataSet is emptied
|
||||||
|
* on exit.
|
||||||
|
* <br>
|
||||||
|
* The PoolingDataSetPreProcessor requires two sub components: <br>
|
||||||
|
* 1) The ObservationPool that supervises what and how input observations are kept. (ex.: Circular FIFO, trailing min/max/avg, etc...)
|
||||||
|
* The default is a Circular FIFO.
|
||||||
|
* 2) The PoolContentAssembler that will assemble the pool content into a resulting single INDArray. (ex.: stacked along a dimention, squashed into a single observation, etc...)
|
||||||
|
* The default is stacking along the dimension 0.
|
||||||
|
*
|
||||||
|
* @author Alexandre Boulanger
|
||||||
|
*/
|
||||||
|
public class PoolingDataSetPreProcessor extends ResettableDataSetPreProcessor {
|
||||||
|
private final ObservationPool observationPool;
|
||||||
|
private final PoolContentAssembler poolContentAssembler;
|
||||||
|
|
||||||
|
protected PoolingDataSetPreProcessor(PoolingDataSetPreProcessor.Builder builder)
|
||||||
|
{
|
||||||
|
observationPool = builder.observationPool;
|
||||||
|
poolContentAssembler = builder.poolContentAssembler;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Note: preProcess will empty the processed dataset if the pool has not filled. Empty datasets should ignored by the
|
||||||
|
* Policy/Learning class and other DataSetPreProcessors
|
||||||
|
*
|
||||||
|
* @param dataSet
|
||||||
|
*/
|
||||||
|
@Override
|
||||||
|
public void preProcess(DataSet dataSet) {
|
||||||
|
Preconditions.checkNotNull(dataSet, "Encountered null dataSet");
|
||||||
|
|
||||||
|
if(dataSet.isEmpty()) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
Preconditions.checkArgument(dataSet.numExamples() == 1, "Pooling datasets conatining more than one example is not supported");
|
||||||
|
|
||||||
|
// store a duplicate in the pool
|
||||||
|
observationPool.add(dataSet.getFeatures().slice(0, 0).dup());
|
||||||
|
if(!observationPool.isAtFullCapacity()) {
|
||||||
|
dataSet.setFeatures(null);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
INDArray result = poolContentAssembler.assemble(observationPool.get());
|
||||||
|
|
||||||
|
// return a DataSet containing only 1 example (the result)
|
||||||
|
long[] resultShape = result.shape();
|
||||||
|
long[] newShape = new long[resultShape.length + 1];
|
||||||
|
newShape[0] = 1;
|
||||||
|
System.arraycopy(resultShape, 0, newShape, 1, resultShape.length);
|
||||||
|
|
||||||
|
dataSet.setFeatures(result.reshape(newShape));
|
||||||
|
}
|
||||||
|
|
||||||
|
public static PoolingDataSetPreProcessor.Builder builder() {
|
||||||
|
return new PoolingDataSetPreProcessor.Builder();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void reset() {
|
||||||
|
observationPool.reset();
|
||||||
|
}
|
||||||
|
|
||||||
|
public static class Builder {
|
||||||
|
private ObservationPool observationPool;
|
||||||
|
private PoolContentAssembler poolContentAssembler;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Default is CircularFifoObservationPool
|
||||||
|
*/
|
||||||
|
public PoolingDataSetPreProcessor.Builder observablePool(ObservationPool observationPool) {
|
||||||
|
this.observationPool = observationPool;
|
||||||
|
return this;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Default is ChannelStackPoolContentAssembler
|
||||||
|
*/
|
||||||
|
public PoolingDataSetPreProcessor.Builder poolContentAssembler(PoolContentAssembler poolContentAssembler) {
|
||||||
|
this.poolContentAssembler = poolContentAssembler;
|
||||||
|
return this;
|
||||||
|
}
|
||||||
|
|
||||||
|
public PoolingDataSetPreProcessor build() {
|
||||||
|
if(observationPool == null) {
|
||||||
|
observationPool = new CircularFifoObservationPool();
|
||||||
|
}
|
||||||
|
|
||||||
|
if(poolContentAssembler == null) {
|
||||||
|
poolContentAssembler = new ChannelStackPoolContentAssembler();
|
||||||
|
}
|
||||||
|
|
||||||
|
return new PoolingDataSetPreProcessor(this);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
|
@ -0,0 +1,28 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2015-2019 Skymind, Inc.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
package org.deeplearning4j.rl4j.observation.preprocessor;
|
||||||
|
|
||||||
|
import org.nd4j.linalg.dataset.api.DataSetPreProcessor;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* A base class for all DataSetPreProcessor that must be reset between each MDP sessions (games).
|
||||||
|
*
|
||||||
|
* @author Alexandre Boulanger
|
||||||
|
*/
|
||||||
|
public abstract class ResettableDataSetPreProcessor implements DataSetPreProcessor {
|
||||||
|
public abstract void reset();
|
||||||
|
}
|
|
@ -0,0 +1,62 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2015-2019 Skymind, Inc.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
package org.deeplearning4j.rl4j.observation.preprocessor;
|
||||||
|
|
||||||
|
import lombok.Builder;
|
||||||
|
import org.nd4j.base.Preconditions;
|
||||||
|
import org.nd4j.linalg.dataset.api.DataSet;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* The SkippingDataSetPreProcessor will either do nothing to the input (when not skipped) or will empty
|
||||||
|
* the input DataSet when skipping.
|
||||||
|
*
|
||||||
|
* @author Alexandre Boulanger
|
||||||
|
*/
|
||||||
|
public class SkippingDataSetPreProcessor extends ResettableDataSetPreProcessor {
|
||||||
|
|
||||||
|
private final int skipFrame;
|
||||||
|
|
||||||
|
private int currentIdx = 0;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @param skipFrame For example, a skipFrame of 4 will skip 3 out of 4 observations.
|
||||||
|
*/
|
||||||
|
@Builder
|
||||||
|
public SkippingDataSetPreProcessor(int skipFrame) {
|
||||||
|
Preconditions.checkArgument(skipFrame > 0, "skipFrame must be greater than 0, got %s", skipFrame);
|
||||||
|
this.skipFrame = skipFrame;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void preProcess(DataSet dataSet) {
|
||||||
|
Preconditions.checkNotNull(dataSet, "Encountered null dataSet");
|
||||||
|
|
||||||
|
if(dataSet.isEmpty()) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
if(currentIdx++ % skipFrame != 0) {
|
||||||
|
dataSet.setFeatures(null);
|
||||||
|
dataSet.setLabels(null);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void reset() {
|
||||||
|
currentIdx = 0;
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,53 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2015-2019 Skymind, Inc.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
package org.deeplearning4j.rl4j.observation.preprocessor.pooling;
|
||||||
|
|
||||||
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* ChannelStackPoolContentAssembler is used with the PoolingDataSetPreProcessor. This assembler will
|
||||||
|
* stack along the dimension 0. For example if the pool elements are of shape [ Height, Width ]
|
||||||
|
* the output will be of shape [ Stacked, Height, Width ]
|
||||||
|
*
|
||||||
|
* @author Alexandre Boulanger
|
||||||
|
*/
|
||||||
|
public class ChannelStackPoolContentAssembler implements PoolContentAssembler {
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Will return a new INDArray with one more dimension and with poolContent stacked along dimension 0.
|
||||||
|
*
|
||||||
|
* @param poolContent Array of INDArray
|
||||||
|
* @return A new INDArray with 1 more dimension than the input elements
|
||||||
|
*/
|
||||||
|
@Override
|
||||||
|
public INDArray assemble(INDArray[] poolContent)
|
||||||
|
{
|
||||||
|
// build the new shape
|
||||||
|
long[] elementShape = poolContent[0].shape();
|
||||||
|
long[] newShape = new long[elementShape.length + 1];
|
||||||
|
newShape[0] = poolContent.length;
|
||||||
|
System.arraycopy(elementShape, 0, newShape, 1, elementShape.length);
|
||||||
|
|
||||||
|
// put pool elements in result
|
||||||
|
INDArray result = Nd4j.create(newShape);
|
||||||
|
for(int i = 0; i < poolContent.length; ++i) {
|
||||||
|
result.putRow(i, poolContent[i]);
|
||||||
|
}
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,95 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2015-2019 Skymind, Inc.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
package org.deeplearning4j.rl4j.observation.preprocessor.pooling;
|
||||||
|
|
||||||
|
import org.apache.commons.collections4.queue.CircularFifoQueue;
|
||||||
|
import org.nd4j.base.Preconditions;
|
||||||
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* CircularFifoObservationPool is used with the PoolingDataSetPreProcessor. This pool is a first-in first-out queue
|
||||||
|
* with a fixed size that replaces its oldest element if full.
|
||||||
|
*
|
||||||
|
* @author Alexandre Boulanger
|
||||||
|
*/
|
||||||
|
public class CircularFifoObservationPool implements ObservationPool {
|
||||||
|
private static final int DEFAULT_POOL_SIZE = 4;
|
||||||
|
|
||||||
|
private final CircularFifoQueue<INDArray> queue;
|
||||||
|
|
||||||
|
private CircularFifoObservationPool(Builder builder) {
|
||||||
|
queue = new CircularFifoQueue<>(builder.poolSize);
|
||||||
|
}
|
||||||
|
|
||||||
|
public CircularFifoObservationPool()
|
||||||
|
{
|
||||||
|
this(DEFAULT_POOL_SIZE);
|
||||||
|
}
|
||||||
|
|
||||||
|
public CircularFifoObservationPool(int poolSize)
|
||||||
|
{
|
||||||
|
Preconditions.checkArgument(poolSize > 0, "The pool size must be at least 1, got %s", poolSize);
|
||||||
|
queue = new CircularFifoQueue<>(poolSize);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Add an element to the pool, if this addition would make the pool to overflow, the added element replaces the oldest one.
|
||||||
|
* @param elem
|
||||||
|
*/
|
||||||
|
public void add(INDArray elem) {
|
||||||
|
queue.add(elem);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @return The content of the pool, returned in order from oldest to newest.
|
||||||
|
*/
|
||||||
|
public INDArray[] get() {
|
||||||
|
int size = queue.size();
|
||||||
|
INDArray[] array = new INDArray[size];
|
||||||
|
for (int i = 0; i < size; ++i) {
|
||||||
|
array[i] = queue.get(i).castTo(Nd4j.dataType());
|
||||||
|
}
|
||||||
|
return array;
|
||||||
|
}
|
||||||
|
|
||||||
|
public boolean isAtFullCapacity() {
|
||||||
|
return queue.isAtFullCapacity();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void reset() {
|
||||||
|
queue.clear();
|
||||||
|
}
|
||||||
|
|
||||||
|
public static Builder builder() {
|
||||||
|
return new Builder();
|
||||||
|
}
|
||||||
|
|
||||||
|
public static class Builder {
|
||||||
|
private int poolSize = DEFAULT_POOL_SIZE;
|
||||||
|
|
||||||
|
public Builder poolSize(int poolSize) {
|
||||||
|
this.poolSize = poolSize;
|
||||||
|
return this;
|
||||||
|
}
|
||||||
|
|
||||||
|
public CircularFifoObservationPool build() {
|
||||||
|
return new CircularFifoObservationPool(this);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,32 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2015-2019 Skymind, Inc.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
package org.deeplearning4j.rl4j.observation.preprocessor.pooling;
|
||||||
|
|
||||||
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* ObservationPool is used with the PoolingDataSetPreProcessor. Used to supervise how data from the
|
||||||
|
* PoolingDataSetPreProcessor is stored.
|
||||||
|
*
|
||||||
|
* @author Alexandre Boulanger
|
||||||
|
*/
|
||||||
|
public interface ObservationPool {
|
||||||
|
void add(INDArray observation);
|
||||||
|
INDArray[] get();
|
||||||
|
boolean isAtFullCapacity();
|
||||||
|
void reset();
|
||||||
|
}
|
|
@ -0,0 +1,30 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2015-2019 Skymind, Inc.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
package org.deeplearning4j.rl4j.observation.preprocessor.pooling;
|
||||||
|
|
||||||
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* A PoolContentAssembler is used with the PoolingDataSetPreProcessor. This interface defines how the array of INDArray
|
||||||
|
* returned by the ObservationPool is packaged into the single INDArray that will be set
|
||||||
|
* in the DataSet of PoolingDataSetPreProcessor.preProcess
|
||||||
|
*
|
||||||
|
* @author Alexandre Boulanger
|
||||||
|
*/
|
||||||
|
public interface PoolContentAssembler {
|
||||||
|
INDArray assemble(INDArray[] poolContent);
|
||||||
|
}
|
|
@ -0,0 +1,164 @@
|
||||||
|
package org.deeplearning4j.rl4j.observation.preprocessor;
|
||||||
|
|
||||||
|
import org.deeplearning4j.rl4j.observation.preprocessor.pooling.ObservationPool;
|
||||||
|
import org.deeplearning4j.rl4j.observation.preprocessor.pooling.PoolContentAssembler;
|
||||||
|
import org.junit.Assert;
|
||||||
|
import org.junit.Test;
|
||||||
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
import org.nd4j.linalg.dataset.DataSet;
|
||||||
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
|
|
||||||
|
import static junit.framework.TestCase.assertTrue;
|
||||||
|
import static org.junit.Assert.assertEquals;
|
||||||
|
|
||||||
|
public class PoolingDataSetPreProcessorTest {
|
||||||
|
|
||||||
|
@Test(expected = NullPointerException.class)
|
||||||
|
public void when_dataSetIsNull_expect_NullPointerException() {
|
||||||
|
// Assemble
|
||||||
|
PoolingDataSetPreProcessor sut = PoolingDataSetPreProcessor.builder().build();
|
||||||
|
|
||||||
|
// Act
|
||||||
|
sut.preProcess(null);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test(expected = IllegalArgumentException.class)
|
||||||
|
public void when_dataSetHasMoreThanOneExample_expect_IllegalArgumentException() {
|
||||||
|
// Assemble
|
||||||
|
PoolingDataSetPreProcessor sut = PoolingDataSetPreProcessor.builder().build();
|
||||||
|
|
||||||
|
// Act
|
||||||
|
sut.preProcess(new DataSet(Nd4j.rand(new long[] { 2, 2, 2 }), null));
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void when_dataSetIsEmpty_expect_EmptyDataSet() {
|
||||||
|
// Assemble
|
||||||
|
PoolingDataSetPreProcessor sut = PoolingDataSetPreProcessor.builder().build();
|
||||||
|
DataSet ds = new DataSet(null, null);
|
||||||
|
|
||||||
|
// Act
|
||||||
|
sut.preProcess(ds);
|
||||||
|
|
||||||
|
// Assert
|
||||||
|
Assert.assertTrue(ds.isEmpty());
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void when_builderHasNoPoolOrAssembler_expect_defaultPoolBehavior() {
|
||||||
|
// Arrange
|
||||||
|
PoolingDataSetPreProcessor sut = PoolingDataSetPreProcessor.builder().build();
|
||||||
|
DataSet[] observations = new DataSet[5];
|
||||||
|
INDArray[] inputs = new INDArray[5];
|
||||||
|
|
||||||
|
|
||||||
|
// Act
|
||||||
|
for(int i = 0; i < 5; ++i) {
|
||||||
|
inputs[i] = Nd4j.rand(new long[] { 1, 2, 2 });
|
||||||
|
DataSet input = new DataSet(inputs[i], null);
|
||||||
|
sut.preProcess(input);
|
||||||
|
observations[i] = input;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Assert
|
||||||
|
assertTrue(observations[0].isEmpty());
|
||||||
|
assertTrue(observations[1].isEmpty());
|
||||||
|
assertTrue(observations[2].isEmpty());
|
||||||
|
|
||||||
|
for(int i = 0; i < 4; ++i) {
|
||||||
|
assertEquals(inputs[i].getDouble(new int[] { 0, 0, 0 }), observations[3].getFeatures().getDouble(new int[] { 0, i, 0, 0 }), 0.0001);
|
||||||
|
assertEquals(inputs[i].getDouble(new int[] { 0, 0, 1 }), observations[3].getFeatures().getDouble(new int[] { 0, i, 0, 1 }), 0.0001);
|
||||||
|
assertEquals(inputs[i].getDouble(new int[] { 0, 1, 0 }), observations[3].getFeatures().getDouble(new int[] { 0, i, 1, 0 }), 0.0001);
|
||||||
|
assertEquals(inputs[i].getDouble(new int[] { 0, 1, 1 }), observations[3].getFeatures().getDouble(new int[] { 0, i, 1, 1 }), 0.0001);
|
||||||
|
}
|
||||||
|
|
||||||
|
for(int i = 0; i < 4; ++i) {
|
||||||
|
assertEquals(inputs[i+1].getDouble(new int[] { 0, 0, 0 }), observations[4].getFeatures().getDouble(new int[] { 0, i, 0, 0 }), 0.0001);
|
||||||
|
assertEquals(inputs[i+1].getDouble(new int[] { 0, 0, 1 }), observations[4].getFeatures().getDouble(new int[] { 0, i, 0, 1 }), 0.0001);
|
||||||
|
assertEquals(inputs[i+1].getDouble(new int[] { 0, 1, 0 }), observations[4].getFeatures().getDouble(new int[] { 0, i, 1, 0 }), 0.0001);
|
||||||
|
assertEquals(inputs[i+1].getDouble(new int[] { 0, 1, 1 }), observations[4].getFeatures().getDouble(new int[] { 0, i, 1, 1 }), 0.0001);
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void when_builderHasPoolAndAssembler_expect_paramPoolAndAssemblerAreUsed() {
|
||||||
|
// Arrange
|
||||||
|
INDArray input = Nd4j.rand(1, 1);
|
||||||
|
TestObservationPool pool = new TestObservationPool();
|
||||||
|
TestPoolContentAssembler assembler = new TestPoolContentAssembler();
|
||||||
|
PoolingDataSetPreProcessor sut = PoolingDataSetPreProcessor.builder()
|
||||||
|
.observablePool(pool)
|
||||||
|
.poolContentAssembler(assembler)
|
||||||
|
.build();
|
||||||
|
|
||||||
|
// Act
|
||||||
|
sut.preProcess(new DataSet(input, null));
|
||||||
|
|
||||||
|
// Assert
|
||||||
|
assertTrue(pool.isAtFullCapacityCalled);
|
||||||
|
assertTrue(pool.isGetCalled);
|
||||||
|
assertEquals(input.getDouble(0), pool.observation.getDouble(0), 0.0);
|
||||||
|
assertTrue(assembler.assembleIsCalled);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void when_pastInputChanges_expect_outputNotChanged() {
|
||||||
|
// Arrange
|
||||||
|
INDArray input = Nd4j.zeros(1, 1);
|
||||||
|
TestObservationPool pool = new TestObservationPool();
|
||||||
|
TestPoolContentAssembler assembler = new TestPoolContentAssembler();
|
||||||
|
PoolingDataSetPreProcessor sut = PoolingDataSetPreProcessor.builder()
|
||||||
|
.observablePool(pool)
|
||||||
|
.poolContentAssembler(assembler)
|
||||||
|
.build();
|
||||||
|
|
||||||
|
// Act
|
||||||
|
sut.preProcess(new DataSet(input, null));
|
||||||
|
input.putScalar(0, 0, 1.0);
|
||||||
|
|
||||||
|
// Assert
|
||||||
|
assertEquals(0.0, pool.observation.getDouble(0), 0.0);
|
||||||
|
}
|
||||||
|
|
||||||
|
private static class TestObservationPool implements ObservationPool {
|
||||||
|
|
||||||
|
public INDArray observation;
|
||||||
|
public boolean isGetCalled;
|
||||||
|
public boolean isAtFullCapacityCalled;
|
||||||
|
private boolean isResetCalled;
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void add(INDArray observation) {
|
||||||
|
this.observation = observation;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public INDArray[] get() {
|
||||||
|
isGetCalled = true;
|
||||||
|
return new INDArray[0];
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public boolean isAtFullCapacity() {
|
||||||
|
isAtFullCapacityCalled = true;
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void reset() {
|
||||||
|
isResetCalled = true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private static class TestPoolContentAssembler implements PoolContentAssembler {
|
||||||
|
|
||||||
|
public boolean assembleIsCalled;
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public INDArray assemble(INDArray[] poolContent) {
|
||||||
|
assembleIsCalled = true;
|
||||||
|
return Nd4j.create(1, 1);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,70 @@
|
||||||
|
package org.deeplearning4j.rl4j.observation.preprocessor;
|
||||||
|
|
||||||
|
import org.junit.Test;
|
||||||
|
import org.nd4j.linalg.dataset.DataSet;
|
||||||
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
|
|
||||||
|
import static org.junit.Assert.assertFalse;
|
||||||
|
import static org.junit.Assert.assertTrue;
|
||||||
|
|
||||||
|
public class SkippingDataSetPreProcessorTest {
|
||||||
|
@Test(expected = IllegalArgumentException.class)
|
||||||
|
public void when_ctorSkipFrameIsZero_expect_IllegalArgumentException() {
|
||||||
|
SkippingDataSetPreProcessor sut = new SkippingDataSetPreProcessor(0);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test(expected = IllegalArgumentException.class)
|
||||||
|
public void when_builderSkipFrameIsZero_expect_IllegalArgumentException() {
|
||||||
|
SkippingDataSetPreProcessor sut = SkippingDataSetPreProcessor.builder()
|
||||||
|
.skipFrame(0)
|
||||||
|
.build();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void when_skipFrameIs3_expect_Skip2OutOf3() {
|
||||||
|
// Arrange
|
||||||
|
SkippingDataSetPreProcessor sut = SkippingDataSetPreProcessor.builder()
|
||||||
|
.skipFrame(3)
|
||||||
|
.build();
|
||||||
|
DataSet[] results = new DataSet[4];
|
||||||
|
|
||||||
|
// Act
|
||||||
|
for(int i = 0; i < 4; ++i) {
|
||||||
|
results[i] = new DataSet(Nd4j.create(new double[] { 123.0 }), null);
|
||||||
|
sut.preProcess(results[i]);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Assert
|
||||||
|
assertFalse(results[0].isEmpty());
|
||||||
|
assertTrue(results[1].isEmpty());
|
||||||
|
assertTrue(results[2].isEmpty());
|
||||||
|
assertFalse(results[3].isEmpty());
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void when_resetIsCalled_expect_skippingIsReset() {
|
||||||
|
// Arrange
|
||||||
|
SkippingDataSetPreProcessor sut = SkippingDataSetPreProcessor.builder()
|
||||||
|
.skipFrame(3)
|
||||||
|
.build();
|
||||||
|
DataSet[] results = new DataSet[4];
|
||||||
|
|
||||||
|
// Act
|
||||||
|
results[0] = new DataSet(Nd4j.create(new double[] { 123.0 }), null);
|
||||||
|
results[1] = new DataSet(Nd4j.create(new double[] { 123.0 }), null);
|
||||||
|
results[2] = new DataSet(Nd4j.create(new double[] { 123.0 }), null);
|
||||||
|
results[3] = new DataSet(Nd4j.create(new double[] { 123.0 }), null);
|
||||||
|
|
||||||
|
sut.preProcess(results[0]);
|
||||||
|
sut.preProcess(results[1]);
|
||||||
|
sut.reset();
|
||||||
|
sut.preProcess(results[2]);
|
||||||
|
sut.preProcess(results[3]);
|
||||||
|
|
||||||
|
// Assert
|
||||||
|
assertFalse(results[0].isEmpty());
|
||||||
|
assertTrue(results[1].isEmpty());
|
||||||
|
assertFalse(results[2].isEmpty());
|
||||||
|
assertTrue(results[3].isEmpty());
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,41 @@
|
||||||
|
package org.deeplearning4j.rl4j.observation.preprocessor.pooling;
|
||||||
|
|
||||||
|
import org.junit.Test;
|
||||||
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
|
|
||||||
|
import static org.junit.Assert.assertEquals;
|
||||||
|
|
||||||
|
public class ChannelStackPoolContentAssemblerTest {
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void when_assemble_expect_poolContentStackedOnChannel() {
|
||||||
|
// Assemble
|
||||||
|
ChannelStackPoolContentAssembler sut = new ChannelStackPoolContentAssembler();
|
||||||
|
INDArray[] poolContent = new INDArray[] {
|
||||||
|
Nd4j.rand(2, 2),
|
||||||
|
Nd4j.rand(2, 2),
|
||||||
|
};
|
||||||
|
|
||||||
|
// Act
|
||||||
|
INDArray result = sut.assemble(poolContent);
|
||||||
|
|
||||||
|
// Assert
|
||||||
|
assertEquals(3, result.shape().length);
|
||||||
|
assertEquals(2, result.shape()[0]);
|
||||||
|
assertEquals(2, result.shape()[1]);
|
||||||
|
assertEquals(2, result.shape()[2]);
|
||||||
|
|
||||||
|
assertEquals(poolContent[0].getDouble(0, 0), result.getDouble(0, 0, 0), 0.0001);
|
||||||
|
assertEquals(poolContent[0].getDouble(0, 1), result.getDouble(0, 0, 1), 0.0001);
|
||||||
|
assertEquals(poolContent[0].getDouble(1, 0), result.getDouble(0, 1, 0), 0.0001);
|
||||||
|
assertEquals(poolContent[0].getDouble(1, 1), result.getDouble(0, 1, 1), 0.0001);
|
||||||
|
|
||||||
|
assertEquals(poolContent[1].getDouble(0, 0), result.getDouble(1, 0, 0), 0.0001);
|
||||||
|
assertEquals(poolContent[1].getDouble(0, 1), result.getDouble(1, 0, 1), 0.0001);
|
||||||
|
assertEquals(poolContent[1].getDouble(1, 0), result.getDouble(1, 1, 0), 0.0001);
|
||||||
|
assertEquals(poolContent[1].getDouble(1, 1), result.getDouble(1, 1, 1), 0.0001);
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
|
@ -0,0 +1,100 @@
|
||||||
|
package org.deeplearning4j.rl4j.observation.preprocessor.pooling;
|
||||||
|
|
||||||
|
import org.junit.Test;
|
||||||
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
|
|
||||||
|
import static org.junit.Assert.assertEquals;
|
||||||
|
import static org.junit.Assert.assertFalse;
|
||||||
|
import static org.junit.Assert.assertTrue;
|
||||||
|
|
||||||
|
public class CircularFifoObservationPoolTest {
|
||||||
|
|
||||||
|
@Test(expected = IllegalArgumentException.class)
|
||||||
|
public void when_poolSizeZeroOrLess_expect_IllegalArgumentException() {
|
||||||
|
CircularFifoObservationPool sut = new CircularFifoObservationPool(0);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void when_poolIsEmpty_expect_NotReady() {
|
||||||
|
// Assemble
|
||||||
|
CircularFifoObservationPool sut = new CircularFifoObservationPool();
|
||||||
|
|
||||||
|
// Act
|
||||||
|
boolean isReady = sut.isAtFullCapacity();
|
||||||
|
|
||||||
|
// Assert
|
||||||
|
assertFalse(isReady);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void when_notEnoughElementsInPool_expect_notReady() {
|
||||||
|
// Assemble
|
||||||
|
CircularFifoObservationPool sut = new CircularFifoObservationPool();
|
||||||
|
sut.add(Nd4j.create(new double[] { 123.0 }));
|
||||||
|
|
||||||
|
// Act
|
||||||
|
boolean isReady = sut.isAtFullCapacity();
|
||||||
|
|
||||||
|
// Assert
|
||||||
|
assertFalse(isReady);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void when_enoughElementsInPool_expect_ready() {
|
||||||
|
// Assemble
|
||||||
|
CircularFifoObservationPool sut = CircularFifoObservationPool.builder()
|
||||||
|
.poolSize(2)
|
||||||
|
.build();
|
||||||
|
sut.add(Nd4j.createFromArray(123.0));
|
||||||
|
sut.add(Nd4j.createFromArray(123.0));
|
||||||
|
|
||||||
|
// Act
|
||||||
|
boolean isReady = sut.isAtFullCapacity();
|
||||||
|
|
||||||
|
// Assert
|
||||||
|
assertTrue(isReady);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void when_addMoreThanSize_expect_getReturnOnlyLastElements() {
|
||||||
|
// Assemble
|
||||||
|
CircularFifoObservationPool sut = CircularFifoObservationPool.builder().build();
|
||||||
|
sut.add(Nd4j.createFromArray(0.0));
|
||||||
|
sut.add(Nd4j.createFromArray(1.0));
|
||||||
|
sut.add(Nd4j.createFromArray(2.0));
|
||||||
|
sut.add(Nd4j.createFromArray(3.0));
|
||||||
|
sut.add(Nd4j.createFromArray(4.0));
|
||||||
|
sut.add(Nd4j.createFromArray(5.0));
|
||||||
|
sut.add(Nd4j.createFromArray(6.0));
|
||||||
|
|
||||||
|
// Act
|
||||||
|
INDArray[] result = sut.get();
|
||||||
|
|
||||||
|
// Assert
|
||||||
|
assertEquals(3.0, result[0].getDouble(0), 0.0);
|
||||||
|
assertEquals(4.0, result[1].getDouble(0), 0.0);
|
||||||
|
assertEquals(5.0, result[2].getDouble(0), 0.0);
|
||||||
|
assertEquals(6.0, result[3].getDouble(0), 0.0);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void when_resetIsCalled_expect_poolContentFlushed() {
|
||||||
|
// Assemble
|
||||||
|
CircularFifoObservationPool sut = CircularFifoObservationPool.builder().build();
|
||||||
|
sut.add(Nd4j.createFromArray(0.0));
|
||||||
|
sut.add(Nd4j.createFromArray(1.0));
|
||||||
|
sut.add(Nd4j.createFromArray(2.0));
|
||||||
|
sut.add(Nd4j.createFromArray(3.0));
|
||||||
|
sut.add(Nd4j.createFromArray(4.0));
|
||||||
|
sut.add(Nd4j.createFromArray(5.0));
|
||||||
|
sut.add(Nd4j.createFromArray(6.0));
|
||||||
|
sut.reset();
|
||||||
|
|
||||||
|
// Act
|
||||||
|
INDArray[] result = sut.get();
|
||||||
|
|
||||||
|
// Assert
|
||||||
|
assertEquals(0, result.length);
|
||||||
|
}
|
||||||
|
}
|
Loading…
Reference in New Issue