[WIP] HSV,RGB color model conversions (#125)

* CUDA implementation for hsv_to_rgb and rgb_to_hsv

Signed-off-by: raver119 <raver119@gmail.com>

* hsv_to_rgb and rgb_to_hsv operations
Test coverage: c order 1d, 2d, 3d array

Signed-off-by: Abdelrauf <rauf@konduit.ai>

* Index check

Signed-off-by: Abdelrauf <rauf@konduit.ai>

* Suppress Msvc floating point errors

Signed-off-by: Abdelrauf <rauf@konduit.ai>

* Added Index Check for adjust_saturation and adjust_hue

Signed-off-by: Abdelrauf <rauf@konduit.ai>

* minor fix

Signed-off-by: raver119 <raver119@gmail.com>

* Fixes missed Msvc floating narrowing errors

Signed-off-by: Abdelrauf <rauf@konduit.ai>
master
Abdelrauf 2019-12-17 10:42:09 +04:00 committed by raver119
parent bfd9e3692a
commit e0a9cb6c08
10 changed files with 1120 additions and 241 deletions

View File

@ -12,6 +12,21 @@
"cmakeCommandArgs": " -DCUDA_BLAS=true -DLIBND4J_NAME=nd4jcuda -DMSVC_DEV=true -DCOMPUTE=61 -DBUILD_TESTS=true", "cmakeCommandArgs": " -DCUDA_BLAS=true -DLIBND4J_NAME=nd4jcuda -DMSVC_DEV=true -DCOMPUTE=61 -DBUILD_TESTS=true",
"buildCommandArgs": "-v", "buildCommandArgs": "-v",
"ctestCommandArgs": "" "ctestCommandArgs": ""
},
{
"name": "WSL-GCC-Debug",
"generator": "Unix Makefiles",
"configurationType": "Debug",
"buildRoot": "${projectDir}\\out\\build\\${name}",
"installRoot": "${projectDir}\\out\\install\\${name}",
"cmakeExecutable": "/usr/bin/cmake",
"cmakeCommandArgs": "-DLIBND4J_ALL_OPS=true -DCMAKE_BUILD_TYPE=Debug -DCPU_BLAS=true -DLIBND4J_NAME=nd4jcpu -DBUILD_TESTS=ON -DCMAKE_BUILD_TYPE=Debug -DOPENBLAS_PATH=/usr/lib/openblas-base/ -DEXTENSION=avx2 ",
"buildCommandArgs": "-j 4",
"ctestCommandArgs": "",
"inheritEnvironments": [ "linux_x64" ],
"wslPath": "${defaultWSLPath}",
"addressSanitizerRuntimeFlags": "detect_leaks=0",
"variables": []
} }
] ]
} }

View File

@ -41,6 +41,7 @@
#include <ops/declarable/headers/tests.h> #include <ops/declarable/headers/tests.h>
#include <ops/declarable/headers/kernels.h> #include <ops/declarable/headers/kernels.h>
#include <ops/declarable/headers/BarnesHutTsne.h> #include <ops/declarable/headers/BarnesHutTsne.h>
#include <ops/declarable/headers/color_models.h>
#include <dll.h> #include <dll.h>
#include <helpers/shape.h> #include <helpers/shape.h>
#include <helpers/TAD.h> #include <helpers/TAD.h>

View File

@ -0,0 +1,85 @@
/*******************************************************************************
* Copyright (c) 2019 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
#include <ops/declarable/headers/color_models.h>
#include <ops/declarable/CustomOperations.h>
#include <helpers/ConstantTadHelper.h>
#include <execution/Threads.h>
namespace nd4j {
namespace ops {
CONFIGURABLE_OP_IMPL(hsv_to_rgb, 1, 1, false, 0, 0) {
auto input = INPUT_VARIABLE(0);
auto output = OUTPUT_VARIABLE(0);
if (input->isEmpty())
return Status::OK();
const int rank = input->rankOf();
const int arg_size = block.getIArguments()->size();
const int dimC = arg_size > 0 ? (INT_ARG(0) >= 0 ? INT_ARG(0) : INT_ARG(0) + rank) : rank - 1;
REQUIRE_TRUE(rank >= 1, 0, "HSVtoRGB: Fails to meet the rank requirement: %i >= 1 ", rank);
if (arg_size > 0) {
REQUIRE_TRUE(dimC >= 0 && dimC < rank, 0, "Index of the Channel dimension out of range: %i not in [%i,%i) ", INT_ARG(0), -rank, rank);
}
REQUIRE_TRUE(input->sizeAt(dimC) == 3, 0, "HSVtoRGB: operation expects 3 channels (H, S, V), but got %i instead", input->sizeAt(dimC));
helpers::transform_hsv_rgb(block.launchContext(), input, output, dimC);
return Status::OK();
}
CONFIGURABLE_OP_IMPL(rgb_to_hsv, 1, 1, false, 0, 0) {
auto input = INPUT_VARIABLE(0);
auto output = OUTPUT_VARIABLE(0);
if (input->isEmpty())
return Status::OK();
const int rank = input->rankOf();
const int arg_size = block.getIArguments()->size();
const int dimC = arg_size > 0 ? (INT_ARG(0) >= 0 ? INT_ARG(0) : INT_ARG(0) + rank) : rank - 1;
REQUIRE_TRUE(rank >= 1, 0, "RGBtoHSV: Fails to meet the rank requirement: %i >= 1 ", rank);
if (arg_size > 0) {
REQUIRE_TRUE(dimC >= 0 && dimC < rank, 0, "Index of the Channel dimension out of range: %i not in [%i,%i) ", INT_ARG(0), -rank, rank);
}
REQUIRE_TRUE(input->sizeAt(dimC) == 3, 0, "RGBtoHSV: operation expects 3 channels (H, S, V), but got %i instead", input->sizeAt(dimC));
helpers::transform_rgb_hsv(block.launchContext(), input, output, dimC);
return Status::OK();
}
DECLARE_TYPES(hsv_to_rgb) {
getOpDescriptor()->setAllowedInputTypes({ ALL_FLOATS })
->setSameMode(true);
}
DECLARE_TYPES(rgb_to_hsv) {
getOpDescriptor()->setAllowedInputTypes({ ALL_FLOATS })
->setSameMode(true);
}
}
}

View File

@ -39,10 +39,14 @@ CONFIGURABLE_OP_IMPL(adjust_hue, 1, 1, true, 0, 0) {
return Status::OK(); return Status::OK();
const int rank = input->rankOf(); const int rank = input->rankOf();
const int dimC = block.getIArguments()->size() > 0 ? (INT_ARG(0) >= 0 ? INT_ARG(0) : INT_ARG(0) + rank) : rank - 1; const int arg_size = block.getIArguments()->size();
const int dimC = arg_size > 0 ? (INT_ARG(0) >= 0 ? INT_ARG(0) : INT_ARG(0) + rank) : rank - 1;
REQUIRE_TRUE(block.numT() > 0 || block.width() > 1, 0, "ADJUST_HUE: delta factor is required !"); REQUIRE_TRUE(block.numT() > 0 || block.width() > 1, 0, "ADJUST_HUE: delta factor is required !");
REQUIRE_TRUE(rank >= 3, 0, "ADJUST_HUE: op expects rank of input array to be >= 3, but got %i instead", rank); REQUIRE_TRUE(rank >= 3, 0, "ADJUST_HUE: op expects rank of input array to be >= 3, but got %i instead", rank);
if (arg_size > 0) {
REQUIRE_TRUE(dimC >= 0 && dimC < rank, 0, "Index of the Channel dimension out of range: %i not in [%i,%i) ", INT_ARG(0), -rank, rank);
}
REQUIRE_TRUE(input->sizeAt(dimC) == 3, 0, "ADJUST_HUE: operation expects image with 3 channels (R, G, B), but got %i instead", input->sizeAt(dimC)); REQUIRE_TRUE(input->sizeAt(dimC) == 3, 0, "ADJUST_HUE: operation expects image with 3 channels (R, G, B), but got %i instead", input->sizeAt(dimC));
NDArray* delta = nullptr; NDArray* delta = nullptr;

View File

@ -38,9 +38,13 @@ CONFIGURABLE_OP_IMPL(adjust_saturation, 1, 1, true, 0, 0) {
return Status::OK(); return Status::OK();
const int rank = input->rankOf(); const int rank = input->rankOf();
const int dimC = block.getIArguments()->size() > 0 ? (INT_ARG(0) >= 0 ? INT_ARG(0) : INT_ARG(0) + rank) : rank - 1; const int arg_size = block.getIArguments()->size();
const int dimC = arg_size > 0 ? (INT_ARG(0) >= 0 ? INT_ARG(0) : INT_ARG(0) + rank) : rank - 1;
REQUIRE_TRUE(rank >= 3, 0, "ADJUST_SATURATION: op expects rank of input array to be >= 3, but got %i instead", rank); REQUIRE_TRUE(rank >= 3, 0, "ADJUST_SATURATION: op expects rank of input array to be >= 3, but got %i instead", rank);
if (arg_size > 0) {
REQUIRE_TRUE(dimC >= 0 && dimC < rank, 0, "Index of the Channel dimension out of range: %i not in [%i,%i) ", INT_ARG(0), -rank, rank);
}
REQUIRE_TRUE(input->sizeAt(dimC) == 3, 0, "ADJUST_SATURATION: operation expects image with 3 channels (R, G, B), but got %i instead", input->sizeAt(dimC)); REQUIRE_TRUE(input->sizeAt(dimC) == 3, 0, "ADJUST_SATURATION: operation expects image with 3 channels (R, G, B), but got %i instead", input->sizeAt(dimC));
REQUIRE_TRUE(block.numT() > 0 || block.width() > 1, 0, "ADJUST_SATURATION: scale factor is required !"); REQUIRE_TRUE(block.numT() > 0 || block.width() > 1, 0, "ADJUST_SATURATION: scale factor is required !");

View File

@ -0,0 +1,56 @@
/*******************************************************************************
* Copyright (c) 2019 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
#ifndef LIBND4J_HEADERS_COLOR_MODELS_H
#define LIBND4J_HEADERS_COLOR_MODELS_H
#include <ops/declarable/headers/common.h>
#include <ops/declarable/CustomOperations.h>
#include <helpers/ConstantTadHelper.h>
#include <execution/Threads.h>
#include <ops/declarable/helpers/color_models_conv.h>
namespace nd4j {
namespace ops {
/**
* Rgb To Hsv
* Input arrays:
* 0 - input array with rank >= 1, must have at least one dimension equal 3, that is dimension containing channels.
* Int arguments:
* 0 - optional argument, corresponds to dimension with 3 channels
*/
#if NOT_EXCLUDED(OP_rgb_to_hsv)
DECLARE_CONFIGURABLE_OP(rgb_to_hsv, 1, 1, false, 0, 0);
#endif
/**
* Hsv To Rgb
* Input arrays:
* 0 - input array with rank >= 1, must have at least one dimension equal 3, that is dimension containing channels.
* Int arguments:
* 0 - optional argument, corresponds to dimension with 3 channels
*/
#if NOT_EXCLUDED(OP_hsv_to_rgb)
DECLARE_CONFIGURABLE_OP(hsv_to_rgb, 1, 1, false, 0, 0);
#endif
}
}
#endif

View File

@ -0,0 +1,30 @@
/*******************************************************************************
* Copyright (c) 2019 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
#include <op_boilerplate.h>
#include <templatemath.h>
#include <NDArray.h>
namespace nd4j {
namespace ops {
namespace helpers {
void transform_hsv_rgb(nd4j::LaunchContext* context, const NDArray* input, NDArray* output, const int dimC);
void transform_rgb_hsv(nd4j::LaunchContext* context, const NDArray* input, NDArray* output, const int dimC);
}
}
}

View File

@ -0,0 +1,90 @@
/*******************************************************************************
* Copyright (c) 2019 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
#include <ops/declarable/helpers/adjust_hue.h>
#include <ops/declarable/helpers/color_models_conv.h>
#include <helpers/ConstantTadHelper.h>
#include <execution/Threads.h>
namespace nd4j {
namespace ops {
namespace helpers {
//local
template <typename T, typename Op>
FORCEINLINE static void triple_transformer(const NDArray* input, NDArray* output, const int dimC, Op op) {
const int rank = input->rankOf();
const T* x = input->bufferAsT<T>();
T* z = output->bufferAsT<T>();
if (dimC == rank - 1 && input->ews() == 1 && output->ews() == 1 && input->ordering() == 'c' && output->ordering() == 'c') {
auto func = PRAGMA_THREADS_FOR{
for (auto i = start; i < stop; i += increment) {
op(x[i], x[i + 1], x[i + 2], z[i], z[i + 1], z[i + 2]);
}
};
samediff::Threads::parallel_for(func, 0, input->lengthOf(), 3);
}
else {
auto packX = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(input->getShapeInfo(), dimC);
auto packZ = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(output->getShapeInfo(), dimC);
const Nd4jLong numOfTads = packX.numberOfTads();
const Nd4jLong xDimCstride = input->stridesOf()[dimC];
const Nd4jLong zDimCstride = output->stridesOf()[dimC];
auto func = PRAGMA_THREADS_FOR{
for (auto i = start; i < stop; i += increment) {
const T* xTad = x + packX.platformOffsets()[i];
T* zTad = z + packZ.platformOffsets()[i];
op(xTad[0], xTad[xDimCstride], xTad[2 * xDimCstride], zTad[0], zTad[zDimCstride], zTad[2 * zDimCstride]);
}
};
samediff::Threads::parallel_tad(func, 0, numOfTads);
}
}
template <typename T>
FORCEINLINE static void hsv_rgb(const NDArray* input, NDArray* output, const int dimC) {
auto op = nd4j::ops::helpers::hsvToRgb<T>;
return triple_transformer<T>(input, output, dimC, op);
}
template <typename T>
FORCEINLINE static void rgb_hsv(const NDArray* input, NDArray* output, const int dimC) {
auto op = nd4j::ops::helpers::rgbToHsv<T>;
return triple_transformer<T>(input, output, dimC, op);
}
void transform_hsv_rgb(nd4j::LaunchContext* context, const NDArray* input, NDArray* output, const int dimC) {
BUILD_SINGLE_SELECTOR(input->dataType(), hsv_rgb, (input, output, dimC), FLOAT_TYPES);
}
void transform_rgb_hsv(nd4j::LaunchContext* context, const NDArray* input, NDArray* output, const int dimC) {
BUILD_SINGLE_SELECTOR(input->dataType(), rgb_hsv, (input, output, dimC), FLOAT_TYPES);
}
}
}
}

View File

@ -0,0 +1,139 @@
/*******************************************************************************
* Copyright (c) 2019 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
#include <ops/declarable/helpers/color_models_conv.h>
#include <ops/declarable/helpers/adjust_hue.h>
#include <ops/declarable/helpers/adjust_saturation.h>
#include <helpers/ConstantTadHelper.h>
#include <PointersManager.h>
namespace nd4j {
namespace ops {
namespace helpers {
template <typename T>
static void _CUDA_G rgbToHsvCuda(const void* vx, const Nd4jLong* xShapeInfo, const Nd4jLong* xTadOffsets,
void* vz, const Nd4jLong *zShapeInfo, const Nd4jLong* zTadOffsets,
const Nd4jLong numOfTads, const int dimC) {
const T* x = reinterpret_cast<const T*>(vx);
T* z = reinterpret_cast<T*>(vz);
__shared__ int rank;
__shared__ Nd4jLong xDimCstride, zDimCstride;
if (threadIdx.x == 0) {
rank = shape::rank(xShapeInfo);
xDimCstride = shape::stride(xShapeInfo)[dimC];
zDimCstride = shape::stride(zShapeInfo)[dimC];
}
__syncthreads();
const auto tid = blockIdx.x * blockDim.x + threadIdx.x;
for (Nd4jLong i = tid; i < numOfTads; i += gridDim.x * blockDim.x) {
const T* xTad = x + xTadOffsets[i];
T* zTad = z + zTadOffsets[i];
rgbToHsv<T>(xTad[0], xTad[xDimCstride], xTad[2 * xDimCstride], zTad[0], zTad[zDimCstride], zTad[2 * zDimCstride]);
}
}
template <typename T>
static void _CUDA_G hsvToRgbCuda(const void* vx, const Nd4jLong* xShapeInfo, const Nd4jLong* xTadOffsets,
void* vz, const Nd4jLong *zShapeInfo, const Nd4jLong* zTadOffsets,
const Nd4jLong numOfTads, const int dimC) {
const T* x = reinterpret_cast<const T*>(vx);
T* z = reinterpret_cast<T*>(vz);
__shared__ int rank;
__shared__ Nd4jLong xDimCstride, zDimCstride;
if (threadIdx.x == 0) {
rank = shape::rank(xShapeInfo);
xDimCstride = shape::stride(xShapeInfo)[dimC];
zDimCstride = shape::stride(zShapeInfo)[dimC];
}
__syncthreads();
const auto tid = blockIdx.x * blockDim.x + threadIdx.x;
for (Nd4jLong i = tid; i < numOfTads; i += gridDim.x * blockDim.x) {
const T* xTad = x + xTadOffsets[i];
T* zTad = z + zTadOffsets[i];
hsvToRgb<T>(xTad[0], xTad[xDimCstride], xTad[2 * xDimCstride], zTad[0], zTad[zDimCstride], zTad[2 * zDimCstride]);
}
}
///////////////////////////////////////////////////////////////////
template<typename T>
static _CUDA_H void hsvToRgbCudaLauncher(const int blocksPerGrid, const int threadsPerBlock, const cudaStream_t *stream,
const void* vx, const Nd4jLong* xShapeInfo, const Nd4jLong* xTadOffsets,
void* vz, const Nd4jLong* zShapeInfo, const Nd4jLong* zTadOffsets,
const Nd4jLong numOfTads, const int dimC) {
hsvToRgbCuda<T><<<blocksPerGrid, threadsPerBlock, 256, *stream>>>(vx, xShapeInfo, xTadOffsets, vz, zShapeInfo, zTadOffsets, numOfTads, dimC);
}
template<typename T>
static _CUDA_H void rgbToHsvCudaLauncher(const int blocksPerGrid, const int threadsPerBlock, const cudaStream_t *stream,
const void* vx, const Nd4jLong* xShapeInfo, const Nd4jLong* xTadOffsets,
void* vz, const Nd4jLong* zShapeInfo, const Nd4jLong* zTadOffsets,
const Nd4jLong numOfTads, const int dimC) {
rgbToHsvCuda<T><<<blocksPerGrid, threadsPerBlock, 256, *stream>>>(vx, xShapeInfo, xTadOffsets, vz, zShapeInfo, zTadOffsets, numOfTads, dimC);
}
void transform_hsv_rgb(nd4j::LaunchContext* context, const NDArray* input, NDArray* output, const int dimC) {
auto packX = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(input->getShapeInfo(), {dimC});
auto packZ = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(output->getShapeInfo(), {dimC});
const Nd4jLong numOfTads = packX.numberOfTads();
const int threadsPerBlock = MAX_NUM_THREADS / 2;
const int blocksPerGrid = (numOfTads + threadsPerBlock - 1) / threadsPerBlock;
PointersManager manager(context, "hsv_to_rgb");
NDArray::prepareSpecialUse({output}, {input});
BUILD_SINGLE_SELECTOR(input->dataType(), hsvToRgbCudaLauncher, (blocksPerGrid, threadsPerBlock, context->getCudaStream(), input->getSpecialBuffer(), input->getSpecialShapeInfo(), packX.platformOffsets(), output->specialBuffer(), output->specialShapeInfo(), packZ.platformOffsets(), numOfTads, dimC), FLOAT_TYPES);
NDArray::registerSpecialUse({output}, {input});
manager.synchronize();
}
void transform_rgb_hsv(nd4j::LaunchContext* context, const NDArray* input, NDArray* output, const int dimC) {
auto packX = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(input->getShapeInfo(), {dimC});
auto packZ = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(output->getShapeInfo(), {dimC});
const Nd4jLong numOfTads = packX.numberOfTads();
const int threadsPerBlock = MAX_NUM_THREADS / 2;
const int blocksPerGrid = (numOfTads + threadsPerBlock - 1) / threadsPerBlock;
PointersManager manager(context, "rgb_to_hsv");
NDArray::prepareSpecialUse({output}, {input});
BUILD_SINGLE_SELECTOR(input->dataType(), rgbToHsvCudaLauncher, (blocksPerGrid, threadsPerBlock, context->getCudaStream(), input->getSpecialBuffer(), input->getSpecialShapeInfo(), packX.platformOffsets(), output->specialBuffer(), output->specialShapeInfo(), packZ.platformOffsets(), numOfTads, dimC), FLOAT_TYPES);
NDArray::registerSpecialUse({output}, {input});
manager.synchronize();
}
}
}
}

View File

@ -15,9 +15,9 @@
******************************************************************************/ ******************************************************************************/
// //
// @author raver119@gmail.com // @author raver119@gmail.com
// //
#include "testlayers.h" #include "testlayers.h"
#include <ops/declarable/CustomOperations.h> #include <ops/declarable/CustomOperations.h>
@ -40,13 +40,13 @@ public:
}; };
TEST_F(DeclarableOpsTests16, scatter_upd_1) { TEST_F(DeclarableOpsTests16, scatter_upd_1) {
auto x = NDArrayFactory::create<float>('c', {3}, {1.f, 1.f, 1.f}); auto x = NDArrayFactory::create<float>('c', { 3 }, { 1.f, 1.f, 1.f });
auto y = NDArrayFactory::create<int>(0); auto y = NDArrayFactory::create<int>(0);
auto w = NDArrayFactory::create<float>(3.0f); auto w = NDArrayFactory::create<float>(3.0f);
auto e = NDArrayFactory::create<float>('c', {3}, {3.f, 1.f, 1.f}); auto e = NDArrayFactory::create<float>('c', { 3 }, { 3.f, 1.f, 1.f });
nd4j::ops::scatter_upd op; nd4j::ops::scatter_upd op;
auto result = op.execute({&x, &y, &w}, {}, {}); auto result = op.execute({ &x, &y, &w }, {}, {});
ASSERT_EQ(Status::OK(), result->status()); ASSERT_EQ(Status::OK(), result->status());
auto z = result->at(0); auto z = result->at(0);
@ -58,15 +58,15 @@ TEST_F(DeclarableOpsTests16, scatter_upd_1) {
TEST_F(DeclarableOpsTests16, scatter_upd_2) { TEST_F(DeclarableOpsTests16, scatter_upd_2) {
NDArray x('c', {10, 3}, nd4j::DataType::FLOAT32); NDArray x('c', { 10, 3 }, nd4j::DataType::FLOAT32);
NDArray indices('c', {2}, {2,5}, nd4j::DataType::INT32); NDArray indices('c', { 2 }, { 2,5 }, nd4j::DataType::INT32);
NDArray updates('c', {2, 3}, {100,101,102, 200,201,202}, nd4j::DataType::FLOAT32); NDArray updates('c', { 2, 3 }, { 100,101,102, 200,201,202 }, nd4j::DataType::FLOAT32);
NDArray e('c', {10, 3}, {1,2,3, 4,5,6, 100,101,102, 10,11,12, 13,14,15, 200,201,202, 19,20,21, 22,23,24, 25,26,27, 28,29,30}, nd4j::DataType::FLOAT32); NDArray e('c', { 10, 3 }, { 1,2,3, 4,5,6, 100,101,102, 10,11,12, 13,14,15, 200,201,202, 19,20,21, 22,23,24, 25,26,27, 28,29,30 }, nd4j::DataType::FLOAT32);
x.linspace(1); x.linspace(1);
nd4j::ops::scatter_upd op; nd4j::ops::scatter_upd op;
auto result = op.execute({&x, &indices, &updates}, {}, {}); auto result = op.execute({ &x, &indices, &updates }, {}, {});
ASSERT_EQ(Status::OK(), result->status()); ASSERT_EQ(Status::OK(), result->status());
auto z = result->at(0); auto z = result->at(0);
@ -78,22 +78,22 @@ TEST_F(DeclarableOpsTests16, scatter_upd_2) {
TEST_F(DeclarableOpsTests16, scatter_upd_3) { TEST_F(DeclarableOpsTests16, scatter_upd_3) {
NDArray x('c', {10, 3}, nd4j::DataType::FLOAT32); NDArray x('c', { 10, 3 }, nd4j::DataType::FLOAT32);
NDArray indices('c', {2}, {20,5}, nd4j::DataType::INT32); NDArray indices('c', { 2 }, { 20,5 }, nd4j::DataType::INT32);
NDArray updates('c', {2, 3}, {100,101,102, 200,201,202}, nd4j::DataType::FLOAT32); NDArray updates('c', { 2, 3 }, { 100,101,102, 200,201,202 }, nd4j::DataType::FLOAT32);
NDArray output('c', {10, 3}, nd4j::DataType::FLOAT32); NDArray output('c', { 10, 3 }, nd4j::DataType::FLOAT32);
nd4j::ops::scatter_upd op; nd4j::ops::scatter_upd op;
ASSERT_ANY_THROW(op.execute({&x, &indices, &updates}, {&output}, {}, {}, {true, true})); ASSERT_ANY_THROW(op.execute({ &x, &indices, &updates }, { &output }, {}, {}, { true, true }));
} }
TEST_F(DeclarableOpsTests16, test_size_dtype_1) { TEST_F(DeclarableOpsTests16, test_size_dtype_1) {
auto x = NDArrayFactory::create<float>('c', {3}, {1, 1, 1}); auto x = NDArrayFactory::create<float>('c', { 3 }, { 1, 1, 1 });
auto z = NDArrayFactory::create<float>(0.0f); auto z = NDArrayFactory::create<float>(0.0f);
auto e = NDArrayFactory::create<float>(3.0f); auto e = NDArrayFactory::create<float>(3.0f);
nd4j::ops::size op; nd4j::ops::size op;
auto status = op.execute({&x}, {&z}, {}, {}, {}); auto status = op.execute({ &x }, { &z }, {}, {}, {});
ASSERT_EQ(Status::OK(), status); ASSERT_EQ(Status::OK(), status);
ASSERT_EQ(e, z); ASSERT_EQ(e, z);
@ -103,7 +103,7 @@ TEST_F(DeclarableOpsTests16, test_empty_noop_1) {
auto z = NDArrayFactory::empty<Nd4jLong>(); auto z = NDArrayFactory::empty<Nd4jLong>();
nd4j::ops::noop op; nd4j::ops::noop op;
auto status = op.execute({}, {&z}, {}, {}, {}); auto status = op.execute({}, { &z }, {}, {}, {});
ASSERT_EQ(Status::OK(), status); ASSERT_EQ(Status::OK(), status);
} }
@ -120,22 +120,22 @@ TEST_F(DeclarableOpsTests16, test_empty_noop_2) {
} }
TEST_F(DeclarableOpsTests16, test_svd_1) { TEST_F(DeclarableOpsTests16, test_svd_1) {
auto x = NDArrayFactory::create<float>('c', {3, 3}, {0.7787856f, 0.80119777f, 0.72437465f, 0.23089433f, 0.72714126f, 0.18039072f,0.50563407f, 0.89252293f, 0.5461209f}); auto x = NDArrayFactory::create<float>('c', { 3, 3 }, { 0.7787856f, 0.80119777f, 0.72437465f, 0.23089433f, 0.72714126f, 0.18039072f,0.50563407f, 0.89252293f, 0.5461209f });
auto z = NDArrayFactory::create<float>('c', {3}); auto z = NDArrayFactory::create<float>('c', { 3 });
nd4j::ops::svd op; nd4j::ops::svd op;
auto status = op.execute({&x}, {&z}, {}, {0, 0, 16}, {}); auto status = op.execute({ &x }, { &z }, {}, { 0, 0, 16 }, {});
ASSERT_EQ(Status::OK(), status); ASSERT_EQ(Status::OK(), status);
} }
TEST_F(DeclarableOpsTests16, test_hamming_distance_1) { TEST_F(DeclarableOpsTests16, test_hamming_distance_1) {
auto x = NDArrayFactory::create<Nd4jLong>({37, 37, 37}); auto x = NDArrayFactory::create<Nd4jLong>({ 37, 37, 37 });
auto y = NDArrayFactory::create<Nd4jLong>({8723, 8723, 8723}); auto y = NDArrayFactory::create<Nd4jLong>({ 8723, 8723, 8723 });
auto e = NDArrayFactory::create<Nd4jLong>(18); auto e = NDArrayFactory::create<Nd4jLong>(18);
nd4j::ops::bits_hamming_distance op; nd4j::ops::bits_hamming_distance op;
auto result = op.execute({&x, &y}, {}, {}); auto result = op.execute({ &x, &y }, {}, {});
ASSERT_EQ(Status::OK(), result->status()); ASSERT_EQ(Status::OK(), result->status());
auto z = result->at(0); auto z = result->at(0);
@ -146,9 +146,9 @@ TEST_F(DeclarableOpsTests16, test_hamming_distance_1) {
} }
TEST_F(DeclarableOpsTests16, test_knn_mindistance_1) { TEST_F(DeclarableOpsTests16, test_knn_mindistance_1) {
auto input = NDArrayFactory::create<float>('c', {512}); auto input = NDArrayFactory::create<float>('c', { 512 });
auto low = NDArrayFactory::create<float>('c', {512}); auto low = NDArrayFactory::create<float>('c', { 512 });
auto high = NDArrayFactory::create<float>('c', {512}); auto high = NDArrayFactory::create<float>('c', { 512 });
auto output = NDArrayFactory::create<float>(0.0f); auto output = NDArrayFactory::create<float>(0.0f);
@ -157,16 +157,16 @@ TEST_F(DeclarableOpsTests16, test_knn_mindistance_1) {
high.linspace(1.0); high.linspace(1.0);
nd4j::ops::knn_mindistance op; nd4j::ops::knn_mindistance op;
auto result = op.execute({&input, &low, &high}, {&output}, {}, {}, {}); auto result = op.execute({ &input, &low, &high }, { &output }, {}, {}, {});
ASSERT_EQ(Status::OK(), result); ASSERT_EQ(Status::OK(), result);
} }
TEST_F(DeclarableOpsTests16, test_empty_cast_1) { TEST_F(DeclarableOpsTests16, test_empty_cast_1) {
auto x = NDArrayFactory::create<bool>('c', {1, 0, 2}); auto x = NDArrayFactory::create<bool>('c', { 1, 0, 2 });
auto e = NDArrayFactory::create<Nd4jLong>('c', {1, 0, 2}); auto e = NDArrayFactory::create<Nd4jLong>('c', { 1, 0, 2 });
nd4j::ops::cast op; nd4j::ops::cast op;
auto result = op.execute({&x}, {}, {10}); auto result = op.execute({ &x }, {}, { 10 });
ASSERT_EQ(Status::OK(), result->status()); ASSERT_EQ(Status::OK(), result->status());
ASSERT_EQ(e, *result->at(0)); ASSERT_EQ(e, *result->at(0));
@ -175,10 +175,10 @@ TEST_F(DeclarableOpsTests16, test_empty_cast_1) {
TEST_F(DeclarableOpsTests16, test_range_1) { TEST_F(DeclarableOpsTests16, test_range_1) {
nd4j::ops::range op; nd4j::ops::range op;
auto z = NDArrayFactory::create<float>('c', {200}); auto z = NDArrayFactory::create<float>('c', { 200 });
Context ctx(1); Context ctx(1);
ctx.setTArguments({-1.0, 1.0, 0.01}); ctx.setTArguments({ -1.0, 1.0, 0.01 });
ctx.setOutputArray(0, &z); ctx.setOutputArray(0, &z);
auto status = op.execute(&ctx); auto status = op.execute(&ctx);
@ -187,9 +187,9 @@ TEST_F(DeclarableOpsTests16, test_range_1) {
TEST_F(DeclarableOpsTests16, test_range_2) { TEST_F(DeclarableOpsTests16, test_range_2) {
nd4j::ops::range op; nd4j::ops::range op;
auto z = NDArrayFactory::create<float>('c', {200}); auto z = NDArrayFactory::create<float>('c', { 200 });
double tArgs[] = {-1.0, 1.0, 0.01}; double tArgs[] = { -1.0, 1.0, 0.01 };
auto shapes = ::calculateOutputShapes2(nullptr, op.getOpHash(), nullptr, nullptr, 0, tArgs, 3, nullptr, 0, nullptr, 0); auto shapes = ::calculateOutputShapes2(nullptr, op.getOpHash(), nullptr, nullptr, 0, tArgs, 3, nullptr, 0, nullptr, 0);
shape::printShapeInfoLinear("Result", shapes->at(0)); shape::printShapeInfoLinear("Result", shapes->at(0));
@ -199,27 +199,27 @@ TEST_F(DeclarableOpsTests16, test_range_2) {
} }
TEST_F(DeclarableOpsTests16, test_reverse_1) { TEST_F(DeclarableOpsTests16, test_reverse_1) {
std::vector<Nd4jLong> rows = {3, 5, 7, 8, 9, 10, 119, 211}; std::vector<Nd4jLong> rows = { 3, 5, 7, 8, 9, 10, 119, 211 };
std::vector<Nd4jLong> columns = {6, 5, 10, 100, 153, 171, 635}; std::vector<Nd4jLong> columns = { 6, 5, 10, 100, 153, 171, 635 };
for (auto r : rows) { for (auto r : rows) {
for (auto c : columns) { for (auto c : columns) {
//nd4j_printf("Trying [%i, %i]\n", r, c); //nd4j_printf("Trying [%i, %i]\n", r, c);
auto array = NDArrayFactory::create<float>('c', {r, c}); auto array = NDArrayFactory::create<float>('c', { r, c });
auto exp = NDArrayFactory::create<float>('c', {r, c}); auto exp = NDArrayFactory::create<float>('c', { r, c });
auto reversed = NDArrayFactory::create<float>('c', {r, c}); auto reversed = NDArrayFactory::create<float>('c', { r, c });
auto rowOriginal = NDArrayFactory::create<float>('c', {c}); auto rowOriginal = NDArrayFactory::create<float>('c', { c });
auto rowReversed = NDArrayFactory::create<float>('c', {c}); auto rowReversed = NDArrayFactory::create<float>('c', { c });
for (int e = 0; e < c; e++) { for (int e = 0; e < c; e++) {
rowOriginal.p(e, (float) e); rowOriginal.p(e, (float)e);
rowReversed.p(c - e - 1, (float) e); rowReversed.p(c - e - 1, (float)e);
} }
auto listI = array.allTensorsAlongDimension({1}); auto listI = array.allTensorsAlongDimension({ 1 });
auto listE = exp.allTensorsAlongDimension({1}); auto listE = exp.allTensorsAlongDimension({ 1 });
for (int e = 0; e < r; e++) { for (int e = 0; e < r; e++) {
listI->at(e)->assign(rowOriginal); listI->at(e)->assign(rowOriginal);
@ -230,11 +230,466 @@ TEST_F(DeclarableOpsTests16, test_reverse_1) {
delete listE; delete listE;
nd4j::ops::reverse op; nd4j::ops::reverse op;
Nd4jLong axis = 1; Nd4jLong axis = 1;
auto status = op.execute({&array}, {&reversed}, {}, {axis}, {}); auto status = op.execute({ &array }, { &reversed }, {}, { axis }, {});
ASSERT_EQ(Status::OK(), status); ASSERT_EQ(Status::OK(), status);
ASSERT_EQ(exp, reversed); ASSERT_EQ(exp, reversed);
} }
} }
} }
TEST_F(DeclarableOpsTests16, test_rgb_to_hsv_1) {
/*
test case generated by python colorsys and scaled to suit our needs
from colorsys import *
from random import *
import numpy as np
rgbs = np.array([randint(0,255) for x in range(0,3*4*5)]).reshape([5,4,3])
hsvs=np.apply_along_axis(lambda x: np.array(rgb_to_hsv(x[0]/255,x[1]/255,x[2]/255))*np.array([360,1,1]),2,rgbs)
rgbs.ravel()
hsvs.ravel()
*/
auto rgbs = NDArrayFactory::create<float>('c', { 5, 4, 3 },
{
213.f, 220.f, 164.f, 121.f, 180.f, 180.f, 18.f, 245.f, 75.f, 235.f, 76.f, 74.f, 168.f,
50.f, 233.f, 191.f, 132.f, 100.f, 207.f, 37.f, 245.f, 77.f, 250.f, 182.f, 111.f, 52.f,
59.f, 193.f, 147.f, 137.f, 168.f, 103.f, 121.f, 48.f, 191.f, 187.f, 53.f, 82.f, 239.f,
156.f, 37.f, 118.f, 244.f, 90.f, 7.f, 221.f, 98.f, 243.f, 12.f, 209.f, 192.f, 2.f,
115.f, 205.f, 79.f, 247.f, 32.f, 70.f, 152.f, 180.f
});
auto expected = NDArrayFactory::create<float>('c', { 5, 4, 3 },
{
6.75000000e+01f, 2.54545455e-01f, 8.62745098e-01f, 1.80000000e+02f,
3.27777778e-01f, 7.05882353e-01f, 1.35066079e+02f, 9.26530612e-01f,
9.60784314e-01f, 7.45341615e-01f, 6.85106383e-01f, 9.21568627e-01f,
2.78688525e+02f, 7.85407725e-01f, 9.13725490e-01f, 2.10989011e+01f,
4.76439791e-01f, 7.49019608e-01f, 2.89038462e+02f, 8.48979592e-01f,
9.60784314e-01f, 1.56416185e+02f, 6.92000000e-01f, 9.80392157e-01f,
3.52881356e+02f, 5.31531532e-01f, 4.35294118e-01f, 1.07142857e+01f,
2.90155440e-01f, 7.56862745e-01f, 3.43384615e+02f, 3.86904762e-01f,
6.58823529e-01f, 1.78321678e+02f, 7.48691099e-01f, 7.49019608e-01f,
2.30645161e+02f, 7.78242678e-01f, 9.37254902e-01f, 3.19159664e+02f,
7.62820513e-01f, 6.11764706e-01f, 2.10126582e+01f, 9.71311475e-01f,
9.56862745e-01f, 2.90896552e+02f, 5.96707819e-01f, 9.52941176e-01f,
1.74822335e+02f, 9.42583732e-01f, 8.19607843e-01f, 2.06600985e+02f,
9.90243902e-01f, 8.03921569e-01f, 1.06883721e+02f, 8.70445344e-01f,
9.68627451e-01f, 1.95272727e+02f, 6.11111111e-01f, 7.05882353e-01f
});
auto actual = NDArrayFactory::create<float>('c', { 5,4,3 });
Context ctx(1);
ctx.setInputArray(0, &rgbs);
ctx.setOutputArray(0, &actual);
nd4j::ops::rgb_to_hsv op;
auto status = op.execute(&ctx);
#if 0
//visual check
rgbs.printBuffer("rgbs ");
actual.printBuffer("HSV ");
expected.printBuffer("exp");
#endif
ASSERT_EQ(ND4J_STATUS_OK, status);
ASSERT_TRUE(expected.equalsTo(actual));
}
TEST_F(DeclarableOpsTests16, test_rgb_to_hsv_2) {
/*
swapped_rgbs=rgbs.swapaxes(1,2).ravel()
swapped_hsvs=hsvs.swapaxes(1,2).ravel()
*/
auto rgbs = NDArrayFactory::create<float>('c', { 5,3,4 },
{
213.f, 121.f, 18.f, 235.f, 220.f, 180.f, 245.f, 76.f, 164.f, 180.f, 75.f, 74.f, 168.f,
191.f, 207.f, 77.f, 50.f, 132.f, 37.f, 250.f, 233.f, 100.f, 245.f, 182.f, 111.f, 193.f,
168.f, 48.f, 52.f, 147.f, 103.f, 191.f, 59.f, 137.f, 121.f, 187.f, 53.f, 156.f, 244.f,
221.f, 82.f, 37.f, 90.f, 98.f, 239.f, 118.f, 7.f, 243.f, 12.f, 2.f, 79.f, 70.f,
209.f, 115.f, 247.f, 152.f, 192.f, 205.f, 32.f, 180.f
});
auto expected = NDArrayFactory::create<float>('c', { 5,3,4 },
{
6.75000000e+01f, 1.80000000e+02f, 1.35066079e+02f, 7.45341615e-01f,
2.54545455e-01f, 3.27777778e-01f, 9.26530612e-01f, 6.85106383e-01f,
8.62745098e-01f, 7.05882353e-01f, 9.60784314e-01f, 9.21568627e-01f,
2.78688525e+02f, 2.10989011e+01f, 2.89038462e+02f, 1.56416185e+02f,
7.85407725e-01f, 4.76439791e-01f, 8.48979592e-01f, 6.92000000e-01f,
9.13725490e-01f, 7.49019608e-01f, 9.60784314e-01f, 9.80392157e-01f,
3.52881356e+02f, 1.07142857e+01f, 3.43384615e+02f, 1.78321678e+02f,
5.31531532e-01f, 2.90155440e-01f, 3.86904762e-01f, 7.48691099e-01f,
4.35294118e-01f, 7.56862745e-01f, 6.58823529e-01f, 7.49019608e-01f,
2.30645161e+02f, 3.19159664e+02f, 2.10126582e+01f, 2.90896552e+02f,
7.78242678e-01f, 7.62820513e-01f, 9.71311475e-01f, 5.96707819e-01f,
9.37254902e-01f, 6.11764706e-01f, 9.56862745e-01f, 9.52941176e-01f,
1.74822335e+02f, 2.06600985e+02f, 1.06883721e+02f, 1.95272727e+02f,
9.42583732e-01f, 9.90243902e-01f, 8.70445344e-01f, 6.11111111e-01f,
8.19607843e-01f, 8.03921569e-01f, 9.68627451e-01f, 7.05882353e-01f
});
auto actual = NDArrayFactory::create<float>('c', { 5,3,4 });
Context ctx(1);
ctx.setInputArray(0, &rgbs);
ctx.setOutputArray(0, &actual);
ctx.setIArguments({ 1 });
nd4j::ops::rgb_to_hsv op;
auto status = op.execute(&ctx);
ASSERT_EQ(ND4J_STATUS_OK, status);
ASSERT_TRUE(expected.equalsTo(actual));
}
TEST_F(DeclarableOpsTests16, test_rgb_to_hsv_3) {
/*
2D
*/
auto rgbs = NDArrayFactory::create<float>('c', { 8,3 },
{ 130.f, 61.f, 239.f, 117.f, 16.f, 168.f, 181.f, 223.f, 0.f, 49.f, 195.f, 195.f, 131.f,
153.f, 78.f, 86.f, 21.f, 209.f, 101.f, 14.f, 107.f, 191.f, 98.f, 210.f });
auto expected = NDArrayFactory::create<float>('c', { 8,3 },
{ 263.25842697f, 0.74476987f, 0.9372549f, 279.86842105f,
0.9047619f, 0.65882353f, 71.30044843f, 1.f,
0.8745098f, 180.f, 0.74871795f, 0.76470588f,
77.6f, 0.49019608f, 0.6f, 260.74468085f,
0.89952153f, 0.81960784f, 296.12903226f, 0.86915888f,
0.41960784f, 289.82142857f, 0.53333333f, 0.82352941f });
auto actual = NDArrayFactory::create<float>('c', { 8,3 });
Context ctx(1);
ctx.setInputArray(0, &rgbs);
ctx.setOutputArray(0, &actual);
nd4j::ops::rgb_to_hsv op;
auto status = op.execute(&ctx);
#if 0
//visual check
rgbs.printBuffer("rgbs ");
actual.printBuffer("HSV ");
expected.printBuffer("exp");
#endif
ASSERT_EQ(ND4J_STATUS_OK, status);
ASSERT_TRUE(expected.equalsTo(actual));
}
TEST_F(DeclarableOpsTests16, test_rgb_to_hsv_4) {
/*
2D
*/
auto rgbs = NDArrayFactory::create<float>('c', { 3,8 },
{ 130.f, 117.f, 181.f, 49.f, 131.f, 86.f, 101.f, 191.f, 61.f, 16.f, 223.f, 195.f, 153.f,
21.f, 14.f, 98.f, 239.f, 168.f, 0.f, 195.f, 78.f, 209.f, 107.f, 210.f });
auto expected = NDArrayFactory::create<float>('c', { 3, 8 },
{ 263.25842697f, 279.86842105f, 71.30044843f, 180.f,
77.6f, 260.74468085f, 296.12903226f, 289.82142857f,
0.74476987f, 0.9047619f, 1.f, 0.74871795f,
0.49019608f, 0.89952153f, 0.86915888f, 0.53333333f,
0.9372549f, 0.65882353f, 0.8745098f, 0.76470588f,
0.6f, 0.81960784f, 0.41960784f, 0.82352941f });
auto actual = NDArrayFactory::create<float>('c', { 3, 8 });
Context ctx(1);
ctx.setInputArray(0, &rgbs);
ctx.setOutputArray(0, &actual);
ctx.setIArguments({ 0 });
nd4j::ops::rgb_to_hsv op;
auto status = op.execute(&ctx);
#if 0
//visual check
rgbs.printBuffer("rgbs ");
actual.printBuffer("HSV ");
expected.printBuffer("exp");
#endif
ASSERT_EQ(ND4J_STATUS_OK, status);
ASSERT_TRUE(expected.equalsTo(actual));
}
TEST_F(DeclarableOpsTests16, test_rgb_to_hsv_5) {
/*
*/
auto rgbs = NDArrayFactory::create<float>('c', { 3 },
{ 213.f, 220.f, 164.f });
auto expected = NDArrayFactory::create<float>('c', { 3 },
{ 6.75000000e+01f, 2.54545455e-01f, 8.62745098e-01f });
auto actual = NDArrayFactory::create<float>('c', { 3 });
Context ctx(1);
ctx.setInputArray(0, &rgbs);
ctx.setOutputArray(0, &actual);
nd4j::ops::rgb_to_hsv op;
auto status = op.execute(&ctx);
#if 0
//visual check
rgbs.printBuffer("rgbs ");
actual.printBuffer("HSV ");
expected.printBuffer("exp");
#endif
ASSERT_EQ(ND4J_STATUS_OK, status);
ASSERT_TRUE(expected.equalsTo(actual));
}
TEST_F(DeclarableOpsTests16, test_rgb_to_hsv_6) {
/*
*/
auto rgbs = NDArrayFactory::create<float>('c', { 3,8 },
{ 130.f, 117.f, 181.f, 49.f, 131.f, 86.f, 101.f, 191.f, 61.f, 16.f, 223.f, 195.f, 153.f,
21.f, 14.f, 98.f, 239.f, 168.f, 0.f, 195.f, 78.f, 209.f, 107.f, 210.f });
auto expected = NDArrayFactory::create<float>('c', { 3 },
{ 263.25842697f, 0.74476987f, 0.9372549f });
//get subarray
std::unique_ptr<NDArray> subArrRgbs(rgbs.subarray({ NDIndex::all(), NDIndex::point(0) }));
subArrRgbs->reshapei({ 3 });
#if 0
//[RANK][SHAPE][STRIDES][OPTIONS][EWS][ORDER]
subArrRgbs->printShapeInfo("subArrRgbs");
#endif
auto actual = NDArrayFactory::create<float>('c', { 3 });
Context ctx(1);
ctx.setInputArray(0, subArrRgbs.get());
ctx.setOutputArray(0, &actual);
ctx.setIArguments({ 0 });
nd4j::ops::rgb_to_hsv op;
auto status = op.execute(&ctx);
#if 0
//visual check
subArrRgbs->printBuffer("subArrRgbs ");
actual.printBuffer("HSV ");
expected.printBuffer("exp");
#endif
ASSERT_EQ(ND4J_STATUS_OK, status);
ASSERT_TRUE(expected.equalsTo(actual));
}
TEST_F(DeclarableOpsTests16, test_hsv_to_rgb_1) {
/*
using the same numbers of rgb_to_hsv_1 test
*/
auto expected = NDArrayFactory::create<float>('c', { 5,4,3 },
{ 213.f, 220.f, 164.f, 121.f, 180.f, 180.f, 18.f, 245.f, 75.f, 235.f, 76.f, 74.f, 168.f,
50.f, 233.f, 191.f, 132.f, 100.f, 207.f, 37.f, 245.f, 77.f, 250.f, 182.f, 111.f, 52.f,
59.f, 193.f, 147.f, 137.f, 168.f, 103.f, 121.f, 48.f, 191.f, 187.f, 53.f, 82.f, 239.f,
156.f, 37.f, 118.f, 244.f, 90.f, 7.f, 221.f, 98.f, 243.f, 12.f, 209.f, 192.f, 2.f,
115.f, 205.f, 79.f, 247.f, 32.f, 70.f, 152.f, 180.f }
);
auto hsvs = NDArrayFactory::create<float>('c', { 5,4,3 },
{
6.75000000e+01f, 2.54545455e-01f, 8.62745098e-01f, 1.80000000e+02f,
3.27777778e-01f, 7.05882353e-01f, 1.35066079e+02f, 9.26530612e-01f,
9.60784314e-01f, 7.45341615e-01f, 6.85106383e-01f, 9.21568627e-01f,
2.78688525e+02f, 7.85407725e-01f, 9.13725490e-01f, 2.10989011e+01f,
4.76439791e-01f, 7.49019608e-01f, 2.89038462e+02f, 8.48979592e-01f,
9.60784314e-01f, 1.56416185e+02f, 6.92000000e-01f, 9.80392157e-01f,
3.52881356e+02f, 5.31531532e-01f, 4.35294118e-01f, 1.07142857e+01f,
2.90155440e-01f, 7.56862745e-01f, 3.43384615e+02f, 3.86904762e-01f,
6.58823529e-01f, 1.78321678e+02f, 7.48691099e-01f, 7.49019608e-01f,
2.30645161e+02f, 7.78242678e-01f, 9.37254902e-01f, 3.19159664e+02f,
7.62820513e-01f, 6.11764706e-01f, 2.10126582e+01f, 9.71311475e-01f,
9.56862745e-01f, 2.90896552e+02f, 5.96707819e-01f, 9.52941176e-01f,
1.74822335e+02f, 9.42583732e-01f, 8.19607843e-01f, 2.06600985e+02f,
9.90243902e-01f, 8.03921569e-01f, 1.06883721e+02f, 8.70445344e-01f,
9.68627451e-01f, 1.95272727e+02f, 6.11111111e-01f, 7.05882353e-01f
});
auto actual = NDArrayFactory::create<float>('c', { 5,4,3 });
Context ctx(1);
ctx.setInputArray(0, &hsvs);
ctx.setOutputArray(0, &actual);
nd4j::ops::hsv_to_rgb op;
auto status = op.execute(&ctx);
ASSERT_EQ(ND4J_STATUS_OK, status);
ASSERT_TRUE(expected.equalsTo(actual));
}
TEST_F(DeclarableOpsTests16, test_hsv_to_rgb_2) {
/*
using the same numbers of hsv_to_rgb_2
*/
auto expected = NDArrayFactory::create<float>('c', { 5,3,4 },
{ 213.f, 121.f, 18.f, 235.f, 220.f, 180.f, 245.f, 76.f, 164.f, 180.f, 75.f, 74.f, 168.f,
191.f, 207.f, 77.f, 50.f, 132.f, 37.f, 250.f, 233.f, 100.f, 245.f, 182.f, 111.f, 193.f,
168.f, 48.f, 52.f, 147.f, 103.f, 191.f, 59.f, 137.f, 121.f, 187.f, 53.f, 156.f, 244.f,
221.f, 82.f, 37.f, 90.f, 98.f, 239.f, 118.f, 7.f, 243.f, 12.f, 2.f, 79.f, 70.f,
209.f, 115.f, 247.f, 152.f, 192.f, 205.f, 32.f, 180.f }
);
auto hsvs = NDArrayFactory::create<float>('c', { 5,3,4 },
{
6.75000000e+01f, 1.80000000e+02f, 1.35066079e+02f, 7.45341615e-01f,
2.54545455e-01f, 3.27777778e-01f, 9.26530612e-01f, 6.85106383e-01f,
8.62745098e-01f, 7.05882353e-01f, 9.60784314e-01f, 9.21568627e-01f,
2.78688525e+02f, 2.10989011e+01f, 2.89038462e+02f, 1.56416185e+02f,
7.85407725e-01f, 4.76439791e-01f, 8.48979592e-01f, 6.92000000e-01f,
9.13725490e-01f, 7.49019608e-01f, 9.60784314e-01f, 9.80392157e-01f,
3.52881356e+02f, 1.07142857e+01f, 3.43384615e+02f, 1.78321678e+02f,
5.31531532e-01f, 2.90155440e-01f, 3.86904762e-01f, 7.48691099e-01f,
4.35294118e-01f, 7.56862745e-01f, 6.58823529e-01f, 7.49019608e-01f,
2.30645161e+02f, 3.19159664e+02f, 2.10126582e+01f, 2.90896552e+02f,
7.78242678e-01f, 7.62820513e-01f, 9.71311475e-01f, 5.96707819e-01f,
9.37254902e-01f, 6.11764706e-01f, 9.56862745e-01f, 9.52941176e-01f,
1.74822335e+02f, 2.06600985e+02f, 1.06883721e+02f, 1.95272727e+02f,
9.42583732e-01f, 9.90243902e-01f, 8.70445344e-01f, 6.11111111e-01f,
8.19607843e-01f, 8.03921569e-01f, 9.68627451e-01f, 7.05882353e-01f
});
auto actual = NDArrayFactory::create<float>('c', { 5,3,4 });
Context ctx(1);
ctx.setInputArray(0, &hsvs);
ctx.setOutputArray(0, &actual);
ctx.setIArguments({ 1 });
nd4j::ops::hsv_to_rgb op;
auto status = op.execute(&ctx);
ASSERT_EQ(ND4J_STATUS_OK, status);
ASSERT_TRUE(expected.equalsTo(actual));
}
TEST_F(DeclarableOpsTests16, test_hsv_to_rgb_3) {
/*
2D
*/
auto expected = NDArrayFactory::create<float>('c', { 8,3 },
{ 130.f, 61.f, 239.f, 117.f, 16.f, 168.f, 181.f, 223.f, 0.f, 49.f, 195.f, 195.f, 131.f,
153.f, 78.f, 86.f, 21.f, 209.f, 101.f, 14.f, 107.f, 191.f, 98.f, 210.f });
auto hsvs = NDArrayFactory::create<float>('c', { 8,3 },
{ 263.25842697f, 0.74476987f, 0.9372549f, 279.86842105f,
0.9047619f, 0.65882353f, 71.30044843f, 1.f,
0.8745098f, 180.f, 0.74871795f, 0.76470588f,
77.6f, 0.49019608f, 0.6f, 260.74468085f,
0.89952153f, 0.81960784f, 296.12903226f, 0.86915888f,
0.41960784f, 289.82142857f, 0.53333333f, 0.82352941f });
auto actual = NDArrayFactory::create<float>('c', { 8,3 });
Context ctx(1);
ctx.setInputArray(0, &hsvs);
ctx.setOutputArray(0, &actual);
nd4j::ops::hsv_to_rgb op;
auto status = op.execute(&ctx);
ASSERT_EQ(ND4J_STATUS_OK, status);
ASSERT_TRUE(expected.equalsTo(actual));
}
TEST_F(DeclarableOpsTests16, test_hsv_to_rgb_4) {
/*
2D
*/
auto expected = NDArrayFactory::create<float>('c', { 3,8 },
{ 130.f, 117.f, 181.f, 49.f, 131.f, 86.f, 101.f, 191.f, 61.f, 16.f, 223.f, 195.f, 153.f,
21.f, 14.f, 98.f, 239.f, 168.f, 0.f, 195.f, 78.f, 209.f, 107.f, 210.f });
auto hsvs = NDArrayFactory::create<float>('c', { 3,8 },
{ 263.25842697f, 279.86842105f, 71.30044843f, 180.f,
77.6f, 260.74468085f, 296.12903226f, 289.82142857f,
0.74476987f, 0.9047619f, 1.f, 0.74871795f,
0.49019608f, 0.89952153f, 0.86915888f, 0.53333333f,
0.9372549f, 0.65882353f, 0.8745098f, 0.76470588f,
0.6f, 0.81960784f, 0.41960784f, 0.82352941f });
auto actual = NDArrayFactory::create<float>('c', { 3,8 });
Context ctx(1);
ctx.setInputArray(0, &hsvs);
ctx.setOutputArray(0, &actual);
ctx.setIArguments({ 0 });
nd4j::ops::hsv_to_rgb op;
auto status = op.execute(&ctx);
ASSERT_EQ(ND4J_STATUS_OK, status);
ASSERT_TRUE(expected.equalsTo(actual));
}
TEST_F(DeclarableOpsTests16, test_hsv_to_rgb_5) {
/*
*/
auto expected = NDArrayFactory::create<float>('c', { 3 },
{ 213.f, 220.f, 164.f });
auto hsvs = NDArrayFactory::create<float>('c', { 3 },
{ 6.75000000e+01f, 2.54545455e-01f, 8.62745098e-01f });
auto actual = NDArrayFactory::create<float>('c', { 3 });
Context ctx(1);
ctx.setInputArray(0, &hsvs);
ctx.setOutputArray(0, &actual);
nd4j::ops::hsv_to_rgb op;
auto status = op.execute(&ctx);
ASSERT_EQ(ND4J_STATUS_OK, status);
ASSERT_TRUE(expected.equalsTo(actual));
}
TEST_F(DeclarableOpsTests16, test_hsv_to_rgb_6) {
auto expected = NDArrayFactory::create<double>('c', { 3 },
{ 130.0, 61.0, 239.0 });
auto hsvs = NDArrayFactory::create<double>('c', { 3,8 },
{ 263.25842697, 279.86842105, 71.30044843, 180,
77.6, 260.74468085, 296.12903226, 289.82142857,
0.74476987, 0.9047619, 1., 0.74871795,
0.49019608, 0.89952153, 0.86915888, 0.53333333,
0.9372549, 0.65882353, 0.8745098, 0.76470588,
0.6, 0.81960784, 0.41960784, 0.82352941
});
//get subarray
std::unique_ptr<NDArray> subArrHsvs(hsvs.subarray({ NDIndex::all(), NDIndex::point(0) }));
subArrHsvs->reshapei({ 3 });
#if 0
//[RANK][SHAPE][STRIDES][OPTIONS][EWS][ORDER]
subArrHsvs->printShapeInfo("subArrHsvs");
#endif
auto actual = NDArrayFactory::create<double>('c', { 3 });
Context ctx(1);
ctx.setInputArray(0, subArrHsvs.get());
ctx.setOutputArray(0, &actual);
nd4j::ops::hsv_to_rgb op;
auto status = op.execute(&ctx);
#if 0
//visual check
subArrHsvs->printBuffer("subArrHsvs ");
actual.printBuffer("rgb ");
expected.printBuffer("exp");
#endif
ASSERT_EQ(ND4J_STATUS_OK, status);
ASSERT_TRUE(expected.equalsTo(actual));
}