diff --git a/libnd4j/include/ops/declarable/helpers/cuda/image_resize.cu b/libnd4j/include/ops/declarable/helpers/cuda/image_resize.cu index d483f87b3..180c8ad0e 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/image_resize.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/image_resize.cu @@ -1066,7 +1066,7 @@ namespace helpers { const Nd4jLong yStart = math::nd4j_floor(inY); const Nd4jLong yEnd = math::nd4j_ceil(inY1); auto scalesDim = yEnd - yStart; - auto yScaleCache = cachePool + (batch * pSt->outWidth + y) * scalesDim * sizeof(ScaleCache); + auto yScaleCache = cachePool + (batch * pSt->outHeight + y) * pSt->outWidth; //auto startPtr = sharedPtr + y * scalesDim * sizeof(float); //float* yScales = yScalesShare + y * sizeof(float) * scalesDim;//reinterpret_cast(startPtr); //shared + y * scalesDim * y + scalesDim * sizeof(T const *) [scalesDim]; @@ -1113,14 +1113,34 @@ namespace helpers { auto outputPtr = reinterpret_cast(output->specialBuffer()); // output is always float. TO DO: provide another float types also with template declaration ImageResizerState* pSt; auto err = cudaMalloc(&pSt, sizeof(ImageResizerState)); + if (err != 0) { + throw cuda_exception::build("helpers::resizeArea: Cannot allocate memory for ImageResizerState", err); + } + err = cudaMemcpyAsync(pSt, &st, sizeof(ImageResizerState), cudaMemcpyHostToDevice, *stream); + if (err != 0) { + throw cuda_exception::build("helpers::resizeArea: Cannot copy to device memory", err); + } ScaleCache* cachePool; - err = cudaMalloc(&cachePool, sizeof(ScaleCache) * st.batchSize * st.outWidth * st.outHeight); - resizeAreaKernel<<<128, 2, 2048, *stream>>>(pSt, cache, scale, inputPtr, input->specialShapeInfo(), outputPtr, + auto cachePoolSize = sizeof(ScaleCache) * st.batchSize * st.outWidth * st.outHeight; + err = cudaMalloc(&cachePool, cachePoolSize); + if (err != 0) { + throw cuda_exception::build("helpers::resizeArea: Cannot allocate memory for cache", err); + } + resizeAreaKernel<<<128, 128, 2048, *stream>>>(pSt, cache, scale, inputPtr, input->specialShapeInfo(), outputPtr, output->specialShapeInfo(), cachePool); err = cudaStreamSynchronize(*stream); + if (err != 0) { + throw cuda_exception::build("helpers::resizeArea: An error occured with kernel running", err); + } err = cudaFree(cachePool); + if (err != 0) { + throw cuda_exception::build("helpers::resizeArea: Cannot deallocate memory for cache", err); + } err = cudaFree(pSt); + if (err != 0) { + throw cuda_exception::build("helpers::resizeArea: Cannot deallocate memory for ImageResizeState", err); + } } // ------------------------------------------------------------------------------------------------------------------ // template @@ -1134,11 +1154,20 @@ namespace helpers { CachedInterpolation* xCached; //(st.outWidth); auto err = cudaMalloc(&xCached, sizeof(CachedInterpolation) * st.outWidth); + if (err != 0) { + throw cuda_exception::build("helpers::resizeAreaFunctor_: Cannot allocate memory for cached interpolations", err); + } NDArray::prepareSpecialUse({output}, {image}); fillInterpolationCache<<<128, 128, 256, *stream>>>(xCached, st.outWidth, st.inWidth, st.widthScale); resizeArea(stream, st, xCached, image, output); err = cudaStreamSynchronize(*stream); + if (err != 0) { + throw cuda_exception::build("helpers::resizeAreaFunctor_: Error occured when kernel was running", err); + } err = cudaFree(xCached); + if (err != 0) { + throw cuda_exception::build("helpers::resizeAreaFunctor_: Cannot deallocate memory for cached interpolations", err); + } NDArray::registerSpecialUse({output}, {image}); } diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests11.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests11.cpp index e4391c688..23c40ebae 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests11.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests11.cpp @@ -1054,7 +1054,7 @@ TEST_F(DeclarableOpsTests11, ImageResizeBicubic_Test8) { ASSERT_TRUE(testData.equalsTo(result)); } -/* + TEST_F(DeclarableOpsTests11, ImageResizeArea_Test1) { NDArray input = NDArrayFactory::create('c', {1, 3, 3, 4}); @@ -1532,7 +1532,7 @@ TEST_F(DeclarableOpsTests11, ImageResizeArea_Test15) { ASSERT_TRUE(expected.isSameShape(result)); ASSERT_TRUE(expected.equalsTo(result)); } - */ + /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests11, summaryStatsData_test1) {