Shugeo resize area fix4 (#465)

* Restore resize_area test suite.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Fixed resize_area kernel for cuda platform to avoid range violation.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Fixed resizeAreaKernel start.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Fixed potential error handling with resize area cuda implementation.

Signed-off-by: shugeo <sgazeos@gmail.com>
master
shugeo 2020-05-26 14:13:48 +03:00 committed by GitHub
parent ecdee6369d
commit a18417193d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 34 additions and 5 deletions

View File

@ -1066,7 +1066,7 @@ namespace helpers {
const Nd4jLong yStart = math::nd4j_floor<float, Nd4jLong>(inY);
const Nd4jLong yEnd = math::nd4j_ceil<float, Nd4jLong>(inY1);
auto scalesDim = yEnd - yStart;
auto yScaleCache = cachePool + (batch * pSt->outWidth + y) * scalesDim * sizeof(ScaleCache<T>);
auto yScaleCache = cachePool + (batch * pSt->outHeight + y) * pSt->outWidth;
//auto startPtr = sharedPtr + y * scalesDim * sizeof(float);
//float* yScales = yScalesShare + y * sizeof(float) * scalesDim;//reinterpret_cast<float*>(startPtr); //shared + y * scalesDim * y + scalesDim * sizeof(T const *) [scalesDim];
@ -1113,14 +1113,34 @@ namespace helpers {
auto outputPtr = reinterpret_cast<float*>(output->specialBuffer()); // output is always float. TO DO: provide another float types also with template <typename X, typename Z> declaration
ImageResizerState* pSt;
auto err = cudaMalloc(&pSt, sizeof(ImageResizerState));
if (err != 0) {
throw cuda_exception::build("helpers::resizeArea: Cannot allocate memory for ImageResizerState", err);
}
err = cudaMemcpyAsync(pSt, &st, sizeof(ImageResizerState), cudaMemcpyHostToDevice, *stream);
if (err != 0) {
throw cuda_exception::build("helpers::resizeArea: Cannot copy to device memory", err);
}
ScaleCache<T>* cachePool;
err = cudaMalloc(&cachePool, sizeof(ScaleCache<T>) * st.batchSize * st.outWidth * st.outHeight);
resizeAreaKernel<T><<<128, 2, 2048, *stream>>>(pSt, cache, scale, inputPtr, input->specialShapeInfo(), outputPtr,
auto cachePoolSize = sizeof(ScaleCache<T>) * st.batchSize * st.outWidth * st.outHeight;
err = cudaMalloc(&cachePool, cachePoolSize);
if (err != 0) {
throw cuda_exception::build("helpers::resizeArea: Cannot allocate memory for cache", err);
}
resizeAreaKernel<T><<<128, 128, 2048, *stream>>>(pSt, cache, scale, inputPtr, input->specialShapeInfo(), outputPtr,
output->specialShapeInfo(), cachePool);
err = cudaStreamSynchronize(*stream);
if (err != 0) {
throw cuda_exception::build("helpers::resizeArea: An error occured with kernel running", err);
}
err = cudaFree(cachePool);
if (err != 0) {
throw cuda_exception::build("helpers::resizeArea: Cannot deallocate memory for cache", err);
}
err = cudaFree(pSt);
if (err != 0) {
throw cuda_exception::build("helpers::resizeArea: Cannot deallocate memory for ImageResizeState", err);
}
}
// ------------------------------------------------------------------------------------------------------------------ //
template <typename T>
@ -1134,11 +1154,20 @@ namespace helpers {
CachedInterpolation* xCached;
//(st.outWidth);
auto err = cudaMalloc(&xCached, sizeof(CachedInterpolation) * st.outWidth);
if (err != 0) {
throw cuda_exception::build("helpers::resizeAreaFunctor_: Cannot allocate memory for cached interpolations", err);
}
NDArray::prepareSpecialUse({output}, {image});
fillInterpolationCache<<<128, 128, 256, *stream>>>(xCached, st.outWidth, st.inWidth, st.widthScale);
resizeArea<T>(stream, st, xCached, image, output);
err = cudaStreamSynchronize(*stream);
if (err != 0) {
throw cuda_exception::build("helpers::resizeAreaFunctor_: Error occured when kernel was running", err);
}
err = cudaFree(xCached);
if (err != 0) {
throw cuda_exception::build("helpers::resizeAreaFunctor_: Cannot deallocate memory for cached interpolations", err);
}
NDArray::registerSpecialUse({output}, {image});
}

View File

@ -1054,7 +1054,7 @@ TEST_F(DeclarableOpsTests11, ImageResizeBicubic_Test8) {
ASSERT_TRUE(testData.equalsTo(result));
}
/*
TEST_F(DeclarableOpsTests11, ImageResizeArea_Test1) {
NDArray input = NDArrayFactory::create<double>('c', {1, 3, 3, 4});
@ -1532,7 +1532,7 @@ TEST_F(DeclarableOpsTests11, ImageResizeArea_Test15) {
ASSERT_TRUE(expected.isSameShape(result));
ASSERT_TRUE(expected.equalsTo(result));
}
*/
///////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests11, summaryStatsData_test1) {