diff --git a/libnd4j/CMakeSettings.json b/libnd4j/CMakeSettings.json index 2bb5bddbc..afda69260 100644 --- a/libnd4j/CMakeSettings.json +++ b/libnd4j/CMakeSettings.json @@ -12,6 +12,21 @@ "cmakeCommandArgs": " -DCUDA_BLAS=true -DLIBND4J_NAME=nd4jcuda -DMSVC_DEV=true -DCOMPUTE=61 -DBUILD_TESTS=true", "buildCommandArgs": "-v", "ctestCommandArgs": "" + }, + { + "name": "WSL-GCC-Debug", + "generator": "Unix Makefiles", + "configurationType": "Debug", + "buildRoot": "${projectDir}\\out\\build\\${name}", + "installRoot": "${projectDir}\\out\\install\\${name}", + "cmakeExecutable": "/usr/bin/cmake", + "cmakeCommandArgs": "-DLIBND4J_ALL_OPS=true -DCMAKE_BUILD_TYPE=Debug -DCPU_BLAS=true -DLIBND4J_NAME=nd4jcpu -DBUILD_TESTS=ON -DCMAKE_BUILD_TYPE=Debug -DOPENBLAS_PATH=/usr/lib/openblas-base/ -DEXTENSION=avx2 ", + "buildCommandArgs": "-j 4", + "ctestCommandArgs": "", + "inheritEnvironments": [ "linux_x64" ], + "wslPath": "${defaultWSLPath}", + "addressSanitizerRuntimeFlags": "detect_leaks=0", + "variables": [] } ] } \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/CustomOperations.h b/libnd4j/include/ops/declarable/CustomOperations.h index 5aea215c1..ebbc4b113 100644 --- a/libnd4j/include/ops/declarable/CustomOperations.h +++ b/libnd4j/include/ops/declarable/CustomOperations.h @@ -41,6 +41,7 @@ #include #include #include +#include #include #include #include diff --git a/libnd4j/include/ops/declarable/generic/color_models/hsv_rgb_ops.cpp b/libnd4j/include/ops/declarable/generic/color_models/hsv_rgb_ops.cpp new file mode 100644 index 000000000..83d3c84ea --- /dev/null +++ b/libnd4j/include/ops/declarable/generic/color_models/hsv_rgb_ops.cpp @@ -0,0 +1,85 @@ +/******************************************************************************* + * Copyright (c) 2019 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + + +#include +#include +#include +#include + +namespace nd4j { + namespace ops { + + + + CONFIGURABLE_OP_IMPL(hsv_to_rgb, 1, 1, false, 0, 0) { + + auto input = INPUT_VARIABLE(0); + auto output = OUTPUT_VARIABLE(0); + + if (input->isEmpty()) + return Status::OK(); + + const int rank = input->rankOf(); + const int arg_size = block.getIArguments()->size(); + const int dimC = arg_size > 0 ? (INT_ARG(0) >= 0 ? INT_ARG(0) : INT_ARG(0) + rank) : rank - 1; + + REQUIRE_TRUE(rank >= 1, 0, "HSVtoRGB: Fails to meet the rank requirement: %i >= 1 ", rank); + if (arg_size > 0) { + REQUIRE_TRUE(dimC >= 0 && dimC < rank, 0, "Index of the Channel dimension out of range: %i not in [%i,%i) ", INT_ARG(0), -rank, rank); + } + REQUIRE_TRUE(input->sizeAt(dimC) == 3, 0, "HSVtoRGB: operation expects 3 channels (H, S, V), but got %i instead", input->sizeAt(dimC)); + + helpers::transform_hsv_rgb(block.launchContext(), input, output, dimC); + + return Status::OK(); + } + + CONFIGURABLE_OP_IMPL(rgb_to_hsv, 1, 1, false, 0, 0) { + + auto input = INPUT_VARIABLE(0); + auto output = OUTPUT_VARIABLE(0); + + if (input->isEmpty()) + return Status::OK(); + + const int rank = input->rankOf(); + const int arg_size = block.getIArguments()->size(); + const int dimC = arg_size > 0 ? (INT_ARG(0) >= 0 ? INT_ARG(0) : INT_ARG(0) + rank) : rank - 1; + + REQUIRE_TRUE(rank >= 1, 0, "RGBtoHSV: Fails to meet the rank requirement: %i >= 1 ", rank); + if (arg_size > 0) { + REQUIRE_TRUE(dimC >= 0 && dimC < rank, 0, "Index of the Channel dimension out of range: %i not in [%i,%i) ", INT_ARG(0), -rank, rank); + } + REQUIRE_TRUE(input->sizeAt(dimC) == 3, 0, "RGBtoHSV: operation expects 3 channels (H, S, V), but got %i instead", input->sizeAt(dimC)); + + helpers::transform_rgb_hsv(block.launchContext(), input, output, dimC); + + return Status::OK(); + } + + + DECLARE_TYPES(hsv_to_rgb) { + getOpDescriptor()->setAllowedInputTypes({ ALL_FLOATS }) + ->setSameMode(true); + } + + DECLARE_TYPES(rgb_to_hsv) { + getOpDescriptor()->setAllowedInputTypes({ ALL_FLOATS }) + ->setSameMode(true); + } + } +} diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/adjust_hue.cpp b/libnd4j/include/ops/declarable/generic/parity_ops/adjust_hue.cpp index 32e51bdb9..003ff6e75 100644 --- a/libnd4j/include/ops/declarable/generic/parity_ops/adjust_hue.cpp +++ b/libnd4j/include/ops/declarable/generic/parity_ops/adjust_hue.cpp @@ -39,10 +39,14 @@ CONFIGURABLE_OP_IMPL(adjust_hue, 1, 1, true, 0, 0) { return Status::OK(); const int rank = input->rankOf(); - const int dimC = block.getIArguments()->size() > 0 ? (INT_ARG(0) >= 0 ? INT_ARG(0) : INT_ARG(0) + rank) : rank - 1; + const int arg_size = block.getIArguments()->size(); + const int dimC = arg_size > 0 ? (INT_ARG(0) >= 0 ? INT_ARG(0) : INT_ARG(0) + rank) : rank - 1; REQUIRE_TRUE(block.numT() > 0 || block.width() > 1, 0, "ADJUST_HUE: delta factor is required !"); REQUIRE_TRUE(rank >= 3, 0, "ADJUST_HUE: op expects rank of input array to be >= 3, but got %i instead", rank); + if (arg_size > 0) { + REQUIRE_TRUE(dimC >= 0 && dimC < rank, 0, "Index of the Channel dimension out of range: %i not in [%i,%i) ", INT_ARG(0), -rank, rank); + } REQUIRE_TRUE(input->sizeAt(dimC) == 3, 0, "ADJUST_HUE: operation expects image with 3 channels (R, G, B), but got %i instead", input->sizeAt(dimC)); NDArray* delta = nullptr; diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/adjust_saturation.cpp b/libnd4j/include/ops/declarable/generic/parity_ops/adjust_saturation.cpp index de947c9ae..0a8eaf0c3 100644 --- a/libnd4j/include/ops/declarable/generic/parity_ops/adjust_saturation.cpp +++ b/libnd4j/include/ops/declarable/generic/parity_ops/adjust_saturation.cpp @@ -38,9 +38,13 @@ CONFIGURABLE_OP_IMPL(adjust_saturation, 1, 1, true, 0, 0) { return Status::OK(); const int rank = input->rankOf(); - const int dimC = block.getIArguments()->size() > 0 ? (INT_ARG(0) >= 0 ? INT_ARG(0) : INT_ARG(0) + rank) : rank - 1; + const int arg_size = block.getIArguments()->size(); + const int dimC = arg_size > 0 ? (INT_ARG(0) >= 0 ? INT_ARG(0) : INT_ARG(0) + rank) : rank - 1; REQUIRE_TRUE(rank >= 3, 0, "ADJUST_SATURATION: op expects rank of input array to be >= 3, but got %i instead", rank); + if (arg_size > 0) { + REQUIRE_TRUE(dimC >= 0 && dimC < rank, 0, "Index of the Channel dimension out of range: %i not in [%i,%i) ", INT_ARG(0), -rank, rank); + } REQUIRE_TRUE(input->sizeAt(dimC) == 3, 0, "ADJUST_SATURATION: operation expects image with 3 channels (R, G, B), but got %i instead", input->sizeAt(dimC)); REQUIRE_TRUE(block.numT() > 0 || block.width() > 1, 0, "ADJUST_SATURATION: scale factor is required !"); diff --git a/libnd4j/include/ops/declarable/headers/color_models.h b/libnd4j/include/ops/declarable/headers/color_models.h new file mode 100644 index 000000000..5574749c2 --- /dev/null +++ b/libnd4j/include/ops/declarable/headers/color_models.h @@ -0,0 +1,56 @@ +/******************************************************************************* + * Copyright (c) 2019 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + + + +#ifndef LIBND4J_HEADERS_COLOR_MODELS_H +#define LIBND4J_HEADERS_COLOR_MODELS_H + +#include +#include +#include +#include +#include + +namespace nd4j { + namespace ops { + + /** + * Rgb To Hsv + * Input arrays: + * 0 - input array with rank >= 1, must have at least one dimension equal 3, that is dimension containing channels. + * Int arguments: + * 0 - optional argument, corresponds to dimension with 3 channels + */ +#if NOT_EXCLUDED(OP_rgb_to_hsv) + DECLARE_CONFIGURABLE_OP(rgb_to_hsv, 1, 1, false, 0, 0); +#endif + + /** + * Hsv To Rgb + * Input arrays: + * 0 - input array with rank >= 1, must have at least one dimension equal 3, that is dimension containing channels. + * Int arguments: + * 0 - optional argument, corresponds to dimension with 3 channels + */ +#if NOT_EXCLUDED(OP_hsv_to_rgb) + DECLARE_CONFIGURABLE_OP(hsv_to_rgb, 1, 1, false, 0, 0); +#endif + + } +} + +#endif diff --git a/libnd4j/include/ops/declarable/helpers/color_models_conv.h b/libnd4j/include/ops/declarable/helpers/color_models_conv.h new file mode 100644 index 000000000..c23e40dba --- /dev/null +++ b/libnd4j/include/ops/declarable/helpers/color_models_conv.h @@ -0,0 +1,30 @@ +/******************************************************************************* + * Copyright (c) 2019 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +#include +#include +#include + +namespace nd4j { + namespace ops { + namespace helpers { + + void transform_hsv_rgb(nd4j::LaunchContext* context, const NDArray* input, NDArray* output, const int dimC); + void transform_rgb_hsv(nd4j::LaunchContext* context, const NDArray* input, NDArray* output, const int dimC); + + } + } +} \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/helpers/cpu/hsv_rgb.cpp b/libnd4j/include/ops/declarable/helpers/cpu/hsv_rgb.cpp new file mode 100644 index 000000000..86ca2e991 --- /dev/null +++ b/libnd4j/include/ops/declarable/helpers/cpu/hsv_rgb.cpp @@ -0,0 +1,90 @@ +/******************************************************************************* + * Copyright (c) 2019 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +#include +#include +#include +#include + +namespace nd4j { + namespace ops { + namespace helpers { + + //local + template + FORCEINLINE static void triple_transformer(const NDArray* input, NDArray* output, const int dimC, Op op) { + + const int rank = input->rankOf(); + + const T* x = input->bufferAsT(); + T* z = output->bufferAsT(); + + if (dimC == rank - 1 && input->ews() == 1 && output->ews() == 1 && input->ordering() == 'c' && output->ordering() == 'c') { + + auto func = PRAGMA_THREADS_FOR{ + for (auto i = start; i < stop; i += increment) { + op(x[i], x[i + 1], x[i + 2], z[i], z[i + 1], z[i + 2]); + } + }; + + samediff::Threads::parallel_for(func, 0, input->lengthOf(), 3); + } + else { + auto packX = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(input->getShapeInfo(), dimC); + auto packZ = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(output->getShapeInfo(), dimC); + + const Nd4jLong numOfTads = packX.numberOfTads(); + const Nd4jLong xDimCstride = input->stridesOf()[dimC]; + const Nd4jLong zDimCstride = output->stridesOf()[dimC]; + + auto func = PRAGMA_THREADS_FOR{ + for (auto i = start; i < stop; i += increment) { + const T* xTad = x + packX.platformOffsets()[i]; + T* zTad = z + packZ.platformOffsets()[i]; + op(xTad[0], xTad[xDimCstride], xTad[2 * xDimCstride], zTad[0], zTad[zDimCstride], zTad[2 * zDimCstride]); + + } + }; + + samediff::Threads::parallel_tad(func, 0, numOfTads); + } + } + + + + template + FORCEINLINE static void hsv_rgb(const NDArray* input, NDArray* output, const int dimC) { + auto op = nd4j::ops::helpers::hsvToRgb; + return triple_transformer(input, output, dimC, op); + } + + template + FORCEINLINE static void rgb_hsv(const NDArray* input, NDArray* output, const int dimC) { + auto op = nd4j::ops::helpers::rgbToHsv; + return triple_transformer(input, output, dimC, op); + } + + void transform_hsv_rgb(nd4j::LaunchContext* context, const NDArray* input, NDArray* output, const int dimC) { + BUILD_SINGLE_SELECTOR(input->dataType(), hsv_rgb, (input, output, dimC), FLOAT_TYPES); + } + + void transform_rgb_hsv(nd4j::LaunchContext* context, const NDArray* input, NDArray* output, const int dimC) { + BUILD_SINGLE_SELECTOR(input->dataType(), rgb_hsv, (input, output, dimC), FLOAT_TYPES); + } + + } + } +} \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/helpers/cuda/hsv_rgb.cu b/libnd4j/include/ops/declarable/helpers/cuda/hsv_rgb.cu new file mode 100644 index 000000000..77faaa242 --- /dev/null +++ b/libnd4j/include/ops/declarable/helpers/cuda/hsv_rgb.cu @@ -0,0 +1,139 @@ +/******************************************************************************* + * Copyright (c) 2019 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ +#include +#include +#include +#include +#include + +namespace nd4j { + namespace ops { + namespace helpers { + + template + static void _CUDA_G rgbToHsvCuda(const void* vx, const Nd4jLong* xShapeInfo, const Nd4jLong* xTadOffsets, + void* vz, const Nd4jLong *zShapeInfo, const Nd4jLong* zTadOffsets, + const Nd4jLong numOfTads, const int dimC) { + + const T* x = reinterpret_cast(vx); + T* z = reinterpret_cast(vz); + + __shared__ int rank; + __shared__ Nd4jLong xDimCstride, zDimCstride; + + if (threadIdx.x == 0) { + rank = shape::rank(xShapeInfo); + xDimCstride = shape::stride(xShapeInfo)[dimC]; + zDimCstride = shape::stride(zShapeInfo)[dimC]; + } + __syncthreads(); + + const auto tid = blockIdx.x * blockDim.x + threadIdx.x; + + for (Nd4jLong i = tid; i < numOfTads; i += gridDim.x * blockDim.x) { + const T* xTad = x + xTadOffsets[i]; + T* zTad = z + zTadOffsets[i]; + + rgbToHsv(xTad[0], xTad[xDimCstride], xTad[2 * xDimCstride], zTad[0], zTad[zDimCstride], zTad[2 * zDimCstride]); + } + } + + template + static void _CUDA_G hsvToRgbCuda(const void* vx, const Nd4jLong* xShapeInfo, const Nd4jLong* xTadOffsets, + void* vz, const Nd4jLong *zShapeInfo, const Nd4jLong* zTadOffsets, + const Nd4jLong numOfTads, const int dimC) { + + const T* x = reinterpret_cast(vx); + T* z = reinterpret_cast(vz); + + __shared__ int rank; + __shared__ Nd4jLong xDimCstride, zDimCstride; + + if (threadIdx.x == 0) { + rank = shape::rank(xShapeInfo); + xDimCstride = shape::stride(xShapeInfo)[dimC]; + zDimCstride = shape::stride(zShapeInfo)[dimC]; + } + __syncthreads(); + + const auto tid = blockIdx.x * blockDim.x + threadIdx.x; + + for (Nd4jLong i = tid; i < numOfTads; i += gridDim.x * blockDim.x) { + const T* xTad = x + xTadOffsets[i]; + T* zTad = z + zTadOffsets[i]; + + hsvToRgb(xTad[0], xTad[xDimCstride], xTad[2 * xDimCstride], zTad[0], zTad[zDimCstride], zTad[2 * zDimCstride]); + } + } + + /////////////////////////////////////////////////////////////////// + template + static _CUDA_H void hsvToRgbCudaLauncher(const int blocksPerGrid, const int threadsPerBlock, const cudaStream_t *stream, + const void* vx, const Nd4jLong* xShapeInfo, const Nd4jLong* xTadOffsets, + void* vz, const Nd4jLong* zShapeInfo, const Nd4jLong* zTadOffsets, + const Nd4jLong numOfTads, const int dimC) { + + hsvToRgbCuda<<>>(vx, xShapeInfo, xTadOffsets, vz, zShapeInfo, zTadOffsets, numOfTads, dimC); + } + + template + static _CUDA_H void rgbToHsvCudaLauncher(const int blocksPerGrid, const int threadsPerBlock, const cudaStream_t *stream, + const void* vx, const Nd4jLong* xShapeInfo, const Nd4jLong* xTadOffsets, + void* vz, const Nd4jLong* zShapeInfo, const Nd4jLong* zTadOffsets, + const Nd4jLong numOfTads, const int dimC) { + + rgbToHsvCuda<<>>(vx, xShapeInfo, xTadOffsets, vz, zShapeInfo, zTadOffsets, numOfTads, dimC); + } + + + void transform_hsv_rgb(nd4j::LaunchContext* context, const NDArray* input, NDArray* output, const int dimC) { + auto packX = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(input->getShapeInfo(), {dimC}); + auto packZ = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(output->getShapeInfo(), {dimC}); + + const Nd4jLong numOfTads = packX.numberOfTads(); + + const int threadsPerBlock = MAX_NUM_THREADS / 2; + const int blocksPerGrid = (numOfTads + threadsPerBlock - 1) / threadsPerBlock; + + PointersManager manager(context, "hsv_to_rgb"); + + NDArray::prepareSpecialUse({output}, {input}); + BUILD_SINGLE_SELECTOR(input->dataType(), hsvToRgbCudaLauncher, (blocksPerGrid, threadsPerBlock, context->getCudaStream(), input->getSpecialBuffer(), input->getSpecialShapeInfo(), packX.platformOffsets(), output->specialBuffer(), output->specialShapeInfo(), packZ.platformOffsets(), numOfTads, dimC), FLOAT_TYPES); + NDArray::registerSpecialUse({output}, {input}); + + manager.synchronize(); + } + + void transform_rgb_hsv(nd4j::LaunchContext* context, const NDArray* input, NDArray* output, const int dimC) { + auto packX = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(input->getShapeInfo(), {dimC}); + auto packZ = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(output->getShapeInfo(), {dimC}); + + const Nd4jLong numOfTads = packX.numberOfTads(); + + const int threadsPerBlock = MAX_NUM_THREADS / 2; + const int blocksPerGrid = (numOfTads + threadsPerBlock - 1) / threadsPerBlock; + + PointersManager manager(context, "rgb_to_hsv"); + + NDArray::prepareSpecialUse({output}, {input}); + BUILD_SINGLE_SELECTOR(input->dataType(), rgbToHsvCudaLauncher, (blocksPerGrid, threadsPerBlock, context->getCudaStream(), input->getSpecialBuffer(), input->getSpecialShapeInfo(), packX.platformOffsets(), output->specialBuffer(), output->specialShapeInfo(), packZ.platformOffsets(), numOfTads, dimC), FLOAT_TYPES); + NDArray::registerSpecialUse({output}, {input}); + + manager.synchronize(); + } + } + } +} \ No newline at end of file diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests16.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests16.cpp index f8bf47e53..e46db7a90 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests16.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests16.cpp @@ -1,240 +1,695 @@ -/******************************************************************************* - * Copyright (c) 2015-2019 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - - -// -// @author raver119@gmail.com -// - -#include "testlayers.h" -#include -#include -#include -#include -#include - - -using namespace nd4j; - - -class DeclarableOpsTests16 : public testing::Test { -public: - - DeclarableOpsTests16() { - printf("\n"); - fflush(stdout); - } -}; - -TEST_F(DeclarableOpsTests16, scatter_upd_1) { - auto x = NDArrayFactory::create('c', {3}, {1.f, 1.f, 1.f}); - auto y = NDArrayFactory::create(0); - auto w = NDArrayFactory::create(3.0f); - auto e = NDArrayFactory::create('c', {3}, {3.f, 1.f, 1.f}); - - nd4j::ops::scatter_upd op; - auto result = op.execute({&x, &y, &w}, {}, {}); - ASSERT_EQ(Status::OK(), result->status()); - - auto z = result->at(0); - - ASSERT_EQ(e, *z); - - delete result; -} - -TEST_F(DeclarableOpsTests16, scatter_upd_2) { - - NDArray x('c', {10, 3}, nd4j::DataType::FLOAT32); - NDArray indices('c', {2}, {2,5}, nd4j::DataType::INT32); - NDArray updates('c', {2, 3}, {100,101,102, 200,201,202}, nd4j::DataType::FLOAT32); - NDArray e('c', {10, 3}, {1,2,3, 4,5,6, 100,101,102, 10,11,12, 13,14,15, 200,201,202, 19,20,21, 22,23,24, 25,26,27, 28,29,30}, nd4j::DataType::FLOAT32); - - x.linspace(1); - - nd4j::ops::scatter_upd op; - auto result = op.execute({&x, &indices, &updates}, {}, {}); - ASSERT_EQ(Status::OK(), result->status()); - - auto z = result->at(0); - - ASSERT_EQ(e, *z); - - delete result; -} - -TEST_F(DeclarableOpsTests16, scatter_upd_3) { - - NDArray x('c', {10, 3}, nd4j::DataType::FLOAT32); - NDArray indices('c', {2}, {20,5}, nd4j::DataType::INT32); - NDArray updates('c', {2, 3}, {100,101,102, 200,201,202}, nd4j::DataType::FLOAT32); - NDArray output('c', {10, 3}, nd4j::DataType::FLOAT32); - - nd4j::ops::scatter_upd op; - ASSERT_ANY_THROW(op.execute({&x, &indices, &updates}, {&output}, {}, {}, {true, true})); -} - -TEST_F(DeclarableOpsTests16, test_size_dtype_1) { - auto x = NDArrayFactory::create('c', {3}, {1, 1, 1}); - auto z = NDArrayFactory::create(0.0f); - auto e = NDArrayFactory::create(3.0f); - - nd4j::ops::size op; - auto status = op.execute({&x}, {&z}, {}, {}, {}); - ASSERT_EQ(Status::OK(), status); - - ASSERT_EQ(e, z); -} - -TEST_F(DeclarableOpsTests16, test_empty_noop_1) { - auto z = NDArrayFactory::empty(); - - nd4j::ops::noop op; - auto status = op.execute({}, {&z}, {}, {}, {}); - ASSERT_EQ(Status::OK(), status); -} - -TEST_F(DeclarableOpsTests16, test_empty_noop_2) { - auto z = NDArrayFactory::empty(); - - Context ctx(1); - ctx.setOutputArray(0, z.buffer(), z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo()); - - nd4j::ops::noop op; - auto status = op.execute(&ctx); - - ASSERT_EQ(Status::OK(), status); -} - -TEST_F(DeclarableOpsTests16, test_svd_1) { - auto x = NDArrayFactory::create('c', {3, 3}, {0.7787856f, 0.80119777f, 0.72437465f, 0.23089433f, 0.72714126f, 0.18039072f,0.50563407f, 0.89252293f, 0.5461209f}); - auto z = NDArrayFactory::create('c', {3}); - - nd4j::ops::svd op; - auto status = op.execute({&x}, {&z}, {}, {0, 0, 16}, {}); - - ASSERT_EQ(Status::OK(), status); -} - -TEST_F(DeclarableOpsTests16, test_hamming_distance_1) { - auto x = NDArrayFactory::create({37, 37, 37}); - auto y = NDArrayFactory::create({8723, 8723, 8723}); - auto e = NDArrayFactory::create(18); - - nd4j::ops::bits_hamming_distance op; - auto result = op.execute({&x, &y}, {}, {}); - ASSERT_EQ(Status::OK(), result->status()); - - auto z = result->at(0); - - ASSERT_EQ(e, *z); - - delete result; -} - -TEST_F(DeclarableOpsTests16, test_knn_mindistance_1) { - auto input = NDArrayFactory::create('c', {512}); - auto low = NDArrayFactory::create('c', {512}); - auto high = NDArrayFactory::create('c', {512}); - - auto output = NDArrayFactory::create(0.0f); - - input.linspace(1.0); - low.linspace(1.0); - high.linspace(1.0); - - nd4j::ops::knn_mindistance op; - auto result = op.execute({&input, &low, &high}, {&output}, {}, {}, {}); - ASSERT_EQ(Status::OK(), result); -} - -TEST_F(DeclarableOpsTests16, test_empty_cast_1) { - auto x = NDArrayFactory::create('c', {1, 0, 2}); - auto e = NDArrayFactory::create('c', {1, 0, 2}); - - nd4j::ops::cast op; - auto result = op.execute({&x}, {}, {10}); - ASSERT_EQ(Status::OK(), result->status()); - ASSERT_EQ(e, *result->at(0)); - - delete result; -} - -TEST_F(DeclarableOpsTests16, test_range_1) { - nd4j::ops::range op; - auto z = NDArrayFactory::create('c', {200}); - - Context ctx(1); - ctx.setTArguments({-1.0, 1.0, 0.01}); - ctx.setOutputArray(0, &z); - - auto status = op.execute(&ctx); - ASSERT_EQ(Status::OK(), status); -} - -TEST_F(DeclarableOpsTests16, test_range_2) { - nd4j::ops::range op; - auto z = NDArrayFactory::create('c', {200}); - - double tArgs[] = {-1.0, 1.0, 0.01}; - - auto shapes = ::calculateOutputShapes2(nullptr, op.getOpHash(), nullptr, nullptr, 0, tArgs, 3, nullptr, 0, nullptr, 0); - shape::printShapeInfoLinear("Result", shapes->at(0)); - ASSERT_TRUE(shape::shapeEquals(z.shapeInfo(), shapes->at(0))); - - delete shapes; -} - -TEST_F(DeclarableOpsTests16, test_reverse_1) { - std::vector rows = {3, 5, 7, 8, 9, 10, 119, 211}; - std::vector columns = {6, 5, 10, 100, 153, 171, 635}; - - for (auto r : rows) { - for (auto c : columns) { - //nd4j_printf("Trying [%i, %i]\n", r, c); - auto array = NDArrayFactory::create('c', {r, c}); - auto exp = NDArrayFactory::create('c', {r, c}); - auto reversed = NDArrayFactory::create('c', {r, c}); - - auto rowOriginal = NDArrayFactory::create('c', {c}); - auto rowReversed = NDArrayFactory::create('c', {c}); - - for (int e = 0; e < c; e++) { - rowOriginal.p(e, (float) e); - rowReversed.p(c - e - 1, (float) e); - } - - - auto listI = array.allTensorsAlongDimension({1}); - auto listE = exp.allTensorsAlongDimension({1}); - - for (int e = 0; e < r; e++) { - listI->at(e)->assign(rowOriginal); - listE->at(e)->assign(rowReversed); - } - - delete listI; - delete listE; - - nd4j::ops::reverse op; - Nd4jLong axis = 1; - auto status = op.execute({&array}, {&reversed}, {}, {axis}, {}); - ASSERT_EQ(Status::OK(), status); - - ASSERT_EQ(exp, reversed); - } - } +/******************************************************************************* + * Copyright (c) 2015-2019 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + + + // + // @author raver119@gmail.com + // + +#include "testlayers.h" +#include +#include +#include +#include +#include + + +using namespace nd4j; + + +class DeclarableOpsTests16 : public testing::Test { +public: + + DeclarableOpsTests16() { + printf("\n"); + fflush(stdout); + } +}; + +TEST_F(DeclarableOpsTests16, scatter_upd_1) { + auto x = NDArrayFactory::create('c', { 3 }, { 1.f, 1.f, 1.f }); + auto y = NDArrayFactory::create(0); + auto w = NDArrayFactory::create(3.0f); + auto e = NDArrayFactory::create('c', { 3 }, { 3.f, 1.f, 1.f }); + + nd4j::ops::scatter_upd op; + auto result = op.execute({ &x, &y, &w }, {}, {}); + ASSERT_EQ(Status::OK(), result->status()); + + auto z = result->at(0); + + ASSERT_EQ(e, *z); + + delete result; +} + +TEST_F(DeclarableOpsTests16, scatter_upd_2) { + + NDArray x('c', { 10, 3 }, nd4j::DataType::FLOAT32); + NDArray indices('c', { 2 }, { 2,5 }, nd4j::DataType::INT32); + NDArray updates('c', { 2, 3 }, { 100,101,102, 200,201,202 }, nd4j::DataType::FLOAT32); + NDArray e('c', { 10, 3 }, { 1,2,3, 4,5,6, 100,101,102, 10,11,12, 13,14,15, 200,201,202, 19,20,21, 22,23,24, 25,26,27, 28,29,30 }, nd4j::DataType::FLOAT32); + + x.linspace(1); + + nd4j::ops::scatter_upd op; + auto result = op.execute({ &x, &indices, &updates }, {}, {}); + ASSERT_EQ(Status::OK(), result->status()); + + auto z = result->at(0); + + ASSERT_EQ(e, *z); + + delete result; +} + +TEST_F(DeclarableOpsTests16, scatter_upd_3) { + + NDArray x('c', { 10, 3 }, nd4j::DataType::FLOAT32); + NDArray indices('c', { 2 }, { 20,5 }, nd4j::DataType::INT32); + NDArray updates('c', { 2, 3 }, { 100,101,102, 200,201,202 }, nd4j::DataType::FLOAT32); + NDArray output('c', { 10, 3 }, nd4j::DataType::FLOAT32); + + nd4j::ops::scatter_upd op; + ASSERT_ANY_THROW(op.execute({ &x, &indices, &updates }, { &output }, {}, {}, { true, true })); +} + +TEST_F(DeclarableOpsTests16, test_size_dtype_1) { + auto x = NDArrayFactory::create('c', { 3 }, { 1, 1, 1 }); + auto z = NDArrayFactory::create(0.0f); + auto e = NDArrayFactory::create(3.0f); + + nd4j::ops::size op; + auto status = op.execute({ &x }, { &z }, {}, {}, {}); + ASSERT_EQ(Status::OK(), status); + + ASSERT_EQ(e, z); +} + +TEST_F(DeclarableOpsTests16, test_empty_noop_1) { + auto z = NDArrayFactory::empty(); + + nd4j::ops::noop op; + auto status = op.execute({}, { &z }, {}, {}, {}); + ASSERT_EQ(Status::OK(), status); +} + +TEST_F(DeclarableOpsTests16, test_empty_noop_2) { + auto z = NDArrayFactory::empty(); + + Context ctx(1); + ctx.setOutputArray(0, z.buffer(), z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo()); + + nd4j::ops::noop op; + auto status = op.execute(&ctx); + + ASSERT_EQ(Status::OK(), status); +} + +TEST_F(DeclarableOpsTests16, test_svd_1) { + auto x = NDArrayFactory::create('c', { 3, 3 }, { 0.7787856f, 0.80119777f, 0.72437465f, 0.23089433f, 0.72714126f, 0.18039072f,0.50563407f, 0.89252293f, 0.5461209f }); + auto z = NDArrayFactory::create('c', { 3 }); + + nd4j::ops::svd op; + auto status = op.execute({ &x }, { &z }, {}, { 0, 0, 16 }, {}); + + ASSERT_EQ(Status::OK(), status); +} + +TEST_F(DeclarableOpsTests16, test_hamming_distance_1) { + auto x = NDArrayFactory::create({ 37, 37, 37 }); + auto y = NDArrayFactory::create({ 8723, 8723, 8723 }); + auto e = NDArrayFactory::create(18); + + nd4j::ops::bits_hamming_distance op; + auto result = op.execute({ &x, &y }, {}, {}); + ASSERT_EQ(Status::OK(), result->status()); + + auto z = result->at(0); + + ASSERT_EQ(e, *z); + + delete result; +} + +TEST_F(DeclarableOpsTests16, test_knn_mindistance_1) { + auto input = NDArrayFactory::create('c', { 512 }); + auto low = NDArrayFactory::create('c', { 512 }); + auto high = NDArrayFactory::create('c', { 512 }); + + auto output = NDArrayFactory::create(0.0f); + + input.linspace(1.0); + low.linspace(1.0); + high.linspace(1.0); + + nd4j::ops::knn_mindistance op; + auto result = op.execute({ &input, &low, &high }, { &output }, {}, {}, {}); + ASSERT_EQ(Status::OK(), result); +} + +TEST_F(DeclarableOpsTests16, test_empty_cast_1) { + auto x = NDArrayFactory::create('c', { 1, 0, 2 }); + auto e = NDArrayFactory::create('c', { 1, 0, 2 }); + + nd4j::ops::cast op; + auto result = op.execute({ &x }, {}, { 10 }); + ASSERT_EQ(Status::OK(), result->status()); + ASSERT_EQ(e, *result->at(0)); + + delete result; +} + +TEST_F(DeclarableOpsTests16, test_range_1) { + nd4j::ops::range op; + auto z = NDArrayFactory::create('c', { 200 }); + + Context ctx(1); + ctx.setTArguments({ -1.0, 1.0, 0.01 }); + ctx.setOutputArray(0, &z); + + auto status = op.execute(&ctx); + ASSERT_EQ(Status::OK(), status); +} + +TEST_F(DeclarableOpsTests16, test_range_2) { + nd4j::ops::range op; + auto z = NDArrayFactory::create('c', { 200 }); + + double tArgs[] = { -1.0, 1.0, 0.01 }; + + auto shapes = ::calculateOutputShapes2(nullptr, op.getOpHash(), nullptr, nullptr, 0, tArgs, 3, nullptr, 0, nullptr, 0); + shape::printShapeInfoLinear("Result", shapes->at(0)); + ASSERT_TRUE(shape::shapeEquals(z.shapeInfo(), shapes->at(0))); + + delete shapes; +} + +TEST_F(DeclarableOpsTests16, test_reverse_1) { + std::vector rows = { 3, 5, 7, 8, 9, 10, 119, 211 }; + std::vector columns = { 6, 5, 10, 100, 153, 171, 635 }; + + for (auto r : rows) { + for (auto c : columns) { + //nd4j_printf("Trying [%i, %i]\n", r, c); + auto array = NDArrayFactory::create('c', { r, c }); + auto exp = NDArrayFactory::create('c', { r, c }); + auto reversed = NDArrayFactory::create('c', { r, c }); + + auto rowOriginal = NDArrayFactory::create('c', { c }); + auto rowReversed = NDArrayFactory::create('c', { c }); + + for (int e = 0; e < c; e++) { + rowOriginal.p(e, (float)e); + rowReversed.p(c - e - 1, (float)e); + } + + + auto listI = array.allTensorsAlongDimension({ 1 }); + auto listE = exp.allTensorsAlongDimension({ 1 }); + + for (int e = 0; e < r; e++) { + listI->at(e)->assign(rowOriginal); + listE->at(e)->assign(rowReversed); + } + + delete listI; + delete listE; + + nd4j::ops::reverse op; + Nd4jLong axis = 1; + auto status = op.execute({ &array }, { &reversed }, {}, { axis }, {}); + ASSERT_EQ(Status::OK(), status); + + ASSERT_EQ(exp, reversed); + } + } +} + +TEST_F(DeclarableOpsTests16, test_rgb_to_hsv_1) { + /* + test case generated by python colorsys and scaled to suit our needs + from colorsys import * + from random import * + import numpy as np + rgbs = np.array([randint(0,255) for x in range(0,3*4*5)]).reshape([5,4,3]) + hsvs=np.apply_along_axis(lambda x: np.array(rgb_to_hsv(x[0]/255,x[1]/255,x[2]/255))*np.array([360,1,1]),2,rgbs) + rgbs.ravel() + hsvs.ravel() + */ + auto rgbs = NDArrayFactory::create('c', { 5, 4, 3 }, + { + 213.f, 220.f, 164.f, 121.f, 180.f, 180.f, 18.f, 245.f, 75.f, 235.f, 76.f, 74.f, 168.f, + 50.f, 233.f, 191.f, 132.f, 100.f, 207.f, 37.f, 245.f, 77.f, 250.f, 182.f, 111.f, 52.f, + 59.f, 193.f, 147.f, 137.f, 168.f, 103.f, 121.f, 48.f, 191.f, 187.f, 53.f, 82.f, 239.f, + 156.f, 37.f, 118.f, 244.f, 90.f, 7.f, 221.f, 98.f, 243.f, 12.f, 209.f, 192.f, 2.f, + 115.f, 205.f, 79.f, 247.f, 32.f, 70.f, 152.f, 180.f + }); + auto expected = NDArrayFactory::create('c', { 5, 4, 3 }, + { + 6.75000000e+01f, 2.54545455e-01f, 8.62745098e-01f, 1.80000000e+02f, + 3.27777778e-01f, 7.05882353e-01f, 1.35066079e+02f, 9.26530612e-01f, + 9.60784314e-01f, 7.45341615e-01f, 6.85106383e-01f, 9.21568627e-01f, + 2.78688525e+02f, 7.85407725e-01f, 9.13725490e-01f, 2.10989011e+01f, + 4.76439791e-01f, 7.49019608e-01f, 2.89038462e+02f, 8.48979592e-01f, + 9.60784314e-01f, 1.56416185e+02f, 6.92000000e-01f, 9.80392157e-01f, + 3.52881356e+02f, 5.31531532e-01f, 4.35294118e-01f, 1.07142857e+01f, + 2.90155440e-01f, 7.56862745e-01f, 3.43384615e+02f, 3.86904762e-01f, + 6.58823529e-01f, 1.78321678e+02f, 7.48691099e-01f, 7.49019608e-01f, + 2.30645161e+02f, 7.78242678e-01f, 9.37254902e-01f, 3.19159664e+02f, + 7.62820513e-01f, 6.11764706e-01f, 2.10126582e+01f, 9.71311475e-01f, + 9.56862745e-01f, 2.90896552e+02f, 5.96707819e-01f, 9.52941176e-01f, + 1.74822335e+02f, 9.42583732e-01f, 8.19607843e-01f, 2.06600985e+02f, + 9.90243902e-01f, 8.03921569e-01f, 1.06883721e+02f, 8.70445344e-01f, + 9.68627451e-01f, 1.95272727e+02f, 6.11111111e-01f, 7.05882353e-01f + }); + + + auto actual = NDArrayFactory::create('c', { 5,4,3 }); + + Context ctx(1); + ctx.setInputArray(0, &rgbs); + ctx.setOutputArray(0, &actual); + + nd4j::ops::rgb_to_hsv op; + auto status = op.execute(&ctx); +#if 0 + //visual check + rgbs.printBuffer("rgbs "); + actual.printBuffer("HSV "); + expected.printBuffer("exp"); +#endif + ASSERT_EQ(ND4J_STATUS_OK, status); + ASSERT_TRUE(expected.equalsTo(actual)); + +} + +TEST_F(DeclarableOpsTests16, test_rgb_to_hsv_2) { + /* + swapped_rgbs=rgbs.swapaxes(1,2).ravel() + swapped_hsvs=hsvs.swapaxes(1,2).ravel() + */ + auto rgbs = NDArrayFactory::create('c', { 5,3,4 }, + { + 213.f, 121.f, 18.f, 235.f, 220.f, 180.f, 245.f, 76.f, 164.f, 180.f, 75.f, 74.f, 168.f, + 191.f, 207.f, 77.f, 50.f, 132.f, 37.f, 250.f, 233.f, 100.f, 245.f, 182.f, 111.f, 193.f, + 168.f, 48.f, 52.f, 147.f, 103.f, 191.f, 59.f, 137.f, 121.f, 187.f, 53.f, 156.f, 244.f, + 221.f, 82.f, 37.f, 90.f, 98.f, 239.f, 118.f, 7.f, 243.f, 12.f, 2.f, 79.f, 70.f, + 209.f, 115.f, 247.f, 152.f, 192.f, 205.f, 32.f, 180.f + }); + auto expected = NDArrayFactory::create('c', { 5,3,4 }, + { + 6.75000000e+01f, 1.80000000e+02f, 1.35066079e+02f, 7.45341615e-01f, + 2.54545455e-01f, 3.27777778e-01f, 9.26530612e-01f, 6.85106383e-01f, + 8.62745098e-01f, 7.05882353e-01f, 9.60784314e-01f, 9.21568627e-01f, + 2.78688525e+02f, 2.10989011e+01f, 2.89038462e+02f, 1.56416185e+02f, + 7.85407725e-01f, 4.76439791e-01f, 8.48979592e-01f, 6.92000000e-01f, + 9.13725490e-01f, 7.49019608e-01f, 9.60784314e-01f, 9.80392157e-01f, + 3.52881356e+02f, 1.07142857e+01f, 3.43384615e+02f, 1.78321678e+02f, + 5.31531532e-01f, 2.90155440e-01f, 3.86904762e-01f, 7.48691099e-01f, + 4.35294118e-01f, 7.56862745e-01f, 6.58823529e-01f, 7.49019608e-01f, + 2.30645161e+02f, 3.19159664e+02f, 2.10126582e+01f, 2.90896552e+02f, + 7.78242678e-01f, 7.62820513e-01f, 9.71311475e-01f, 5.96707819e-01f, + 9.37254902e-01f, 6.11764706e-01f, 9.56862745e-01f, 9.52941176e-01f, + 1.74822335e+02f, 2.06600985e+02f, 1.06883721e+02f, 1.95272727e+02f, + 9.42583732e-01f, 9.90243902e-01f, 8.70445344e-01f, 6.11111111e-01f, + 8.19607843e-01f, 8.03921569e-01f, 9.68627451e-01f, 7.05882353e-01f + }); + + + auto actual = NDArrayFactory::create('c', { 5,3,4 }); + + Context ctx(1); + ctx.setInputArray(0, &rgbs); + ctx.setOutputArray(0, &actual); + ctx.setIArguments({ 1 }); + nd4j::ops::rgb_to_hsv op; + auto status = op.execute(&ctx); + + ASSERT_EQ(ND4J_STATUS_OK, status); + ASSERT_TRUE(expected.equalsTo(actual)); + +} + +TEST_F(DeclarableOpsTests16, test_rgb_to_hsv_3) { + /* + 2D + */ + auto rgbs = NDArrayFactory::create('c', { 8,3 }, + { 130.f, 61.f, 239.f, 117.f, 16.f, 168.f, 181.f, 223.f, 0.f, 49.f, 195.f, 195.f, 131.f, + 153.f, 78.f, 86.f, 21.f, 209.f, 101.f, 14.f, 107.f, 191.f, 98.f, 210.f }); + auto expected = NDArrayFactory::create('c', { 8,3 }, + { 263.25842697f, 0.74476987f, 0.9372549f, 279.86842105f, + 0.9047619f, 0.65882353f, 71.30044843f, 1.f, + 0.8745098f, 180.f, 0.74871795f, 0.76470588f, + 77.6f, 0.49019608f, 0.6f, 260.74468085f, + 0.89952153f, 0.81960784f, 296.12903226f, 0.86915888f, + 0.41960784f, 289.82142857f, 0.53333333f, 0.82352941f }); + + + auto actual = NDArrayFactory::create('c', { 8,3 }); + + Context ctx(1); + ctx.setInputArray(0, &rgbs); + ctx.setOutputArray(0, &actual); + + nd4j::ops::rgb_to_hsv op; + auto status = op.execute(&ctx); +#if 0 + //visual check + rgbs.printBuffer("rgbs "); + actual.printBuffer("HSV "); + expected.printBuffer("exp"); +#endif + ASSERT_EQ(ND4J_STATUS_OK, status); + ASSERT_TRUE(expected.equalsTo(actual)); + +} + + +TEST_F(DeclarableOpsTests16, test_rgb_to_hsv_4) { + /* + 2D + */ + auto rgbs = NDArrayFactory::create('c', { 3,8 }, + { 130.f, 117.f, 181.f, 49.f, 131.f, 86.f, 101.f, 191.f, 61.f, 16.f, 223.f, 195.f, 153.f, + 21.f, 14.f, 98.f, 239.f, 168.f, 0.f, 195.f, 78.f, 209.f, 107.f, 210.f }); + auto expected = NDArrayFactory::create('c', { 3, 8 }, + { 263.25842697f, 279.86842105f, 71.30044843f, 180.f, + 77.6f, 260.74468085f, 296.12903226f, 289.82142857f, + 0.74476987f, 0.9047619f, 1.f, 0.74871795f, + 0.49019608f, 0.89952153f, 0.86915888f, 0.53333333f, + 0.9372549f, 0.65882353f, 0.8745098f, 0.76470588f, + 0.6f, 0.81960784f, 0.41960784f, 0.82352941f }); + + + auto actual = NDArrayFactory::create('c', { 3, 8 }); + + Context ctx(1); + ctx.setInputArray(0, &rgbs); + ctx.setOutputArray(0, &actual); + ctx.setIArguments({ 0 }); + nd4j::ops::rgb_to_hsv op; + auto status = op.execute(&ctx); +#if 0 + //visual check + rgbs.printBuffer("rgbs "); + actual.printBuffer("HSV "); + expected.printBuffer("exp"); +#endif + ASSERT_EQ(ND4J_STATUS_OK, status); + ASSERT_TRUE(expected.equalsTo(actual)); + +} + +TEST_F(DeclarableOpsTests16, test_rgb_to_hsv_5) { + /* + + */ + auto rgbs = NDArrayFactory::create('c', { 3 }, + { 213.f, 220.f, 164.f }); + auto expected = NDArrayFactory::create('c', { 3 }, + { 6.75000000e+01f, 2.54545455e-01f, 8.62745098e-01f }); + + + auto actual = NDArrayFactory::create('c', { 3 }); + + Context ctx(1); + ctx.setInputArray(0, &rgbs); + ctx.setOutputArray(0, &actual); + + nd4j::ops::rgb_to_hsv op; + auto status = op.execute(&ctx); +#if 0 + //visual check + rgbs.printBuffer("rgbs "); + actual.printBuffer("HSV "); + expected.printBuffer("exp"); +#endif + ASSERT_EQ(ND4J_STATUS_OK, status); + ASSERT_TRUE(expected.equalsTo(actual)); + +} + + +TEST_F(DeclarableOpsTests16, test_rgb_to_hsv_6) { + /* + + */ + auto rgbs = NDArrayFactory::create('c', { 3,8 }, + { 130.f, 117.f, 181.f, 49.f, 131.f, 86.f, 101.f, 191.f, 61.f, 16.f, 223.f, 195.f, 153.f, + 21.f, 14.f, 98.f, 239.f, 168.f, 0.f, 195.f, 78.f, 209.f, 107.f, 210.f }); + + auto expected = NDArrayFactory::create('c', { 3 }, + { 263.25842697f, 0.74476987f, 0.9372549f }); + + //get subarray + std::unique_ptr subArrRgbs(rgbs.subarray({ NDIndex::all(), NDIndex::point(0) })); + subArrRgbs->reshapei({ 3 }); +#if 0 + //[RANK][SHAPE][STRIDES][OPTIONS][EWS][ORDER] + subArrRgbs->printShapeInfo("subArrRgbs"); +#endif + auto actual = NDArrayFactory::create('c', { 3 }); + + Context ctx(1); + ctx.setInputArray(0, subArrRgbs.get()); + ctx.setOutputArray(0, &actual); + ctx.setIArguments({ 0 }); + nd4j::ops::rgb_to_hsv op; + auto status = op.execute(&ctx); +#if 0 + //visual check + subArrRgbs->printBuffer("subArrRgbs "); + actual.printBuffer("HSV "); + expected.printBuffer("exp"); +#endif + ASSERT_EQ(ND4J_STATUS_OK, status); + ASSERT_TRUE(expected.equalsTo(actual)); + +} + +TEST_F(DeclarableOpsTests16, test_hsv_to_rgb_1) { + /* + using the same numbers of rgb_to_hsv_1 test + */ + auto expected = NDArrayFactory::create('c', { 5,4,3 }, + { 213.f, 220.f, 164.f, 121.f, 180.f, 180.f, 18.f, 245.f, 75.f, 235.f, 76.f, 74.f, 168.f, + 50.f, 233.f, 191.f, 132.f, 100.f, 207.f, 37.f, 245.f, 77.f, 250.f, 182.f, 111.f, 52.f, + 59.f, 193.f, 147.f, 137.f, 168.f, 103.f, 121.f, 48.f, 191.f, 187.f, 53.f, 82.f, 239.f, + 156.f, 37.f, 118.f, 244.f, 90.f, 7.f, 221.f, 98.f, 243.f, 12.f, 209.f, 192.f, 2.f, + 115.f, 205.f, 79.f, 247.f, 32.f, 70.f, 152.f, 180.f } + ); + auto hsvs = NDArrayFactory::create('c', { 5,4,3 }, + { + 6.75000000e+01f, 2.54545455e-01f, 8.62745098e-01f, 1.80000000e+02f, + 3.27777778e-01f, 7.05882353e-01f, 1.35066079e+02f, 9.26530612e-01f, + 9.60784314e-01f, 7.45341615e-01f, 6.85106383e-01f, 9.21568627e-01f, + 2.78688525e+02f, 7.85407725e-01f, 9.13725490e-01f, 2.10989011e+01f, + 4.76439791e-01f, 7.49019608e-01f, 2.89038462e+02f, 8.48979592e-01f, + 9.60784314e-01f, 1.56416185e+02f, 6.92000000e-01f, 9.80392157e-01f, + 3.52881356e+02f, 5.31531532e-01f, 4.35294118e-01f, 1.07142857e+01f, + 2.90155440e-01f, 7.56862745e-01f, 3.43384615e+02f, 3.86904762e-01f, + 6.58823529e-01f, 1.78321678e+02f, 7.48691099e-01f, 7.49019608e-01f, + 2.30645161e+02f, 7.78242678e-01f, 9.37254902e-01f, 3.19159664e+02f, + 7.62820513e-01f, 6.11764706e-01f, 2.10126582e+01f, 9.71311475e-01f, + 9.56862745e-01f, 2.90896552e+02f, 5.96707819e-01f, 9.52941176e-01f, + 1.74822335e+02f, 9.42583732e-01f, 8.19607843e-01f, 2.06600985e+02f, + 9.90243902e-01f, 8.03921569e-01f, 1.06883721e+02f, 8.70445344e-01f, + 9.68627451e-01f, 1.95272727e+02f, 6.11111111e-01f, 7.05882353e-01f + }); + + + auto actual = NDArrayFactory::create('c', { 5,4,3 }); + + Context ctx(1); + ctx.setInputArray(0, &hsvs); + ctx.setOutputArray(0, &actual); + + nd4j::ops::hsv_to_rgb op; + auto status = op.execute(&ctx); + + ASSERT_EQ(ND4J_STATUS_OK, status); + ASSERT_TRUE(expected.equalsTo(actual)); + +} + +TEST_F(DeclarableOpsTests16, test_hsv_to_rgb_2) { + /* + using the same numbers of hsv_to_rgb_2 + */ + auto expected = NDArrayFactory::create('c', { 5,3,4 }, + { 213.f, 121.f, 18.f, 235.f, 220.f, 180.f, 245.f, 76.f, 164.f, 180.f, 75.f, 74.f, 168.f, + 191.f, 207.f, 77.f, 50.f, 132.f, 37.f, 250.f, 233.f, 100.f, 245.f, 182.f, 111.f, 193.f, + 168.f, 48.f, 52.f, 147.f, 103.f, 191.f, 59.f, 137.f, 121.f, 187.f, 53.f, 156.f, 244.f, + 221.f, 82.f, 37.f, 90.f, 98.f, 239.f, 118.f, 7.f, 243.f, 12.f, 2.f, 79.f, 70.f, + 209.f, 115.f, 247.f, 152.f, 192.f, 205.f, 32.f, 180.f } + ); + auto hsvs = NDArrayFactory::create('c', { 5,3,4 }, + { + 6.75000000e+01f, 1.80000000e+02f, 1.35066079e+02f, 7.45341615e-01f, + 2.54545455e-01f, 3.27777778e-01f, 9.26530612e-01f, 6.85106383e-01f, + 8.62745098e-01f, 7.05882353e-01f, 9.60784314e-01f, 9.21568627e-01f, + 2.78688525e+02f, 2.10989011e+01f, 2.89038462e+02f, 1.56416185e+02f, + 7.85407725e-01f, 4.76439791e-01f, 8.48979592e-01f, 6.92000000e-01f, + 9.13725490e-01f, 7.49019608e-01f, 9.60784314e-01f, 9.80392157e-01f, + 3.52881356e+02f, 1.07142857e+01f, 3.43384615e+02f, 1.78321678e+02f, + 5.31531532e-01f, 2.90155440e-01f, 3.86904762e-01f, 7.48691099e-01f, + 4.35294118e-01f, 7.56862745e-01f, 6.58823529e-01f, 7.49019608e-01f, + 2.30645161e+02f, 3.19159664e+02f, 2.10126582e+01f, 2.90896552e+02f, + 7.78242678e-01f, 7.62820513e-01f, 9.71311475e-01f, 5.96707819e-01f, + 9.37254902e-01f, 6.11764706e-01f, 9.56862745e-01f, 9.52941176e-01f, + 1.74822335e+02f, 2.06600985e+02f, 1.06883721e+02f, 1.95272727e+02f, + 9.42583732e-01f, 9.90243902e-01f, 8.70445344e-01f, 6.11111111e-01f, + 8.19607843e-01f, 8.03921569e-01f, 9.68627451e-01f, 7.05882353e-01f + }); + + + auto actual = NDArrayFactory::create('c', { 5,3,4 }); + + Context ctx(1); + ctx.setInputArray(0, &hsvs); + ctx.setOutputArray(0, &actual); + ctx.setIArguments({ 1 }); + nd4j::ops::hsv_to_rgb op; + auto status = op.execute(&ctx); + + ASSERT_EQ(ND4J_STATUS_OK, status); + ASSERT_TRUE(expected.equalsTo(actual)); + +} + +TEST_F(DeclarableOpsTests16, test_hsv_to_rgb_3) { + /* + 2D + */ + auto expected = NDArrayFactory::create('c', { 8,3 }, + { 130.f, 61.f, 239.f, 117.f, 16.f, 168.f, 181.f, 223.f, 0.f, 49.f, 195.f, 195.f, 131.f, + 153.f, 78.f, 86.f, 21.f, 209.f, 101.f, 14.f, 107.f, 191.f, 98.f, 210.f }); + auto hsvs = NDArrayFactory::create('c', { 8,3 }, + { 263.25842697f, 0.74476987f, 0.9372549f, 279.86842105f, + 0.9047619f, 0.65882353f, 71.30044843f, 1.f, + 0.8745098f, 180.f, 0.74871795f, 0.76470588f, + 77.6f, 0.49019608f, 0.6f, 260.74468085f, + 0.89952153f, 0.81960784f, 296.12903226f, 0.86915888f, + 0.41960784f, 289.82142857f, 0.53333333f, 0.82352941f }); + + + auto actual = NDArrayFactory::create('c', { 8,3 }); + + Context ctx(1); + ctx.setInputArray(0, &hsvs); + ctx.setOutputArray(0, &actual); + + nd4j::ops::hsv_to_rgb op; + auto status = op.execute(&ctx); + + ASSERT_EQ(ND4J_STATUS_OK, status); + ASSERT_TRUE(expected.equalsTo(actual)); + +} + +TEST_F(DeclarableOpsTests16, test_hsv_to_rgb_4) { + /* + 2D + */ + auto expected = NDArrayFactory::create('c', { 3,8 }, + { 130.f, 117.f, 181.f, 49.f, 131.f, 86.f, 101.f, 191.f, 61.f, 16.f, 223.f, 195.f, 153.f, + 21.f, 14.f, 98.f, 239.f, 168.f, 0.f, 195.f, 78.f, 209.f, 107.f, 210.f }); + auto hsvs = NDArrayFactory::create('c', { 3,8 }, + { 263.25842697f, 279.86842105f, 71.30044843f, 180.f, + 77.6f, 260.74468085f, 296.12903226f, 289.82142857f, + 0.74476987f, 0.9047619f, 1.f, 0.74871795f, + 0.49019608f, 0.89952153f, 0.86915888f, 0.53333333f, + 0.9372549f, 0.65882353f, 0.8745098f, 0.76470588f, + 0.6f, 0.81960784f, 0.41960784f, 0.82352941f }); + + + auto actual = NDArrayFactory::create('c', { 3,8 }); + + Context ctx(1); + ctx.setInputArray(0, &hsvs); + ctx.setOutputArray(0, &actual); + ctx.setIArguments({ 0 }); + nd4j::ops::hsv_to_rgb op; + auto status = op.execute(&ctx); + + ASSERT_EQ(ND4J_STATUS_OK, status); + ASSERT_TRUE(expected.equalsTo(actual)); + +} + +TEST_F(DeclarableOpsTests16, test_hsv_to_rgb_5) { + /* + + */ + auto expected = NDArrayFactory::create('c', { 3 }, + { 213.f, 220.f, 164.f }); + auto hsvs = NDArrayFactory::create('c', { 3 }, + { 6.75000000e+01f, 2.54545455e-01f, 8.62745098e-01f }); + + + auto actual = NDArrayFactory::create('c', { 3 }); + + Context ctx(1); + ctx.setInputArray(0, &hsvs); + ctx.setOutputArray(0, &actual); + + nd4j::ops::hsv_to_rgb op; + auto status = op.execute(&ctx); + + ASSERT_EQ(ND4J_STATUS_OK, status); + ASSERT_TRUE(expected.equalsTo(actual)); + +} + +TEST_F(DeclarableOpsTests16, test_hsv_to_rgb_6) { + + auto expected = NDArrayFactory::create('c', { 3 }, + { 130.0, 61.0, 239.0 }); + auto hsvs = NDArrayFactory::create('c', { 3,8 }, + { 263.25842697, 279.86842105, 71.30044843, 180, + 77.6, 260.74468085, 296.12903226, 289.82142857, + 0.74476987, 0.9047619, 1., 0.74871795, + 0.49019608, 0.89952153, 0.86915888, 0.53333333, + 0.9372549, 0.65882353, 0.8745098, 0.76470588, + 0.6, 0.81960784, 0.41960784, 0.82352941 + }); + + //get subarray + std::unique_ptr subArrHsvs(hsvs.subarray({ NDIndex::all(), NDIndex::point(0) })); + subArrHsvs->reshapei({ 3 }); +#if 0 + //[RANK][SHAPE][STRIDES][OPTIONS][EWS][ORDER] + subArrHsvs->printShapeInfo("subArrHsvs"); +#endif + auto actual = NDArrayFactory::create('c', { 3 }); + + Context ctx(1); + ctx.setInputArray(0, subArrHsvs.get()); + ctx.setOutputArray(0, &actual); + nd4j::ops::hsv_to_rgb op; + auto status = op.execute(&ctx); +#if 0 + //visual check + subArrHsvs->printBuffer("subArrHsvs "); + actual.printBuffer("rgb "); + expected.printBuffer("exp"); +#endif + ASSERT_EQ(ND4J_STATUS_OK, status); + ASSERT_TRUE(expected.equalsTo(actual)); + } \ No newline at end of file