126 lines
6.2 KiB
C++
126 lines
6.2 KiB
C++
/*
|
|
* ******************************************************************************
|
|
* *
|
|
* *
|
|
* * This program and the accompanying materials are made available under the
|
|
* * terms of the Apache License, Version 2.0 which is available at
|
|
* * https://www.apache.org/licenses/LICENSE-2.0.
|
|
* *
|
|
* * See the NOTICE file distributed with this work for additional
|
|
* * information regarding copyright ownership.
|
|
* * Unless required by applicable law or agreed to in writing, software
|
|
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
* * License for the specific language governing permissions and limitations
|
|
* * under the License.
|
|
* *
|
|
* * SPDX-License-Identifier: Apache-2.0
|
|
* *****************************************************************************
|
|
*/
|
|
|
|
//
|
|
// @author sgazeos@gmail.com
|
|
//
|
|
|
|
#include <system/op_boilerplate.h>
|
|
#if NOT_EXCLUDED(OP_resize_nearest_neighbor)
|
|
|
|
//#include <ops/declarable/headers/parity_ops.h>
|
|
#include <ops/declarable/CustomOperations.h>
|
|
#include <ops/declarable/helpers/image_resize.h>
|
|
|
|
namespace sd {
|
|
namespace ops {
|
|
CUSTOM_OP_IMPL(resize_nearest_neighbor, 1, 1, false, 0, -2) {
|
|
|
|
auto image = INPUT_VARIABLE(0);
|
|
auto output = OUTPUT_VARIABLE(0);
|
|
auto inRank = image->rankOf();
|
|
int width;
|
|
int height;
|
|
bool alignCorners = false; // - default value
|
|
if (output->isEmpty()) return Status::OK();
|
|
if (block.width() > 1) {
|
|
auto newImageSize = INPUT_VARIABLE(1);
|
|
REQUIRE_TRUE(newImageSize->lengthOf() == 2, 0, "resize_nearest_neighbor: Resize params is a pair of values, not %i.", newImageSize->lengthOf());
|
|
REQUIRE_TRUE(block.numI() <= 1, 0, "resize_nearest_neighbor: Resize params already given by the second param. Int params are expensive.");
|
|
height = newImageSize->e<int>(0);
|
|
width = newImageSize->e<int>(1);
|
|
}
|
|
else {
|
|
REQUIRE_TRUE(block.numI() == 2, 0, "resize_nearest_neighbor: Neither resize width nor height are provided.");
|
|
height = INT_ARG(0);
|
|
width = INT_ARG(1);
|
|
}
|
|
if (block.numB() > 0)
|
|
alignCorners = B_ARG(0);
|
|
bool halfPixelCenter = false;
|
|
|
|
if (block.numB() > 1)
|
|
halfPixelCenter = B_ARG(1);
|
|
REQUIRE_TRUE(width <= (1 << 24) || height <= (1 << 24), 0, "resize_nearest_neighbor: the image resize should be limited to 2^24 pixels both for height and width, but %d and %d were given.", height, width);
|
|
REQUIRE_TRUE(inRank == 4 || inRank == 3, 0, "resize_nearest_neighbor: Input should be 4D tensor, but rank %i occured");
|
|
REQUIRE_TRUE(inRank == output->rankOf(), 0, "resize_nearest_neighbor: Input and output ranks should be equals, but %i and %i occured.", inRank, output->rankOf());
|
|
REQUIRE_TRUE(image->dataType() == output->dataType(), 0, "resize_nearest_neighbor: Input and output types should be the same, but `%s' occured instead.", DataTypeUtils::asString(output->dataType()).c_str());
|
|
REQUIRE_TRUE(!halfPixelCenter || (halfPixelCenter && !alignCorners), 0, "resize_nearest_neighbor: `half_pixel_centers' should be false or true only when `align_corners' is false");
|
|
REQUIRE_TRUE(((alignCorners && height > 2) || (height > 0)) && ((alignCorners && width > 1) || (width > 0)), 0, "resize_nearest_neighbor: Wrong input or output size to resize (width = %d, height = %d)", width, height);
|
|
|
|
auto source = inRank == 4?*image:image->reshape(image->ordering(), {1, image->sizeAt(0), image->sizeAt(1), image->sizeAt(2)});
|
|
auto target = inRank == 4 ? *output : output->reshape(output->ordering(), {1, output->sizeAt(0), output->sizeAt(1), output->sizeAt(2)}, false);
|
|
|
|
return helpers::resizeNeighborFunctor(block.launchContext(), inRank==4?image:&source, width, height, alignCorners, halfPixelCenter, inRank == 4 ? output : &target);
|
|
}
|
|
|
|
DECLARE_SHAPE_FN(resize_nearest_neighbor) {
|
|
auto shapeList = SHAPELIST();
|
|
auto in = inputShape->at(0);
|
|
auto inRank = shape::rank(in);
|
|
Nd4jLong* outputShape;
|
|
|
|
REQUIRE_TRUE(inRank == 4 || inRank == 3, 0, "resize_nearest_neighbor: input image should be 4D "
|
|
"tensor, but input has rank %i",
|
|
inRank);
|
|
|
|
int width;
|
|
int height;
|
|
if (block.width() > 1) {
|
|
auto newImageSize = INPUT_VARIABLE(1);
|
|
REQUIRE_TRUE(newImageSize->lengthOf() == 2, 0, "resize_nearest_neighbor: Resize params is a pair of values, not %i.", newImageSize->lengthOf());
|
|
REQUIRE_TRUE(block.numI() <= 1, 0, "resize_nearest_neighbor: Resize params already given by the second param. Int params are expensive.");
|
|
width = newImageSize->e<int>(0);
|
|
height = newImageSize->e<int>(1);
|
|
}
|
|
else {
|
|
REQUIRE_TRUE(block.numI() <= 3, 0, "resize_nearest_neighbor: Neither resize width nor height are provided.");
|
|
width = INT_ARG(0);
|
|
height = INT_ARG(1);
|
|
}
|
|
|
|
ALLOCATE(outputShape, block.getWorkspace(), shape::shapeInfoLength(inRank), Nd4jLong);
|
|
outputShape[0] = inRank;
|
|
if (inRank == 4) {
|
|
outputShape[1] = in[1];
|
|
outputShape[2] = width;
|
|
outputShape[3] = height;
|
|
outputShape[4] = in[4];
|
|
}
|
|
else { // input shape is 3D, so result also should be 3D
|
|
outputShape[1] = width;
|
|
outputShape[2] = height;
|
|
outputShape[3] = in[3];
|
|
}
|
|
ShapeUtils::updateStridesAndType(outputShape, in, shape::order(in));
|
|
|
|
shapeList->push_back(CONSTANT(outputShape));
|
|
return shapeList;
|
|
}
|
|
DECLARE_TYPES(resize_nearest_neighbor) {
|
|
getOpDescriptor()
|
|
->setAllowedInputTypes({ALL_INTS, ALL_FLOATS})
|
|
->setAllowedOutputTypes({ALL_INTS, ALL_FLOATS});
|
|
}
|
|
|
|
}
|
|
}
|
|
|
|
#endif |