Shugeo resize area fix4 (#229)

* Fixed a couple of issues with resize_area op.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Added additional test for alternate params for resize_area testing.

Signed-off-by: shugeo <sgazeos@gmail.com>
master
shugeo 2020-02-12 18:02:42 +02:00 committed by GitHub
parent 11cb561045
commit f0c684020f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 66 additions and 7 deletions

View File

@ -80,8 +80,8 @@ namespace nd4j {
"resize_area: Resize params is a pair of values, not %i.", newImageSize->lengthOf());
REQUIRE_TRUE(block.numI() <= 1, 0,
"resize_area: Resize params already given by the second param. Int params are expensive.");
width = newImageSize->e<int>(0);
height = newImageSize->e<int>(1);
width = newImageSize->e<int>(1);
height = newImageSize->e<int>(0);
}
else {
REQUIRE_TRUE(block.numI() == 2, 0, "resize_area: Resize params ommited as pair ints nor int tensor.");
@ -95,13 +95,13 @@ namespace nd4j {
outputShape[0] = inRank;
if (inRank == 4) {
outputShape[1] = in[1];
outputShape[2] = width;
outputShape[3] = height;
outputShape[2] = height;
outputShape[3] = width;
outputShape[4] = in[4];
}
else {
outputShape[1] = width;
outputShape[2] = height;
outputShape[1] = height;
outputShape[2] = width;
outputShape[3] = in[3];
}
ShapeUtils::updateStridesAndType(outputShape, DataType::FLOAT32, shape::order(in));

View File

@ -1116,7 +1116,7 @@ namespace helpers {
err = cudaMemcpyAsync(pSt, &st, sizeof(ImageResizerState), cudaMemcpyHostToDevice, *stream);
ScaleCache<T>* cachePool;
err = cudaMalloc(&cachePool, sizeof(ScaleCache<T>) * st.batchSize * st.outWidth * st.outHeight);
resizeAreaKernel<T><<<128, 4, 2048, *stream>>>(pSt, cache, scale, inputPtr, input->getSpecialShapeInfo(), outputPtr,
resizeAreaKernel<T><<<128, 2, 2048, *stream>>>(pSt, cache, scale, inputPtr, input->getSpecialShapeInfo(), outputPtr,
output->specialShapeInfo(), cachePool);
err = cudaStreamSynchronize(*stream);
err = cudaFree(cachePool);

View File

@ -1520,6 +1520,65 @@ TEST_F(DeclarableOpsTests11, ImageResizeArea_Test13) {
delete results;
}
TEST_F(DeclarableOpsTests11, ImageResizeArea_Test14) {
NDArray input = NDArrayFactory::create<int>('c', {1, 5, 5, 1}, {
1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25
});
auto size = NDArrayFactory::create<int>({8, 7});
NDArray expected = NDArrayFactory::create<float>('c', {1, 8, 7, 1}, {
1.f, 1.6f , 2.1999993f, 2.9999995f , 3.8f , 4.399997f, 5.f , 2.9999995f , 3.5999997f , 4.199999f,
4.9999995f, 5.8f , 6.3999963f , 7.f , 5.999999f , 6.6f, 7.1999984f , 7.9999995f , 8.8f,
9.399994f, 10.f , 10.f, 10.6f , 11.199998f, 12.f, 12.8f, 13.399992f , 14.f, 12.f , 12.599999f,
13.199998f , 13.999998f , 14.800002f , 15.399991f , 16.f , 15.999999f , 16.599998f , 17.199995f,
18.f , 18.800003f , 19.399986f , 20.000002f , 19.f , 19.599998f , 20.199997f ,
20.999998f , 21.800003f , 22.399984f , 23.000002f , 20.999998f ,
21.599998f , 22.199995f , 22.999998f , 23.800001f , 24.399984f ,
25.f
}); //input.linspace(1);
// auto size = NDArrayFactory::create<int>({6, 6});
nd4j::ops::resize_area op;
auto results = op.evaluate({&input, &size}, {}, {false});
ASSERT_EQ(ND4J_STATUS_OK, results->status());
NDArray* result = results->at(0);
// result->printBuffer("Area Resized to 8x7");
// expected.printBuffer("Area Expect for 8x7");
ASSERT_TRUE(expected.isSameShape(result));
ASSERT_TRUE(expected.equalsTo(result));
delete results;
}
TEST_F(DeclarableOpsTests11, ImageResizeArea_Test15) {
NDArray input = NDArrayFactory::create<int>('c', {1, 5, 5, 1}, {
1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25
});
//auto size = NDArrayFactory::create<int>({8, 7});
NDArray expected = NDArrayFactory::create<float>('c', {1, 8, 7, 1}, {
1.f, 1.6f , 2.1999993f, 2.9999995f , 3.8f , 4.399997f, 5.f , 2.9999995f , 3.5999997f , 4.199999f,
4.9999995f, 5.8f , 6.3999963f , 7.f , 5.999999f , 6.6f, 7.1999984f , 7.9999995f , 8.8f,
9.399994f, 10.f , 10.f, 10.6f , 11.199998f, 12.f, 12.8f, 13.399992f , 14.f, 12.f , 12.599999f,
13.199998f , 13.999998f , 14.800002f , 15.399991f , 16.f , 15.999999f , 16.599998f , 17.199995f,
18.f , 18.800003f , 19.399986f , 20.000002f , 19.f , 19.599998f , 20.199997f ,
20.999998f , 21.800003f , 22.399984f , 23.000002f , 20.999998f , 21.599998f , 22.199995f ,
22.999998f , 23.800001f , 24.399984f , 25.f
});
nd4j::ops::resize_area op;
auto results = op.evaluate({&input}, {}, {8, 7}, {false});
ASSERT_EQ(ND4J_STATUS_OK, results->status());
NDArray* result = results->at(0);
// result->printBuffer("Area Resized to 8x7");
// expected.printBuffer("Area Expect for 8x7");
ASSERT_TRUE(expected.isSameShape(result));
ASSERT_TRUE(expected.equalsTo(result));
delete results;
}
///////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests11, summaryStatsData_test1) {