RgbToYiq and YiqToRgb operations (#142)

* RgbToYiq and YiqToRgb

Signed-off-by: Abdelrauf <rauf@konduit.ai>

* CUDA impl for RgbToYiq and YiqToRgb

Signed-off-by: raver119 <raver119@gmail.com>

* remove print

Signed-off-by: raver119 <raver119@gmail.com>

* allow inplace for hsv,rgb,yiq ops

Signed-off-by: Abdelrauf <rauf@konduit.ai>

Co-authored-by: raver119 <raver119@gmail.com>
master
Abdelrauf 2019-12-24 16:20:35 +04:00 committed by raver119
parent 62f93ac211
commit 39d43ca170
9 changed files with 779 additions and 11 deletions

View File

@ -15,7 +15,7 @@
******************************************************************************/ ******************************************************************************/
// //
// @author Adel Rauf (rauf@konduit.ai) // @author AbdelRauf (rauf@konduit.ai)
// //
#include <ops/declarable/headers/images.h> #include <ops/declarable/headers/images.h>
@ -26,7 +26,7 @@
namespace nd4j { namespace nd4j {
namespace ops { namespace ops {
CONFIGURABLE_OP_IMPL(hsv_to_rgb, 1, 1, false, 0, 0) { CONFIGURABLE_OP_IMPL(hsv_to_rgb, 1, 1, true, 0, 0) {
auto input = INPUT_VARIABLE(0); auto input = INPUT_VARIABLE(0);
auto output = OUTPUT_VARIABLE(0); auto output = OUTPUT_VARIABLE(0);

View File

@ -15,7 +15,7 @@
******************************************************************************/ ******************************************************************************/
// //
// @author Adel Rauf (rauf@konduit.ai) // @author AbdelRauf (rauf@konduit.ai)
// //
@ -28,7 +28,7 @@
namespace nd4j { namespace nd4j {
namespace ops { namespace ops {
CONFIGURABLE_OP_IMPL(rgb_to_hsv, 1, 1, false, 0, 0) { CONFIGURABLE_OP_IMPL(rgb_to_hsv, 1, 1, true, 0, 0) {
auto input = INPUT_VARIABLE(0); auto input = INPUT_VARIABLE(0);
auto output = OUTPUT_VARIABLE(0); auto output = OUTPUT_VARIABLE(0);
@ -44,7 +44,7 @@ CONFIGURABLE_OP_IMPL(rgb_to_hsv, 1, 1, false, 0, 0) {
if (argSize > 0) { if (argSize > 0) {
REQUIRE_TRUE(dimC >= 0 && dimC < rank, 0, "Index of the Channel dimension out of range: %i not in [%i,%i) ", INT_ARG(0), -rank, rank); REQUIRE_TRUE(dimC >= 0 && dimC < rank, 0, "Index of the Channel dimension out of range: %i not in [%i,%i) ", INT_ARG(0), -rank, rank);
} }
REQUIRE_TRUE(input->sizeAt(dimC) == 3, 0, "RGBtoHSV: operation expects 3 channels (H, S, V), but got %i instead", input->sizeAt(dimC)); REQUIRE_TRUE(input->sizeAt(dimC) == 3, 0, "RGBtoHSV: operation expects 3 channels (R, G, B), but got %i instead", input->sizeAt(dimC));
helpers::transformRgbHsv(block.launchContext(), input, output, dimC); helpers::transformRgbHsv(block.launchContext(), input, output, dimC);

View File

@ -0,0 +1,60 @@
/*******************************************************************************
* Copyright (c) 2019 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// @author AbdelRauf (rauf@konduit.ai)
//
#include <ops/declarable/headers/images.h>
#include <ops/declarable/CustomOperations.h>
#include <helpers/ConstantTadHelper.h>
#include <execution/Threads.h>
namespace nd4j {
namespace ops {
CONFIGURABLE_OP_IMPL(rgb_to_yiq, 1, 1, true, 0, 0) {
auto input = INPUT_VARIABLE(0);
auto output = OUTPUT_VARIABLE(0);
if (input->isEmpty())
return Status::OK();
const int rank = input->rankOf();
const int arg_size = block.getIArguments()->size();
const int dimC = arg_size > 0 ? (INT_ARG(0) >= 0 ? INT_ARG(0) : INT_ARG(0) + rank) : rank - 1;
REQUIRE_TRUE(rank >= 1, 0, "RGBtoYIQ: Fails to meet the rank requirement: %i >= 1 ", rank);
if (arg_size > 0) {
REQUIRE_TRUE(dimC >= 0 && dimC < rank, 0, "Index of the Channel dimension out of range: %i not in [%i,%i) ", INT_ARG(0), -rank, rank);
}
REQUIRE_TRUE(input->sizeAt(dimC) == 3, 0, "RGBtoYIQ: operation expects 3 channels (R, G, B), but got %i instead", input->sizeAt(dimC));
helpers::transformRgbYiq(block.launchContext(), input, output, dimC);
return Status::OK();
}
DECLARE_TYPES(rgb_to_yiq) {
getOpDescriptor()->setAllowedInputTypes({ ALL_FLOATS })
->setSameMode(true);
}
}
}

View File

@ -0,0 +1,61 @@
/*******************************************************************************
* Copyright (c) 2019 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// @author AbdelRauf (rauf@konduit.ai)
//
#include <ops/declarable/headers/images.h>
#include <ops/declarable/CustomOperations.h>
#include <helpers/ConstantTadHelper.h>
#include <execution/Threads.h>
namespace nd4j {
namespace ops {
CONFIGURABLE_OP_IMPL(yiq_to_rgb, 1, 1, true, 0, 0) {
auto input = INPUT_VARIABLE(0);
auto output = OUTPUT_VARIABLE(0);
if (input->isEmpty())
return Status::OK();
const int rank = input->rankOf();
const int arg_size = block.getIArguments()->size();
const int dimC = arg_size > 0 ? (INT_ARG(0) >= 0 ? INT_ARG(0) : INT_ARG(0) + rank) : rank - 1;
REQUIRE_TRUE(rank >= 1, 0, "YIQtoRGB: Fails to meet the rank requirement: %i >= 1 ", rank);
if (arg_size > 0) {
REQUIRE_TRUE(dimC >= 0 && dimC < rank, 0, "Index of the Channel dimension out of range: %i not in [%i,%i) ", INT_ARG(0), -rank, rank);
}
REQUIRE_TRUE(input->sizeAt(dimC) == 3, 0, "YIQtoRGB: operation expects 3 channels (Y, I, Q), but got %i instead", input->sizeAt(dimC));
helpers::transformYiqRgb(block.launchContext(), input, output, dimC);
return Status::OK();
}
DECLARE_TYPES(yiq_to_rgb) {
getOpDescriptor()->setAllowedInputTypes({ ALL_FLOATS })
->setSameMode(true);
}
}
}

View File

@ -18,7 +18,7 @@
// @author Oleh Semeniv (oleg.semeniv@gmail.com) // @author Oleh Semeniv (oleg.semeniv@gmail.com)
// //
// //
// @author Adel Rauf (rauf@konduit.ai) // @author AbdelRauf (rauf@konduit.ai)
// //
#ifndef LIBND4J_HEADERS_IMAGES_H #ifndef LIBND4J_HEADERS_IMAGES_H
@ -65,6 +65,28 @@ namespace ops {
DECLARE_CUSTOM_OP(rgb_to_grs, 1, 1, false, 0, 0); DECLARE_CUSTOM_OP(rgb_to_grs, 1, 1, false, 0, 0);
#endif #endif
/**
* Rgb To Yiq
* Input arrays:
* 0 - input array with rank >= 1, must have at least one dimension equal 3, that is dimension containing channels.
* Int arguments:
* 0 - optional argument, corresponds to dimension with 3 channels
*/
#if NOT_EXCLUDED(OP_rgb_to_yiq)
DECLARE_CONFIGURABLE_OP(rgb_to_yiq, 1, 1, false, 0, 0);
#endif
/**
* Yiq To Rgb
* Input arrays:
* 0 - input array with rank >= 1, must have at least one dimension equal 3, that is dimension containing channels.
* Int arguments:
* 0 - optional argument, corresponds to dimension with 3 channels
*/
#if NOT_EXCLUDED(OP_yiq_to_rgb)
DECLARE_CONFIGURABLE_OP(yiq_to_rgb, 1, 1, false, 0, 0);
#endif
} }
} }

View File

@ -16,7 +16,7 @@
// //
// @author Oleh Semeniv (oleg.semeniv@gmail.com) // @author Oleh Semeniv (oleg.semeniv@gmail.com)
// @author Adel Rauf (rauf@konduit.ai) // @author AbdelRauf (rauf@konduit.ai)
// //
#include <ops/declarable/helpers/adjust_hue.h> #include <ops/declarable/helpers/adjust_hue.h>
@ -111,6 +111,64 @@ FORCEINLINE static void tripleTransformer(const NDArray* input, NDArray* output,
} }
template <typename T>
FORCEINLINE static void tripleTransformer(const NDArray* input, NDArray* output, const int dimC , T (&tr)[3][3] ) {
const int rank = input->rankOf();
const T* x = input->bufferAsT<T>();
T* z = output->bufferAsT<T>();
// TODO: Use tensordot or other optimizied helpers to see if we can get better performance.
if (dimC == rank - 1 && input->ews() == 1 && output->ews() == 1 && input->ordering() == 'c' && output->ordering() == 'c') {
auto func = PRAGMA_THREADS_FOR{
for (auto i = start; i < stop; i += increment) {
//simple M*v //tr.T*v.T // v * tr //rule: (AB)' =B'A'
// v.shape (1,3) row vector
T x0, x1, x2;
x0 = x[i]; //just additional hint
x1 = x[i + 1];
x2 = x[i + 2];
z[i] = x0 * tr[0][0] + x1 * tr[1][0] + x2 * tr[2][0];
z[i+1] = x0 * tr[0][1] + x1 * tr[1][1] + x2 * tr[2][1];
z[i+2] = x0 * tr[0][2] + x1 * tr[1][2] + x2 * tr[2][2];
}
};
samediff::Threads::parallel_for(func, 0, input->lengthOf(), 3);
}
else {
auto packX = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(input->getShapeInfo(), dimC);
auto packZ = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(output->getShapeInfo(), dimC);
const Nd4jLong numOfTads = packX.numberOfTads();
const Nd4jLong xDimCstride = input->stridesOf()[dimC];
const Nd4jLong zDimCstride = output->stridesOf()[dimC];
auto func = PRAGMA_THREADS_FOR{
for (auto i = start; i < stop; i += increment) {
const T* xTad = x + packX.platformOffsets()[i];
T* zTad = z + packZ.platformOffsets()[i];
//simple M*v //tr.T*v
T x0, x1, x2;
x0 = xTad[0];
x1 = xTad[xDimCstride];
x2 = xTad[2 * xDimCstride];
zTad[0] = x0 * tr[0][0] + x1 * tr[1][0] + x2 * tr[2][0];
zTad[zDimCstride] = x0 * tr[0][1] + x1 * tr[1][1] + x2 * tr[2][1];
zTad[2 * zDimCstride] = x0 * tr[0][2] + x1 * tr[1][2] + x2 * tr[2][2];
}
};
samediff::Threads::parallel_tad(func, 0, numOfTads);
}
}
template <typename T> template <typename T>
FORCEINLINE static void hsvRgb(const NDArray* input, NDArray* output, const int dimC) { FORCEINLINE static void hsvRgb(const NDArray* input, NDArray* output, const int dimC) {
@ -124,6 +182,31 @@ FORCEINLINE static void rgbHsv(const NDArray* input, NDArray* output, const int
return tripleTransformer<T>(input, output, dimC, op); return tripleTransformer<T>(input, output, dimC, op);
} }
template <typename T>
FORCEINLINE static void rgbYiq(const NDArray* input, NDArray* output, const int dimC) {
T arr[3][3] = {
{ (T)0.299, (T)0.59590059, (T)0.2115 },
{ (T)0.587, (T)-0.27455667, (T)-0.52273617 },
{ (T)0.114, (T)-0.32134392, (T)0.31119955 }
};
return tripleTransformer<T>(input, output, dimC, arr);
}
template <typename T>
FORCEINLINE static void yiqRgb(const NDArray* input, NDArray* output, const int dimC) {
//TODO: this operation does not use the clamp operation, so there is a possibility being out of range.
//Justify that it will not be out of range for images data
T arr[3][3] = {
{ (T)1, (T)1, (T)1 },
{ (T)0.95598634, (T)-0.27201283, (T)-1.10674021 },
{ (T)0.6208248, (T)-0.64720424, (T)1.70423049 }
};
return tripleTransformer<T>(input, output, dimC, arr);
}
void transformHsvRgb(nd4j::LaunchContext* context, const NDArray* input, NDArray* output, const int dimC) { void transformHsvRgb(nd4j::LaunchContext* context, const NDArray* input, NDArray* output, const int dimC) {
BUILD_SINGLE_SELECTOR(input->dataType(), hsvRgb, (input, output, dimC), FLOAT_TYPES); BUILD_SINGLE_SELECTOR(input->dataType(), hsvRgb, (input, output, dimC), FLOAT_TYPES);
} }
@ -132,6 +215,15 @@ void transformRgbHsv(nd4j::LaunchContext* context, const NDArray* input, NDArray
BUILD_SINGLE_SELECTOR(input->dataType(), rgbHsv, (input, output, dimC), FLOAT_TYPES); BUILD_SINGLE_SELECTOR(input->dataType(), rgbHsv, (input, output, dimC), FLOAT_TYPES);
} }
void transformYiqRgb(nd4j::LaunchContext* context, const NDArray* input, NDArray* output, const int dimC) {
BUILD_SINGLE_SELECTOR(input->dataType(), yiqRgb, (input, output, dimC), FLOAT_TYPES);
}
void transformRgbYiq(nd4j::LaunchContext* context, const NDArray* input, NDArray* output, const int dimC) {
BUILD_SINGLE_SELECTOR(input->dataType(), rgbYiq, (input, output, dimC), FLOAT_TYPES);
}
} }
} }
} }

View File

@ -211,12 +211,97 @@ void transformRgbHsv(nd4j::LaunchContext* context, const NDArray* input, NDArray
manager.synchronize(); manager.synchronize();
} }
template<typename T>
__global__ void tripleTransformerCuda(const void *vx, const Nd4jLong *xShapeInfo, const Nd4jLong *xTadShapeInfo, const Nd4jLong *xOffsets, void *vz, const Nd4jLong *zShapeInfo, const Nd4jLong *zTadShapeInfo, const Nd4jLong *zOffsets, const int dimC, int mode, uint64_t numTads) {
const auto x = reinterpret_cast<const T*>(vx);
auto z = reinterpret_cast<T*>(vz);
__shared__ Nd4jLong zLen, *sharedMem;
__shared__ int rank; // xRank == zRank
float yiqarr[3][3] = {
{ 0.299f, 0.59590059f, 0.2115f },
{ 0.587f, -0.27455667f, -0.52273617f },
{ 0.114f, -0.32134392f, 0.31119955f }
};
float rgbarr[3][3] = {
{ 1.f, 1.f, 1.f },
{ 0.95598634f, -0.27201283f, -1.10674021f },
{ 0.6208248f, -0.64720424f, 1.70423049f }
};
auto tr = mode == 1? yiqarr : rgbarr;
if (threadIdx.x == 0) {
extern __shared__ unsigned char shmem[];
sharedMem = reinterpret_cast<Nd4jLong*>(shmem);
zLen = shape::length(zShapeInfo);
rank = shape::rank(zShapeInfo);
}
__syncthreads();
Nd4jLong* coords = sharedMem + threadIdx.x * rank;
if (dimC == (rank - 1) && 'c' == shape::order(xShapeInfo) && 1 == shape::elementWiseStride(xShapeInfo) && 'c' == shape::order(zShapeInfo) && 1 == shape::elementWiseStride(zShapeInfo)) {
for (uint64_t f = blockIdx.x * blockDim.x + threadIdx.x; f < zLen / 3; f += gridDim.x * blockDim.x) {
auto i = f * 3;
auto xi0 = x[i];
auto xi1 = x[i+1];
auto xi2 = x[i+2];
for (int e = 0; e < 3; e++)
z[i + e] = xi0 * tr[0][e] + xi1 * tr[1][e] + xi2 * tr[2][e];
}
} else {
// TAD based case
const Nd4jLong xDimCstride = shape::stride(xShapeInfo)[dimC];
const Nd4jLong zDimCstride = shape::stride(zShapeInfo)[dimC];
for (uint64_t i = blockIdx.x * blockDim.x + threadIdx.x; i < numTads; i += blockDim.x * gridDim.x) {
const T* xTad = x + xOffsets[i];
T* zTad = z + zOffsets[i];
auto xi0 = xTad[0];
auto xi1 = xTad[xDimCstride];
auto xi2 = xTad[xDimCstride * 2];
for (int e = 0; e < 3; e++)
zTad[zDimCstride * e] = xi0 * tr[0][e] + xi1 * tr[1][e] + xi2 * tr[2][e];
}
}
}
template <typename T>
static void rgbYiq(nd4j::LaunchContext* context, const NDArray* input, NDArray* output, const int dimC) {
auto packX = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(input->getShapeInfo(), dimC);
auto packZ = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(output->getShapeInfo(), dimC);
NDArray::prepareSpecialUse({output}, {input});
return tripleTransformerCuda<T><<<256, 256, 8192, *context->getCudaStream()>>>(input->getSpecialBuffer(), input->getSpecialShapeInfo(), packX.platformShapeInfo(), packX.platformOffsets(), output->specialBuffer(), output->specialShapeInfo(), packZ.platformShapeInfo(), packZ.platformOffsets(), dimC, 1, packZ.numberOfTads());
NDArray::registerSpecialUse({output}, {input});
}
template <typename T>
FORCEINLINE static void yiqRgb(nd4j::LaunchContext* context, const NDArray* input, NDArray* output, const int dimC) {
auto packX = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(input->getShapeInfo(), dimC);
auto packZ = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(output->getShapeInfo(), dimC);
NDArray::prepareSpecialUse({output}, {input});
return tripleTransformerCuda<T><<<256, 256, 8192, *context->getCudaStream()>>>(input->getSpecialBuffer(), input->getSpecialShapeInfo(), packX.platformShapeInfo(), packX.platformOffsets(), output->specialBuffer(), output->specialShapeInfo(), packZ.platformShapeInfo(), packZ.platformOffsets(), dimC, 2, packZ.numberOfTads());
NDArray::registerSpecialUse({output}, {input});
}
void transformYiqRgb(nd4j::LaunchContext* context, const NDArray* input, NDArray* output, const int dimC) {
BUILD_SINGLE_SELECTOR(input->dataType(), yiqRgb, (context, input, output, dimC), FLOAT_TYPES);
}
void transformRgbYiq(nd4j::LaunchContext* context, const NDArray* input, NDArray* output, const int dimC) {
BUILD_SINGLE_SELECTOR(input->dataType(), rgbYiq, (context, input, output, dimC), FLOAT_TYPES);
}

View File

@ -18,7 +18,7 @@
// @author Oleh Semeniv (oleg.semeniv@gmail.com) // @author Oleh Semeniv (oleg.semeniv@gmail.com)
// //
// //
// @author Adel Rauf (rauf@konduit.ai) // @author AbdelRauf (rauf@konduit.ai)
// //
#ifndef LIBND4J_HELPERS_IMAGES_H #ifndef LIBND4J_HELPERS_IMAGES_H
@ -33,9 +33,14 @@ namespace ops {
namespace helpers { namespace helpers {
void transformRgbGrs(nd4j::LaunchContext* context, const NDArray& input, NDArray& output, const int dimC); void transformRgbGrs(nd4j::LaunchContext* context, const NDArray& input, NDArray& output, const int dimC);
void transformHsvRgb(nd4j::LaunchContext* context, const NDArray* input, NDArray* output, const int dimC); void transformHsvRgb(nd4j::LaunchContext* context, const NDArray* input, NDArray* output, const int dimC);
void transformRgbHsv(nd4j::LaunchContext* context, const NDArray* input, NDArray* output, const int dimC); void transformRgbHsv(nd4j::LaunchContext* context, const NDArray* input, NDArray* output, const int dimC);
void transformYiqRgb(nd4j::LaunchContext* context, const NDArray* input, NDArray* output, const int dimC);
void transformRgbYiq(nd4j::LaunchContext* context, const NDArray* input, NDArray* output, const int dimC);
} }
} }
} }

View File

@ -442,7 +442,7 @@ TEST_F(DeclarableOpsTests16, test_rgb_to_hsv_6) {
expected.reshapei({ 3 }); expected.reshapei({ 3 });
#if 0 #if 0
//[RANK][SHAPE][STRIDES][OPTIONS][EWS][ORDER] //[RANK][SHAPE][STRIDES][OPTIONS][EWS][ORDER]
subArrRgbs->printShapeInfo("subArrRgbs"); subArrRgbs.printShapeInfo("subArrRgbs");
#endif #endif
auto actual = NDArrayFactory::create<float>('c', { 3 }); auto actual = NDArrayFactory::create<float>('c', { 3 });
@ -642,7 +642,7 @@ TEST_F(DeclarableOpsTests16, test_hsv_to_rgb_6) {
expected.reshapei({ 3 }); expected.reshapei({ 3 });
#if 0 #if 0
//[RANK][SHAPE][STRIDES][OPTIONS][EWS][ORDER] //[RANK][SHAPE][STRIDES][OPTIONS][EWS][ORDER]
subArrHsvs->printShapeInfo("subArrHsvs"); subArrHsvs.printShapeInfo("subArrHsvs");
#endif #endif
Context ctx(1); Context ctx(1);
@ -655,3 +655,446 @@ TEST_F(DeclarableOpsTests16, test_hsv_to_rgb_6) {
ASSERT_TRUE(expected.equalsTo(actual)); ASSERT_TRUE(expected.equalsTo(actual));
} }
TEST_F(DeclarableOpsTests16, test_rgb_to_yiq_1) {
/**
generated using numpy
_rgb_to_yiq_kernel = np.array([[0.299f, 0.59590059f, 0.2115f],
[0.587f, -0.27455667f, -0.52273617f],
[0.114f, -0.32134392f, 0.31119955f]])
nnrgbs = np.array([random() for x in range(0,3*4*5)],np.float32).reshape([5,4,3])
out =np.tensordot(nnrgbs,_rgb_to_yiq_kernel,axes=[[len(nnrgbs.shape)-1],[0]])
#alternatively you could use just with apply
out_2=np.apply_along_axis(lambda x: _rgb_to_yiq_kernel.T @ x,len(nnrgbs.shape)-1,nnrgbs)
*/
auto rgb = NDArrayFactory::create<float>('c', { 5, 4 ,3 },
{
0.48055f , 0.80757356f, 0.2564435f , 0.94277316f, 0.17006584f,
0.33366168f, 0.41727918f, 0.54528666f, 0.48942474f, 0.3305715f ,
0.98633456f, 0.00158441f, 0.97605824f, 0.02462568f, 0.14837205f,
0.00112842f, 0.99260217f, 0.9585542f , 0.41196227f, 0.3095014f ,
0.6620493f , 0.30888894f, 0.3122602f , 0.7993488f , 0.86656475f,
0.5997049f , 0.9776477f , 0.72481847f, 0.7835693f , 0.14649455f,
0.3573504f , 0.33301765f, 0.7853056f , 0.25830218f, 0.59289205f,
0.41357264f, 0.5934154f , 0.72647524f, 0.6623308f , 0.96197623f,
0.0720306f , 0.23853847f, 0.1427159f , 0.19581454f, 0.06766324f,
0.10614152f, 0.26093867f, 0.9584985f , 0.01258832f, 0.8160156f ,
0.56506383f, 0.08418505f, 0.86440504f, 0.6807802f , 0.20662387f,
0.4153733f , 0.76146203f, 0.50057423f, 0.08274968f, 0.9521758f
});
auto expected = NDArrayFactory::create<float>('c', { 5, 4 ,3 },
{
0.64696468f, -0.01777124f, -0.24070648f, 0.41975525f, 0.40788622f,
0.21433232f, 0.50064416f, -0.05832884f, -0.04447775f, 0.67799989f,
-0.07432612f, -0.44518381f, 0.32321111f, 0.52719408f, 0.2397369f ,
0.69227005f, -0.57987869f, -0.22032876f, 0.38032767f, -0.05223263f,
0.13137188f, 0.3667803f , -0.15853189f, 0.15085728f, 0.72258149f,
0.03757231f, 0.17403452f, 0.69337627f, 0.16971045f, -0.21071186f,
0.39185397f, -0.13084008f, 0.145886f , 0.47240727f, -0.1417591f ,
-0.12659159f, 0.67937788f, -0.05867803f, -0.04813048f, 0.35710624f,
0.47681283f, 0.24003804f, 0.1653288f , 0.00953913f, -0.05111816f,
0.29417614f, -0.31640032f, 0.18433114f, 0.54718234f, -0.39812097f,
-0.24805083f, 0.61018603f, -0.40592682f, -0.22219216f, 0.39241133f,
-0.23560742f, 0.06353694f, 0.3067938f , -0.0304029f , 0.35893188f
});
auto actual = NDArrayFactory::create<float>('c', { 5, 4, 3 });
Context ctx(1);
ctx.setInputArray(0, &rgb);
ctx.setOutputArray(0, &actual);
nd4j::ops::rgb_to_yiq op;
auto status = op.execute(&ctx);
ASSERT_EQ(ND4J_STATUS_OK, status);
ASSERT_TRUE(expected.equalsTo(actual));
}
TEST_F(DeclarableOpsTests16, test_rgb_to_yiq_2) {
auto rgb = NDArrayFactory::create<float>('c', { 5, 3, 4 },
{
0.48055f , 0.94277316f, 0.41727918f, 0.3305715f , 0.80757356f,
0.17006584f, 0.54528666f, 0.98633456f, 0.2564435f , 0.33366168f,
0.48942474f, 0.00158441f, 0.97605824f, 0.00112842f, 0.41196227f,
0.30888894f, 0.02462568f, 0.99260217f, 0.3095014f , 0.3122602f ,
0.14837205f, 0.9585542f , 0.6620493f , 0.7993488f , 0.86656475f,
0.72481847f, 0.3573504f , 0.25830218f, 0.5997049f , 0.7835693f ,
0.33301765f, 0.59289205f, 0.9776477f , 0.14649455f, 0.7853056f ,
0.41357264f, 0.5934154f , 0.96197623f, 0.1427159f , 0.10614152f,
0.72647524f, 0.0720306f , 0.19581454f, 0.26093867f, 0.6623308f ,
0.23853847f, 0.06766324f, 0.9584985f , 0.01258832f, 0.08418505f,
0.20662387f, 0.50057423f, 0.8160156f , 0.86440504f, 0.4153733f ,
0.08274968f, 0.56506383f, 0.6807802f , 0.76146203f, 0.9521758f
});
auto expected = NDArrayFactory::create<float>('c', { 5, 3, 4 },
{
0.64696468f, 0.41975525f, 0.50064416f, 0.67799989f, -0.01777124f,
0.40788622f, -0.05832884f, -0.07432612f, -0.24070648f, 0.21433232f,
-0.04447775f, -0.44518381f, 0.32321111f, 0.69227005f, 0.38032767f,
0.3667803f , 0.52719408f, -0.57987869f, -0.05223263f, -0.15853189f,
0.2397369f , -0.22032876f, 0.13137188f, 0.15085728f, 0.72258149f,
0.69337627f, 0.39185397f, 0.47240727f, 0.03757231f, 0.16971045f,
-0.13084008f, -0.1417591f , 0.17403452f, -0.21071186f, 0.145886f ,
-0.12659159f, 0.67937788f, 0.35710624f, 0.1653288f , 0.29417614f,
-0.05867803f, 0.47681283f, 0.00953913f, -0.31640032f, -0.04813048f,
0.24003804f, -0.05111816f, 0.18433114f, 0.54718234f, 0.61018603f,
0.39241133f, 0.3067938f , -0.39812097f, -0.40592682f, -0.23560742f,
-0.0304029f , -0.24805083f, -0.22219216f, 0.06353694f, 0.35893188f
});
auto actual = NDArrayFactory::create<float>('c', { 5, 3, 4 });
Context ctx(1);
ctx.setInputArray(0, &rgb);
ctx.setOutputArray(0, &actual);
ctx.setIArguments({ 1 });
nd4j::ops::rgb_to_yiq op;
auto status = op.execute(&ctx);
ASSERT_EQ(ND4J_STATUS_OK, status);
ASSERT_TRUE(expected.equalsTo(actual));
}
TEST_F(DeclarableOpsTests16, test_rgb_to_yiq_3) {
auto rgb = NDArrayFactory::create<float>('c', { 4, 3 },
{
0.48055f , 0.80757356f, 0.2564435f , 0.94277316f, 0.17006584f,
0.33366168f, 0.41727918f, 0.54528666f, 0.48942474f, 0.3305715f ,
0.98633456f, 0.00158441f
});
auto expected = NDArrayFactory::create<float>('c', { 4, 3 },
{
0.64696468f, -0.01777124f, -0.24070648f, 0.41975525f, 0.40788622f,
0.21433232f, 0.50064416f, -0.05832884f, -0.04447775f, 0.67799989f,
-0.07432612f, -0.44518381f
});
auto actual = NDArrayFactory::create<float>('c', { 4, 3 });
Context ctx(1);
ctx.setInputArray(0, &rgb);
ctx.setOutputArray(0, &actual);
nd4j::ops::rgb_to_yiq op;
auto status = op.execute(&ctx);
ASSERT_EQ(ND4J_STATUS_OK, status);
ASSERT_TRUE(expected.equalsTo(actual));
}
TEST_F(DeclarableOpsTests16, test_rgb_to_yiq_4) {
auto rgb = NDArrayFactory::create<float>('c', { 3, 4 },
{
0.48055f , 0.94277316f, 0.41727918f, 0.3305715f , 0.80757356f,
0.17006584f, 0.54528666f, 0.98633456f, 0.2564435f , 0.33366168f,
0.48942474f, 0.00158441f
});
auto expected = NDArrayFactory::create<float>('c', { 3, 4 },
{
0.64696468f, 0.41975525f, 0.50064416f, 0.67799989f, -0.01777124f,
0.40788622f, -0.05832884f, -0.07432612f, -0.24070648f, 0.21433232f,
-0.04447775f, -0.44518381f
});
auto actual = NDArrayFactory::create<float>('c', { 3, 4 });
Context ctx(1);
ctx.setInputArray(0, &rgb);
ctx.setOutputArray(0, &actual);
ctx.setIArguments({ 0 });
nd4j::ops::rgb_to_yiq op;
auto status = op.execute(&ctx);
ASSERT_EQ(ND4J_STATUS_OK, status);
ASSERT_TRUE(expected.equalsTo(actual));
}
TEST_F(DeclarableOpsTests16, test_rgb_to_yiq_5) {
auto rgbs = NDArrayFactory::create<float>('c', { 3 },
{ 0.48055f , 0.80757356f, 0.2564435f });
auto expected = NDArrayFactory::create<float>('c', { 3 },
{ 0.64696468f, -0.01777124f, -0.24070648f, });
auto actual = NDArrayFactory::create<float>('c', { 3 });
Context ctx(1);
ctx.setInputArray(0, &rgbs);
ctx.setOutputArray(0, &actual);
nd4j::ops::rgb_to_yiq op;
auto status = op.execute(&ctx);
ASSERT_EQ(ND4J_STATUS_OK, status);
ASSERT_TRUE(expected.equalsTo(actual));
}
TEST_F(DeclarableOpsTests16, test_rgb_to_yiq_6) {
auto rgbs = NDArrayFactory::create<float>('c', { 3, 4 },
{
0.48055f , 0.94277316f, 0.41727918f, 0.3305715f , 0.80757356f,
0.17006584f, 0.54528666f, 0.98633456f, 0.2564435f , 0.33366168f,
0.48942474f, 0.00158441f
});
auto yiqs = NDArrayFactory::create<float>('c', { 3, 4 },
{
0.64696468f, 0.41975525f, 0.50064416f, 0.67799989f, -0.01777124f,
0.40788622f, -0.05832884f, -0.07432612f, -0.24070648f, 0.21433232f,
-0.04447775f, -0.44518381f
});
//get subarray
NDArray subArrRgbs = rgbs.subarray({ NDIndex::all(), NDIndex::point(0) });
NDArray expected = yiqs.subarray({ NDIndex::all(), NDIndex::point(0) });
subArrRgbs.reshapei({ 3 });
expected.reshapei({ 3 });
#if 0
//[RANK][SHAPE][STRIDES][OPTIONS][EWS][ORDER]
subArrRgbs.printShapeInfo("subArrRgbs");
#endif
auto actual = NDArrayFactory::create<float>('c', { 3 });
Context ctx(1);
ctx.setInputArray(0, &subArrRgbs);
ctx.setOutputArray(0, &actual);
nd4j::ops::rgb_to_yiq op;
auto status = op.execute(&ctx);
ASSERT_EQ(ND4J_STATUS_OK, status);
ASSERT_TRUE(expected.equalsTo(actual));
}
TEST_F(DeclarableOpsTests16, test_yiq_to_rgb_1) {
auto yiqs = NDArrayFactory::create<float>('c', { 5, 4, 3 }, {
0.775258899f, -0.288912386f, -0.132725924f, 0.0664454922f, -0.212469354f,
0.455438733f, 0.418221354f, 0.349350512f, 0.145902053f, 0.947576523f,
-0.471601307f, 0.263960421f, 0.700227439f, 0.32434237f, -0.278446227f,
0.130805135f, -0.438441873f, 0.187127829f, 0.0276055578f, -0.179727226f,
0.305075705f, 0.716282248f, 0.278215706f, -0.44586885f, 0.76971364f,
0.131288841f, -0.141177326f, 0.900081575f, -0.0788725987f, 0.14756602f,
0.387832165f, 0.229834676f, 0.47921446f, 0.632930398f, 0.0443540029f,
-0.268817365f, 0.0977194682f, -0.141669706f, -0.140715122f, 0.946808815f,
-0.52525419f, -0.106209636f, 0.659476519f, 0.391066104f, 0.426448852f,
0.496989518f, -0.283434421f, -0.177366048f, 0.715208411f, -0.496444523f,
0.189553142f, 0.616444945f, 0.345852494f, 0.447739422f, 0.224696323f,
0.451372236f, 0.298027098f, 0.446561724f, -0.187599331f, -0.448159873f
});
auto expected = NDArrayFactory::create<float>('c', { 5, 4, 3 }, {
0.416663059f, 0.939747555f, 0.868814286f, 0.146075352f, -0.170521997f,
1.07776645f, 0.842775284f, 0.228765106f, 0.280231822f, 0.660605291f,
0.905021825f, 1.91936605f, 0.837427991f, 0.792213732f, -0.133271854f,
-0.17216571f, 0.128957025f, 0.934955336f, 0.0451873479f, -0.120952621f,
0.746436225f, 0.705446224f, 0.929172217f, -0.351493549f, 0.807577594f,
0.825371955f, 0.383812296f, 0.916293093f, 0.82603058f, 1.23885956f,
0.905059196f, 0.015164554f, 0.950156781f, 0.508443732f, 0.794845279f,
0.12571529f, -0.125074273f, 0.227326869f, 0.0147000261f, 0.378735409f,
1.15842402f, 1.34712305f, 1.2980804f, 0.277102016f, 0.953435072f,
0.115916842f, 0.688879376f, 0.508405162f, 0.35829352f, 0.727568094f,
1.58768577f, 1.22504294f, 0.232589777f, 0.996727258f, 0.841224629f,
-0.0909671176f, 0.233051388f, -0.0110094378f, 0.787642119f, -0.109582274f
});
auto actual = NDArrayFactory::create<float>('c', { 5, 4, 3 });
Context ctx(1);
ctx.setInputArray(0, &yiqs);
ctx.setOutputArray(0, &actual);
nd4j::ops::yiq_to_rgb op;
auto status = op.execute(&ctx);
ASSERT_EQ(ND4J_STATUS_OK, status);
ASSERT_TRUE(expected.equalsTo(actual));
}
TEST_F(DeclarableOpsTests16, test_yiq_to_rgb_2) {
auto yiqs = NDArrayFactory::create<float>('c', { 5, 3, 4 }, {
0.775258899f, 0.0664454922f, 0.418221354f, 0.947576523f, -0.288912386f,
-0.212469354f, 0.349350512f, -0.471601307f, -0.132725924f, 0.455438733f,
0.145902053f, 0.263960421f, 0.700227439f, 0.130805135f, 0.0276055578f,
0.716282248f, 0.32434237f, -0.438441873f, -0.179727226f, 0.278215706f,
-0.278446227f, 0.187127829f, 0.305075705f, -0.44586885f, 0.76971364f,
0.900081575f, 0.387832165f, 0.632930398f, 0.131288841f, -0.0788725987f,
0.229834676f, 0.0443540029f, -0.141177326f, 0.14756602f, 0.47921446f,
-0.268817365f, 0.0977194682f, 0.946808815f, 0.659476519f, 0.496989518f,
-0.141669706f, -0.52525419f, 0.391066104f, -0.283434421f, -0.140715122f,
-0.106209636f, 0.426448852f, -0.177366048f, 0.715208411f, 0.616444945f,
0.224696323f, 0.446561724f, -0.496444523f, 0.345852494f, 0.451372236f,
-0.187599331f, 0.189553142f, 0.447739422f, 0.298027098f, -0.448159873f
});
auto expected = NDArrayFactory::create<float>('c', { 5, 3, 4 }, {
0.416663059f, 0.146075352f, 0.842775284f, 0.660605291f, 0.939747555f,
-0.170521997f, 0.228765106f, 0.905021825f, 0.868814286f, 1.07776645f,
0.280231822f, 1.91936605f, 0.837427991f, -0.17216571f, 0.0451873479f,
0.705446224f, 0.792213732f, 0.128957025f, -0.120952621f, 0.929172217f,
-0.133271854f, 0.934955336f, 0.746436225f, -0.351493549f, 0.807577594f,
0.916293093f, 0.905059196f, 0.508443732f, 0.825371955f, 0.82603058f,
0.015164554f, 0.794845279f, 0.383812296f, 1.23885956f, 0.950156781f,
0.12571529f, -0.125074273f, 0.378735409f, 1.2980804f, 0.115916842f,
0.227326869f, 1.15842402f, 0.277102016f, 0.688879376f, 0.0147000261f,
1.34712305f, 0.953435072f, 0.508405162f, 0.35829352f, 1.22504294f,
0.841224629f, -0.0110094378f, 0.727568094f, 0.232589777f, -0.0909671176f,
0.787642119f, 1.58768577f, 0.996727258f, 0.233051388f, -0.109582274f
});
auto actual = NDArrayFactory::create<float>('c', { 5, 3, 4 });
Context ctx(1);
ctx.setInputArray(0, &yiqs);
ctx.setOutputArray(0, &actual);
ctx.setIArguments({ 1 });
nd4j::ops::yiq_to_rgb op;
auto status = op.execute(&ctx);
ASSERT_EQ(ND4J_STATUS_OK, status);
ASSERT_TRUE(expected.equalsTo(actual));
}
TEST_F(DeclarableOpsTests16, test_yiq_to_rgb_3) {
auto yiqs = NDArrayFactory::create<float>('c', { 4, 3 }, {
0.775258899f, -0.288912386f, -0.132725924f, 0.0664454922f, -0.212469354f,
0.455438733f, 0.418221354f, 0.349350512f, 0.145902053f, 0.947576523f,
-0.471601307f, 0.263960421f
});
auto expected = NDArrayFactory::create<float>('c', { 4, 3 }, {
0.416663059f, 0.939747555f, 0.868814286f, 0.146075352f, -0.170521997f,
1.07776645f, 0.842775284f, 0.228765106f, 0.280231822f, 0.660605291f,
0.905021825f, 1.91936605f
});
auto actual = NDArrayFactory::create<float>('c', { 4, 3 });
Context ctx(1);
ctx.setInputArray(0, &yiqs);
ctx.setOutputArray(0, &actual);
nd4j::ops::yiq_to_rgb op;
auto status = op.execute(&ctx);
ASSERT_EQ(ND4J_STATUS_OK, status);
ASSERT_TRUE(expected.equalsTo(actual));
}
TEST_F(DeclarableOpsTests16, test_yiq_to_rgb_4) {
auto yiqs = NDArrayFactory::create<float>('c', { 3, 4 }, {
0.775258899f, 0.0664454922f, 0.418221354f, 0.947576523f, -0.288912386f,
-0.212469354f, 0.349350512f, -0.471601307f, -0.132725924f, 0.455438733f,
0.145902053f, 0.263960421f
});
auto expected = NDArrayFactory::create<float>('c', { 3, 4 }, {
0.416663059f, 0.146075352f, 0.842775284f, 0.660605291f, 0.939747555f,
-0.170521997f, 0.228765106f, 0.905021825f, 0.868814286f, 1.07776645f,
0.280231822f, 1.91936605f
});
auto actual = NDArrayFactory::create<float>('c', { 3, 4 });
Context ctx(1);
ctx.setInputArray(0, &yiqs);
ctx.setOutputArray(0, &actual);
ctx.setIArguments({ 0 });
nd4j::ops::yiq_to_rgb op;
auto status = op.execute(&ctx);
ASSERT_EQ(ND4J_STATUS_OK, status);
ASSERT_TRUE(expected.equalsTo(actual));
}
TEST_F(DeclarableOpsTests16, test_yiq_to_rgb_5) {
auto yiqs = NDArrayFactory::create<float>('c', { 3 }, {
0.775258899f, -0.288912386f, -0.132725924f
});
auto expected = NDArrayFactory::create<float>('c', { 3 }, {
0.416663059f, 0.939747555f, 0.868814286f
});
auto actual = NDArrayFactory::create<float>('c', { 3 });
Context ctx(1);
ctx.setInputArray(0, &yiqs);
ctx.setOutputArray(0, &actual);
nd4j::ops::yiq_to_rgb op;
auto status = op.execute(&ctx);
#if 0
actual.printBuffer("actual");
expected.printBuffer("expected");
#endif
ASSERT_EQ(ND4J_STATUS_OK, status);
ASSERT_TRUE(expected.equalsTo(actual));
}
TEST_F(DeclarableOpsTests16, test_yiq_to_rgb_6) {
auto yiqs = NDArrayFactory::create<float>('c', { 3, 4 }, {
0.775258899f, 0.0664454922f, 0.418221354f, 0.947576523f, -0.288912386f,
-0.212469354f, 0.349350512f, -0.471601307f, -0.132725924f, 0.455438733f,
0.145902053f, 0.263960421f
});
auto rgbs = NDArrayFactory::create<float>('c', { 3, 4 }, {
0.416663059f, 0.146075352f, 0.842775284f, 0.660605291f, 0.939747555f,
-0.170521997f, 0.228765106f, 0.905021825f, 0.868814286f, 1.07776645f,
0.280231822f, 1.91936605f
});
//get subarray
NDArray subArrYiqs = yiqs.subarray({ NDIndex::all(), NDIndex::point(0) });
NDArray expected = rgbs.subarray({ NDIndex::all(), NDIndex::point(0) });
subArrYiqs.reshapei({ 3 });
expected.reshapei({ 3 });
#if 0
//[RANK][SHAPE][STRIDES][OPTIONS][EWS][ORDER]
subArrYiqs.printShapeInfo("subArrYiqs");
#endif
auto actual = NDArrayFactory::create<float>('c', { 3 });
Context ctx(1);
ctx.setInputArray(0, &subArrYiqs);
ctx.setOutputArray(0, &actual);
nd4j::ops::yiq_to_rgb op;
auto status = op.execute(&ctx);
ASSERT_EQ(ND4J_STATUS_OK, status);
ASSERT_TRUE(expected.equalsTo(actual));
}