From 211c0df76f734028ebd1bf0a72825b6ba9fa9da6 Mon Sep 17 00:00:00 2001 From: Oleh Date: Fri, 20 Dec 2019 19:59:29 +0200 Subject: [PATCH] Oleh rgb to gray scale (#138) * libnd4j: RgbToGrayscale op #8536 - raw implementation in user branch, need checks for integration and adding other orders Signed-off-by: Oleg * libnd4j: RgbToGrayscale op #8536 next step of merging images Signed-off-by: Oleg * libnd4j: RgbToGrayscale op #8536, Revert merge of hsv_to_rgb and rgb_to_hsv as cause conflicts in naming need refactoring before merge, implementation of rbg_to_grs added * libnd4j: RgbToGrayscale op #8536 imlementation and conflict resolve * libnd4j: RgbToGrayscale op #8536 merged operations with images into image, renamed methods and files * libnd4j: RgbToGrayscale op #8536 added test for rgbToGrayScale, need clarification and fixed tests case run Signed-off-by: Oleg * libnd4j: RgbToGrayscale op #8536 bug fixing and need review * libnd4j: RgbToGrayscale op #8536 some additional corrections after review Signed-off-by: Oleg * - minor corrections in rgbToGrs test1 Signed-off-by: Yurii * libnd4j: RgbToGrayscale op #8536, corrected tests and rbf_to_grs, fixed problems, refactoring, need review * libnd4j: RgbToGrayscale op #8536 fix for 'f' order in rgbToGrs * libnd4j: RgbToGrayscale op #8536 fixed several bugs with dimC, test case refactoring and improve Signed-off-by: Oleg * - add cuda kernel for rgbToGrs op Signed-off-by: Yurii * - fix linkage errors Signed-off-by: Yurii Co-authored-by: Yurii Shyrma --- .../include/ops/declarable/CustomOperations.h | 2 +- .../generic/color_models/hsv_rgb_ops.cpp | 85 ------- .../declarable/generic/images/hsvToRgb.cpp | 59 +++++ .../declarable/generic/images/rgbToGrs.cpp | 74 ++++++ .../declarable/generic/images/rgbToHsv.cpp | 62 +++++ .../ops/declarable/headers/color_models.h | 56 ----- .../include/ops/declarable/headers/images.h | 71 ++++++ .../ops/declarable/helpers/cpu/hsv_rgb.cpp | 90 ------- .../declarable/helpers/cpu/imagesHelpers.cpp | 137 +++++++++++ .../ops/declarable/helpers/cuda/hsv_rgb.cu | 139 ----------- .../declarable/helpers/cuda/imagesHelpers.cu | 228 ++++++++++++++++++ .../{color_models_conv.h => imagesHelpers.h} | 31 ++- .../layers_tests/DeclarableOpsTests15.cpp | 147 +++++++++++ .../layers_tests/DeclarableOpsTests16.cpp | 1 - 14 files changed, 801 insertions(+), 381 deletions(-) delete mode 100644 libnd4j/include/ops/declarable/generic/color_models/hsv_rgb_ops.cpp create mode 100644 libnd4j/include/ops/declarable/generic/images/hsvToRgb.cpp create mode 100644 libnd4j/include/ops/declarable/generic/images/rgbToGrs.cpp create mode 100644 libnd4j/include/ops/declarable/generic/images/rgbToHsv.cpp delete mode 100644 libnd4j/include/ops/declarable/headers/color_models.h create mode 100644 libnd4j/include/ops/declarable/headers/images.h delete mode 100644 libnd4j/include/ops/declarable/helpers/cpu/hsv_rgb.cpp create mode 100644 libnd4j/include/ops/declarable/helpers/cpu/imagesHelpers.cpp delete mode 100644 libnd4j/include/ops/declarable/helpers/cuda/hsv_rgb.cu create mode 100644 libnd4j/include/ops/declarable/helpers/cuda/imagesHelpers.cu rename libnd4j/include/ops/declarable/helpers/{color_models_conv.h => imagesHelpers.h} (60%) diff --git a/libnd4j/include/ops/declarable/CustomOperations.h b/libnd4j/include/ops/declarable/CustomOperations.h index ebbc4b113..7d699c49b 100644 --- a/libnd4j/include/ops/declarable/CustomOperations.h +++ b/libnd4j/include/ops/declarable/CustomOperations.h @@ -41,7 +41,7 @@ #include #include #include -#include +#include #include #include #include diff --git a/libnd4j/include/ops/declarable/generic/color_models/hsv_rgb_ops.cpp b/libnd4j/include/ops/declarable/generic/color_models/hsv_rgb_ops.cpp deleted file mode 100644 index 83d3c84ea..000000000 --- a/libnd4j/include/ops/declarable/generic/color_models/hsv_rgb_ops.cpp +++ /dev/null @@ -1,85 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2019 Konduit K.K. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - - -#include -#include -#include -#include - -namespace nd4j { - namespace ops { - - - - CONFIGURABLE_OP_IMPL(hsv_to_rgb, 1, 1, false, 0, 0) { - - auto input = INPUT_VARIABLE(0); - auto output = OUTPUT_VARIABLE(0); - - if (input->isEmpty()) - return Status::OK(); - - const int rank = input->rankOf(); - const int arg_size = block.getIArguments()->size(); - const int dimC = arg_size > 0 ? (INT_ARG(0) >= 0 ? INT_ARG(0) : INT_ARG(0) + rank) : rank - 1; - - REQUIRE_TRUE(rank >= 1, 0, "HSVtoRGB: Fails to meet the rank requirement: %i >= 1 ", rank); - if (arg_size > 0) { - REQUIRE_TRUE(dimC >= 0 && dimC < rank, 0, "Index of the Channel dimension out of range: %i not in [%i,%i) ", INT_ARG(0), -rank, rank); - } - REQUIRE_TRUE(input->sizeAt(dimC) == 3, 0, "HSVtoRGB: operation expects 3 channels (H, S, V), but got %i instead", input->sizeAt(dimC)); - - helpers::transform_hsv_rgb(block.launchContext(), input, output, dimC); - - return Status::OK(); - } - - CONFIGURABLE_OP_IMPL(rgb_to_hsv, 1, 1, false, 0, 0) { - - auto input = INPUT_VARIABLE(0); - auto output = OUTPUT_VARIABLE(0); - - if (input->isEmpty()) - return Status::OK(); - - const int rank = input->rankOf(); - const int arg_size = block.getIArguments()->size(); - const int dimC = arg_size > 0 ? (INT_ARG(0) >= 0 ? INT_ARG(0) : INT_ARG(0) + rank) : rank - 1; - - REQUIRE_TRUE(rank >= 1, 0, "RGBtoHSV: Fails to meet the rank requirement: %i >= 1 ", rank); - if (arg_size > 0) { - REQUIRE_TRUE(dimC >= 0 && dimC < rank, 0, "Index of the Channel dimension out of range: %i not in [%i,%i) ", INT_ARG(0), -rank, rank); - } - REQUIRE_TRUE(input->sizeAt(dimC) == 3, 0, "RGBtoHSV: operation expects 3 channels (H, S, V), but got %i instead", input->sizeAt(dimC)); - - helpers::transform_rgb_hsv(block.launchContext(), input, output, dimC); - - return Status::OK(); - } - - - DECLARE_TYPES(hsv_to_rgb) { - getOpDescriptor()->setAllowedInputTypes({ ALL_FLOATS }) - ->setSameMode(true); - } - - DECLARE_TYPES(rgb_to_hsv) { - getOpDescriptor()->setAllowedInputTypes({ ALL_FLOATS }) - ->setSameMode(true); - } - } -} diff --git a/libnd4j/include/ops/declarable/generic/images/hsvToRgb.cpp b/libnd4j/include/ops/declarable/generic/images/hsvToRgb.cpp new file mode 100644 index 000000000..24c7c66e0 --- /dev/null +++ b/libnd4j/include/ops/declarable/generic/images/hsvToRgb.cpp @@ -0,0 +1,59 @@ +/******************************************************************************* + * Copyright (c) 2019 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author Adel Rauf (rauf@konduit.ai) +// + +#include +#include +#include +#include + +namespace nd4j { +namespace ops { + +CONFIGURABLE_OP_IMPL(hsv_to_rgb, 1, 1, false, 0, 0) { + + auto input = INPUT_VARIABLE(0); + auto output = OUTPUT_VARIABLE(0); + + if (input->isEmpty()) + return Status::OK(); + + const int rank = input->rankOf(); + const int argSize = block.getIArguments()->size(); + const int dimC = argSize > 0 ? (INT_ARG(0) >= 0 ? INT_ARG(0) : INT_ARG(0) + rank) : rank - 1; + + REQUIRE_TRUE(rank >= 1, 0, "HSVtoRGB: Fails to meet the rank requirement: %i >= 1 ", rank); + if (argSize > 0) { + REQUIRE_TRUE(dimC >= 0 && dimC < rank, 0, "Index of the Channel dimension out of range: %i not in [%i,%i) ", INT_ARG(0), -rank, rank); + } + REQUIRE_TRUE(input->sizeAt(dimC) == 3, 0, "HSVtoRGB: operation expects 3 channels (H, S, V), but got %i instead", input->sizeAt(dimC)); + + helpers::transformHsvRgb(block.launchContext(), input, output, dimC); + + return Status::OK(); +} + +DECLARE_TYPES(hsv_to_rgb) { + getOpDescriptor()->setAllowedInputTypes({ ALL_FLOATS }) + ->setSameMode(true); +} + + +} +} diff --git a/libnd4j/include/ops/declarable/generic/images/rgbToGrs.cpp b/libnd4j/include/ops/declarable/generic/images/rgbToGrs.cpp new file mode 100644 index 000000000..aa2bec9da --- /dev/null +++ b/libnd4j/include/ops/declarable/generic/images/rgbToGrs.cpp @@ -0,0 +1,74 @@ +/******************************************************************************* + * Copyright (c) 2019 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author Oleh Semeniv (oleg.semeniv@gmail.com) +// + +#include +#include +#include +#include + +namespace nd4j { +namespace ops { + +CUSTOM_OP_IMPL(rgb_to_grs, 1, 1, false, 0, 0) { + + const auto input = INPUT_VARIABLE(0); + auto output = OUTPUT_VARIABLE(0); + + const int inRank = input->rankOf(); + const int argSize = block.getIArguments()->size(); + const int dimC = argSize > 0 ? (INT_ARG(0) >= 0 ? INT_ARG(0) : INT_ARG(0) + inRank) : inRank - 1; + + REQUIRE_TRUE(inRank >= 1, 0, "RGBtoGrayScale: Fails to meet the inRank requirement: %i >= 1 ", inRank); + if (argSize > 0) { + REQUIRE_TRUE(dimC >= 0 && dimC < inRank, 0, "Index of the Channel dimension out of range: %i not in [%i,%i) ", INT_ARG(0), -inRank, inRank); + } + REQUIRE_TRUE(input->sizeAt(dimC) == 3, 0, "RGBGrayScale: operation expects 3 channels (R, G, B) in last dimention, but received %i instead", input->sizeAt(dimC)); + + helpers::transformRgbGrs(block.launchContext(), *input, *output, dimC); + return Status::OK(); +} + +DECLARE_TYPES(rgb_to_grs) { + getOpDescriptor()->setAllowedInputTypes( {ALL_INTS, ALL_FLOATS} ) + ->setSameMode(true); +} + +DECLARE_SHAPE_FN(rgb_to_grs) { + + const auto input = INPUT_VARIABLE(0); + const int inRank = input->rankOf(); + + const int argSize = block.getIArguments()->size(); + const int dimC = argSize > 0 ? (INT_ARG(0) >= 0 ? INT_ARG(0) : INT_ARG(0) + inRank) : inRank - 1; + + REQUIRE_TRUE(inRank >= 1, 0, "RGBtoGrayScale: Fails to meet the inRank requirement: %i >= 1 ", inRank); + if (argSize > 0) { + REQUIRE_TRUE(dimC >= 0 && dimC < inRank, 0, "Index of the Channel dimension out of range: %i not in [%i,%i) ", INT_ARG(0), -inRank, inRank); + } + REQUIRE_TRUE(input->sizeAt(dimC) == 3, 0, "RGBtoGrayScale: operation expects 3 channels (R, B, G) in last dimention, but received %i", dimC); + + auto nShape = input->getShapeAsVector(); + nShape[dimC] = 1; + + return SHAPELIST(ConstantShapeHelper::getInstance()->createShapeInfo(input->dataType(), input->ordering(), nShape)); +} + +} +} diff --git a/libnd4j/include/ops/declarable/generic/images/rgbToHsv.cpp b/libnd4j/include/ops/declarable/generic/images/rgbToHsv.cpp new file mode 100644 index 000000000..57faba562 --- /dev/null +++ b/libnd4j/include/ops/declarable/generic/images/rgbToHsv.cpp @@ -0,0 +1,62 @@ +/******************************************************************************* + * Copyright (c) 2019 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author Adel Rauf (rauf@konduit.ai) +// + + + +#include +#include +#include +#include + +namespace nd4j { +namespace ops { + +CONFIGURABLE_OP_IMPL(rgb_to_hsv, 1, 1, false, 0, 0) { + + auto input = INPUT_VARIABLE(0); + auto output = OUTPUT_VARIABLE(0); + + if (input->isEmpty()) + return Status::OK(); + + const int rank = input->rankOf(); + const int argSize = block.getIArguments()->size(); + const int dimC = argSize > 0 ? (INT_ARG(0) >= 0 ? INT_ARG(0) : INT_ARG(0) + rank) : rank - 1; + + REQUIRE_TRUE(rank >= 1, 0, "RGBtoHSV: Fails to meet the rank requirement: %i >= 1 ", rank); + if (argSize > 0) { + REQUIRE_TRUE(dimC >= 0 && dimC < rank, 0, "Index of the Channel dimension out of range: %i not in [%i,%i) ", INT_ARG(0), -rank, rank); + } + REQUIRE_TRUE(input->sizeAt(dimC) == 3, 0, "RGBtoHSV: operation expects 3 channels (H, S, V), but got %i instead", input->sizeAt(dimC)); + + helpers::transformRgbHsv(block.launchContext(), input, output, dimC); + + return Status::OK(); +} + + +DECLARE_TYPES(rgb_to_hsv) { + getOpDescriptor()->setAllowedInputTypes({ ALL_FLOATS }) + ->setSameMode(true); +} + + +} +} diff --git a/libnd4j/include/ops/declarable/headers/color_models.h b/libnd4j/include/ops/declarable/headers/color_models.h deleted file mode 100644 index 5574749c2..000000000 --- a/libnd4j/include/ops/declarable/headers/color_models.h +++ /dev/null @@ -1,56 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2019 Konduit K.K. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - - - -#ifndef LIBND4J_HEADERS_COLOR_MODELS_H -#define LIBND4J_HEADERS_COLOR_MODELS_H - -#include -#include -#include -#include -#include - -namespace nd4j { - namespace ops { - - /** - * Rgb To Hsv - * Input arrays: - * 0 - input array with rank >= 1, must have at least one dimension equal 3, that is dimension containing channels. - * Int arguments: - * 0 - optional argument, corresponds to dimension with 3 channels - */ -#if NOT_EXCLUDED(OP_rgb_to_hsv) - DECLARE_CONFIGURABLE_OP(rgb_to_hsv, 1, 1, false, 0, 0); -#endif - - /** - * Hsv To Rgb - * Input arrays: - * 0 - input array with rank >= 1, must have at least one dimension equal 3, that is dimension containing channels. - * Int arguments: - * 0 - optional argument, corresponds to dimension with 3 channels - */ -#if NOT_EXCLUDED(OP_hsv_to_rgb) - DECLARE_CONFIGURABLE_OP(hsv_to_rgb, 1, 1, false, 0, 0); -#endif - - } -} - -#endif diff --git a/libnd4j/include/ops/declarable/headers/images.h b/libnd4j/include/ops/declarable/headers/images.h new file mode 100644 index 000000000..d92fd62dc --- /dev/null +++ b/libnd4j/include/ops/declarable/headers/images.h @@ -0,0 +1,71 @@ +/******************************************************************************* + * Copyright (c) 2019 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author Oleh Semeniv (oleg.semeniv@gmail.com) +// +// +// @author Adel Rauf (rauf@konduit.ai) +// + +#ifndef LIBND4J_HEADERS_IMAGES_H +#define LIBND4J_HEADERS_IMAGES_H + +#include +#include +#include +#include +#include + +namespace nd4j { +namespace ops { + + +/** + * Rgb To Hsv + * Input arrays: + * 0 - input array with rank >= 1, must have at least one dimension equal 3, that is dimension containing channels. + * Int arguments: + * 0 - optional argument, corresponds to dimension with 3 channels + */ +#if NOT_EXCLUDED(OP_rgb_to_hsv) + DECLARE_CONFIGURABLE_OP(rgb_to_hsv, 1, 1, false, 0, 0); +#endif + +/** + * Hsv To Rgb + * Input arrays: + * 0 - input array with rank >= 1, must have at least one dimension equal 3, that is dimension containing channels. + * Int arguments: + * 0 - optional argument, corresponds to dimension with 3 channels + */ +#if NOT_EXCLUDED(OP_hsv_to_rgb) + DECLARE_CONFIGURABLE_OP(hsv_to_rgb, 1, 1, false, 0, 0); +#endif + +/** +* Rgb To GrayScale +* Input arrays: +* 0 - input array with rank >= 1, the RGB tensor to convert. Last dimension must have size 3 and should contain RGB values. +*/ +#if NOT_EXCLUDED(OP_rgb_to_grs) + DECLARE_CUSTOM_OP(rgb_to_grs, 1, 1, false, 0, 0); +#endif + +} +} + +#endif diff --git a/libnd4j/include/ops/declarable/helpers/cpu/hsv_rgb.cpp b/libnd4j/include/ops/declarable/helpers/cpu/hsv_rgb.cpp deleted file mode 100644 index 86ca2e991..000000000 --- a/libnd4j/include/ops/declarable/helpers/cpu/hsv_rgb.cpp +++ /dev/null @@ -1,90 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2019 Konduit K.K. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -#include -#include -#include -#include - -namespace nd4j { - namespace ops { - namespace helpers { - - //local - template - FORCEINLINE static void triple_transformer(const NDArray* input, NDArray* output, const int dimC, Op op) { - - const int rank = input->rankOf(); - - const T* x = input->bufferAsT(); - T* z = output->bufferAsT(); - - if (dimC == rank - 1 && input->ews() == 1 && output->ews() == 1 && input->ordering() == 'c' && output->ordering() == 'c') { - - auto func = PRAGMA_THREADS_FOR{ - for (auto i = start; i < stop; i += increment) { - op(x[i], x[i + 1], x[i + 2], z[i], z[i + 1], z[i + 2]); - } - }; - - samediff::Threads::parallel_for(func, 0, input->lengthOf(), 3); - } - else { - auto packX = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(input->getShapeInfo(), dimC); - auto packZ = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(output->getShapeInfo(), dimC); - - const Nd4jLong numOfTads = packX.numberOfTads(); - const Nd4jLong xDimCstride = input->stridesOf()[dimC]; - const Nd4jLong zDimCstride = output->stridesOf()[dimC]; - - auto func = PRAGMA_THREADS_FOR{ - for (auto i = start; i < stop; i += increment) { - const T* xTad = x + packX.platformOffsets()[i]; - T* zTad = z + packZ.platformOffsets()[i]; - op(xTad[0], xTad[xDimCstride], xTad[2 * xDimCstride], zTad[0], zTad[zDimCstride], zTad[2 * zDimCstride]); - - } - }; - - samediff::Threads::parallel_tad(func, 0, numOfTads); - } - } - - - - template - FORCEINLINE static void hsv_rgb(const NDArray* input, NDArray* output, const int dimC) { - auto op = nd4j::ops::helpers::hsvToRgb; - return triple_transformer(input, output, dimC, op); - } - - template - FORCEINLINE static void rgb_hsv(const NDArray* input, NDArray* output, const int dimC) { - auto op = nd4j::ops::helpers::rgbToHsv; - return triple_transformer(input, output, dimC, op); - } - - void transform_hsv_rgb(nd4j::LaunchContext* context, const NDArray* input, NDArray* output, const int dimC) { - BUILD_SINGLE_SELECTOR(input->dataType(), hsv_rgb, (input, output, dimC), FLOAT_TYPES); - } - - void transform_rgb_hsv(nd4j::LaunchContext* context, const NDArray* input, NDArray* output, const int dimC) { - BUILD_SINGLE_SELECTOR(input->dataType(), rgb_hsv, (input, output, dimC), FLOAT_TYPES); - } - - } - } -} \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/helpers/cpu/imagesHelpers.cpp b/libnd4j/include/ops/declarable/helpers/cpu/imagesHelpers.cpp new file mode 100644 index 000000000..bc28dfa6a --- /dev/null +++ b/libnd4j/include/ops/declarable/helpers/cpu/imagesHelpers.cpp @@ -0,0 +1,137 @@ +/******************************************************************************* + * Copyright (c) 2019 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author Oleh Semeniv (oleg.semeniv@gmail.com) +// @author Adel Rauf (rauf@konduit.ai) +// + +#include +#include +#include +#include + +namespace nd4j { +namespace ops { +namespace helpers { + +template +static void rgbToGrs_(const NDArray& input, NDArray& output, const int dimC) { + + const T* x = input.bufferAsT(); + T* z = output.bufferAsT(); + const int rank = input.rankOf(); + + if(dimC == rank - 1 && 'c' == input.ordering() && 1 == input.ews() && + 'c' == output.ordering() && 1 == output.ews()){ + + auto func = PRAGMA_THREADS_FOR{ + for (auto i = start; i < stop; i += increment) { + const auto xStep = i*3; + z[i] = 0.2989f*x[xStep] + 0.5870f*x[xStep + 1] + 0.1140f*x[xStep + 2]; + } + }; + + samediff::Threads::parallel_for(func, 0, output.lengthOf(), 1); + return; + } + + auto func = PRAGMA_THREADS_FOR{ + + Nd4jLong coords[MAX_RANK]; + for (auto i = start; i < stop; i += increment) { + shape::index2coords(i, output.getShapeInfo(), coords); + const auto zOffset = shape::getOffset(output.getShapeInfo(), coords); + const auto xOffset0 = shape::getOffset(input.getShapeInfo(), coords); + const auto xOffset1 = xOffset0 + input.strideAt(dimC); + const auto xOffset2 = xOffset1 + input.strideAt(dimC); + z[zOffset] = 0.2989f*x[xOffset0] + 0.5870f*x[xOffset1] + 0.1140f*x[xOffset2]; + } + }; + + samediff::Threads::parallel_for(func, 0, output.lengthOf(), 1); + return; +} + +void transformRgbGrs(nd4j::LaunchContext* context, const NDArray& input, NDArray& output, const int dimC) { + BUILD_SINGLE_SELECTOR(input.dataType(), rgbToGrs_, (input, output, dimC), NUMERIC_TYPES); +} + + +template +FORCEINLINE static void tripleTransformer(const NDArray* input, NDArray* output, const int dimC, Op op) { + + const int rank = input->rankOf(); + + const T* x = input->bufferAsT(); + T* z = output->bufferAsT(); + + if (dimC == rank - 1 && input->ews() == 1 && output->ews() == 1 && input->ordering() == 'c' && output->ordering() == 'c') { + + auto func = PRAGMA_THREADS_FOR{ + for (auto i = start; i < stop; i += increment) { + op(x[i], x[i + 1], x[i + 2], z[i], z[i + 1], z[i + 2]); + } + }; + + samediff::Threads::parallel_for(func, 0, input->lengthOf(), 3); + } + else { + auto packX = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(input->getShapeInfo(), dimC); + auto packZ = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(output->getShapeInfo(), dimC); + + const Nd4jLong numOfTads = packX.numberOfTads(); + const Nd4jLong xDimCstride = input->stridesOf()[dimC]; + const Nd4jLong zDimCstride = output->stridesOf()[dimC]; + + auto func = PRAGMA_THREADS_FOR{ + for (auto i = start; i < stop; i += increment) { + const T* xTad = x + packX.platformOffsets()[i]; + T* zTad = z + packZ.platformOffsets()[i]; + op(xTad[0], xTad[xDimCstride], xTad[2 * xDimCstride], zTad[0], zTad[zDimCstride], zTad[2 * zDimCstride]); + + } + }; + + samediff::Threads::parallel_tad(func, 0, numOfTads); + } +} + + + +template +FORCEINLINE static void hsvRgb(const NDArray* input, NDArray* output, const int dimC) { + auto op = nd4j::ops::helpers::hsvToRgb; + return tripleTransformer(input, output, dimC, op); +} + +template +FORCEINLINE static void rgbHsv(const NDArray* input, NDArray* output, const int dimC) { + auto op = nd4j::ops::helpers::rgbToHsv; + return tripleTransformer(input, output, dimC, op); +} + +void transformHsvRgb(nd4j::LaunchContext* context, const NDArray* input, NDArray* output, const int dimC) { + BUILD_SINGLE_SELECTOR(input->dataType(), hsvRgb, (input, output, dimC), FLOAT_TYPES); +} + +void transformRgbHsv(nd4j::LaunchContext* context, const NDArray* input, NDArray* output, const int dimC) { + BUILD_SINGLE_SELECTOR(input->dataType(), rgbHsv, (input, output, dimC), FLOAT_TYPES); +} + +} +} +} \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/helpers/cuda/hsv_rgb.cu b/libnd4j/include/ops/declarable/helpers/cuda/hsv_rgb.cu deleted file mode 100644 index 77faaa242..000000000 --- a/libnd4j/include/ops/declarable/helpers/cuda/hsv_rgb.cu +++ /dev/null @@ -1,139 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2019 Konduit K.K. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ -#include -#include -#include -#include -#include - -namespace nd4j { - namespace ops { - namespace helpers { - - template - static void _CUDA_G rgbToHsvCuda(const void* vx, const Nd4jLong* xShapeInfo, const Nd4jLong* xTadOffsets, - void* vz, const Nd4jLong *zShapeInfo, const Nd4jLong* zTadOffsets, - const Nd4jLong numOfTads, const int dimC) { - - const T* x = reinterpret_cast(vx); - T* z = reinterpret_cast(vz); - - __shared__ int rank; - __shared__ Nd4jLong xDimCstride, zDimCstride; - - if (threadIdx.x == 0) { - rank = shape::rank(xShapeInfo); - xDimCstride = shape::stride(xShapeInfo)[dimC]; - zDimCstride = shape::stride(zShapeInfo)[dimC]; - } - __syncthreads(); - - const auto tid = blockIdx.x * blockDim.x + threadIdx.x; - - for (Nd4jLong i = tid; i < numOfTads; i += gridDim.x * blockDim.x) { - const T* xTad = x + xTadOffsets[i]; - T* zTad = z + zTadOffsets[i]; - - rgbToHsv(xTad[0], xTad[xDimCstride], xTad[2 * xDimCstride], zTad[0], zTad[zDimCstride], zTad[2 * zDimCstride]); - } - } - - template - static void _CUDA_G hsvToRgbCuda(const void* vx, const Nd4jLong* xShapeInfo, const Nd4jLong* xTadOffsets, - void* vz, const Nd4jLong *zShapeInfo, const Nd4jLong* zTadOffsets, - const Nd4jLong numOfTads, const int dimC) { - - const T* x = reinterpret_cast(vx); - T* z = reinterpret_cast(vz); - - __shared__ int rank; - __shared__ Nd4jLong xDimCstride, zDimCstride; - - if (threadIdx.x == 0) { - rank = shape::rank(xShapeInfo); - xDimCstride = shape::stride(xShapeInfo)[dimC]; - zDimCstride = shape::stride(zShapeInfo)[dimC]; - } - __syncthreads(); - - const auto tid = blockIdx.x * blockDim.x + threadIdx.x; - - for (Nd4jLong i = tid; i < numOfTads; i += gridDim.x * blockDim.x) { - const T* xTad = x + xTadOffsets[i]; - T* zTad = z + zTadOffsets[i]; - - hsvToRgb(xTad[0], xTad[xDimCstride], xTad[2 * xDimCstride], zTad[0], zTad[zDimCstride], zTad[2 * zDimCstride]); - } - } - - /////////////////////////////////////////////////////////////////// - template - static _CUDA_H void hsvToRgbCudaLauncher(const int blocksPerGrid, const int threadsPerBlock, const cudaStream_t *stream, - const void* vx, const Nd4jLong* xShapeInfo, const Nd4jLong* xTadOffsets, - void* vz, const Nd4jLong* zShapeInfo, const Nd4jLong* zTadOffsets, - const Nd4jLong numOfTads, const int dimC) { - - hsvToRgbCuda<<>>(vx, xShapeInfo, xTadOffsets, vz, zShapeInfo, zTadOffsets, numOfTads, dimC); - } - - template - static _CUDA_H void rgbToHsvCudaLauncher(const int blocksPerGrid, const int threadsPerBlock, const cudaStream_t *stream, - const void* vx, const Nd4jLong* xShapeInfo, const Nd4jLong* xTadOffsets, - void* vz, const Nd4jLong* zShapeInfo, const Nd4jLong* zTadOffsets, - const Nd4jLong numOfTads, const int dimC) { - - rgbToHsvCuda<<>>(vx, xShapeInfo, xTadOffsets, vz, zShapeInfo, zTadOffsets, numOfTads, dimC); - } - - - void transform_hsv_rgb(nd4j::LaunchContext* context, const NDArray* input, NDArray* output, const int dimC) { - auto packX = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(input->getShapeInfo(), {dimC}); - auto packZ = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(output->getShapeInfo(), {dimC}); - - const Nd4jLong numOfTads = packX.numberOfTads(); - - const int threadsPerBlock = MAX_NUM_THREADS / 2; - const int blocksPerGrid = (numOfTads + threadsPerBlock - 1) / threadsPerBlock; - - PointersManager manager(context, "hsv_to_rgb"); - - NDArray::prepareSpecialUse({output}, {input}); - BUILD_SINGLE_SELECTOR(input->dataType(), hsvToRgbCudaLauncher, (blocksPerGrid, threadsPerBlock, context->getCudaStream(), input->getSpecialBuffer(), input->getSpecialShapeInfo(), packX.platformOffsets(), output->specialBuffer(), output->specialShapeInfo(), packZ.platformOffsets(), numOfTads, dimC), FLOAT_TYPES); - NDArray::registerSpecialUse({output}, {input}); - - manager.synchronize(); - } - - void transform_rgb_hsv(nd4j::LaunchContext* context, const NDArray* input, NDArray* output, const int dimC) { - auto packX = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(input->getShapeInfo(), {dimC}); - auto packZ = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(output->getShapeInfo(), {dimC}); - - const Nd4jLong numOfTads = packX.numberOfTads(); - - const int threadsPerBlock = MAX_NUM_THREADS / 2; - const int blocksPerGrid = (numOfTads + threadsPerBlock - 1) / threadsPerBlock; - - PointersManager manager(context, "rgb_to_hsv"); - - NDArray::prepareSpecialUse({output}, {input}); - BUILD_SINGLE_SELECTOR(input->dataType(), rgbToHsvCudaLauncher, (blocksPerGrid, threadsPerBlock, context->getCudaStream(), input->getSpecialBuffer(), input->getSpecialShapeInfo(), packX.platformOffsets(), output->specialBuffer(), output->specialShapeInfo(), packZ.platformOffsets(), numOfTads, dimC), FLOAT_TYPES); - NDArray::registerSpecialUse({output}, {input}); - - manager.synchronize(); - } - } - } -} \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/helpers/cuda/imagesHelpers.cu b/libnd4j/include/ops/declarable/helpers/cuda/imagesHelpers.cu new file mode 100644 index 000000000..2731569a0 --- /dev/null +++ b/libnd4j/include/ops/declarable/helpers/cuda/imagesHelpers.cu @@ -0,0 +1,228 @@ +/******************************************************************************* + * Copyright (c) 2019 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author Yurii Shyrma (iuriish@yahoo.com) +// + +#include +#include +#include +#include +#include + + +namespace nd4j { +namespace ops { +namespace helpers { + +/////////////////////////////////////////////////////////////////// +// for example xShapeInfo = {2,3,4}, zShapeInfo = {2,1,4} +template +__global__ void rgbToGrsCuda(const void *vx, const Nd4jLong *xShapeInfo, void *vz, const Nd4jLong *zShapeInfo, const int dimC) { + + const auto x = reinterpret_cast(vx); + auto z = reinterpret_cast(vz); + + __shared__ Nd4jLong zLen, *sharedMem; + __shared__ int rank; // xRank == zRank + + if (threadIdx.x == 0) { + extern __shared__ unsigned char shmem[]; + sharedMem = reinterpret_cast(shmem); + + zLen = shape::length(zShapeInfo); + rank = shape::rank(zShapeInfo); + } + __syncthreads(); + + Nd4jLong* coords = sharedMem + threadIdx.x * rank; + + for (Nd4jLong i = blockIdx.x * blockDim.x + threadIdx.x; i < zLen; i += gridDim.x * blockDim.x) { + + if (dimC == (rank - 1) && 'c' == shape::order(xShapeInfo) && 1 == shape::elementWiseStride(xShapeInfo) && 'c' == shape::order(zShapeInfo) && 1 == shape::elementWiseStride(zShapeInfo)) { + const auto xStep = i*3; + z[i] = 0.2989f * x[xStep] + 0.5870f * x[xStep + 1] + 0.1140f * x[xStep + 2]; + } + else { + + shape::index2coords(i, zShapeInfo, coords); + + const auto zOffset = shape::getOffset(zShapeInfo, coords); + const auto xOffset0 = shape::getOffset(xShapeInfo, coords); + const auto xOffset1 = xOffset0 + shape::stride(xShapeInfo)[dimC]; + const auto xOffset2 = xOffset1 + shape::stride(xShapeInfo)[dimC]; + + z[zOffset] = 0.2989f * x[xOffset0] + 0.5870f * x[xOffset1] + 0.1140f * x[xOffset2]; + } + } +} + +/////////////////////////////////////////////////////////////////// +template +linkage void rgbToGrsCudaLauncher(const int blocksPerGrid, const int threadsPerBlock, const int sharedMem, const cudaStream_t *stream, const void *vx, const Nd4jLong *xShapeInfo, void *vz, const Nd4jLong *zShapeInfo, const int dimC) { + + rgbToGrsCuda<<>>(vx, xShapeInfo, vz, zShapeInfo, dimC); +} + +/////////////////////////////////////////////////////////////////// +void transformRgbGrs(nd4j::LaunchContext* context, const NDArray& input, NDArray& output, const int dimC) { + + PointersManager manager(context, "rgbToGrs"); + + const int threadsPerBlock = MAX_NUM_THREADS / 2; + const int blocksPerGrid = (input.lengthOf() + threadsPerBlock - 1) / threadsPerBlock; + const int sharedMem = input.rankOf() * sizeof(Nd4jLong) * threadsPerBlock + 128; + + NDArray::prepareSpecialUse({&output}, {&input}); + BUILD_SINGLE_SELECTOR(input.dataType(), rgbToGrsCudaLauncher, (blocksPerGrid, threadsPerBlock, sharedMem, context->getCudaStream(), input.getSpecialBuffer(), input.getSpecialShapeInfo(), output.getSpecialBuffer(), output.getSpecialShapeInfo(), dimC), NUMERIC_TYPES); + NDArray::registerSpecialUse({&output}, {&input}); + + manager.synchronize(); +} + + +/////////////////////////////////////////////////////////////////// +template +static void _CUDA_G rgbToHsvCuda(const void* vx, const Nd4jLong* xShapeInfo, const Nd4jLong* xTadOffsets, + void* vz, const Nd4jLong *zShapeInfo, const Nd4jLong* zTadOffsets, + const Nd4jLong numOfTads, const int dimC) { + + const T* x = reinterpret_cast(vx); + T* z = reinterpret_cast(vz); + + __shared__ int rank; + __shared__ Nd4jLong xDimCstride, zDimCstride; + + if (threadIdx.x == 0) { + rank = shape::rank(xShapeInfo); + xDimCstride = shape::stride(xShapeInfo)[dimC]; + zDimCstride = shape::stride(zShapeInfo)[dimC]; + } + __syncthreads(); + + const auto tid = blockIdx.x * blockDim.x + threadIdx.x; + + for (Nd4jLong i = tid; i < numOfTads; i += gridDim.x * blockDim.x) { + const T* xTad = x + xTadOffsets[i]; + T* zTad = z + zTadOffsets[i]; + + rgbToHsv(xTad[0], xTad[xDimCstride], xTad[2 * xDimCstride], zTad[0], zTad[zDimCstride], zTad[2 * zDimCstride]); + } +} + +/////////////////////////////////////////////////////////////////// +template +static void _CUDA_G hsvToRgbCuda(const void* vx, const Nd4jLong* xShapeInfo, const Nd4jLong* xTadOffsets, + void* vz, const Nd4jLong *zShapeInfo, const Nd4jLong* zTadOffsets, + const Nd4jLong numOfTads, const int dimC) { + + const T* x = reinterpret_cast(vx); + T* z = reinterpret_cast(vz); + + __shared__ int rank; + __shared__ Nd4jLong xDimCstride, zDimCstride; + + if (threadIdx.x == 0) { + rank = shape::rank(xShapeInfo); + xDimCstride = shape::stride(xShapeInfo)[dimC]; + zDimCstride = shape::stride(zShapeInfo)[dimC]; + } + __syncthreads(); + + const auto tid = blockIdx.x * blockDim.x + threadIdx.x; + + for (Nd4jLong i = tid; i < numOfTads; i += gridDim.x * blockDim.x) { + const T* xTad = x + xTadOffsets[i]; + T* zTad = z + zTadOffsets[i]; + + hsvToRgb(xTad[0], xTad[xDimCstride], xTad[2 * xDimCstride], zTad[0], zTad[zDimCstride], zTad[2 * zDimCstride]); + } +} + +/////////////////////////////////////////////////////////////////// +template +static _CUDA_H void hsvToRgbCudaLauncher(const int blocksPerGrid, const int threadsPerBlock, const cudaStream_t *stream, + const void* vx, const Nd4jLong* xShapeInfo, const Nd4jLong* xTadOffsets, + void* vz, const Nd4jLong* zShapeInfo, const Nd4jLong* zTadOffsets, + const Nd4jLong numOfTads, const int dimC) { + + hsvToRgbCuda<<>>(vx, xShapeInfo, xTadOffsets, vz, zShapeInfo, zTadOffsets, numOfTads, dimC); +} + +template +static _CUDA_H void rgbToHsvCudaLauncher(const int blocksPerGrid, const int threadsPerBlock, const cudaStream_t *stream, + const void* vx, const Nd4jLong* xShapeInfo, const Nd4jLong* xTadOffsets, + void* vz, const Nd4jLong* zShapeInfo, const Nd4jLong* zTadOffsets, + const Nd4jLong numOfTads, const int dimC) { + + rgbToHsvCuda<<>>(vx, xShapeInfo, xTadOffsets, vz, zShapeInfo, zTadOffsets, numOfTads, dimC); +} + +/////////////////////////////////////////////////////////////////// +void transformHsvRgb(nd4j::LaunchContext* context, const NDArray* input, NDArray* output, const int dimC) { + + auto packX = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(input->getShapeInfo(), {dimC}); + auto packZ = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(output->getShapeInfo(), {dimC}); + + const Nd4jLong numOfTads = packX.numberOfTads(); + + const int threadsPerBlock = MAX_NUM_THREADS / 2; + const int blocksPerGrid = (numOfTads + threadsPerBlock - 1) / threadsPerBlock; + + PointersManager manager(context, "hsv_to_rgb"); + + NDArray::prepareSpecialUse({output}, {input}); + BUILD_SINGLE_SELECTOR(input->dataType(), hsvToRgbCudaLauncher, (blocksPerGrid, threadsPerBlock, context->getCudaStream(), input->getSpecialBuffer(), input->getSpecialShapeInfo(), packX.platformOffsets(), output->specialBuffer(), output->specialShapeInfo(), packZ.platformOffsets(), numOfTads, dimC), FLOAT_TYPES); + NDArray::registerSpecialUse({output}, {input}); + + manager.synchronize(); +} + +/////////////////////////////////////////////////////////////////// +void transformRgbHsv(nd4j::LaunchContext* context, const NDArray* input, NDArray* output, const int dimC) { + auto packX = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(input->getShapeInfo(), {dimC}); + auto packZ = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(output->getShapeInfo(), {dimC}); + + const Nd4jLong numOfTads = packX.numberOfTads(); + + const int threadsPerBlock = MAX_NUM_THREADS / 2; + const int blocksPerGrid = (numOfTads + threadsPerBlock - 1) / threadsPerBlock; + + PointersManager manager(context, "rgb_to_hsv"); + + NDArray::prepareSpecialUse({output}, {input}); + BUILD_SINGLE_SELECTOR(input->dataType(), rgbToHsvCudaLauncher, (blocksPerGrid, threadsPerBlock, context->getCudaStream(), input->getSpecialBuffer(), input->getSpecialShapeInfo(), packX.platformOffsets(), output->specialBuffer(), output->specialShapeInfo(), packZ.platformOffsets(), numOfTads, dimC), FLOAT_TYPES); + NDArray::registerSpecialUse({output}, {input}); + + manager.synchronize(); +} + + + + + + + + + + + + +} +} +} + diff --git a/libnd4j/include/ops/declarable/helpers/color_models_conv.h b/libnd4j/include/ops/declarable/helpers/imagesHelpers.h similarity index 60% rename from libnd4j/include/ops/declarable/helpers/color_models_conv.h rename to libnd4j/include/ops/declarable/helpers/imagesHelpers.h index c23e40dba..224c0b19f 100644 --- a/libnd4j/include/ops/declarable/helpers/color_models_conv.h +++ b/libnd4j/include/ops/declarable/helpers/imagesHelpers.h @@ -14,17 +14,30 @@ * SPDX-License-Identifier: Apache-2.0 ******************************************************************************/ +// +// @author Oleh Semeniv (oleg.semeniv@gmail.com) +// +// +// @author Adel Rauf (rauf@konduit.ai) +// + +#ifndef LIBND4J_HELPERS_IMAGES_H +#define LIBND4J_HELPERS_IMAGES_H + #include #include #include namespace nd4j { - namespace ops { - namespace helpers { - - void transform_hsv_rgb(nd4j::LaunchContext* context, const NDArray* input, NDArray* output, const int dimC); - void transform_rgb_hsv(nd4j::LaunchContext* context, const NDArray* input, NDArray* output, const int dimC); - - } - } -} \ No newline at end of file +namespace ops { +namespace helpers { + + void transformRgbGrs(nd4j::LaunchContext* context, const NDArray& input, NDArray& output, const int dimC); + void transformHsvRgb(nd4j::LaunchContext* context, const NDArray* input, NDArray* output, const int dimC); + void transformRgbHsv(nd4j::LaunchContext* context, const NDArray* input, NDArray* output, const int dimC); + +} +} +} + +#endif \ No newline at end of file diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests15.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests15.cpp index 5697a5257..1217288a8 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests15.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests15.cpp @@ -919,3 +919,150 @@ TEST_F(DeclarableOpsTests15, test_empty_decreasing_1) { ASSERT_EQ(true, z.e(0)); } + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests15, test_rgb_to_grs_1) { + // rank 1 + NDArray rgbs('c', { 3 }, { 10, 50, 200 }, nd4j::DataType::INT32); + NDArray expected('c', { 1 }, { 55 }, nd4j::DataType::INT32); + nd4j::ops::rgb_to_grs op; + auto result = op.execute({&rgbs}, {}, {}); + auto output = result->at(0); + + ASSERT_EQ(Status::OK(), result->status()); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); + + delete result; +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests15, test_rgb_to_grs_2) { + // rank 1 + auto rgbs = NDArrayFactory::create('f', { 3 }, { 1, 120, -25 }); + auto expected = NDArrayFactory::create('f', { 1 }, { 67 }); + nd4j::ops::rgb_to_grs op; + auto result = op.execute({ &rgbs }, {}, {}); + auto output = result->at(0); + + ASSERT_EQ(Status::OK(), result->status()); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); + + delete result; +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests15, test_rgb_to_grs_3) { + // rank 2 + NDArray rgbs('c', { 4, 3 }, { -94, 99, 97, 90, 114, 101, 111, 96, 105, 100, 103, 102 }, nd4j::DataType::INT32); + NDArray expected('c', { 4, 1 }, { 41, 105, 101, 101 }, nd4j::DataType::INT32); + nd4j::ops::rgb_to_grs op; + auto result = op.execute({ &rgbs }, {}, {}); + auto output = result->at(0); + + ASSERT_EQ(Status::OK(), result->status()); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); + + delete result; +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests15, test_rgb_to_grs_4) { + + NDArray rgbs('c', { 3, 2 }, {14, 99, 207, 10, 114, 201 }, nd4j::DataType::INT32); + + rgbs.permutei({1,0}); + NDArray expected('c', { 2, 1 }, { 138, 58 }, nd4j::DataType::INT32); + nd4j::ops::rgb_to_grs op; + auto result = op.execute({ &rgbs }, {}, {}); + auto output = result->at(0); + + ASSERT_EQ(Status::OK(), result->status()); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); + + delete result; +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests15, test_rgb_to_grs_5) { + // rank 2 + NDArray rgbs('c', { 3, 4 }, { -94, 99, 97, 90, 114, 101, 111, 96, 105, 100, 103, 102 }, nd4j::DataType::INT32); + NDArray expected('c', { 1, 4 }, { 50, 100, 105, 94 }, nd4j::DataType::INT32); + nd4j::ops::rgb_to_grs op; + auto result = op.execute({ &rgbs }, {}, {0}); + auto output = result->at(0); + + ASSERT_EQ(Status::OK(), result->status()); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); + + delete result; +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests15, test_rgb_to_grs_6) { + // rank 3 + auto rgbs = NDArrayFactory::create('c', { 5,4,3 }, {1.7750e+01f, -7.1062e+01f, -1.0019e+02f,-2.3406e+01f, 5.2094e+01f, 9.5438e+01f, -6.7461e+00f, 3.8562e+01f, 6.5078e+00f,3.3562e+01f, -5.8844e+01f, 2.2750e+01f, -1.0477e+01f, 7.7344e+00f, 9.5469e+00f,2.1391e+01f, -8.5312e+01f, 7.5830e-01f,2.3125e+01f, 1.8145e+00f, 1.4602e+01f,-4.5859e+00f, 3.9344e+01f, 1.1617e+01f,-8.6562e+01f, 1.0038e+02f, 6.7938e+01f,5.9961e+00f, 6.7812e+01f, 2.9734e+01f,2.9609e+01f, -6.1438e+01f, 1.7750e+01f,6.8562e+01f, -7.4414e+00f, 3.9656e+01f,1.1641e+01f, -2.7516e+01f, 6.7562e+01f,7.8438e+01f, 5.4883e+00f, 2.9438e+01f,-3.1344e+01f, 6.5125e+01f, 1.2695e+01f,4.0531e+01f, -6.1211e+00f, 6.2219e+01f,4.6812e+01f, 5.2250e+01f, -1.1414e+01f,1.5404e-02f, 2.9938e+01f, 5.6719e+00f,-2.0125e+01f, 2.1531e+01f, 6.2500e+01f,7.2188e+01f, 9.3750e+00f, -4.8125e+01f}); + auto expected = NDArrayFactory::create('c', { 5,4,1 }, {-47.82958221f, 34.46305847f, 21.36137581f, -21.91625023f,2.49686432f, -43.59792709f, 9.64180183f, 23.04854202f,40.7946167f, 44.98754883f, -25.19047546f, 20.64586449f,-4.97033119f, 30.0226841f, 30.30688286f, 15.61459541f,43.36166f, 18.22480774f, 13.74833488f, 21.59387016f}); + + nd4j::ops::rgb_to_grs op; + auto result = op.execute({ &rgbs }, {}, {}); + auto output = result->at(0); + + ASSERT_EQ(Status::OK(), result->status()); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); + + delete result; +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests15, test_rgb_to_grs_7) { + // rank 3 + auto rgbs = NDArrayFactory::create('c', { 5,3,4 }, { 1.7750e+01f, -7.1062e+01f, -1.0019e+02f,-2.3406e+01f, 5.2094e+01f, 9.5438e+01f, -6.7461e+00f, 3.8562e+01f, 6.5078e+00f,3.3562e+01f, -5.8844e+01f, 2.2750e+01f, -1.0477e+01f, 7.7344e+00f, 9.5469e+00f,2.1391e+01f, -8.5312e+01f, 7.5830e-01f,2.3125e+01f, 1.8145e+00f, 1.4602e+01f,-4.5859e+00f, 3.9344e+01f, 1.1617e+01f,-8.6562e+01f, 1.0038e+02f, 6.7938e+01f,5.9961e+00f, 6.7812e+01f, 2.9734e+01f,2.9609e+01f, -6.1438e+01f, 1.7750e+01f,6.8562e+01f, -7.4414e+00f, 3.9656e+01f,1.1641e+01f, -2.7516e+01f, 6.7562e+01f,7.8438e+01f, 5.4883e+00f, 2.9438e+01f,-3.1344e+01f, 6.5125e+01f, 1.2695e+01f,4.0531e+01f, -6.1211e+00f, 6.2219e+01f,4.6812e+01f, 5.2250e+01f, -1.1414e+01f,1.5404e-02f, 2.9938e+01f, 5.6719e+00f,-2.0125e+01f, 2.1531e+01f, 6.2500e+01f,7.2188e+01f, 9.3750e+00f, -4.8125e+01f}); + auto expected = NDArrayFactory::create('c', { 5,1,4 }, { 36.626545, 38.607746, -40.614971, 18.233341, -51.545094,2.234142, 20.913160, 8.783220, 15.955761, 55.273506, 36.838833, -29.751089, 8.148357, 13.676106, 1.097548, 68.766457, 38.690712, 27.176361, -14.156269, 7.157052 }); + + nd4j::ops::rgb_to_grs op; + auto result = op.execute({ &rgbs }, {}, {1}); + auto output = result->at(0); + + ASSERT_EQ(Status::OK(), result->status()); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); + + delete result; +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests15, test_rgb_to_grs_8) { + // rank 3 + auto rgbs = NDArrayFactory::create('c', { 3,5,4 }, {1.7750e+01f, -7.1062e+01f, -1.0019e+02f,-2.3406e+01f, 5.2094e+01f, 9.5438e+01f, -6.7461e+00f, 3.8562e+01f, 6.5078e+00f,3.3562e+01f, -5.8844e+01f, 2.2750e+01f, -1.0477e+01f, 7.7344e+00f, 9.5469e+00f,2.1391e+01f, -8.5312e+01f, 7.5830e-01f,2.3125e+01f, 1.8145e+00f, 1.4602e+01f,-4.5859e+00f, 3.9344e+01f, 1.1617e+01f,-8.6562e+01f, 1.0038e+02f, 6.7938e+01f,5.9961e+00f, 6.7812e+01f, 2.9734e+01f,2.9609e+01f, -6.1438e+01f, 1.7750e+01f,6.8562e+01f, -7.4414e+00f, 3.9656e+01f,1.1641e+01f, -2.7516e+01f, 6.7562e+01f,7.8438e+01f, 5.4883e+00f, 2.9438e+01f,-3.1344e+01f, 6.5125e+01f, 1.2695e+01f,4.0531e+01f, -6.1211e+00f, 6.2219e+01f,4.6812e+01f, 5.2250e+01f, -1.1414e+01f,1.5404e-02f, 2.9938e+01f, 5.6719e+00f,-2.0125e+01f, 2.1531e+01f, 6.2500e+01f,7.2188e+01f, 9.3750e+00f, -4.8125e+01f}); + try { + nd4j::ops::rgb_to_grs op; + auto result = op.execute({ &rgbs }, {}, {}); + ASSERT_EQ(Status::THROW(), result->status()); + delete result; + } catch (std::exception& e) { + nd4j_printf("Error should be here `%s'. It's OK.\n", e.what()); + } +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests15, test_rgb_to_grs_9) { + // rank 3 + auto rgbs = NDArrayFactory::create('f', { 2, 2, 3 }, { 1.7750e+01f,-7.1062e+01f, -1.0019e+02f, -2.3406e+01f,5.2094e+01f,9.5438e+01f, -6.7461e+00f,3.8562e+01f, 6.5078e+00f, 3.3562e+01f,-5.8844e+01f,2.2750e+01f}); + auto expected = NDArrayFactory::create('f', { 2,2,1 }, { 36.626545f, 38.607746f, -40.614971f, 18.233341f }); + + nd4j::ops::rgb_to_grs op; + auto result = op.execute({ &rgbs }, {}, {}); + auto output = result->at(0); + + ASSERT_EQ(Status::OK(), result->status()); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); + + delete result; +} \ No newline at end of file diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests16.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests16.cpp index a2e33ec4f..303580205 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests16.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests16.cpp @@ -239,7 +239,6 @@ TEST_F(DeclarableOpsTests16, test_reverse_1) { } } - TEST_F(DeclarableOpsTests16, test_rgb_to_hsv_1) { /* test case generated by python colorsys and scaled to suit our needs