diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/resize_area.cpp b/libnd4j/include/ops/declarable/generic/parity_ops/resize_area.cpp index b0f637c45..984672ad2 100644 --- a/libnd4j/include/ops/declarable/generic/parity_ops/resize_area.cpp +++ b/libnd4j/include/ops/declarable/generic/parity_ops/resize_area.cpp @@ -80,8 +80,8 @@ namespace nd4j { "resize_area: Resize params is a pair of values, not %i.", newImageSize->lengthOf()); REQUIRE_TRUE(block.numI() <= 1, 0, "resize_area: Resize params already given by the second param. Int params are expensive."); - width = newImageSize->e(0); - height = newImageSize->e(1); + width = newImageSize->e(1); + height = newImageSize->e(0); } else { REQUIRE_TRUE(block.numI() == 2, 0, "resize_area: Resize params ommited as pair ints nor int tensor."); @@ -95,13 +95,13 @@ namespace nd4j { outputShape[0] = inRank; if (inRank == 4) { outputShape[1] = in[1]; - outputShape[2] = width; - outputShape[3] = height; + outputShape[2] = height; + outputShape[3] = width; outputShape[4] = in[4]; } else { - outputShape[1] = width; - outputShape[2] = height; + outputShape[1] = height; + outputShape[2] = width; outputShape[3] = in[3]; } ShapeUtils::updateStridesAndType(outputShape, DataType::FLOAT32, shape::order(in)); diff --git a/libnd4j/include/ops/declarable/helpers/cuda/image_resize.cu b/libnd4j/include/ops/declarable/helpers/cuda/image_resize.cu index 94df35964..c028daff3 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/image_resize.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/image_resize.cu @@ -1116,7 +1116,7 @@ namespace helpers { err = cudaMemcpyAsync(pSt, &st, sizeof(ImageResizerState), cudaMemcpyHostToDevice, *stream); ScaleCache* cachePool; err = cudaMalloc(&cachePool, sizeof(ScaleCache) * st.batchSize * st.outWidth * st.outHeight); - resizeAreaKernel<<<128, 4, 2048, *stream>>>(pSt, cache, scale, inputPtr, input->getSpecialShapeInfo(), outputPtr, + resizeAreaKernel<<<128, 2, 2048, *stream>>>(pSt, cache, scale, inputPtr, input->getSpecialShapeInfo(), outputPtr, output->specialShapeInfo(), cachePool); err = cudaStreamSynchronize(*stream); err = cudaFree(cachePool); diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests11.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests11.cpp index 465703768..aeecaccef 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests11.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests11.cpp @@ -1520,6 +1520,65 @@ TEST_F(DeclarableOpsTests11, ImageResizeArea_Test13) { delete results; } +TEST_F(DeclarableOpsTests11, ImageResizeArea_Test14) { + + NDArray input = NDArrayFactory::create('c', {1, 5, 5, 1}, { + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25 + }); + auto size = NDArrayFactory::create({8, 7}); + NDArray expected = NDArrayFactory::create('c', {1, 8, 7, 1}, { + 1.f, 1.6f , 2.1999993f, 2.9999995f , 3.8f , 4.399997f, 5.f , 2.9999995f , 3.5999997f , 4.199999f, + 4.9999995f, 5.8f , 6.3999963f , 7.f , 5.999999f , 6.6f, 7.1999984f , 7.9999995f , 8.8f, + 9.399994f, 10.f , 10.f, 10.6f , 11.199998f, 12.f, 12.8f, 13.399992f , 14.f, 12.f , 12.599999f, + 13.199998f , 13.999998f , 14.800002f , 15.399991f , 16.f , 15.999999f , 16.599998f , 17.199995f, + 18.f , 18.800003f , 19.399986f , 20.000002f , 19.f , 19.599998f , 20.199997f , + 20.999998f , 21.800003f , 22.399984f , 23.000002f , 20.999998f , + 21.599998f , 22.199995f , 22.999998f , 23.800001f , 24.399984f , + 25.f + }); //input.linspace(1); +// auto size = NDArrayFactory::create({6, 6}); + nd4j::ops::resize_area op; + auto results = op.evaluate({&input, &size}, {}, {false}); + + ASSERT_EQ(ND4J_STATUS_OK, results->status()); + + NDArray* result = results->at(0); +// result->printBuffer("Area Resized to 8x7"); +// expected.printBuffer("Area Expect for 8x7"); + ASSERT_TRUE(expected.isSameShape(result)); + ASSERT_TRUE(expected.equalsTo(result)); + delete results; +} + +TEST_F(DeclarableOpsTests11, ImageResizeArea_Test15) { + + NDArray input = NDArrayFactory::create('c', {1, 5, 5, 1}, { + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25 + }); + //auto size = NDArrayFactory::create({8, 7}); + NDArray expected = NDArrayFactory::create('c', {1, 8, 7, 1}, { + 1.f, 1.6f , 2.1999993f, 2.9999995f , 3.8f , 4.399997f, 5.f , 2.9999995f , 3.5999997f , 4.199999f, + 4.9999995f, 5.8f , 6.3999963f , 7.f , 5.999999f , 6.6f, 7.1999984f , 7.9999995f , 8.8f, + 9.399994f, 10.f , 10.f, 10.6f , 11.199998f, 12.f, 12.8f, 13.399992f , 14.f, 12.f , 12.599999f, + 13.199998f , 13.999998f , 14.800002f , 15.399991f , 16.f , 15.999999f , 16.599998f , 17.199995f, + 18.f , 18.800003f , 19.399986f , 20.000002f , 19.f , 19.599998f , 20.199997f , + 20.999998f , 21.800003f , 22.399984f , 23.000002f , 20.999998f , 21.599998f , 22.199995f , + 22.999998f , 23.800001f , 24.399984f , 25.f + }); + + nd4j::ops::resize_area op; + auto results = op.evaluate({&input}, {}, {8, 7}, {false}); + + ASSERT_EQ(ND4J_STATUS_OK, results->status()); + + NDArray* result = results->at(0); +// result->printBuffer("Area Resized to 8x7"); +// expected.printBuffer("Area Expect for 8x7"); + ASSERT_TRUE(expected.isSameShape(result)); + ASSERT_TRUE(expected.equalsTo(result)); + delete results; +} + /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests11, summaryStatsData_test1) {