Shugeo image resize bicubic (#56)

* Added implementation files for image_resize and resize_bicubic ops.

* Image resize and image.resize_bicubic ops implementation. Initial revision.

* Finished with infrastructure development for image.resize_bilinear op and image_resizo op implementation.

* Refactored resize methods.

* Added processing for Mitchelcubic algorithm.

* Added check for input/output sizes.

* Added int and float types for crop_and_resize op.

* Refactored crop_and_resize output type check.

* Added helper for bicubic interpolation as TF v.1 does.

* Added TF v.1 bicubic helper for cuda platform.

* Added cached class for bicubic algorithm.

* Refactored cuda implementation for crop_and_resize helper to use proper output type.

* Added facilities for bicubic interpolation.

* Portion bicubic interpolation from TF.

* Added tests for resize_bilinear testing.

* Working implementation of bicubic interpolation and tests.

* Refactored routines with image_resize bicubic op helper.

* Refactored code with coding standards.

* Refactored cpu helpers for resize_bicubic op.

* Refactored bicubic helpers.

* Added bicubic resize facilities.

* Implementing cuda kernels for bicubic interpolation. Implementation step.

* Cuda implementation of resize_bicubic op helper.

* Refactor image.resize_bicubic op helpers.

* Refactored helpers for resize_bicubic. Added error checking with cuda implementation.

* Refactored cuda implementation of resize_bicubic op helper. The first working revision.

* Cuda arch implementation for resize_bicubic op helper. Full working single-threaded revision.

* Intermediate bicubic interpolation helper for cuda.

* Refactored cpu helper for resize_bicubic.

* Multithreaded cuda implementation for resize_bicubic.

* Fixed merge issues.

* Refactored nlp helpers.

* Replicated resize_bicubic for 3D also.

* Eliminated waste comments of unused code.

* Eliminated waste comments with unused code.

* Eliminated waste template definitions.

* Eliminated waste debug code.

* Eliminated waste comments.

* Fixed multithreading with helpers.

* Fixed test suites for float and double in float point input lists.

* Fixed usage of reshape with 3D/4D on resizes.

* Final fixes.

* Fixed resize_neighbor op problem.
master
shugeo 2019-11-20 21:11:04 +02:00 committed by GitHub
parent 13e5c0a280
commit dc0036f2c6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 2051 additions and 179 deletions

View File

@ -39,6 +39,7 @@ namespace nd4j {
double extrapolationVal = 0.; double extrapolationVal = 0.;
auto newImageSize = INPUT_VARIABLE(3); auto newImageSize = INPUT_VARIABLE(3);
REQUIRE_TRUE(output->dataType() == image->dataType(), 0, "crop_and_resize: Source images and output should have the same data type.");
REQUIRE_TRUE(newImageSize->lengthOf() == 2, 0, "crop_and_resize: Resize params is a pair of values, not %i.", newImageSize->lengthOf()); REQUIRE_TRUE(newImageSize->lengthOf() == 2, 0, "crop_and_resize: Resize params is a pair of values, not %i.", newImageSize->lengthOf());
//REQUIRE_TRUE(block.numI() <= 1, 0, "crop_and_resize: Resize params already given by the second param. Int params are expensive."); //REQUIRE_TRUE(block.numI() <= 1, 0, "crop_and_resize: Resize params already given by the second param. Int params are expensive.");
//width = int(newImageSize->getScalar(0)); //width = int(newImageSize->getScalar(0));
@ -74,17 +75,17 @@ namespace nd4j {
outputShape[2] = height; outputShape[2] = height;
outputShape[3] = in[4]; outputShape[3] = in[4];
return SHAPELIST(ConstantShapeHelper::getInstance()->createShapeInfo(ShapeDescriptor(DataType::FLOAT32, shape::order(in), outputShape, 4))); return SHAPELIST(ConstantShapeHelper::getInstance()->createShapeInfo(ShapeDescriptor(ArrayOptions::dataType(in), shape::order(in), outputShape, 4)));
} }
DECLARE_TYPES(crop_and_resize) { DECLARE_TYPES(crop_and_resize) {
getOpDescriptor() getOpDescriptor()
->setAllowedInputTypes(0, {ALL_INTS, ALL_FLOATS}) ->setAllowedInputTypes(0, {ALL_INTS, ALL_FLOATS})
// ->setAllowedInputTypes(1, {ALL_FLOATS}) // ->setAllowedInputTypes(1, {ALL_FLOATS})
->setAllowedInputTypes(1, {FLOAT32}) // as TF ->setAllowedInputTypes(1, {ALL_INTS, ALL_FLOATS})
->setAllowedInputTypes(2, {ALL_INTS}) ->setAllowedInputTypes(2, {ALL_INTS})
->setAllowedInputTypes(3, {ALL_INTS}) ->setAllowedInputTypes(3, {ALL_INTS})
->setAllowedOutputTypes({FLOAT32}); // as TF ->setAllowedOutputTypes({ALL_INTS, ALL_FLOATS}); // as TF
// ->setAllowedOutputTypes({ALL_FLOATS}); // ->setAllowedOutputTypes({ALL_FLOATS});
} }
} }

View File

@ -29,30 +29,40 @@ namespace nd4j {
CUSTOM_OP_IMPL(resize_bicubic, 2, 1, false, 0, 0) { CUSTOM_OP_IMPL(resize_bicubic, 2, 1, false, 0, 0) {
auto image = INPUT_VARIABLE(0); auto image = INPUT_VARIABLE(0);
auto size = INPUT_VARIABLE(1); auto size = INPUT_VARIABLE(1); // integer vector with shape {2} and content (new_height, new_width)
size->syncToHost();
auto output = OUTPUT_VARIABLE(0); auto output = OUTPUT_VARIABLE(0);
int width; int width;
int height; int height;
bool center = false; // - default value auto inRank = image->rankOf();
REQUIRE_TRUE(inRank == 3 || inRank == 4, 0, "resize_bicubic: Source tensor should have rank 4, but %i given.", inRank);
REQUIRE_TRUE(size->lengthOf() == 2, 0, "resize_bilinear: Resize params is a pair of values, not %i.", size->lengthOf()); REQUIRE_TRUE(output->rankOf() == inRank, 0, "resize_bicubic: Source tensor and output should have the same rank, but %i and %i given.", inRank, output->rankOf());
REQUIRE_TRUE(block.numI() <= 1, 0, "resize_bilinear: Resize params already given by the second param. Int params are expensive."); REQUIRE_TRUE(size->rankOf() == 1, size->lengthOf() == 2, 0, "resize_bicubic: Resize params is a pair of values, not %i.", size->lengthOf());
width = size->e<int>(0); REQUIRE_TRUE(block.numI() <= 1, 0, "resize_bicubic: Resize params already given by the second param. Int params are expensive.");
height = size->e<int>(1); width = size->e<int>(1);
auto method = 1; //kResizeBilinear; height = size->e<int>(0);
if (block.numI() == 1) { REQUIRE_TRUE(width > 0 , 0, "resize_bicubic: picture width should be positive 32 bit integer, but %i given", width);
method = INT_ARG(0); REQUIRE_TRUE(height > 0 , 0, "resize_bicubic: picture height should be positive 32 bit integer, but %i given", height);
} //REQUIRE_TRUE(image->sizeAt(1) > 3 && image->sizeAt(2) > 3, 0, "resize_cubic: To use bicubic algorithm need at least 16 pixels as source.");
auto preserveAspectRatio = false; REQUIRE_TRUE(width > 3 && height > 3, 0, "resize_bicubic: To use bicubic algorithm need at least 16 pixels as target.");
auto antialias = false; REQUIRE_TRUE(image->lengthOf() > 0, 0, "resize_bicubic: Only non-zero images allowed to processing.");
// auto method = 1; //kResizeBilinear;
// if (block.numI() == 1) {
// method = INT_ARG(0);
// }
auto alignCorners = false;
auto halfPixelAlign = false;
if (block.numB() > 0) { if (block.numB() > 0) {
preserveAspectRatio = block.getBArguments()->at(0); alignCorners = block.getBArguments()->at(0);
if (block.numB()> 1) if (block.numB()> 1)
antialias = block.getBArguments()->at(1); halfPixelAlign = block.getBArguments()->at(1);
} }
REQUIRE_TRUE(halfPixelAlign == false || halfPixelAlign == true && alignCorners == false, 0, "resize_bicubic: half pixel align can be used only with non-aligned corners");
return helpers::resizeBicubicFunctor(block.launchContext(), image, width, height, preserveAspectRatio, antialias, output); auto source = inRank == 4?image->reshape(image->ordering(), {image->sizeAt(0), image->sizeAt(1), image->sizeAt(2), image->sizeAt(3)}):image->reshape(image->ordering(), {1, image->sizeAt(0), image->sizeAt(1), image->sizeAt(2)});
auto target = inRank == 4?output->reshape(output->ordering(), {output->sizeAt(0), output->sizeAt(1), output->sizeAt(2), output->sizeAt(3)}):output->reshape(output->ordering(), {1, output->sizeAt(0), output->sizeAt(1), output->sizeAt(2)});
return helpers::resizeBicubicFunctorA(block.launchContext(), &source, width, height, alignCorners, halfPixelAlign, &target);
} }
DECLARE_SHAPE_FN(resize_bicubic) { DECLARE_SHAPE_FN(resize_bicubic) {
@ -60,7 +70,7 @@ namespace nd4j {
auto in = inputShape->at(0); auto in = inputShape->at(0);
Nd4jLong* outputShape; Nd4jLong* outputShape;
auto inRank = shape::rank(in);
int width; int width;
int height; int height;
auto newImageSize = INPUT_VARIABLE(1); auto newImageSize = INPUT_VARIABLE(1);
@ -69,12 +79,21 @@ namespace nd4j {
width = newImageSize->e<int>(0); width = newImageSize->e<int>(0);
height = newImageSize->e<int>(1); height = newImageSize->e<int>(1);
ALLOCATE(outputShape, block.getWorkspace(), shape::shapeInfoLength(4), Nd4jLong); REQUIRE_TRUE(inRank == 4 || inRank == 3, 0, "resize_bicubic: Source tensor should have rank 4, but %i given.", inRank);
outputShape[0] = 4;
ALLOCATE(outputShape, block.getWorkspace(), shape::shapeInfoLength(inRank), Nd4jLong);
outputShape[0] = inRank;
if (inRank == 4) {
outputShape[1] = in[1]; outputShape[1] = in[1];
outputShape[2] = width; outputShape[2] = width;
outputShape[3] = height; outputShape[3] = height;
outputShape[4] = in[4]; outputShape[4] = in[4];
}
else {
outputShape[1] = width;
outputShape[2] = height;
outputShape[3] = in[3];
}
ShapeUtils::updateStridesAndType(outputShape, in, shape::order(in)); ShapeUtils::updateStridesAndType(outputShape, in, shape::order(in));
shapeList->push_back(CONSTANT(outputShape)); shapeList->push_back(CONSTANT(outputShape));
@ -83,7 +102,7 @@ namespace nd4j {
DECLARE_TYPES(resize_bicubic) { DECLARE_TYPES(resize_bicubic) {
getOpDescriptor() getOpDescriptor()
->setAllowedInputTypes(0, {ALL_FLOATS}) ->setAllowedInputTypes(0, {ALL_FLOATS})
->setAllowedInputTypes(1, {ALL_INTS}) ->setAllowedInputTypes(1, {DataType::INT32})
->setAllowedOutputTypes({ALL_FLOATS}); ->setAllowedOutputTypes({ALL_FLOATS});
} }

View File

@ -60,8 +60,7 @@ namespace nd4j {
center = 0 != INT_ARG(2); center = 0 != INT_ARG(2);
} }
auto res = helpers::resizeBilinearFunctor(block.launchContext(), &source, width, height, center, &target); return helpers::resizeBilinearFunctor(block.launchContext(), inRank==4?image:&source, width, height, center, inRank == 4?output:&target);
return res;
} }
DECLARE_SHAPE_FN(resize_bilinear) { DECLARE_SHAPE_FN(resize_bilinear) {

View File

@ -54,10 +54,10 @@ namespace nd4j {
auto inRank = image->rankOf(); auto inRank = image->rankOf();
REQUIRE_TRUE(inRank == 4 || inRank == 3, 0, "resize_nearest_neighbor: Input should be 4D tensor, but rank %i occured"); REQUIRE_TRUE(inRank == 4 || inRank == 3, 0, "resize_nearest_neighbor: Input should be 4D tensor, but rank %i occured");
REQUIRE_TRUE(inRank == output->rankOf(), 0, "resize_nearest_neighbor: Input and output ranks should be equals, but %i and %i occured.", inRank, output->rankOf()); REQUIRE_TRUE(inRank == output->rankOf(), 0, "resize_nearest_neighbor: Input and output ranks should be equals, but %i and %i occured.", inRank, output->rankOf());
auto source = inRank == 4?image->reshape(image->ordering(), {image->sizeAt(0), image->sizeAt(1), image->sizeAt(2), image->sizeAt(3)}):image->reshape(image->ordering(), {1, image->sizeAt(0), image->sizeAt(1), image->sizeAt(2)}); auto source = inRank == 4?*image:image->reshape(image->ordering(), {1, image->sizeAt(0), image->sizeAt(1), image->sizeAt(2)});
auto target = inRank == 4?output->reshape(output->ordering(), {output->sizeAt(0), output->sizeAt(1), output->sizeAt(2), output->sizeAt(3)}):output->reshape(output->ordering(), {1, output->sizeAt(0), output->sizeAt(1), output->sizeAt(2)}); auto target = inRank == 4?*output:output->reshape(output->ordering(), {1, output->sizeAt(0), output->sizeAt(1), output->sizeAt(2)});
return helpers::resizeNeighborFunctor(block.launchContext(), &source, width, height, center, &target); return helpers::resizeNeighborFunctor(block.launchContext(), inRank==4?image:&source, width, height, center, inRank == 4?output:&target);
} }
DECLARE_SHAPE_FN(resize_nearest_neighbor) { DECLARE_SHAPE_FN(resize_nearest_neighbor) {

View File

@ -89,16 +89,10 @@ namespace helpers {
} }
if (colorTable.empty()) if (colorTable.empty())
colorTable = DefaultColorTable(channels); colorTable = DefaultColorTable(channels);
// PRAGMA_OMP_PARALLEL_FOR
auto func = PRAGMA_THREADS_FOR { auto func = PRAGMA_THREADS_FOR {
for (auto batch = 0; batch < batchSize; ++batch) { // loop by batch for (auto batch = start; batch < stop; ++batch) { // loop by batch
// auto image = imageList->at(batch);
const Nd4jLong numBoxes = boxes->sizeAt(1); const Nd4jLong numBoxes = boxes->sizeAt(1);
for (auto boxIndex = 0; boxIndex < numBoxes; ++boxIndex) { for (auto boxIndex = 0; boxIndex < numBoxes; ++boxIndex) {
// box with shape
//auto internalBox = (*boxes)(batch, {0})(boxIndex, {0});//internalBoxes->at(c);
//auto color = colorSet->at(c);
//internalBox.printIndexedBuffer("Current Box");
auto colorIndex = boxIndex % colorTable.size(); auto colorIndex = boxIndex % colorTable.size();
auto rowStart = Nd4jLong((height - 1) * boxes->t<float>(batch, boxIndex, 0)); auto rowStart = Nd4jLong((height - 1) * boxes->t<float>(batch, boxIndex, 0));
auto rowStartBound = nd4j::math::nd4j_max(Nd4jLong(0), rowStart); auto rowStartBound = nd4j::math::nd4j_max(Nd4jLong(0), rowStart);
@ -152,20 +146,7 @@ namespace helpers {
output->p(batch, i, colEnd, c, colorTable[colorIndex][c]); output->p(batch, i, colEnd, c, colorTable[colorIndex][c]);
} }
} }
// for (auto y = rowStart; y <= rowEnd; y++) {
// for (auto e = 0; e < channels; ++e) {
// output->p(batch, y, colStart, e, colorTable[colorIndex][e]);
// output->p(batch, y, colEnd, e, colorTable[colorIndex][e]);
// }
// }
// for (auto x = colStart + 1; x < colEnd; x++) {
// for (auto e = 0; e < channels; ++e) {
// output->p(batch, rowStart, x, e, colorTable[colorIndex][e]);
// output->p(batch, rowEnd, x, e, colorTable[colorIndex][e]);
// }
// }
} }
// delete internalBoxes;
} }
}; };
samediff::Threads::parallel_tad(func, 0, batchSize); samediff::Threads::parallel_tad(func, 0, batchSize);

View File

@ -47,6 +47,91 @@ namespace helpers {
// https://en.wikipedia.org/wiki/Bilinear_interpolation) // https://en.wikipedia.org/wiki/Bilinear_interpolation)
double interpolarValue; double interpolarValue;
}; };
// calculateResizeScale determines the float scaling factor.
inline float calculateResizeScale(Nd4jLong inSize, Nd4jLong outSize,
bool alignCorners) {
return (alignCorners && outSize > 1)
? (inSize - 1) / static_cast<float>(outSize - 1)
: inSize / static_cast<float>(outSize);
}
struct ImageResizerState {
explicit ImageResizerState(bool alignCorners, bool halfPixelCenters)
: _alignCorners(alignCorners),
_halfPixelCenters(halfPixelCenters) {}
// ValidateAndCalculateOutputSize checks the bounds on the input tensors
// and requested size, sets up some of the resizing state such as the
// heightScale and widthScale, and calculates the output size.
// If any of these operations fails, it sets an error status in
// the context, which the caller must check.
int validateAndCalculateOutputSize(NDArray const* input, int const width, int const height) {
//
batchSize = input->sizeAt(0);//.dim_size(0);
outHeight = height;
outWidth = width; //internal::SubtleMustCopy(Svec(1));
inHeight = static_cast<int32_t>(input->sizeAt(1));
inWidth = static_cast<int32_t>(input->sizeAt(2));
channels = input->sizeAt(3); //.dim_size(3);
heightScale = calculateResizeScale(inHeight, outHeight, _alignCorners);
widthScale = calculateResizeScale(inWidth, outWidth, _alignCorners);
// Guard against overflows
if (ceilf((outHeight - 1) * heightScale) > static_cast<float>(DataTypeUtils::max<int>())) {
nd4j_printf("resize_bicubic: Upper overflow occurs for resize height (%f)\n", ceilf((outHeight - 1) * heightScale));
return Status::CODE(ND4J_STATUS_BAD_INPUT, "resize_bicubic: Upper overflow occurs for resize height");
}
if (ceilf((outWidth - 1) * heightScale) > static_cast<float>(DataTypeUtils::max<int>())) {
nd4j_printf("resize_bicubic: Upper overflow occurs for resize height (%f)\n", ceilf((outHeight - 1) * heightScale));
return Status::CODE(ND4J_STATUS_BAD_INPUT, "resize_bicubic: Upper overflow occurs for resize width");
}
return Status::OK();
}
// Calculates all the required variables, and allocates the output.
int validateAndCreateOutput(NDArray const* input, int const width, int const height) {
return validateAndCalculateOutputSize(input, width, height);
}
Nd4jLong batchSize;
Nd4jLong outHeight;
Nd4jLong outWidth;
Nd4jLong inHeight;
Nd4jLong inWidth;
Nd4jLong channels;
float heightScale;
float widthScale;
NDArray* output = nullptr;
private:
bool _alignCorners;
bool _halfPixelCenters;
};
// Half pixel scaler scales assuming that the pixel centers are at 0.5, i.e. the
// floating point coordinates of the top,left pixel is 0.5,0.5.
struct HalfPixelScaler {
HalfPixelScaler(){};
inline float operator()(const int x, const float scale) const {
// Note that we subtract 0.5 from the return value, as the existing bilinear
// sampling code etc assumes pixels are in the old coordinate system.
return (static_cast<float>(x) + 0.5f) * scale - 0.5f;
}
};
struct WeightsAndIndices {
float _weight0;
float _weight1;
float _weight2;
float _weight3;
Nd4jLong _index0;
Nd4jLong _index1;
Nd4jLong _index2;
Nd4jLong _index3;
int _advance; // advance value.
};
inline void computeInterpolationWeights(Nd4jLong outSize, inline void computeInterpolationWeights(Nd4jLong outSize,
Nd4jLong inSize, Nd4jLong inSize,
@ -55,13 +140,16 @@ namespace helpers {
interpolationData[outSize].bottomIndex = 0; interpolationData[outSize].bottomIndex = 0;
interpolationData[outSize].topIndex = 0; interpolationData[outSize].topIndex = 0;
PRAGMA_OMP_PARALLEL_FOR auto func = PRAGMA_THREADS_FOR {
for (Nd4jLong i = outSize - 1; i >= 0; --i) { for (auto k = start; k < stop; k++) {
auto i = (outSize - k - 1);
double in = i * scale; double in = i * scale;
interpolationData[i].bottomIndex = static_cast<Nd4jLong>(in); interpolationData[i].bottomIndex = static_cast<Nd4jLong>(in);
interpolationData[i].topIndex = nd4j::math::nd4j_min(interpolationData[i].bottomIndex + 1, inSize - 1); interpolationData[i].topIndex = nd4j::math::nd4j_min(interpolationData[i].bottomIndex + 1, inSize - 1);
interpolationData[i].interpolarValue = in - interpolationData[i].bottomIndex; interpolationData[i].interpolarValue = in - interpolationData[i].bottomIndex;
} }
};
samediff::Threads::parallel_for(func, 0, outSize);
} }
/** /**
@ -87,10 +175,10 @@ namespace helpers {
Nd4jLong inBatchNumValues = inHeight * inRowSize; Nd4jLong inBatchNumValues = inHeight * inRowSize;
Nd4jLong outRowSize = outWidth * channels; Nd4jLong outRowSize = outWidth * channels;
T const *input_b_ptr = reinterpret_cast<T const *>(images->getBuffer()); // this works only with 'c' direction T const *pInput = images->getDataBuffer()->primaryAsT<T>(); // this works only with 'c' direction
BilinearInterpolationData const *xs_ = xs.data(); BilinearInterpolationData const *xs_ = xs.data();
T *output_y_ptr = reinterpret_cast<T *>(output->buffer()); T* pOutput = output->dataBuffer()->primaryAsT<T>();
auto computeBilinear = [](double topLeft, double topRight, auto computeBilinear = [](double topLeft, double topRight,
double bottomLeft, double bottomRight, double bottomLeft, double bottomRight,
double xVal, double yVal) { double xVal, double yVal) {
@ -100,28 +188,28 @@ namespace helpers {
}; };
auto func = PRAGMA_THREADS_FOR { auto func = PRAGMA_THREADS_FOR {
for (Nd4jLong b = 0; b < batchSize; ++b) { for (auto b = start; b < stop; ++b) {
for (Nd4jLong y = 0; y < outHeight; ++y) { for (auto y = 0; y < outHeight; ++y) {
const T *ys_input_lower_ptr = input_b_ptr + ys[y].bottomIndex * inRowSize; const T *ys_input_lower_ptr = pInput + ys[y].bottomIndex * inRowSize;
const T *ys_input_upper_ptr = input_b_ptr + ys[y].topIndex * inRowSize; const T *ys_input_upper_ptr = pInput + ys[y].topIndex * inRowSize;
double yVal = ys[y].interpolarValue; double yVal = ys[y].interpolarValue;
for (Nd4jLong x = 0; x < outWidth; ++x) { for (auto x = 0; x < outWidth; ++x) {
auto xsBottom = xs_[x].bottomIndex; auto xsBottom = xs_[x].bottomIndex;
auto xsTop = xs_[x].topIndex; auto xsTop = xs_[x].topIndex;
auto xVal = xs_[x].interpolarValue; auto xVal = xs_[x].interpolarValue;
for (int c = 0; c < channels; ++c) { for (auto c = 0; c < channels; ++c) {
double topLeft(ys_input_lower_ptr[xsBottom + c]); double topLeft(ys_input_lower_ptr[xsBottom + c]);
double topRight(ys_input_lower_ptr[xsTop + c]); double topRight(ys_input_lower_ptr[xsTop + c]);
double bottomLeft(ys_input_upper_ptr[xsBottom + c]); double bottomLeft(ys_input_upper_ptr[xsBottom + c]);
double bottomRight(ys_input_upper_ptr[xsTop + c]); double bottomRight(ys_input_upper_ptr[xsTop + c]);
output_y_ptr[x * channels + c] = pOutput[x * channels + c] =
computeBilinear(topLeft, topRight, bottomLeft, bottomRight, computeBilinear(topLeft, topRight, bottomLeft, bottomRight,
xVal, yVal); xVal, yVal);
} }
} }
output_y_ptr += outRowSize; pOutput += outRowSize;
} }
input_b_ptr += inBatchNumValues; pInput += inBatchNumValues;
} }
}; };
samediff::Threads::parallel_tad(func, 0, batchSize); samediff::Threads::parallel_tad(func, 0, batchSize);
@ -192,7 +280,7 @@ namespace helpers {
// Handle no-op resizes efficiently. // Handle no-op resizes efficiently.
if (outHeight == inHeight && outWidth == inWidth) { if (outHeight == inHeight && outWidth == inWidth) {
output->assign(images); output->assign(images);
return ND4J_STATUS_OK; return Status::OK();
} }
if ((center && inHeight < 2) || (inHeight < 1) || (outHeight < 1) || (center && outHeight < 2) || if ((center && inHeight < 2) || (inHeight < 1) || (outHeight < 1) || (center && outHeight < 2) ||
@ -209,9 +297,9 @@ namespace helpers {
for (auto y = start_y; y < stop_y; y += inc_y) { for (auto y = start_y; y < stop_y; y += inc_y) {
Nd4jLong inY = nd4j::math::nd4j_min((center) ? static_cast<Nd4jLong>(nd4j::math::p_round<float>(y * heightScale)) : static_cast<Nd4jLong>(nd4j::math::p_floor<float>(y * heightScale)), inHeight - 1); Nd4jLong inY = nd4j::math::nd4j_min((center) ? static_cast<Nd4jLong>(nd4j::math::p_round<float>(y * heightScale)) : static_cast<Nd4jLong>(nd4j::math::p_floor<float>(y * heightScale)), inHeight - 1);
for (int x = 0; x < outWidth; ++x) { for (auto x = 0; x < outWidth; ++x) {
Nd4jLong inX = nd4j::math::nd4j_min((center) ? static_cast<Nd4jLong>(nd4j::math::p_round<float>(x * widthScale)) : static_cast<Nd4jLong>(nd4j::math::p_floor<float>(x * widthScale)),inWidth - 1); Nd4jLong inX = nd4j::math::nd4j_min((center) ? static_cast<Nd4jLong>(nd4j::math::p_round<float>(x * widthScale)) : static_cast<Nd4jLong>(nd4j::math::p_floor<float>(x * widthScale)),inWidth - 1);
for (Nd4jLong e = 0; e < channels; e++) for (auto e = 0; e < channels; e++)
output->p(b, y, x, e, images->e<T>(b, inY, inX, e)); output->p(b, y, x, e, images->e<T>(b, inY, inX, e));
} }
} }
@ -232,28 +320,16 @@ namespace helpers {
LIBND4J_TYPES); LIBND4J_TYPES);
} }
BUILD_SINGLE_TEMPLATE(template void resizeImage_,
(NDArray const* images, Nd4jLong batchSize, Nd4jLong inHeight, Nd4jLong inWidth, Nd4jLong outHeight,
Nd4jLong outWidth, Nd4jLong channels,
std::vector<BilinearInterpolationData> const& xs,
std::vector<BilinearInterpolationData> const& ys,
NDArray* output), LIBND4J_TYPES);
int resizeBilinearFunctor(nd4j::LaunchContext * context, NDArray const *images, int width, int height, bool center, NDArray *output) { int resizeBilinearFunctor(nd4j::LaunchContext * context, NDArray const *images, int width, int height, bool center, NDArray *output) {
BUILD_SINGLE_SELECTOR(images->dataType(), return resizeBilinearFunctor_, BUILD_SINGLE_SELECTOR(images->dataType(), return resizeBilinearFunctor_,
(images, width, height, center, output), LIBND4J_TYPES); (images, width, height, center, output), LIBND4J_TYPES);
} }
BUILD_SINGLE_TEMPLATE(template int resizeBilinearFunctor_,
(NDArray const* images, int width, int height, bool center, NDArray* output), LIBND4J_TYPES);
int resizeNeighborFunctor(nd4j::LaunchContext * context, NDArray const *images, int width, int height, bool center, NDArray *output) { int resizeNeighborFunctor(nd4j::LaunchContext * context, NDArray const *images, int width, int height, bool center, NDArray *output) {
BUILD_SINGLE_SELECTOR(images->dataType(), return resizeNeighborFunctor_, BUILD_SINGLE_SELECTOR(images->dataType(), return resizeNeighborFunctor_,
(images, width, height, center, output), LIBND4J_TYPES); (images, width, height, center, output), LIBND4J_TYPES);
} }
BUILD_SINGLE_TEMPLATE(template int resizeNeighborFunctor_,
(NDArray const* images, int width, int height, bool center, NDArray* output), LIBND4J_TYPES);
template<typename T, typename F, typename I> template<typename T, typename F, typename I>
static void cropAndResizeFunctor_(NDArray const *images, NDArray const *boxes, NDArray const *indices, static void cropAndResizeFunctor_(NDArray const *images, NDArray const *boxes, NDArray const *indices,
@ -267,7 +343,7 @@ namespace helpers {
const int cropWidth = crops->sizeAt(2); const int cropWidth = crops->sizeAt(2);
const int depth = crops->sizeAt(3); const int depth = crops->sizeAt(3);
for (int b = 0; b < numBoxes; ++b) { for (auto b = 0; b < numBoxes; ++b) {
T y1 = boxes->t<F>(b, 0); T y1 = boxes->t<F>(b, 0);
T x1 = boxes->t<F>(b, 1); T x1 = boxes->t<F>(b, 1);
T y2 = boxes->t<F>(b, 2); T y2 = boxes->t<F>(b, 2);
@ -282,14 +358,14 @@ namespace helpers {
T widthScale = (cropWidth > 1) ? (x2 - x1) * (imageWidth - 1) / (cropWidth - 1) : T(0); T widthScale = (cropWidth > 1) ? (x2 - x1) * (imageWidth - 1) / (cropWidth - 1) : T(0);
auto func = PRAGMA_THREADS_FOR { auto func = PRAGMA_THREADS_FOR {
for (int y = start; y < stop; y += increment) { for (auto y = start; y < stop; y += increment) {
const float inY = (cropHeight > 1) const float inY = (cropHeight > 1)
? y1 * (imageHeight - 1) + y * heightScale ? y1 * (imageHeight - 1) + y * heightScale
: 0.5 * (y1 + y2) * (imageHeight - 1); : 0.5 * (y1 + y2) * (imageHeight - 1);
if (inY < 0 || inY > imageHeight - 1) { if (inY < 0 || inY > imageHeight - 1) {
for (int x = 0; x < cropWidth; ++x) { for (auto x = 0; x < cropWidth; ++x) {
for (int d = 0; d < depth; ++d) { for (auto d = 0; d < depth; ++d) {
crops->p(b, y, x, d, extrapolationVal); crops->p(b, y, x, d, extrapolationVal);
} }
} }
@ -300,13 +376,13 @@ namespace helpers {
const int bottomYIndex = nd4j::math::p_ceil(inY); const int bottomYIndex = nd4j::math::p_ceil(inY);
const float y_lerp = inY - topYIndex; const float y_lerp = inY - topYIndex;
for (int x = 0; x < cropWidth; ++x) { for (auto x = 0; x < cropWidth; ++x) {
const float in_x = (cropWidth > 1) const float in_x = (cropWidth > 1)
? x1 * (imageWidth - 1) + x * widthScale ? x1 * (imageWidth - 1) + x * widthScale
: 0.5 * (x1 + x2) * (imageWidth - 1); : 0.5 * (x1 + x2) * (imageWidth - 1);
if (in_x < 0 || in_x > imageWidth - 1) { if (in_x < 0 || in_x > imageWidth - 1) {
for (int d = 0; d < depth; ++d) { for (auto d = 0; d < depth; ++d) {
crops->p(b, y, x, d, extrapolationVal); crops->p(b, y, x, d, extrapolationVal);
} }
continue; continue;
@ -315,7 +391,7 @@ namespace helpers {
int right_x_index = math::p_ceil(in_x); int right_x_index = math::p_ceil(in_x);
T x_lerp = in_x - left_x_index; T x_lerp = in_x - left_x_index;
for (int d = 0; d < depth; ++d) { for (auto d = 0; d < depth; ++d) {
const float topLeft(images->e<float>(bIn, topYIndex, left_x_index, d)); const float topLeft(images->e<float>(bIn, topYIndex, left_x_index, d));
const float topRight(images->e<float>(bIn, topYIndex, right_x_index, d)); const float topRight(images->e<float>(bIn, topYIndex, right_x_index, d));
const float bottomLeft(images->e<float>(bIn, bottomYIndex, left_x_index, d)); const float bottomLeft(images->e<float>(bIn, bottomYIndex, left_x_index, d));
@ -326,21 +402,21 @@ namespace helpers {
} }
} }
} else { // method is "nearest neighbor" } else { // method is "nearest neighbor"
for (int x = 0; x < cropWidth; ++x) { for (auto x = 0; x < cropWidth; ++x) {
const float inX = (cropWidth > 1) const float inX = (cropWidth > 1)
? x1 * (imageWidth - 1) + x * widthScale ? x1 * (imageWidth - 1) + x * widthScale
: 0.5 * (x1 + x2) * (imageWidth - 1); : 0.5 * (x1 + x2) * (imageWidth - 1);
if (inX < 0 || inX > imageWidth - 1) { if (inX < 0 || inX > imageWidth - 1) {
for (int d = 0; d < depth; ++d) { for (auto d = 0; d < depth; ++d) {
crops->p(b, y, x, d, extrapolationVal); crops->p(b, y, x, d, extrapolationVal);
} }
continue; continue;
} }
const int closestXIndex = roundf(inX); const int closestXIndex = roundf(inX);
const int closestYIndex = roundf(inY); const int closestYIndex = roundf(inY);
for (int d = 0; d < depth; ++d) { for (auto d = 0; d < depth; ++d) {
crops->p(b, y, x, d, (F) images->e<T>(bIn, closestYIndex, closestXIndex, d)); crops->p(b, y, x, d, images->e<T>(bIn, closestYIndex, closestXIndex, d));
} }
} }
} }
@ -350,12 +426,443 @@ namespace helpers {
samediff::Threads::parallel_for(func, 0, cropHeight); samediff::Threads::parallel_for(func, 0, cropHeight);
} }
} }
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
// ------------------------------------------------------------------------------------------------------------------ //
// Bicubic interpolation
// ------------------------------------------------------------------------------------------------------------------ //
class CachedInterpolationCalculator {
public:
CachedInterpolationCalculator() : _indexes{-1, -1, -1, -1} {}
int resizeBicubicFunctor(nd4j::LaunchContext * context, NDArray const* image, int width, int height, // Advances iteration. Returns the number of values that should be copied from
// the current point to the next point. The copying should always be done by
// copying the last <retval> values from the old point to the first <retval>
// values of the new point.
inline int Advance(const Nd4jLong x0, const Nd4jLong x1, const Nd4jLong x2,
const Nd4jLong x3) {
// We use 2 hands and walk through, copying from one to another where
// we already have values.
// Invariant, new_indicies_hand <= cached_values_hand
const Nd4jLong new_x_indices[] = {x0, x1, x2, x3};
int cachedValuesHand = 0;
int newIndiciesHand = 0;
while (cachedValuesHand < 4) {
if (_indexes[cachedValuesHand] == new_x_indices[newIndiciesHand]) {
if (newIndiciesHand < cachedValuesHand) {
_indexes[newIndiciesHand] = _indexes[cachedValuesHand];
}
newIndiciesHand++;
}
cachedValuesHand++;
}
switch (newIndiciesHand) {
case 0:
_indexes[0] = x0;
case 1:
_indexes[1] = x1;
case 2:
_indexes[2] = x2;
case 3:
_indexes[3] = x3;
break;
}
return newIndiciesHand;
}
private:
Nd4jLong _indexes[4];
};
static const Nd4jLong kTableSize = 1024LL; //(1 << 10);
const float* initCoeffsTable(const double a) {
// Allocate and initialize coefficients table using Bicubic
// convolution algorithm.
// https://en.wikipedia.org/wiki/Bicubic_interpolation
float* coeffs_table = new float[(kTableSize + 1) * 2];
auto func = PRAGMA_THREADS_FOR {
for (auto i = start; i <= stop; ++i) {
float x = i * 1.0 / kTableSize;
coeffs_table[i * 2] = ((a + 2) * x - (a + 3)) * x * x + 1;
x += 1.0;
coeffs_table[i * 2 + 1] = ((a * x - 5 * a) * x + 8 * a) * x - 4 * a;
}
};
samediff::Threads::parallel_for(func, 0, kTableSize);
return coeffs_table;
}
const float* getCoeffsTable(const bool use_keys_cubic) {
// Static so that we initialize it on first use
if (use_keys_cubic) {
// http://ieeexplore.ieee.org/document/1163711/
// R. G. Keys. Cubic convolution interpolation for digital image
// processing. IEEE Transactions on Acoustics, Speech, and Signal
// Processing, 29(6):11531160, 1981.
static const float* coeffs_table = initCoeffsTable(-0.5f);
return coeffs_table;
} else {
static const float* coeffs_table = initCoeffsTable(-0.75f);
return coeffs_table;
}
}
inline Nd4jLong bound(Nd4jLong val, Nd4jLong limit) {
return math::nd4j_min(limit - 1ll, math::nd4j_max(Nd4jLong{0}, val));
}
template <typename T>
int resizeBicubicFunctor_(nd4j::LaunchContext * context, NDArray const* image, int width, int height,
bool preserveAspectRatio, bool antialias, NDArray* output) { bool preserveAspectRatio, bool antialias, NDArray* output) {
return ND4J_STATUS_OK; return ND4J_STATUS_OK;
} }
int resizeBicubicFunctor(nd4j::LaunchContext * context, NDArray const* image, int width, int height,
bool preserveAspectRatio, bool antialias, NDArray* output) {
BUILD_SINGLE_SELECTOR(image->dataType(), return resizeBicubicFunctor_, (context, image,
width, height, preserveAspectRatio, antialias, output), NUMERIC_TYPES);
}
// ------------------------------------------------------------------------------------------------------------------ //
template <typename T>
inline float interpolate1D(const float weight0, const float weight1, const float weight2, const float weight3,
const T value0, const T value1, const T value2, const T value3) {
return static_cast<float>(value0) * weight0 +
static_cast<float>(value1) * weight1 +
static_cast<float>(value2) * weight2 +
static_cast<float>(value3) * weight3;
}
// Compute the 1D interpolation for a given X index using the y_weights
static float compute(float values[4], const float xW0, const float xW1, const float xW2, const float xW3) {
return interpolate1D(xW0, xW1, xW2, xW3, values[0], values[1],values[2], values[3]);
}
template <typename Scaler, bool use_keys_cubic>
inline void getWeightsAndIndices(const float scale, const Nd4jLong out_loc, const Nd4jLong limit, WeightsAndIndices* out) {
const Scaler scaler;
const float in_loc_f = scaler(out_loc, scale);
const Nd4jLong in_loc = std::floor(in_loc_f);
const float delta = in_loc_f - in_loc;
const Nd4jLong offset = lrintf(delta * kTableSize);
const float* coeffs_table = getCoeffsTable(use_keys_cubic);
if (use_keys_cubic) {
// The legacy code placed more weight on the edge pixels, since bounding
// the set of inputs to sample could cause an edge pixel to be repeated.
// Here we change the behavior at borders to match that used by the
// scale_and_translate_op, where sampling locations outside the image have
// their weight set to 0, and the weights are renormalized so that their sum
// is 1.0.
out->_index0 = bound(in_loc - 1, limit);
out->_weight0 =
(out->_index0 == in_loc - 1 ? coeffs_table[offset * 2 + 1] : 0.0f);
out->_index1 = bound(in_loc, limit);
out->_weight1 = (out->_index1 == in_loc ? coeffs_table[offset * 2] : 0.0f);
out->_index2 = bound(in_loc + 1, limit);
out->_weight2 =
(out->_index2 == in_loc + 1 ? coeffs_table[(kTableSize - offset) * 2]
: 0.0f);
out->_index3 = bound(in_loc + 2, limit);
out->_weight3 = (out->_index3 == in_loc + 2
? coeffs_table[(kTableSize - offset) * 2 + 1]
: 0.0f);
const float weight_sum =
out->_weight0 + out->_weight1 + out->_weight2 + out->_weight3;
if (std::abs(weight_sum) >= 1000.0f * std::numeric_limits<float>::min()) {
const float one_over_weight_sum = 1.0f / weight_sum;
out->_weight0 *= one_over_weight_sum;
out->_weight1 *= one_over_weight_sum;
out->_weight2 *= one_over_weight_sum;
out->_weight3 *= one_over_weight_sum;
}
} else {
out->_weight0 = coeffs_table[offset * 2 + 1];
out->_weight1 = coeffs_table[offset * 2];
out->_weight2 = coeffs_table[(kTableSize - offset) * 2];
out->_weight3 = coeffs_table[(kTableSize - offset) * 2 + 1];
out->_index0 = bound(in_loc - 1, limit);
out->_index1 = bound(in_loc, limit);
out->_index2 = bound(in_loc + 1, limit);
out->_index3 = bound(in_loc + 2, limit);
}
}
// Older incorrect scaling method that causes all resizes to have a slight
// translation leading to inconsistent results. For example, a flip then a
// resize gives different results then a resize then a flip.
struct LegacyScaler {
LegacyScaler(){};
inline float operator()(const int x, const float scale) const {
return static_cast<float>(x) * scale;
}
};
static void computeXWeightsAndIndices(const ImageResizerState& resizer_state,
const bool half_pixel_centers,
std::vector<WeightsAndIndices>* x_wais) {
CachedInterpolationCalculator calc;
if (half_pixel_centers) {
auto func = PRAGMA_THREADS_FOR {
for (auto x = start; x < stop; ++x) {
getWeightsAndIndices<HalfPixelScaler, true>(
resizer_state.widthScale, x, resizer_state.inWidth, &(*x_wais)[x]);
auto &x_wai = (*x_wais)[x];
x_wai._advance = calc.Advance(x_wai._index0, x_wai._index1, x_wai._index2,
x_wai._index3);
}
};
samediff::Threads::parallel_for(func, 0, resizer_state.outWidth);
} else {
auto func = PRAGMA_THREADS_FOR {
for (auto x = start; x < stop; ++x) {
getWeightsAndIndices<LegacyScaler, false>(
resizer_state.widthScale, x, resizer_state.inWidth, &(*x_wais)[x]);
auto& x_wai = (*x_wais)[x];
x_wai._advance = calc.Advance(x_wai._index0, x_wai._index1, x_wai._index2,
x_wai._index3);
}
};
samediff::Threads::parallel_for(func, 0, resizer_state.outWidth);
}
// Scale the values so they can be used as offsets into buffers.
auto func = PRAGMA_THREADS_FOR {
for (auto x = start; x < stop; ++x) {
(*x_wais)[x]._index0 *= resizer_state.channels;
(*x_wais)[x]._index1 *= resizer_state.channels;
(*x_wais)[x]._index2 *= resizer_state.channels;
(*x_wais)[x]._index3 *= resizer_state.channels;
}
};
samediff::Threads::parallel_for(func, 0, resizer_state.outWidth);
}
template <typename T>
static FORCEINLINE float computeYInterpolation(
int which, int channelNum, const WeightsAndIndices& yWai,
const T* pY0, const T* pY1, const T* pY2, const T* pY3,
const WeightsAndIndices& xWai) {
int xIndex;
switch (which) {
case 0:
xIndex = xWai._index0;
break;
case 1:
xIndex = xWai._index1;
break;
case 2:
xIndex = xWai._index2;
break;
default:
xIndex = xWai._index3;
break;
}
const Nd4jLong pt_index = xIndex + channelNum;
return interpolate1D<T>(yWai._weight0, yWai._weight1, yWai._weight2,
yWai._weight3, pY0[pt_index], pY1[pt_index],
pY2[pt_index], pY3[pt_index]);
}
template <typename T>
static void
bicubicInterpolateWithCaching(NDArray const* image, ImageResizerState const& resizerState, bool const halfPixelCenters, NDArray* output) {
std::vector<WeightsAndIndices> xWais(resizerState.outWidth);
computeXWeightsAndIndices(resizerState, halfPixelCenters, &xWais);
const auto numChannels = resizerState.channels;
const Nd4jLong inRowWidth = resizerState.inWidth * numChannels;
const Nd4jLong inBatchWidth = resizerState.inHeight * inRowWidth;
const T* inputPtr = image->getDataBuffer()->primaryAsT<T>();
T* pOutputY = output->dataBuffer()->primaryAsT<T>(); //_data.data();
std::vector<float> cachedValue(numChannels == 3 ? 0 : 4 * numChannels, 0);
auto func = PRAGMA_THREADS_FOR {
for (auto b = start; b < stop; ++b) {
auto pInput = inputPtr + b * inBatchWidth;
for (auto y = 0; y < resizerState.outHeight; ++y) {
auto pOutput = &pOutputY[(b * resizerState.outHeight + y) * resizerState.outWidth * numChannels];
WeightsAndIndices yWai;
if (halfPixelCenters) {
getWeightsAndIndices<HalfPixelScaler, true>(
resizerState.heightScale, y, resizerState.inHeight, &yWai);
} else {
getWeightsAndIndices<LegacyScaler, false>(
resizerState.heightScale, y, resizerState.inHeight, &yWai);
}
// Make pointers represent offsets of data in inputBPtr.
const T *y_ptr_0 = pInput + yWai._index0 * inRowWidth;
const T *y_ptr_1 = pInput + yWai._index1 * inRowWidth;
const T *y_ptr_2 = pInput + yWai._index2 * inRowWidth;
const T *y_ptr_3 = pInput + yWai._index3 * inRowWidth;
if (numChannels == 3) {
// Manually unroll case of 3 channels.
float cached_value_0[4] = {0};
float cached_value_1[4] = {0};
float cached_value_2[4] = {0};
for (auto x = 0; x < resizerState.outWidth; ++x) {
const WeightsAndIndices &xWai = xWais[x];
// Shift values in cached_value_* to fill first '_advance' values.
switch (xWai._advance) {
case 3:
cached_value_0[0] = cached_value_0[1];
cached_value_0[1] = cached_value_0[2];
cached_value_0[2] = cached_value_0[3];
cached_value_1[0] = cached_value_1[1];
cached_value_1[1] = cached_value_1[2];
cached_value_1[2] = cached_value_1[3];
cached_value_2[0] = cached_value_2[1];
cached_value_2[1] = cached_value_2[2];
cached_value_2[2] = cached_value_2[3];
break;
case 2:
cached_value_0[0] = cached_value_0[2];
cached_value_0[1] = cached_value_0[3];
cached_value_1[0] = cached_value_1[2];
cached_value_1[1] = cached_value_1[3];
cached_value_2[0] = cached_value_2[2];
cached_value_2[1] = cached_value_2[3];
break;
case 1: {
cached_value_0[0] = cached_value_0[3];
cached_value_1[0] = cached_value_1[3];
cached_value_2[0] = cached_value_2[3];
break;
}
}
// Set the remaining '4-_advance' values by computing.
switch (xWai._advance) {
case 0:
cached_value_0[0] = computeYInterpolation(
0, 0, yWai, y_ptr_0, y_ptr_1, y_ptr_2, y_ptr_3, xWai);
cached_value_1[0] = computeYInterpolation(
0, 1, yWai, y_ptr_0, y_ptr_1, y_ptr_2, y_ptr_3, xWai);
cached_value_2[0] = computeYInterpolation(
0, 2, yWai, y_ptr_0, y_ptr_1, y_ptr_2, y_ptr_3, xWai);
case 1:
cached_value_0[1] = computeYInterpolation(
1, 0, yWai, y_ptr_0, y_ptr_1, y_ptr_2, y_ptr_3, xWai);
cached_value_1[1] = computeYInterpolation(
1, 1, yWai, y_ptr_0, y_ptr_1, y_ptr_2, y_ptr_3, xWai);
cached_value_2[1] = computeYInterpolation(
1, 2, yWai, y_ptr_0, y_ptr_1, y_ptr_2, y_ptr_3, xWai);
case 2:
cached_value_0[2] = computeYInterpolation(
2, 0, yWai, y_ptr_0, y_ptr_1, y_ptr_2, y_ptr_3, xWai);
cached_value_1[2] = computeYInterpolation(
2, 1, yWai, y_ptr_0, y_ptr_1, y_ptr_2, y_ptr_3, xWai);
cached_value_2[2] = computeYInterpolation(
2, 2, yWai, y_ptr_0, y_ptr_1, y_ptr_2, y_ptr_3, xWai);
case 3:
cached_value_0[3] = computeYInterpolation(
3, 0, yWai, y_ptr_0, y_ptr_1, y_ptr_2, y_ptr_3, xWai);
cached_value_1[3] = computeYInterpolation(
3, 1, yWai, y_ptr_0, y_ptr_1, y_ptr_2, y_ptr_3, xWai);
cached_value_2[3] = computeYInterpolation(
3, 2, yWai, y_ptr_0, y_ptr_1, y_ptr_2, y_ptr_3, xWai);
break;
}
pOutput[x * numChannels + 0] =
compute(cached_value_0, xWai._weight0, xWai._weight1,
xWai._weight2, xWai._weight3);
pOutput[x * numChannels + 1] =
compute(cached_value_1, xWai._weight0, xWai._weight1,
xWai._weight2, xWai._weight3);
pOutput[x * numChannels + 2] =
compute(cached_value_2, xWai._weight0, xWai._weight1,
xWai._weight2, xWai._weight3);
}
} else {
for (auto x = 0; x < resizerState.outWidth; ++x) {
const WeightsAndIndices &xWai = xWais[x];
// Shift values in cachedValue to fill first '_advance' values.
switch (xWai._advance) {
case 3:
for (auto c = 0; c < numChannels; ++c) {
cachedValue[4 * c + 0] = cachedValue[4 * c + 1];
cachedValue[4 * c + 1] = cachedValue[4 * c + 2];
cachedValue[4 * c + 2] = cachedValue[4 * c + 3];
}
break;
case 2:
for (auto c = 0; c < numChannels; ++c) {
cachedValue[4 * c + 0] = cachedValue[4 * c + 2];
cachedValue[4 * c + 1] = cachedValue[4 * c + 3];
}
break;
case 1: {
for (auto c = 0; c < numChannels; ++c) {
cachedValue[4 * c + 0] = cachedValue[4 * c + 3];
}
break;
}
}
// Set the remaining '4-_advance' values by computing.
switch (xWai._advance) {
case 0:
for (auto c = 0; c < numChannels; ++c) {
cachedValue[4 * c + 0] = computeYInterpolation(
0, c, yWai, y_ptr_0, y_ptr_1, y_ptr_2, y_ptr_3, xWai);
}
case 1:
for (auto c = 0; c < numChannels; ++c) {
cachedValue[4 * c + 1] = computeYInterpolation(
1, c, yWai, y_ptr_0, y_ptr_1, y_ptr_2, y_ptr_3, xWai);
}
case 2:
for (auto c = 0; c < numChannels; ++c) {
cachedValue[4 * c + 2] = computeYInterpolation(
2, c, yWai, y_ptr_0, y_ptr_1, y_ptr_2, y_ptr_3, xWai);
}
case 3:
for (auto c = 0; c < numChannels; ++c) {
cachedValue[4 * c + 3] = computeYInterpolation(
3, c, yWai, y_ptr_0, y_ptr_1, y_ptr_2, y_ptr_3, xWai);
}
break;
}
for (auto c = 0; c < numChannels; ++c) {
pOutput[x * numChannels + c] =
compute(&cachedValue[4 * c], xWai._weight0, xWai._weight1,
xWai._weight2, xWai._weight3);
}
}
}
}
}
};
samediff::Threads::parallel_tad(func, 0, resizerState.batchSize);
}
// simplified bicubic resize without antialiasing
//
template <typename T>
int resizeBicubicFunctorA_(nd4j::LaunchContext * context, NDArray const* image, int width, int height,
bool const alignCorners, bool const halfPixelAlign, NDArray* output) {
ImageResizerState st(alignCorners, halfPixelAlign); // align_corners, half_pixel_align
int res = st.validateAndCreateOutput(image, width, height);
if (res == Status::OK())
bicubicInterpolateWithCaching<T>(image, st, halfPixelAlign, output);
return res;
}
int resizeBicubicFunctorA(nd4j::LaunchContext * context, NDArray const* image, int width, int height,
bool const alignCorners, bool const halfPixelAlign, NDArray* output) {
BUILD_SINGLE_SELECTOR(image->dataType(), return resizeBicubicFunctorA_, (context,
image, width, height, alignCorners, halfPixelAlign, output), NUMERIC_TYPES);
}
// ------------------------------------------------------------------------------------------------------------------ //
int resizeFunctor(nd4j::LaunchContext * context, NDArray const* image, int width, int height, int resizeFunctor(nd4j::LaunchContext * context, NDArray const* image, int width, int height,
ImageResizeMethods method, bool preserveAspectRatio, bool antialias, NDArray* output) { ImageResizeMethods method, bool preserveAspectRatio, bool antialias, NDArray* output) {
switch (method) { switch (method) {
@ -371,8 +878,21 @@ namespace helpers {
return ND4J_STATUS_OK; return ND4J_STATUS_OK;
} }
// ------------------------------------------------------------------------------------------------------------------ //
// ------------------------------------------------------------------------------------------------------------------ //
// crop and resize helper functor:
// \@param context - launch context for operation
// \@param images - batch of images (4D tensor) with shape {batch, width, height, channels} with given type
// \@param boxes - float boxes for crop
// \@param indices - integer boxes indices for crop
// \@param cropSize - integer size (newWidth, newHeight)
// \@param method - one of bilinear (0) or nearest neighbour (1) interpolation algorithm
// \@param extrapolationVal - radix to increase/decrease image
// \@param crops - output image batch (4D with given type)
//
void void
cropAndResizeFunctor(nd4j::LaunchContext * context, NDArray const *images, NDArray const *boxes, NDArray const *indices, NDArray const *cropSize, cropAndResizeFunctor(nd4j::LaunchContext * context, NDArray const *images, NDArray const *boxes,
NDArray const *indices, NDArray const *cropSize,
int method, double extrapolationVal, NDArray *crops) { int method, double extrapolationVal, NDArray *crops) {
BUILD_TRIPLE_SELECTOR(images->dataType(), boxes->dataType(), indices->dataType(), cropAndResizeFunctor_, BUILD_TRIPLE_SELECTOR(images->dataType(), boxes->dataType(), indices->dataType(), cropAndResizeFunctor_,
(images, boxes, indices, cropSize, method, extrapolationVal, crops), NUMERIC_TYPES, FLOAT_TYPES, INTEGER_TYPES); (images, boxes, indices, cropSize, method, extrapolationVal, crops), NUMERIC_TYPES, FLOAT_TYPES, INTEGER_TYPES);

View File

@ -293,11 +293,614 @@ namespace helpers {
BUILD_SINGLE_TEMPLATE(template int resizeNeighborFunctor_, (nd4j::LaunchContext* context, NDArray const* images, BUILD_SINGLE_TEMPLATE(template int resizeNeighborFunctor_, (nd4j::LaunchContext* context, NDArray const* images,
int width, int height, bool center, NDArray* output), LIBND4J_TYPES); int width, int height, bool center, NDArray* output), LIBND4J_TYPES);
int resizeBicubicFunctor(nd4j::LaunchContext * context, NDArray const* image, int width, int height, ////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
bool preserveAspectRatio, bool antialias, NDArray* output) { // Bicubic interpolation
return ND4J_STATUS_OK; ////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
// Utility functions and classes
// calculateResizeScale determines the float scaling factor.
inline float calculateResizeScale(Nd4jLong inSize, Nd4jLong outSize,
bool alignCorners) {
return (alignCorners && outSize > 1)
? (inSize - 1) / static_cast<float>(outSize - 1)
: inSize / static_cast<float>(outSize);
} }
struct ImageResizerState {
explicit ImageResizerState(bool alignCorners, bool halfPixelCenters)
: _alignCorners(alignCorners),
_halfPixelCenters(halfPixelCenters) {}
// ValidateAndCalculateOutputSize checks the bounds on the input tensors
// and requested size, sets up some of the resizing state such as the
// heightScale and widthScale, and calculates the output size.
// If any of these operations fails, it sets an error status in
// the context, which the caller must check.
int validateAndCalculateOutputSize(NDArray const* input, int const width, int const height) {
//
batchSize = input->sizeAt(0);//.dim_size(0);
outHeight = height;
outWidth = width; //internal::SubtleMustCopy(Svec(1));
inHeight = static_cast<int32_t>(input->sizeAt(1));
inWidth = static_cast<int32_t>(input->sizeAt(2));
channels = input->sizeAt(3); //.dim_size(3);
heightScale = calculateResizeScale(inHeight, outHeight, _alignCorners);
widthScale = calculateResizeScale(inWidth, outWidth, _alignCorners);
// Guard against overflows
if (ceilf((outHeight - 1) * heightScale) > static_cast<float>(DataTypeUtils::max<int>())) {
nd4j_printf("resize_bicubic: Upper overflow occurs for resize height (%f)\n", ceilf((outHeight - 1) * heightScale));
return Status::CODE(ND4J_STATUS_BAD_INPUT, "resize_bicubic: Upper overflow occurs for resize height");
}
if (ceilf((outWidth - 1) * heightScale) > static_cast<float>(DataTypeUtils::max<int>())) {
nd4j_printf("resize_bicubic: Upper overflow occurs for resize height (%f)\n", ceilf((outHeight - 1) * heightScale));
return Status::CODE(ND4J_STATUS_BAD_INPUT, "resize_bicubic: Upper overflow occurs for resize width");
}
return Status::OK();
}
// Calculates all the required variables, and allocates the output.
int validateAndCreateOutput(NDArray const* input, int const width, int const height) {
return validateAndCalculateOutputSize(input, width, height);
}
Nd4jLong batchSize;
Nd4jLong outHeight;
Nd4jLong outWidth;
Nd4jLong inHeight;
Nd4jLong inWidth;
Nd4jLong channels;
float heightScale;
float widthScale;
NDArray* output = nullptr;
cudaStream_t* stream;
private:
bool _alignCorners;
bool _halfPixelCenters;
};
// Half pixel scaler scales assuming that the pixel centers are at 0.5, i.e. the
// floating point coordinates of the top,left pixel is 0.5,0.5.
struct HalfPixelScaler {
_CUDA_HD HalfPixelScaler(){};
inline _CUDA_HD float operator()(const int x, const float scale) const {
// Note that we subtract 0.5 from the return value, as the existing bilinear
// sampling code etc assumes pixels are in the old coordinate system.
return (static_cast<float>(x) + 0.5f) * scale - 0.5f;
}
};
struct WeightsAndIndices {
float _weight0;
float _weight1;
float _weight2;
float _weight3;
Nd4jLong _index0;
Nd4jLong _index1;
Nd4jLong _index2;
Nd4jLong _index3;
int _advance; // advance value.
};
class CachedInterpolationCalculator {
public:
_CUDA_HD CachedInterpolationCalculator() : _indexes{-1, -1, -1, -1} {}
// Advances iteration. Returns the number of values that should be copied from
// the current point to the next point. The copying should always be done by
// copying the last <retval> values from the old point to the first <retval>
// values of the new point.
inline _CUDA_HD int Advance(const Nd4jLong x0, const Nd4jLong x1, const Nd4jLong x2,
const Nd4jLong x3) {
// We use 2 hands and walk through, copying from one to another where
// we already have values.
// Invariant, new_indicies_hand <= cached_values_hand
const Nd4jLong new_x_indices[4] = {x0, x1, x2, x3};
int cachedValuesHand = 0;
int newIndiciesHand = 0;
while (cachedValuesHand < 4) {
if (_indexes[cachedValuesHand] == new_x_indices[newIndiciesHand]) {
if (newIndiciesHand < cachedValuesHand) {
_indexes[newIndiciesHand] = _indexes[cachedValuesHand];
}
newIndiciesHand++;
}
cachedValuesHand++;
}
switch (newIndiciesHand) {
case 0:
_indexes[0] = x0;
case 1:
_indexes[1] = x1;
case 2:
_indexes[2] = x2;
case 3:
_indexes[3] = x3;
break;
}
return newIndiciesHand;
}
private:
Nd4jLong _indexes[4];
};
static __global__ void initCoefTableKernel(const double a, float* table, Nd4jLong tableSize) {
auto start = blockIdx.x * blockDim.x + threadIdx.x;
auto step = blockDim.x * gridDim.x;
for (int i = start; i <= tableSize; i += step) {
float x = i * 1.0 / tableSize;
table[i * 2] = ((a + 2) * x - (a + 3)) * x * x + 1;
x += 1.0;
table[i * 2 + 1] = ((a * x - 5 * a) * x + 8 * a) * x - 4 * a;
}
}
static const Nd4jLong kTableSize = (1 << 10);
float* initCoeffsTable(const double a, cudaStream_t* stream) {
// Allocate and initialize coefficients table using Bicubic
// convolution algorithm.
// https://en.wikipedia.org/wiki/Bicubic_interpolation
float* coeffs_table; // = new float[(kTableSize + 1) * 2];
auto err = cudaMalloc(&coeffs_table, sizeof(float) * ((kTableSize + 1) * 2));
if (err != 0) {
throw cuda_exception::build("helpers::initCoeffsTable: Cannot allocate memory for vertical parts rectangulars", err);
}
initCoefTableKernel<<<128,128,128, *stream>>>(a, coeffs_table, kTableSize);
err = cudaStreamSynchronize(*stream);
if (err != 0) {
throw cuda_exception::build("helpers::initCoeffsTable: Cannot syncronize kernel", err);
}
return coeffs_table;
}
// _CUDA_HD const float* getCoeffsTable(const bool use_keys_cubic) {
// // Static so that we initialize it on first use
// if (use_keys_cubic) {
// // http://ieeexplore.ieee.org/document/1163711/
// // R. G. Keys. Cubic convolution interpolation for digital image
// // processing. IEEE Transactions on Acoustics, Speech, and Signal
// // Processing, 29(6):11531160, 1981.
// //static const float* coeffs_table = initCoeffsTable(-0.5f, stream);
// return sCoeffsTableHalf;
// } else {
// //static const float* coeffs_table = initCoeffsTable(-0.75f, stream);
// return sCoeffsTableThreeFourth;
// }
// }
inline _CUDA_HD Nd4jLong bound(Nd4jLong val, Nd4jLong limit) {
return math::nd4j_min(limit - 1ll, math::nd4j_max(Nd4jLong{0}, val));
}
template <typename T>
inline _CUDA_HD float interpolate1D(const float weight0, const float weight1, const float weight2, const float weight3,
const T value0, const T value1, const T value2, const T value3) {
return static_cast<float>(value0) * weight0 +
static_cast<float>(value1) * weight1 +
static_cast<float>(value2) * weight2 +
static_cast<float>(value3) * weight3;
}
// Compute the 1D interpolation for a given X index using the y_weights
static _CUDA_HD float compute(float values[4], const float xW0, const float xW1, const float xW2, const float xW3) {
return interpolate1D(xW0, xW1, xW2, xW3, values[0], values[1],values[2], values[3]);
}
template <typename Scaler, bool use_keys_cubic>
inline _CUDA_HD void getWeightsAndIndices(float const* coeffs_table, const float scale, const Nd4jLong out_loc, const Nd4jLong limit, WeightsAndIndices* out) {
const Scaler scaler;
const float in_loc_f = scaler(out_loc, scale);
const Nd4jLong in_loc = math::nd4j_floor<float, Nd4jLong>(in_loc_f);
const float delta = in_loc_f - in_loc;
const Nd4jLong offset = math::nd4j_round<float, Nd4jLong>(delta * kTableSize);
//const float* coeffs_table = getCoeffsTable(use_keys_cubic);
if (use_keys_cubic) {
// The legacy code placed more weight on the edge pixels, since bounding
// the set of inputs to sample could cause an edge pixel to be repeated.
// Here we change the behavior at borders to match that used by the
// scale_and_translate_op, where sampling locations outside the image have
// their weight set to 0, and the weights are renormalized so that their sum
// is 1.0.
out->_index0 = bound(in_loc - 1, limit);
out->_weight0 =
(out->_index0 == in_loc - 1 ? coeffs_table[offset * 2 + 1] : 0.0f);
out->_index1 = bound(in_loc, limit);
out->_weight1 = (out->_index1 == in_loc ? coeffs_table[offset * 2] : 0.0f);
out->_index2 = bound(in_loc + 1, limit);
out->_weight2 =
(out->_index2 == in_loc + 1 ? coeffs_table[(kTableSize - offset) * 2]
: 0.0f);
out->_index3 = bound(in_loc + 2, limit);
out->_weight3 = (out->_index3 == in_loc + 2
? coeffs_table[(kTableSize - offset) * 2 + 1]
: 0.0f);
const float weight_sum =
out->_weight0 + out->_weight1 + out->_weight2 + out->_weight3;
if (math::nd4j_abs(weight_sum) >= 1000.0f * DataTypeUtils::min<float>()) {
const float one_over_weight_sum = 1.0f / weight_sum;
out->_weight0 *= one_over_weight_sum;
out->_weight1 *= one_over_weight_sum;
out->_weight2 *= one_over_weight_sum;
out->_weight3 *= one_over_weight_sum;
}
} else {
out->_weight0 = coeffs_table[offset * 2 + 1];
out->_weight1 = coeffs_table[offset * 2];
out->_weight2 = coeffs_table[(kTableSize - offset) * 2];
out->_weight3 = coeffs_table[(kTableSize - offset) * 2 + 1];
out->_index0 = bound(in_loc - 1, limit);
out->_index1 = bound(in_loc, limit);
out->_index2 = bound(in_loc + 1, limit);
out->_index3 = bound(in_loc + 2, limit);
}
}
// Older incorrect scaling method that causes all resizes to have a slight
// translation leading to inconsistent results. For example, a flip then a
// resize gives different results then a resize then a flip.
struct LegacyScaler {
_CUDA_HD LegacyScaler(){};
inline _CUDA_HD float operator()(const int x, const float scale) const {
return static_cast<float>(x) * scale;
}
};
static __global__ void accumulateChannelsKernel(WeightsAndIndices* pXWais, Nd4jLong outWidth, Nd4jLong channels) {
auto start = blockIdx.x * blockDim.x + threadIdx.x;
auto step = blockDim.x * gridDim.x;
for (auto x = start; x < outWidth; x += step) {
pXWais[x]._index0 *= channels;
pXWais[x]._index1 *= channels;
pXWais[x]._index2 *= channels;
pXWais[x]._index3 *= channels;
}
}
static __global__ void advaceWeightsAndIndicesKernel(float const* cacheTable, CachedInterpolationCalculator* calc, WeightsAndIndices* pXWais, Nd4jLong inWidth, float widthScale,
Nd4jLong outWidth, Nd4jLong channels, bool halfPixelCenters) {
auto start = blockIdx.x * blockDim.x + threadIdx.x;
auto step = blockDim.x * gridDim.x;
for (auto x = start; x < outWidth; x += step) {
if (halfPixelCenters)
getWeightsAndIndices<HalfPixelScaler, true>(cacheTable, widthScale, x, inWidth, &pXWais[x]);
else
getWeightsAndIndices<LegacyScaler, false>(cacheTable, widthScale, x, inWidth, &pXWais[x]);
pXWais[x]._advance = calc->Advance(pXWais[x]._index0, pXWais[x]._index1, pXWais[x]._index2, pXWais[x]._index3);
}
}
// resizerState and xWais are device allocated
static void computeXWeightsAndIndices(float const* coeffsTable, const ImageResizerState& resizerState,
const bool halfPixelCenters,
WeightsAndIndices* pXWais) {
auto stream = resizerState.stream;
auto outWidth = resizerState.outWidth;
CachedInterpolationCalculator calc; // = new CachedInterpolationCalculator;
CachedInterpolationCalculator* pCalcD;
auto err = cudaMalloc(&pCalcD, sizeof(CachedInterpolationCalculator));
if (err != 0) {
cuda_exception::build("helpers::computeXWeightsAndIndices: Cannot allocated device memory for interpolate calculator", err);
}
err = cudaMemcpy(pCalcD, &calc, sizeof(CachedInterpolationCalculator), cudaMemcpyHostToDevice);
if (err != 0) {
cuda_exception::build("helpers::computeXWeightsAndIndices: Cannot set up device memory for interpolate calculator", err);
}
advaceWeightsAndIndicesKernel<<<128, 128, 128, *stream>>>(coeffsTable, pCalcD, pXWais, resizerState.inWidth, resizerState.widthScale, outWidth, resizerState.channels, halfPixelCenters);
err = cudaFree(pCalcD);
if (err != 0) {
cuda_exception::build("helpers::computeXWeightsAndIndices: Cannot deallocated device memory for interpolate calculator", err);
}
err = cudaStreamSynchronize(*stream);
if (err != 0) {
cuda_exception::build("helpers::computeXWeightsAndIndices: Cannot synchronize stream after advance weights and indicers", err);
}
// Scale the values so they can be used as offsets into buffers.
accumulateChannelsKernel<<<128, 128, 512, *stream>>>(pXWais, outWidth, resizerState.channels);
err = cudaStreamSynchronize(*stream);
if (err != 0) {
cuda_exception::build("helpers::computeXWeightsAndIndices: Cannot synchronize stream after accumulate channels", err);
}
}
template <typename T>
static _CUDA_HD FORCEINLINE float computeYInterpolation(
int which, int channelNum, const WeightsAndIndices& yWai,
const T* pY0, const T* pY1, const T* pY2, const T* pY3,
const WeightsAndIndices& xWai) {
int xIndex;
switch (which) {
case 0:
xIndex = xWai._index0;
break;
case 1:
xIndex = xWai._index1;
break;
case 2:
xIndex = xWai._index2;
break;
default:
xIndex = xWai._index3;
break;
}
const Nd4jLong pt_index = xIndex + channelNum;
return interpolate1D<T>(yWai._weight0, yWai._weight1, yWai._weight2,
yWai._weight3, pY0[pt_index], pY1[pt_index],
pY2[pt_index], pY3[pt_index]);
}
template <typename T>
static __global__ void bicubicInterpolateWithCachingKernel(float const* cachedTable, float* cachedValue, T const* inputPtr, ImageResizerState* pResizerState, WeightsAndIndices* xWais, bool halfPixelCenters, Nd4jLong inBatchWidth, Nd4jLong inRowWidth, T* outputPtr) {
// auto numChannels = pResizerState->channels;
for (Nd4jLong b = blockIdx.x; b < pResizerState->batchSize; b += gridDim.x) {
auto pInput = inputPtr + b * inBatchWidth;
for (Nd4jLong y = threadIdx.x; y < pResizerState->outHeight; y += blockDim.x) {
auto pos = (b * pResizerState->outHeight + y) * pResizerState->outWidth * pResizerState->channels;
auto pOutput = &outputPtr[pos];
struct WeightsAndIndices yWai;
if (halfPixelCenters) {
getWeightsAndIndices<HalfPixelScaler, true>(cachedTable, pResizerState->heightScale, y, pResizerState->inHeight, &yWai);
} else {
getWeightsAndIndices<LegacyScaler, false>(cachedTable, pResizerState->heightScale, y, pResizerState->inHeight, &yWai);
}
// Make pointers represent offsets of data in inputBPtr.
const T* y_ptr_0 = pInput + yWai._index0 * inRowWidth;
const T* y_ptr_1 = pInput + yWai._index1 * inRowWidth;
const T* y_ptr_2 = pInput + yWai._index2 * inRowWidth;
const T* y_ptr_3 = pInput + yWai._index3 * inRowWidth;
if (pResizerState->channels == 3) {
// Manually unroll case of 3 channels.
float cached_value_0[4] = {0};
float cached_value_1[4] = {0};
float cached_value_2[4] = {0};
for (Nd4jLong x = 0; x < pResizerState->outWidth; ++x) {
const WeightsAndIndices& xWai = xWais[x];
// Shift values in cached_value_* to fill first '_advance' values.
switch (xWai._advance) {
case 3:
cached_value_0[0] = cached_value_0[1];
cached_value_0[1] = cached_value_0[2];
cached_value_0[2] = cached_value_0[3];
cached_value_1[0] = cached_value_1[1];
cached_value_1[1] = cached_value_1[2];
cached_value_1[2] = cached_value_1[3];
cached_value_2[0] = cached_value_2[1];
cached_value_2[1] = cached_value_2[2];
cached_value_2[2] = cached_value_2[3];
break;
case 2:
cached_value_0[0] = cached_value_0[2];
cached_value_0[1] = cached_value_0[3];
cached_value_1[0] = cached_value_1[2];
cached_value_1[1] = cached_value_1[3];
cached_value_2[0] = cached_value_2[2];
cached_value_2[1] = cached_value_2[3];
break;
case 1: {
cached_value_0[0] = cached_value_0[3];
cached_value_1[0] = cached_value_1[3];
cached_value_2[0] = cached_value_2[3];
break;
}
}
// Set the remaining '4-_advance' values by computing.
switch (xWai._advance) {
case 0:
cached_value_0[0] = computeYInterpolation(0, 0, yWai, y_ptr_0, y_ptr_1, y_ptr_2, y_ptr_3, xWai);
cached_value_1[0] = computeYInterpolation(0, 1, yWai, y_ptr_0, y_ptr_1, y_ptr_2, y_ptr_3, xWai);
cached_value_2[0] = computeYInterpolation(0, 2, yWai, y_ptr_0, y_ptr_1, y_ptr_2, y_ptr_3, xWai);
case 1:
cached_value_0[1] = computeYInterpolation(1, 0, yWai, y_ptr_0, y_ptr_1, y_ptr_2, y_ptr_3, xWai);
cached_value_1[1] = computeYInterpolation(1, 1, yWai, y_ptr_0, y_ptr_1, y_ptr_2, y_ptr_3, xWai);
cached_value_2[1] = computeYInterpolation(1, 2, yWai, y_ptr_0, y_ptr_1, y_ptr_2, y_ptr_3, xWai);
case 2:
cached_value_0[2] = computeYInterpolation(2, 0, yWai, y_ptr_0, y_ptr_1, y_ptr_2, y_ptr_3, xWai);
cached_value_1[2] = computeYInterpolation(2, 1, yWai, y_ptr_0, y_ptr_1, y_ptr_2, y_ptr_3, xWai);
cached_value_2[2] = computeYInterpolation(2, 2, yWai, y_ptr_0, y_ptr_1, y_ptr_2, y_ptr_3, xWai);
case 3:
cached_value_0[3] = computeYInterpolation(3, 0, yWai, y_ptr_0, y_ptr_1, y_ptr_2, y_ptr_3, xWai);
cached_value_1[3] = computeYInterpolation(3, 1, yWai, y_ptr_0, y_ptr_1, y_ptr_2, y_ptr_3, xWai);
cached_value_2[3] = computeYInterpolation(3, 2, yWai, y_ptr_0, y_ptr_1, y_ptr_2, y_ptr_3, xWai);
// break;
}
pOutput[x * pResizerState->channels + 0] = compute(cached_value_0, xWai._weight0, xWai._weight1,
xWai._weight2, xWai._weight3);
pOutput[x * pResizerState->channels + 1] = compute(cached_value_1, xWai._weight0, xWai._weight1,
xWai._weight2, xWai._weight3);
pOutput[x * pResizerState->channels + 2] = compute(cached_value_2, xWai._weight0, xWai._weight1,
xWai._weight2, xWai._weight3);
}
} else {
for (Nd4jLong x = 0; x < pResizerState->outWidth; ++x) {
const WeightsAndIndices& xWai = xWais[x];
// Shift values in cachedValue to fill first '_advance' values.
switch (xWai._advance) {
case 3:
for (Nd4jLong c = 0; c < pResizerState->channels; ++c) {
cachedValue[4 * c + 0] = cachedValue[4 * c + 1];
cachedValue[4 * c + 1] = cachedValue[4 * c + 2];
cachedValue[4 * c + 2] = cachedValue[4 * c + 3];
}
break;
case 2:
for (Nd4jLong c = 0; c < pResizerState->channels; ++c) {
cachedValue[4 * c + 0] = cachedValue[4 * c + 2];
cachedValue[4 * c + 1] = cachedValue[4 * c + 3];
}
break;
case 1: {
for (Nd4jLong c = 0; c < pResizerState->channels; ++c) {
cachedValue[4 * c + 0] = cachedValue[4 * c + 3];
}
break;
}
}
// Set the remaining '4-_advance' values by computing.
switch (xWai._advance) {
case 0:
for (Nd4jLong c = 0; c < pResizerState->channels; ++c) {
cachedValue[4 * c + 0] = computeYInterpolation(0, c, yWai, y_ptr_0, y_ptr_1, y_ptr_2, y_ptr_3, xWai);
}
case 1:
for (Nd4jLong c = 0; c < pResizerState->channels; ++c) {
cachedValue[4 * c + 1] = computeYInterpolation(1, c, yWai, y_ptr_0, y_ptr_1, y_ptr_2, y_ptr_3, xWai);
}
case 2:
for (Nd4jLong c = 0; c < pResizerState->channels; ++c) {
cachedValue[4 * c + 2] = computeYInterpolation(2, c, yWai, y_ptr_0, y_ptr_1, y_ptr_2, y_ptr_3, xWai);
}
case 3:
for (Nd4jLong c = 0; c < pResizerState->channels; ++c) {
cachedValue[4 * c + 3] = computeYInterpolation(3, c, yWai, y_ptr_0, y_ptr_1, y_ptr_2, y_ptr_3, xWai);
}
// break;
}
for (Nd4jLong c = 0; c < pResizerState->channels; ++c) {
pOutput[x * pResizerState->channels + c] = compute(&cachedValue[4 * c], xWai._weight0, xWai._weight1, xWai._weight2, xWai._weight3);
}
}
}
}
}
}
template <typename T>
static void
bicubicInterpolateWithCaching(NDArray const* image, ImageResizerState const& resizerState, bool const halfPixelCenters, NDArray* output) {
const auto numChannels = resizerState.channels;
const Nd4jLong inRowWidth = resizerState.inWidth * numChannels;
const Nd4jLong inBatchWidth = resizerState.inHeight * inRowWidth;
auto stream = resizerState.stream; //output->getContext()->getCudaStream();
ImageResizerState* resizerStateD;
auto err = cudaMalloc(&resizerStateD, sizeof(ImageResizerState));
if (err != 0) {
throw cuda_exception::build("helpers::bicubicInterpolateWithCaching: Cannot allocate memory for resizerState", err);
}
err = cudaMemcpy(resizerStateD, &resizerState, sizeof(ImageResizerState), cudaMemcpyHostToDevice);
if (err != 0) {
throw cuda_exception::build("helpers::bicubicInterpolateWithCaching: Cannot set up memory for resizerState", err);
}
float* cachedValue = nullptr;
size_t cachedSize = sizeof(float) * (numChannels == 3 ? 0 : 4 * numChannels);
if (cachedSize) {
err = cudaMalloc(reinterpret_cast<void**>(&cachedValue), cachedSize);
if (err != 0) {
throw cuda_exception::build(
"helpers::bicubicInterpolateWithCaching: Cannot allocate memory for cached values", err);
}
err = cudaMemset(cachedValue, 0, cachedSize);
if (err != 0) {
throw cuda_exception::build(
"helpers::bicubicInterpolateWithCaching: Cannot set up memory for cached values", err);
}
}
WeightsAndIndices* xWais; //(resizerState.outWidth);
err = cudaMalloc(&xWais, sizeof(WeightsAndIndices) * resizerState.outWidth);
if (err != 0) {
throw cuda_exception::build("helpers::bicubicInterpolateWithCaching: Cannot allocate memory for weights and indices", err);
}
auto coeffsTable = halfPixelCenters?initCoeffsTable(-0.5, stream): initCoeffsTable(-0.75, stream);
if (err != 0) {
throw cuda_exception::build("helpers::bicubicInterpolateWithCaching: computeXWeigtsAndInidces finished with error", err);
}
computeXWeightsAndIndices(coeffsTable, resizerState, halfPixelCenters, xWais);
err = cudaStreamQuery(*stream);
if (err != 0) {
throw cuda_exception::build("helpers::bicubicInterpolateWithCaching: computeXWeigtsAndInidces finished with error", err);
}
const T* pInput = image->getDataBuffer()->specialAsT<T>();
T* pOutput = output->dataBuffer()->specialAsT<T>(); //_data.data();
bicubicInterpolateWithCachingKernel<T><<<128, 1, 512, *stream>>>(coeffsTable, cachedValue, pInput,
resizerStateD, xWais, halfPixelCenters, inBatchWidth, inRowWidth, pOutput);
err = cudaStreamSynchronize(*stream);
if (err != 0) {
throw cuda_exception::build("helpers::bicubicInterpolateWithCaching: Kernels finished with error", err);
}
err = cudaFree(resizerStateD);
if (err != 0) {
throw cuda_exception::build("helpers::bicubicInterpolateWithCaching: Cannot deallocate memory for resizerState", err);
}
if (cachedSize)
err = cudaFree(cachedValue);
if (err != 0) {
throw cuda_exception::build("helpers::bicubicInterpolateWithCaching: Cannot deallocate memory for cached values", err);
}
err = cudaFree(xWais);
if (err != 0) {
throw cuda_exception::build("helpers::bicubicInterpolateWithCaching: Cannot deallocate memory for weights and indices", err);
}
err = cudaFree(coeffsTable);
if (err != 0) {
throw cuda_exception::build("helpers::bicubicInterpolateWithCaching: Cannot deallocate memory for coefficients table", err);
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
template <typename T>
int resizeBicubicFunctor_(nd4j::LaunchContext * context, NDArray const* image, int width, int height,
bool preserveAspectRatio, bool antialias, NDArray* output) {
return Status::OK();
}
int resizeBicubicFunctor(nd4j::LaunchContext * context, NDArray const* image, int width, int height,
bool preserveAspectRatio, bool antialias, NDArray* output) {
BUILD_SINGLE_SELECTOR(image->dataType(), return resizeBicubicFunctor_, (context, image,
width, height, preserveAspectRatio, antialias, output), NUMERIC_TYPES);
}
BUILD_SINGLE_TEMPLATE(template int resizeBicubicFunctor_, (nd4j::LaunchContext * context, NDArray const* image, int width, int height,
bool preserveAspectRatio, bool antialias, NDArray* output), NUMERIC_TYPES);
// ------------------------------------------------------------------------------------------------------------------ //
// ------------------------------------------------------------------------------------------------------------------ //
// simplified bicubic resize without antialiasing
//
template <typename T>
int resizeBicubicFunctorA_(nd4j::LaunchContext * context, NDArray const* image, int width, int height,
bool const alignCorners, bool const halfPixelCenters, NDArray* output) {
ImageResizerState st(alignCorners, halfPixelCenters); // align_corners, half_pixel_align
st.stream = context->getCudaStream();
NDArray::prepareSpecialUse({output}, {image});
int res = st.validateAndCreateOutput(image, width, height);
if (res == Status::OK())
bicubicInterpolateWithCaching<T>(image, st, halfPixelCenters, output);
NDArray::registerSpecialUse({output}, {image});
return res;
}
int resizeBicubicFunctorA(nd4j::LaunchContext * context, NDArray const* image, int width, int height,
bool const alignCorners, bool const halfPixelCenters, NDArray* output) {
BUILD_SINGLE_SELECTOR(image->dataType(), return resizeBicubicFunctorA_, (context,
image, width, height, alignCorners, halfPixelCenters, output), NUMERIC_TYPES);
}
BUILD_SINGLE_TEMPLATE(template int resizeBicubicFunctorA_, (nd4j::LaunchContext * context,
NDArray const* image, int width, int height, bool const alignCorners, bool const halfPixelCenters, NDArray* output), NUMERIC_TYPES);
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
int resizeFunctor(nd4j::LaunchContext * context, NDArray const* image, int width, int height, int resizeFunctor(nd4j::LaunchContext * context, NDArray const* image, int width, int height,
ImageResizeMethods method, bool preserveAspectRatio, bool antialias, NDArray* output) { ImageResizeMethods method, bool preserveAspectRatio, bool antialias, NDArray* output) {
switch (method) { switch (method) {
@ -316,13 +919,13 @@ namespace helpers {
//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
// --------------------------------------------------------------------------------------------------------------- // // --------------------------------------------------------------------------------------------------------------- //
// Crop and Resize helper implementation // Crop and Resize helper implementation
// --------------------------------------------------------------------------------------------------------------- // // -------------------------------------------------------------------------------------------------------------- //
// cropAndResize kernel // cropAndResize kernel type of input(images) and output should be the same
// //
template <typename T, typename Z, typename I> template <typename T, typename Z, typename I>
static __global__ void cropAndResizeKernel(T const *images, Nd4jLong* imagesShape, Z const* boxes, Nd4jLong* boxesShape, static __global__ void cropAndResizeKernel(T const *images, Nd4jLong* imagesShape, Z const* boxes, Nd4jLong* boxesShape,
I const* indices, Nd4jLong* indexShape, I const* cropSize, Nd4jLong* cropShape, int method, I const* indices, Nd4jLong* indexShape, I const* cropSize, Nd4jLong* cropShape, int method,
double extrapolationVal, Z* output, Nd4jLong* outputShape, int numBoxes, int cropHeight, int cropWidth, double extrapolationVal, T* output, Nd4jLong* outputShape, int numBoxes, int cropHeight, int cropWidth,
int batchSize, int imageHeight, int imageWidth, int depth) { int batchSize, int imageHeight, int imageWidth, int depth) {
for (int b = blockIdx.x; b < numBoxes; b += gridDim.x) for (int b = blockIdx.x; b < numBoxes; b += gridDim.x)
@ -464,7 +1067,7 @@ namespace helpers {
Z const* boxesBuf = reinterpret_cast<Z const*>(boxes->getSpecialBuffer()); Z const* boxesBuf = reinterpret_cast<Z const*>(boxes->getSpecialBuffer());
I const* indexBuf = reinterpret_cast<I const*>(indices->getSpecialBuffer()); I const* indexBuf = reinterpret_cast<I const*>(indices->getSpecialBuffer());
I const* cropSizes = reinterpret_cast<I const*>(cropSize->getSpecialBuffer()); I const* cropSizes = reinterpret_cast<I const*>(cropSize->getSpecialBuffer());
Z* outBuf = reinterpret_cast<Z*>(crops->specialBuffer()); T* outBuf = reinterpret_cast<T*>(crops->specialBuffer());
NDArray::prepareSpecialUse({crops}, {images, boxes, indices, cropSize}); NDArray::prepareSpecialUse({crops}, {images, boxes, indices, cropSize});
cropAndResizeKernel<T,Z,I><<<batchSize, math::nd4j_max(imageHeight * imageWidth, cropHeight * cropWidth), 512, *stream>>>(imagesBuf, images->getSpecialShapeInfo(), boxesBuf, boxes->getSpecialShapeInfo(), indexBuf, indices->getSpecialShapeInfo(), cropAndResizeKernel<T,Z,I><<<batchSize, math::nd4j_max(imageHeight * imageWidth, cropHeight * cropWidth), 512, *stream>>>(imagesBuf, images->getSpecialShapeInfo(), boxesBuf, boxes->getSpecialShapeInfo(), indexBuf, indices->getSpecialShapeInfo(),

View File

@ -215,6 +215,10 @@ namespace nd4j {
} else { } else {
addInfVectorKernel<T><<<128, 256, 256, *stream>>>(infVector, neu1e, vectorLength); addInfVectorKernel<T><<<128, 256, 256, *stream>>>(infVector, neu1e, vectorLength);
} }
err = cudaStreamSynchronize(*stream);
if (0 != err) {
throw cuda_exception::build("helpers::skipgram_: Cannot synchronize stream after addInfVectorKernel", err);
}
err = cudaFree(neu1e); err = cudaFree(neu1e);
if (0 != err) { if (0 != err) {
@ -317,10 +321,15 @@ namespace nd4j {
} }
} }
addInfVectorKernel<T><<<128, 256, 256, *stream>>>(syn0row, neu1e, vectorLength); addInfVectorKernel<T><<<128, 256, 256, *stream>>>(syn0row, neu1e, vectorLength);
err = cudaStreamSynchronize(*stream);
if (0 != err) {
throw cuda_exception::build("helpers::skipgramBatchExec_: Cannot synchronize stream after addInfVectorKernel", err);
}
// optionally release temp arrays // optionally release temp arrays
err = cudaFree(neu1e); err = cudaFree(neu1e);
if (err != 0) { if (err != 0) {
throw cuda_exception::build("helpers::skipgramBatchExec_: Cannot deallocate memory with stage", err);
break; break;
} }
// if (vectorLength > 600) // if (vectorLength > 600)
@ -485,9 +494,22 @@ namespace nd4j {
infVector[i] += neu1e[i]; infVector[i] += neu1e[i];
} }
} }
err = cudaStreamSynchronize(*stream);
if (0 != err) {
throw cuda_exception::build(
"helpers::cbow_: Cannot synchronize stream after kernel executing", err);
}
err = cudaFree(neu1); err = cudaFree(neu1);
if (0 != err) {
throw cuda_exception::build(
"helpers::cbow_: Cannot deallocate memory for synonims table", err);
}
err = cudaFree(neu1e); err = cudaFree(neu1e);
if (0 != err) {
throw cuda_exception::build(
"helpers::cbow_: Cannot deallocate memory for antonims table", err);
}
} }
BUILD_SINGLE_TEMPLATE(template void cbow_, (LaunchContext* lc, void *syn0, void *syn1, void *syn1Neg, void *expTable, void *vnegTable, void *vinfVector, int target, int ngStarter, int *context, int *lockedWords, int *indices, int8_t *codes, double alpha, Nd4jLong randomValue, const int contextWidth, const int hsRounds, const int nsRounds, const int vocabSize, const int vectorLength, const int expLength, const int negLength, const int numLabels, const bool trainWords), FLOAT_TYPES); BUILD_SINGLE_TEMPLATE(template void cbow_, (LaunchContext* lc, void *syn0, void *syn1, void *syn1Neg, void *expTable, void *vnegTable, void *vinfVector, int target, int ngStarter, int *context, int *lockedWords, int *indices, int8_t *codes, double alpha, Nd4jLong randomValue, const int contextWidth, const int hsRounds, const int nsRounds, const int vocabSize, const int vectorLength, const int expLength, const int negLength, const int numLabels, const bool trainWords), FLOAT_TYPES);
@ -574,13 +596,13 @@ namespace nd4j {
const auto hsRounds = codes.isEmpty() ? 0 : codes.sizeAt(1); const auto hsRounds = codes.isEmpty() ? 0 : codes.sizeAt(1);
const auto numTargets = context.sizeAt(0); const auto numTargets = context.sizeAt(0);
const int contextWidth = context.sizeAt(1); const int contextWidth = context.sizeAt(1);
const auto bContext = reinterpret_cast<int*>(context.buffer()); //bufferAsT<int>(); //const auto bContext = reinterpret_cast<int*>(context.buffer()); //bufferAsT<int>();
const auto dContext = reinterpret_cast<int*>(context.specialBuffer()); //bufferAsT<int>(); const auto dContext = context.dataBuffer()->specialAsT<int>(); //bufferAsT<int>();
const auto bLocker = reinterpret_cast<int*>(lockedWords.buffer()); //lockedWords.bufferAsT<int>(); // const auto bLocker = reinterpret_cast<int*>(lockedWords.buffer()); //lockedWords.bufferAsT<int>();
const auto dLocker = reinterpret_cast<int*>(lockedWords.specialBuffer()); //lockedWords.bufferAsT<int>(); const auto dLocker = lockedWords.dataBuffer()->specialAsT<int>(); //.specialBuffer()); //lockedWords.bufferAsT<int>();
const auto bIndices = reinterpret_cast<int*>(indices.buffer());//AsT<int>(); const auto bIndices = indices.dataBuffer()->primaryAsT<int>(); //buffer());//AsT<int>();
const auto bCodes = reinterpret_cast<int8_t*>(codes.buffer()); //bufferAsT<int8_t>(); const auto bCodes = codes.dataBuffer()->primaryAsT<int8_t>(); //reinterpret_cast<int8_t*>(codes.buffer()); //bufferAsT<int8_t>();
const auto bStarters = reinterpret_cast<int*>(negStarters.buffer()); //AsT<int>(); const auto bStarters = negStarters.dataBuffer()->primaryAsT<int>(); //reinterpret_cast<int*>(negStarters.buffer()); //AsT<int>();
const auto numIndices = indices.isEmpty() ? 0 : indices.sizeAt(1); const auto numIndices = indices.isEmpty() ? 0 : indices.sizeAt(1);
lr.syncToHost(); lr.syncToHost();
nLabels.syncToHost(); nLabels.syncToHost();
@ -678,6 +700,11 @@ namespace nd4j {
// } // }
} }
cerr = cudaStreamSynchronize(*stream);
if (cerr) {
throw cuda_exception::build("Cannot syncronize stream before memory deallocation", cerr);
}
cerr = cudaFree(neu1); cerr = cudaFree(neu1);
if (cerr) { if (cerr) {
throw cuda_exception::build("Cannot deallocate temp buffer1", cerr); throw cuda_exception::build("Cannot deallocate temp buffer1", cerr);

View File

@ -43,6 +43,8 @@ namespace helpers {
NDArray* output); NDArray* output);
int resizeBicubicFunctor(nd4j::LaunchContext * context, NDArray const* image, int width, int height, int resizeBicubicFunctor(nd4j::LaunchContext * context, NDArray const* image, int width, int height,
bool preserveAspectRatio, bool antialias, NDArray* output); bool preserveAspectRatio, bool antialias, NDArray* output);
int resizeBicubicFunctorA(nd4j::LaunchContext * context, NDArray const* image, int width, int height,
bool const alignCorners, bool const halfPixelAlign, NDArray* output);
int resizeFunctor(nd4j::LaunchContext * context, NDArray const* image, int width, int height, int resizeFunctor(nd4j::LaunchContext * context, NDArray const* image, int width, int height,
ImageResizeMethods method, bool preserveAspectRatio, bool antialias, NDArray* output); ImageResizeMethods method, bool preserveAspectRatio, bool antialias, NDArray* output);

View File

@ -50,7 +50,7 @@ namespace helpers {
Nd4jLong space_shape[NUM_BLOCK_DIMS]; Nd4jLong space_shape[NUM_BLOCK_DIMS];
Nd4jLong batch_shape[NUM_BLOCK_DIMS]; Nd4jLong batch_shape[NUM_BLOCK_DIMS];
const int batch_size = batch->sizeAt(0); const int batchSize = batch->sizeAt(0);
const int space_size = space->sizeAt(0); const int space_size = space->sizeAt(0);
#pragma unroll #pragma unroll
@ -65,7 +65,7 @@ namespace helpers {
auto batch_strides = batch->stridesOf(); auto batch_strides = batch->stridesOf();
// TODO: this loop should be moved to _execute phase // TODO: this loop should be moved to _execute phase
for (int batch_b = 0; batch_b < batch_size; ++batch_b) { for (int batch_b = 0; batch_b < batchSize; ++batch_b) {
const Nd4jLong space_b = batch_b % space_size; const Nd4jLong space_b = batch_b % space_size;
Nd4jLong block_index = batch_b / space_size; Nd4jLong block_index = batch_b / space_size;
Nd4jLong block_offsets[NUM_BLOCK_DIMS]; Nd4jLong block_offsets[NUM_BLOCK_DIMS];

View File

@ -1528,6 +1528,111 @@ TEST_F(DeclarableOpsTests10, ImageResizeBilinear_Test01) {
delete results; delete results;
} }
TEST_F(DeclarableOpsTests10, ImageResizeBilinear_Test02) {
NDArray input = NDArrayFactory::create<float>('c', {2, 5,5,3}, {0.7788f, 0.8012f, 0.7244f,
0.2309f, 0.7271f, 0.1804f,
0.5056f, 0.8925f, 0.5461f,
0.9234f, 0.0856f, 0.7938f,
0.6591f, 0.5555f, 0.1596f,
0.3087f, 0.1548f, 0.4695f,
0.9939f, 0.6113f, 0.6765f,
0.1800f, 0.6750f, 0.2246f,
0.0509f, 0.4601f, 0.8284f,
0.2354f, 0.9752f, 0.8361f,
0.2585f, 0.4189f, 0.7028f,
0.7679f, 0.5373f, 0.7234f,
0.2690f, 0.0062f, 0.0327f,
0.0644f, 0.8428f, 0.7494f,
0.0755f, 0.6245f, 0.3491f,
0.5793f, 0.5730f, 0.1822f,
0.6420f, 0.9143f, 0.3019f,
0.3574f, 0.1704f, 0.8395f,
0.5468f, 0.0744f, 0.9011f,
0.6574f, 0.4124f, 0.2445f,
0.4248f, 0.5219f, 0.6952f,
0.4900f, 0.2158f, 0.9549f,
0.1386f, 0.1544f, 0.5365f,
0.0134f, 0.4163f, 0.1456f,
0.4109f, 0.2484f, 0.3330f,
0.2974f, 0.6636f, 0.3808f,
0.8664f, 0.1896f, 0.7530f,
0.7215f, 0.6612f, 0.7270f,
0.5704f, 0.2666f, 0.7453f,
0.0444f, 0.3024f, 0.4850f,
0.7982f, 0.0965f, 0.7843f,
0.5075f, 0.0844f, 0.8370f,
0.6103f, 0.4604f, 0.6087f,
0.8594f, 0.4599f, 0.6714f,
0.2744f, 0.1981f, 0.4143f,
0.7821f, 0.3505f, 0.5040f,
0.1180f, 0.8307f, 0.1817f,
0.8442f, 0.5074f, 0.4471f,
0.5105f, 0.6666f, 0.2576f,
0.2341f, 0.6801f, 0.2652f,
0.5394f, 0.4690f, 0.6146f,
0.1210f, 0.2576f, 0.0769f,
0.4643f, 0.1628f, 0.2026f,
0.3774f, 0.0506f, 0.3462f,
0.5720f, 0.0838f, 0.4228f,
0.0588f, 0.5362f, 0.4756f,
0.2530f, 0.1778f, 0.0751f,
0.8977f, 0.3648f, 0.3065f,
0.4739f, 0.7014f, 0.4473f,
0.5171f, 0.1744f, 0.3487f});
NDArray expected = NDArrayFactory::create<float>('c', {10, 10, 4}, {1., 2., 3., 4., 2.2, 3.2, 4.2, 5.2, 3.4, 4.4, 5.4, 6.4,
4.6, 5.6, 6.6, 7.6, 5.8, 6.8, 7.8, 8.8, 7., 8., 9., 10.,
8.2, 9.2, 10.2, 11.2, 9., 10., 11., 12., 9., 10., 11., 12.,
9., 10., 11., 12., 3.4, 4.4, 5.4, 6.4, 4.6, 5.6, 6.6, 7.6,
5.8, 6.8, 7.8, 8.8, 7.0, 8., 9., 10., 8.2, 9.2, 10.2, 11.2,
9.4,10.4, 11.4, 12.4,10.6, 11.6,12.6, 13.6,11.4, 12.4, 13.4, 14.4,
11.4,12.4, 13.4, 14.4,11.4, 12.4,13.4, 14.4, 5.8, 6.8, 7.8, 8.8,
7., 8., 9., 10., 8.2, 9.2,10.2, 11.2, 9.4, 10.4, 11.4, 12.4,
10.6,11.6, 12.6, 13.6,11.8, 12.8,13.8, 14.8,13.0, 14.0, 15.0, 16.,
13.8,14.8, 15.8, 16.8,13.8, 14.8,15.8, 16.8,13.8, 14.8, 15.8, 16.8,
8.2, 9.2, 10.2, 11.2, 9.4, 10.4,11.4, 12.4,10.6, 11.6, 12.6, 13.6,
11.8,12.8, 13.8, 14.8,13., 14., 15., 16., 14.2, 15.2, 16.2, 17.2,
15.4,16.4, 17.4, 18.4,16.2, 17.2,18.2, 19.2,16.2, 17.2, 18.2, 19.2,
16.2,17.2, 18.2, 19.2,10.6, 11.6,12.6, 13.6,11.8, 12.8, 13.8, 14.8,
13., 14., 15., 16., 14.2, 15.2,16.2, 17.2,15.4, 16.4, 17.4, 18.4,
16.6,17.6, 18.6, 19.6,17.8, 18.8,19.8, 20.8,18.6, 19.6, 20.6, 21.6,
18.6,19.6, 20.6, 21.6,18.6, 19.6,20.6, 21.6,13., 14., 15., 16.,
14.2,15.2, 16.2, 17.2,15.4, 16.4,17.4, 18.4,16.6, 17.6, 18.6, 19.6,
17.8,18.8, 19.8, 20.8,19., 20., 21., 22., 20.2, 21.2, 22.2, 23.2,
21., 22., 23., 24., 21., 22., 23., 24., 21., 22., 23., 24.,
13., 14., 15., 16., 14.2, 15.2,16.2, 17.2,15.4, 16.4, 17.4, 18.4,
16.6,17.6, 18.6, 19.6,17.8, 18.8, 19.8, 20.8,19., 20., 21., 22.,
20.2,21.2, 22.2, 23.2,21., 22., 23., 24., 21., 22., 23., 24.,
21., 22., 23., 24., 13., 14., 15., 16., 14.2, 15.2, 16.2, 17.2,
15.4,16.4, 17.4, 18.4,16.6, 17.6, 18.6, 19.6,17.8, 18.8, 19.8, 20.8,
19., 20., 21., 22., 20.2, 21.2, 22.2, 23.2,21., 22., 23., 24.,
21., 22., 23., 24., 21., 22., 23., 24., 13., 14., 15., 16.,
14.2,15.2, 16.2, 17.2,15.4, 16.4, 17.4, 18.4,16.6, 17.6, 18.6, 19.6,
17.8,18.8, 19.8, 20.8,19., 20., 21., 22., 20.2, 21.2, 22.2, 23.2,
21., 22., 23., 24., 21., 22., 23., 24., 21., 22., 23., 24.,
13., 14., 15., 16., 14.2, 15.2, 16.2, 17.2,15.4, 16.4, 17.4, 18.4,
16.6,17.6, 18.6, 19.6,17.8, 18.8, 19.8, 20.8,19., 20., 21., 22.,
20.2,21.2, 22.2, 23.2,
21., 22., 23., 24., 21., 22., 23., 24., 21., 22., 23., 24.});
//input.linspace(1);
nd4j::ops::resize_bilinear op;
auto results = op.execute({&input}, {}, {9, 9});
ASSERT_EQ(ND4J_STATUS_OK, results->status());
NDArray* result = results->at(0);
result->printIndexedBuffer("Resized to 9x9");
//expected.printIndexedBuffer("Expect for 10x10");
result->printShapeInfo("Output shape");
// expected.printShapeInfo("Expect shape");
// ASSERT_TRUE(expected.isSameShape(result));
// ASSERT_TRUE(expected.equalsTo(result));
delete results;
}
//////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests10, ImageResizeBilinear_Test2) { TEST_F(DeclarableOpsTests10, ImageResizeBilinear_Test2) {
@ -2145,7 +2250,7 @@ TEST_F(DeclarableOpsTests10, Image_CropAndResize_1) {
NDArray cropSize = NDArrayFactory::create<int>({1, 1}); NDArray cropSize = NDArrayFactory::create<int>({1, 1});
//NDArray<float> ('c', {6}, {0.9f, .75f, .6f, .95f, .5f, .3f}); //NDArray<float> ('c', {6}, {0.9f, .75f, .6f, .95f, .5f, .3f});
NDArray expected = NDArrayFactory::create<float>('c', {1,1,1,1}, {2.5f}); NDArray expected = NDArrayFactory::create<double>('c', {1,1,1,1}, {2.5f});
nd4j::ops::crop_and_resize op; nd4j::ops::crop_and_resize op;
auto results = op.execute({&images, &boxes, &boxI, &cropSize}, {}, {}); auto results = op.execute({&images, &boxes, &boxI, &cropSize}, {}, {});
@ -2154,6 +2259,7 @@ TEST_F(DeclarableOpsTests10, Image_CropAndResize_1) {
auto result = results->at(0); auto result = results->at(0);
// result->printIndexedBuffer("Cropped and Resized"); // result->printIndexedBuffer("Cropped and Resized");
ASSERT_TRUE(expected.isSameShapeStrict(result)); ASSERT_TRUE(expected.isSameShapeStrict(result));
ASSERT_TRUE(expected.equalsTo(result)); ASSERT_TRUE(expected.equalsTo(result));
@ -2187,7 +2293,7 @@ TEST_F(DeclarableOpsTests10, Image_CropAndResize_2) {
//////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests10, Image_CropAndResize_3) { TEST_F(DeclarableOpsTests10, Image_CropAndResize_3) {
NDArray images ('c', {1,2,2,1}, {1,2,3,4}); NDArray images ('c', {1,2,2,1}, {1,2,3,4}, nd4j::DataType::FLOAT32);
NDArray boxes('c', {1,4}, {0,0,1,1}, nd4j::DataType::FLOAT32); NDArray boxes('c', {1,4}, {0,0,1,1}, nd4j::DataType::FLOAT32);
NDArray boxI('c', {1}, {0}, nd4j::DataType::INT64); NDArray boxI('c', {1}, {0}, nd4j::DataType::INT64);
NDArray cropSize = NDArrayFactory::create<Nd4jLong>({3, 3}); NDArray cropSize = NDArrayFactory::create<Nd4jLong>({3, 3});
@ -2211,7 +2317,7 @@ TEST_F(DeclarableOpsTests10, Image_CropAndResize_3) {
//////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests10, Image_CropAndResize_4) { TEST_F(DeclarableOpsTests10, Image_CropAndResize_4) {
NDArray images('c', {1,2,2,1}, {1, 2, 3, 4}); NDArray images('c', {1,2,2,1}, {1, 2, 3, 4}, nd4j::DataType::FLOAT32);
NDArray boxes('c', {1,4}, {0,0,1,1}, nd4j::DataType::FLOAT32); NDArray boxes('c', {1,4}, {0,0,1,1}, nd4j::DataType::FLOAT32);
NDArray boxI('c', {1}, {0}, nd4j::DataType::INT32); NDArray boxI('c', {1}, {0}, nd4j::DataType::INT32);
NDArray cropSize = NDArrayFactory::create<int>({3, 3}); NDArray cropSize = NDArrayFactory::create<int>({3, 3});
@ -2235,7 +2341,7 @@ TEST_F(DeclarableOpsTests10, Image_CropAndResize_4) {
//////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests10, Image_CropAndResize_5) { TEST_F(DeclarableOpsTests10, Image_CropAndResize_5) {
NDArray images('c', {1, 100, 100, 3}); NDArray images('c', {1, 100, 100, 3}, nd4j::DataType::FLOAT32);
NDArray boxes('c', {1,4}, {0,0,1,1}, nd4j::DataType::FLOAT32); NDArray boxes('c', {1,4}, {0,0,1,1}, nd4j::DataType::FLOAT32);
NDArray boxI('c', {2}, {1,1}, nd4j::DataType::INT32); NDArray boxI('c', {2}, {1,1}, nd4j::DataType::INT32);
NDArray cropSize = NDArrayFactory::create<int>({10, 10}); NDArray cropSize = NDArrayFactory::create<int>({10, 10});
@ -2260,23 +2366,23 @@ TEST_F(DeclarableOpsTests10, Image_CropAndResize_5) {
TEST_F(DeclarableOpsTests10, Image_DrawBoundingBoxes_1) { TEST_F(DeclarableOpsTests10, Image_DrawBoundingBoxes_1) {
NDArray images = NDArrayFactory::create<float>('c', {2,4,5,3}); NDArray images = NDArrayFactory::create<float>('c', {2,4,5,3});
NDArray boxes = NDArrayFactory::create<float>('c', {2, 2, 4}, { NDArray boxes = NDArrayFactory::create<float>('c', {2, 2, 4}, {
0. , 0. , 1. , 1. , 0.1, 0.2, 0.9, 0.8, 0.f , 0.f , 1.f , 1.f , 0.1f, 0.2f, 0.9f, 0.8f,
0.3, 0.3, 0.7, 0.7, 0.4, 0.4, 0.6, 0.6 0.3f, 0.3f, 0.7f, 0.7f, 0.4f, 0.4f, 0.6f, 0.6f
}); });
NDArray colors = NDArrayFactory::create<float>('c', {2, 3}, {201., 202., 203., 127., 128., 129.}); NDArray colors = NDArrayFactory::create<float>('c', {2, 3}, {201.f, 202.f, 203.f, 127.f, 128.f, 129.f});
//NDArray<float> ('c', {6}, {0.9f, .75f, .6f, .95f, .5f, .3f}); //NDArray<float> ('c', {6}, {0.9f, .75f, .6f, .95f, .5f, .3f});
NDArray expected = NDArrayFactory::create<float>('c', {2,4,5,3}, { NDArray expected = NDArrayFactory::create<float>('c', {2,4,5,3}, {
127., 128., 129., 127., 128., 129., 127., 128., 129., 127., 128., 129., 201., 202., 203., 127.f, 128.f, 129.f, 127.f, 128.f, 129.f, 127.f, 128.f, 129.f, 127.f, 128.f, 129.f, 201.f, 202.f, 203.f,
127., 128., 129., 19., 20., 21., 22., 23., 24., 127., 128., 129., 201., 202., 203., 127.f, 128.f, 129.f, 19.f, 20.f, 21.f, 22.f, 23.f, 24.f, 127.f, 128.f, 129.f, 201.f, 202.f, 203.f,
127., 128., 129., 127., 128., 129., 127., 128., 129., 127., 128., 129., 201., 202., 203., 127.f, 128.f, 129.f, 127.f, 128.f, 129.f, 127.f, 128.f, 129.f, 127.f, 128.f, 129.f, 201.f, 202.f, 203.f,
201., 202., 203., 201. ,202. ,203., 201., 202., 203., 201., 202., 203., 201., 202., 203., 201.f, 202.f, 203.f, 201.f, 202.f, 203.f, 201.f, 202.f, 203.f, 201.f, 202.f, 203.f, 201.f, 202.f, 203.f,
61., 62., 63., 201., 202., 203., 201., 202., 203., 70., 71., 72., 73., 74., 75., 61.f, 62.f, 63.f, 201.f, 202.f, 203.f, 201.f, 202.f, 203.f, 70.f, 71.f, 72.f, 73.f, 74.f, 75.f,
76., 77., 78., 127., 128., 129., 127., 128., 129., 85., 86., 87., 88., 89., 90., 76.f, 77.f, 78.f, 127.f, 128.f, 129.f, 127.f, 128.f, 129.f, 85.f, 86.f, 87.f, 88.f, 89.f, 90.f,
91., 92., 93., 201., 202., 203., 201., 202., 203., 100., 101., 102., 103., 104., 105., 91.f, 92.f, 93.f, 201.f, 202.f, 203.f, 201.f, 202.f, 203.f, 100.f, 101.f, 102.f, 103.f, 104.f, 105.f,
106., 107., 108., 109., 110., 111., 112., 113., 114., 115., 116., 117., 118., 119., 120. 106.f, 107.f, 108.f, 109.f, 110.f, 111.f, 112.f, 113.f, 114.f, 115.f, 116.f, 117.f, 118.f, 119.f, 120.f
}); });
images.linspace(1.); images.linspace(1.);
nd4j::ops::draw_bounding_boxes op; nd4j::ops::draw_bounding_boxes op;
@ -2297,20 +2403,20 @@ TEST_F(DeclarableOpsTests10, Image_DrawBoundingBoxes_1) {
//////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests10, Image_DrawBoundingBoxes_2) { TEST_F(DeclarableOpsTests10, Image_DrawBoundingBoxes_2) {
NDArray images = NDArrayFactory::create<float>('c', {1,9,9,1}); NDArray images = NDArrayFactory::create<float>('c', {1,9,9,1});
NDArray boxes = NDArrayFactory::create<float>('c', {1, 1, 4}, {0.2, 0.2, 0.7, 0.7}); NDArray boxes = NDArrayFactory::create<float>('c', {1, 1, 4}, {0.2f, 0.2f, 0.7f, 0.7f});
NDArray colors = NDArrayFactory::create<float>('c', {1, 1}, {0.95}); NDArray colors = NDArrayFactory::create<float>('c', {1, 1}, {0.95f});
//NDArray<float> ('c', {6}, {0.9f, .75f, .6f, .95f, .5f, .3f}); //NDArray<float> ('c', {6}, {0.9f, .75f, .6f, .95f, .5f, .3f});
NDArray expected = NDArrayFactory::create<float>('c', {1,9,9,1}, { NDArray expected = NDArrayFactory::create<float>('c', {1,9,9,1}, {
1.1 , 2.1, 3.1 , 4.1 , 5.1 , 6.1 , 7.1 , 8.1 , 9.1 , 1.1f , 2.1f, 3.1f, 4.1f, 5.1f, 6.1f, 7.1f , 8.1f , 9.1f ,
10.1 , 0.95, 0.95, 0.95, 0.95, 0.95, 16.1 , 17.1 , 18.1 , 10.1f , 0.95f, 0.95f, 0.95f, 0.95f, 0.95f, 16.1f , 17.1f , 18.1f ,
19.1 , 0.95, 21.1, 22.1, 23.1, 0.95, 25.1 , 26.1 , 27.1 , 19.1f , 0.95f, 21.1f, 22.1f, 23.1f, 0.95f, 25.1f , 26.1f , 27.1f ,
28.1 , 0.95, 30.1, 31.1, 32.1, 0.95, 34.1 , 35.1 , 36.1 , 28.1f , 0.95f, 30.1f, 31.1f, 32.1f, 0.95f, 34.1f , 35.1f , 36.1f ,
37.1 , 0.95, 39.1, 40.1, 41.1, 0.95, 43.1 , 44.1 , 45.1 , 37.1f , 0.95f, 39.1f, 40.1f, 41.1f, 0.95f, 43.1f , 44.1f , 45.1f ,
46.1 , 0.95, 0.95, 0.95, 0.95, 0.95, 52.1 , 53.1 , 54.1 , 46.1f , 0.95f, 0.95f, 0.95f, 0.95f, 0.95f, 52.1f , 53.1f , 54.1f ,
55.1 , 56.1, 57.1 , 58.1 , 59.1 , 60.1 , 61.1 , 62.1 , 63.1 , 55.1f , 56.1f, 57.1f, 58.1f, 59.1f , 60.1f, 61.1f , 62.1f , 63.1f ,
64.1 , 65.1, 66.1 , 67.1 , 68.1 , 69.1 , 70.1 , 71.1 , 72.1 , 64.1f , 65.1f, 66.1f, 67.1f, 68.1f , 69.1f, 70.1f , 71.1f , 72.1f ,
73.1 , 74.1, 75.1 , 76.1 , 77.1 , 78.1 , 79.1 , 80.1 , 81.1 }); 73.1f , 74.1f, 75.1f, 76.1f, 77.1f , 78.1f, 79.1f , 80.1f , 81.1f });
images.linspace(1.1); images.linspace(1.1);
nd4j::ops::draw_bounding_boxes op; nd4j::ops::draw_bounding_boxes op;
auto results = op.execute({&images, &boxes, &colors}, {}, {}); auto results = op.execute({&images, &boxes, &colors}, {}, {});
@ -2329,18 +2435,21 @@ TEST_F(DeclarableOpsTests10, Image_DrawBoundingBoxes_2) {
//////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests10, Image_DrawBoundingBoxes_3) { TEST_F(DeclarableOpsTests10, Image_DrawBoundingBoxes_3) {
NDArray images = NDArrayFactory::create<float>('c', {2,5,5,1}, {0.7788, 0.8012, 0.7244, 0.2309, 0.7271, 0.1804, NDArray images = NDArrayFactory::create<float>('c', {2,5,5,1}, {0.7788f, 0.8012f, 0.7244f, 0.2309f, 0.7271f, 0.1804f,
0.5056, 0.8925, 0.5461, 0.9234, 0.0856,0.7938, 0.5056f, 0.8925f, 0.5461f, 0.9234f, 0.0856f, 0.7938f,
0.6591, 0.5555, 0.1596, 0.3087, 0.1548, 0.4695, 0.6591f, 0.5555f, 0.1596f, 0.3087f, 0.1548f, 0.4695f,
0.9939, 0.6113,0.6765, 0.1800, 0.6750, 0.2246, 0.9939f, 0.6113f, 0.6765f, 0.1800f, 0.6750f, 0.2246f,
0.0509, 0.4601, 0.8284, 0.2354, 0.9752, 0.8361, 0.0509f, 0.4601f, 0.8284f, 0.2354f, 0.9752f, 0.8361f,
0.2585, 0.4189, 0.7028, 0.7679, 0.5373, 0.7234, 0.2585f, 0.4189f, 0.7028f, 0.7679f, 0.5373f, 0.7234f,
0.2690, 0.0062, 0.0327, 0.0644, 0.8428, 0.7494, 0.2690f, 0.0062f, 0.0327f, 0.0644f, 0.8428f, 0.7494f,
0.0755, 0.6245, 0.3491, 0.5793, 0.5730, 0.1822, 0.0755f, 0.6245f, 0.3491f, 0.5793f, 0.5730f, 0.1822f,
0.6420, 0.9143}); 0.6420f, 0.9143f});
NDArray boxes = NDArrayFactory::create<float>('c', {2, 2, 4}, {0.7717, 0.9281, 0.9846, 0.4838, 0.6433, 0.6041, 0.6501, 0.7612, 0.7605, 0.3948, 0.9493, 0.8600, 0.7876, 0.8945, 0.4638, 0.7157}); NDArray boxes = NDArrayFactory::create<float>('c', {2, 2, 4}, {0.7717f, 0.9281f, 0.9846f, 0.4838f,
NDArray colors = NDArrayFactory::create<float>('c', {1, 2}, {0.9441, 0.5957}); 0.6433f, 0.6041f, 0.6501f, 0.7612f,
0.7605f, 0.3948f, 0.9493f, 0.8600f,
0.7876f, 0.8945f, 0.4638f, 0.7157f});
NDArray colors = NDArrayFactory::create<float>('c', {1, 2}, {0.9441f, 0.5957f});
//NDArray<float> ('c', {6}, {0.9f, .75f, .6f, .95f, .5f, .3f}); //NDArray<float> ('c', {6}, {0.9f, .75f, .6f, .95f, .5f, .3f});
// NDArray expected = NDArrayFactory::create<float>('c', {2,5,5,1}, { // NDArray expected = NDArrayFactory::create<float>('c', {2,5,5,1}, {
@ -2351,17 +2460,17 @@ TEST_F(DeclarableOpsTests10, Image_DrawBoundingBoxes_3) {
// 0.2585f, 0.4189f,0.7028f,0.7679f,0.5373f,0.7234f,0.2690f,0.0062f,0.0327f,0.0644f, // 0.2585f, 0.4189f,0.7028f,0.7679f,0.5373f,0.7234f,0.2690f,0.0062f,0.0327f,0.0644f,
// 0.8428f, 0.9441f,0.9441f,0.9441f,0.3491f,0.5793f,0.5730f,0.1822f,0.6420f,0.9143f }); // 0.8428f, 0.9441f,0.9441f,0.9441f,0.3491f,0.5793f,0.5730f,0.1822f,0.6420f,0.9143f });
NDArray expected = NDArrayFactory::create<float>('c', {2,5,5,1}, { NDArray expected = NDArrayFactory::create<float>('c', {2,5,5,1}, {
0.7788, 0.8012, 0.7244, 0.2309, 0.7271, 0.7788f, 0.8012f, 0.7244f, 0.2309f, 0.7271f,
0.1804, 0.5056, 0.8925, 0.5461, 0.9234, 0.1804f, 0.5056f, 0.8925f, 0.5461f, 0.9234f,
0.0856, 0.7938, 0.9441, 0.9441, 0.1596, 0.0856f, 0.7938f, 0.9441f, 0.9441f, 0.1596f,
0.3087, 0.1548, 0.4695, 0.9939, 0.6113, 0.3087f, 0.1548f, 0.4695f, 0.9939f, 0.6113f,
0.6765, 0.18 , 0.675 , 0.2246, 0.0509, 0.6765f, 0.18f , 0.675f , 0.2246f, 0.0509f,
0.4601, 0.8284, 0.2354, 0.9752, 0.8361, 0.4601f, 0.8284f, 0.2354f, 0.9752f, 0.8361f,
0.2585, 0.4189, 0.7028, 0.7679, 0.5373, 0.2585f, 0.4189f, 0.7028f, 0.7679f, 0.5373f,
0.7234, 0.269 , 0.0062, 0.0327, 0.0644, 0.7234f, 0.269f , 0.0062f, 0.0327f, 0.0644f,
0.8428, 0.9441, 0.9441, 0.9441, 0.3491, 0.8428f, 0.9441f, 0.9441f, 0.9441f, 0.3491f,
0.5793, 0.573 , 0.1822, 0.642 , 0.9143}); 0.5793f, 0.573f , 0.1822f, 0.642f , 0.9143f});
nd4j::ops::draw_bounding_boxes op; nd4j::ops::draw_bounding_boxes op;
auto results = op.execute({&images, &boxes, &colors}, {}, {}); auto results = op.execute({&images, &boxes, &colors}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());

View File

@ -477,6 +477,617 @@ TEST_F(DeclarableOpsTests11, log_loss_grad_test13) {
delete results; delete results;
} }
TEST_F(DeclarableOpsTests11, ImageResizeBicubic_Test1) {
NDArray input = NDArrayFactory::create<double>('c', {1, 7, 7, 1}, {
1, 2.1, 3.15, 4.2, 5.15, 6.1, 7,
8, 9.1, 10., 11, 12.9, 13.1, 14,
15, 16., 17., 18, 19, 20., 21,
22, 23., 24., 25, 26, 27, 28,
30, 31, 32, 33, 34., 35, 36,
37, 38, 39, 40, 41., 42, 43,
44, 45, 46, 47, 48., 49, 50
});
NDArray expected = NDArrayFactory::create<double>('c', {1, 30, 30, 1}, {
1. ,1.1976162 ,1.4174359 ,1.6775769 ,1.9961575 ,2.3283265 ,
2.550918 ,2.7360606 ,2.9655411 ,3.2929654 ,3.5441515 ,3.7380352 ,
3.948995 ,4.248106 ,4.5073795 ,4.6843743 ,4.8572845 ,5.104302 ,
5.3869915 ,5.581401 ,5.7539616 ,5.974285 ,6.272836 ,6.5204263 ,
6.718899 ,6.8871036 ,7.039068 ,7.099216 ,7.0784245 ,7.0281887 ,
2.247592 ,2.446947 ,2.6694887 ,2.9312382 ,3.248216 ,3.5745337 ,
3.78931 ,3.9656973 ,4.186417 ,4.5046535 ,4.740569 ,4.9217057 ,
5.133866 ,5.459533 ,5.7744613 ,6.0197873 ,6.254011 ,6.535633 ,
6.8097296 ,6.9607787 ,7.0749416 ,7.241601 ,7.5094895 ,7.7499495 ,
7.954571 ,8.131972 ,8.286526 ,8.346463 ,8.325745 ,8.275683 ,
3.6286845 ,3.830573 ,4.0569587 ,4.3211575 ,4.6364856 ,4.9556503 ,
5.160583 ,5.3258467 ,5.535462 ,5.84216 ,6.058749 ,6.223753 ,
6.437597 ,6.797369 ,7.1836042 ,7.5164022 ,7.8290343 ,8.154773 ,
8.417635 ,8.512958 ,8.5521 ,8.649708 ,8.87788 ,9.108794 ,
9.320926 ,9.509781 ,9.667375 ,9.72694 ,9.706349 ,9.656599 ,
5.276778 ,5.480438 ,5.709702 ,5.9754477 ,6.288551 ,6.6005697 ,
6.796207 ,6.9511423 ,7.1503997 ,7.4461427 ,7.644651 ,7.794562 ,
8.009684 ,8.400473 ,8.851847 ,9.26469 ,9.649218, 10.015648 ,
10.268647 ,10.313368 ,10.2843275 ,10.319379 ,10.512033 ,10.734956 ,
10.954604 ,11.154507 ,11.315369 ,11.374779 ,11.354242 ,11.304622 ,
7.325373 ,7.5284843 ,7.757575 ,8.022221 ,8.331997 ,8.638187 ,
8.827649 ,8.976217 ,9.168955 ,9.45726 ,9.6442375 ,9.784517 ,
9.999621, 10.407702 ,10.896234, 11.355122, 11.781423, 12.172186 ,
12.420712 ,12.4374485 ,12.370511 ,12.371386 ,12.545973 ,12.766424 ,
12.992249 ,13.20012 ,13.364252 ,13.424109 ,13.40342 ,13.353425 ,
9.493208 ,9.692467 ,9.9169445, 10.176801, 10.482199, 10.78547 ,
10.974367 ,11.123442 ,11.31637 ,11.603645 ,11.790616 ,11.930889 ,
12.144082 ,12.546447 ,13.024898 ,13.4723 ,13.889232 ,14.276275 ,
14.528972 ,14.555555 ,14.50145 ,14.515459 ,14.700572 ,14.927055 ,
15.156046 ,15.366046 ,15.532901 ,15.594008 ,15.5728855 ,15.521847 ,
10.970133 ,11.163599 ,11.380694 ,11.633735 ,11.935032 ,12.238887 ,
12.43254 ,12.588294 ,12.787534 ,13.079956 ,13.27752 ,13.426631 ,
13.636713 ,14.013844 ,14.441672 ,14.827978 ,15.191209 ,15.549808 ,
15.81343 ,15.881828 ,15.883522 ,15.950411 ,16.16933 ,16.40794 ,
16.636436 ,16.842583 ,17.010887 ,17.07363 ,17.05194 ,16.999537 ,
12.219155 ,12.406129 ,12.614796 ,12.860335 ,13.157928 ,13.464224 ,
13.665207 ,13.830567 ,14.039036 ,14.339629 ,14.552863 ,14.715049 ,
14.921564 ,15.264454 ,15.622843 ,15.924977 ,16.213829 ,16.532364 ,
16.8099 ,16.934835 ,17.012146 ,17.150164 ,17.413412 ,17.666712 ,
17.892765 ,18.09207 ,18.261044 ,18.325531 ,18.303238 ,18.249378 ,
13.7663965 ,13.947391 ,14.148263 ,14.386917 ,14.681246 ,14.990087 ,
15.198166 ,15.372728 ,15.590062 ,15.898583 ,16.126892 ,16.301655 ,
16.50487 ,16.815214 ,17.107498 ,17.329458 ,17.547403 ,17.827654 ,
18.118288 ,18.296928 ,18.4461 ,18.651634 ,18.956806 ,19.22382 ,
19.447308 ,19.639887 ,19.809319 ,19.875397 ,19.852556 ,19.797365 ,
15.9419365 ,16.118704 ,16.314133 ,16.547867 ,16.839561 ,17.14954 ,
17.361883 ,17.542162 ,17.764957 ,18.078188 ,18.315733 ,18.498205 ,
18.699116 ,18.988684 ,19.238989 ,19.410137 ,19.583265 ,19.839512 ,
20.13878 ,20.35177 ,20.546844 ,20.795671 ,21.128067 ,21.404358 ,
21.626736 ,21.8155 ,21.98561 ,22.052843 ,22.029604 ,21.973448 ,
17.53522 ,17.71077 ,17.904636 ,18.13695 ,18.42784 ,18.738056 ,
18.951529 ,19.133352 ,19.357613 ,19.672083 ,19.912102 ,20.096638 ,
20.296894 ,20.580765 ,20.819603 ,20.976887 ,21.137802 ,21.387535 ,
21.689209 ,21.911621 ,22.119276 ,22.37999 ,22.71991 ,22.998823 ,
23.22097 ,23.40876 ,23.57911 ,23.646685 ,23.623325 ,23.566887 ,
18.746353 ,18.922657 ,19.117487 ,19.350685 ,19.64207 ,19.952137 ,
20.164913 ,20.345781 ,20.569134 ,20.88284 ,21.12133 ,21.30459 ,
21.505253 ,21.792645 ,22.038572 ,22.204426 ,22.37289 ,22.626648 ,
22.926834 ,23.143423 ,23.343302 ,23.596668 ,23.931936 ,24.209232 ,
24.431519 ,24.619913 ,24.79011 ,24.857473 ,24.83419 ,24.777927 ,
20.16656 ,20.344206 ,20.540766 ,20.775532 ,21.067804 ,21.377607 ,
21.589132 ,21.768297 ,21.99003 ,22.302366 ,22.538124 ,22.719105 ,
22.920494 ,23.214176 ,23.472767 ,23.653934 ,23.83589 ,24.096842 ,
24.394371 ,24.600555 ,24.786541 ,25.026773 ,25.353731 ,25.62813 ,
25.850672 ,26.04014 ,26.210072 ,26.277063 ,26.253906 ,26.197956 ,
22.363024 ,22.54125 ,22.738552 ,22.973991 ,23.266647 ,23.57634 ,
23.787327 ,23.96576 ,24.186796 ,24.498543 ,24.733124 ,24.913122 ,
25.114826 ,25.411213 ,25.675262 ,25.863028 ,26.050789 ,26.314838 ,
26.611223 ,26.812925 ,26.992926 ,27.227505 ,27.550882 ,27.824034 ,
28.046684 ,28.236614 ,28.406433 ,28.473265 ,28.450163 ,28.394344 ,
24.429443 ,24.60767 ,24.80497 ,25.04041 ,25.333065 ,25.642756 ,
25.853743 ,26.032173 ,26.25321 ,26.564959 ,26.79954 ,26.97954 ,
27.181242 ,27.47763 ,27.74168 ,27.929441 ,28.117207 ,28.381254 ,
28.677637 ,28.879343 ,29.059345 ,29.293922 ,29.617298 ,29.890451 ,
30.113104 ,30.303034 ,30.472853 ,30.539684 ,30.516582 ,30.460762 ,
26. ,26.178228 ,26.375526 ,26.61097 ,26.903624 ,27.213314 ,
27.424305 ,27.602734 ,27.823772 ,28.135519 ,28.3701 ,28.550098 ,
28.7518 ,29.04819 ,29.312237 ,29.5 ,29.687763 ,29.951813 ,
30.2482 ,30.449903 ,30.629902 ,30.864483 ,31.187859 ,31.461012 ,
31.683659 ,31.873592 ,32.043407 ,32.11024 ,32.087135 ,32.03132 ,
27.570559 ,27.748787 ,27.946087 ,28.181528 ,28.474184 ,28.783876 ,
28.994865 ,29.173294 ,29.39433 ,29.70608 ,29.940659 ,30.120655 ,
30.32236 ,30.618746 ,30.882797 ,31.070557 ,31.25832 ,31.522371 ,
31.818754 ,32.02046 ,32.20046 ,32.43504 ,32.758415 ,33.031567 ,
33.25422 ,33.44415 ,33.613964 ,33.680794 ,33.657696 ,33.60188 ,
29.636976 ,29.815207 ,30.0125 ,30.247944 ,30.5406 ,30.85029 ,
31.061283 ,31.239712 ,31.46075 ,31.7725 ,32.00708 ,32.187077 ,
32.38878 ,32.685165 ,32.949215 ,33.13698 ,33.32474 ,33.58879 ,
33.885178 ,34.086884 ,34.26688 ,34.501457 ,34.824837 ,35.09799 ,
35.320637 ,35.510574 ,35.68039 ,35.747215 ,35.724117 ,35.6683 ,
31.83344 ,32.011665 ,32.20897 ,32.444412 ,32.73707 ,33.046757 ,
33.257744 ,33.436176 ,33.657207 ,33.96896 ,34.203537 ,34.383537 ,
34.58524 ,34.88163 ,35.145676 ,35.33344 ,35.521206 ,35.785255 ,
36.081642 ,36.28334 ,36.46334 ,36.69792 ,37.021297 ,37.294453 ,
37.517097 ,37.707027 ,37.876846 ,37.94368 ,37.920578 ,37.864758 ,
33.253647 ,33.431873 ,33.62917 ,33.864613 ,34.15727 ,34.466957 ,
34.677948 ,34.856377 ,35.077415 ,35.38916 ,35.623745 ,35.803745 ,
36.005447 ,36.301834 ,36.565884 ,36.753647 ,36.941406 ,37.205456 ,
37.50184 ,37.703545 ,37.883545 ,38.118122 ,38.4415 ,38.714653 ,
38.9373 ,39.127235 ,39.297054 ,39.363884 ,39.340782 ,39.28496 ,
34.464783 ,34.64301 ,34.840305 ,35.075752 ,35.368404 ,35.6781 ,
35.889088 ,36.067516 ,36.28855 ,36.6003 ,36.834885 ,37.014877 ,
37.216583 ,37.51297 ,37.77702 ,37.964783 ,38.152546 ,38.416595 ,
38.71298 ,38.914684 ,39.094685 ,39.32926 ,39.652645 ,39.925793 ,
40.14844 ,40.338375 ,40.508194 ,40.575024 ,40.55192 ,40.496105 ,
36.058067 ,36.23629 ,36.43359 ,36.669033 ,36.961685 ,37.271378 ,
37.48237 ,37.6608 ,37.881836 ,38.19359 ,38.42817 ,38.608162 ,
38.809868 ,39.10625 ,39.3703 ,39.558064 ,39.74583 ,40.00988 ,
40.306267 ,40.50797 ,40.68797 ,40.92255 ,41.245926 ,41.519077 ,
41.741722 ,41.931652 ,42.101475 ,42.168304 ,42.145203 ,42.089386 ,
38.315002 ,38.493233 ,38.690533 ,38.925976 ,39.218628 ,39.52832 ,
39.739307 ,39.917736 ,40.138775 ,40.45052 ,40.685104 ,40.865097 ,
41.066803 ,41.36319 ,41.627243 ,41.815002 ,42.002766 ,42.26682 ,
42.5632 ,42.764908 ,42.944904 ,43.179485 ,43.50286 ,43.776016 ,
43.998665 ,44.188595 ,44.358418 ,44.425247 ,44.402145 ,44.34633 ,
40.22708 ,40.40531 ,40.602608 ,40.83805 ,41.130707 ,41.440395 ,
41.651382 ,41.82982 ,42.050854 ,42.3626 ,42.597183 ,42.77718 ,
42.97888 ,43.27527 ,43.53932 ,43.72708 ,43.914845 ,44.178894 ,
44.47528 ,44.676983 ,44.856983 ,45.09156 ,45.41494 ,45.68809 ,
45.91074 ,46.100674 ,46.270493 ,46.337322 ,46.31422 ,46.2584 ,
41.785618 ,41.963844 ,42.161144 ,42.396584 ,42.68924 ,42.998936 ,
43.209923 ,43.388355 ,43.609394 ,43.921143 ,44.15572 ,44.335716 ,
44.53742 ,44.833805 ,45.09786 ,45.285614 ,45.473377 ,45.737427 ,
46.033817 ,46.235523 ,46.415524 ,46.650105 ,46.973476 ,47.24663 ,
47.469276 ,47.65921 ,47.82903 ,47.895855 ,47.872753 ,47.81694 ,
43.11514 ,43.293365 ,43.490665 ,43.726105 ,44.018764 ,44.328457 ,
44.539444 ,44.717873 ,44.93891 ,45.25066 ,45.48524 ,45.665237 ,
45.86694 ,46.163326 ,46.427376 ,46.615143 ,46.802902 ,47.066956 ,
47.363342 ,47.56505 ,47.74505 ,47.979626 ,48.302998 ,48.576153 ,
48.798798 ,48.98873 ,49.158546 ,49.225376 ,49.202282 ,49.146458 ,
44.303867 ,44.482094 ,44.679394 ,44.914833 ,45.207493 ,45.51718 ,
45.72817 ,45.9066 ,46.12764 ,46.439384 ,46.673965 ,46.853966 ,
47.055668 ,47.352055 ,47.6161 ,47.803867 ,47.99163 ,48.25568 ,
48.552063 ,48.75377 ,48.933773 ,49.16835 ,49.491726 ,49.764877 ,
49.987526 ,50.17746 ,50.347275 ,50.4141 ,50.391006 ,50.335186 ,
44.771675 ,44.949905 ,45.1472 ,45.382645 ,45.6753 ,45.98499 ,
46.195976 ,46.374413 ,46.595448 ,46.907196 ,47.141773 ,47.321774 ,
47.523476 ,47.819862 ,48.08391 ,48.27168 ,48.459446 ,48.72349 ,
49.019882 ,49.22158 ,49.401585 ,49.63616 ,49.959538 ,50.232693 ,
50.455338 ,50.64527 ,50.81509 ,50.88192 ,50.858818 ,50.803 ,
44.609966 ,44.788193 ,44.985493 ,45.220936 ,45.51359 ,45.82328 ,
46.03427 ,46.2127 ,46.433743 ,46.74549 ,46.98007 ,47.160065 ,
47.36177 ,47.658157 ,47.922207 ,48.10997 ,48.297733 ,48.561783 ,
48.858166 ,49.059875 ,49.239872 ,49.47445 ,49.79783 ,50.07098 ,
50.293625 ,50.48356 ,50.653378 ,50.720203 ,50.6971 ,50.64128 ,
44.219246 ,44.397472 ,44.594772 ,44.83021 ,45.122868 ,45.43256 ,
45.643543 ,45.82198 ,46.04302 ,46.354763 ,46.589344 ,46.76934 ,
46.971046 ,47.267433 ,47.531483 ,47.719242 ,47.907005 ,48.17105 ,
48.467438 ,48.66914 ,48.849144 ,49.08372 ,49.4071 ,49.680256 ,
49.902905 ,50.092834 ,50.262653 ,50.329483 ,50.30638 ,50.25057});
auto size = NDArrayFactory::create<int>({30, 30});
nd4j::ops::resize_bicubic op;
auto results = op.execute({&input, &size}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, results->status());
NDArray* result = results->at(0);
// result->printBuffer("Resized to 30x30");
// expected.printBuffer("Expect for 30x30");
ASSERT_TRUE(expected.isSameShape(result));
ASSERT_TRUE(expected.equalsTo(result));
delete results;
}
TEST_F(DeclarableOpsTests11, ImageResizeBicubic_Test2) {
NDArray input = NDArrayFactory::create<double>('c', {2, 5, 4, 3});
NDArray expected = NDArrayFactory::create<double>('c', {2, 10, 8, 3}, {
1. , 2. ,3. ,2.21875, 3.21875, 4.21875, 4. , 5. , 6. ,5.5,
6.5, 7.5, 7., 8., 9. ,8.78125, 9.78125, 10.78125, 10., 11. ,
12., 10.28125 , 11.28125 ,12.28125, 5.875, 6.875, 7.875, 7.09375, 8.09375 ,9.09375,
8.875, 9.875, 10.875, 10.375, 11.375, 12.375 ,11.875 ,12.875 , 13.875, 13.65625,
14.65625, 15.65625, 14.875 ,15.875 ,16.875 , 15.15625, 16.15625, 17.15625, 13., 14.,
15. ,14.21875, 15.21875, 16.21875, 16., 17., 18. ,17.5 ,18.5 , 19.5,
19., 20., 21., 20.78125 ,21.78125 ,22.78125, 22., 23. , 24. , 22.28125,
23.28125 ,24.28125 ,19. , 20., 21., 20.21875, 21.21875, 22.21875 ,22. ,23.,
24. , 23.5, 24.5, 25.5, 25. ,26. ,27., 26.78125 , 27.78125, 28.78125,
28., 29. ,30. ,28.28125, 29.28125, 30.28125, 25., 26., 27. ,26.21875,
27.21875, 28.21875, 28., 29., 30., 29.5 ,30.5 ,31.5 , 31., 32.,
33., 32.78125, 33.78125 ,34.78125 ,34., 35., 36., 34.28125, 35.28125, 36.28125,
31. ,32., 33. , 32.21875, 33.21875, 34.21875, 34. ,35. ,36., 35.5,
36.5 , 37.5 , 37., 38. ,39. ,38.78125, 39.78125, 40.78125, 40., 41.,
42. ,40.28125 ,41.28125, 42.28125, 37. , 38., 39., 38.21875 ,39.21875 ,40.21875,
40. , 41. , 42. , 41.5, 42.5, 43.5 ,43., 44., 45., 44.78125,
45.78125, 46.78125 ,46. ,47. , 48. , 46.28125 , 47.28125, 48.28125, 44.125 ,45.125,
46.125, 45.34375, 46.34375, 47.34375, 47.125, 48.125 ,49.125 ,48.625, 49.625 , 50.625,
50.125 , 51.125, 52.125 ,51.90625 ,52.90625, 53.90625, 53.125, 54.125, 55.125, 53.40625,
54.40625 ,55.40625, 49. ,50. , 51. ,50.21875, 51.21875 ,52.21875 ,52. ,53.,
54. ,53.5 , 54.5, 55.5 ,55. ,56. ,57. ,56.78125 ,57.78125, 58.78125,
58. ,59. ,60. ,58.28125 ,59.28125 ,60.28125, 50.125, 51.125 ,52.125 ,51.34375,
52.34375 ,53.34375 ,53.125, 54.125, 55.125 ,54.625 ,55.625 ,56.625 ,56.125 ,57.125,
58.125, 57.90625 ,58.90625 ,59.90625 ,59.125 ,60.125 ,61.125, 59.40625, 60.40625 ,61.40625,
61. ,62. ,63. ,62.21875, 63.21875, 64.21875 ,64. ,65. ,66. ,65.5 ,
66.5, 67.5, 67. ,68. ,69. ,68.78125 ,69.78125 ,70.78125 ,70., 71. ,
72. ,70.28125 ,71.28125 ,72.28125 ,65.875 ,66.875, 67.875 ,67.09375 ,68.09375 ,69.09375,
68.875 ,69.875 ,70.875, 70.375 ,71.375 ,72.375 ,71.875 ,72.875 ,73.875 ,73.65625,
74.65625 ,75.65625 ,74.875 ,75.875, 76.875 ,75.15625 ,76.15625,
77.15625 ,73. ,74. ,75., 74.21875 ,75.21875 ,76.21875,
76. ,77. ,78. ,77.5 ,78.5 ,79.5 ,79.,
80. ,81. ,80.78125 ,81.78125, 82.78125 ,82. ,83.,
84. ,82.28125 ,83.28125 ,84.28125, 79. ,80. ,81.,
80.21875 ,81.21875 ,82.21875 ,82., 83. ,84. ,83.5,
84.5 ,85.5 ,85. ,86., 87. ,86.78125 ,87.78125,
88.78125 ,88. ,89. ,90., 88.28125 ,89.28125 ,90.28125,
85. ,86. ,87. ,86.21875, 87.21875 ,88.21875 ,88.,
89. ,90. ,89.5 ,90.5, 91.5 ,91. ,92.,
93. ,92.78125 ,93.78125 ,94.78125, 94. ,95. ,96.,
94.28125 ,95.28125 ,96.28125 ,91., 92. ,93. ,92.21875,
93.21875 ,94.21875 ,94. ,95., 96. ,95.5 ,96.5,
97.5 ,97. ,98. ,99., 98.78125 ,99.78125 ,100.78125,
100. ,101. ,102. ,100.28125, 101.28125 ,102.28125, 97.,
98. ,99. ,98.21875 ,99.21875, 100.21875 ,100., 101.,
102. ,101.5 ,102.5 ,103.5, 103. ,104., 105.,
104.78125 ,105.78125 ,106.78125 ,106., 107. ,108., 106.28125,
107.28125 ,108.28125 ,104.125 ,105.125, 106.125 ,105.34375, 106.34375,
107.34375 ,107.125 ,108.125 ,109.125, 108.625 ,109.625, 110.625,
110.125 ,111.125 ,112.125 ,111.90625, 112.90625 ,113.90625, 113.125,
114.125 ,115.125 ,113.40625 ,114.40625, 115.40625 ,109., 110.,
111. ,110.21875 ,111.21875 ,112.21875, 112., 113., 114.,
113.5 ,114.5 ,115.5 ,115., 116., 117., 116.78125,
117.78125 ,118.78125 ,118. ,119., 120., 118.28125, 119.28125,
120.28125 ,110.125 ,111.125 ,112.125, 111.34375, 112.34375, 113.34375,
113.125 ,114.125 ,115.125 ,114.625, 115.625, 116.625, 116.125,
117.125 ,118.125 ,117.90625, 118.90625, 119.90625, 119.125, 120.125,
121.125 ,119.40625 ,120.40625, 121.40625}); //input = 1.f;
input.linspace(1);
auto size = NDArrayFactory::create<int>({10, 8});
nd4j::ops::resize_bicubic op;
auto results = op.execute({&input, &size}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, results->status());
NDArray* result = results->at(0);
// result->printBuffer("Resized to 10x8");
// expected.printBuffer("Expect for 10x8");
ASSERT_TRUE(expected.isSameShape(result));
ASSERT_TRUE(expected.equalsTo(result));
delete results;
}
TEST_F(DeclarableOpsTests11, ImageResizeBicubic_Test3) {
NDArray input = NDArrayFactory::create<double>('c', {1, 3, 3, 4});
NDArray expected = NDArrayFactory::create<double>('c', {1, 6, 6, 4}, {
1. ,2. ,3. ,4.,
2.625 ,3.625 ,4.625 ,5.625,
5. ,6. ,7. ,8.,
7.375 ,8.375 ,9.375, 10.375,
9. ,10. ,11. ,12.,
9.375 ,10.375 ,11.375 ,12.375,
5.875 ,6.875 ,7.875 , 8.875 ,
7.5 ,8.5 ,9.5 , 10.5 ,
9.875 ,10.875 ,11.875, 12.875,
12.25 ,13.25 ,14.25 , 15.25 ,
13.875 ,14.875 ,15.875, 16.875,
14.25 ,15.25 ,16.25 , 17.25 ,
13. ,14. ,15. ,16.,
14.625 ,15.625 ,16.625 ,17.625,
17. ,18. ,19. ,20.,
19.375 ,20.375 ,21.375 ,22.375,
21. ,22. ,23. ,24.,
21.375 ,22.375 ,23.375 ,24.375,
20.125 ,21.125 ,22.125 ,23.125,
21.75 ,22.75 ,23.75 ,24.75,
24.125 ,25.125 ,26.125 ,27.125,
26.5 ,27.5 ,28.5 ,29.5,
28.125 ,29.125 ,30.125 ,31.125,
28.5 ,29.5 ,30.5 ,31.5,
25. , 26. , 27. , 28.,
26.625 ,27.625 ,28.625 ,29.625,
29. ,30. ,31. ,32.,
31.375 ,32.375 ,33.375 ,34.375,
33. ,34. ,35. ,36.,
33.375 ,34.375 ,35.375 ,36.375,
26.125, 27.125, 28.125, 29.125,
27.75 ,28.75 ,29.75 ,30.75,
30.125 ,31.125 ,32.125 ,33.125,
32.5 ,33.5 ,34.5 ,35.5,
34.125 ,35.125 ,36.125 ,37.125,
34.5 ,35.5 ,36.5 ,37.5
});
input.linspace(1);
auto size = NDArrayFactory::create<int>({6, 6});
nd4j::ops::resize_bicubic op;
auto results = op.execute({&input, &size}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, results->status());
NDArray* result = results->at(0);
// result->printBuffer("Resized to 6x6");
// expected.printBuffer("Expect for 6x6");
ASSERT_TRUE(expected.isSameShape(result));
ASSERT_TRUE(expected.equalsTo(result));
delete results;
}
TEST_F(DeclarableOpsTests11, ImageResizeBicubic_Test4) {
NDArray input = NDArrayFactory::create<double>('c', {1, 3, 4, 3});
NDArray expected = NDArrayFactory::create<double>('c', {1, 6, 8, 3}, {
1. , 2. , 3. ,
2.21875 ,3.21875 ,4.21875,
4. ,5. ,6. ,
5.5 ,6.5 ,7.5 ,
7. ,8. ,9. ,
8.78125 ,9.78125, 10.78125,
10. ,11., 12. ,
10.28125 ,11.28125, 12.28125,
5.875 , 6.875 , 7.875 ,
7.09375 , 8.09375 , 9.09375,
8.875 , 9.875 ,10.875 ,
10.375 ,11.375 ,12.375 ,
11.875 ,12.875 ,13.875 ,
13.65625 ,14.65625 ,15.65625,
14.875 ,15.875 ,16.875 ,
15.15625 ,16.15625 ,17.15625,
13., 14., 15.,
14.21875 ,15.21875 ,16.21875,
16. ,17. ,18. ,
17.5 ,18.5 ,19.5 ,
19. ,20. ,21. ,
20.78125 ,21.78125 ,22.78125,
22. ,23. ,24. ,
22.28125 ,23.28125 ,24.28125,
20.125 , 21.125 , 22.125,
21.34375 ,22.34375 ,23.34375,
23.125 ,24.125 ,25.125 ,
24.625 ,25.625 ,26.625 ,
26.125 ,27.125 ,28.125 ,
27.90625 ,28.90625 ,29.90625,
29.125 ,30.125 ,31.125 ,
29.40625 ,30.40625 ,31.40625,
25. ,26. ,27. ,
26.21875 ,27.21875 ,28.21875,
28. ,29. ,30. ,
29.5 ,30.5 ,31.5 ,
31. ,32. ,33. ,
32.78125 ,33.78125 ,34.78125,
34. ,35. ,36. ,
34.28125 ,35.28125 ,36.28125,
26.125 ,27.125 , 28.125 ,
27.34375 ,28.34375 ,29.34375,
29.125 ,30.125 ,31.125 ,
30.625 ,31.625 ,32.625 ,
32.125 ,33.125 ,34.125 ,
33.90625 ,34.90625 ,35.90625,
35.125 ,36.125 ,37.125 ,
35.40625 ,36.40625 ,37.40625 });
input.linspace(1);
auto size = NDArrayFactory::create<int>({6, 8});
nd4j::ops::resize_bicubic op;
auto results = op.execute({&input, &size}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, results->status());
NDArray* result = results->at(0);
// result->printBuffer("Resized to 6x8");
// expected.printBuffer("Expect for 6x8");
ASSERT_TRUE(expected.isSameShape(result));
ASSERT_TRUE(expected.equalsTo(result));
delete results;
}
TEST_F(DeclarableOpsTests11, ImageResizeBicubic_Test5) {
NDArray input = NDArrayFactory::create<double>('c', {1, 4, 4, 3});
NDArray expected = NDArrayFactory::create<double>('c', {1, 8, 8, 3}, {
1. ,2. , 3. , 2.21875 , 3.21875 , 4.21875 , 4. , 5. ,
6. ,5.5 , 6.5 , 7.5 , 7. , 8. , 9. , 8.78125 ,
9.78125 ,10.78125 ,10. ,11. ,12. ,10.28125 ,11.28125 ,12.28125 ,
5.875 ,6.875 , 7.875 , 7.09375 , 8.09375 , 9.09375 , 8.875 , 9.875 ,
10.875 ,10.375 , 11.375 , 12.375 , 11.875 , 12.875 , 13.875 , 13.65625,
14.65625 ,15.65625, 14.875 , 15.875 , 16.875 , 15.15625, 16.15625, 17.15625,
13. ,14. , 15. , 14.21875, 15.21875, 16.21875, 16. , 17. ,
18. ,17.5 , 18.5 , 19.5 , 19. , 20. , 21. , 20.78125,
21.78125 ,22.78125, 22. , 23. , 24. , 22.28125, 23.28125, 24.28125,
19. ,20. , 21. , 20.21875, 21.21875, 22.21875, 22. , 23. ,
24. ,23.5 , 24.5 , 25.5 , 25. , 26. , 27. , 26.78125,
27.78125 ,28.78125, 28. , 29. , 30. , 28.28125, 29.28125, 30.28125,
25. ,26. , 27. , 26.21875, 27.21875, 28.21875, 28. , 29. ,
30. ,29.5 , 30.5 , 31.5 , 31. , 32. , 33. , 32.78125,
33.78125 ,34.78125, 34. , 35. , 36. , 34.28125, 35.28125, 36.28125,
32.125 ,33.125 , 34.125 , 33.34375, 34.34375, 35.34375, 35.125 , 36.125 ,
37.125 ,36.625 , 37.625 , 38.625 , 38.125 , 39.125 , 40.125 , 39.90625,
40.90625 ,41.90625, 41.125 , 42.125 , 43.125 , 41.40625, 42.40625, 43.40625,
37. ,38. , 39. , 38.21875, 39.21875, 40.21875, 40. , 41. ,
42. ,41.5 , 42.5 , 43.5 , 43. , 44. , 45. , 44.78125,
45.78125 ,46.78125, 46. , 47. , 48. , 46.28125, 47.28125, 48.28125,
38.125 ,39.125 , 40.125 , 39.34375, 40.34375, 41.34375, 41.125 , 42.125 ,
43.125 ,42.625 , 43.625 , 44.625 , 44.125 , 45.125 , 46.125 , 45.90625,
46.90625 ,47.90625, 47.125 , 48.125 , 49.125 , 47.40625, 48.40625, 49.40625,
});
input.linspace(1);
auto size = NDArrayFactory::create<int>({8, 8});
nd4j::ops::resize_bicubic op;
auto results = op.execute({&input, &size}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, results->status());
NDArray* result = results->at(0);
// result->printBuffer("Resized to 8x8");
// expected.printBuffer("Expect for 8x8");
ASSERT_TRUE(expected.isSameShape(result));
ASSERT_TRUE(expected.equalsTo(result));
delete results;
}
TEST_F(DeclarableOpsTests11, ImageResizeBicubic_Test6) {
NDArray input = NDArrayFactory::create<double>('c', {7, 7, 1}, {
1, 2.1, 3.15, 4.2, 5.15, 6.1, 7,
8, 9.1, 10., 11, 12.9, 13.1, 14,
15, 16., 17., 18, 19, 20., 21,
22, 23., 24., 25, 26, 27, 28,
30, 31, 32, 33, 34., 35, 36,
37, 38, 39, 40, 41., 42, 43,
44, 45, 46, 47, 48., 49, 50
});
NDArray expected = NDArrayFactory::create<double>('c', {30, 30, 1}, {
1. ,1.1976162 ,1.4174359 ,1.6775769 ,1.9961575 ,2.3283265 ,
2.550918 ,2.7360606 ,2.9655411 ,3.2929654 ,3.5441515 ,3.7380352 ,
3.948995 ,4.248106 ,4.5073795 ,4.6843743 ,4.8572845 ,5.104302 ,
5.3869915 ,5.581401 ,5.7539616 ,5.974285 ,6.272836 ,6.5204263 ,
6.718899 ,6.8871036 ,7.039068 ,7.099216 ,7.0784245 ,7.0281887 ,
2.247592 ,2.446947 ,2.6694887 ,2.9312382 ,3.248216 ,3.5745337 ,
3.78931 ,3.9656973 ,4.186417 ,4.5046535 ,4.740569 ,4.9217057 ,
5.133866 ,5.459533 ,5.7744613 ,6.0197873 ,6.254011 ,6.535633 ,
6.8097296 ,6.9607787 ,7.0749416 ,7.241601 ,7.5094895 ,7.7499495 ,
7.954571 ,8.131972 ,8.286526 ,8.346463 ,8.325745 ,8.275683 ,
3.6286845 ,3.830573 ,4.0569587 ,4.3211575 ,4.6364856 ,4.9556503 ,
5.160583 ,5.3258467 ,5.535462 ,5.84216 ,6.058749 ,6.223753 ,
6.437597 ,6.797369 ,7.1836042 ,7.5164022 ,7.8290343 ,8.154773 ,
8.417635 ,8.512958 ,8.5521 ,8.649708 ,8.87788 ,9.108794 ,
9.320926 ,9.509781 ,9.667375 ,9.72694 ,9.706349 ,9.656599 ,
5.276778 ,5.480438 ,5.709702 ,5.9754477 ,6.288551 ,6.6005697 ,
6.796207 ,6.9511423 ,7.1503997 ,7.4461427 ,7.644651 ,7.794562 ,
8.009684 ,8.400473 ,8.851847 ,9.26469 ,9.649218, 10.015648 ,
10.268647 ,10.313368 ,10.2843275 ,10.319379 ,10.512033 ,10.734956 ,
10.954604 ,11.154507 ,11.315369 ,11.374779 ,11.354242 ,11.304622 ,
7.325373 ,7.5284843 ,7.757575 ,8.022221 ,8.331997 ,8.638187 ,
8.827649 ,8.976217 ,9.168955 ,9.45726 ,9.6442375 ,9.784517 ,
9.999621, 10.407702 ,10.896234, 11.355122, 11.781423, 12.172186 ,
12.420712 ,12.4374485 ,12.370511 ,12.371386 ,12.545973 ,12.766424 ,
12.992249 ,13.20012 ,13.364252 ,13.424109 ,13.40342 ,13.353425 ,
9.493208 ,9.692467 ,9.9169445, 10.176801, 10.482199, 10.78547 ,
10.974367 ,11.123442 ,11.31637 ,11.603645 ,11.790616 ,11.930889 ,
12.144082 ,12.546447 ,13.024898 ,13.4723 ,13.889232 ,14.276275 ,
14.528972 ,14.555555 ,14.50145 ,14.515459 ,14.700572 ,14.927055 ,
15.156046 ,15.366046 ,15.532901 ,15.594008 ,15.5728855 ,15.521847 ,
10.970133 ,11.163599 ,11.380694 ,11.633735 ,11.935032 ,12.238887 ,
12.43254 ,12.588294 ,12.787534 ,13.079956 ,13.27752 ,13.426631 ,
13.636713 ,14.013844 ,14.441672 ,14.827978 ,15.191209 ,15.549808 ,
15.81343 ,15.881828 ,15.883522 ,15.950411 ,16.16933 ,16.40794 ,
16.636436 ,16.842583 ,17.010887 ,17.07363 ,17.05194 ,16.999537 ,
12.219155 ,12.406129 ,12.614796 ,12.860335 ,13.157928 ,13.464224 ,
13.665207 ,13.830567 ,14.039036 ,14.339629 ,14.552863 ,14.715049 ,
14.921564 ,15.264454 ,15.622843 ,15.924977 ,16.213829 ,16.532364 ,
16.8099 ,16.934835 ,17.012146 ,17.150164 ,17.413412 ,17.666712 ,
17.892765 ,18.09207 ,18.261044 ,18.325531 ,18.303238 ,18.249378 ,
13.7663965 ,13.947391 ,14.148263 ,14.386917 ,14.681246 ,14.990087 ,
15.198166 ,15.372728 ,15.590062 ,15.898583 ,16.126892 ,16.301655 ,
16.50487 ,16.815214 ,17.107498 ,17.329458 ,17.547403 ,17.827654 ,
18.118288 ,18.296928 ,18.4461 ,18.651634 ,18.956806 ,19.22382 ,
19.447308 ,19.639887 ,19.809319 ,19.875397 ,19.852556 ,19.797365 ,
15.9419365 ,16.118704 ,16.314133 ,16.547867 ,16.839561 ,17.14954 ,
17.361883 ,17.542162 ,17.764957 ,18.078188 ,18.315733 ,18.498205 ,
18.699116 ,18.988684 ,19.238989 ,19.410137 ,19.583265 ,19.839512 ,
20.13878 ,20.35177 ,20.546844 ,20.795671 ,21.128067 ,21.404358 ,
21.626736 ,21.8155 ,21.98561 ,22.052843 ,22.029604 ,21.973448 ,
17.53522 ,17.71077 ,17.904636 ,18.13695 ,18.42784 ,18.738056 ,
18.951529 ,19.133352 ,19.357613 ,19.672083 ,19.912102 ,20.096638 ,
20.296894 ,20.580765 ,20.819603 ,20.976887 ,21.137802 ,21.387535 ,
21.689209 ,21.911621 ,22.119276 ,22.37999 ,22.71991 ,22.998823 ,
23.22097 ,23.40876 ,23.57911 ,23.646685 ,23.623325 ,23.566887 ,
18.746353 ,18.922657 ,19.117487 ,19.350685 ,19.64207 ,19.952137 ,
20.164913 ,20.345781 ,20.569134 ,20.88284 ,21.12133 ,21.30459 ,
21.505253 ,21.792645 ,22.038572 ,22.204426 ,22.37289 ,22.626648 ,
22.926834 ,23.143423 ,23.343302 ,23.596668 ,23.931936 ,24.209232 ,
24.431519 ,24.619913 ,24.79011 ,24.857473 ,24.83419 ,24.777927 ,
20.16656 ,20.344206 ,20.540766 ,20.775532 ,21.067804 ,21.377607 ,
21.589132 ,21.768297 ,21.99003 ,22.302366 ,22.538124 ,22.719105 ,
22.920494 ,23.214176 ,23.472767 ,23.653934 ,23.83589 ,24.096842 ,
24.394371 ,24.600555 ,24.786541 ,25.026773 ,25.353731 ,25.62813 ,
25.850672 ,26.04014 ,26.210072 ,26.277063 ,26.253906 ,26.197956 ,
22.363024 ,22.54125 ,22.738552 ,22.973991 ,23.266647 ,23.57634 ,
23.787327 ,23.96576 ,24.186796 ,24.498543 ,24.733124 ,24.913122 ,
25.114826 ,25.411213 ,25.675262 ,25.863028 ,26.050789 ,26.314838 ,
26.611223 ,26.812925 ,26.992926 ,27.227505 ,27.550882 ,27.824034 ,
28.046684 ,28.236614 ,28.406433 ,28.473265 ,28.450163 ,28.394344 ,
24.429443 ,24.60767 ,24.80497 ,25.04041 ,25.333065 ,25.642756 ,
25.853743 ,26.032173 ,26.25321 ,26.564959 ,26.79954 ,26.97954 ,
27.181242 ,27.47763 ,27.74168 ,27.929441 ,28.117207 ,28.381254 ,
28.677637 ,28.879343 ,29.059345 ,29.293922 ,29.617298 ,29.890451 ,
30.113104 ,30.303034 ,30.472853 ,30.539684 ,30.516582 ,30.460762 ,
26. ,26.178228 ,26.375526 ,26.61097 ,26.903624 ,27.213314 ,
27.424305 ,27.602734 ,27.823772 ,28.135519 ,28.3701 ,28.550098 ,
28.7518 ,29.04819 ,29.312237 ,29.5 ,29.687763 ,29.951813 ,
30.2482 ,30.449903 ,30.629902 ,30.864483 ,31.187859 ,31.461012 ,
31.683659 ,31.873592 ,32.043407 ,32.11024 ,32.087135 ,32.03132 ,
27.570559 ,27.748787 ,27.946087 ,28.181528 ,28.474184 ,28.783876 ,
28.994865 ,29.173294 ,29.39433 ,29.70608 ,29.940659 ,30.120655 ,
30.32236 ,30.618746 ,30.882797 ,31.070557 ,31.25832 ,31.522371 ,
31.818754 ,32.02046 ,32.20046 ,32.43504 ,32.758415 ,33.031567 ,
33.25422 ,33.44415 ,33.613964 ,33.680794 ,33.657696 ,33.60188 ,
29.636976 ,29.815207 ,30.0125 ,30.247944 ,30.5406 ,30.85029 ,
31.061283 ,31.239712 ,31.46075 ,31.7725 ,32.00708 ,32.187077 ,
32.38878 ,32.685165 ,32.949215 ,33.13698 ,33.32474 ,33.58879 ,
33.885178 ,34.086884 ,34.26688 ,34.501457 ,34.824837 ,35.09799 ,
35.320637 ,35.510574 ,35.68039 ,35.747215 ,35.724117 ,35.6683 ,
31.83344 ,32.011665 ,32.20897 ,32.444412 ,32.73707 ,33.046757 ,
33.257744 ,33.436176 ,33.657207 ,33.96896 ,34.203537 ,34.383537 ,
34.58524 ,34.88163 ,35.145676 ,35.33344 ,35.521206 ,35.785255 ,
36.081642 ,36.28334 ,36.46334 ,36.69792 ,37.021297 ,37.294453 ,
37.517097 ,37.707027 ,37.876846 ,37.94368 ,37.920578 ,37.864758 ,
33.253647 ,33.431873 ,33.62917 ,33.864613 ,34.15727 ,34.466957 ,
34.677948 ,34.856377 ,35.077415 ,35.38916 ,35.623745 ,35.803745 ,
36.005447 ,36.301834 ,36.565884 ,36.753647 ,36.941406 ,37.205456 ,
37.50184 ,37.703545 ,37.883545 ,38.118122 ,38.4415 ,38.714653 ,
38.9373 ,39.127235 ,39.297054 ,39.363884 ,39.340782 ,39.28496 ,
34.464783 ,34.64301 ,34.840305 ,35.075752 ,35.368404 ,35.6781 ,
35.889088 ,36.067516 ,36.28855 ,36.6003 ,36.834885 ,37.014877 ,
37.216583 ,37.51297 ,37.77702 ,37.964783 ,38.152546 ,38.416595 ,
38.71298 ,38.914684 ,39.094685 ,39.32926 ,39.652645 ,39.925793 ,
40.14844 ,40.338375 ,40.508194 ,40.575024 ,40.55192 ,40.496105 ,
36.058067 ,36.23629 ,36.43359 ,36.669033 ,36.961685 ,37.271378 ,
37.48237 ,37.6608 ,37.881836 ,38.19359 ,38.42817 ,38.608162 ,
38.809868 ,39.10625 ,39.3703 ,39.558064 ,39.74583 ,40.00988 ,
40.306267 ,40.50797 ,40.68797 ,40.92255 ,41.245926 ,41.519077 ,
41.741722 ,41.931652 ,42.101475 ,42.168304 ,42.145203 ,42.089386 ,
38.315002 ,38.493233 ,38.690533 ,38.925976 ,39.218628 ,39.52832 ,
39.739307 ,39.917736 ,40.138775 ,40.45052 ,40.685104 ,40.865097 ,
41.066803 ,41.36319 ,41.627243 ,41.815002 ,42.002766 ,42.26682 ,
42.5632 ,42.764908 ,42.944904 ,43.179485 ,43.50286 ,43.776016 ,
43.998665 ,44.188595 ,44.358418 ,44.425247 ,44.402145 ,44.34633 ,
40.22708 ,40.40531 ,40.602608 ,40.83805 ,41.130707 ,41.440395 ,
41.651382 ,41.82982 ,42.050854 ,42.3626 ,42.597183 ,42.77718 ,
42.97888 ,43.27527 ,43.53932 ,43.72708 ,43.914845 ,44.178894 ,
44.47528 ,44.676983 ,44.856983 ,45.09156 ,45.41494 ,45.68809 ,
45.91074 ,46.100674 ,46.270493 ,46.337322 ,46.31422 ,46.2584 ,
41.785618 ,41.963844 ,42.161144 ,42.396584 ,42.68924 ,42.998936 ,
43.209923 ,43.388355 ,43.609394 ,43.921143 ,44.15572 ,44.335716 ,
44.53742 ,44.833805 ,45.09786 ,45.285614 ,45.473377 ,45.737427 ,
46.033817 ,46.235523 ,46.415524 ,46.650105 ,46.973476 ,47.24663 ,
47.469276 ,47.65921 ,47.82903 ,47.895855 ,47.872753 ,47.81694 ,
43.11514 ,43.293365 ,43.490665 ,43.726105 ,44.018764 ,44.328457 ,
44.539444 ,44.717873 ,44.93891 ,45.25066 ,45.48524 ,45.665237 ,
45.86694 ,46.163326 ,46.427376 ,46.615143 ,46.802902 ,47.066956 ,
47.363342 ,47.56505 ,47.74505 ,47.979626 ,48.302998 ,48.576153 ,
48.798798 ,48.98873 ,49.158546 ,49.225376 ,49.202282 ,49.146458 ,
44.303867 ,44.482094 ,44.679394 ,44.914833 ,45.207493 ,45.51718 ,
45.72817 ,45.9066 ,46.12764 ,46.439384 ,46.673965 ,46.853966 ,
47.055668 ,47.352055 ,47.6161 ,47.803867 ,47.99163 ,48.25568 ,
48.552063 ,48.75377 ,48.933773 ,49.16835 ,49.491726 ,49.764877 ,
49.987526 ,50.17746 ,50.347275 ,50.4141 ,50.391006 ,50.335186 ,
44.771675 ,44.949905 ,45.1472 ,45.382645 ,45.6753 ,45.98499 ,
46.195976 ,46.374413 ,46.595448 ,46.907196 ,47.141773 ,47.321774 ,
47.523476 ,47.819862 ,48.08391 ,48.27168 ,48.459446 ,48.72349 ,
49.019882 ,49.22158 ,49.401585 ,49.63616 ,49.959538 ,50.232693 ,
50.455338 ,50.64527 ,50.81509 ,50.88192 ,50.858818 ,50.803 ,
44.609966 ,44.788193 ,44.985493 ,45.220936 ,45.51359 ,45.82328 ,
46.03427 ,46.2127 ,46.433743 ,46.74549 ,46.98007 ,47.160065 ,
47.36177 ,47.658157 ,47.922207 ,48.10997 ,48.297733 ,48.561783 ,
48.858166 ,49.059875 ,49.239872 ,49.47445 ,49.79783 ,50.07098 ,
50.293625 ,50.48356 ,50.653378 ,50.720203 ,50.6971 ,50.64128 ,
44.219246 ,44.397472 ,44.594772 ,44.83021 ,45.122868 ,45.43256 ,
45.643543 ,45.82198 ,46.04302 ,46.354763 ,46.589344 ,46.76934 ,
46.971046 ,47.267433 ,47.531483 ,47.719242 ,47.907005 ,48.17105 ,
48.467438 ,48.66914 ,48.849144 ,49.08372 ,49.4071 ,49.680256 ,
49.902905 ,50.092834 ,50.262653 ,50.329483 ,50.30638 ,50.25057});
auto size = NDArrayFactory::create<int>({30, 30});
nd4j::ops::resize_bicubic op;
auto results = op.execute({&input, &size}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, results->status());
NDArray* result = results->at(0);
// result->printBuffer("Resized to 30x30");
// expected.printBuffer("Expect for 30x30");
ASSERT_TRUE(expected.isSameShape(result));
ASSERT_TRUE(expected.equalsTo(result));
delete results;
}
/////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests11, summaryStatsData_test1) { TEST_F(DeclarableOpsTests11, summaryStatsData_test1) {