RgbToYiq and YiqToRgb operations (#142)
* RgbToYiq and YiqToRgb Signed-off-by: Abdelrauf <rauf@konduit.ai> * CUDA impl for RgbToYiq and YiqToRgb Signed-off-by: raver119 <raver119@gmail.com> * remove print Signed-off-by: raver119 <raver119@gmail.com> * allow inplace for hsv,rgb,yiq ops Signed-off-by: Abdelrauf <rauf@konduit.ai> Co-authored-by: raver119 <raver119@gmail.com>master
parent
62f93ac211
commit
39d43ca170
|
@ -15,7 +15,7 @@
|
||||||
******************************************************************************/
|
******************************************************************************/
|
||||||
|
|
||||||
//
|
//
|
||||||
// @author Adel Rauf (rauf@konduit.ai)
|
// @author AbdelRauf (rauf@konduit.ai)
|
||||||
//
|
//
|
||||||
|
|
||||||
#include <ops/declarable/headers/images.h>
|
#include <ops/declarable/headers/images.h>
|
||||||
|
@ -26,7 +26,7 @@
|
||||||
namespace nd4j {
|
namespace nd4j {
|
||||||
namespace ops {
|
namespace ops {
|
||||||
|
|
||||||
CONFIGURABLE_OP_IMPL(hsv_to_rgb, 1, 1, false, 0, 0) {
|
CONFIGURABLE_OP_IMPL(hsv_to_rgb, 1, 1, true, 0, 0) {
|
||||||
|
|
||||||
auto input = INPUT_VARIABLE(0);
|
auto input = INPUT_VARIABLE(0);
|
||||||
auto output = OUTPUT_VARIABLE(0);
|
auto output = OUTPUT_VARIABLE(0);
|
||||||
|
|
|
@ -15,7 +15,7 @@
|
||||||
******************************************************************************/
|
******************************************************************************/
|
||||||
|
|
||||||
//
|
//
|
||||||
// @author Adel Rauf (rauf@konduit.ai)
|
// @author AbdelRauf (rauf@konduit.ai)
|
||||||
//
|
//
|
||||||
|
|
||||||
|
|
||||||
|
@ -28,7 +28,7 @@
|
||||||
namespace nd4j {
|
namespace nd4j {
|
||||||
namespace ops {
|
namespace ops {
|
||||||
|
|
||||||
CONFIGURABLE_OP_IMPL(rgb_to_hsv, 1, 1, false, 0, 0) {
|
CONFIGURABLE_OP_IMPL(rgb_to_hsv, 1, 1, true, 0, 0) {
|
||||||
|
|
||||||
auto input = INPUT_VARIABLE(0);
|
auto input = INPUT_VARIABLE(0);
|
||||||
auto output = OUTPUT_VARIABLE(0);
|
auto output = OUTPUT_VARIABLE(0);
|
||||||
|
@ -44,7 +44,7 @@ CONFIGURABLE_OP_IMPL(rgb_to_hsv, 1, 1, false, 0, 0) {
|
||||||
if (argSize > 0) {
|
if (argSize > 0) {
|
||||||
REQUIRE_TRUE(dimC >= 0 && dimC < rank, 0, "Index of the Channel dimension out of range: %i not in [%i,%i) ", INT_ARG(0), -rank, rank);
|
REQUIRE_TRUE(dimC >= 0 && dimC < rank, 0, "Index of the Channel dimension out of range: %i not in [%i,%i) ", INT_ARG(0), -rank, rank);
|
||||||
}
|
}
|
||||||
REQUIRE_TRUE(input->sizeAt(dimC) == 3, 0, "RGBtoHSV: operation expects 3 channels (H, S, V), but got %i instead", input->sizeAt(dimC));
|
REQUIRE_TRUE(input->sizeAt(dimC) == 3, 0, "RGBtoHSV: operation expects 3 channels (R, G, B), but got %i instead", input->sizeAt(dimC));
|
||||||
|
|
||||||
helpers::transformRgbHsv(block.launchContext(), input, output, dimC);
|
helpers::transformRgbHsv(block.launchContext(), input, output, dimC);
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,60 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2019 Konduit K.K.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
//
|
||||||
|
// @author AbdelRauf (rauf@konduit.ai)
|
||||||
|
//
|
||||||
|
|
||||||
|
#include <ops/declarable/headers/images.h>
|
||||||
|
#include <ops/declarable/CustomOperations.h>
|
||||||
|
#include <helpers/ConstantTadHelper.h>
|
||||||
|
#include <execution/Threads.h>
|
||||||
|
|
||||||
|
namespace nd4j {
|
||||||
|
namespace ops {
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
CONFIGURABLE_OP_IMPL(rgb_to_yiq, 1, 1, true, 0, 0) {
|
||||||
|
|
||||||
|
auto input = INPUT_VARIABLE(0);
|
||||||
|
auto output = OUTPUT_VARIABLE(0);
|
||||||
|
|
||||||
|
if (input->isEmpty())
|
||||||
|
return Status::OK();
|
||||||
|
|
||||||
|
const int rank = input->rankOf();
|
||||||
|
const int arg_size = block.getIArguments()->size();
|
||||||
|
const int dimC = arg_size > 0 ? (INT_ARG(0) >= 0 ? INT_ARG(0) : INT_ARG(0) + rank) : rank - 1;
|
||||||
|
|
||||||
|
REQUIRE_TRUE(rank >= 1, 0, "RGBtoYIQ: Fails to meet the rank requirement: %i >= 1 ", rank);
|
||||||
|
if (arg_size > 0) {
|
||||||
|
REQUIRE_TRUE(dimC >= 0 && dimC < rank, 0, "Index of the Channel dimension out of range: %i not in [%i,%i) ", INT_ARG(0), -rank, rank);
|
||||||
|
}
|
||||||
|
REQUIRE_TRUE(input->sizeAt(dimC) == 3, 0, "RGBtoYIQ: operation expects 3 channels (R, G, B), but got %i instead", input->sizeAt(dimC));
|
||||||
|
|
||||||
|
helpers::transformRgbYiq(block.launchContext(), input, output, dimC);
|
||||||
|
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
DECLARE_TYPES(rgb_to_yiq) {
|
||||||
|
getOpDescriptor()->setAllowedInputTypes({ ALL_FLOATS })
|
||||||
|
->setSameMode(true);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,61 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2019 Konduit K.K.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
//
|
||||||
|
// @author AbdelRauf (rauf@konduit.ai)
|
||||||
|
//
|
||||||
|
|
||||||
|
#include <ops/declarable/headers/images.h>
|
||||||
|
#include <ops/declarable/CustomOperations.h>
|
||||||
|
#include <helpers/ConstantTadHelper.h>
|
||||||
|
#include <execution/Threads.h>
|
||||||
|
|
||||||
|
namespace nd4j {
|
||||||
|
namespace ops {
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
CONFIGURABLE_OP_IMPL(yiq_to_rgb, 1, 1, true, 0, 0) {
|
||||||
|
|
||||||
|
auto input = INPUT_VARIABLE(0);
|
||||||
|
auto output = OUTPUT_VARIABLE(0);
|
||||||
|
|
||||||
|
if (input->isEmpty())
|
||||||
|
return Status::OK();
|
||||||
|
|
||||||
|
const int rank = input->rankOf();
|
||||||
|
const int arg_size = block.getIArguments()->size();
|
||||||
|
const int dimC = arg_size > 0 ? (INT_ARG(0) >= 0 ? INT_ARG(0) : INT_ARG(0) + rank) : rank - 1;
|
||||||
|
|
||||||
|
REQUIRE_TRUE(rank >= 1, 0, "YIQtoRGB: Fails to meet the rank requirement: %i >= 1 ", rank);
|
||||||
|
if (arg_size > 0) {
|
||||||
|
REQUIRE_TRUE(dimC >= 0 && dimC < rank, 0, "Index of the Channel dimension out of range: %i not in [%i,%i) ", INT_ARG(0), -rank, rank);
|
||||||
|
}
|
||||||
|
REQUIRE_TRUE(input->sizeAt(dimC) == 3, 0, "YIQtoRGB: operation expects 3 channels (Y, I, Q), but got %i instead", input->sizeAt(dimC));
|
||||||
|
|
||||||
|
helpers::transformYiqRgb(block.launchContext(), input, output, dimC);
|
||||||
|
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
DECLARE_TYPES(yiq_to_rgb) {
|
||||||
|
getOpDescriptor()->setAllowedInputTypes({ ALL_FLOATS })
|
||||||
|
->setSameMode(true);
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
|
@ -18,7 +18,7 @@
|
||||||
// @author Oleh Semeniv (oleg.semeniv@gmail.com)
|
// @author Oleh Semeniv (oleg.semeniv@gmail.com)
|
||||||
//
|
//
|
||||||
//
|
//
|
||||||
// @author Adel Rauf (rauf@konduit.ai)
|
// @author AbdelRauf (rauf@konduit.ai)
|
||||||
//
|
//
|
||||||
|
|
||||||
#ifndef LIBND4J_HEADERS_IMAGES_H
|
#ifndef LIBND4J_HEADERS_IMAGES_H
|
||||||
|
@ -65,6 +65,28 @@ namespace ops {
|
||||||
DECLARE_CUSTOM_OP(rgb_to_grs, 1, 1, false, 0, 0);
|
DECLARE_CUSTOM_OP(rgb_to_grs, 1, 1, false, 0, 0);
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Rgb To Yiq
|
||||||
|
* Input arrays:
|
||||||
|
* 0 - input array with rank >= 1, must have at least one dimension equal 3, that is dimension containing channels.
|
||||||
|
* Int arguments:
|
||||||
|
* 0 - optional argument, corresponds to dimension with 3 channels
|
||||||
|
*/
|
||||||
|
#if NOT_EXCLUDED(OP_rgb_to_yiq)
|
||||||
|
DECLARE_CONFIGURABLE_OP(rgb_to_yiq, 1, 1, false, 0, 0);
|
||||||
|
#endif
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Yiq To Rgb
|
||||||
|
* Input arrays:
|
||||||
|
* 0 - input array with rank >= 1, must have at least one dimension equal 3, that is dimension containing channels.
|
||||||
|
* Int arguments:
|
||||||
|
* 0 - optional argument, corresponds to dimension with 3 channels
|
||||||
|
*/
|
||||||
|
#if NOT_EXCLUDED(OP_yiq_to_rgb)
|
||||||
|
DECLARE_CONFIGURABLE_OP(yiq_to_rgb, 1, 1, false, 0, 0);
|
||||||
|
#endif
|
||||||
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -16,7 +16,7 @@
|
||||||
|
|
||||||
//
|
//
|
||||||
// @author Oleh Semeniv (oleg.semeniv@gmail.com)
|
// @author Oleh Semeniv (oleg.semeniv@gmail.com)
|
||||||
// @author Adel Rauf (rauf@konduit.ai)
|
// @author AbdelRauf (rauf@konduit.ai)
|
||||||
//
|
//
|
||||||
|
|
||||||
#include <ops/declarable/helpers/adjust_hue.h>
|
#include <ops/declarable/helpers/adjust_hue.h>
|
||||||
|
@ -111,6 +111,64 @@ FORCEINLINE static void tripleTransformer(const NDArray* input, NDArray* output,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
FORCEINLINE static void tripleTransformer(const NDArray* input, NDArray* output, const int dimC , T (&tr)[3][3] ) {
|
||||||
|
|
||||||
|
const int rank = input->rankOf();
|
||||||
|
|
||||||
|
const T* x = input->bufferAsT<T>();
|
||||||
|
T* z = output->bufferAsT<T>();
|
||||||
|
// TODO: Use tensordot or other optimizied helpers to see if we can get better performance.
|
||||||
|
|
||||||
|
if (dimC == rank - 1 && input->ews() == 1 && output->ews() == 1 && input->ordering() == 'c' && output->ordering() == 'c') {
|
||||||
|
|
||||||
|
auto func = PRAGMA_THREADS_FOR{
|
||||||
|
for (auto i = start; i < stop; i += increment) {
|
||||||
|
//simple M*v //tr.T*v.T // v * tr //rule: (AB)' =B'A'
|
||||||
|
// v.shape (1,3) row vector
|
||||||
|
T x0, x1, x2;
|
||||||
|
x0 = x[i]; //just additional hint
|
||||||
|
x1 = x[i + 1];
|
||||||
|
x2 = x[i + 2];
|
||||||
|
z[i] = x0 * tr[0][0] + x1 * tr[1][0] + x2 * tr[2][0];
|
||||||
|
z[i+1] = x0 * tr[0][1] + x1 * tr[1][1] + x2 * tr[2][1];
|
||||||
|
z[i+2] = x0 * tr[0][2] + x1 * tr[1][2] + x2 * tr[2][2];
|
||||||
|
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
samediff::Threads::parallel_for(func, 0, input->lengthOf(), 3);
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
auto packX = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(input->getShapeInfo(), dimC);
|
||||||
|
auto packZ = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(output->getShapeInfo(), dimC);
|
||||||
|
|
||||||
|
const Nd4jLong numOfTads = packX.numberOfTads();
|
||||||
|
const Nd4jLong xDimCstride = input->stridesOf()[dimC];
|
||||||
|
const Nd4jLong zDimCstride = output->stridesOf()[dimC];
|
||||||
|
|
||||||
|
auto func = PRAGMA_THREADS_FOR{
|
||||||
|
for (auto i = start; i < stop; i += increment) {
|
||||||
|
const T* xTad = x + packX.platformOffsets()[i];
|
||||||
|
T* zTad = z + packZ.platformOffsets()[i];
|
||||||
|
//simple M*v //tr.T*v
|
||||||
|
T x0, x1, x2;
|
||||||
|
x0 = xTad[0];
|
||||||
|
x1 = xTad[xDimCstride];
|
||||||
|
x2 = xTad[2 * xDimCstride];
|
||||||
|
zTad[0] = x0 * tr[0][0] + x1 * tr[1][0] + x2 * tr[2][0];
|
||||||
|
zTad[zDimCstride] = x0 * tr[0][1] + x1 * tr[1][1] + x2 * tr[2][1];
|
||||||
|
zTad[2 * zDimCstride] = x0 * tr[0][2] + x1 * tr[1][2] + x2 * tr[2][2];
|
||||||
|
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
samediff::Threads::parallel_tad(func, 0, numOfTads);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
FORCEINLINE static void hsvRgb(const NDArray* input, NDArray* output, const int dimC) {
|
FORCEINLINE static void hsvRgb(const NDArray* input, NDArray* output, const int dimC) {
|
||||||
|
@ -124,6 +182,31 @@ FORCEINLINE static void rgbHsv(const NDArray* input, NDArray* output, const int
|
||||||
return tripleTransformer<T>(input, output, dimC, op);
|
return tripleTransformer<T>(input, output, dimC, op);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
FORCEINLINE static void rgbYiq(const NDArray* input, NDArray* output, const int dimC) {
|
||||||
|
T arr[3][3] = {
|
||||||
|
{ (T)0.299, (T)0.59590059, (T)0.2115 },
|
||||||
|
{ (T)0.587, (T)-0.27455667, (T)-0.52273617 },
|
||||||
|
{ (T)0.114, (T)-0.32134392, (T)0.31119955 }
|
||||||
|
};
|
||||||
|
return tripleTransformer<T>(input, output, dimC, arr);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
FORCEINLINE static void yiqRgb(const NDArray* input, NDArray* output, const int dimC) {
|
||||||
|
//TODO: this operation does not use the clamp operation, so there is a possibility being out of range.
|
||||||
|
//Justify that it will not be out of range for images data
|
||||||
|
T arr[3][3] = {
|
||||||
|
{ (T)1, (T)1, (T)1 },
|
||||||
|
{ (T)0.95598634, (T)-0.27201283, (T)-1.10674021 },
|
||||||
|
{ (T)0.6208248, (T)-0.64720424, (T)1.70423049 }
|
||||||
|
};
|
||||||
|
return tripleTransformer<T>(input, output, dimC, arr);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
void transformHsvRgb(nd4j::LaunchContext* context, const NDArray* input, NDArray* output, const int dimC) {
|
void transformHsvRgb(nd4j::LaunchContext* context, const NDArray* input, NDArray* output, const int dimC) {
|
||||||
BUILD_SINGLE_SELECTOR(input->dataType(), hsvRgb, (input, output, dimC), FLOAT_TYPES);
|
BUILD_SINGLE_SELECTOR(input->dataType(), hsvRgb, (input, output, dimC), FLOAT_TYPES);
|
||||||
}
|
}
|
||||||
|
@ -132,6 +215,15 @@ void transformRgbHsv(nd4j::LaunchContext* context, const NDArray* input, NDArray
|
||||||
BUILD_SINGLE_SELECTOR(input->dataType(), rgbHsv, (input, output, dimC), FLOAT_TYPES);
|
BUILD_SINGLE_SELECTOR(input->dataType(), rgbHsv, (input, output, dimC), FLOAT_TYPES);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void transformYiqRgb(nd4j::LaunchContext* context, const NDArray* input, NDArray* output, const int dimC) {
|
||||||
|
BUILD_SINGLE_SELECTOR(input->dataType(), yiqRgb, (input, output, dimC), FLOAT_TYPES);
|
||||||
|
}
|
||||||
|
|
||||||
|
void transformRgbYiq(nd4j::LaunchContext* context, const NDArray* input, NDArray* output, const int dimC) {
|
||||||
|
BUILD_SINGLE_SELECTOR(input->dataType(), rgbYiq, (input, output, dimC), FLOAT_TYPES);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
|
@ -211,12 +211,97 @@ void transformRgbHsv(nd4j::LaunchContext* context, const NDArray* input, NDArray
|
||||||
manager.synchronize();
|
manager.synchronize();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template<typename T>
|
||||||
|
__global__ void tripleTransformerCuda(const void *vx, const Nd4jLong *xShapeInfo, const Nd4jLong *xTadShapeInfo, const Nd4jLong *xOffsets, void *vz, const Nd4jLong *zShapeInfo, const Nd4jLong *zTadShapeInfo, const Nd4jLong *zOffsets, const int dimC, int mode, uint64_t numTads) {
|
||||||
|
const auto x = reinterpret_cast<const T*>(vx);
|
||||||
|
auto z = reinterpret_cast<T*>(vz);
|
||||||
|
|
||||||
|
__shared__ Nd4jLong zLen, *sharedMem;
|
||||||
|
__shared__ int rank; // xRank == zRank
|
||||||
|
|
||||||
|
float yiqarr[3][3] = {
|
||||||
|
{ 0.299f, 0.59590059f, 0.2115f },
|
||||||
|
{ 0.587f, -0.27455667f, -0.52273617f },
|
||||||
|
{ 0.114f, -0.32134392f, 0.31119955f }
|
||||||
|
};
|
||||||
|
|
||||||
|
float rgbarr[3][3] = {
|
||||||
|
{ 1.f, 1.f, 1.f },
|
||||||
|
{ 0.95598634f, -0.27201283f, -1.10674021f },
|
||||||
|
{ 0.6208248f, -0.64720424f, 1.70423049f }
|
||||||
|
};
|
||||||
|
|
||||||
|
auto tr = mode == 1? yiqarr : rgbarr;
|
||||||
|
|
||||||
|
if (threadIdx.x == 0) {
|
||||||
|
extern __shared__ unsigned char shmem[];
|
||||||
|
sharedMem = reinterpret_cast<Nd4jLong*>(shmem);
|
||||||
|
|
||||||
|
zLen = shape::length(zShapeInfo);
|
||||||
|
rank = shape::rank(zShapeInfo);
|
||||||
|
}
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
Nd4jLong* coords = sharedMem + threadIdx.x * rank;
|
||||||
|
|
||||||
|
if (dimC == (rank - 1) && 'c' == shape::order(xShapeInfo) && 1 == shape::elementWiseStride(xShapeInfo) && 'c' == shape::order(zShapeInfo) && 1 == shape::elementWiseStride(zShapeInfo)) {
|
||||||
|
for (uint64_t f = blockIdx.x * blockDim.x + threadIdx.x; f < zLen / 3; f += gridDim.x * blockDim.x) {
|
||||||
|
auto i = f * 3;
|
||||||
|
|
||||||
|
auto xi0 = x[i];
|
||||||
|
auto xi1 = x[i+1];
|
||||||
|
auto xi2 = x[i+2];
|
||||||
|
|
||||||
|
for (int e = 0; e < 3; e++)
|
||||||
|
z[i + e] = xi0 * tr[0][e] + xi1 * tr[1][e] + xi2 * tr[2][e];
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// TAD based case
|
||||||
|
const Nd4jLong xDimCstride = shape::stride(xShapeInfo)[dimC];
|
||||||
|
const Nd4jLong zDimCstride = shape::stride(zShapeInfo)[dimC];
|
||||||
|
|
||||||
|
for (uint64_t i = blockIdx.x * blockDim.x + threadIdx.x; i < numTads; i += blockDim.x * gridDim.x) {
|
||||||
|
const T* xTad = x + xOffsets[i];
|
||||||
|
T* zTad = z + zOffsets[i];
|
||||||
|
|
||||||
|
auto xi0 = xTad[0];
|
||||||
|
auto xi1 = xTad[xDimCstride];
|
||||||
|
auto xi2 = xTad[xDimCstride * 2];
|
||||||
|
|
||||||
|
for (int e = 0; e < 3; e++)
|
||||||
|
zTad[zDimCstride * e] = xi0 * tr[0][e] + xi1 * tr[1][e] + xi2 * tr[2][e];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
static void rgbYiq(nd4j::LaunchContext* context, const NDArray* input, NDArray* output, const int dimC) {
|
||||||
|
auto packX = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(input->getShapeInfo(), dimC);
|
||||||
|
auto packZ = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(output->getShapeInfo(), dimC);
|
||||||
|
|
||||||
|
NDArray::prepareSpecialUse({output}, {input});
|
||||||
|
return tripleTransformerCuda<T><<<256, 256, 8192, *context->getCudaStream()>>>(input->getSpecialBuffer(), input->getSpecialShapeInfo(), packX.platformShapeInfo(), packX.platformOffsets(), output->specialBuffer(), output->specialShapeInfo(), packZ.platformShapeInfo(), packZ.platformOffsets(), dimC, 1, packZ.numberOfTads());
|
||||||
|
NDArray::registerSpecialUse({output}, {input});
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
FORCEINLINE static void yiqRgb(nd4j::LaunchContext* context, const NDArray* input, NDArray* output, const int dimC) {
|
||||||
|
auto packX = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(input->getShapeInfo(), dimC);
|
||||||
|
auto packZ = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(output->getShapeInfo(), dimC);
|
||||||
|
|
||||||
|
NDArray::prepareSpecialUse({output}, {input});
|
||||||
|
return tripleTransformerCuda<T><<<256, 256, 8192, *context->getCudaStream()>>>(input->getSpecialBuffer(), input->getSpecialShapeInfo(), packX.platformShapeInfo(), packX.platformOffsets(), output->specialBuffer(), output->specialShapeInfo(), packZ.platformShapeInfo(), packZ.platformOffsets(), dimC, 2, packZ.numberOfTads());
|
||||||
|
NDArray::registerSpecialUse({output}, {input});
|
||||||
|
}
|
||||||
|
|
||||||
|
void transformYiqRgb(nd4j::LaunchContext* context, const NDArray* input, NDArray* output, const int dimC) {
|
||||||
|
BUILD_SINGLE_SELECTOR(input->dataType(), yiqRgb, (context, input, output, dimC), FLOAT_TYPES);
|
||||||
|
}
|
||||||
|
|
||||||
|
void transformRgbYiq(nd4j::LaunchContext* context, const NDArray* input, NDArray* output, const int dimC) {
|
||||||
|
BUILD_SINGLE_SELECTOR(input->dataType(), rgbYiq, (context, input, output, dimC), FLOAT_TYPES);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -18,7 +18,7 @@
|
||||||
// @author Oleh Semeniv (oleg.semeniv@gmail.com)
|
// @author Oleh Semeniv (oleg.semeniv@gmail.com)
|
||||||
//
|
//
|
||||||
//
|
//
|
||||||
// @author Adel Rauf (rauf@konduit.ai)
|
// @author AbdelRauf (rauf@konduit.ai)
|
||||||
//
|
//
|
||||||
|
|
||||||
#ifndef LIBND4J_HELPERS_IMAGES_H
|
#ifndef LIBND4J_HELPERS_IMAGES_H
|
||||||
|
@ -33,9 +33,14 @@ namespace ops {
|
||||||
namespace helpers {
|
namespace helpers {
|
||||||
|
|
||||||
void transformRgbGrs(nd4j::LaunchContext* context, const NDArray& input, NDArray& output, const int dimC);
|
void transformRgbGrs(nd4j::LaunchContext* context, const NDArray& input, NDArray& output, const int dimC);
|
||||||
|
|
||||||
void transformHsvRgb(nd4j::LaunchContext* context, const NDArray* input, NDArray* output, const int dimC);
|
void transformHsvRgb(nd4j::LaunchContext* context, const NDArray* input, NDArray* output, const int dimC);
|
||||||
|
|
||||||
void transformRgbHsv(nd4j::LaunchContext* context, const NDArray* input, NDArray* output, const int dimC);
|
void transformRgbHsv(nd4j::LaunchContext* context, const NDArray* input, NDArray* output, const int dimC);
|
||||||
|
|
||||||
|
void transformYiqRgb(nd4j::LaunchContext* context, const NDArray* input, NDArray* output, const int dimC);
|
||||||
|
|
||||||
|
void transformRgbYiq(nd4j::LaunchContext* context, const NDArray* input, NDArray* output, const int dimC);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -442,7 +442,7 @@ TEST_F(DeclarableOpsTests16, test_rgb_to_hsv_6) {
|
||||||
expected.reshapei({ 3 });
|
expected.reshapei({ 3 });
|
||||||
#if 0
|
#if 0
|
||||||
//[RANK][SHAPE][STRIDES][OPTIONS][EWS][ORDER]
|
//[RANK][SHAPE][STRIDES][OPTIONS][EWS][ORDER]
|
||||||
subArrRgbs->printShapeInfo("subArrRgbs");
|
subArrRgbs.printShapeInfo("subArrRgbs");
|
||||||
#endif
|
#endif
|
||||||
auto actual = NDArrayFactory::create<float>('c', { 3 });
|
auto actual = NDArrayFactory::create<float>('c', { 3 });
|
||||||
|
|
||||||
|
@ -636,13 +636,13 @@ TEST_F(DeclarableOpsTests16, test_hsv_to_rgb_6) {
|
||||||
|
|
||||||
auto actual = NDArrayFactory::create<float>('c', { 3 });
|
auto actual = NDArrayFactory::create<float>('c', { 3 });
|
||||||
//get subarray
|
//get subarray
|
||||||
NDArray subArrHsvs = hsvs.subarray({ NDIndex::all(), NDIndex::point(0) });
|
NDArray subArrHsvs = hsvs.subarray({ NDIndex::all(), NDIndex::point(0) });
|
||||||
subArrHsvs.reshapei({ 3 });
|
subArrHsvs.reshapei({ 3 });
|
||||||
NDArray expected = rgbs.subarray({ NDIndex::all(), NDIndex::point(0) });
|
NDArray expected = rgbs.subarray({ NDIndex::all(), NDIndex::point(0) });
|
||||||
expected.reshapei({ 3 });
|
expected.reshapei({ 3 });
|
||||||
#if 0
|
#if 0
|
||||||
//[RANK][SHAPE][STRIDES][OPTIONS][EWS][ORDER]
|
//[RANK][SHAPE][STRIDES][OPTIONS][EWS][ORDER]
|
||||||
subArrHsvs->printShapeInfo("subArrHsvs");
|
subArrHsvs.printShapeInfo("subArrHsvs");
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
Context ctx(1);
|
Context ctx(1);
|
||||||
|
@ -655,3 +655,446 @@ TEST_F(DeclarableOpsTests16, test_hsv_to_rgb_6) {
|
||||||
ASSERT_TRUE(expected.equalsTo(actual));
|
ASSERT_TRUE(expected.equalsTo(actual));
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
TEST_F(DeclarableOpsTests16, test_rgb_to_yiq_1) {
|
||||||
|
/**
|
||||||
|
generated using numpy
|
||||||
|
_rgb_to_yiq_kernel = np.array([[0.299f, 0.59590059f, 0.2115f],
|
||||||
|
[0.587f, -0.27455667f, -0.52273617f],
|
||||||
|
[0.114f, -0.32134392f, 0.31119955f]])
|
||||||
|
nnrgbs = np.array([random() for x in range(0,3*4*5)],np.float32).reshape([5,4,3])
|
||||||
|
out =np.tensordot(nnrgbs,_rgb_to_yiq_kernel,axes=[[len(nnrgbs.shape)-1],[0]])
|
||||||
|
|
||||||
|
#alternatively you could use just with apply
|
||||||
|
out_2=np.apply_along_axis(lambda x: _rgb_to_yiq_kernel.T @ x,len(nnrgbs.shape)-1,nnrgbs)
|
||||||
|
|
||||||
|
*/
|
||||||
|
auto rgb = NDArrayFactory::create<float>('c', { 5, 4 ,3 },
|
||||||
|
{
|
||||||
|
0.48055f , 0.80757356f, 0.2564435f , 0.94277316f, 0.17006584f,
|
||||||
|
0.33366168f, 0.41727918f, 0.54528666f, 0.48942474f, 0.3305715f ,
|
||||||
|
0.98633456f, 0.00158441f, 0.97605824f, 0.02462568f, 0.14837205f,
|
||||||
|
0.00112842f, 0.99260217f, 0.9585542f , 0.41196227f, 0.3095014f ,
|
||||||
|
0.6620493f , 0.30888894f, 0.3122602f , 0.7993488f , 0.86656475f,
|
||||||
|
0.5997049f , 0.9776477f , 0.72481847f, 0.7835693f , 0.14649455f,
|
||||||
|
0.3573504f , 0.33301765f, 0.7853056f , 0.25830218f, 0.59289205f,
|
||||||
|
0.41357264f, 0.5934154f , 0.72647524f, 0.6623308f , 0.96197623f,
|
||||||
|
0.0720306f , 0.23853847f, 0.1427159f , 0.19581454f, 0.06766324f,
|
||||||
|
0.10614152f, 0.26093867f, 0.9584985f , 0.01258832f, 0.8160156f ,
|
||||||
|
0.56506383f, 0.08418505f, 0.86440504f, 0.6807802f , 0.20662387f,
|
||||||
|
0.4153733f , 0.76146203f, 0.50057423f, 0.08274968f, 0.9521758f
|
||||||
|
});
|
||||||
|
|
||||||
|
auto expected = NDArrayFactory::create<float>('c', { 5, 4 ,3 },
|
||||||
|
{
|
||||||
|
0.64696468f, -0.01777124f, -0.24070648f, 0.41975525f, 0.40788622f,
|
||||||
|
0.21433232f, 0.50064416f, -0.05832884f, -0.04447775f, 0.67799989f,
|
||||||
|
-0.07432612f, -0.44518381f, 0.32321111f, 0.52719408f, 0.2397369f ,
|
||||||
|
0.69227005f, -0.57987869f, -0.22032876f, 0.38032767f, -0.05223263f,
|
||||||
|
0.13137188f, 0.3667803f , -0.15853189f, 0.15085728f, 0.72258149f,
|
||||||
|
0.03757231f, 0.17403452f, 0.69337627f, 0.16971045f, -0.21071186f,
|
||||||
|
0.39185397f, -0.13084008f, 0.145886f , 0.47240727f, -0.1417591f ,
|
||||||
|
-0.12659159f, 0.67937788f, -0.05867803f, -0.04813048f, 0.35710624f,
|
||||||
|
0.47681283f, 0.24003804f, 0.1653288f , 0.00953913f, -0.05111816f,
|
||||||
|
0.29417614f, -0.31640032f, 0.18433114f, 0.54718234f, -0.39812097f,
|
||||||
|
-0.24805083f, 0.61018603f, -0.40592682f, -0.22219216f, 0.39241133f,
|
||||||
|
-0.23560742f, 0.06353694f, 0.3067938f , -0.0304029f , 0.35893188f
|
||||||
|
});
|
||||||
|
|
||||||
|
auto actual = NDArrayFactory::create<float>('c', { 5, 4, 3 });
|
||||||
|
|
||||||
|
Context ctx(1);
|
||||||
|
ctx.setInputArray(0, &rgb);
|
||||||
|
ctx.setOutputArray(0, &actual);
|
||||||
|
|
||||||
|
nd4j::ops::rgb_to_yiq op;
|
||||||
|
auto status = op.execute(&ctx);
|
||||||
|
|
||||||
|
ASSERT_EQ(ND4J_STATUS_OK, status);
|
||||||
|
ASSERT_TRUE(expected.equalsTo(actual));
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
TEST_F(DeclarableOpsTests16, test_rgb_to_yiq_2) {
|
||||||
|
|
||||||
|
auto rgb = NDArrayFactory::create<float>('c', { 5, 3, 4 },
|
||||||
|
{
|
||||||
|
0.48055f , 0.94277316f, 0.41727918f, 0.3305715f , 0.80757356f,
|
||||||
|
0.17006584f, 0.54528666f, 0.98633456f, 0.2564435f , 0.33366168f,
|
||||||
|
0.48942474f, 0.00158441f, 0.97605824f, 0.00112842f, 0.41196227f,
|
||||||
|
0.30888894f, 0.02462568f, 0.99260217f, 0.3095014f , 0.3122602f ,
|
||||||
|
0.14837205f, 0.9585542f , 0.6620493f , 0.7993488f , 0.86656475f,
|
||||||
|
0.72481847f, 0.3573504f , 0.25830218f, 0.5997049f , 0.7835693f ,
|
||||||
|
0.33301765f, 0.59289205f, 0.9776477f , 0.14649455f, 0.7853056f ,
|
||||||
|
0.41357264f, 0.5934154f , 0.96197623f, 0.1427159f , 0.10614152f,
|
||||||
|
0.72647524f, 0.0720306f , 0.19581454f, 0.26093867f, 0.6623308f ,
|
||||||
|
0.23853847f, 0.06766324f, 0.9584985f , 0.01258832f, 0.08418505f,
|
||||||
|
0.20662387f, 0.50057423f, 0.8160156f , 0.86440504f, 0.4153733f ,
|
||||||
|
0.08274968f, 0.56506383f, 0.6807802f , 0.76146203f, 0.9521758f
|
||||||
|
});
|
||||||
|
|
||||||
|
auto expected = NDArrayFactory::create<float>('c', { 5, 3, 4 },
|
||||||
|
{
|
||||||
|
0.64696468f, 0.41975525f, 0.50064416f, 0.67799989f, -0.01777124f,
|
||||||
|
0.40788622f, -0.05832884f, -0.07432612f, -0.24070648f, 0.21433232f,
|
||||||
|
-0.04447775f, -0.44518381f, 0.32321111f, 0.69227005f, 0.38032767f,
|
||||||
|
0.3667803f , 0.52719408f, -0.57987869f, -0.05223263f, -0.15853189f,
|
||||||
|
0.2397369f , -0.22032876f, 0.13137188f, 0.15085728f, 0.72258149f,
|
||||||
|
0.69337627f, 0.39185397f, 0.47240727f, 0.03757231f, 0.16971045f,
|
||||||
|
-0.13084008f, -0.1417591f , 0.17403452f, -0.21071186f, 0.145886f ,
|
||||||
|
-0.12659159f, 0.67937788f, 0.35710624f, 0.1653288f , 0.29417614f,
|
||||||
|
-0.05867803f, 0.47681283f, 0.00953913f, -0.31640032f, -0.04813048f,
|
||||||
|
0.24003804f, -0.05111816f, 0.18433114f, 0.54718234f, 0.61018603f,
|
||||||
|
0.39241133f, 0.3067938f , -0.39812097f, -0.40592682f, -0.23560742f,
|
||||||
|
-0.0304029f , -0.24805083f, -0.22219216f, 0.06353694f, 0.35893188f
|
||||||
|
});
|
||||||
|
|
||||||
|
auto actual = NDArrayFactory::create<float>('c', { 5, 3, 4 });
|
||||||
|
|
||||||
|
Context ctx(1);
|
||||||
|
ctx.setInputArray(0, &rgb);
|
||||||
|
ctx.setOutputArray(0, &actual);
|
||||||
|
ctx.setIArguments({ 1 });
|
||||||
|
nd4j::ops::rgb_to_yiq op;
|
||||||
|
auto status = op.execute(&ctx);
|
||||||
|
|
||||||
|
ASSERT_EQ(ND4J_STATUS_OK, status);
|
||||||
|
ASSERT_TRUE(expected.equalsTo(actual));
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
TEST_F(DeclarableOpsTests16, test_rgb_to_yiq_3) {
|
||||||
|
|
||||||
|
auto rgb = NDArrayFactory::create<float>('c', { 4, 3 },
|
||||||
|
{
|
||||||
|
0.48055f , 0.80757356f, 0.2564435f , 0.94277316f, 0.17006584f,
|
||||||
|
0.33366168f, 0.41727918f, 0.54528666f, 0.48942474f, 0.3305715f ,
|
||||||
|
0.98633456f, 0.00158441f
|
||||||
|
});
|
||||||
|
|
||||||
|
auto expected = NDArrayFactory::create<float>('c', { 4, 3 },
|
||||||
|
{
|
||||||
|
0.64696468f, -0.01777124f, -0.24070648f, 0.41975525f, 0.40788622f,
|
||||||
|
0.21433232f, 0.50064416f, -0.05832884f, -0.04447775f, 0.67799989f,
|
||||||
|
-0.07432612f, -0.44518381f
|
||||||
|
});
|
||||||
|
|
||||||
|
auto actual = NDArrayFactory::create<float>('c', { 4, 3 });
|
||||||
|
|
||||||
|
Context ctx(1);
|
||||||
|
ctx.setInputArray(0, &rgb);
|
||||||
|
ctx.setOutputArray(0, &actual);
|
||||||
|
|
||||||
|
nd4j::ops::rgb_to_yiq op;
|
||||||
|
auto status = op.execute(&ctx);
|
||||||
|
|
||||||
|
ASSERT_EQ(ND4J_STATUS_OK, status);
|
||||||
|
ASSERT_TRUE(expected.equalsTo(actual));
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
TEST_F(DeclarableOpsTests16, test_rgb_to_yiq_4) {
|
||||||
|
|
||||||
|
auto rgb = NDArrayFactory::create<float>('c', { 3, 4 },
|
||||||
|
{
|
||||||
|
0.48055f , 0.94277316f, 0.41727918f, 0.3305715f , 0.80757356f,
|
||||||
|
0.17006584f, 0.54528666f, 0.98633456f, 0.2564435f , 0.33366168f,
|
||||||
|
0.48942474f, 0.00158441f
|
||||||
|
});
|
||||||
|
|
||||||
|
auto expected = NDArrayFactory::create<float>('c', { 3, 4 },
|
||||||
|
{
|
||||||
|
0.64696468f, 0.41975525f, 0.50064416f, 0.67799989f, -0.01777124f,
|
||||||
|
0.40788622f, -0.05832884f, -0.07432612f, -0.24070648f, 0.21433232f,
|
||||||
|
-0.04447775f, -0.44518381f
|
||||||
|
});
|
||||||
|
|
||||||
|
auto actual = NDArrayFactory::create<float>('c', { 3, 4 });
|
||||||
|
|
||||||
|
Context ctx(1);
|
||||||
|
ctx.setInputArray(0, &rgb);
|
||||||
|
ctx.setOutputArray(0, &actual);
|
||||||
|
ctx.setIArguments({ 0 });
|
||||||
|
nd4j::ops::rgb_to_yiq op;
|
||||||
|
auto status = op.execute(&ctx);
|
||||||
|
|
||||||
|
ASSERT_EQ(ND4J_STATUS_OK, status);
|
||||||
|
ASSERT_TRUE(expected.equalsTo(actual));
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
TEST_F(DeclarableOpsTests16, test_rgb_to_yiq_5) {
|
||||||
|
|
||||||
|
auto rgbs = NDArrayFactory::create<float>('c', { 3 },
|
||||||
|
{ 0.48055f , 0.80757356f, 0.2564435f });
|
||||||
|
auto expected = NDArrayFactory::create<float>('c', { 3 },
|
||||||
|
{ 0.64696468f, -0.01777124f, -0.24070648f, });
|
||||||
|
|
||||||
|
|
||||||
|
auto actual = NDArrayFactory::create<float>('c', { 3 });
|
||||||
|
|
||||||
|
Context ctx(1);
|
||||||
|
ctx.setInputArray(0, &rgbs);
|
||||||
|
ctx.setOutputArray(0, &actual);
|
||||||
|
|
||||||
|
nd4j::ops::rgb_to_yiq op;
|
||||||
|
auto status = op.execute(&ctx);
|
||||||
|
|
||||||
|
ASSERT_EQ(ND4J_STATUS_OK, status);
|
||||||
|
ASSERT_TRUE(expected.equalsTo(actual));
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(DeclarableOpsTests16, test_rgb_to_yiq_6) {
|
||||||
|
|
||||||
|
auto rgbs = NDArrayFactory::create<float>('c', { 3, 4 },
|
||||||
|
{
|
||||||
|
0.48055f , 0.94277316f, 0.41727918f, 0.3305715f , 0.80757356f,
|
||||||
|
0.17006584f, 0.54528666f, 0.98633456f, 0.2564435f , 0.33366168f,
|
||||||
|
0.48942474f, 0.00158441f
|
||||||
|
});
|
||||||
|
|
||||||
|
auto yiqs = NDArrayFactory::create<float>('c', { 3, 4 },
|
||||||
|
{
|
||||||
|
0.64696468f, 0.41975525f, 0.50064416f, 0.67799989f, -0.01777124f,
|
||||||
|
0.40788622f, -0.05832884f, -0.07432612f, -0.24070648f, 0.21433232f,
|
||||||
|
-0.04447775f, -0.44518381f
|
||||||
|
});
|
||||||
|
|
||||||
|
//get subarray
|
||||||
|
NDArray subArrRgbs = rgbs.subarray({ NDIndex::all(), NDIndex::point(0) });
|
||||||
|
NDArray expected = yiqs.subarray({ NDIndex::all(), NDIndex::point(0) });
|
||||||
|
subArrRgbs.reshapei({ 3 });
|
||||||
|
expected.reshapei({ 3 });
|
||||||
|
#if 0
|
||||||
|
//[RANK][SHAPE][STRIDES][OPTIONS][EWS][ORDER]
|
||||||
|
subArrRgbs.printShapeInfo("subArrRgbs");
|
||||||
|
#endif
|
||||||
|
auto actual = NDArrayFactory::create<float>('c', { 3 });
|
||||||
|
|
||||||
|
Context ctx(1);
|
||||||
|
ctx.setInputArray(0, &subArrRgbs);
|
||||||
|
ctx.setOutputArray(0, &actual);
|
||||||
|
nd4j::ops::rgb_to_yiq op;
|
||||||
|
auto status = op.execute(&ctx);
|
||||||
|
|
||||||
|
ASSERT_EQ(ND4J_STATUS_OK, status);
|
||||||
|
ASSERT_TRUE(expected.equalsTo(actual));
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
TEST_F(DeclarableOpsTests16, test_yiq_to_rgb_1) {
|
||||||
|
|
||||||
|
auto yiqs = NDArrayFactory::create<float>('c', { 5, 4, 3 }, {
|
||||||
|
0.775258899f, -0.288912386f, -0.132725924f, 0.0664454922f, -0.212469354f,
|
||||||
|
0.455438733f, 0.418221354f, 0.349350512f, 0.145902053f, 0.947576523f,
|
||||||
|
-0.471601307f, 0.263960421f, 0.700227439f, 0.32434237f, -0.278446227f,
|
||||||
|
0.130805135f, -0.438441873f, 0.187127829f, 0.0276055578f, -0.179727226f,
|
||||||
|
0.305075705f, 0.716282248f, 0.278215706f, -0.44586885f, 0.76971364f,
|
||||||
|
0.131288841f, -0.141177326f, 0.900081575f, -0.0788725987f, 0.14756602f,
|
||||||
|
0.387832165f, 0.229834676f, 0.47921446f, 0.632930398f, 0.0443540029f,
|
||||||
|
-0.268817365f, 0.0977194682f, -0.141669706f, -0.140715122f, 0.946808815f,
|
||||||
|
-0.52525419f, -0.106209636f, 0.659476519f, 0.391066104f, 0.426448852f,
|
||||||
|
0.496989518f, -0.283434421f, -0.177366048f, 0.715208411f, -0.496444523f,
|
||||||
|
0.189553142f, 0.616444945f, 0.345852494f, 0.447739422f, 0.224696323f,
|
||||||
|
0.451372236f, 0.298027098f, 0.446561724f, -0.187599331f, -0.448159873f
|
||||||
|
});
|
||||||
|
auto expected = NDArrayFactory::create<float>('c', { 5, 4, 3 }, {
|
||||||
|
0.416663059f, 0.939747555f, 0.868814286f, 0.146075352f, -0.170521997f,
|
||||||
|
1.07776645f, 0.842775284f, 0.228765106f, 0.280231822f, 0.660605291f,
|
||||||
|
0.905021825f, 1.91936605f, 0.837427991f, 0.792213732f, -0.133271854f,
|
||||||
|
-0.17216571f, 0.128957025f, 0.934955336f, 0.0451873479f, -0.120952621f,
|
||||||
|
0.746436225f, 0.705446224f, 0.929172217f, -0.351493549f, 0.807577594f,
|
||||||
|
0.825371955f, 0.383812296f, 0.916293093f, 0.82603058f, 1.23885956f,
|
||||||
|
0.905059196f, 0.015164554f, 0.950156781f, 0.508443732f, 0.794845279f,
|
||||||
|
0.12571529f, -0.125074273f, 0.227326869f, 0.0147000261f, 0.378735409f,
|
||||||
|
1.15842402f, 1.34712305f, 1.2980804f, 0.277102016f, 0.953435072f,
|
||||||
|
0.115916842f, 0.688879376f, 0.508405162f, 0.35829352f, 0.727568094f,
|
||||||
|
1.58768577f, 1.22504294f, 0.232589777f, 0.996727258f, 0.841224629f,
|
||||||
|
-0.0909671176f, 0.233051388f, -0.0110094378f, 0.787642119f, -0.109582274f
|
||||||
|
});
|
||||||
|
auto actual = NDArrayFactory::create<float>('c', { 5, 4, 3 });
|
||||||
|
|
||||||
|
Context ctx(1);
|
||||||
|
ctx.setInputArray(0, &yiqs);
|
||||||
|
ctx.setOutputArray(0, &actual);
|
||||||
|
|
||||||
|
nd4j::ops::yiq_to_rgb op;
|
||||||
|
auto status = op.execute(&ctx);
|
||||||
|
|
||||||
|
ASSERT_EQ(ND4J_STATUS_OK, status);
|
||||||
|
ASSERT_TRUE(expected.equalsTo(actual));
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
TEST_F(DeclarableOpsTests16, test_yiq_to_rgb_2) {
|
||||||
|
|
||||||
|
auto yiqs = NDArrayFactory::create<float>('c', { 5, 3, 4 }, {
|
||||||
|
0.775258899f, 0.0664454922f, 0.418221354f, 0.947576523f, -0.288912386f,
|
||||||
|
-0.212469354f, 0.349350512f, -0.471601307f, -0.132725924f, 0.455438733f,
|
||||||
|
0.145902053f, 0.263960421f, 0.700227439f, 0.130805135f, 0.0276055578f,
|
||||||
|
0.716282248f, 0.32434237f, -0.438441873f, -0.179727226f, 0.278215706f,
|
||||||
|
-0.278446227f, 0.187127829f, 0.305075705f, -0.44586885f, 0.76971364f,
|
||||||
|
0.900081575f, 0.387832165f, 0.632930398f, 0.131288841f, -0.0788725987f,
|
||||||
|
0.229834676f, 0.0443540029f, -0.141177326f, 0.14756602f, 0.47921446f,
|
||||||
|
-0.268817365f, 0.0977194682f, 0.946808815f, 0.659476519f, 0.496989518f,
|
||||||
|
-0.141669706f, -0.52525419f, 0.391066104f, -0.283434421f, -0.140715122f,
|
||||||
|
-0.106209636f, 0.426448852f, -0.177366048f, 0.715208411f, 0.616444945f,
|
||||||
|
0.224696323f, 0.446561724f, -0.496444523f, 0.345852494f, 0.451372236f,
|
||||||
|
-0.187599331f, 0.189553142f, 0.447739422f, 0.298027098f, -0.448159873f
|
||||||
|
});
|
||||||
|
auto expected = NDArrayFactory::create<float>('c', { 5, 3, 4 }, {
|
||||||
|
0.416663059f, 0.146075352f, 0.842775284f, 0.660605291f, 0.939747555f,
|
||||||
|
-0.170521997f, 0.228765106f, 0.905021825f, 0.868814286f, 1.07776645f,
|
||||||
|
0.280231822f, 1.91936605f, 0.837427991f, -0.17216571f, 0.0451873479f,
|
||||||
|
0.705446224f, 0.792213732f, 0.128957025f, -0.120952621f, 0.929172217f,
|
||||||
|
-0.133271854f, 0.934955336f, 0.746436225f, -0.351493549f, 0.807577594f,
|
||||||
|
0.916293093f, 0.905059196f, 0.508443732f, 0.825371955f, 0.82603058f,
|
||||||
|
0.015164554f, 0.794845279f, 0.383812296f, 1.23885956f, 0.950156781f,
|
||||||
|
0.12571529f, -0.125074273f, 0.378735409f, 1.2980804f, 0.115916842f,
|
||||||
|
0.227326869f, 1.15842402f, 0.277102016f, 0.688879376f, 0.0147000261f,
|
||||||
|
1.34712305f, 0.953435072f, 0.508405162f, 0.35829352f, 1.22504294f,
|
||||||
|
0.841224629f, -0.0110094378f, 0.727568094f, 0.232589777f, -0.0909671176f,
|
||||||
|
0.787642119f, 1.58768577f, 0.996727258f, 0.233051388f, -0.109582274f
|
||||||
|
});
|
||||||
|
auto actual = NDArrayFactory::create<float>('c', { 5, 3, 4 });
|
||||||
|
|
||||||
|
Context ctx(1);
|
||||||
|
ctx.setInputArray(0, &yiqs);
|
||||||
|
ctx.setOutputArray(0, &actual);
|
||||||
|
ctx.setIArguments({ 1 });
|
||||||
|
nd4j::ops::yiq_to_rgb op;
|
||||||
|
auto status = op.execute(&ctx);
|
||||||
|
|
||||||
|
ASSERT_EQ(ND4J_STATUS_OK, status);
|
||||||
|
ASSERT_TRUE(expected.equalsTo(actual));
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
TEST_F(DeclarableOpsTests16, test_yiq_to_rgb_3) {
|
||||||
|
|
||||||
|
auto yiqs = NDArrayFactory::create<float>('c', { 4, 3 }, {
|
||||||
|
0.775258899f, -0.288912386f, -0.132725924f, 0.0664454922f, -0.212469354f,
|
||||||
|
0.455438733f, 0.418221354f, 0.349350512f, 0.145902053f, 0.947576523f,
|
||||||
|
-0.471601307f, 0.263960421f
|
||||||
|
});
|
||||||
|
auto expected = NDArrayFactory::create<float>('c', { 4, 3 }, {
|
||||||
|
0.416663059f, 0.939747555f, 0.868814286f, 0.146075352f, -0.170521997f,
|
||||||
|
1.07776645f, 0.842775284f, 0.228765106f, 0.280231822f, 0.660605291f,
|
||||||
|
0.905021825f, 1.91936605f
|
||||||
|
});
|
||||||
|
auto actual = NDArrayFactory::create<float>('c', { 4, 3 });
|
||||||
|
|
||||||
|
Context ctx(1);
|
||||||
|
ctx.setInputArray(0, &yiqs);
|
||||||
|
ctx.setOutputArray(0, &actual);
|
||||||
|
|
||||||
|
nd4j::ops::yiq_to_rgb op;
|
||||||
|
auto status = op.execute(&ctx);
|
||||||
|
|
||||||
|
ASSERT_EQ(ND4J_STATUS_OK, status);
|
||||||
|
ASSERT_TRUE(expected.equalsTo(actual));
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
TEST_F(DeclarableOpsTests16, test_yiq_to_rgb_4) {
|
||||||
|
|
||||||
|
auto yiqs = NDArrayFactory::create<float>('c', { 3, 4 }, {
|
||||||
|
0.775258899f, 0.0664454922f, 0.418221354f, 0.947576523f, -0.288912386f,
|
||||||
|
-0.212469354f, 0.349350512f, -0.471601307f, -0.132725924f, 0.455438733f,
|
||||||
|
0.145902053f, 0.263960421f
|
||||||
|
});
|
||||||
|
auto expected = NDArrayFactory::create<float>('c', { 3, 4 }, {
|
||||||
|
0.416663059f, 0.146075352f, 0.842775284f, 0.660605291f, 0.939747555f,
|
||||||
|
-0.170521997f, 0.228765106f, 0.905021825f, 0.868814286f, 1.07776645f,
|
||||||
|
0.280231822f, 1.91936605f
|
||||||
|
});
|
||||||
|
auto actual = NDArrayFactory::create<float>('c', { 3, 4 });
|
||||||
|
|
||||||
|
Context ctx(1);
|
||||||
|
ctx.setInputArray(0, &yiqs);
|
||||||
|
ctx.setOutputArray(0, &actual);
|
||||||
|
ctx.setIArguments({ 0 });
|
||||||
|
nd4j::ops::yiq_to_rgb op;
|
||||||
|
auto status = op.execute(&ctx);
|
||||||
|
|
||||||
|
ASSERT_EQ(ND4J_STATUS_OK, status);
|
||||||
|
ASSERT_TRUE(expected.equalsTo(actual));
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
TEST_F(DeclarableOpsTests16, test_yiq_to_rgb_5) {
|
||||||
|
|
||||||
|
auto yiqs = NDArrayFactory::create<float>('c', { 3 }, {
|
||||||
|
0.775258899f, -0.288912386f, -0.132725924f
|
||||||
|
});
|
||||||
|
auto expected = NDArrayFactory::create<float>('c', { 3 }, {
|
||||||
|
0.416663059f, 0.939747555f, 0.868814286f
|
||||||
|
});
|
||||||
|
auto actual = NDArrayFactory::create<float>('c', { 3 });
|
||||||
|
|
||||||
|
Context ctx(1);
|
||||||
|
ctx.setInputArray(0, &yiqs);
|
||||||
|
ctx.setOutputArray(0, &actual);
|
||||||
|
|
||||||
|
nd4j::ops::yiq_to_rgb op;
|
||||||
|
auto status = op.execute(&ctx);
|
||||||
|
#if 0
|
||||||
|
actual.printBuffer("actual");
|
||||||
|
expected.printBuffer("expected");
|
||||||
|
#endif
|
||||||
|
ASSERT_EQ(ND4J_STATUS_OK, status);
|
||||||
|
ASSERT_TRUE(expected.equalsTo(actual));
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(DeclarableOpsTests16, test_yiq_to_rgb_6) {
|
||||||
|
|
||||||
|
auto yiqs = NDArrayFactory::create<float>('c', { 3, 4 }, {
|
||||||
|
0.775258899f, 0.0664454922f, 0.418221354f, 0.947576523f, -0.288912386f,
|
||||||
|
-0.212469354f, 0.349350512f, -0.471601307f, -0.132725924f, 0.455438733f,
|
||||||
|
0.145902053f, 0.263960421f
|
||||||
|
});
|
||||||
|
auto rgbs = NDArrayFactory::create<float>('c', { 3, 4 }, {
|
||||||
|
0.416663059f, 0.146075352f, 0.842775284f, 0.660605291f, 0.939747555f,
|
||||||
|
-0.170521997f, 0.228765106f, 0.905021825f, 0.868814286f, 1.07776645f,
|
||||||
|
0.280231822f, 1.91936605f
|
||||||
|
});
|
||||||
|
|
||||||
|
//get subarray
|
||||||
|
NDArray subArrYiqs = yiqs.subarray({ NDIndex::all(), NDIndex::point(0) });
|
||||||
|
NDArray expected = rgbs.subarray({ NDIndex::all(), NDIndex::point(0) });
|
||||||
|
subArrYiqs.reshapei({ 3 });
|
||||||
|
expected.reshapei({ 3 });
|
||||||
|
#if 0
|
||||||
|
//[RANK][SHAPE][STRIDES][OPTIONS][EWS][ORDER]
|
||||||
|
subArrYiqs.printShapeInfo("subArrYiqs");
|
||||||
|
#endif
|
||||||
|
auto actual = NDArrayFactory::create<float>('c', { 3 });
|
||||||
|
|
||||||
|
Context ctx(1);
|
||||||
|
ctx.setInputArray(0, &subArrYiqs);
|
||||||
|
ctx.setOutputArray(0, &actual);
|
||||||
|
nd4j::ops::yiq_to_rgb op;
|
||||||
|
auto status = op.execute(&ctx);
|
||||||
|
|
||||||
|
ASSERT_EQ(ND4J_STATUS_OK, status);
|
||||||
|
ASSERT_TRUE(expected.equalsTo(actual));
|
||||||
|
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in New Issue