Shugeo resize area (#162)

* Added implementation for resize_area op. Initial commit.

* Added implementation of resize_area op. Initial revision.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Corrected resizeArea functor call.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Implementation of resize_area. Cpu platform helpers.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Implementation for resize_area helpers. The first part revision.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Added a set of tests for resize_area op.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Cuda implementation for resize_area. Initial approach.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Adding multithreading for resize_area algorithm.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Cuda implementation of resize_area helpers. Shared memory approach.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Refactored resizeAreaKernel with cuda implementation.

* Eliminated compilation errors.

* ResizeArea helpers for cuda platform. The first working revision.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Added test for batched resize_area op testing.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Implementation of resize_are for cuda platform and tests.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Fixed multithreading with resize_area op helper.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Corrected copyright marks with sources.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Corrected copyright mark for resize_area op implementation.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Corrected copyright mark for parity ops header.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Corrected typo in strings and so on with image resize ops.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Refactored resize_area helpers and multithreading.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Added ResizeArea wrapper

* Added test with align_corners and fixed shape processing with only int args given for output size.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Added test

* TF mapping for ResizeArea

* Fixed implementation issues with resize_area op for both platforms.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Refactored image resizer struct to use flexible types for ints and floats.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Improved multithreading with resizeAreaKernel launch.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Use asynchronical memory copying with cuda platform image resize allocations.

Signed-off-by: shugeo <sgazeos@gmail.com>

Co-authored-by: Alexander Stoyakin <alexander.stoyakin@gmail.com>
master
shugeo 2020-01-22 09:46:33 +02:00 committed by raver119
parent 7783012f39
commit e50b285c2c
13 changed files with 1061 additions and 24 deletions

View File

@ -0,0 +1,122 @@
/*******************************************************************************
* Copyright (c) 2019-2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// @author sgazeos@gmail.com
//
#include <op_boilerplate.h>
#if NOT_EXCLUDED(OP_resize_area)
#include <ops/declarable/CustomOperations.h>
#include <ops/declarable/helpers/image_resize.h>
namespace nd4j {
namespace ops {
CUSTOM_OP_IMPL(resize_area, 1, 1, false, 0, -2) {
auto image = INPUT_VARIABLE(0);
int width;
int height;
if (block.width() == 2) {
auto size = INPUT_VARIABLE(1); // integer vector with shape {2} and content (new_height, new_width)
REQUIRE_TRUE(size->rankOf() == 1, size->lengthOf() == 2, 0, "resize_area: Resize params is a pair of values, not %i.", size->lengthOf());
size->syncToHost();
width = size->e<int>(1);
height = size->e<int>(0);
}
else {
REQUIRE_TRUE(block.numI() == 2, 0, "resize_area: Resize params already given by the second param. Int params are expensive.");
width = INT_ARG(1);
height = INT_ARG(0);
}
auto output = OUTPUT_VARIABLE(0);
if (output->isEmpty()) return Status::OK();
auto inRank = image->rankOf();
REQUIRE_TRUE(inRank == 3 || inRank == 4, 0, "resize_area: Source tensor should have rank 4, but %i given.", inRank);
REQUIRE_TRUE(output->rankOf() == inRank, 0, "resize_area: Source tensor and output should have the same rank, but %i and %i given.", inRank, output->rankOf());
REQUIRE_TRUE(width > 0 , 0, "resize_area: picture width should be positive 32 bit integer, but %i given", width);
REQUIRE_TRUE(height > 0 , 0, "resize_area: picture height should be positive 32 bit integer, but %i given", height);
REQUIRE_TRUE(image->lengthOf() > 0, 0, "resize_area: Only non-zero images allowed to processing.");
auto alignCorners = false;
if (block.numB() > 0) {
alignCorners = B_ARG(0);
}
auto source = inRank == 4?image->reshape(image->ordering(), {image->sizeAt(0), image->sizeAt(1), image->sizeAt(2), image->sizeAt(3)}):image->reshape(image->ordering(), {1, image->sizeAt(0), image->sizeAt(1), image->sizeAt(2)});
auto target = inRank == 4?output->reshape(output->ordering(), {output->sizeAt(0), output->sizeAt(1), output->sizeAt(2), output->sizeAt(3)}):output->reshape(output->ordering(), {1, output->sizeAt(0), output->sizeAt(1), output->sizeAt(2)});
return helpers::resizeAreaFunctor(block.launchContext(), &source, width, height, alignCorners, &target);
}
DECLARE_SHAPE_FN(resize_area) {
auto shapeList = SHAPELIST();
auto in = inputShape->at(0);
Nd4jLong* outputShape;
auto inRank = shape::rank(in);
int width;
int height;
if (block.width() == 2) {
auto newImageSize = INPUT_VARIABLE(1);
REQUIRE_TRUE(newImageSize->lengthOf() == 2, 0,
"resize_area: Resize params is a pair of values, not %i.", newImageSize->lengthOf());
REQUIRE_TRUE(block.numI() <= 1, 0,
"resize_area: Resize params already given by the second param. Int params are expensive.");
width = newImageSize->e<int>(0);
height = newImageSize->e<int>(1);
}
else {
REQUIRE_TRUE(block.numI() == 2, 0, "resize_area: Resize params ommited as pair ints nor int tensor.");
width = INT_ARG(1);
height = INT_ARG(0);
}
REQUIRE_TRUE(inRank == 4 || inRank == 3, 0, "resize_area: Source tensor should have rank 4, but %i given.", inRank);
ALLOCATE(outputShape, block.getWorkspace(), shape::shapeInfoLength(inRank), Nd4jLong);
outputShape[0] = inRank;
if (inRank == 4) {
outputShape[1] = in[1];
outputShape[2] = width;
outputShape[3] = height;
outputShape[4] = in[4];
}
else {
outputShape[1] = width;
outputShape[2] = height;
outputShape[3] = in[3];
}
ShapeUtils::updateStridesAndType(outputShape, DataType::FLOAT32, shape::order(in));
shapeList->push_back(CONSTANT(outputShape));
return shapeList;
}
DECLARE_TYPES(resize_area) {
getOpDescriptor()
->setAllowedInputTypes(0, {ALL_FLOATS, ALL_INTS})
->setAllowedInputTypes(1, DataType::INT32)
->setAllowedOutputTypes({DataType::FLOAT32});
}
}
}
#endif

View File

@ -1,5 +1,5 @@
/*******************************************************************************
* Copyright (c) 2019 Konduit K.K.
* Copyright (c) 2019-2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
@ -76,8 +76,8 @@ namespace nd4j {
int width;
int height;
auto newImageSize = INPUT_VARIABLE(1);
REQUIRE_TRUE(newImageSize->lengthOf() == 2, 0, "resize_bilinear: Resize params is a pair of values, not %i.", newImageSize->lengthOf());
REQUIRE_TRUE(block.numI() <= 1, 0, "resize_bilinear: Resize params already given by the second param. Int params are expensive.");
REQUIRE_TRUE(newImageSize->lengthOf() == 2, 0, "resize_bicubic: Resize params is a pair of values, not %i.", newImageSize->lengthOf());
REQUIRE_TRUE(block.numI() <= 1, 0, "resize_bicubic: Resize params already given by the second param. Int params are expensive.");
width = newImageSize->e<int>(0);
height = newImageSize->e<int>(1);

View File

@ -1,5 +1,6 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
* Copyright (c) 2019-2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at

View File

@ -1,5 +1,6 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
* Copyright (c) 2019-2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
@ -19,7 +20,7 @@
//
#include <op_boilerplate.h>
#if NOT_EXCLUDED(OP_resize_bilinear)
#if NOT_EXCLUDED(OP_resize_nearest_neighbor)
//#include <ops/declarable/headers/parity_ops.h>
#include <ops/declarable/CustomOperations.h>
@ -54,7 +55,7 @@ namespace nd4j {
if (block.numB() > 1)
halfPixelCenter = B_ARG(1);
REQUIRE_TRUE(width <= (1 << 24) || height <= (1 << 24), 0, "resize_nearest_neighbour: the image resize should be limited to 2^24 pixels both for height and width, but %d and %d were given.", height, width);
REQUIRE_TRUE(width <= (1 << 24) || height <= (1 << 24), 0, "resize_nearest_neighbor: the image resize should be limited to 2^24 pixels both for height and width, but %d and %d were given.", height, width);
REQUIRE_TRUE(inRank == 4 || inRank == 3, 0, "resize_nearest_neighbor: Input should be 4D tensor, but rank %i occured");
REQUIRE_TRUE(inRank == output->rankOf(), 0, "resize_nearest_neighbor: Input and output ranks should be equals, but %i and %i occured.", inRank, output->rankOf());
REQUIRE_TRUE(image->dataType() == output->dataType(), 0, "resize_nearest_neighbor: Input and output types should be the same, but `%s' occured instead.", DataTypeUtils::asString(output->dataType()).c_str());
@ -73,7 +74,7 @@ namespace nd4j {
auto inRank = shape::rank(in);
Nd4jLong* outputShape;
REQUIRE_TRUE(inRank == 4 || inRank == 3, 0, "resize_bilinear: input image should be 4D "
REQUIRE_TRUE(inRank == 4 || inRank == 3, 0, "resize_nearest_neighbor: input image should be 4D "
"tensor, but input has rank %i",
inRank);

View File

@ -1,6 +1,6 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
* Copyright (c) 2019 Konduit K.K.
* Copyright (c) 2019-2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
@ -1722,6 +1722,27 @@ namespace nd4j {
DECLARE_CUSTOM_OP(resize_bicubic, 1, 1, false, 0, -2);
#endif
/**
* This op make area interpolated resize (as OpenCV INTER_AREA algorithm) for given tensor
*
* input array:
* 0 - images - 4D-Tensor with shape (batch, sizeX, sizeY, channels)
* 1 - size - 1D-Tensor with 2 values (newWidth, newHeight) (if missing a pair of integer args should be provided).
*
* int args: - proveded only when size tensor is missing
* 0 - new height
* 1 - new width
* boolean args:
* 0 - align_corners - optional (default is false)
*
* output array:
* the 4D-Tensor with resized image (shape is {batch, newWidth, newHeight, channels})
*
*/
#if NOT_EXCLUDED(OP_resize_area)
DECLARE_CUSTOM_OP(resize_area, 1, 1, false, 0, -2);
#endif
/**
* This op make interpolated resize for given tensor with given algorithm.
* Supported algorithms are bilinear, bicubic, nearest_neighbor.

View File

@ -1,6 +1,6 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
* Copyright (c) 2019 Konduit K.K.
* Copyright (c) 2019-2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
@ -35,6 +35,8 @@ limitations under the License.
#include <ops/declarable/helpers/image_resize.h>
#include <execution/Threads.h>
#include <ops/declarable/headers/parity_ops.h>
#include "../cross.h"
namespace nd4j {
namespace ops {
@ -55,8 +57,9 @@ namespace helpers {
: inSize / static_cast<float>(outSize);
}
struct ImageResizerState {
explicit ImageResizerState(bool alignCorners, bool halfPixelCenters)
template <typename I, typename F>
struct ImageResizerStateCommon {
explicit ImageResizerStateCommon(bool alignCorners, bool halfPixelCenters)
: _alignCorners(alignCorners),
_halfPixelCenters(halfPixelCenters) {}
@ -94,14 +97,14 @@ namespace helpers {
return validateAndCalculateOutputSize(input, width, height);
}
Nd4jLong batchSize;
Nd4jLong outHeight;
Nd4jLong outWidth;
Nd4jLong inHeight;
Nd4jLong inWidth;
Nd4jLong channels;
float heightScale;
float widthScale;
I batchSize;
I outHeight;
I outWidth;
I inHeight;
I inWidth;
I channels;
F heightScale;
F widthScale;
NDArray* output = nullptr;
private:
@ -109,6 +112,8 @@ namespace helpers {
bool _halfPixelCenters;
};
typedef ImageResizerStateCommon<Nd4jLong, float> ImageResizerState;
// Half pixel scaler scales assuming that the pixel centers are at 0.5, i.e. the
// floating point coordinates of the top,left pixel is 0.5,0.5.
struct HalfPixelScaler {
@ -255,7 +260,7 @@ namespace helpers {
// Handle no-op resizes efficiently.
if (outHeight == inHeight && outWidth == inWidth) {
output->assign(images);
return ND4J_STATUS_OK;
return Status::OK();
}
std::vector<BilinearInterpolationData> ys(outHeight + 1);
@ -283,7 +288,7 @@ namespace helpers {
samediff::Threads::parallel_for(func, 0, xsSize);
resizeImage_<X,Z>(images->getDataBuffer()->primaryAsT<X>(), batchSize, inHeight, inWidth, outHeight, outWidth, channels, xs, ys, output->dataBuffer()->primaryAsT<Z>());
return ND4J_STATUS_OK;
return Status::OK();
}
template <class Scaler, typename T>
@ -353,6 +358,7 @@ namespace helpers {
int resizeBilinearFunctor(nd4j::LaunchContext * context, NDArray const *images, int const width, int const height,
bool const alignCorners, bool const halfPixelCenter, NDArray *output) {
BUILD_DOUBLE_SELECTOR(images->dataType(), output->dataType(), return resizeBilinearFunctor_, (images, width, height, alignCorners, halfPixelCenter, output), NUMERIC_TYPES, FLOAT_TYPES);
return Status::OK();
}
int resizeNeighborFunctor(nd4j::LaunchContext * context, NDArray const *images, int const width, int const height,
@ -883,6 +889,206 @@ namespace helpers {
bool const alignCorners, bool const halfPixelAlign, NDArray* output) {
BUILD_SINGLE_SELECTOR(image->dataType(), return resizeBicubicFunctorA_, (context, image, width, height, alignCorners, halfPixelAlign, output), NUMERIC_TYPES);
}
// ------------------------------------------------------------------------------------------------------------------ //
struct CachedInterpolation {
Nd4jLong start;
Nd4jLong end;
float startScale;
float endMinusOneScale;
bool needsBounding;
};
template <typename T>
struct ScaleCache {
float yScale;
T const* yPtr;
};
// Computes the sum of all x values defined by <x_interp> taken across
// the y offsets and scales defined by y_ptrs and y_scales, for channel c.
//
// Note that <NeedsXBounding> is a template parameter to avoid a performance
// penalty from dynamically checking it.
template <typename T>
static void computePatchSumOf3Channels(float scale,
ImageResizerState const& st,
std::vector<ScaleCache<T>> const& yPtrs,
CachedInterpolation const& xCache,
float* outputPtr) {
bool const needsXBounding = xCache.needsBounding;
auto boundIfNeeded = [needsXBounding](Nd4jLong x, Nd4jLong y) -> Nd4jLong {
return (needsXBounding ? bound(x, y) : (x));
};
float sum_0 = 0;
float sum_1 = 0;
float sum_2 = 0;
for (int i = 0; i < yPtrs.size(); ++i) {
const T* ptr = yPtrs[i].yPtr;
float scaleX = xCache.startScale;
Nd4jLong offset = 3 * boundIfNeeded(xCache.start, st.inWidth);
float sum_y_0 = static_cast<float>(ptr[offset + 0]) * scaleX;
float sum_y_1 = static_cast<float>(ptr[offset + 1]) * scaleX;
float sum_y_2 = static_cast<float>(ptr[offset + 2]) * scaleX;
if (xCache.start + 1 != xCache.end) {
for (Nd4jLong x = xCache.start + 1; x < xCache.end - 1; ++x) {
Nd4jLong offset = 3 * boundIfNeeded(x, st.inWidth);
sum_y_0 += static_cast<float>(ptr[offset + 0]);
sum_y_1 += static_cast<float>(ptr[offset + 1]);
sum_y_2 += static_cast<float>(ptr[offset + 2]);
}
scaleX = xCache.endMinusOneScale;
offset = st.channels * boundIfNeeded(xCache.end - 1, st.inWidth);
sum_y_0 += static_cast<float>(ptr[offset + 0]) * scaleX;
sum_y_1 += static_cast<float>(ptr[offset + 1]) * scaleX;
sum_y_2 += static_cast<float>(ptr[offset + 2]) * scaleX;
}
sum_0 += sum_y_0 * yPtrs[i].yScale;
sum_1 += sum_y_1 * yPtrs[i].yScale;
sum_2 += sum_y_2 * yPtrs[i].yScale;
}
outputPtr[0] = sum_0 * scale;
outputPtr[1] = sum_1 * scale;
outputPtr[2] = sum_2 * scale;
}
// Computes the sum of all x values defined by <x_interp> taken across
// the y offsets and scales defined by y_ptrs and y_scales, for channel c.
//
// Note that <NeedsXBounding> is a template parameter to avoid a performance
// penalty from dynamically checking it.
template <typename T>
static void computePatchSum(float scale, const ImageResizerState& st,
const std::vector<ScaleCache<T>>& yPtrs,
const CachedInterpolation& xCache,
float* outputPtr) {
bool const needsXBounding = xCache.needsBounding;
auto boundIfNeeded = [needsXBounding](Nd4jLong x, Nd4jLong y) -> Nd4jLong {
return (needsXBounding ? bound(x, y) : (x));
};
const auto numChannels = st.channels;
for (Nd4jLong c = 0; c < numChannels; ++c) {
float sum = 0;
for (int i = 0; i < yPtrs.size(); ++i) {
T const* ptr = yPtrs[i].yPtr;
float scaleX = xCache.startScale;
float sumY = static_cast<float>(ptr[numChannels * boundIfNeeded(xCache.start, st.inWidth) + c]) * scaleX;
if (xCache.start + 1 != xCache.end) {
for (Nd4jLong x = xCache.start + 1; x < xCache.end - 1; ++x) {
sumY += static_cast<float>(
ptr[numChannels * boundIfNeeded(x, st.inWidth) + c]);
}
scaleX = xCache.endMinusOneScale;
sumY += static_cast<float>(ptr[numChannels * boundIfNeeded(xCache.end - 1, st.inWidth) + c]) * scaleX;
}
sum += sumY * yPtrs[i].yScale;
}
outputPtr[c] = sum * scale;
}
}
template <typename T>
static void resizeArea(ImageResizerState const& st, std::vector<CachedInterpolation> const& caches, NDArray const* input, NDArray* output) {
T const* inputPtr = input->bufferAsT<T>();
float scale = 1.f / (st.heightScale * st.widthScale);
auto outputPtr = output->bufferAsT<float>(); // output is always float. TO DO: provide another float types also with template <typename X, typename Z> declaration
auto batchProcess = PRAGMA_THREADS_FOR {
for (auto batch = start; batch < stop; batch += increment) {
for (auto y = 0; y < st.outHeight; ++y) {
const float inY = y * st.heightScale;
const float inY1 = (y + 1) * st.heightScale;
// The start and end height indices of all the cells that could
// contribute to the target cell.
const Nd4jLong yStart = math::nd4j_floor<float, Nd4jLong>(inY);
const Nd4jLong yEnd = math::nd4j_ceil<float, Nd4jLong>(inY1);
std::vector<ScaleCache<T>> yCaches;
auto cacheLen = yEnd - yStart;
if (cacheLen) {
yCaches.resize(cacheLen);
};
for (auto i = yStart, k = 0LL; i < yEnd; ++i, ++k) {
ScaleCache<T> scaleCache;
if (i < inY) {
scaleCache.yScale = (i + 1 > inY1 ? st.heightScale : i + 1 - inY);
} else {
scaleCache.yScale = (i + 1 > inY1 ? inY1 - i : 1.0);
}
scaleCache.yPtr = inputPtr + (batch * st.inHeight * st.inWidth * st.channels +
bound(i, st.inHeight) * st.inWidth * st.channels);
yCaches[k] = scaleCache;
}
float* output = outputPtr + (batch * st.outHeight + y) * st.channels * st.outWidth;
if (st.channels == 3) {
for (Nd4jLong x = 0; x < st.outWidth; ++x) {
const CachedInterpolation &xCache = caches[x];
computePatchSumOf3Channels<T>(scale, st, yCaches, xCache, output);
output += st.channels;
}
} else {
for (Nd4jLong x = 0; x < st.outWidth; ++x) {
const CachedInterpolation &xCache = caches[x];
computePatchSum<T>(scale, st, yCaches, xCache, output);
output += st.channels;
}
}
}
}
};
samediff::Threads::parallel_tad(batchProcess, 0, st.batchSize, 1);
}
template <typename X>
int resizeAreaFunctor_(nd4j::LaunchContext* context, NDArray const* image, int const width, int const height,
bool const alignCorners, NDArray* output) {
ImageResizerState st(alignCorners, false); // Create resize info
auto res = st.validateAndCalculateOutputSize(image, width, height);
if (Status::OK() == res) {
std::vector<CachedInterpolation> xCached(st.outWidth);
auto cachingProcedure = PRAGMA_THREADS_FOR {
for (auto x = start; x < stop; x += increment) {
auto &xCache = xCached[x];
const float inX = x * st.widthScale;
const float inX1 = (x + 1) * st.widthScale;
Nd4jLong v = math::nd4j_floor<float, Nd4jLong>(inX);
xCache.start = v;
xCache.startScale =
v < inX ? (v + 1 > inX1 ? st.widthScale : v + 1 - inX) : (v + 1 > inX1 ? inX1 - v
: 1.f);
v = math::nd4j_ceil<float, Nd4jLong>(inX1);
xCache.end = v--;
xCache.endMinusOneScale =
v < inX ? (v + 1 > inX1 ? st.widthScale : v + 1 - inX) : (v + 1 > inX1 ? inX1 - v
: 1.f);
xCache.needsBounding = bound(xCache.start, st.inWidth) != xCache.start ||
bound(xCache.end - 1, st.inWidth) != (xCache.end - 1);
}
};
samediff::Threads::parallel_for(cachingProcedure, 0, xCached.size(), 1);
resizeArea<X>(st, xCached, image, output);
}
return res;
}
int resizeAreaFunctor(nd4j::LaunchContext * context, NDArray const* image, int const width, int const height,
bool const alignCorners, NDArray* output) {
BUILD_SINGLE_SELECTOR(image->dataType(), return resizeAreaFunctor_, (context, image, width, height, alignCorners, output), NUMERIC_TYPES);
}
// ------------------------------------------------------------------------------------------------------------------ //
int resizeFunctor(nd4j::LaunchContext * context, NDArray const* image, int const width, int const height,
ImageResizeMethods method, bool preserveAspectRatio, bool antialias, NDArray* output) {
@ -890,9 +1096,9 @@ namespace helpers {
case kResizeBilinear: return resizeBilinearFunctor(context, image, width, height, false, false, output); break;
case kResizeNearest: return resizeNeighborFunctor(context, image, width, height, false, false, output); break;
case kResizeBicubic: return resizeBicubicFunctor(context, image, width, height, preserveAspectRatio, antialias, output); break;
case kResizeArea: return resizeAreaFunctor(context, image, width, height, preserveAspectRatio, output);
case kResizeLanczos5:
case kResizeGaussian:
case kResizeArea:
case kResizeMitchelcubic:
throw std::runtime_error("helper::resizeFunctor: Non implemented yet.");
}

View File

@ -1,5 +1,6 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
* Copyright (c) 2019-2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
@ -29,7 +30,7 @@ limitations under the License.
==============================================================================*/
//
// @author sgazeos@gmail.com
// @author George A. Shulinok <sgazeos@gmail.com>
//
#include <ops/declarable/helpers/image_resize.h>
@ -639,7 +640,7 @@ namespace helpers {
if (err != 0) {
cuda_exception::build("helpers::computeXWeightsAndIndices: Cannot allocated device memory for interpolate calculator", err);
}
err = cudaMemcpy(pCalcD, &calc, sizeof(CachedInterpolationCalculator), cudaMemcpyHostToDevice);
err = cudaMemcpyAsync(pCalcD, &calc, sizeof(CachedInterpolationCalculator), cudaMemcpyHostToDevice, *stream);
if (err != 0) {
cuda_exception::build("helpers::computeXWeightsAndIndices: Cannot set up device memory for interpolate calculator", err);
}
@ -847,7 +848,7 @@ namespace helpers {
if (err != 0) {
throw cuda_exception::build("helpers::bicubicInterpolateWithCaching: Cannot allocate memory for resizerState", err);
}
err = cudaMemcpy(resizerStateD, &resizerState, sizeof(ImageResizerState), cudaMemcpyHostToDevice);
err = cudaMemcpyAsync(resizerStateD, &resizerState, sizeof(ImageResizerState), cudaMemcpyHostToDevice, *stream);
if (err != 0) {
throw cuda_exception::build("helpers::bicubicInterpolateWithCaching: Cannot set up memory for resizerState", err);
}
@ -927,6 +928,233 @@ namespace helpers {
BUILD_SINGLE_TEMPLATE(template int resizeBicubicFunctor_, (nd4j::LaunchContext * context, NDArray const* image, int width, int height,
bool preserveAspectRatio, bool antialias, NDArray* output), NUMERIC_TYPES);
// ------------------------------------------------------------------------------------------------------------------ //
struct CachedInterpolation {
Nd4jLong start;
Nd4jLong end;
float startScale;
float endMinusOneScale;
bool needsBounding;
};
static __global__ void fillInterpolationCache(CachedInterpolation* xCached, Nd4jLong cacheLen, Nd4jLong inWidth, float widthScale) {
auto start = blockIdx.x * blockDim.x + threadIdx.x;
auto increment = blockDim.x * gridDim.x;
for (auto x = start; x < cacheLen; x += increment) {
auto& xCache = xCached[x];
const float inX = x * widthScale;
const float inX1 = (x + 1) * widthScale;
Nd4jLong v = math::nd4j_floor<float, Nd4jLong>(inX);
xCache.start = v;
xCache.startScale = v < inX ? (v + 1 > inX1 ? widthScale : v + 1 - inX) : (v + 1 > inX1 ? inX1 - v : 1.f);
v = math::nd4j_ceil<float, Nd4jLong>(inX1);
xCache.end = v--;
xCache.endMinusOneScale = v < inX ? (v + 1 > inX1 ? widthScale : v + 1 - inX) : (v + 1 > inX1 ? inX1 - v : 1.f);
xCache.needsBounding = bound(xCache.start, inWidth) != xCache.start || bound(xCache.end - 1, inWidth) != (xCache.end - 1);
}
}
// ------------------------------------------------------------------------------------------------------------------ //
template <typename T>
struct ScaleCache {
float yScale;
T const* yPtr;
};
// Computes the sum of all x values defined by <x_interp> taken across
// the y offsets and scales defined by y_ptrs and y_scales, for channel c.
//
// Note that <NeedsXBounding> is a template parameter to avoid a performance
// penalty from dynamically checking it.
template <typename T>
static __device__ void computePatchSumOf3Channels(float scale,
const ImageResizerState& st,
ScaleCache<T> const* yScaleCache,
Nd4jLong ptrsLen,
const CachedInterpolation& xCache,
float* outputPtr) {
bool const needsXBounding = xCache.needsBounding;
auto boundIfNeeded = [needsXBounding](Nd4jLong x, Nd4jLong y) -> Nd4jLong {
return (needsXBounding ? bound(x, y) : (x));
};
float sum_0 = 0;
float sum_1 = 0;
float sum_2 = 0;
for (int i = 0; i < ptrsLen; ++i) {
const T* ptr = yScaleCache[i].yPtr;
float scaleX = xCache.startScale;
Nd4jLong offset = 3 * boundIfNeeded(xCache.start, st.inWidth);
float sum_y_0 = static_cast<float>(ptr[offset + 0]) * scaleX;
float sum_y_1 = static_cast<float>(ptr[offset + 1]) * scaleX;
float sum_y_2 = static_cast<float>(ptr[offset + 2]) * scaleX;
if (xCache.start + 1 != xCache.end) {
for (Nd4jLong x = xCache.start + 1; x < xCache.end - 1; ++x) {
Nd4jLong offset = 3 * boundIfNeeded(x, st.inWidth);
sum_y_0 += static_cast<float>(ptr[offset + 0]);
sum_y_1 += static_cast<float>(ptr[offset + 1]);
sum_y_2 += static_cast<float>(ptr[offset + 2]);
}
scaleX = xCache.endMinusOneScale;
offset = st.channels * boundIfNeeded(xCache.end - 1, st.inWidth);
sum_y_0 += static_cast<float>(ptr[offset + 0]) * scaleX;
sum_y_1 += static_cast<float>(ptr[offset + 1]) * scaleX;
sum_y_2 += static_cast<float>(ptr[offset + 2]) * scaleX;
}
sum_0 += sum_y_0 * yScaleCache[i].yScale;
sum_1 += sum_y_1 * yScaleCache[i].yScale;
sum_2 += sum_y_2 * yScaleCache[i].yScale;
}
outputPtr[0] = sum_0 * scale;
outputPtr[1] = sum_1 * scale;
outputPtr[2] = sum_2 * scale;
}
// Computes the sum of all x values defined by <x_interp> taken across
// the y offsets and scales defined by y_ptrs and y_scales, for channel c.
//
// Note that <NeedsXBounding> is a template parameter to avoid a performance
// penalty from dynamically checking it.
template <typename T>
static __device__ void computePatchSum(float scale, const ImageResizerState& st,
ScaleCache<T> const* yScaleCache, Nd4jLong ptrsLen,
const CachedInterpolation& xCache,
float* outputPtr) {
bool const needsXBounding = xCache.needsBounding;
auto boundIfNeeded = [needsXBounding](Nd4jLong x, Nd4jLong y) -> Nd4jLong {
return (needsXBounding ? bound(x, y) : (x));
};
const auto numChannels = st.channels;
for (Nd4jLong c = 0; c < numChannels; ++c) {
float sum = 0;
for (int i = 0; i < ptrsLen; ++i) {
T const* ptr = yScaleCache[i].yPtr;
float scaleX = xCache.startScale;
float sumY = static_cast<float>(ptr[numChannels * boundIfNeeded(xCache.start, st.inWidth) + c]) * scaleX;
if (xCache.start + 1 != xCache.end) {
for (Nd4jLong x = xCache.start + 1; x < xCache.end - 1; ++x) {
sumY += static_cast<float>(
ptr[numChannels * boundIfNeeded(x, st.inWidth) + c]);
}
scaleX = xCache.endMinusOneScale;
sumY += static_cast<float>(ptr[numChannels * boundIfNeeded(xCache.end - 1, st.inWidth) + c]) * scaleX;
}
sum += sumY * yScaleCache[i].yScale;
}
outputPtr[c] = sum * scale;
}
}
template <typename T>
static __global__ void resizeAreaKernel(ImageResizerState const* pSt, CachedInterpolation const* caches, float scale,
T const* inputPtr, Nd4jLong* inputShape, float* outputPtr, Nd4jLong* outputShape) {
__shared__ ScaleCache<T>* sharedPtr;
if (threadIdx.x == 0) {
extern __shared__ char shared[];
sharedPtr = reinterpret_cast<ScaleCache<T>*>(shared);
}
__syncthreads();
for (auto batch = blockIdx.x; batch < pSt->batchSize; batch += gridDim.x) {
for (auto y = threadIdx.x; y < pSt->outHeight; y += blockDim.x) {
const float inY = y * pSt->heightScale;
const float inY1 = (y + 1) * pSt->heightScale;
// The start and end height indices of all the cells that could
// contribute to the target cell.
const Nd4jLong yStart = math::nd4j_floor<float, Nd4jLong>(inY);
const Nd4jLong yEnd = math::nd4j_ceil<float, Nd4jLong>(inY1);
auto scalesDim = yEnd - yStart;
auto yScaleCache = sharedPtr + scalesDim * y * sizeof(ScaleCache<T>);
//auto startPtr = sharedPtr + y * scalesDim * sizeof(float);
//float* yScales = yScalesShare + y * sizeof(float) * scalesDim;//reinterpret_cast<float*>(startPtr); //shared + y * scalesDim * y + scalesDim * sizeof(T const *) [scalesDim];
//T const** yPtrs = yPtrsShare + y * sizeof(T const*) * scalesDim; //[scalesDim];
//yPtrs = reinterpret_cast<T const**>(sharedBuf);
float* output = outputPtr + (batch * pSt->outHeight + y) * pSt->channels * pSt->outWidth;
//int k = 0;
for (Nd4jLong i = yStart, k = 0; i < yEnd; ++i, ++k) {
float scaleY;
if (i < inY) {
scaleY = (i + 1 > inY1 ? pSt->heightScale : i + 1 - inY);
} else {
scaleY = (i + 1 > inY1 ? inY1 - i : 1.0);
}
yScaleCache[k].yScale = scaleY;
yScaleCache[k].yPtr = inputPtr + (batch * pSt->inHeight * pSt->inWidth * pSt->channels + bound(i, pSt->inHeight) * pSt->inWidth * pSt->channels);
}
if (pSt->channels == 3) {
for (Nd4jLong x = 0; x < pSt->outWidth; ++x) {
const CachedInterpolation& xCache = caches[x];
computePatchSumOf3Channels<T>(scale, *pSt, yScaleCache, scalesDim, xCache, output);
output += pSt->channels;
}
} else {
for (Nd4jLong x = 0; x < pSt->outWidth; ++x) {
const CachedInterpolation &xCache = caches[x];
computePatchSum<T>(scale, *pSt, yScaleCache, scalesDim, xCache, output);
output += pSt->channels;
}
}
}
}
}
template <typename T>
static void resizeArea(cudaStream_t* stream, ImageResizerState const& st, CachedInterpolation* cache,
NDArray const* input, NDArray* output) {
T const* inputPtr = reinterpret_cast<T const*>(input->getSpecialBuffer());
// float* yScales;
// T const** yPtrs;
float scale = 1.f / (st.heightScale * st.widthScale);
auto outputPtr = reinterpret_cast<float*>(output->specialBuffer()); // output is always float. TO DO: provide another float types also with template <typename X, typename Z> declaration
ImageResizerState* pSt;
auto err = cudaMalloc(&pSt, sizeof(ImageResizerState));
err = cudaMemcpyAsync(pSt, &st, sizeof(ImageResizerState), cudaMemcpyHostToDevice, *stream);
resizeAreaKernel<T><<<128, 4, 2048, *stream>>>(pSt, cache, scale, inputPtr, input->getSpecialShapeInfo(), outputPtr,
output->specialShapeInfo());
err = cudaStreamSynchronize(*stream);
err = cudaFree(pSt);
}
// ------------------------------------------------------------------------------------------------------------------ //
template <typename T>
int resizeAreaFunctor_(nd4j::LaunchContext* context, NDArray const* image, int const width, int const height,
bool const alignCorners, NDArray* output) {
ImageResizerState st(alignCorners, false); // Create resize info
auto res = st.validateAndCalculateOutputSize(image, width, height);
auto stream = context->getCudaStream();
if (Status::OK() == res) {
CachedInterpolation* xCached;
(st.outWidth);
auto err = cudaMalloc(&xCached, sizeof(CachedInterpolation) * st.outWidth);
NDArray::prepareSpecialUse({output}, {image});
fillInterpolationCache<<<128, 128, 256, *stream>>>(xCached, st.outWidth, st.inWidth, st.widthScale);
resizeArea<T>(stream, st, xCached, image, output);
err = cudaStreamSynchronize(*stream);
err = cudaFree(xCached);
NDArray::registerSpecialUse({output}, {image});
}
return res;
}
int resizeAreaFunctor(nd4j::LaunchContext * context, NDArray const* image, int const width, int const height,
bool const alignCorners, NDArray* output) {
BUILD_SINGLE_SELECTOR(image->dataType(), return resizeAreaFunctor_, (context, image, width, height, alignCorners, output), NUMERIC_TYPES);
}
// ------------------------------------------------------------------------------------------------------------------ //
// simplified bicubic resize without antialiasing
//

View File

@ -45,6 +45,9 @@ namespace helpers {
bool preserveAspectRatio, bool antialias, NDArray* output);
int resizeBicubicFunctorA(nd4j::LaunchContext * context, NDArray const* image, int const width, int const height,
bool const alignCorners, bool const halfPixelAlign, NDArray* output);
int resizeAreaFunctor(nd4j::LaunchContext * context, NDArray const* image, int const width, int const height,
bool const alignCorners, NDArray* output);
int resizeFunctor(nd4j::LaunchContext * context, NDArray const* image, int const width, int const height,
ImageResizeMethods method, bool preserveAspectRatio, bool antialias, NDArray* output);

View File

@ -1087,6 +1087,304 @@ TEST_F(DeclarableOpsTests11, ImageResizeBicubic_Test8) {
delete results;
}
TEST_F(DeclarableOpsTests11, ImageResizeArea_Test1) {
NDArray input = NDArrayFactory::create<double>('c', {1, 3, 3, 4});
NDArray expected = NDArrayFactory::create<float>('c', {1, 6, 6, 4}, {
1.f, 2.f, 3.f, 4.f,
1.f, 2.f, 3.f, 4.f,
5.f, 6.f, 7.f, 8.f,
5.f, 6.f, 7.f, 8.f,
9.f, 10.f, 11.f, 12.f,
9.f, 10.f, 11.f, 12.f,
1.f, 2.f, 3.f, 4.f,
1.f, 2.f, 3.f, 4.f,
5.f, 6.f, 7.f, 8.f,
5.f, 6.f, 7.f, 8.f,
9.f, 10.f, 11.f, 12.f,
9.f, 10.f, 11.f, 12.f,
13.f, 14.f, 15.f, 16.f,
13.f, 14.f, 15.f, 16.f,
17.f, 18.f, 19.f, 20.f,
17.f, 18.f, 19.f, 20.f,
21.f, 22.f, 23.f, 24.f,
21.f, 22.f, 23.f, 24.f,
13.f, 14.f, 15.f, 16.f,
13.f, 14.f, 15.f, 16.f,
17.f, 18.f, 19.f, 20.f,
17.f, 18.f, 19.f, 20.f,
21.f, 22.f, 23.f, 24.f,
21.f, 22.f, 23.f, 24.f,
25.f, 26.f, 27.f, 28.f,
25.f, 26.f, 27.f, 28.f,
29.f, 30.f, 31.f, 32.f,
29.f, 30.f, 31.f, 32.f,
33.f, 34.f, 35.f, 36.f,
33.f, 34.f, 35.f, 36.f,
25.f, 26.f, 27.f, 28.f,
25.f, 26.f, 27.f, 28.f,
29.f, 30.f, 31.f, 32.f,
29.f, 30.f, 31.f, 32.f,
33.f, 34.f, 35.f, 36.f,
33.f, 34.f, 35.f, 36.f });
input.linspace(1);
auto size = NDArrayFactory::create<int>({6, 6});
nd4j::ops::resize_area op;
auto results = op.execute({&input, &size}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, results->status());
NDArray* result = results->at(0);
// result->printBuffer("Area Resized to 6x6");
// expected.printBuffer("Area Expect for 6x6");
ASSERT_TRUE(expected.isSameShape(result));
ASSERT_TRUE(expected.equalsTo(result));
delete results;
}
TEST_F(DeclarableOpsTests11, ImageResizeArea_Test2) {
NDArray input = NDArrayFactory::create<float>('c', {1, 3, 3, 1});
NDArray expected = NDArrayFactory::create<float>('c', {1, 6, 6, 1}, {
1.f, 1.f, 2.f, 2.f, 3.f, 3.f,
1.f, 1.f, 2.f, 2.f, 3.f, 3.f,
4.f, 4.f, 5.f, 5.f, 6.f, 6.f,
4.f, 4.f, 5.f, 5.f, 6.f, 6.f,
7.f, 7.f, 8.f, 8.f, 9.f, 9.f,
7.f, 7.f, 8.f, 8.f, 9.f, 9.f
});
input.linspace(1);
auto size = NDArrayFactory::create<int>({6, 6});
nd4j::ops::resize_area op;
auto results = op.execute({&input, &size}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, results->status());
NDArray* result = results->at(0);
// result->printBuffer("Area Resized to 6x6");
// expected.printBuffer("Area Expect for 6x6");
ASSERT_TRUE(expected.isSameShape(result));
ASSERT_TRUE(expected.equalsTo(result));
delete results;
}
TEST_F(DeclarableOpsTests11, ImageResizeArea_Test3) {
NDArray input = NDArrayFactory::create<float>('c', {1, 3, 3, 3});
NDArray expected = NDArrayFactory::create<float>('c', {1, 6, 6, 3}, {
1.f, 2.f, 3.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 7.f, 8.f, 9.f,
1.f, 2.f, 3.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 7.f, 8.f, 9.f,
10.f, 11.f, 12.f, 10.f, 11.f, 12.f, 13.f, 14.f, 15.f, 13.f, 14.f, 15.f, 16.f, 17.f, 18.f, 16.f, 17.f, 18.f,
10.f, 11.f, 12.f, 10.f, 11.f, 12.f, 13.f, 14.f, 15.f, 13.f, 14.f, 15.f, 16.f, 17.f, 18.f, 16.f, 17.f, 18.f,
19.f, 20.f, 21.f, 19.f, 20.f, 21.f, 22.f, 23.f, 24.f, 22.f, 23.f, 24.f, 25.f, 26.f, 27.f, 25.f, 26.f, 27.f,
19.f, 20.f, 21.f, 19.f, 20.f, 21.f, 22.f, 23.f, 24.f, 22.f, 23.f, 24.f, 25.f, 26.f, 27.f, 25.f, 26.f, 27.f
});
input.linspace(1);
auto size = NDArrayFactory::create<int>({6, 6});
nd4j::ops::resize_area op;
auto results = op.execute({&input, &size}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, results->status());
NDArray* result = results->at(0);
// result->printBuffer("Area Resized to 6x6");
// expected.printBuffer("Area Expect for 6x6");
ASSERT_TRUE(expected.isSameShape(result));
ASSERT_TRUE(expected.equalsTo(result));
delete results;
}
TEST_F(DeclarableOpsTests11, ImageResizeArea_Test4) {
NDArray input = NDArrayFactory::create<float>('c', {2, 3, 3, 3}, {
1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,
1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27
});
NDArray expected = NDArrayFactory::create<float>('c', {2, 6, 6, 3}, {
1.f, 2.f, 3.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 7.f, 8.f, 9.f,
1.f, 2.f, 3.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 7.f, 8.f, 9.f,
10.f, 11.f, 12.f, 10.f, 11.f, 12.f, 13.f, 14.f, 15.f, 13.f, 14.f, 15.f, 16.f, 17.f, 18.f, 16.f, 17.f, 18.f,
10.f, 11.f, 12.f, 10.f, 11.f, 12.f, 13.f, 14.f, 15.f, 13.f, 14.f, 15.f, 16.f, 17.f, 18.f, 16.f, 17.f, 18.f,
19.f, 20.f, 21.f, 19.f, 20.f, 21.f, 22.f, 23.f, 24.f, 22.f, 23.f, 24.f, 25.f, 26.f, 27.f, 25.f, 26.f, 27.f,
19.f, 20.f, 21.f, 19.f, 20.f, 21.f, 22.f, 23.f, 24.f, 22.f, 23.f, 24.f, 25.f, 26.f, 27.f, 25.f, 26.f, 27.f,
1.f, 2.f, 3.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 7.f, 8.f, 9.f,
1.f, 2.f, 3.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 7.f, 8.f, 9.f,
10.f, 11.f, 12.f, 10.f, 11.f, 12.f, 13.f, 14.f, 15.f, 13.f, 14.f, 15.f, 16.f, 17.f, 18.f, 16.f, 17.f, 18.f,
10.f, 11.f, 12.f, 10.f, 11.f, 12.f, 13.f, 14.f, 15.f, 13.f, 14.f, 15.f, 16.f, 17.f, 18.f, 16.f, 17.f, 18.f,
19.f, 20.f, 21.f, 19.f, 20.f, 21.f, 22.f, 23.f, 24.f, 22.f, 23.f, 24.f, 25.f, 26.f, 27.f, 25.f, 26.f, 27.f,
19.f, 20.f, 21.f, 19.f, 20.f, 21.f, 22.f, 23.f, 24.f, 22.f, 23.f, 24.f, 25.f, 26.f, 27.f, 25.f, 26.f, 27.f
});
//input.linspace(1);
auto size = NDArrayFactory::create<int>({6, 6});
nd4j::ops::resize_area op;
auto results = op.execute({&input, &size}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, results->status());
NDArray* result = results->at(0);
// result->printBuffer("Area Resized to 6x6");
// expected.printBuffer("Area Expect for 6x6");
ASSERT_TRUE(expected.isSameShape(result));
ASSERT_TRUE(expected.equalsTo(result));
delete results;
}
TEST_F(DeclarableOpsTests11, ImageResizeArea_Test5) {
NDArray input = NDArrayFactory::create<int>('c', {2, 3, 3, 3}, {
1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,
1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27
});
NDArray expected = NDArrayFactory::create<float>('c', {2, 6, 6, 3}, {
1.f, 2.f, 3.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 7.f, 8.f, 9.f,
1.f, 2.f, 3.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 7.f, 8.f, 9.f,
10.f, 11.f, 12.f, 10.f, 11.f, 12.f, 13.f, 14.f, 15.f, 13.f, 14.f, 15.f, 16.f, 17.f, 18.f, 16.f, 17.f, 18.f,
10.f, 11.f, 12.f, 10.f, 11.f, 12.f, 13.f, 14.f, 15.f, 13.f, 14.f, 15.f, 16.f, 17.f, 18.f, 16.f, 17.f, 18.f,
19.f, 20.f, 21.f, 19.f, 20.f, 21.f, 22.f, 23.f, 24.f, 22.f, 23.f, 24.f, 25.f, 26.f, 27.f, 25.f, 26.f, 27.f,
19.f, 20.f, 21.f, 19.f, 20.f, 21.f, 22.f, 23.f, 24.f, 22.f, 23.f, 24.f, 25.f, 26.f, 27.f, 25.f, 26.f, 27.f,
1.f, 2.f, 3.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 7.f, 8.f, 9.f,
1.f, 2.f, 3.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 7.f, 8.f, 9.f,
10.f, 11.f, 12.f, 10.f, 11.f, 12.f, 13.f, 14.f, 15.f, 13.f, 14.f, 15.f, 16.f, 17.f, 18.f, 16.f, 17.f, 18.f,
10.f, 11.f, 12.f, 10.f, 11.f, 12.f, 13.f, 14.f, 15.f, 13.f, 14.f, 15.f, 16.f, 17.f, 18.f, 16.f, 17.f, 18.f,
19.f, 20.f, 21.f, 19.f, 20.f, 21.f, 22.f, 23.f, 24.f, 22.f, 23.f, 24.f, 25.f, 26.f, 27.f, 25.f, 26.f, 27.f,
19.f, 20.f, 21.f, 19.f, 20.f, 21.f, 22.f, 23.f, 24.f, 22.f, 23.f, 24.f, 25.f, 26.f, 27.f, 25.f, 26.f, 27.f
});
//input.linspace(1);
auto size = NDArrayFactory::create<int>({6, 6});
nd4j::ops::resize_area op;
auto results = op.execute({&input, &size}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, results->status());
NDArray* result = results->at(0);
// result->printBuffer("Area Resized to 6x6");
// expected.printBuffer("Area Expect for 6x6");
ASSERT_TRUE(expected.isSameShape(result));
ASSERT_TRUE(expected.equalsTo(result));
delete results;
}
TEST_F(DeclarableOpsTests11, ImageResizeArea_Test6) {
NDArray input = NDArrayFactory::create<int>('c', {2, 3, 3, 1}, {
1, 2, 3, 4, 5, 6, 7, 8, 9,
1, 2, 3, 4, 5, 6, 7, 8, 9
});
NDArray expected = NDArrayFactory::create<float>('c', {2, 6, 6, 1}, {
1.f, 1.f, 1.5f, 2.f, 2.f, 3.f,
1.f, 1.f, 1.5f, 2.f, 2.f, 3.f,
2.5f, 2.5f, 3.f, 3.5f, 3.5f, 4.5f,
4.f, 4.f, 4.5f, 5.f, 5.f, 6.f,
4.f, 4.f, 4.5f, 5.f, 5.f, 6.f,
7.f, 7.f, 7.5f, 8.f, 8.f, 9.f,
1.f, 1.f, 1.5f, 2.f, 2.f, 3.f,
1.f, 1.f, 1.5f, 2.f, 2.f, 3.f,
2.5f, 2.5f, 3.f, 3.5f, 3.5f, 4.5f,
4.f, 4.f, 4.5f, 5.f, 5.f, 6.f,
4.f, 4.f, 4.5f, 5.f, 5.f, 6.f,
7.f, 7.f, 7.5f, 8.f, 8.f, 9.f
});
//input.linspace(1);
auto size = NDArrayFactory::create<int>({6, 6});
nd4j::ops::resize_area op;
auto results = op.execute({&input, &size}, {}, {}, {true});
ASSERT_EQ(ND4J_STATUS_OK, results->status());
NDArray* result = results->at(0);
result->printBuffer("Area Resized to 6x6");
expected.printBuffer("Area Expect for 6x6");
ASSERT_TRUE(expected.isSameShape(result));
ASSERT_TRUE(expected.equalsTo(result));
delete results;
}
TEST_F(DeclarableOpsTests11, ImageResizeArea_Test7) {
NDArray input = NDArrayFactory::create<int>('c', {2, 3, 3, 1}, {
1, 2, 3, 4, 5, 6, 7, 8, 9,
1, 2, 3, 4, 5, 6, 7, 8, 9
});
NDArray expected = NDArrayFactory::create<float>('c', {2, 6, 6, 1}, {
1.f, 1.f, 1.5f, 2.f, 2.f, 3.f,
1.f, 1.f, 1.5f, 2.f, 2.f, 3.f,
2.5f, 2.5f, 3.f, 3.5f, 3.5f, 4.5f,
4.f, 4.f, 4.5f, 5.f, 5.f, 6.f,
4.f, 4.f, 4.5f, 5.f, 5.f, 6.f,
7.f, 7.f, 7.5f, 8.f, 8.f, 9.f,
1.f, 1.f, 1.5f, 2.f, 2.f, 3.f,
1.f, 1.f, 1.5f, 2.f, 2.f, 3.f,
2.5f, 2.5f, 3.f, 3.5f, 3.5f, 4.5f,
4.f, 4.f, 4.5f, 5.f, 5.f, 6.f,
4.f, 4.f, 4.5f, 5.f, 5.f, 6.f,
7.f, 7.f, 7.5f, 8.f, 8.f, 9.f
});
//input.linspace(1);
// auto size = NDArrayFactory::create<int>({6, 6});
nd4j::ops::resize_area op;
auto results = op.execute({&input}, {}, {6, 6}, {true});
ASSERT_EQ(ND4J_STATUS_OK, results->status());
NDArray* result = results->at(0);
result->printBuffer("Area Resized to 6x6");
expected.printBuffer("Area Expect for 6x6");
ASSERT_TRUE(expected.isSameShape(result));
ASSERT_TRUE(expected.equalsTo(result));
delete results;
}
TEST_F(DeclarableOpsTests11, ImageResizeArea_Test8) {
NDArray input = NDArrayFactory::create<int>('c', {1, 3, 3, 1}, {
1, 2, 3, 4, 5, 6, 7, 8, 9
});
NDArray expected = NDArrayFactory::create<float>('c', {1, 6, 6, 1}, {
1.f, 1.f, 1.5f, 2.f, 2.f, 3.f,
1.f, 1.f, 1.5f, 2.f, 2.f, 3.f,
2.5f, 2.5f, 3.f, 3.5f, 3.5f, 4.5f,
4.f, 4.f, 4.5f, 5.f, 5.f, 6.f,
4.f, 4.f, 4.5f, 5.f, 5.f, 6.f,
7.f, 7.f, 7.5f, 8.f, 8.f, 9.f
});
//input.linspace(1);
// auto size = NDArrayFactory::create<int>({6, 6});
nd4j::ops::resize_area op;
auto results = op.execute({&input}, {}, {6, 6}, {true});
ASSERT_EQ(ND4J_STATUS_OK, results->status());
NDArray* result = results->at(0);
result->printBuffer("Area Resized to 6x6");
expected.printBuffer("Area Expect for 6x6");
ASSERT_TRUE(expected.isSameShape(result));
ASSERT_TRUE(expected.equalsTo(result));
delete results;
}
///////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests11, summaryStatsData_test1) {

View File

@ -537,6 +537,20 @@ TEST_F(DeclarableOpsTests15, Test_BitCast_4) {
}
TEST_F(DeclarableOpsTests15, Test_BitCast_4_1) {
auto x = NDArrayFactory::create<double>('c', {1, 2});
auto e = NDArrayFactory::create<Nd4jLong>('c', {1, 2}, {4607182418800017408LL, 4611686018427387904LL}); // as TF 4607182418800017408, 4611686018427387904
x.linspace(1.);
nd4j::ops::bitcast op;
auto result = op.execute({&x}, {}, {nd4j::DataType::INT64}, {});
ASSERT_EQ(Status::OK(), result->status());
// e.printIndexedBuffer("Double to int64");
auto res = result->at(0);
ASSERT_EQ(*res, e);
delete result;
}
TEST_F(DeclarableOpsTests15, Test_BitCast_5) {
auto x = NDArrayFactory::create<float16>('c', {4, 4}, {

View File

@ -89,6 +89,7 @@ public class ImportClassMapping {
org.nd4j.linalg.api.ops.impl.image.ResizeBilinear.class,
org.nd4j.linalg.api.ops.impl.image.ResizeBicubic.class,
org.nd4j.linalg.api.ops.impl.image.ResizeNearestNeighbor.class,
org.nd4j.linalg.api.ops.impl.image.ResizeArea.class,
org.nd4j.linalg.api.ops.impl.indexaccum.FirstIndex.class,
org.nd4j.linalg.api.ops.impl.indexaccum.IAMax.class,
org.nd4j.linalg.api.ops.impl.indexaccum.IAMin.class,

View File

@ -0,0 +1,114 @@
/*******************************************************************************
* Copyright (c) 2020 Konduit, K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.nd4j.linalg.api.ops.impl.image;
import lombok.NoArgsConstructor;
import lombok.NonNull;
import lombok.val;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions;
import org.nd4j.imports.graphmapper.tf.TFGraphMapper;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
import org.nd4j.linalg.factory.Nd4j;
import org.tensorflow.framework.AttrValue;
import org.tensorflow.framework.GraphDef;
import org.tensorflow.framework.NodeDef;
import java.util.Collections;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
@NoArgsConstructor
public class ResizeArea extends DynamicCustomOp {
protected boolean alignCorners = false;
protected Integer height = null;
protected Integer width = null;
public ResizeArea(@NonNull SameDiff sd, @NonNull SDVariable image, int height, int width,
boolean alignCorners) {
super(sd, image);
this.alignCorners = alignCorners;
this.height = height;
this.width = width;
addArgs();
}
public ResizeArea(@NonNull INDArray x, INDArray z, int height, int width,
boolean alignCorners) {
super(new INDArray[]{x}, new INDArray[]{z});
this.alignCorners = alignCorners;
this.height = height;
this.width = width;
addArgs();
}
@Override
public String opName() {
return "resize_area";
}
@Override
public String tensorflowName() {
return "ResizeArea";
}
@Override
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
TFGraphMapper.initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph);
val attrC = attributesForNode.get("align_corners");
this.alignCorners = attrC != null ? attrC.getB() : false;
addArgs();
}
protected void addArgs() {
iArguments.clear();
if(height != null && width != null){
INDArray size = Nd4j.createFromArray(new int[]{height,width});
addInputArgument(size);
//iArguments.add(Long.valueOf(height));
//iArguments.add(Long.valueOf(width));
}
addBArgument(alignCorners);
}
@Override
public Map<String, Object> propertiesForFunction() {
Map<String,Object> ret = new LinkedHashMap<>();
ret.put("alignCorners", alignCorners);
ret.put("height", height);
ret.put("width", width);
return ret;
}
@Override
public List<SDVariable> doDiff(List<SDVariable> f1) {
throw new UnsupportedOperationException();
}
@Override
public List<DataType> calculateOutputDataTypes(List<DataType> inputDataTypes){
Preconditions.checkState(inputDataTypes != null && (inputDataTypes.size() == 1 || inputDataTypes.size() == 2),
"Expected 1 or 2 input datatypes for %s, got %s", getClass(), inputDataTypes);
return Collections.singletonList(DataType.FLOAT);
}
}

View File

@ -34,6 +34,7 @@ import org.nd4j.linalg.api.ops.executioner.OpStatus;
import org.nd4j.linalg.api.ops.impl.controlflow.Where;
import org.nd4j.linalg.api.ops.impl.image.CropAndResize;
import org.nd4j.linalg.api.ops.impl.image.NonMaxSuppression;
import org.nd4j.linalg.api.ops.impl.image.ResizeArea;
import org.nd4j.linalg.api.ops.impl.image.ResizeBilinear;
import org.nd4j.linalg.api.ops.impl.reduce.MmulBp;
import org.nd4j.linalg.api.ops.impl.shape.Create;
@ -968,6 +969,33 @@ public class CustomOpsTests extends BaseNd4jTest {
Nd4j.exec(op);
}
@Test
public void testResizeArea1() {
INDArray x = Nd4j.rand(DataType.FLOAT, 1, 2,3,4);
INDArray z = Nd4j.createUninitialized(DataType.FLOAT, 1, 10, 10, 4);
ResizeArea op = new ResizeArea(x, z, 10, 10, false);
Nd4j.exec(op);
}
@Test
public void testResizeArea2() {
INDArray image = Nd4j.linspace(DataType.FLOAT, 1.0f, 1.0f, 9 ).reshape(1,3,3,1);
INDArray output = Nd4j.createUninitialized(DataType.FLOAT, 1, 6, 6, 1);
INDArray expected = Nd4j.createFromArray(new float[]{
1.f, 1.f, 2.f, 2.f, 3.f, 3.f,
1.f, 1.f, 2.f, 2.f, 3.f, 3.f,
4.f, 4.f, 5.f, 5.f, 6.f, 6.f,
4.f, 4.f, 5.f, 5.f, 6.f, 6.f,
7.f, 7.f, 8.f, 8.f, 9.f, 9.f,
7.f, 7.f, 8.f, 8.f, 9.f, 9.f
}).reshape(1,6,6,1);
ResizeArea op = new ResizeArea(image, output, 6, 6, false);
Nd4j.exec(op);
assertEquals(expected, output);
}
@Test
public void testCompareAndBitpack() {
INDArray in = Nd4j.createFromArray(new double[]{-12.f, -11.f, -10.f, -9.f, -8.f, -7.f, -6.f, -5.f, -4.f, -3.f,