Shugeo resize area fix4 (#229)
* Fixed a couple of issues with resize_area op. Signed-off-by: shugeo <sgazeos@gmail.com> * Added additional test for alternate params for resize_area testing. Signed-off-by: shugeo <sgazeos@gmail.com>master
parent
11cb561045
commit
f0c684020f
|
@ -80,8 +80,8 @@ namespace nd4j {
|
||||||
"resize_area: Resize params is a pair of values, not %i.", newImageSize->lengthOf());
|
"resize_area: Resize params is a pair of values, not %i.", newImageSize->lengthOf());
|
||||||
REQUIRE_TRUE(block.numI() <= 1, 0,
|
REQUIRE_TRUE(block.numI() <= 1, 0,
|
||||||
"resize_area: Resize params already given by the second param. Int params are expensive.");
|
"resize_area: Resize params already given by the second param. Int params are expensive.");
|
||||||
width = newImageSize->e<int>(0);
|
width = newImageSize->e<int>(1);
|
||||||
height = newImageSize->e<int>(1);
|
height = newImageSize->e<int>(0);
|
||||||
}
|
}
|
||||||
else {
|
else {
|
||||||
REQUIRE_TRUE(block.numI() == 2, 0, "resize_area: Resize params ommited as pair ints nor int tensor.");
|
REQUIRE_TRUE(block.numI() == 2, 0, "resize_area: Resize params ommited as pair ints nor int tensor.");
|
||||||
|
@ -95,13 +95,13 @@ namespace nd4j {
|
||||||
outputShape[0] = inRank;
|
outputShape[0] = inRank;
|
||||||
if (inRank == 4) {
|
if (inRank == 4) {
|
||||||
outputShape[1] = in[1];
|
outputShape[1] = in[1];
|
||||||
outputShape[2] = width;
|
outputShape[2] = height;
|
||||||
outputShape[3] = height;
|
outputShape[3] = width;
|
||||||
outputShape[4] = in[4];
|
outputShape[4] = in[4];
|
||||||
}
|
}
|
||||||
else {
|
else {
|
||||||
outputShape[1] = width;
|
outputShape[1] = height;
|
||||||
outputShape[2] = height;
|
outputShape[2] = width;
|
||||||
outputShape[3] = in[3];
|
outputShape[3] = in[3];
|
||||||
}
|
}
|
||||||
ShapeUtils::updateStridesAndType(outputShape, DataType::FLOAT32, shape::order(in));
|
ShapeUtils::updateStridesAndType(outputShape, DataType::FLOAT32, shape::order(in));
|
||||||
|
|
|
@ -1116,7 +1116,7 @@ namespace helpers {
|
||||||
err = cudaMemcpyAsync(pSt, &st, sizeof(ImageResizerState), cudaMemcpyHostToDevice, *stream);
|
err = cudaMemcpyAsync(pSt, &st, sizeof(ImageResizerState), cudaMemcpyHostToDevice, *stream);
|
||||||
ScaleCache<T>* cachePool;
|
ScaleCache<T>* cachePool;
|
||||||
err = cudaMalloc(&cachePool, sizeof(ScaleCache<T>) * st.batchSize * st.outWidth * st.outHeight);
|
err = cudaMalloc(&cachePool, sizeof(ScaleCache<T>) * st.batchSize * st.outWidth * st.outHeight);
|
||||||
resizeAreaKernel<T><<<128, 4, 2048, *stream>>>(pSt, cache, scale, inputPtr, input->getSpecialShapeInfo(), outputPtr,
|
resizeAreaKernel<T><<<128, 2, 2048, *stream>>>(pSt, cache, scale, inputPtr, input->getSpecialShapeInfo(), outputPtr,
|
||||||
output->specialShapeInfo(), cachePool);
|
output->specialShapeInfo(), cachePool);
|
||||||
err = cudaStreamSynchronize(*stream);
|
err = cudaStreamSynchronize(*stream);
|
||||||
err = cudaFree(cachePool);
|
err = cudaFree(cachePool);
|
||||||
|
|
|
@ -1520,6 +1520,65 @@ TEST_F(DeclarableOpsTests11, ImageResizeArea_Test13) {
|
||||||
delete results;
|
delete results;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_F(DeclarableOpsTests11, ImageResizeArea_Test14) {
|
||||||
|
|
||||||
|
NDArray input = NDArrayFactory::create<int>('c', {1, 5, 5, 1}, {
|
||||||
|
1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25
|
||||||
|
});
|
||||||
|
auto size = NDArrayFactory::create<int>({8, 7});
|
||||||
|
NDArray expected = NDArrayFactory::create<float>('c', {1, 8, 7, 1}, {
|
||||||
|
1.f, 1.6f , 2.1999993f, 2.9999995f , 3.8f , 4.399997f, 5.f , 2.9999995f , 3.5999997f , 4.199999f,
|
||||||
|
4.9999995f, 5.8f , 6.3999963f , 7.f , 5.999999f , 6.6f, 7.1999984f , 7.9999995f , 8.8f,
|
||||||
|
9.399994f, 10.f , 10.f, 10.6f , 11.199998f, 12.f, 12.8f, 13.399992f , 14.f, 12.f , 12.599999f,
|
||||||
|
13.199998f , 13.999998f , 14.800002f , 15.399991f , 16.f , 15.999999f , 16.599998f , 17.199995f,
|
||||||
|
18.f , 18.800003f , 19.399986f , 20.000002f , 19.f , 19.599998f , 20.199997f ,
|
||||||
|
20.999998f , 21.800003f , 22.399984f , 23.000002f , 20.999998f ,
|
||||||
|
21.599998f , 22.199995f , 22.999998f , 23.800001f , 24.399984f ,
|
||||||
|
25.f
|
||||||
|
}); //input.linspace(1);
|
||||||
|
// auto size = NDArrayFactory::create<int>({6, 6});
|
||||||
|
nd4j::ops::resize_area op;
|
||||||
|
auto results = op.evaluate({&input, &size}, {}, {false});
|
||||||
|
|
||||||
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
NDArray* result = results->at(0);
|
||||||
|
// result->printBuffer("Area Resized to 8x7");
|
||||||
|
// expected.printBuffer("Area Expect for 8x7");
|
||||||
|
ASSERT_TRUE(expected.isSameShape(result));
|
||||||
|
ASSERT_TRUE(expected.equalsTo(result));
|
||||||
|
delete results;
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(DeclarableOpsTests11, ImageResizeArea_Test15) {
|
||||||
|
|
||||||
|
NDArray input = NDArrayFactory::create<int>('c', {1, 5, 5, 1}, {
|
||||||
|
1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25
|
||||||
|
});
|
||||||
|
//auto size = NDArrayFactory::create<int>({8, 7});
|
||||||
|
NDArray expected = NDArrayFactory::create<float>('c', {1, 8, 7, 1}, {
|
||||||
|
1.f, 1.6f , 2.1999993f, 2.9999995f , 3.8f , 4.399997f, 5.f , 2.9999995f , 3.5999997f , 4.199999f,
|
||||||
|
4.9999995f, 5.8f , 6.3999963f , 7.f , 5.999999f , 6.6f, 7.1999984f , 7.9999995f , 8.8f,
|
||||||
|
9.399994f, 10.f , 10.f, 10.6f , 11.199998f, 12.f, 12.8f, 13.399992f , 14.f, 12.f , 12.599999f,
|
||||||
|
13.199998f , 13.999998f , 14.800002f , 15.399991f , 16.f , 15.999999f , 16.599998f , 17.199995f,
|
||||||
|
18.f , 18.800003f , 19.399986f , 20.000002f , 19.f , 19.599998f , 20.199997f ,
|
||||||
|
20.999998f , 21.800003f , 22.399984f , 23.000002f , 20.999998f , 21.599998f , 22.199995f ,
|
||||||
|
22.999998f , 23.800001f , 24.399984f , 25.f
|
||||||
|
});
|
||||||
|
|
||||||
|
nd4j::ops::resize_area op;
|
||||||
|
auto results = op.evaluate({&input}, {}, {8, 7}, {false});
|
||||||
|
|
||||||
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
NDArray* result = results->at(0);
|
||||||
|
// result->printBuffer("Area Resized to 8x7");
|
||||||
|
// expected.printBuffer("Area Expect for 8x7");
|
||||||
|
ASSERT_TRUE(expected.isSameShape(result));
|
||||||
|
ASSERT_TRUE(expected.equalsTo(result));
|
||||||
|
delete results;
|
||||||
|
}
|
||||||
|
|
||||||
///////////////////////////////////////////////////////////////////
|
///////////////////////////////////////////////////////////////////
|
||||||
TEST_F(DeclarableOpsTests11, summaryStatsData_test1) {
|
TEST_F(DeclarableOpsTests11, summaryStatsData_test1) {
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue