Shugeo resize area fix4 (#465)
* Restore resize_area test suite. Signed-off-by: shugeo <sgazeos@gmail.com> * Fixed resize_area kernel for cuda platform to avoid range violation. Signed-off-by: shugeo <sgazeos@gmail.com> * Fixed resizeAreaKernel start. Signed-off-by: shugeo <sgazeos@gmail.com> * Fixed potential error handling with resize area cuda implementation. Signed-off-by: shugeo <sgazeos@gmail.com>master
parent
ecdee6369d
commit
a18417193d
|
@ -1066,7 +1066,7 @@ namespace helpers {
|
|||
const Nd4jLong yStart = math::nd4j_floor<float, Nd4jLong>(inY);
|
||||
const Nd4jLong yEnd = math::nd4j_ceil<float, Nd4jLong>(inY1);
|
||||
auto scalesDim = yEnd - yStart;
|
||||
auto yScaleCache = cachePool + (batch * pSt->outWidth + y) * scalesDim * sizeof(ScaleCache<T>);
|
||||
auto yScaleCache = cachePool + (batch * pSt->outHeight + y) * pSt->outWidth;
|
||||
|
||||
//auto startPtr = sharedPtr + y * scalesDim * sizeof(float);
|
||||
//float* yScales = yScalesShare + y * sizeof(float) * scalesDim;//reinterpret_cast<float*>(startPtr); //shared + y * scalesDim * y + scalesDim * sizeof(T const *) [scalesDim];
|
||||
|
@ -1113,14 +1113,34 @@ namespace helpers {
|
|||
auto outputPtr = reinterpret_cast<float*>(output->specialBuffer()); // output is always float. TO DO: provide another float types also with template <typename X, typename Z> declaration
|
||||
ImageResizerState* pSt;
|
||||
auto err = cudaMalloc(&pSt, sizeof(ImageResizerState));
|
||||
if (err != 0) {
|
||||
throw cuda_exception::build("helpers::resizeArea: Cannot allocate memory for ImageResizerState", err);
|
||||
}
|
||||
|
||||
err = cudaMemcpyAsync(pSt, &st, sizeof(ImageResizerState), cudaMemcpyHostToDevice, *stream);
|
||||
if (err != 0) {
|
||||
throw cuda_exception::build("helpers::resizeArea: Cannot copy to device memory", err);
|
||||
}
|
||||
ScaleCache<T>* cachePool;
|
||||
err = cudaMalloc(&cachePool, sizeof(ScaleCache<T>) * st.batchSize * st.outWidth * st.outHeight);
|
||||
resizeAreaKernel<T><<<128, 2, 2048, *stream>>>(pSt, cache, scale, inputPtr, input->specialShapeInfo(), outputPtr,
|
||||
auto cachePoolSize = sizeof(ScaleCache<T>) * st.batchSize * st.outWidth * st.outHeight;
|
||||
err = cudaMalloc(&cachePool, cachePoolSize);
|
||||
if (err != 0) {
|
||||
throw cuda_exception::build("helpers::resizeArea: Cannot allocate memory for cache", err);
|
||||
}
|
||||
resizeAreaKernel<T><<<128, 128, 2048, *stream>>>(pSt, cache, scale, inputPtr, input->specialShapeInfo(), outputPtr,
|
||||
output->specialShapeInfo(), cachePool);
|
||||
err = cudaStreamSynchronize(*stream);
|
||||
if (err != 0) {
|
||||
throw cuda_exception::build("helpers::resizeArea: An error occured with kernel running", err);
|
||||
}
|
||||
err = cudaFree(cachePool);
|
||||
if (err != 0) {
|
||||
throw cuda_exception::build("helpers::resizeArea: Cannot deallocate memory for cache", err);
|
||||
}
|
||||
err = cudaFree(pSt);
|
||||
if (err != 0) {
|
||||
throw cuda_exception::build("helpers::resizeArea: Cannot deallocate memory for ImageResizeState", err);
|
||||
}
|
||||
}
|
||||
// ------------------------------------------------------------------------------------------------------------------ //
|
||||
template <typename T>
|
||||
|
@ -1134,11 +1154,20 @@ namespace helpers {
|
|||
CachedInterpolation* xCached;
|
||||
//(st.outWidth);
|
||||
auto err = cudaMalloc(&xCached, sizeof(CachedInterpolation) * st.outWidth);
|
||||
if (err != 0) {
|
||||
throw cuda_exception::build("helpers::resizeAreaFunctor_: Cannot allocate memory for cached interpolations", err);
|
||||
}
|
||||
NDArray::prepareSpecialUse({output}, {image});
|
||||
fillInterpolationCache<<<128, 128, 256, *stream>>>(xCached, st.outWidth, st.inWidth, st.widthScale);
|
||||
resizeArea<T>(stream, st, xCached, image, output);
|
||||
err = cudaStreamSynchronize(*stream);
|
||||
if (err != 0) {
|
||||
throw cuda_exception::build("helpers::resizeAreaFunctor_: Error occured when kernel was running", err);
|
||||
}
|
||||
err = cudaFree(xCached);
|
||||
if (err != 0) {
|
||||
throw cuda_exception::build("helpers::resizeAreaFunctor_: Cannot deallocate memory for cached interpolations", err);
|
||||
}
|
||||
NDArray::registerSpecialUse({output}, {image});
|
||||
}
|
||||
|
||||
|
|
|
@ -1054,7 +1054,7 @@ TEST_F(DeclarableOpsTests11, ImageResizeBicubic_Test8) {
|
|||
ASSERT_TRUE(testData.equalsTo(result));
|
||||
}
|
||||
|
||||
/*
|
||||
|
||||
TEST_F(DeclarableOpsTests11, ImageResizeArea_Test1) {
|
||||
|
||||
NDArray input = NDArrayFactory::create<double>('c', {1, 3, 3, 4});
|
||||
|
@ -1532,7 +1532,7 @@ TEST_F(DeclarableOpsTests11, ImageResizeArea_Test15) {
|
|||
ASSERT_TRUE(expected.isSameShape(result));
|
||||
ASSERT_TRUE(expected.equalsTo(result));
|
||||
}
|
||||
*/
|
||||
|
||||
|
||||
///////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests11, summaryStatsData_test1) {
|
||||
|
|
Loading…
Reference in New Issue