[WIP] Oleh rgb yuv (#147)
* libnd4j: RgbToYuv and YuvToRgb, both implementations for both cpu and cuda. Need adding tests and review Signed-off-by: Oleg <oleg.semeniv@gmail.com> * libnd4j: RgbToYuv and YuvToRgb, replace coords method on Tad in both cpu and cuda, add tests, fixed bugs Signed-off-by: Oleg <oleg.semeniv@gmail.com> * libnd4j: RgbToYuv and YuvToRgb minor corrections Signed-off-by: Oleg <oleg.semeniv@gmail.com> * libnd4j: RgbToYuv and YuvToRgb corrections to use operations in-placemaster
parent
d1e5e79c10
commit
75123b0a4c
|
@ -0,0 +1,57 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2019 Konduit K.K.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
//
|
||||
// @author Oleh Semeniv (oleg.semeniv@gmail.com)
|
||||
//
|
||||
|
||||
|
||||
|
||||
#include <ops/declarable/headers/images.h>
|
||||
#include <ops/declarable/CustomOperations.h>
|
||||
#include <helpers/ConstantTadHelper.h>
|
||||
#include <execution/Threads.h>
|
||||
|
||||
namespace nd4j {
|
||||
namespace ops {
|
||||
|
||||
CONFIGURABLE_OP_IMPL(rgb_to_yuv, 1, 1, true, 0, 0) {
|
||||
|
||||
auto input = INPUT_VARIABLE(0);
|
||||
auto output = OUTPUT_VARIABLE(0);
|
||||
|
||||
const int rank = input->rankOf();
|
||||
const int argSize = block.getIArguments()->size();
|
||||
const int dimC = argSize > 0 ? (INT_ARG(0) >= 0 ? INT_ARG(0) : INT_ARG(0) + rank) : rank - 1;
|
||||
|
||||
REQUIRE_TRUE(rank >= 1, 0, "RGBtoYUV: Fails to meet the rank requirement: %i >= 1 ", rank);
|
||||
if (argSize > 0) {
|
||||
REQUIRE_TRUE(dimC >= 0 && dimC < rank, 0, "Index of the Channel dimension out of range: %i not in [%i,%i) ", INT_ARG(0), -rank, rank);
|
||||
}
|
||||
REQUIRE_TRUE(input->sizeAt(dimC) == 3, 0, "RGBtoYUV: operation expects 3 channels (R, G, B), but got %i instead", input->sizeAt(dimC));
|
||||
|
||||
helpers::transformRgbYuv(block.launchContext(), *input, *output, dimC);
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
DECLARE_TYPES(rgb_to_yuv) {
|
||||
getOpDescriptor()->setAllowedInputTypes({ ALL_FLOATS })
|
||||
->setSameMode(true);
|
||||
}
|
||||
|
||||
}
|
||||
}
|
|
@ -0,0 +1,56 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2019 Konduit K.K.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
//
|
||||
// @author Oleh Semeniv (oleg.semeniv@gmail.com)
|
||||
//
|
||||
|
||||
#include <ops/declarable/headers/images.h>
|
||||
#include <ops/declarable/CustomOperations.h>
|
||||
#include <helpers/ConstantTadHelper.h>
|
||||
#include <execution/Threads.h>
|
||||
|
||||
namespace nd4j {
|
||||
namespace ops {
|
||||
|
||||
CONFIGURABLE_OP_IMPL(yuv_to_rgb, 1, 1, true, 0, 0) {
|
||||
|
||||
auto input = INPUT_VARIABLE(0);
|
||||
auto output = OUTPUT_VARIABLE(0);
|
||||
|
||||
const int rank = input->rankOf();
|
||||
const int argSize = block.getIArguments()->size();
|
||||
const int dimC = argSize > 0 ? (INT_ARG(0) >= 0 ? INT_ARG(0) : INT_ARG(0) + rank) : rank - 1;
|
||||
|
||||
REQUIRE_TRUE(rank >= 1, 0, "YUVtoRGB: Fails to meet the rank requirement: %i >= 1 ", rank);
|
||||
if (argSize > 0) {
|
||||
REQUIRE_TRUE(dimC >= 0 && dimC < rank, 0, "Index of the Channel dimension out of range: %i not in [%i,%i) ", INT_ARG(0), -rank, rank);
|
||||
}
|
||||
REQUIRE_TRUE(input->sizeAt(dimC) == 3, 0, "YUVtoRGB: operation expects 3 channels (Y, U, V), but got %i instead", input->sizeAt(dimC));
|
||||
|
||||
helpers::transformYuvRgb(block.launchContext(), *input, *output, dimC);
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
DECLARE_TYPES(yuv_to_rgb) {
|
||||
getOpDescriptor()->setAllowedInputTypes({ ALL_FLOATS })
|
||||
->setSameMode(true);
|
||||
}
|
||||
|
||||
|
||||
}
|
||||
}
|
|
@ -42,7 +42,7 @@ namespace ops {
|
|||
* 0 - optional argument, corresponds to dimension with 3 channels
|
||||
*/
|
||||
#if NOT_EXCLUDED(OP_rgb_to_hsv)
|
||||
DECLARE_CONFIGURABLE_OP(rgb_to_hsv, 1, 1, false, 0, 0);
|
||||
DECLARE_CONFIGURABLE_OP(rgb_to_hsv, 1, 1, true, 0, 0);
|
||||
#endif
|
||||
|
||||
/**
|
||||
|
@ -53,7 +53,7 @@ namespace ops {
|
|||
* 0 - optional argument, corresponds to dimension with 3 channels
|
||||
*/
|
||||
#if NOT_EXCLUDED(OP_hsv_to_rgb)
|
||||
DECLARE_CONFIGURABLE_OP(hsv_to_rgb, 1, 1, false, 0, 0);
|
||||
DECLARE_CONFIGURABLE_OP(hsv_to_rgb, 1, 1, true, 0, 0);
|
||||
#endif
|
||||
|
||||
/**
|
||||
|
@ -65,6 +65,27 @@ namespace ops {
|
|||
DECLARE_CUSTOM_OP(rgb_to_grs, 1, 1, false, 0, 0);
|
||||
#endif
|
||||
|
||||
/**
|
||||
* Rgb To Yuv
|
||||
* Input arrays:
|
||||
* 0 - input array with rank >= 1, must have at least one dimension equal 3, that is dimension containing channels.
|
||||
* Int arguments:
|
||||
* 0 - optional argument, corresponds to dimension with 3 channels
|
||||
*/
|
||||
#if NOT_EXCLUDED(OP_rgb_to_yuv)
|
||||
DECLARE_CONFIGURABLE_OP(rgb_to_yuv, 1, 1, true, 0, 0);
|
||||
#endif
|
||||
|
||||
/**
|
||||
* Yuv To Rgb
|
||||
* Input arrays:
|
||||
* 0 - input array with rank >= 1, must have at least one dimension equal 3, that is dimension containing channels.
|
||||
* Int arguments:
|
||||
* 0 - optional argument, corresponds to dimension with 3 channels
|
||||
*/
|
||||
#if NOT_EXCLUDED(OP_rgb_to_yuv)
|
||||
DECLARE_CONFIGURABLE_OP(yuv_to_rgb, 1, 1, true, 0, 0);
|
||||
|
||||
/**
|
||||
* Rgb To Yiq
|
||||
* Input arrays:
|
||||
|
@ -73,7 +94,7 @@ namespace ops {
|
|||
* 0 - optional argument, corresponds to dimension with 3 channels
|
||||
*/
|
||||
#if NOT_EXCLUDED(OP_rgb_to_yiq)
|
||||
DECLARE_CONFIGURABLE_OP(rgb_to_yiq, 1, 1, false, 0, 0);
|
||||
DECLARE_CONFIGURABLE_OP(rgb_to_yiq, 1, 1, true, 0, 0);
|
||||
#endif
|
||||
|
||||
/**
|
||||
|
@ -84,10 +105,11 @@ namespace ops {
|
|||
* 0 - optional argument, corresponds to dimension with 3 channels
|
||||
*/
|
||||
#if NOT_EXCLUDED(OP_yiq_to_rgb)
|
||||
DECLARE_CONFIGURABLE_OP(yiq_to_rgb, 1, 1, false, 0, 0);
|
||||
DECLARE_CONFIGURABLE_OP(yiq_to_rgb, 1, 1, true, 0, 0);
|
||||
#endif
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
#endif
|
||||
#endif
|
||||
|
|
|
@ -17,6 +17,7 @@
|
|||
//
|
||||
// @author raver119@gmail.com
|
||||
// @author Yurii Shyrma (iuriish@yahoo.com)
|
||||
// @author Oleh Semeniv (oleg.semeniv@gmail.com)
|
||||
//
|
||||
|
||||
#include <op_boilerplate.h>
|
||||
|
@ -106,7 +107,22 @@ FORCEINLINE _CUDA_HD void hsvToRgb(const T& h, const T& s, const T& v, T& r, T&
|
|||
// b *= 255;
|
||||
}
|
||||
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
template <typename T>
|
||||
FORCEINLINE _CUDA_HD void rgbYuv(const T& r, const T& g, const T& b, T& y, T& u, T& v) {
|
||||
y = static_cast<T>(0.299) * r + static_cast<T>(0.587) *g + static_cast<T>(0.114) * b;
|
||||
u = -static_cast<T>(0.14714119) * r - static_cast<T>(0.2888691) * g + static_cast<T>(0.43601035) * b;
|
||||
v = static_cast<T>(0.61497538) * r - static_cast<T>(0.51496512) * g - static_cast<T>(0.10001026) * b;
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
template <typename T>
|
||||
FORCEINLINE _CUDA_HD void yuvRgb(const T& y, const T& u, const T& v, T& r, T& g, T& b) {
|
||||
r = y + static_cast<T>(1.13988303) * v;
|
||||
g = y - static_cast<T>(0.394642334) * u - static_cast<T>(0.58062185) * v;
|
||||
b = y + static_cast<T>(2.03206185) * u;
|
||||
}
|
||||
|
||||
/*////////////////////////////////////////////////////////////////////////////////
|
||||
template <typename T>
|
||||
static FORCEINLINE _CUDA_HD void rgb_to_hv(T r, T g, T b, T* h, T* v_min, T* v_max) {
|
||||
|
|
|
@ -70,6 +70,65 @@ void transformRgbGrs(nd4j::LaunchContext* context, const NDArray& input, NDArray
|
|||
BUILD_SINGLE_SELECTOR(input.dataType(), rgbToGrs_, (input, output, dimC), NUMERIC_TYPES);
|
||||
}
|
||||
|
||||
template <typename T, typename Op>
|
||||
FORCEINLINE static void rgbToFromYuv_(const NDArray& input, NDArray& output, const int dimC, Op op) {
|
||||
|
||||
const T* x = input.bufferAsT<T>();
|
||||
T* z = output.bufferAsT<T>();
|
||||
const int rank = input.rankOf();
|
||||
bool bSimple = (dimC == rank - 1 && 'c' == input.ordering() && 1 == input.ews() &&
|
||||
'c' == output.ordering() && 1 == output.ews());
|
||||
|
||||
if (bSimple) {
|
||||
|
||||
auto func = PRAGMA_THREADS_FOR{
|
||||
for (auto i = start; i < stop; i += increment) {
|
||||
op(x[i], x[i + 1], x[i + 2], z[i], z[i + 1], z[i + 2]);
|
||||
}
|
||||
};
|
||||
|
||||
samediff::Threads::parallel_for(func, 0, input.lengthOf(), 3);
|
||||
return;
|
||||
}
|
||||
|
||||
auto packX = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(input.getShapeInfo(), dimC);
|
||||
auto packZ = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(output.getShapeInfo(), dimC);
|
||||
|
||||
const Nd4jLong numOfTads = packX.numberOfTads();
|
||||
const Nd4jLong xDimCstride = input.stridesOf()[dimC];
|
||||
const Nd4jLong zDimCstride = output.stridesOf()[dimC];
|
||||
|
||||
auto func = PRAGMA_THREADS_FOR{
|
||||
for (auto i = start; i < stop; i += increment) {
|
||||
const T* xTad = x + packX.platformOffsets()[i];
|
||||
T* zTad = z + packZ.platformOffsets()[i];
|
||||
op(xTad[0], xTad[xDimCstride], xTad[2 * xDimCstride], zTad[0], zTad[zDimCstride], zTad[2 * zDimCstride]);
|
||||
}
|
||||
};
|
||||
|
||||
samediff::Threads::parallel_tad(func, 0, numOfTads);
|
||||
return;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
FORCEINLINE static void rgbYuv_(const NDArray& input, NDArray& output, const int dimC) {
|
||||
auto op = nd4j::ops::helpers::rgbYuv<T>;
|
||||
return rgbToFromYuv_<T>(input, output, dimC, op);
|
||||
}
|
||||
|
||||
void transformRgbYuv(nd4j::LaunchContext* context, const NDArray& input, NDArray& output, const int dimC) {
|
||||
BUILD_SINGLE_SELECTOR(input.dataType(), rgbYuv_, (input, output, dimC), FLOAT_TYPES);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
FORCEINLINE static void yuvRgb_(const NDArray& input, NDArray& output, const int dimC) {
|
||||
auto op = nd4j::ops::helpers::yuvRgb<T>;
|
||||
return rgbToFromYuv_<T>(input, output, dimC, op);
|
||||
}
|
||||
|
||||
void transformYuvRgb(nd4j::LaunchContext* context, const NDArray& input, NDArray& output, const int dimC) {
|
||||
BUILD_SINGLE_SELECTOR(input.dataType(), yuvRgb_, (input, output, dimC), FLOAT_TYPES);
|
||||
}
|
||||
|
||||
template <typename T, typename Op>
|
||||
FORCEINLINE static void tripleTransformer(const NDArray* input, NDArray* output, const int dimC, Op op) {
|
||||
|
|
|
@ -16,6 +16,7 @@
|
|||
|
||||
//
|
||||
// @author Yurii Shyrma (iuriish@yahoo.com)
|
||||
// @author Oleh Semeniv (oleg.semeniv@gmail.com)
|
||||
//
|
||||
|
||||
#include <op_boilerplate.h>
|
||||
|
@ -29,6 +30,117 @@ namespace nd4j {
|
|||
namespace ops {
|
||||
namespace helpers {
|
||||
|
||||
|
||||
///////////////////////////////////////////////////////////////////
|
||||
template<typename T>
|
||||
__global__ void rgbToYuvCuda(const void* vx, const Nd4jLong* xShapeInfo, const Nd4jLong* xTadOffsets, void* vz, const Nd4jLong *zShapeInfo, const Nd4jLong* zTadOffsets, const Nd4jLong numOfTads, const int dimC) {
|
||||
|
||||
const T* x = reinterpret_cast<const T*>(vx);
|
||||
T* z = reinterpret_cast<T*>(vz);
|
||||
|
||||
__shared__ int rank;
|
||||
__shared__ Nd4jLong xDimCstride, zDimCstride;
|
||||
|
||||
if (threadIdx.x == 0) {
|
||||
rank = shape::rank(xShapeInfo);
|
||||
xDimCstride = shape::stride(xShapeInfo)[dimC];
|
||||
zDimCstride = shape::stride(zShapeInfo)[dimC];
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
const auto tid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
|
||||
for (Nd4jLong i = tid; i < numOfTads; i += gridDim.x * blockDim.x) {
|
||||
const T* xTad = x + xTadOffsets[i];
|
||||
T* zTad = z + zTadOffsets[i];
|
||||
|
||||
rgbYuv<T>(xTad[0], xTad[xDimCstride], xTad[2 * xDimCstride], zTad[0], zTad[zDimCstride], zTad[2 * zDimCstride]);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////
|
||||
template<typename T>
|
||||
linkage void rgbToYuvCudaLauncher(const int blocksPerGrid, const int threadsPerBlock, const cudaStream_t* stream, const void* vx, const Nd4jLong* xShapeInfo, const Nd4jLong* xTadOffsets, void* vz, const Nd4jLong* zShapeInfo, const Nd4jLong* zTadOffsets, const Nd4jLong numOfTads, const int dimC) {
|
||||
|
||||
rgbToYuvCuda<T> << <blocksPerGrid, threadsPerBlock, 256, * stream >> > (vx, xShapeInfo, xTadOffsets, vz, zShapeInfo, zTadOffsets, numOfTads, dimC);
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////
|
||||
void transformRgbYuv(nd4j::LaunchContext* context, const NDArray& input, NDArray& output, const int dimC) {
|
||||
|
||||
auto packX = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(input.getShapeInfo(), { dimC });
|
||||
auto packZ = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(output.getShapeInfo(), { dimC });
|
||||
|
||||
const Nd4jLong numOfTads = packX.numberOfTads();
|
||||
|
||||
const int threadsPerBlock = MAX_NUM_THREADS / 2;
|
||||
const int blocksPerGrid = (numOfTads + threadsPerBlock - 1) / threadsPerBlock;
|
||||
|
||||
PointersManager manager(context, "yuv_to_rgb");
|
||||
|
||||
NDArray::prepareSpecialUse({ &output }, { &input });
|
||||
BUILD_SINGLE_SELECTOR(input.dataType(), rgbToYuvCudaLauncher, (blocksPerGrid, threadsPerBlock, context->getCudaStream(), input.getSpecialBuffer(), input.getSpecialShapeInfo(), packX.platformOffsets(), output.specialBuffer(), output.specialShapeInfo(), packZ.platformOffsets(), numOfTads, dimC), FLOAT_TYPES);
|
||||
NDArray::registerSpecialUse({ &output }, { &input });
|
||||
|
||||
manager.synchronize();
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////
|
||||
template<typename T>
|
||||
__global__ void yuvToRgbCuda(const void* vx, const Nd4jLong* xShapeInfo, const Nd4jLong* xTadOffsets, void* vz, const Nd4jLong *zShapeInfo, const Nd4jLong* zTadOffsets, const Nd4jLong numOfTads, const int dimC) {
|
||||
|
||||
const T* x = reinterpret_cast<const T*>(vx);
|
||||
T* z = reinterpret_cast<T*>(vz);
|
||||
|
||||
__shared__ int rank;
|
||||
__shared__ Nd4jLong xDimCstride, zDimCstride;
|
||||
|
||||
if (threadIdx.x == 0) {
|
||||
rank = shape::rank(xShapeInfo);
|
||||
xDimCstride = shape::stride(xShapeInfo)[dimC];
|
||||
zDimCstride = shape::stride(zShapeInfo)[dimC];
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
const auto tid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
|
||||
for (Nd4jLong i = tid; i < numOfTads; i += gridDim.x * blockDim.x) {
|
||||
const T* xTad = x + xTadOffsets[i];
|
||||
T* zTad = z + zTadOffsets[i];
|
||||
|
||||
yuvRgb<T>(xTad[0], xTad[xDimCstride], xTad[2 * xDimCstride], zTad[0], zTad[zDimCstride], zTad[2 * zDimCstride]);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////
|
||||
template<typename T>
|
||||
linkage void yuvToRgbCudaLauncher(const int blocksPerGrid, const int threadsPerBlock, const cudaStream_t* stream, const void* vx, const Nd4jLong* xShapeInfo, const Nd4jLong* xTadOffsets, void* vz, const Nd4jLong* zShapeInfo, const Nd4jLong* zTadOffsets, const Nd4jLong numOfTads, const int dimC) {
|
||||
|
||||
yuvToRgbCuda<T> << <blocksPerGrid, threadsPerBlock, 256, * stream >> > (vx, xShapeInfo, xTadOffsets, vz, zShapeInfo, zTadOffsets, numOfTads, dimC);
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////
|
||||
void transformYuvRgb(nd4j::LaunchContext* context, const NDArray& input, NDArray& output, const int dimC) {
|
||||
|
||||
auto packX = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(input.getShapeInfo(), { dimC });
|
||||
auto packZ = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(output.getShapeInfo(), { dimC });
|
||||
|
||||
const Nd4jLong numOfTads = packX.numberOfTads();
|
||||
|
||||
const int threadsPerBlock = MAX_NUM_THREADS / 2;
|
||||
const int blocksPerGrid = (numOfTads + threadsPerBlock - 1) / threadsPerBlock;
|
||||
|
||||
PointersManager manager(context, "yuv_to_rgb");
|
||||
|
||||
NDArray::prepareSpecialUse({ &output }, { &input });
|
||||
BUILD_SINGLE_SELECTOR(input.dataType(), yuvToRgbCudaLauncher, (blocksPerGrid, threadsPerBlock, context->getCudaStream(), input.getSpecialBuffer(), input.getSpecialShapeInfo(), packX.platformOffsets(), output.specialBuffer(), output.specialShapeInfo(), packZ.platformOffsets(), numOfTads, dimC), FLOAT_TYPES);
|
||||
NDArray::registerSpecialUse({ &output }, { &input });
|
||||
|
||||
manager.synchronize();
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////
|
||||
// for example xShapeInfo = {2,3,4}, zShapeInfo = {2,1,4}
|
||||
template<typename T>
|
||||
|
@ -83,7 +195,7 @@ void transformRgbGrs(nd4j::LaunchContext* context, const NDArray& input, NDArray
|
|||
|
||||
PointersManager manager(context, "rgbToGrs");
|
||||
|
||||
const int threadsPerBlock = MAX_NUM_THREADS / 2;
|
||||
const int threadsPerBlock = MAX_NUM_THREADS / 4;
|
||||
const int blocksPerGrid = (input.lengthOf() + threadsPerBlock - 1) / threadsPerBlock;
|
||||
const int sharedMem = input.rankOf() * sizeof(Nd4jLong) * threadsPerBlock + 128;
|
||||
|
||||
|
|
|
@ -37,6 +37,8 @@ namespace helpers {
|
|||
void transformHsvRgb(nd4j::LaunchContext* context, const NDArray* input, NDArray* output, const int dimC);
|
||||
|
||||
void transformRgbHsv(nd4j::LaunchContext* context, const NDArray* input, NDArray* output, const int dimC);
|
||||
void transformYuvRgb(nd4j::LaunchContext* context, const NDArray& input, NDArray& output, const int dimC);
|
||||
void transformRgbYuv(nd4j::LaunchContext* context, const NDArray& input, NDArray& output, const int dimC);
|
||||
|
||||
void transformYiqRgb(nd4j::LaunchContext* context, const NDArray* input, NDArray* output, const int dimC);
|
||||
|
||||
|
|
|
@ -1064,5 +1064,234 @@ TEST_F(DeclarableOpsTests15, test_rgb_to_grs_9) {
|
|||
ASSERT_TRUE(expected.isSameShape(output));
|
||||
ASSERT_TRUE(expected.equalsTo(output));
|
||||
|
||||
delete result;
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests15, test_rgb_to_yuv_1) {
|
||||
// rank 1
|
||||
NDArray rgbs('f', { 3 }, { 10, 50, 200 }, nd4j::DataType::FLOAT32);
|
||||
NDArray expected('f', { 3 }, { 55.14 , 71.2872001, -39.6005542 }, nd4j::DataType::FLOAT32);
|
||||
nd4j::ops::rgb_to_yuv op;
|
||||
auto result = op.execute({ &rgbs }, {}, {});
|
||||
auto output = result->at(0);
|
||||
|
||||
ASSERT_EQ(Status::OK(), result->status());
|
||||
ASSERT_TRUE(expected.isSameShape(output));
|
||||
ASSERT_TRUE(expected.equalsTo(output));
|
||||
|
||||
delete result;
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests15, test_rgb_to_yuv_2) {
|
||||
|
||||
NDArray rgbs('c', { 3, 2 }, { 14., 99., 207., 10., 114., 201. }, nd4j::DataType::FLOAT32);
|
||||
rgbs.permutei({ 1,0 });
|
||||
|
||||
NDArray expected('c', { 2, 3 }, { 138.691, -12.150713, -109.38929, 58.385, 70.18241, 35.63085 }, nd4j::DataType::FLOAT32);
|
||||
nd4j::ops::rgb_to_yuv op;
|
||||
|
||||
auto result = op.execute({ &rgbs }, {}, {});
|
||||
auto output = result->at(0);
|
||||
|
||||
ASSERT_EQ(Status::OK(), result->status());
|
||||
ASSERT_TRUE(expected.isSameShape(output));
|
||||
ASSERT_TRUE(expected.equalsTo(output));
|
||||
|
||||
delete result;
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests15, test_rgb_to_yuv_3) {
|
||||
// rank 2
|
||||
NDArray rgbs('c', { 3, 4 }, { -9.4, 9.9, 9.7, 9.0, 1.14, 1.01, 1.11, 9.6, 1.05, 10.0, 1.03, 10.22 }, nd4j::DataType::FLOAT32);
|
||||
NDArray expected('c', { 3, 4 }, { -2.021720, 4.692970, 3.669290, 9.491281, 1.511627, 2.611648, -1.298824, 0.358612, -6.472839, 4.568039, 5.290639, -0.430992 }, nd4j::DataType::FLOAT32);
|
||||
|
||||
nd4j::ops::rgb_to_yuv op;
|
||||
auto result = op.execute({ &rgbs }, {}, { 0 });
|
||||
auto output = result->at(0);
|
||||
ASSERT_EQ(Status::OK(), result->status());
|
||||
ASSERT_TRUE(expected.isSameShape(output));
|
||||
ASSERT_TRUE(expected.equalsTo(output));
|
||||
|
||||
delete result;
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests15, test_rgb_to_yuv_4) {
|
||||
// rank 3
|
||||
NDArray rgbs('c', { 5,4,3 }, { 1.7750e+01, 1.4602e+01, 5.4883e+00, 9.5438e+01, 1.0038e+02, 4.0531e+01, -5.8844e+01, 2.9609e+01, -1.1414e+01, 2.1391e+01, 3.9656e+01, 2.1531e+01, -7.1062e+01, -4.5859e+00, 2.9438e+01, -6.7461e+00, 6.7938e+01, -6.1211e+00, 2.2750e+01, -6.1438e+01, 1.5404e-02, -8.5312e+01, 1.1641e+01, 6.2500e+01, -1.0019e+02, 3.9344e+01, -3.1344e+01, 3.8562e+01, 5.9961e+00, 6.2219e+01, -1.0477e+01, 1.7750e+01, 2.9938e+01, 7.5830e-01, -2.7516e+01, 7.2188e+01, -2.3406e+01, 1.1617e+01, 6.5125e+01, 6.5078e+00, 6.7812e+01, 4.6812e+01, 7.7344e+00, 6.8562e+01, 5.6719e+00, 2.3125e+01, 6.7562e+01, 9.3750e+00, 5.2094e+01, -8.6562e+01, 1.2695e+01, 3.3562e+01, 2.9734e+01, 5.2250e+01, 9.5469e+00, -7.4414e+00, -2.0125e+01, 1.8145e+00, 7.8438e+01, -4.8125e+01 }, nd4j::DataType::FLOAT32);
|
||||
NDArray expected('c', { 5,4,3 }, { 14.5042902, -4.43686799, 2.847406, 92.079556, -25.36761168, 2.94630572, -1.515069, -4.87137291, -50.29369639, 32.128515, -5.21515376, -9.41983935,-20.5835293, 24.61614501, -44.28390394, 37.1647167, -21.30142676, -38.52221293, -29.26009994, 14.40679768, 45.62757638, -11.550021, 36.44083018, -64.71012983,-10.435098, - 10.28950082, - 78.74044941, 22.1427147, 19.72198103, 14.40435988, 10.699559, 9.46744852, - 18.5778351 , -7.6957283, 39.31166179, 7.41657542, 7.245035, 28.48336771, - 26.88963173, 47.0880442, - 0.13584441, - 35.60035823, 43.2050762, - 18.47048906, - 31.11782117, 47.642019, - 18.83162118, - 21.50836396,-33.788558, 22.87507047, 75.34330791, 33.445396, 9.25395257, 0.10229474, -3.8078287, -8.02985955, 11.71587638, 41.0993915, -43.90830496, -34.46396749 }, nd4j::DataType::FLOAT32);
|
||||
|
||||
nd4j::ops::rgb_to_yuv op;
|
||||
auto result = op.execute({ &rgbs }, {}, {});
|
||||
auto output = result->at(0);
|
||||
|
||||
ASSERT_EQ(Status::OK(), result->status());
|
||||
ASSERT_TRUE(expected.isSameShape(output));
|
||||
ASSERT_TRUE(expected.equalsTo(output));
|
||||
|
||||
delete result;
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests15, test_rgb_to_yuv_5) {
|
||||
// rank 3
|
||||
NDArray rgbs('c', { 5,3,4 }, { 1.7750e+01f, -7.1062e+01f, -1.0019e+02f,-2.3406e+01f, 5.2094e+01f, 9.5438e+01f, -6.7461e+00f, 3.8562e+01f, 6.5078e+00f,3.3562e+01f, -5.8844e+01f, 2.2750e+01f, -1.0477e+01f, 7.7344e+00f, 9.5469e+00f,2.1391e+01f, -8.5312e+01f, 7.5830e-01f,2.3125e+01f, 1.8145e+00f, 1.4602e+01f,-4.5859e+00f, 3.9344e+01f, 1.1617e+01f,-8.6562e+01f, 1.0038e+02f, 6.7938e+01f,5.9961e+00f, 6.7812e+01f, 2.9734e+01f,2.9609e+01f, -6.1438e+01f, 1.7750e+01f,6.8562e+01f, -7.4414e+00f, 3.9656e+01f,1.1641e+01f, -2.7516e+01f, 6.7562e+01f,7.8438e+01f, 5.4883e+00f, 2.9438e+01f,-3.1344e+01f, 6.5125e+01f, 1.2695e+01f,4.0531e+01f, -6.1211e+00f, 6.2219e+01f,4.6812e+01f, 5.2250e+01f, -1.1414e+01f,1.5404e-02f, 2.9938e+01f, 5.6719e+00f,-2.0125e+01f, 2.1531e+01f, 6.2500e+01f,7.2188e+01f, 9.3750e+00f, -4.8125e+01f }, nd4j::DataType::FLOAT32);
|
||||
NDArray expected('c', { 5,3,4 }, { 36.628319, 38.600643,-40.624989, 18.231001, - 14.822637, - 2.479566, - 8.965780, 2.223851, -16.561626,-96.205162,-52.255379,-36.527435,-51.546139,2.234915, 20.914114, 8.785358, 32.552223, -3.356598, 9.069552, 1.393482,36.029255, 4.824605,- 9.972263,11.058715, 15.947105, 55.283543, 36.845627, -29.750486,0.887228, 6.534475, -21.794132,34.155693, -89.929497,39.562351, 27.276817,31.359871, 8.149521, 13.673355, 1.104303, 68.774300, 2.236881, 13.216944, - 3.555702,- 3.225931,3.063015, - 36.134724,58.302204, 8.477802, 38.695396,27.181587, - 14.157411,7.157054, 11.714512, 22.148155, 11.580557, - 27.204905,7.120562, 21.992094, 2.406748, - 6.265247, }, nd4j::DataType::FLOAT32);
|
||||
|
||||
nd4j::ops::rgb_to_yuv op;
|
||||
auto result = op.execute({ &rgbs }, {}, { 1 });
|
||||
auto output = result->at(0);
|
||||
|
||||
ASSERT_EQ(Status::OK(), result->status());
|
||||
ASSERT_TRUE(expected.isSameShape(output));
|
||||
ASSERT_TRUE(expected.equalsTo(output));
|
||||
delete result;
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests15, test_rgb_to_yuv_6) {
|
||||
// rank 3
|
||||
NDArray rgbs('c', { 3,5,4 }, { 1.7750e+01f, -7.1062e+01f, -1.0019e+02f,-2.3406e+01f, 5.2094e+01f, 9.5438e+01f, -6.7461e+00f, 3.8562e+01f, 6.5078e+00f,3.3562e+01f, -5.8844e+01f, 2.2750e+01f, -1.0477e+01f, 7.7344e+00f, 9.5469e+00f,2.1391e+01f, -8.5312e+01f, 7.5830e-01f,2.3125e+01f, 1.8145e+00f, 1.4602e+01f,-4.5859e+00f, 3.9344e+01f, 1.1617e+01f,-8.6562e+01f, 1.0038e+02f, 6.7938e+01f,5.9961e+00f, 6.7812e+01f, 2.9734e+01f,2.9609e+01f, -6.1438e+01f, 1.7750e+01f,6.8562e+01f, -7.4414e+00f, 3.9656e+01f,1.1641e+01f, -2.7516e+01f, 6.7562e+01f,7.8438e+01f, 5.4883e+00f, 2.9438e+01f,-3.1344e+01f, 6.5125e+01f, 1.2695e+01f,4.0531e+01f, -6.1211e+00f, 6.2219e+01f,4.6812e+01f, 5.2250e+01f, -1.1414e+01f,1.5404e-02f, 2.9938e+01f, 5.6719e+00f,-2.0125e+01f, 2.1531e+01f, 6.2500e+01f,7.2188e+01f, 9.3750e+00f, -4.8125e+01f }, nd4j::DataType::FLOAT32);
|
||||
try {
|
||||
nd4j::ops::rgb_to_yuv op;
|
||||
auto result = op.execute({ &rgbs }, {}, {});
|
||||
ASSERT_EQ(Status::THROW(), result->status());
|
||||
delete result;
|
||||
}
|
||||
catch (std::exception & e) {
|
||||
nd4j_printf("Error should be here `%s'. It's OK.\n", e.what());
|
||||
}
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests15, test_rgb_to_yuv_7) {
|
||||
// rank 3
|
||||
NDArray rgbs('f', { 2, 2, 3 }, { 1.7750e+01f,-7.1062e+01f, -1.0019e+02f, -2.3406e+01f,5.2094e+01f,9.5438e+01f, -6.7461e+00f,3.8562e+01f, 6.5078e+00f, 3.3562e+01f,-5.8844e+01f,2.2750e+01f }, nd4j::DataType::FLOAT32);
|
||||
NDArray expected('f', { 2,2,3 }, { 36.628319,38.600643, -40.624989,18.231001, -14.822637,-2.479566, -8.965780, 2.223851, -16.561626,- 96.205162,-52.255379, -36.527435 }, nd4j::DataType::FLOAT32);
|
||||
|
||||
nd4j::ops::rgb_to_yuv op;
|
||||
auto result = op.execute({ &rgbs }, {}, {});
|
||||
auto output = result->at(0);
|
||||
|
||||
ASSERT_EQ(Status::OK(), result->status());
|
||||
ASSERT_TRUE(expected.isSameShape(output));
|
||||
ASSERT_TRUE(expected.equalsTo(output));
|
||||
|
||||
delete result;
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests15, test_yuv_to_rgb_1) {
|
||||
// rank 1
|
||||
NDArray yuv('c', { 3 }, { 55.14 , 71.2872001, -39.6005542 }, nd4j::DataType::FLOAT32);
|
||||
NDArray expected('c', { 3 }, { 10, 50, 200 }, nd4j::DataType::FLOAT32);
|
||||
nd4j::ops::yuv_to_rgb op;
|
||||
auto result = op.execute({ &yuv }, {}, {});
|
||||
auto output = result->at(0);
|
||||
|
||||
ASSERT_EQ(Status::OK(), result->status());
|
||||
ASSERT_TRUE(expected.isSameShape(output));
|
||||
ASSERT_TRUE(expected.equalsTo(output));
|
||||
|
||||
delete result;
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests15, test_yuv_to_rgb_2) {
|
||||
// rank 1
|
||||
NDArray yuv('f', { 3 }, { 55.14, 71.2872001, -39.6005542 }, nd4j::DataType::FLOAT32);
|
||||
NDArray expected('f', { 3 }, { 10, 50, 200 }, nd4j::DataType::FLOAT32);
|
||||
nd4j::ops::yuv_to_rgb op;
|
||||
auto result = op.execute({ &yuv }, {}, {});
|
||||
auto output = result->at(0);
|
||||
|
||||
ASSERT_EQ(Status::OK(), result->status());
|
||||
ASSERT_TRUE(expected.isSameShape(output));
|
||||
ASSERT_TRUE(expected.equalsTo(output));
|
||||
|
||||
delete result;
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests15, test_yuv_to_rgb_3) {
|
||||
// rank 2
|
||||
NDArray expected('c', { 3, 4 }, { -9.4, 9.9, 9.7, 9.0, 1.14, 1.01, 1.11, 9.6, 1.05, 10.0, 1.03, 10.22 }, nd4j::DataType::FLOAT32);
|
||||
NDArray yuv('c', { 3, 4 }, { -2.021720, 4.692970, 3.669290, 9.491281, 1.511627, 2.611648, -1.298824, 0.358612, -6.472839, 4.568039, 5.290639, -0.430992 }, nd4j::DataType::FLOAT32);
|
||||
|
||||
nd4j::ops::yuv_to_rgb op;
|
||||
auto result = op.execute({ &yuv }, {}, { 0 });
|
||||
auto output = result->at(0);
|
||||
ASSERT_EQ(Status::OK(), result->status());
|
||||
ASSERT_TRUE(expected.isSameShape(output));
|
||||
ASSERT_TRUE(expected.equalsTo(output));
|
||||
|
||||
delete result;
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests15, test_yuv_to_rgb_4) {
|
||||
// rank 3
|
||||
NDArray expected('c', { 5,4,3 }, { 1.7750e+01, 1.4602e+01, 5.4883e+00, 9.5438e+01, 1.0038e+02, 4.0531e+01, -5.8844e+01, 2.9609e+01, -1.1414e+01, 2.1391e+01, 3.9656e+01, 2.1531e+01, -7.1062e+01, -4.5859e+00, 2.9438e+01, -6.7461e+00, 6.7938e+01, -6.1211e+00, 2.2750e+01, -6.1438e+01, 1.5404e-02, -8.5312e+01, 1.1641e+01, 6.2500e+01, -1.0019e+02, 3.9344e+01, -3.1344e+01, 3.8562e+01, 5.9961e+00, 6.2219e+01, -1.0477e+01, 1.7750e+01, 2.9938e+01, 7.5830e-01, -2.7516e+01, 7.2188e+01, -2.3406e+01, 1.1617e+01, 6.5125e+01, 6.5078e+00, 6.7812e+01, 4.6812e+01, 7.7344e+00, 6.8562e+01, 5.6719e+00, 2.3125e+01, 6.7562e+01, 9.3750e+00, 5.2094e+01, -8.6562e+01, 1.2695e+01, 3.3562e+01, 2.9734e+01, 5.2250e+01, 9.5469e+00, -7.4414e+00, -2.0125e+01, 1.8145e+00, 7.8438e+01, -4.8125e+01 }, nd4j::DataType::FLOAT32);
|
||||
NDArray yuv('c', { 5,4,3 }, { 14.5042902, -4.43686799, 2.847406, 92.079556, -25.36761168, 2.94630572, -1.515069, -4.87137291, -50.29369639, 32.128515, -5.21515376, -9.41983935,-20.5835293, 24.61614501, -44.28390394, 37.1647167, -21.30142676, -38.52221293, -29.26009994, 14.40679768, 45.62757638, -11.550021, 36.44083018, -64.71012983,-10.435098, -10.28950082, -78.74044941, 22.1427147, 19.72198103, 14.40435988, 10.699559, 9.46744852, -18.5778351 , -7.6957283, 39.31166179, 7.41657542, 7.245035, 28.48336771, -26.88963173, 47.0880442, -0.13584441, -35.60035823, 43.2050762, -18.47048906, -31.11782117, 47.642019, -18.83162118, -21.50836396,-33.788558, 22.87507047, 75.34330791, 33.445396, 9.25395257, 0.10229474, -3.8078287, -8.02985955, 11.71587638, 41.0993915, -43.90830496, -34.46396749 }, nd4j::DataType::FLOAT32);
|
||||
|
||||
nd4j::ops::yuv_to_rgb op;
|
||||
auto result = op.execute({ &yuv }, {}, {});
|
||||
auto output = result->at(0);
|
||||
|
||||
ASSERT_EQ(Status::OK(), result->status());
|
||||
ASSERT_TRUE(expected.isSameShape(output));
|
||||
ASSERT_TRUE(expected.equalsTo(output));
|
||||
|
||||
delete result;
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests15, test_yuv_to_rgb_5) {
|
||||
// rank 3
|
||||
NDArray expected('c', { 5,3,4 }, { 1.7750e+01f, -7.1062e+01f, -1.0019e+02f,-2.3406e+01f, 5.2094e+01f, 9.5438e+01f, -6.7461e+00f, 3.8562e+01f, 6.5078e+00f,3.3562e+01f, -5.8844e+01f, 2.2750e+01f, -1.0477e+01f, 7.7344e+00f, 9.5469e+00f,2.1391e+01f, -8.5312e+01f, 7.5830e-01f,2.3125e+01f, 1.8145e+00f, 1.4602e+01f,-4.5859e+00f, 3.9344e+01f, 1.1617e+01f,-8.6562e+01f, 1.0038e+02f, 6.7938e+01f,5.9961e+00f, 6.7812e+01f, 2.9734e+01f,2.9609e+01f, -6.1438e+01f, 1.7750e+01f,6.8562e+01f, -7.4414e+00f, 3.9656e+01f,1.1641e+01f, -2.7516e+01f, 6.7562e+01f,7.8438e+01f, 5.4883e+00f, 2.9438e+01f,-3.1344e+01f, 6.5125e+01f, 1.2695e+01f,4.0531e+01f, -6.1211e+00f, 6.2219e+01f,4.6812e+01f, 5.2250e+01f, -1.1414e+01f,1.5404e-02f, 2.9938e+01f, 5.6719e+00f,-2.0125e+01f, 2.1531e+01f, 6.2500e+01f,7.2188e+01f, 9.3750e+00f, -4.8125e+01f }, nd4j::DataType::FLOAT32);
|
||||
NDArray yuv('c', { 5,3,4 }, { 36.628319, 38.600643,-40.624989, 18.231001, -14.822637, -2.479566, -8.965780, 2.223851, -16.561626,-96.205162,-52.255379,-36.527435,-51.546139,2.234915, 20.914114, 8.785358, 32.552223, -3.356598, 9.069552, 1.393482,36.029255, 4.824605,-9.972263,11.058715, 15.947105, 55.283543, 36.845627, -29.750486,0.887228, 6.534475, -21.794132,34.155693, -89.929497,39.562351, 27.276817,31.359871, 8.149521, 13.673355, 1.104303, 68.774300, 2.236881, 13.216944, -3.555702,-3.225931,3.063015, -36.134724,58.302204, 8.477802, 38.695396,27.181587, -14.157411,7.157054, 11.714512, 22.148155, 11.580557, -27.204905,7.120562, 21.992094, 2.406748, -6.265247, }, nd4j::DataType::FLOAT32);
|
||||
|
||||
nd4j::ops::yuv_to_rgb op;
|
||||
auto result = op.execute({ &yuv }, {}, { 1 });
|
||||
auto output = result->at(0);
|
||||
|
||||
ASSERT_EQ(Status::OK(), result->status());
|
||||
ASSERT_TRUE(expected.isSameShape(output));
|
||||
ASSERT_TRUE(expected.equalsTo(output));
|
||||
delete result;
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests15, test_yuv_to_rgb_6) {
|
||||
// rank 3
|
||||
NDArray yuv('c', { 3,5,4 }, { 1.7750e+01f, -7.1062e+01f, -1.0019e+02f,-2.3406e+01f, 5.2094e+01f, 9.5438e+01f, -6.7461e+00f, 3.8562e+01f, 6.5078e+00f,3.3562e+01f, -5.8844e+01f, 2.2750e+01f, -1.0477e+01f, 7.7344e+00f, 9.5469e+00f,2.1391e+01f, -8.5312e+01f, 7.5830e-01f,2.3125e+01f, 1.8145e+00f, 1.4602e+01f,-4.5859e+00f, 3.9344e+01f, 1.1617e+01f,-8.6562e+01f, 1.0038e+02f, 6.7938e+01f,5.9961e+00f, 6.7812e+01f, 2.9734e+01f,2.9609e+01f, -6.1438e+01f, 1.7750e+01f,6.8562e+01f, -7.4414e+00f, 3.9656e+01f,1.1641e+01f, -2.7516e+01f, 6.7562e+01f,7.8438e+01f, 5.4883e+00f, 2.9438e+01f,-3.1344e+01f, 6.5125e+01f, 1.2695e+01f,4.0531e+01f, -6.1211e+00f, 6.2219e+01f,4.6812e+01f, 5.2250e+01f, -1.1414e+01f,1.5404e-02f, 2.9938e+01f, 5.6719e+00f,-2.0125e+01f, 2.1531e+01f, 6.2500e+01f,7.2188e+01f, 9.3750e+00f, -4.8125e+01f }, nd4j::DataType::FLOAT32);
|
||||
try {
|
||||
nd4j::ops::yuv_to_rgb op;
|
||||
auto result = op.execute({ &yuv }, {}, {});
|
||||
ASSERT_EQ(Status::THROW(), result->status());
|
||||
delete result;
|
||||
}
|
||||
catch (std::exception & e) {
|
||||
nd4j_printf("Error should be here `%s'. It's OK.\n", e.what());
|
||||
}
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests15, test_yuv_to_rgb_7) {
|
||||
// rank 3
|
||||
NDArray expected('f', { 2, 2, 3 }, { 1.7750e+01f,-7.1062e+01f, -1.0019e+02f, -2.3406e+01f,5.2094e+01f,9.5438e+01f, -6.7461e+00f,3.8562e+01f, 6.5078e+00f, 3.3562e+01f,-5.8844e+01f,2.2750e+01f }, nd4j::DataType::FLOAT32);
|
||||
NDArray yuv('f', { 2,2,3 }, { 36.628319, 38.600643, -40.624989, 18.231001, -14.822637, -2.479566, -8.965780, 2.223851, -16.561626, -96.205162, -52.255379, -36.527435 }, nd4j::DataType::FLOAT32);
|
||||
|
||||
nd4j::ops::yuv_to_rgb op;
|
||||
auto result = op.execute({ &yuv }, {}, {});
|
||||
auto output = result->at(0);
|
||||
|
||||
ASSERT_EQ(Status::OK(), result->status());
|
||||
ASSERT_TRUE(expected.isSameShape(output));
|
||||
ASSERT_TRUE(expected.equalsTo(output));
|
||||
|
||||
delete result;
|
||||
}
|
Loading…
Reference in New Issue