[WIP] Oleh rgb yuv (#147)

* libnd4j: RgbToYuv and YuvToRgb, both implementations for both cpu and cuda. Need adding tests and review

Signed-off-by: Oleg <oleg.semeniv@gmail.com>

* libnd4j: RgbToYuv and YuvToRgb, replace coords method on Tad in both cpu and cuda, add tests, fixed bugs

Signed-off-by: Oleg <oleg.semeniv@gmail.com>

* libnd4j: RgbToYuv and YuvToRgb minor corrections

Signed-off-by: Oleg <oleg.semeniv@gmail.com>

* libnd4j: RgbToYuv and YuvToRgb corrections to use operations in-place
master
Oleh 2019-12-24 17:30:54 +02:00 committed by raver119
parent d1e5e79c10
commit 75123b0a4c
8 changed files with 559 additions and 6 deletions

View File

@ -0,0 +1,57 @@
/*******************************************************************************
* Copyright (c) 2019 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// @author Oleh Semeniv (oleg.semeniv@gmail.com)
//
#include <ops/declarable/headers/images.h>
#include <ops/declarable/CustomOperations.h>
#include <helpers/ConstantTadHelper.h>
#include <execution/Threads.h>
namespace nd4j {
namespace ops {
CONFIGURABLE_OP_IMPL(rgb_to_yuv, 1, 1, true, 0, 0) {
auto input = INPUT_VARIABLE(0);
auto output = OUTPUT_VARIABLE(0);
const int rank = input->rankOf();
const int argSize = block.getIArguments()->size();
const int dimC = argSize > 0 ? (INT_ARG(0) >= 0 ? INT_ARG(0) : INT_ARG(0) + rank) : rank - 1;
REQUIRE_TRUE(rank >= 1, 0, "RGBtoYUV: Fails to meet the rank requirement: %i >= 1 ", rank);
if (argSize > 0) {
REQUIRE_TRUE(dimC >= 0 && dimC < rank, 0, "Index of the Channel dimension out of range: %i not in [%i,%i) ", INT_ARG(0), -rank, rank);
}
REQUIRE_TRUE(input->sizeAt(dimC) == 3, 0, "RGBtoYUV: operation expects 3 channels (R, G, B), but got %i instead", input->sizeAt(dimC));
helpers::transformRgbYuv(block.launchContext(), *input, *output, dimC);
return Status::OK();
}
DECLARE_TYPES(rgb_to_yuv) {
getOpDescriptor()->setAllowedInputTypes({ ALL_FLOATS })
->setSameMode(true);
}
}
}

View File

@ -0,0 +1,56 @@
/*******************************************************************************
* Copyright (c) 2019 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// @author Oleh Semeniv (oleg.semeniv@gmail.com)
//
#include <ops/declarable/headers/images.h>
#include <ops/declarable/CustomOperations.h>
#include <helpers/ConstantTadHelper.h>
#include <execution/Threads.h>
namespace nd4j {
namespace ops {
CONFIGURABLE_OP_IMPL(yuv_to_rgb, 1, 1, true, 0, 0) {
auto input = INPUT_VARIABLE(0);
auto output = OUTPUT_VARIABLE(0);
const int rank = input->rankOf();
const int argSize = block.getIArguments()->size();
const int dimC = argSize > 0 ? (INT_ARG(0) >= 0 ? INT_ARG(0) : INT_ARG(0) + rank) : rank - 1;
REQUIRE_TRUE(rank >= 1, 0, "YUVtoRGB: Fails to meet the rank requirement: %i >= 1 ", rank);
if (argSize > 0) {
REQUIRE_TRUE(dimC >= 0 && dimC < rank, 0, "Index of the Channel dimension out of range: %i not in [%i,%i) ", INT_ARG(0), -rank, rank);
}
REQUIRE_TRUE(input->sizeAt(dimC) == 3, 0, "YUVtoRGB: operation expects 3 channels (Y, U, V), but got %i instead", input->sizeAt(dimC));
helpers::transformYuvRgb(block.launchContext(), *input, *output, dimC);
return Status::OK();
}
DECLARE_TYPES(yuv_to_rgb) {
getOpDescriptor()->setAllowedInputTypes({ ALL_FLOATS })
->setSameMode(true);
}
}
}

View File

@ -42,7 +42,7 @@ namespace ops {
* 0 - optional argument, corresponds to dimension with 3 channels
*/
#if NOT_EXCLUDED(OP_rgb_to_hsv)
DECLARE_CONFIGURABLE_OP(rgb_to_hsv, 1, 1, false, 0, 0);
DECLARE_CONFIGURABLE_OP(rgb_to_hsv, 1, 1, true, 0, 0);
#endif
/**
@ -53,7 +53,7 @@ namespace ops {
* 0 - optional argument, corresponds to dimension with 3 channels
*/
#if NOT_EXCLUDED(OP_hsv_to_rgb)
DECLARE_CONFIGURABLE_OP(hsv_to_rgb, 1, 1, false, 0, 0);
DECLARE_CONFIGURABLE_OP(hsv_to_rgb, 1, 1, true, 0, 0);
#endif
/**
@ -65,6 +65,27 @@ namespace ops {
DECLARE_CUSTOM_OP(rgb_to_grs, 1, 1, false, 0, 0);
#endif
/**
* Rgb To Yuv
* Input arrays:
* 0 - input array with rank >= 1, must have at least one dimension equal 3, that is dimension containing channels.
* Int arguments:
* 0 - optional argument, corresponds to dimension with 3 channels
*/
#if NOT_EXCLUDED(OP_rgb_to_yuv)
DECLARE_CONFIGURABLE_OP(rgb_to_yuv, 1, 1, true, 0, 0);
#endif
/**
* Yuv To Rgb
* Input arrays:
* 0 - input array with rank >= 1, must have at least one dimension equal 3, that is dimension containing channels.
* Int arguments:
* 0 - optional argument, corresponds to dimension with 3 channels
*/
#if NOT_EXCLUDED(OP_rgb_to_yuv)
DECLARE_CONFIGURABLE_OP(yuv_to_rgb, 1, 1, true, 0, 0);
/**
* Rgb To Yiq
* Input arrays:
@ -73,7 +94,7 @@ namespace ops {
* 0 - optional argument, corresponds to dimension with 3 channels
*/
#if NOT_EXCLUDED(OP_rgb_to_yiq)
DECLARE_CONFIGURABLE_OP(rgb_to_yiq, 1, 1, false, 0, 0);
DECLARE_CONFIGURABLE_OP(rgb_to_yiq, 1, 1, true, 0, 0);
#endif
/**
@ -84,10 +105,11 @@ namespace ops {
* 0 - optional argument, corresponds to dimension with 3 channels
*/
#if NOT_EXCLUDED(OP_yiq_to_rgb)
DECLARE_CONFIGURABLE_OP(yiq_to_rgb, 1, 1, false, 0, 0);
DECLARE_CONFIGURABLE_OP(yiq_to_rgb, 1, 1, true, 0, 0);
#endif
}
}
#endif
#endif

View File

@ -17,6 +17,7 @@
//
// @author raver119@gmail.com
// @author Yurii Shyrma (iuriish@yahoo.com)
// @author Oleh Semeniv (oleg.semeniv@gmail.com)
//
#include <op_boilerplate.h>
@ -106,6 +107,21 @@ FORCEINLINE _CUDA_HD void hsvToRgb(const T& h, const T& s, const T& v, T& r, T&
// b *= 255;
}
////////////////////////////////////////////////////////////////////////////////
template <typename T>
FORCEINLINE _CUDA_HD void rgbYuv(const T& r, const T& g, const T& b, T& y, T& u, T& v) {
y = static_cast<T>(0.299) * r + static_cast<T>(0.587) *g + static_cast<T>(0.114) * b;
u = -static_cast<T>(0.14714119) * r - static_cast<T>(0.2888691) * g + static_cast<T>(0.43601035) * b;
v = static_cast<T>(0.61497538) * r - static_cast<T>(0.51496512) * g - static_cast<T>(0.10001026) * b;
}
////////////////////////////////////////////////////////////////////////////////
template <typename T>
FORCEINLINE _CUDA_HD void yuvRgb(const T& y, const T& u, const T& v, T& r, T& g, T& b) {
r = y + static_cast<T>(1.13988303) * v;
g = y - static_cast<T>(0.394642334) * u - static_cast<T>(0.58062185) * v;
b = y + static_cast<T>(2.03206185) * u;
}
/*////////////////////////////////////////////////////////////////////////////////
template <typename T>

View File

@ -70,6 +70,65 @@ void transformRgbGrs(nd4j::LaunchContext* context, const NDArray& input, NDArray
BUILD_SINGLE_SELECTOR(input.dataType(), rgbToGrs_, (input, output, dimC), NUMERIC_TYPES);
}
template <typename T, typename Op>
FORCEINLINE static void rgbToFromYuv_(const NDArray& input, NDArray& output, const int dimC, Op op) {
const T* x = input.bufferAsT<T>();
T* z = output.bufferAsT<T>();
const int rank = input.rankOf();
bool bSimple = (dimC == rank - 1 && 'c' == input.ordering() && 1 == input.ews() &&
'c' == output.ordering() && 1 == output.ews());
if (bSimple) {
auto func = PRAGMA_THREADS_FOR{
for (auto i = start; i < stop; i += increment) {
op(x[i], x[i + 1], x[i + 2], z[i], z[i + 1], z[i + 2]);
}
};
samediff::Threads::parallel_for(func, 0, input.lengthOf(), 3);
return;
}
auto packX = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(input.getShapeInfo(), dimC);
auto packZ = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(output.getShapeInfo(), dimC);
const Nd4jLong numOfTads = packX.numberOfTads();
const Nd4jLong xDimCstride = input.stridesOf()[dimC];
const Nd4jLong zDimCstride = output.stridesOf()[dimC];
auto func = PRAGMA_THREADS_FOR{
for (auto i = start; i < stop; i += increment) {
const T* xTad = x + packX.platformOffsets()[i];
T* zTad = z + packZ.platformOffsets()[i];
op(xTad[0], xTad[xDimCstride], xTad[2 * xDimCstride], zTad[0], zTad[zDimCstride], zTad[2 * zDimCstride]);
}
};
samediff::Threads::parallel_tad(func, 0, numOfTads);
return;
}
template <typename T>
FORCEINLINE static void rgbYuv_(const NDArray& input, NDArray& output, const int dimC) {
auto op = nd4j::ops::helpers::rgbYuv<T>;
return rgbToFromYuv_<T>(input, output, dimC, op);
}
void transformRgbYuv(nd4j::LaunchContext* context, const NDArray& input, NDArray& output, const int dimC) {
BUILD_SINGLE_SELECTOR(input.dataType(), rgbYuv_, (input, output, dimC), FLOAT_TYPES);
}
template <typename T>
FORCEINLINE static void yuvRgb_(const NDArray& input, NDArray& output, const int dimC) {
auto op = nd4j::ops::helpers::yuvRgb<T>;
return rgbToFromYuv_<T>(input, output, dimC, op);
}
void transformYuvRgb(nd4j::LaunchContext* context, const NDArray& input, NDArray& output, const int dimC) {
BUILD_SINGLE_SELECTOR(input.dataType(), yuvRgb_, (input, output, dimC), FLOAT_TYPES);
}
template <typename T, typename Op>
FORCEINLINE static void tripleTransformer(const NDArray* input, NDArray* output, const int dimC, Op op) {

View File

@ -16,6 +16,7 @@
//
// @author Yurii Shyrma (iuriish@yahoo.com)
// @author Oleh Semeniv (oleg.semeniv@gmail.com)
//
#include <op_boilerplate.h>
@ -29,6 +30,117 @@ namespace nd4j {
namespace ops {
namespace helpers {
///////////////////////////////////////////////////////////////////
template<typename T>
__global__ void rgbToYuvCuda(const void* vx, const Nd4jLong* xShapeInfo, const Nd4jLong* xTadOffsets, void* vz, const Nd4jLong *zShapeInfo, const Nd4jLong* zTadOffsets, const Nd4jLong numOfTads, const int dimC) {
const T* x = reinterpret_cast<const T*>(vx);
T* z = reinterpret_cast<T*>(vz);
__shared__ int rank;
__shared__ Nd4jLong xDimCstride, zDimCstride;
if (threadIdx.x == 0) {
rank = shape::rank(xShapeInfo);
xDimCstride = shape::stride(xShapeInfo)[dimC];
zDimCstride = shape::stride(zShapeInfo)[dimC];
}
__syncthreads();
const auto tid = blockIdx.x * blockDim.x + threadIdx.x;
for (Nd4jLong i = tid; i < numOfTads; i += gridDim.x * blockDim.x) {
const T* xTad = x + xTadOffsets[i];
T* zTad = z + zTadOffsets[i];
rgbYuv<T>(xTad[0], xTad[xDimCstride], xTad[2 * xDimCstride], zTad[0], zTad[zDimCstride], zTad[2 * zDimCstride]);
}
}
///////////////////////////////////////////////////////////////////
template<typename T>
linkage void rgbToYuvCudaLauncher(const int blocksPerGrid, const int threadsPerBlock, const cudaStream_t* stream, const void* vx, const Nd4jLong* xShapeInfo, const Nd4jLong* xTadOffsets, void* vz, const Nd4jLong* zShapeInfo, const Nd4jLong* zTadOffsets, const Nd4jLong numOfTads, const int dimC) {
rgbToYuvCuda<T> << <blocksPerGrid, threadsPerBlock, 256, * stream >> > (vx, xShapeInfo, xTadOffsets, vz, zShapeInfo, zTadOffsets, numOfTads, dimC);
}
///////////////////////////////////////////////////////////////////
void transformRgbYuv(nd4j::LaunchContext* context, const NDArray& input, NDArray& output, const int dimC) {
auto packX = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(input.getShapeInfo(), { dimC });
auto packZ = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(output.getShapeInfo(), { dimC });
const Nd4jLong numOfTads = packX.numberOfTads();
const int threadsPerBlock = MAX_NUM_THREADS / 2;
const int blocksPerGrid = (numOfTads + threadsPerBlock - 1) / threadsPerBlock;
PointersManager manager(context, "yuv_to_rgb");
NDArray::prepareSpecialUse({ &output }, { &input });
BUILD_SINGLE_SELECTOR(input.dataType(), rgbToYuvCudaLauncher, (blocksPerGrid, threadsPerBlock, context->getCudaStream(), input.getSpecialBuffer(), input.getSpecialShapeInfo(), packX.platformOffsets(), output.specialBuffer(), output.specialShapeInfo(), packZ.platformOffsets(), numOfTads, dimC), FLOAT_TYPES);
NDArray::registerSpecialUse({ &output }, { &input });
manager.synchronize();
}
///////////////////////////////////////////////////////////////////
template<typename T>
__global__ void yuvToRgbCuda(const void* vx, const Nd4jLong* xShapeInfo, const Nd4jLong* xTadOffsets, void* vz, const Nd4jLong *zShapeInfo, const Nd4jLong* zTadOffsets, const Nd4jLong numOfTads, const int dimC) {
const T* x = reinterpret_cast<const T*>(vx);
T* z = reinterpret_cast<T*>(vz);
__shared__ int rank;
__shared__ Nd4jLong xDimCstride, zDimCstride;
if (threadIdx.x == 0) {
rank = shape::rank(xShapeInfo);
xDimCstride = shape::stride(xShapeInfo)[dimC];
zDimCstride = shape::stride(zShapeInfo)[dimC];
}
__syncthreads();
const auto tid = blockIdx.x * blockDim.x + threadIdx.x;
for (Nd4jLong i = tid; i < numOfTads; i += gridDim.x * blockDim.x) {
const T* xTad = x + xTadOffsets[i];
T* zTad = z + zTadOffsets[i];
yuvRgb<T>(xTad[0], xTad[xDimCstride], xTad[2 * xDimCstride], zTad[0], zTad[zDimCstride], zTad[2 * zDimCstride]);
}
}
///////////////////////////////////////////////////////////////////
template<typename T>
linkage void yuvToRgbCudaLauncher(const int blocksPerGrid, const int threadsPerBlock, const cudaStream_t* stream, const void* vx, const Nd4jLong* xShapeInfo, const Nd4jLong* xTadOffsets, void* vz, const Nd4jLong* zShapeInfo, const Nd4jLong* zTadOffsets, const Nd4jLong numOfTads, const int dimC) {
yuvToRgbCuda<T> << <blocksPerGrid, threadsPerBlock, 256, * stream >> > (vx, xShapeInfo, xTadOffsets, vz, zShapeInfo, zTadOffsets, numOfTads, dimC);
}
///////////////////////////////////////////////////////////////////
void transformYuvRgb(nd4j::LaunchContext* context, const NDArray& input, NDArray& output, const int dimC) {
auto packX = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(input.getShapeInfo(), { dimC });
auto packZ = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(output.getShapeInfo(), { dimC });
const Nd4jLong numOfTads = packX.numberOfTads();
const int threadsPerBlock = MAX_NUM_THREADS / 2;
const int blocksPerGrid = (numOfTads + threadsPerBlock - 1) / threadsPerBlock;
PointersManager manager(context, "yuv_to_rgb");
NDArray::prepareSpecialUse({ &output }, { &input });
BUILD_SINGLE_SELECTOR(input.dataType(), yuvToRgbCudaLauncher, (blocksPerGrid, threadsPerBlock, context->getCudaStream(), input.getSpecialBuffer(), input.getSpecialShapeInfo(), packX.platformOffsets(), output.specialBuffer(), output.specialShapeInfo(), packZ.platformOffsets(), numOfTads, dimC), FLOAT_TYPES);
NDArray::registerSpecialUse({ &output }, { &input });
manager.synchronize();
}
///////////////////////////////////////////////////////////////////
// for example xShapeInfo = {2,3,4}, zShapeInfo = {2,1,4}
template<typename T>
@ -83,7 +195,7 @@ void transformRgbGrs(nd4j::LaunchContext* context, const NDArray& input, NDArray
PointersManager manager(context, "rgbToGrs");
const int threadsPerBlock = MAX_NUM_THREADS / 2;
const int threadsPerBlock = MAX_NUM_THREADS / 4;
const int blocksPerGrid = (input.lengthOf() + threadsPerBlock - 1) / threadsPerBlock;
const int sharedMem = input.rankOf() * sizeof(Nd4jLong) * threadsPerBlock + 128;

View File

@ -37,6 +37,8 @@ namespace helpers {
void transformHsvRgb(nd4j::LaunchContext* context, const NDArray* input, NDArray* output, const int dimC);
void transformRgbHsv(nd4j::LaunchContext* context, const NDArray* input, NDArray* output, const int dimC);
void transformYuvRgb(nd4j::LaunchContext* context, const NDArray& input, NDArray& output, const int dimC);
void transformRgbYuv(nd4j::LaunchContext* context, const NDArray& input, NDArray& output, const int dimC);
void transformYiqRgb(nd4j::LaunchContext* context, const NDArray* input, NDArray* output, const int dimC);

View File

@ -1066,3 +1066,232 @@ TEST_F(DeclarableOpsTests15, test_rgb_to_grs_9) {
delete result;
}
////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests15, test_rgb_to_yuv_1) {
// rank 1
NDArray rgbs('f', { 3 }, { 10, 50, 200 }, nd4j::DataType::FLOAT32);
NDArray expected('f', { 3 }, { 55.14 , 71.2872001, -39.6005542 }, nd4j::DataType::FLOAT32);
nd4j::ops::rgb_to_yuv op;
auto result = op.execute({ &rgbs }, {}, {});
auto output = result->at(0);
ASSERT_EQ(Status::OK(), result->status());
ASSERT_TRUE(expected.isSameShape(output));
ASSERT_TRUE(expected.equalsTo(output));
delete result;
}
////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests15, test_rgb_to_yuv_2) {
NDArray rgbs('c', { 3, 2 }, { 14., 99., 207., 10., 114., 201. }, nd4j::DataType::FLOAT32);
rgbs.permutei({ 1,0 });
NDArray expected('c', { 2, 3 }, { 138.691, -12.150713, -109.38929, 58.385, 70.18241, 35.63085 }, nd4j::DataType::FLOAT32);
nd4j::ops::rgb_to_yuv op;
auto result = op.execute({ &rgbs }, {}, {});
auto output = result->at(0);
ASSERT_EQ(Status::OK(), result->status());
ASSERT_TRUE(expected.isSameShape(output));
ASSERT_TRUE(expected.equalsTo(output));
delete result;
}
////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests15, test_rgb_to_yuv_3) {
// rank 2
NDArray rgbs('c', { 3, 4 }, { -9.4, 9.9, 9.7, 9.0, 1.14, 1.01, 1.11, 9.6, 1.05, 10.0, 1.03, 10.22 }, nd4j::DataType::FLOAT32);
NDArray expected('c', { 3, 4 }, { -2.021720, 4.692970, 3.669290, 9.491281, 1.511627, 2.611648, -1.298824, 0.358612, -6.472839, 4.568039, 5.290639, -0.430992 }, nd4j::DataType::FLOAT32);
nd4j::ops::rgb_to_yuv op;
auto result = op.execute({ &rgbs }, {}, { 0 });
auto output = result->at(0);
ASSERT_EQ(Status::OK(), result->status());
ASSERT_TRUE(expected.isSameShape(output));
ASSERT_TRUE(expected.equalsTo(output));
delete result;
}
////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests15, test_rgb_to_yuv_4) {
// rank 3
NDArray rgbs('c', { 5,4,3 }, { 1.7750e+01, 1.4602e+01, 5.4883e+00, 9.5438e+01, 1.0038e+02, 4.0531e+01, -5.8844e+01, 2.9609e+01, -1.1414e+01, 2.1391e+01, 3.9656e+01, 2.1531e+01, -7.1062e+01, -4.5859e+00, 2.9438e+01, -6.7461e+00, 6.7938e+01, -6.1211e+00, 2.2750e+01, -6.1438e+01, 1.5404e-02, -8.5312e+01, 1.1641e+01, 6.2500e+01, -1.0019e+02, 3.9344e+01, -3.1344e+01, 3.8562e+01, 5.9961e+00, 6.2219e+01, -1.0477e+01, 1.7750e+01, 2.9938e+01, 7.5830e-01, -2.7516e+01, 7.2188e+01, -2.3406e+01, 1.1617e+01, 6.5125e+01, 6.5078e+00, 6.7812e+01, 4.6812e+01, 7.7344e+00, 6.8562e+01, 5.6719e+00, 2.3125e+01, 6.7562e+01, 9.3750e+00, 5.2094e+01, -8.6562e+01, 1.2695e+01, 3.3562e+01, 2.9734e+01, 5.2250e+01, 9.5469e+00, -7.4414e+00, -2.0125e+01, 1.8145e+00, 7.8438e+01, -4.8125e+01 }, nd4j::DataType::FLOAT32);
NDArray expected('c', { 5,4,3 }, { 14.5042902, -4.43686799, 2.847406, 92.079556, -25.36761168, 2.94630572, -1.515069, -4.87137291, -50.29369639, 32.128515, -5.21515376, -9.41983935,-20.5835293, 24.61614501, -44.28390394, 37.1647167, -21.30142676, -38.52221293, -29.26009994, 14.40679768, 45.62757638, -11.550021, 36.44083018, -64.71012983,-10.435098, - 10.28950082, - 78.74044941, 22.1427147, 19.72198103, 14.40435988, 10.699559, 9.46744852, - 18.5778351 , -7.6957283, 39.31166179, 7.41657542, 7.245035, 28.48336771, - 26.88963173, 47.0880442, - 0.13584441, - 35.60035823, 43.2050762, - 18.47048906, - 31.11782117, 47.642019, - 18.83162118, - 21.50836396,-33.788558, 22.87507047, 75.34330791, 33.445396, 9.25395257, 0.10229474, -3.8078287, -8.02985955, 11.71587638, 41.0993915, -43.90830496, -34.46396749 }, nd4j::DataType::FLOAT32);
nd4j::ops::rgb_to_yuv op;
auto result = op.execute({ &rgbs }, {}, {});
auto output = result->at(0);
ASSERT_EQ(Status::OK(), result->status());
ASSERT_TRUE(expected.isSameShape(output));
ASSERT_TRUE(expected.equalsTo(output));
delete result;
}
////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests15, test_rgb_to_yuv_5) {
// rank 3
NDArray rgbs('c', { 5,3,4 }, { 1.7750e+01f, -7.1062e+01f, -1.0019e+02f,-2.3406e+01f, 5.2094e+01f, 9.5438e+01f, -6.7461e+00f, 3.8562e+01f, 6.5078e+00f,3.3562e+01f, -5.8844e+01f, 2.2750e+01f, -1.0477e+01f, 7.7344e+00f, 9.5469e+00f,2.1391e+01f, -8.5312e+01f, 7.5830e-01f,2.3125e+01f, 1.8145e+00f, 1.4602e+01f,-4.5859e+00f, 3.9344e+01f, 1.1617e+01f,-8.6562e+01f, 1.0038e+02f, 6.7938e+01f,5.9961e+00f, 6.7812e+01f, 2.9734e+01f,2.9609e+01f, -6.1438e+01f, 1.7750e+01f,6.8562e+01f, -7.4414e+00f, 3.9656e+01f,1.1641e+01f, -2.7516e+01f, 6.7562e+01f,7.8438e+01f, 5.4883e+00f, 2.9438e+01f,-3.1344e+01f, 6.5125e+01f, 1.2695e+01f,4.0531e+01f, -6.1211e+00f, 6.2219e+01f,4.6812e+01f, 5.2250e+01f, -1.1414e+01f,1.5404e-02f, 2.9938e+01f, 5.6719e+00f,-2.0125e+01f, 2.1531e+01f, 6.2500e+01f,7.2188e+01f, 9.3750e+00f, -4.8125e+01f }, nd4j::DataType::FLOAT32);
NDArray expected('c', { 5,3,4 }, { 36.628319, 38.600643,-40.624989, 18.231001, - 14.822637, - 2.479566, - 8.965780, 2.223851, -16.561626,-96.205162,-52.255379,-36.527435,-51.546139,2.234915, 20.914114, 8.785358, 32.552223, -3.356598, 9.069552, 1.393482,36.029255, 4.824605,- 9.972263,11.058715, 15.947105, 55.283543, 36.845627, -29.750486,0.887228, 6.534475, -21.794132,34.155693, -89.929497,39.562351, 27.276817,31.359871, 8.149521, 13.673355, 1.104303, 68.774300, 2.236881, 13.216944, - 3.555702,- 3.225931,3.063015, - 36.134724,58.302204, 8.477802, 38.695396,27.181587, - 14.157411,7.157054, 11.714512, 22.148155, 11.580557, - 27.204905,7.120562, 21.992094, 2.406748, - 6.265247, }, nd4j::DataType::FLOAT32);
nd4j::ops::rgb_to_yuv op;
auto result = op.execute({ &rgbs }, {}, { 1 });
auto output = result->at(0);
ASSERT_EQ(Status::OK(), result->status());
ASSERT_TRUE(expected.isSameShape(output));
ASSERT_TRUE(expected.equalsTo(output));
delete result;
}
////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests15, test_rgb_to_yuv_6) {
// rank 3
NDArray rgbs('c', { 3,5,4 }, { 1.7750e+01f, -7.1062e+01f, -1.0019e+02f,-2.3406e+01f, 5.2094e+01f, 9.5438e+01f, -6.7461e+00f, 3.8562e+01f, 6.5078e+00f,3.3562e+01f, -5.8844e+01f, 2.2750e+01f, -1.0477e+01f, 7.7344e+00f, 9.5469e+00f,2.1391e+01f, -8.5312e+01f, 7.5830e-01f,2.3125e+01f, 1.8145e+00f, 1.4602e+01f,-4.5859e+00f, 3.9344e+01f, 1.1617e+01f,-8.6562e+01f, 1.0038e+02f, 6.7938e+01f,5.9961e+00f, 6.7812e+01f, 2.9734e+01f,2.9609e+01f, -6.1438e+01f, 1.7750e+01f,6.8562e+01f, -7.4414e+00f, 3.9656e+01f,1.1641e+01f, -2.7516e+01f, 6.7562e+01f,7.8438e+01f, 5.4883e+00f, 2.9438e+01f,-3.1344e+01f, 6.5125e+01f, 1.2695e+01f,4.0531e+01f, -6.1211e+00f, 6.2219e+01f,4.6812e+01f, 5.2250e+01f, -1.1414e+01f,1.5404e-02f, 2.9938e+01f, 5.6719e+00f,-2.0125e+01f, 2.1531e+01f, 6.2500e+01f,7.2188e+01f, 9.3750e+00f, -4.8125e+01f }, nd4j::DataType::FLOAT32);
try {
nd4j::ops::rgb_to_yuv op;
auto result = op.execute({ &rgbs }, {}, {});
ASSERT_EQ(Status::THROW(), result->status());
delete result;
}
catch (std::exception & e) {
nd4j_printf("Error should be here `%s'. It's OK.\n", e.what());
}
}
////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests15, test_rgb_to_yuv_7) {
// rank 3
NDArray rgbs('f', { 2, 2, 3 }, { 1.7750e+01f,-7.1062e+01f, -1.0019e+02f, -2.3406e+01f,5.2094e+01f,9.5438e+01f, -6.7461e+00f,3.8562e+01f, 6.5078e+00f, 3.3562e+01f,-5.8844e+01f,2.2750e+01f }, nd4j::DataType::FLOAT32);
NDArray expected('f', { 2,2,3 }, { 36.628319,38.600643, -40.624989,18.231001, -14.822637,-2.479566, -8.965780, 2.223851, -16.561626,- 96.205162,-52.255379, -36.527435 }, nd4j::DataType::FLOAT32);
nd4j::ops::rgb_to_yuv op;
auto result = op.execute({ &rgbs }, {}, {});
auto output = result->at(0);
ASSERT_EQ(Status::OK(), result->status());
ASSERT_TRUE(expected.isSameShape(output));
ASSERT_TRUE(expected.equalsTo(output));
delete result;
}
////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests15, test_yuv_to_rgb_1) {
// rank 1
NDArray yuv('c', { 3 }, { 55.14 , 71.2872001, -39.6005542 }, nd4j::DataType::FLOAT32);
NDArray expected('c', { 3 }, { 10, 50, 200 }, nd4j::DataType::FLOAT32);
nd4j::ops::yuv_to_rgb op;
auto result = op.execute({ &yuv }, {}, {});
auto output = result->at(0);
ASSERT_EQ(Status::OK(), result->status());
ASSERT_TRUE(expected.isSameShape(output));
ASSERT_TRUE(expected.equalsTo(output));
delete result;
}
////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests15, test_yuv_to_rgb_2) {
// rank 1
NDArray yuv('f', { 3 }, { 55.14, 71.2872001, -39.6005542 }, nd4j::DataType::FLOAT32);
NDArray expected('f', { 3 }, { 10, 50, 200 }, nd4j::DataType::FLOAT32);
nd4j::ops::yuv_to_rgb op;
auto result = op.execute({ &yuv }, {}, {});
auto output = result->at(0);
ASSERT_EQ(Status::OK(), result->status());
ASSERT_TRUE(expected.isSameShape(output));
ASSERT_TRUE(expected.equalsTo(output));
delete result;
}
////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests15, test_yuv_to_rgb_3) {
// rank 2
NDArray expected('c', { 3, 4 }, { -9.4, 9.9, 9.7, 9.0, 1.14, 1.01, 1.11, 9.6, 1.05, 10.0, 1.03, 10.22 }, nd4j::DataType::FLOAT32);
NDArray yuv('c', { 3, 4 }, { -2.021720, 4.692970, 3.669290, 9.491281, 1.511627, 2.611648, -1.298824, 0.358612, -6.472839, 4.568039, 5.290639, -0.430992 }, nd4j::DataType::FLOAT32);
nd4j::ops::yuv_to_rgb op;
auto result = op.execute({ &yuv }, {}, { 0 });
auto output = result->at(0);
ASSERT_EQ(Status::OK(), result->status());
ASSERT_TRUE(expected.isSameShape(output));
ASSERT_TRUE(expected.equalsTo(output));
delete result;
}
////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests15, test_yuv_to_rgb_4) {
// rank 3
NDArray expected('c', { 5,4,3 }, { 1.7750e+01, 1.4602e+01, 5.4883e+00, 9.5438e+01, 1.0038e+02, 4.0531e+01, -5.8844e+01, 2.9609e+01, -1.1414e+01, 2.1391e+01, 3.9656e+01, 2.1531e+01, -7.1062e+01, -4.5859e+00, 2.9438e+01, -6.7461e+00, 6.7938e+01, -6.1211e+00, 2.2750e+01, -6.1438e+01, 1.5404e-02, -8.5312e+01, 1.1641e+01, 6.2500e+01, -1.0019e+02, 3.9344e+01, -3.1344e+01, 3.8562e+01, 5.9961e+00, 6.2219e+01, -1.0477e+01, 1.7750e+01, 2.9938e+01, 7.5830e-01, -2.7516e+01, 7.2188e+01, -2.3406e+01, 1.1617e+01, 6.5125e+01, 6.5078e+00, 6.7812e+01, 4.6812e+01, 7.7344e+00, 6.8562e+01, 5.6719e+00, 2.3125e+01, 6.7562e+01, 9.3750e+00, 5.2094e+01, -8.6562e+01, 1.2695e+01, 3.3562e+01, 2.9734e+01, 5.2250e+01, 9.5469e+00, -7.4414e+00, -2.0125e+01, 1.8145e+00, 7.8438e+01, -4.8125e+01 }, nd4j::DataType::FLOAT32);
NDArray yuv('c', { 5,4,3 }, { 14.5042902, -4.43686799, 2.847406, 92.079556, -25.36761168, 2.94630572, -1.515069, -4.87137291, -50.29369639, 32.128515, -5.21515376, -9.41983935,-20.5835293, 24.61614501, -44.28390394, 37.1647167, -21.30142676, -38.52221293, -29.26009994, 14.40679768, 45.62757638, -11.550021, 36.44083018, -64.71012983,-10.435098, -10.28950082, -78.74044941, 22.1427147, 19.72198103, 14.40435988, 10.699559, 9.46744852, -18.5778351 , -7.6957283, 39.31166179, 7.41657542, 7.245035, 28.48336771, -26.88963173, 47.0880442, -0.13584441, -35.60035823, 43.2050762, -18.47048906, -31.11782117, 47.642019, -18.83162118, -21.50836396,-33.788558, 22.87507047, 75.34330791, 33.445396, 9.25395257, 0.10229474, -3.8078287, -8.02985955, 11.71587638, 41.0993915, -43.90830496, -34.46396749 }, nd4j::DataType::FLOAT32);
nd4j::ops::yuv_to_rgb op;
auto result = op.execute({ &yuv }, {}, {});
auto output = result->at(0);
ASSERT_EQ(Status::OK(), result->status());
ASSERT_TRUE(expected.isSameShape(output));
ASSERT_TRUE(expected.equalsTo(output));
delete result;
}
////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests15, test_yuv_to_rgb_5) {
// rank 3
NDArray expected('c', { 5,3,4 }, { 1.7750e+01f, -7.1062e+01f, -1.0019e+02f,-2.3406e+01f, 5.2094e+01f, 9.5438e+01f, -6.7461e+00f, 3.8562e+01f, 6.5078e+00f,3.3562e+01f, -5.8844e+01f, 2.2750e+01f, -1.0477e+01f, 7.7344e+00f, 9.5469e+00f,2.1391e+01f, -8.5312e+01f, 7.5830e-01f,2.3125e+01f, 1.8145e+00f, 1.4602e+01f,-4.5859e+00f, 3.9344e+01f, 1.1617e+01f,-8.6562e+01f, 1.0038e+02f, 6.7938e+01f,5.9961e+00f, 6.7812e+01f, 2.9734e+01f,2.9609e+01f, -6.1438e+01f, 1.7750e+01f,6.8562e+01f, -7.4414e+00f, 3.9656e+01f,1.1641e+01f, -2.7516e+01f, 6.7562e+01f,7.8438e+01f, 5.4883e+00f, 2.9438e+01f,-3.1344e+01f, 6.5125e+01f, 1.2695e+01f,4.0531e+01f, -6.1211e+00f, 6.2219e+01f,4.6812e+01f, 5.2250e+01f, -1.1414e+01f,1.5404e-02f, 2.9938e+01f, 5.6719e+00f,-2.0125e+01f, 2.1531e+01f, 6.2500e+01f,7.2188e+01f, 9.3750e+00f, -4.8125e+01f }, nd4j::DataType::FLOAT32);
NDArray yuv('c', { 5,3,4 }, { 36.628319, 38.600643,-40.624989, 18.231001, -14.822637, -2.479566, -8.965780, 2.223851, -16.561626,-96.205162,-52.255379,-36.527435,-51.546139,2.234915, 20.914114, 8.785358, 32.552223, -3.356598, 9.069552, 1.393482,36.029255, 4.824605,-9.972263,11.058715, 15.947105, 55.283543, 36.845627, -29.750486,0.887228, 6.534475, -21.794132,34.155693, -89.929497,39.562351, 27.276817,31.359871, 8.149521, 13.673355, 1.104303, 68.774300, 2.236881, 13.216944, -3.555702,-3.225931,3.063015, -36.134724,58.302204, 8.477802, 38.695396,27.181587, -14.157411,7.157054, 11.714512, 22.148155, 11.580557, -27.204905,7.120562, 21.992094, 2.406748, -6.265247, }, nd4j::DataType::FLOAT32);
nd4j::ops::yuv_to_rgb op;
auto result = op.execute({ &yuv }, {}, { 1 });
auto output = result->at(0);
ASSERT_EQ(Status::OK(), result->status());
ASSERT_TRUE(expected.isSameShape(output));
ASSERT_TRUE(expected.equalsTo(output));
delete result;
}
////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests15, test_yuv_to_rgb_6) {
// rank 3
NDArray yuv('c', { 3,5,4 }, { 1.7750e+01f, -7.1062e+01f, -1.0019e+02f,-2.3406e+01f, 5.2094e+01f, 9.5438e+01f, -6.7461e+00f, 3.8562e+01f, 6.5078e+00f,3.3562e+01f, -5.8844e+01f, 2.2750e+01f, -1.0477e+01f, 7.7344e+00f, 9.5469e+00f,2.1391e+01f, -8.5312e+01f, 7.5830e-01f,2.3125e+01f, 1.8145e+00f, 1.4602e+01f,-4.5859e+00f, 3.9344e+01f, 1.1617e+01f,-8.6562e+01f, 1.0038e+02f, 6.7938e+01f,5.9961e+00f, 6.7812e+01f, 2.9734e+01f,2.9609e+01f, -6.1438e+01f, 1.7750e+01f,6.8562e+01f, -7.4414e+00f, 3.9656e+01f,1.1641e+01f, -2.7516e+01f, 6.7562e+01f,7.8438e+01f, 5.4883e+00f, 2.9438e+01f,-3.1344e+01f, 6.5125e+01f, 1.2695e+01f,4.0531e+01f, -6.1211e+00f, 6.2219e+01f,4.6812e+01f, 5.2250e+01f, -1.1414e+01f,1.5404e-02f, 2.9938e+01f, 5.6719e+00f,-2.0125e+01f, 2.1531e+01f, 6.2500e+01f,7.2188e+01f, 9.3750e+00f, -4.8125e+01f }, nd4j::DataType::FLOAT32);
try {
nd4j::ops::yuv_to_rgb op;
auto result = op.execute({ &yuv }, {}, {});
ASSERT_EQ(Status::THROW(), result->status());
delete result;
}
catch (std::exception & e) {
nd4j_printf("Error should be here `%s'. It's OK.\n", e.what());
}
}
////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests15, test_yuv_to_rgb_7) {
// rank 3
NDArray expected('f', { 2, 2, 3 }, { 1.7750e+01f,-7.1062e+01f, -1.0019e+02f, -2.3406e+01f,5.2094e+01f,9.5438e+01f, -6.7461e+00f,3.8562e+01f, 6.5078e+00f, 3.3562e+01f,-5.8844e+01f,2.2750e+01f }, nd4j::DataType::FLOAT32);
NDArray yuv('f', { 2,2,3 }, { 36.628319, 38.600643, -40.624989, 18.231001, -14.822637, -2.479566, -8.965780, 2.223851, -16.561626, -96.205162, -52.255379, -36.527435 }, nd4j::DataType::FLOAT32);
nd4j::ops::yuv_to_rgb op;
auto result = op.execute({ &yuv }, {}, {});
auto output = result->at(0);
ASSERT_EQ(Status::OK(), result->status());
ASSERT_TRUE(expected.isSameShape(output));
ASSERT_TRUE(expected.equalsTo(output));
delete result;
}