RgbToYiq and YiqToRgb operations (#142)
* RgbToYiq and YiqToRgb Signed-off-by: Abdelrauf <rauf@konduit.ai> * CUDA impl for RgbToYiq and YiqToRgb Signed-off-by: raver119 <raver119@gmail.com> * remove print Signed-off-by: raver119 <raver119@gmail.com> * allow inplace for hsv,rgb,yiq ops Signed-off-by: Abdelrauf <rauf@konduit.ai> Co-authored-by: raver119 <raver119@gmail.com>master
parent
62f93ac211
commit
39d43ca170
|
@ -15,7 +15,7 @@
|
|||
******************************************************************************/
|
||||
|
||||
//
|
||||
// @author Adel Rauf (rauf@konduit.ai)
|
||||
// @author AbdelRauf (rauf@konduit.ai)
|
||||
//
|
||||
|
||||
#include <ops/declarable/headers/images.h>
|
||||
|
@ -26,7 +26,7 @@
|
|||
namespace nd4j {
|
||||
namespace ops {
|
||||
|
||||
CONFIGURABLE_OP_IMPL(hsv_to_rgb, 1, 1, false, 0, 0) {
|
||||
CONFIGURABLE_OP_IMPL(hsv_to_rgb, 1, 1, true, 0, 0) {
|
||||
|
||||
auto input = INPUT_VARIABLE(0);
|
||||
auto output = OUTPUT_VARIABLE(0);
|
||||
|
|
|
@ -15,7 +15,7 @@
|
|||
******************************************************************************/
|
||||
|
||||
//
|
||||
// @author Adel Rauf (rauf@konduit.ai)
|
||||
// @author AbdelRauf (rauf@konduit.ai)
|
||||
//
|
||||
|
||||
|
||||
|
@ -28,7 +28,7 @@
|
|||
namespace nd4j {
|
||||
namespace ops {
|
||||
|
||||
CONFIGURABLE_OP_IMPL(rgb_to_hsv, 1, 1, false, 0, 0) {
|
||||
CONFIGURABLE_OP_IMPL(rgb_to_hsv, 1, 1, true, 0, 0) {
|
||||
|
||||
auto input = INPUT_VARIABLE(0);
|
||||
auto output = OUTPUT_VARIABLE(0);
|
||||
|
@ -44,7 +44,7 @@ CONFIGURABLE_OP_IMPL(rgb_to_hsv, 1, 1, false, 0, 0) {
|
|||
if (argSize > 0) {
|
||||
REQUIRE_TRUE(dimC >= 0 && dimC < rank, 0, "Index of the Channel dimension out of range: %i not in [%i,%i) ", INT_ARG(0), -rank, rank);
|
||||
}
|
||||
REQUIRE_TRUE(input->sizeAt(dimC) == 3, 0, "RGBtoHSV: operation expects 3 channels (H, S, V), but got %i instead", input->sizeAt(dimC));
|
||||
REQUIRE_TRUE(input->sizeAt(dimC) == 3, 0, "RGBtoHSV: operation expects 3 channels (R, G, B), but got %i instead", input->sizeAt(dimC));
|
||||
|
||||
helpers::transformRgbHsv(block.launchContext(), input, output, dimC);
|
||||
|
||||
|
|
|
@ -0,0 +1,60 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2019 Konduit K.K.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
//
|
||||
// @author AbdelRauf (rauf@konduit.ai)
|
||||
//
|
||||
|
||||
#include <ops/declarable/headers/images.h>
|
||||
#include <ops/declarable/CustomOperations.h>
|
||||
#include <helpers/ConstantTadHelper.h>
|
||||
#include <execution/Threads.h>
|
||||
|
||||
namespace nd4j {
|
||||
namespace ops {
|
||||
|
||||
|
||||
|
||||
CONFIGURABLE_OP_IMPL(rgb_to_yiq, 1, 1, true, 0, 0) {
|
||||
|
||||
auto input = INPUT_VARIABLE(0);
|
||||
auto output = OUTPUT_VARIABLE(0);
|
||||
|
||||
if (input->isEmpty())
|
||||
return Status::OK();
|
||||
|
||||
const int rank = input->rankOf();
|
||||
const int arg_size = block.getIArguments()->size();
|
||||
const int dimC = arg_size > 0 ? (INT_ARG(0) >= 0 ? INT_ARG(0) : INT_ARG(0) + rank) : rank - 1;
|
||||
|
||||
REQUIRE_TRUE(rank >= 1, 0, "RGBtoYIQ: Fails to meet the rank requirement: %i >= 1 ", rank);
|
||||
if (arg_size > 0) {
|
||||
REQUIRE_TRUE(dimC >= 0 && dimC < rank, 0, "Index of the Channel dimension out of range: %i not in [%i,%i) ", INT_ARG(0), -rank, rank);
|
||||
}
|
||||
REQUIRE_TRUE(input->sizeAt(dimC) == 3, 0, "RGBtoYIQ: operation expects 3 channels (R, G, B), but got %i instead", input->sizeAt(dimC));
|
||||
|
||||
helpers::transformRgbYiq(block.launchContext(), input, output, dimC);
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
||||
DECLARE_TYPES(rgb_to_yiq) {
|
||||
getOpDescriptor()->setAllowedInputTypes({ ALL_FLOATS })
|
||||
->setSameMode(true);
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,61 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2019 Konduit K.K.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
//
|
||||
// @author AbdelRauf (rauf@konduit.ai)
|
||||
//
|
||||
|
||||
#include <ops/declarable/headers/images.h>
|
||||
#include <ops/declarable/CustomOperations.h>
|
||||
#include <helpers/ConstantTadHelper.h>
|
||||
#include <execution/Threads.h>
|
||||
|
||||
namespace nd4j {
|
||||
namespace ops {
|
||||
|
||||
|
||||
|
||||
CONFIGURABLE_OP_IMPL(yiq_to_rgb, 1, 1, true, 0, 0) {
|
||||
|
||||
auto input = INPUT_VARIABLE(0);
|
||||
auto output = OUTPUT_VARIABLE(0);
|
||||
|
||||
if (input->isEmpty())
|
||||
return Status::OK();
|
||||
|
||||
const int rank = input->rankOf();
|
||||
const int arg_size = block.getIArguments()->size();
|
||||
const int dimC = arg_size > 0 ? (INT_ARG(0) >= 0 ? INT_ARG(0) : INT_ARG(0) + rank) : rank - 1;
|
||||
|
||||
REQUIRE_TRUE(rank >= 1, 0, "YIQtoRGB: Fails to meet the rank requirement: %i >= 1 ", rank);
|
||||
if (arg_size > 0) {
|
||||
REQUIRE_TRUE(dimC >= 0 && dimC < rank, 0, "Index of the Channel dimension out of range: %i not in [%i,%i) ", INT_ARG(0), -rank, rank);
|
||||
}
|
||||
REQUIRE_TRUE(input->sizeAt(dimC) == 3, 0, "YIQtoRGB: operation expects 3 channels (Y, I, Q), but got %i instead", input->sizeAt(dimC));
|
||||
|
||||
helpers::transformYiqRgb(block.launchContext(), input, output, dimC);
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
||||
DECLARE_TYPES(yiq_to_rgb) {
|
||||
getOpDescriptor()->setAllowedInputTypes({ ALL_FLOATS })
|
||||
->setSameMode(true);
|
||||
}
|
||||
|
||||
}
|
||||
}
|
|
@ -18,7 +18,7 @@
|
|||
// @author Oleh Semeniv (oleg.semeniv@gmail.com)
|
||||
//
|
||||
//
|
||||
// @author Adel Rauf (rauf@konduit.ai)
|
||||
// @author AbdelRauf (rauf@konduit.ai)
|
||||
//
|
||||
|
||||
#ifndef LIBND4J_HEADERS_IMAGES_H
|
||||
|
@ -65,6 +65,28 @@ namespace ops {
|
|||
DECLARE_CUSTOM_OP(rgb_to_grs, 1, 1, false, 0, 0);
|
||||
#endif
|
||||
|
||||
/**
|
||||
* Rgb To Yiq
|
||||
* Input arrays:
|
||||
* 0 - input array with rank >= 1, must have at least one dimension equal 3, that is dimension containing channels.
|
||||
* Int arguments:
|
||||
* 0 - optional argument, corresponds to dimension with 3 channels
|
||||
*/
|
||||
#if NOT_EXCLUDED(OP_rgb_to_yiq)
|
||||
DECLARE_CONFIGURABLE_OP(rgb_to_yiq, 1, 1, false, 0, 0);
|
||||
#endif
|
||||
|
||||
/**
|
||||
* Yiq To Rgb
|
||||
* Input arrays:
|
||||
* 0 - input array with rank >= 1, must have at least one dimension equal 3, that is dimension containing channels.
|
||||
* Int arguments:
|
||||
* 0 - optional argument, corresponds to dimension with 3 channels
|
||||
*/
|
||||
#if NOT_EXCLUDED(OP_yiq_to_rgb)
|
||||
DECLARE_CONFIGURABLE_OP(yiq_to_rgb, 1, 1, false, 0, 0);
|
||||
#endif
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -16,7 +16,7 @@
|
|||
|
||||
//
|
||||
// @author Oleh Semeniv (oleg.semeniv@gmail.com)
|
||||
// @author Adel Rauf (rauf@konduit.ai)
|
||||
// @author AbdelRauf (rauf@konduit.ai)
|
||||
//
|
||||
|
||||
#include <ops/declarable/helpers/adjust_hue.h>
|
||||
|
@ -111,6 +111,64 @@ FORCEINLINE static void tripleTransformer(const NDArray* input, NDArray* output,
|
|||
}
|
||||
|
||||
|
||||
template <typename T>
|
||||
FORCEINLINE static void tripleTransformer(const NDArray* input, NDArray* output, const int dimC , T (&tr)[3][3] ) {
|
||||
|
||||
const int rank = input->rankOf();
|
||||
|
||||
const T* x = input->bufferAsT<T>();
|
||||
T* z = output->bufferAsT<T>();
|
||||
// TODO: Use tensordot or other optimizied helpers to see if we can get better performance.
|
||||
|
||||
if (dimC == rank - 1 && input->ews() == 1 && output->ews() == 1 && input->ordering() == 'c' && output->ordering() == 'c') {
|
||||
|
||||
auto func = PRAGMA_THREADS_FOR{
|
||||
for (auto i = start; i < stop; i += increment) {
|
||||
//simple M*v //tr.T*v.T // v * tr //rule: (AB)' =B'A'
|
||||
// v.shape (1,3) row vector
|
||||
T x0, x1, x2;
|
||||
x0 = x[i]; //just additional hint
|
||||
x1 = x[i + 1];
|
||||
x2 = x[i + 2];
|
||||
z[i] = x0 * tr[0][0] + x1 * tr[1][0] + x2 * tr[2][0];
|
||||
z[i+1] = x0 * tr[0][1] + x1 * tr[1][1] + x2 * tr[2][1];
|
||||
z[i+2] = x0 * tr[0][2] + x1 * tr[1][2] + x2 * tr[2][2];
|
||||
|
||||
}
|
||||
};
|
||||
|
||||
samediff::Threads::parallel_for(func, 0, input->lengthOf(), 3);
|
||||
}
|
||||
else {
|
||||
auto packX = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(input->getShapeInfo(), dimC);
|
||||
auto packZ = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(output->getShapeInfo(), dimC);
|
||||
|
||||
const Nd4jLong numOfTads = packX.numberOfTads();
|
||||
const Nd4jLong xDimCstride = input->stridesOf()[dimC];
|
||||
const Nd4jLong zDimCstride = output->stridesOf()[dimC];
|
||||
|
||||
auto func = PRAGMA_THREADS_FOR{
|
||||
for (auto i = start; i < stop; i += increment) {
|
||||
const T* xTad = x + packX.platformOffsets()[i];
|
||||
T* zTad = z + packZ.platformOffsets()[i];
|
||||
//simple M*v //tr.T*v
|
||||
T x0, x1, x2;
|
||||
x0 = xTad[0];
|
||||
x1 = xTad[xDimCstride];
|
||||
x2 = xTad[2 * xDimCstride];
|
||||
zTad[0] = x0 * tr[0][0] + x1 * tr[1][0] + x2 * tr[2][0];
|
||||
zTad[zDimCstride] = x0 * tr[0][1] + x1 * tr[1][1] + x2 * tr[2][1];
|
||||
zTad[2 * zDimCstride] = x0 * tr[0][2] + x1 * tr[1][2] + x2 * tr[2][2];
|
||||
|
||||
}
|
||||
};
|
||||
|
||||
samediff::Threads::parallel_tad(func, 0, numOfTads);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
||||
template <typename T>
|
||||
FORCEINLINE static void hsvRgb(const NDArray* input, NDArray* output, const int dimC) {
|
||||
|
@ -124,6 +182,31 @@ FORCEINLINE static void rgbHsv(const NDArray* input, NDArray* output, const int
|
|||
return tripleTransformer<T>(input, output, dimC, op);
|
||||
}
|
||||
|
||||
|
||||
template <typename T>
|
||||
FORCEINLINE static void rgbYiq(const NDArray* input, NDArray* output, const int dimC) {
|
||||
T arr[3][3] = {
|
||||
{ (T)0.299, (T)0.59590059, (T)0.2115 },
|
||||
{ (T)0.587, (T)-0.27455667, (T)-0.52273617 },
|
||||
{ (T)0.114, (T)-0.32134392, (T)0.31119955 }
|
||||
};
|
||||
return tripleTransformer<T>(input, output, dimC, arr);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
FORCEINLINE static void yiqRgb(const NDArray* input, NDArray* output, const int dimC) {
|
||||
//TODO: this operation does not use the clamp operation, so there is a possibility being out of range.
|
||||
//Justify that it will not be out of range for images data
|
||||
T arr[3][3] = {
|
||||
{ (T)1, (T)1, (T)1 },
|
||||
{ (T)0.95598634, (T)-0.27201283, (T)-1.10674021 },
|
||||
{ (T)0.6208248, (T)-0.64720424, (T)1.70423049 }
|
||||
};
|
||||
return tripleTransformer<T>(input, output, dimC, arr);
|
||||
}
|
||||
|
||||
|
||||
|
||||
void transformHsvRgb(nd4j::LaunchContext* context, const NDArray* input, NDArray* output, const int dimC) {
|
||||
BUILD_SINGLE_SELECTOR(input->dataType(), hsvRgb, (input, output, dimC), FLOAT_TYPES);
|
||||
}
|
||||
|
@ -132,6 +215,15 @@ void transformRgbHsv(nd4j::LaunchContext* context, const NDArray* input, NDArray
|
|||
BUILD_SINGLE_SELECTOR(input->dataType(), rgbHsv, (input, output, dimC), FLOAT_TYPES);
|
||||
}
|
||||
|
||||
void transformYiqRgb(nd4j::LaunchContext* context, const NDArray* input, NDArray* output, const int dimC) {
|
||||
BUILD_SINGLE_SELECTOR(input->dataType(), yiqRgb, (input, output, dimC), FLOAT_TYPES);
|
||||
}
|
||||
|
||||
void transformRgbYiq(nd4j::LaunchContext* context, const NDArray* input, NDArray* output, const int dimC) {
|
||||
BUILD_SINGLE_SELECTOR(input->dataType(), rgbYiq, (input, output, dimC), FLOAT_TYPES);
|
||||
}
|
||||
|
||||
|
||||
}
|
||||
}
|
||||
}
|
|
@ -211,12 +211,97 @@ void transformRgbHsv(nd4j::LaunchContext* context, const NDArray* input, NDArray
|
|||
manager.synchronize();
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
__global__ void tripleTransformerCuda(const void *vx, const Nd4jLong *xShapeInfo, const Nd4jLong *xTadShapeInfo, const Nd4jLong *xOffsets, void *vz, const Nd4jLong *zShapeInfo, const Nd4jLong *zTadShapeInfo, const Nd4jLong *zOffsets, const int dimC, int mode, uint64_t numTads) {
|
||||
const auto x = reinterpret_cast<const T*>(vx);
|
||||
auto z = reinterpret_cast<T*>(vz);
|
||||
|
||||
__shared__ Nd4jLong zLen, *sharedMem;
|
||||
__shared__ int rank; // xRank == zRank
|
||||
|
||||
float yiqarr[3][3] = {
|
||||
{ 0.299f, 0.59590059f, 0.2115f },
|
||||
{ 0.587f, -0.27455667f, -0.52273617f },
|
||||
{ 0.114f, -0.32134392f, 0.31119955f }
|
||||
};
|
||||
|
||||
float rgbarr[3][3] = {
|
||||
{ 1.f, 1.f, 1.f },
|
||||
{ 0.95598634f, -0.27201283f, -1.10674021f },
|
||||
{ 0.6208248f, -0.64720424f, 1.70423049f }
|
||||
};
|
||||
|
||||
auto tr = mode == 1? yiqarr : rgbarr;
|
||||
|
||||
if (threadIdx.x == 0) {
|
||||
extern __shared__ unsigned char shmem[];
|
||||
sharedMem = reinterpret_cast<Nd4jLong*>(shmem);
|
||||
|
||||
zLen = shape::length(zShapeInfo);
|
||||
rank = shape::rank(zShapeInfo);
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
Nd4jLong* coords = sharedMem + threadIdx.x * rank;
|
||||
|
||||
if (dimC == (rank - 1) && 'c' == shape::order(xShapeInfo) && 1 == shape::elementWiseStride(xShapeInfo) && 'c' == shape::order(zShapeInfo) && 1 == shape::elementWiseStride(zShapeInfo)) {
|
||||
for (uint64_t f = blockIdx.x * blockDim.x + threadIdx.x; f < zLen / 3; f += gridDim.x * blockDim.x) {
|
||||
auto i = f * 3;
|
||||
|
||||
auto xi0 = x[i];
|
||||
auto xi1 = x[i+1];
|
||||
auto xi2 = x[i+2];
|
||||
|
||||
for (int e = 0; e < 3; e++)
|
||||
z[i + e] = xi0 * tr[0][e] + xi1 * tr[1][e] + xi2 * tr[2][e];
|
||||
}
|
||||
} else {
|
||||
// TAD based case
|
||||
const Nd4jLong xDimCstride = shape::stride(xShapeInfo)[dimC];
|
||||
const Nd4jLong zDimCstride = shape::stride(zShapeInfo)[dimC];
|
||||
|
||||
for (uint64_t i = blockIdx.x * blockDim.x + threadIdx.x; i < numTads; i += blockDim.x * gridDim.x) {
|
||||
const T* xTad = x + xOffsets[i];
|
||||
T* zTad = z + zOffsets[i];
|
||||
|
||||
auto xi0 = xTad[0];
|
||||
auto xi1 = xTad[xDimCstride];
|
||||
auto xi2 = xTad[xDimCstride * 2];
|
||||
|
||||
for (int e = 0; e < 3; e++)
|
||||
zTad[zDimCstride * e] = xi0 * tr[0][e] + xi1 * tr[1][e] + xi2 * tr[2][e];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
template <typename T>
|
||||
static void rgbYiq(nd4j::LaunchContext* context, const NDArray* input, NDArray* output, const int dimC) {
|
||||
auto packX = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(input->getShapeInfo(), dimC);
|
||||
auto packZ = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(output->getShapeInfo(), dimC);
|
||||
|
||||
NDArray::prepareSpecialUse({output}, {input});
|
||||
return tripleTransformerCuda<T><<<256, 256, 8192, *context->getCudaStream()>>>(input->getSpecialBuffer(), input->getSpecialShapeInfo(), packX.platformShapeInfo(), packX.platformOffsets(), output->specialBuffer(), output->specialShapeInfo(), packZ.platformShapeInfo(), packZ.platformOffsets(), dimC, 1, packZ.numberOfTads());
|
||||
NDArray::registerSpecialUse({output}, {input});
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
FORCEINLINE static void yiqRgb(nd4j::LaunchContext* context, const NDArray* input, NDArray* output, const int dimC) {
|
||||
auto packX = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(input->getShapeInfo(), dimC);
|
||||
auto packZ = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(output->getShapeInfo(), dimC);
|
||||
|
||||
NDArray::prepareSpecialUse({output}, {input});
|
||||
return tripleTransformerCuda<T><<<256, 256, 8192, *context->getCudaStream()>>>(input->getSpecialBuffer(), input->getSpecialShapeInfo(), packX.platformShapeInfo(), packX.platformOffsets(), output->specialBuffer(), output->specialShapeInfo(), packZ.platformShapeInfo(), packZ.platformOffsets(), dimC, 2, packZ.numberOfTads());
|
||||
NDArray::registerSpecialUse({output}, {input});
|
||||
}
|
||||
|
||||
void transformYiqRgb(nd4j::LaunchContext* context, const NDArray* input, NDArray* output, const int dimC) {
|
||||
BUILD_SINGLE_SELECTOR(input->dataType(), yiqRgb, (context, input, output, dimC), FLOAT_TYPES);
|
||||
}
|
||||
|
||||
void transformRgbYiq(nd4j::LaunchContext* context, const NDArray* input, NDArray* output, const int dimC) {
|
||||
BUILD_SINGLE_SELECTOR(input->dataType(), rgbYiq, (context, input, output, dimC), FLOAT_TYPES);
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
|
|
@ -18,7 +18,7 @@
|
|||
// @author Oleh Semeniv (oleg.semeniv@gmail.com)
|
||||
//
|
||||
//
|
||||
// @author Adel Rauf (rauf@konduit.ai)
|
||||
// @author AbdelRauf (rauf@konduit.ai)
|
||||
//
|
||||
|
||||
#ifndef LIBND4J_HELPERS_IMAGES_H
|
||||
|
@ -33,9 +33,14 @@ namespace ops {
|
|||
namespace helpers {
|
||||
|
||||
void transformRgbGrs(nd4j::LaunchContext* context, const NDArray& input, NDArray& output, const int dimC);
|
||||
|
||||
void transformHsvRgb(nd4j::LaunchContext* context, const NDArray* input, NDArray* output, const int dimC);
|
||||
|
||||
void transformRgbHsv(nd4j::LaunchContext* context, const NDArray* input, NDArray* output, const int dimC);
|
||||
|
||||
void transformYiqRgb(nd4j::LaunchContext* context, const NDArray* input, NDArray* output, const int dimC);
|
||||
|
||||
void transformRgbYiq(nd4j::LaunchContext* context, const NDArray* input, NDArray* output, const int dimC);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -442,7 +442,7 @@ TEST_F(DeclarableOpsTests16, test_rgb_to_hsv_6) {
|
|||
expected.reshapei({ 3 });
|
||||
#if 0
|
||||
//[RANK][SHAPE][STRIDES][OPTIONS][EWS][ORDER]
|
||||
subArrRgbs->printShapeInfo("subArrRgbs");
|
||||
subArrRgbs.printShapeInfo("subArrRgbs");
|
||||
#endif
|
||||
auto actual = NDArrayFactory::create<float>('c', { 3 });
|
||||
|
||||
|
@ -636,13 +636,13 @@ TEST_F(DeclarableOpsTests16, test_hsv_to_rgb_6) {
|
|||
|
||||
auto actual = NDArrayFactory::create<float>('c', { 3 });
|
||||
//get subarray
|
||||
NDArray subArrHsvs = hsvs.subarray({ NDIndex::all(), NDIndex::point(0) });
|
||||
NDArray subArrHsvs = hsvs.subarray({ NDIndex::all(), NDIndex::point(0) });
|
||||
subArrHsvs.reshapei({ 3 });
|
||||
NDArray expected = rgbs.subarray({ NDIndex::all(), NDIndex::point(0) });
|
||||
expected.reshapei({ 3 });
|
||||
#if 0
|
||||
//[RANK][SHAPE][STRIDES][OPTIONS][EWS][ORDER]
|
||||
subArrHsvs->printShapeInfo("subArrHsvs");
|
||||
subArrHsvs.printShapeInfo("subArrHsvs");
|
||||
#endif
|
||||
|
||||
Context ctx(1);
|
||||
|
@ -655,3 +655,446 @@ TEST_F(DeclarableOpsTests16, test_hsv_to_rgb_6) {
|
|||
ASSERT_TRUE(expected.equalsTo(actual));
|
||||
|
||||
}
|
||||
|
||||
|
||||
|
||||
TEST_F(DeclarableOpsTests16, test_rgb_to_yiq_1) {
|
||||
/**
|
||||
generated using numpy
|
||||
_rgb_to_yiq_kernel = np.array([[0.299f, 0.59590059f, 0.2115f],
|
||||
[0.587f, -0.27455667f, -0.52273617f],
|
||||
[0.114f, -0.32134392f, 0.31119955f]])
|
||||
nnrgbs = np.array([random() for x in range(0,3*4*5)],np.float32).reshape([5,4,3])
|
||||
out =np.tensordot(nnrgbs,_rgb_to_yiq_kernel,axes=[[len(nnrgbs.shape)-1],[0]])
|
||||
|
||||
#alternatively you could use just with apply
|
||||
out_2=np.apply_along_axis(lambda x: _rgb_to_yiq_kernel.T @ x,len(nnrgbs.shape)-1,nnrgbs)
|
||||
|
||||
*/
|
||||
auto rgb = NDArrayFactory::create<float>('c', { 5, 4 ,3 },
|
||||
{
|
||||
0.48055f , 0.80757356f, 0.2564435f , 0.94277316f, 0.17006584f,
|
||||
0.33366168f, 0.41727918f, 0.54528666f, 0.48942474f, 0.3305715f ,
|
||||
0.98633456f, 0.00158441f, 0.97605824f, 0.02462568f, 0.14837205f,
|
||||
0.00112842f, 0.99260217f, 0.9585542f , 0.41196227f, 0.3095014f ,
|
||||
0.6620493f , 0.30888894f, 0.3122602f , 0.7993488f , 0.86656475f,
|
||||
0.5997049f , 0.9776477f , 0.72481847f, 0.7835693f , 0.14649455f,
|
||||
0.3573504f , 0.33301765f, 0.7853056f , 0.25830218f, 0.59289205f,
|
||||
0.41357264f, 0.5934154f , 0.72647524f, 0.6623308f , 0.96197623f,
|
||||
0.0720306f , 0.23853847f, 0.1427159f , 0.19581454f, 0.06766324f,
|
||||
0.10614152f, 0.26093867f, 0.9584985f , 0.01258832f, 0.8160156f ,
|
||||
0.56506383f, 0.08418505f, 0.86440504f, 0.6807802f , 0.20662387f,
|
||||
0.4153733f , 0.76146203f, 0.50057423f, 0.08274968f, 0.9521758f
|
||||
});
|
||||
|
||||
auto expected = NDArrayFactory::create<float>('c', { 5, 4 ,3 },
|
||||
{
|
||||
0.64696468f, -0.01777124f, -0.24070648f, 0.41975525f, 0.40788622f,
|
||||
0.21433232f, 0.50064416f, -0.05832884f, -0.04447775f, 0.67799989f,
|
||||
-0.07432612f, -0.44518381f, 0.32321111f, 0.52719408f, 0.2397369f ,
|
||||
0.69227005f, -0.57987869f, -0.22032876f, 0.38032767f, -0.05223263f,
|
||||
0.13137188f, 0.3667803f , -0.15853189f, 0.15085728f, 0.72258149f,
|
||||
0.03757231f, 0.17403452f, 0.69337627f, 0.16971045f, -0.21071186f,
|
||||
0.39185397f, -0.13084008f, 0.145886f , 0.47240727f, -0.1417591f ,
|
||||
-0.12659159f, 0.67937788f, -0.05867803f, -0.04813048f, 0.35710624f,
|
||||
0.47681283f, 0.24003804f, 0.1653288f , 0.00953913f, -0.05111816f,
|
||||
0.29417614f, -0.31640032f, 0.18433114f, 0.54718234f, -0.39812097f,
|
||||
-0.24805083f, 0.61018603f, -0.40592682f, -0.22219216f, 0.39241133f,
|
||||
-0.23560742f, 0.06353694f, 0.3067938f , -0.0304029f , 0.35893188f
|
||||
});
|
||||
|
||||
auto actual = NDArrayFactory::create<float>('c', { 5, 4, 3 });
|
||||
|
||||
Context ctx(1);
|
||||
ctx.setInputArray(0, &rgb);
|
||||
ctx.setOutputArray(0, &actual);
|
||||
|
||||
nd4j::ops::rgb_to_yiq op;
|
||||
auto status = op.execute(&ctx);
|
||||
|
||||
ASSERT_EQ(ND4J_STATUS_OK, status);
|
||||
ASSERT_TRUE(expected.equalsTo(actual));
|
||||
|
||||
}
|
||||
|
||||
|
||||
|
||||
TEST_F(DeclarableOpsTests16, test_rgb_to_yiq_2) {
|
||||
|
||||
auto rgb = NDArrayFactory::create<float>('c', { 5, 3, 4 },
|
||||
{
|
||||
0.48055f , 0.94277316f, 0.41727918f, 0.3305715f , 0.80757356f,
|
||||
0.17006584f, 0.54528666f, 0.98633456f, 0.2564435f , 0.33366168f,
|
||||
0.48942474f, 0.00158441f, 0.97605824f, 0.00112842f, 0.41196227f,
|
||||
0.30888894f, 0.02462568f, 0.99260217f, 0.3095014f , 0.3122602f ,
|
||||
0.14837205f, 0.9585542f , 0.6620493f , 0.7993488f , 0.86656475f,
|
||||
0.72481847f, 0.3573504f , 0.25830218f, 0.5997049f , 0.7835693f ,
|
||||
0.33301765f, 0.59289205f, 0.9776477f , 0.14649455f, 0.7853056f ,
|
||||
0.41357264f, 0.5934154f , 0.96197623f, 0.1427159f , 0.10614152f,
|
||||
0.72647524f, 0.0720306f , 0.19581454f, 0.26093867f, 0.6623308f ,
|
||||
0.23853847f, 0.06766324f, 0.9584985f , 0.01258832f, 0.08418505f,
|
||||
0.20662387f, 0.50057423f, 0.8160156f , 0.86440504f, 0.4153733f ,
|
||||
0.08274968f, 0.56506383f, 0.6807802f , 0.76146203f, 0.9521758f
|
||||
});
|
||||
|
||||
auto expected = NDArrayFactory::create<float>('c', { 5, 3, 4 },
|
||||
{
|
||||
0.64696468f, 0.41975525f, 0.50064416f, 0.67799989f, -0.01777124f,
|
||||
0.40788622f, -0.05832884f, -0.07432612f, -0.24070648f, 0.21433232f,
|
||||
-0.04447775f, -0.44518381f, 0.32321111f, 0.69227005f, 0.38032767f,
|
||||
0.3667803f , 0.52719408f, -0.57987869f, -0.05223263f, -0.15853189f,
|
||||
0.2397369f , -0.22032876f, 0.13137188f, 0.15085728f, 0.72258149f,
|
||||
0.69337627f, 0.39185397f, 0.47240727f, 0.03757231f, 0.16971045f,
|
||||
-0.13084008f, -0.1417591f , 0.17403452f, -0.21071186f, 0.145886f ,
|
||||
-0.12659159f, 0.67937788f, 0.35710624f, 0.1653288f , 0.29417614f,
|
||||
-0.05867803f, 0.47681283f, 0.00953913f, -0.31640032f, -0.04813048f,
|
||||
0.24003804f, -0.05111816f, 0.18433114f, 0.54718234f, 0.61018603f,
|
||||
0.39241133f, 0.3067938f , -0.39812097f, -0.40592682f, -0.23560742f,
|
||||
-0.0304029f , -0.24805083f, -0.22219216f, 0.06353694f, 0.35893188f
|
||||
});
|
||||
|
||||
auto actual = NDArrayFactory::create<float>('c', { 5, 3, 4 });
|
||||
|
||||
Context ctx(1);
|
||||
ctx.setInputArray(0, &rgb);
|
||||
ctx.setOutputArray(0, &actual);
|
||||
ctx.setIArguments({ 1 });
|
||||
nd4j::ops::rgb_to_yiq op;
|
||||
auto status = op.execute(&ctx);
|
||||
|
||||
ASSERT_EQ(ND4J_STATUS_OK, status);
|
||||
ASSERT_TRUE(expected.equalsTo(actual));
|
||||
|
||||
}
|
||||
|
||||
|
||||
TEST_F(DeclarableOpsTests16, test_rgb_to_yiq_3) {
|
||||
|
||||
auto rgb = NDArrayFactory::create<float>('c', { 4, 3 },
|
||||
{
|
||||
0.48055f , 0.80757356f, 0.2564435f , 0.94277316f, 0.17006584f,
|
||||
0.33366168f, 0.41727918f, 0.54528666f, 0.48942474f, 0.3305715f ,
|
||||
0.98633456f, 0.00158441f
|
||||
});
|
||||
|
||||
auto expected = NDArrayFactory::create<float>('c', { 4, 3 },
|
||||
{
|
||||
0.64696468f, -0.01777124f, -0.24070648f, 0.41975525f, 0.40788622f,
|
||||
0.21433232f, 0.50064416f, -0.05832884f, -0.04447775f, 0.67799989f,
|
||||
-0.07432612f, -0.44518381f
|
||||
});
|
||||
|
||||
auto actual = NDArrayFactory::create<float>('c', { 4, 3 });
|
||||
|
||||
Context ctx(1);
|
||||
ctx.setInputArray(0, &rgb);
|
||||
ctx.setOutputArray(0, &actual);
|
||||
|
||||
nd4j::ops::rgb_to_yiq op;
|
||||
auto status = op.execute(&ctx);
|
||||
|
||||
ASSERT_EQ(ND4J_STATUS_OK, status);
|
||||
ASSERT_TRUE(expected.equalsTo(actual));
|
||||
|
||||
}
|
||||
|
||||
|
||||
TEST_F(DeclarableOpsTests16, test_rgb_to_yiq_4) {
|
||||
|
||||
auto rgb = NDArrayFactory::create<float>('c', { 3, 4 },
|
||||
{
|
||||
0.48055f , 0.94277316f, 0.41727918f, 0.3305715f , 0.80757356f,
|
||||
0.17006584f, 0.54528666f, 0.98633456f, 0.2564435f , 0.33366168f,
|
||||
0.48942474f, 0.00158441f
|
||||
});
|
||||
|
||||
auto expected = NDArrayFactory::create<float>('c', { 3, 4 },
|
||||
{
|
||||
0.64696468f, 0.41975525f, 0.50064416f, 0.67799989f, -0.01777124f,
|
||||
0.40788622f, -0.05832884f, -0.07432612f, -0.24070648f, 0.21433232f,
|
||||
-0.04447775f, -0.44518381f
|
||||
});
|
||||
|
||||
auto actual = NDArrayFactory::create<float>('c', { 3, 4 });
|
||||
|
||||
Context ctx(1);
|
||||
ctx.setInputArray(0, &rgb);
|
||||
ctx.setOutputArray(0, &actual);
|
||||
ctx.setIArguments({ 0 });
|
||||
nd4j::ops::rgb_to_yiq op;
|
||||
auto status = op.execute(&ctx);
|
||||
|
||||
ASSERT_EQ(ND4J_STATUS_OK, status);
|
||||
ASSERT_TRUE(expected.equalsTo(actual));
|
||||
|
||||
}
|
||||
|
||||
|
||||
|
||||
TEST_F(DeclarableOpsTests16, test_rgb_to_yiq_5) {
|
||||
|
||||
auto rgbs = NDArrayFactory::create<float>('c', { 3 },
|
||||
{ 0.48055f , 0.80757356f, 0.2564435f });
|
||||
auto expected = NDArrayFactory::create<float>('c', { 3 },
|
||||
{ 0.64696468f, -0.01777124f, -0.24070648f, });
|
||||
|
||||
|
||||
auto actual = NDArrayFactory::create<float>('c', { 3 });
|
||||
|
||||
Context ctx(1);
|
||||
ctx.setInputArray(0, &rgbs);
|
||||
ctx.setOutputArray(0, &actual);
|
||||
|
||||
nd4j::ops::rgb_to_yiq op;
|
||||
auto status = op.execute(&ctx);
|
||||
|
||||
ASSERT_EQ(ND4J_STATUS_OK, status);
|
||||
ASSERT_TRUE(expected.equalsTo(actual));
|
||||
|
||||
}
|
||||
|
||||
TEST_F(DeclarableOpsTests16, test_rgb_to_yiq_6) {
|
||||
|
||||
auto rgbs = NDArrayFactory::create<float>('c', { 3, 4 },
|
||||
{
|
||||
0.48055f , 0.94277316f, 0.41727918f, 0.3305715f , 0.80757356f,
|
||||
0.17006584f, 0.54528666f, 0.98633456f, 0.2564435f , 0.33366168f,
|
||||
0.48942474f, 0.00158441f
|
||||
});
|
||||
|
||||
auto yiqs = NDArrayFactory::create<float>('c', { 3, 4 },
|
||||
{
|
||||
0.64696468f, 0.41975525f, 0.50064416f, 0.67799989f, -0.01777124f,
|
||||
0.40788622f, -0.05832884f, -0.07432612f, -0.24070648f, 0.21433232f,
|
||||
-0.04447775f, -0.44518381f
|
||||
});
|
||||
|
||||
//get subarray
|
||||
NDArray subArrRgbs = rgbs.subarray({ NDIndex::all(), NDIndex::point(0) });
|
||||
NDArray expected = yiqs.subarray({ NDIndex::all(), NDIndex::point(0) });
|
||||
subArrRgbs.reshapei({ 3 });
|
||||
expected.reshapei({ 3 });
|
||||
#if 0
|
||||
//[RANK][SHAPE][STRIDES][OPTIONS][EWS][ORDER]
|
||||
subArrRgbs.printShapeInfo("subArrRgbs");
|
||||
#endif
|
||||
auto actual = NDArrayFactory::create<float>('c', { 3 });
|
||||
|
||||
Context ctx(1);
|
||||
ctx.setInputArray(0, &subArrRgbs);
|
||||
ctx.setOutputArray(0, &actual);
|
||||
nd4j::ops::rgb_to_yiq op;
|
||||
auto status = op.execute(&ctx);
|
||||
|
||||
ASSERT_EQ(ND4J_STATUS_OK, status);
|
||||
ASSERT_TRUE(expected.equalsTo(actual));
|
||||
|
||||
}
|
||||
|
||||
|
||||
|
||||
TEST_F(DeclarableOpsTests16, test_yiq_to_rgb_1) {
|
||||
|
||||
auto yiqs = NDArrayFactory::create<float>('c', { 5, 4, 3 }, {
|
||||
0.775258899f, -0.288912386f, -0.132725924f, 0.0664454922f, -0.212469354f,
|
||||
0.455438733f, 0.418221354f, 0.349350512f, 0.145902053f, 0.947576523f,
|
||||
-0.471601307f, 0.263960421f, 0.700227439f, 0.32434237f, -0.278446227f,
|
||||
0.130805135f, -0.438441873f, 0.187127829f, 0.0276055578f, -0.179727226f,
|
||||
0.305075705f, 0.716282248f, 0.278215706f, -0.44586885f, 0.76971364f,
|
||||
0.131288841f, -0.141177326f, 0.900081575f, -0.0788725987f, 0.14756602f,
|
||||
0.387832165f, 0.229834676f, 0.47921446f, 0.632930398f, 0.0443540029f,
|
||||
-0.268817365f, 0.0977194682f, -0.141669706f, -0.140715122f, 0.946808815f,
|
||||
-0.52525419f, -0.106209636f, 0.659476519f, 0.391066104f, 0.426448852f,
|
||||
0.496989518f, -0.283434421f, -0.177366048f, 0.715208411f, -0.496444523f,
|
||||
0.189553142f, 0.616444945f, 0.345852494f, 0.447739422f, 0.224696323f,
|
||||
0.451372236f, 0.298027098f, 0.446561724f, -0.187599331f, -0.448159873f
|
||||
});
|
||||
auto expected = NDArrayFactory::create<float>('c', { 5, 4, 3 }, {
|
||||
0.416663059f, 0.939747555f, 0.868814286f, 0.146075352f, -0.170521997f,
|
||||
1.07776645f, 0.842775284f, 0.228765106f, 0.280231822f, 0.660605291f,
|
||||
0.905021825f, 1.91936605f, 0.837427991f, 0.792213732f, -0.133271854f,
|
||||
-0.17216571f, 0.128957025f, 0.934955336f, 0.0451873479f, -0.120952621f,
|
||||
0.746436225f, 0.705446224f, 0.929172217f, -0.351493549f, 0.807577594f,
|
||||
0.825371955f, 0.383812296f, 0.916293093f, 0.82603058f, 1.23885956f,
|
||||
0.905059196f, 0.015164554f, 0.950156781f, 0.508443732f, 0.794845279f,
|
||||
0.12571529f, -0.125074273f, 0.227326869f, 0.0147000261f, 0.378735409f,
|
||||
1.15842402f, 1.34712305f, 1.2980804f, 0.277102016f, 0.953435072f,
|
||||
0.115916842f, 0.688879376f, 0.508405162f, 0.35829352f, 0.727568094f,
|
||||
1.58768577f, 1.22504294f, 0.232589777f, 0.996727258f, 0.841224629f,
|
||||
-0.0909671176f, 0.233051388f, -0.0110094378f, 0.787642119f, -0.109582274f
|
||||
});
|
||||
auto actual = NDArrayFactory::create<float>('c', { 5, 4, 3 });
|
||||
|
||||
Context ctx(1);
|
||||
ctx.setInputArray(0, &yiqs);
|
||||
ctx.setOutputArray(0, &actual);
|
||||
|
||||
nd4j::ops::yiq_to_rgb op;
|
||||
auto status = op.execute(&ctx);
|
||||
|
||||
ASSERT_EQ(ND4J_STATUS_OK, status);
|
||||
ASSERT_TRUE(expected.equalsTo(actual));
|
||||
|
||||
}
|
||||
|
||||
|
||||
|
||||
TEST_F(DeclarableOpsTests16, test_yiq_to_rgb_2) {
|
||||
|
||||
auto yiqs = NDArrayFactory::create<float>('c', { 5, 3, 4 }, {
|
||||
0.775258899f, 0.0664454922f, 0.418221354f, 0.947576523f, -0.288912386f,
|
||||
-0.212469354f, 0.349350512f, -0.471601307f, -0.132725924f, 0.455438733f,
|
||||
0.145902053f, 0.263960421f, 0.700227439f, 0.130805135f, 0.0276055578f,
|
||||
0.716282248f, 0.32434237f, -0.438441873f, -0.179727226f, 0.278215706f,
|
||||
-0.278446227f, 0.187127829f, 0.305075705f, -0.44586885f, 0.76971364f,
|
||||
0.900081575f, 0.387832165f, 0.632930398f, 0.131288841f, -0.0788725987f,
|
||||
0.229834676f, 0.0443540029f, -0.141177326f, 0.14756602f, 0.47921446f,
|
||||
-0.268817365f, 0.0977194682f, 0.946808815f, 0.659476519f, 0.496989518f,
|
||||
-0.141669706f, -0.52525419f, 0.391066104f, -0.283434421f, -0.140715122f,
|
||||
-0.106209636f, 0.426448852f, -0.177366048f, 0.715208411f, 0.616444945f,
|
||||
0.224696323f, 0.446561724f, -0.496444523f, 0.345852494f, 0.451372236f,
|
||||
-0.187599331f, 0.189553142f, 0.447739422f, 0.298027098f, -0.448159873f
|
||||
});
|
||||
auto expected = NDArrayFactory::create<float>('c', { 5, 3, 4 }, {
|
||||
0.416663059f, 0.146075352f, 0.842775284f, 0.660605291f, 0.939747555f,
|
||||
-0.170521997f, 0.228765106f, 0.905021825f, 0.868814286f, 1.07776645f,
|
||||
0.280231822f, 1.91936605f, 0.837427991f, -0.17216571f, 0.0451873479f,
|
||||
0.705446224f, 0.792213732f, 0.128957025f, -0.120952621f, 0.929172217f,
|
||||
-0.133271854f, 0.934955336f, 0.746436225f, -0.351493549f, 0.807577594f,
|
||||
0.916293093f, 0.905059196f, 0.508443732f, 0.825371955f, 0.82603058f,
|
||||
0.015164554f, 0.794845279f, 0.383812296f, 1.23885956f, 0.950156781f,
|
||||
0.12571529f, -0.125074273f, 0.378735409f, 1.2980804f, 0.115916842f,
|
||||
0.227326869f, 1.15842402f, 0.277102016f, 0.688879376f, 0.0147000261f,
|
||||
1.34712305f, 0.953435072f, 0.508405162f, 0.35829352f, 1.22504294f,
|
||||
0.841224629f, -0.0110094378f, 0.727568094f, 0.232589777f, -0.0909671176f,
|
||||
0.787642119f, 1.58768577f, 0.996727258f, 0.233051388f, -0.109582274f
|
||||
});
|
||||
auto actual = NDArrayFactory::create<float>('c', { 5, 3, 4 });
|
||||
|
||||
Context ctx(1);
|
||||
ctx.setInputArray(0, &yiqs);
|
||||
ctx.setOutputArray(0, &actual);
|
||||
ctx.setIArguments({ 1 });
|
||||
nd4j::ops::yiq_to_rgb op;
|
||||
auto status = op.execute(&ctx);
|
||||
|
||||
ASSERT_EQ(ND4J_STATUS_OK, status);
|
||||
ASSERT_TRUE(expected.equalsTo(actual));
|
||||
|
||||
}
|
||||
|
||||
|
||||
TEST_F(DeclarableOpsTests16, test_yiq_to_rgb_3) {
|
||||
|
||||
auto yiqs = NDArrayFactory::create<float>('c', { 4, 3 }, {
|
||||
0.775258899f, -0.288912386f, -0.132725924f, 0.0664454922f, -0.212469354f,
|
||||
0.455438733f, 0.418221354f, 0.349350512f, 0.145902053f, 0.947576523f,
|
||||
-0.471601307f, 0.263960421f
|
||||
});
|
||||
auto expected = NDArrayFactory::create<float>('c', { 4, 3 }, {
|
||||
0.416663059f, 0.939747555f, 0.868814286f, 0.146075352f, -0.170521997f,
|
||||
1.07776645f, 0.842775284f, 0.228765106f, 0.280231822f, 0.660605291f,
|
||||
0.905021825f, 1.91936605f
|
||||
});
|
||||
auto actual = NDArrayFactory::create<float>('c', { 4, 3 });
|
||||
|
||||
Context ctx(1);
|
||||
ctx.setInputArray(0, &yiqs);
|
||||
ctx.setOutputArray(0, &actual);
|
||||
|
||||
nd4j::ops::yiq_to_rgb op;
|
||||
auto status = op.execute(&ctx);
|
||||
|
||||
ASSERT_EQ(ND4J_STATUS_OK, status);
|
||||
ASSERT_TRUE(expected.equalsTo(actual));
|
||||
|
||||
}
|
||||
|
||||
|
||||
TEST_F(DeclarableOpsTests16, test_yiq_to_rgb_4) {
|
||||
|
||||
auto yiqs = NDArrayFactory::create<float>('c', { 3, 4 }, {
|
||||
0.775258899f, 0.0664454922f, 0.418221354f, 0.947576523f, -0.288912386f,
|
||||
-0.212469354f, 0.349350512f, -0.471601307f, -0.132725924f, 0.455438733f,
|
||||
0.145902053f, 0.263960421f
|
||||
});
|
||||
auto expected = NDArrayFactory::create<float>('c', { 3, 4 }, {
|
||||
0.416663059f, 0.146075352f, 0.842775284f, 0.660605291f, 0.939747555f,
|
||||
-0.170521997f, 0.228765106f, 0.905021825f, 0.868814286f, 1.07776645f,
|
||||
0.280231822f, 1.91936605f
|
||||
});
|
||||
auto actual = NDArrayFactory::create<float>('c', { 3, 4 });
|
||||
|
||||
Context ctx(1);
|
||||
ctx.setInputArray(0, &yiqs);
|
||||
ctx.setOutputArray(0, &actual);
|
||||
ctx.setIArguments({ 0 });
|
||||
nd4j::ops::yiq_to_rgb op;
|
||||
auto status = op.execute(&ctx);
|
||||
|
||||
ASSERT_EQ(ND4J_STATUS_OK, status);
|
||||
ASSERT_TRUE(expected.equalsTo(actual));
|
||||
|
||||
}
|
||||
|
||||
|
||||
|
||||
TEST_F(DeclarableOpsTests16, test_yiq_to_rgb_5) {
|
||||
|
||||
auto yiqs = NDArrayFactory::create<float>('c', { 3 }, {
|
||||
0.775258899f, -0.288912386f, -0.132725924f
|
||||
});
|
||||
auto expected = NDArrayFactory::create<float>('c', { 3 }, {
|
||||
0.416663059f, 0.939747555f, 0.868814286f
|
||||
});
|
||||
auto actual = NDArrayFactory::create<float>('c', { 3 });
|
||||
|
||||
Context ctx(1);
|
||||
ctx.setInputArray(0, &yiqs);
|
||||
ctx.setOutputArray(0, &actual);
|
||||
|
||||
nd4j::ops::yiq_to_rgb op;
|
||||
auto status = op.execute(&ctx);
|
||||
#if 0
|
||||
actual.printBuffer("actual");
|
||||
expected.printBuffer("expected");
|
||||
#endif
|
||||
ASSERT_EQ(ND4J_STATUS_OK, status);
|
||||
ASSERT_TRUE(expected.equalsTo(actual));
|
||||
|
||||
}
|
||||
|
||||
TEST_F(DeclarableOpsTests16, test_yiq_to_rgb_6) {
|
||||
|
||||
auto yiqs = NDArrayFactory::create<float>('c', { 3, 4 }, {
|
||||
0.775258899f, 0.0664454922f, 0.418221354f, 0.947576523f, -0.288912386f,
|
||||
-0.212469354f, 0.349350512f, -0.471601307f, -0.132725924f, 0.455438733f,
|
||||
0.145902053f, 0.263960421f
|
||||
});
|
||||
auto rgbs = NDArrayFactory::create<float>('c', { 3, 4 }, {
|
||||
0.416663059f, 0.146075352f, 0.842775284f, 0.660605291f, 0.939747555f,
|
||||
-0.170521997f, 0.228765106f, 0.905021825f, 0.868814286f, 1.07776645f,
|
||||
0.280231822f, 1.91936605f
|
||||
});
|
||||
|
||||
//get subarray
|
||||
NDArray subArrYiqs = yiqs.subarray({ NDIndex::all(), NDIndex::point(0) });
|
||||
NDArray expected = rgbs.subarray({ NDIndex::all(), NDIndex::point(0) });
|
||||
subArrYiqs.reshapei({ 3 });
|
||||
expected.reshapei({ 3 });
|
||||
#if 0
|
||||
//[RANK][SHAPE][STRIDES][OPTIONS][EWS][ORDER]
|
||||
subArrYiqs.printShapeInfo("subArrYiqs");
|
||||
#endif
|
||||
auto actual = NDArrayFactory::create<float>('c', { 3 });
|
||||
|
||||
Context ctx(1);
|
||||
ctx.setInputArray(0, &subArrYiqs);
|
||||
ctx.setOutputArray(0, &actual);
|
||||
nd4j::ops::yiq_to_rgb op;
|
||||
auto status = op.execute(&ctx);
|
||||
|
||||
ASSERT_EQ(ND4J_STATUS_OK, status);
|
||||
ASSERT_TRUE(expected.equalsTo(actual));
|
||||
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue