Fix for hsv and rgb ranges (#136)

Signed-off-by: Abdelrauf <rauf@konduit.ai>
master
Abdelrauf 2019-12-20 09:48:30 +04:00 committed by raver119
parent 43e118de1e
commit 3c9a2a5cd9
5 changed files with 238 additions and 271 deletions

View File

@ -41,33 +41,33 @@ FORCEINLINE _CUDA_HD void rgbToHsv(const T& r, const T& g, const T& b, T& h, T&
const T max = nd4j::math::nd4j_max<T>(r, nd4j::math::nd4j_max<T>(g, b)); const T max = nd4j::math::nd4j_max<T>(r, nd4j::math::nd4j_max<T>(g, b));
const T min = nd4j::math::nd4j_min<T>(r, nd4j::math::nd4j_min<T>(g, b)); const T min = nd4j::math::nd4j_min<T>(r, nd4j::math::nd4j_min<T>(g, b));
const T c = max - min; const T c = max - min;
const T _p6 = (T)1 / (T)6;
// calculate h // calculate h
if(c == 0) { if(c == 0) {
h = 0; h = 0;
} }
else if(max == r) { else if(max == r) {
h = 60.f * ((g - b) / c) + (g >= b ? 0 : 360); h = _p6 * ((g - b) / c) + (g >= b ? (T)0 : (T)1);
} }
else if(max == g) { else if(max == g) {
h = 60.f * ((b - r) / c) + 120; h = _p6 * ((b - r) / c + (T)2);
} }
else { // max == b else { // max == b
h = 60.f * ((r - g) / c) + 240; h = _p6 * ((r - g) / c + (T)4);
} }
// calculate s // calculate s
s = max == (T)0 ? (T)0 : c / max; s = max == (T)0 ? (T)0 : c / max;
// calculate v // calculate v
v = max / 255.f; v = max;// / 255.f;
} }
//////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////
template <typename T> template <typename T>
FORCEINLINE _CUDA_HD void hsvToRgb(const T& h, const T& s, const T& v, T& r, T& g, T& b) { FORCEINLINE _CUDA_HD void hsvToRgb(const T& h, const T& s, const T& v, T& r, T& g, T& b) {
const float sector = h / 60.f; const float sector = h * 6.f;
const T c = v * s; const T c = v * s;
if(0.f <= sector && sector < 1.f) { if(0.f <= sector && sector < 1.f) {
@ -101,11 +101,12 @@ FORCEINLINE _CUDA_HD void hsvToRgb(const T& h, const T& s, const T& v, T& r, T&
b = v - c * (sector - 5); b = v - c * (sector - 5);
} }
r *= 255; // r *= 255;
g *= 255; // g *= 255;
b *= 255; // b *= 255;
} }
/*//////////////////////////////////////////////////////////////////////////////// /*////////////////////////////////////////////////////////////////////////////////
template <typename T> template <typename T>
static FORCEINLINE _CUDA_HD void rgb_to_hv(T r, T g, T b, T* h, T* v_min, T* v_max) { static FORCEINLINE _CUDA_HD void rgb_to_hv(T r, T g, T b, T* h, T* v_min, T* v_max) {

View File

@ -45,11 +45,11 @@ static void adjustHue_(const NDArray *input, const NDArray* deltaScalarArr, NDAr
rgbToHsv<T>(x[i], x[i + 1], x[i + 2], h, s, v); rgbToHsv<T>(x[i], x[i + 1], x[i + 2], h, s, v);
h += delta * 360; h += delta ;
if (h > 360) if (h > (T)1)
h -= 360; h -= (T)1;
else if (h < 0) else if (h < 0)
h += 360; h += (T)1;
hsvToRgb<T>(h, s, v, z[i], z[i + 1], z[i + 2]); hsvToRgb<T>(h, s, v, z[i], z[i + 1], z[i + 2]);
} }
@ -76,11 +76,11 @@ static void adjustHue_(const NDArray *input, const NDArray* deltaScalarArr, NDAr
rgbToHsv<T>(xTad[0], xTad[xDimCstride], xTad[2 * xDimCstride], h, s, v); rgbToHsv<T>(xTad[0], xTad[xDimCstride], xTad[2 * xDimCstride], h, s, v);
h += delta * 360; h += delta ;
if (h > 360) if (h > (T)1)
h -= 360; h -= (T)1;
else if (h < 0) else if (h < 0)
h += 360; h += (T)1;
hsvToRgb<T>(h, s, v, zTad[0], zTad[zDimCstride], zTad[2 * zDimCstride]); hsvToRgb<T>(h, s, v, zTad[0], zTad[zDimCstride], zTad[2 * zDimCstride]);

View File

@ -58,11 +58,11 @@ static void _CUDA_G adjustHueCuda(const void* vx, const Nd4jLong* xShapeInfo, co
rgbToHsv<T>(xTad[0], xTad[xDimCstride], xTad[2 * xDimCstride], h, s, v); rgbToHsv<T>(xTad[0], xTad[xDimCstride], xTad[2 * xDimCstride], h, s, v);
h += delta * 360; h += delta ;
if(h > 360) if(h > 1)
h -= 360; h -= 1;
else if(h < 0) else if(h < 0)
h += 360; h += 1;
hsvToRgb<T>(h, s, v, zTad[0], zTad[zDimCstride], zTad[2 * zDimCstride]); hsvToRgb<T>(h, s, v, zTad[0], zTad[zDimCstride], zTad[2 * zDimCstride]);
} }

View File

@ -24,7 +24,7 @@
#include <NDArray.h> #include <NDArray.h>
#include <ops/ops.h> #include <ops/ops.h>
#include <GradCheck.h> #include <GradCheck.h>
#include <memory>
using namespace nd4j; using namespace nd4j;
@ -432,7 +432,7 @@ TEST_F(DeclarableOpsTests13, adjustHue_1) {
NDArray exp ('c', {2,2,3}, {100,0,44, 208,5,220, 177,230,97, 2,255,244}, nd4j::DataType::FLOAT32); NDArray exp ('c', {2,2,3}, {100,0,44, 208,5,220, 177,230,97, 2,255,244}, nd4j::DataType::FLOAT32);
nd4j::ops::adjust_hue op; nd4j::ops::adjust_hue op;
auto results = op.execute({&input, &factor}, {}, {2}); std::unique_ptr<nd4j::ResultSet> results (op.execute({&input, &factor}, {}, {2}));
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -442,17 +442,18 @@ TEST_F(DeclarableOpsTests13, adjustHue_1) {
ASSERT_TRUE(exp.isSameShape(result)); ASSERT_TRUE(exp.isSameShape(result));
ASSERT_TRUE(exp.equalsTo(result)); ASSERT_TRUE(exp.equalsTo(result));
delete results;
} }
//////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests13, adjustHue_2) { TEST_F(DeclarableOpsTests13, adjustHue_2) {
NDArray input('c', {2,2,3}, {0,100,56, 17,220,5, 150,97,230, 255,2,13}, nd4j::DataType::FLOAT32); NDArray input('c', { 2,2,3 }, { 0.f,100.f / 255.f,56.f / 255.f, 17.f / 255.f,220.f / 255.f,5.f / 255.f, 150.f / 255.f,97.f / 255.f,230.f / 255.f, 255.f / 255.f,2.f / 255.f,13.f / 255.f }, nd4j::DataType::FLOAT32);
NDArray exp ('c', {2,2,3}, {4,100,0, 146,220,5, 97,123.8,230, 255,2,164.8}, nd4j::DataType::FLOAT32); NDArray exp('c', { 2,2,3 }, { 4.f / 255.f,100.f / 255.f,0.f, 146.f / 255.f,220.f / 255.f,5.f / 255.f, 97.f / 255.f,123.8f / 255.f,230.f / 255.f, 255.f / 255.f,2.f / 255.f,164.8f / 255.f }, nd4j::DataType::FLOAT32);
nd4j::ops::adjust_hue op; nd4j::ops::adjust_hue op;
auto results = op.execute({&input}, {0.9}, {2}); std::unique_ptr<nd4j::ResultSet> results(op.execute({&input}, {0.9}, {2}));
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -461,7 +462,7 @@ TEST_F(DeclarableOpsTests13, adjustHue_2) {
ASSERT_TRUE(exp.isSameShape(result)); ASSERT_TRUE(exp.isSameShape(result));
ASSERT_TRUE(exp.equalsTo(result)); ASSERT_TRUE(exp.equalsTo(result));
delete results;
} }
@ -472,7 +473,7 @@ TEST_F(DeclarableOpsTests13, adjustHue_3) {
NDArray exp ('c', {2,2,3}, {0.,84.,100., 5.,220.,122.0001, 229.8,97.,230., 255.,142.8002,2.}, nd4j::DataType::FLOAT32); NDArray exp ('c', {2,2,3}, {0.,84.,100., 5.,220.,122.0001, 229.8,97.,230., 255.,142.8002,2.}, nd4j::DataType::FLOAT32);
nd4j::ops::adjust_hue op; nd4j::ops::adjust_hue op;
auto results = op.execute({&input}, {-0.9}, {2}); std::unique_ptr<nd4j::ResultSet> results(op.execute({&input}, {-0.9}, {2}));
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -481,7 +482,7 @@ TEST_F(DeclarableOpsTests13, adjustHue_3) {
ASSERT_TRUE(exp.isSameShape(result)); ASSERT_TRUE(exp.isSameShape(result));
ASSERT_TRUE(exp.equalsTo(result)); ASSERT_TRUE(exp.equalsTo(result));
delete results;
} }
//////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////
@ -491,7 +492,7 @@ TEST_F(DeclarableOpsTests13, adjustHue_4) {
NDArray exp ('c', {2,3,2}, {100,208, 0,5, 44,220, 177,2, 230,255, 97,244}, nd4j::DataType::FLOAT32); NDArray exp ('c', {2,3,2}, {100,208, 0,5, 44,220, 177,2, 230,255, 97,244}, nd4j::DataType::FLOAT32);
nd4j::ops::adjust_hue op; nd4j::ops::adjust_hue op;
auto results = op.execute({&input}, {0.5}, {1}); std::unique_ptr<nd4j::ResultSet> results(op.execute({&input}, {0.5}, {1}));
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -500,7 +501,7 @@ TEST_F(DeclarableOpsTests13, adjustHue_4) {
ASSERT_TRUE(exp.isSameShape(result)); ASSERT_TRUE(exp.isSameShape(result));
ASSERT_TRUE(exp.equalsTo(result)); ASSERT_TRUE(exp.equalsTo(result));
delete results;
} }
//////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////
@ -510,7 +511,7 @@ TEST_F(DeclarableOpsTests13, adjustHue_5) {
NDArray exp ('c', {3,2,2}, {100,208, 177,2, 0,5, 230,255, 44,220, 97,244}, nd4j::DataType::FLOAT32); NDArray exp ('c', {3,2,2}, {100,208, 177,2, 0,5, 230,255, 44,220, 97,244}, nd4j::DataType::FLOAT32);
nd4j::ops::adjust_hue op; nd4j::ops::adjust_hue op;
auto results = op.execute({&input}, {0.5}, {0}); std::unique_ptr<nd4j::ResultSet> results(op.execute({&input}, {0.5}, {0}));
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -519,7 +520,7 @@ TEST_F(DeclarableOpsTests13, adjustHue_5) {
ASSERT_TRUE(exp.isSameShape(result)); ASSERT_TRUE(exp.isSameShape(result));
ASSERT_TRUE(exp.equalsTo(result)); ASSERT_TRUE(exp.equalsTo(result));
delete results;
} }
//////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////

View File

@ -239,42 +239,45 @@ TEST_F(DeclarableOpsTests16, test_reverse_1) {
} }
} }
TEST_F(DeclarableOpsTests16, test_rgb_to_hsv_1) { TEST_F(DeclarableOpsTests16, test_rgb_to_hsv_1) {
/* /*
test case generated by python colorsys and scaled to suit our needs test case generated by python colorsys and scaled to suit our needs
from colorsys import * from colorsys import *
from random import * from random import *
import numpy as np import numpy as np
rgbs = np.array([randint(0,255) for x in range(0,3*4*5)]).reshape([5,4,3]) rgbs = np.random.uniform(0,1, 5*4*3 ).astype('float32').reshape([5,4,3])
hsvs=np.apply_along_axis(lambda x: np.array(rgb_to_hsv(x[0]/255,x[1]/255,x[2]/255))*np.array([360,1,1]),2,rgbs) hsvs=np.apply_along_axis(lambda x: np.array(rgb_to_hsv(x[0],x[1],x[2])),2,rgbs)
rgbs.ravel() rgbs.ravel()
hsvs.ravel() hsvs.ravel()
*/ */
auto rgbs = NDArrayFactory::create<float>('c', { 5, 4, 3 }, auto rgbs = NDArrayFactory::create<float>('c', { 5, 4, 3 }, {
{ 0.545678377f, 0.725874603f, 0.413571358f, 0.644941628f, 0.517642438f,
213.f, 220.f, 164.f, 121.f, 180.f, 180.f, 18.f, 245.f, 75.f, 235.f, 76.f, 74.f, 168.f, 0.890151322f, 0.461456001f, 0.0869259685f, 0.928968489f, 0.588904262f,
50.f, 233.f, 191.f, 132.f, 100.f, 207.f, 37.f, 245.f, 77.f, 250.f, 182.f, 111.f, 52.f, 0.54742825f, 0.684074104f, 0.52110225f, 0.761800349f, 0.486593395f,
59.f, 193.f, 147.f, 137.f, 168.f, 103.f, 121.f, 48.f, 191.f, 187.f, 53.f, 82.f, 239.f, 0.753103435f, 0.237176552f, 0.263826847f, 0.913557053f, 0.90049392f,
156.f, 37.f, 118.f, 244.f, 90.f, 7.f, 221.f, 98.f, 243.f, 12.f, 209.f, 192.f, 2.f, 0.290193319f, 0.46850124f, 0.965541422f, 0.148351923f, 0.674094439f,
115.f, 205.f, 79.f, 247.f, 32.f, 70.f, 152.f, 180.f 0.524110138f, 0.216262609f, 0.0361763388f, 0.2204483f, 0.279114306f,
0.3721793f, 0.632020354f, 0.25007084f, 0.823592246f, 0.637001634f,
0.30433768f, 0.0448598303f, 0.385092884f, 0.366362303f, 0.586083114f,
0.218390301f, 0.931746006f, 0.978048146f, 0.762684941f, 0.00208298792f,
0.91390729f, 0.505838513f, 0.875348926f, 0.428009957f, 0.367065936f,
0.911922634f, 0.270003974f, 0.164243385f, 0.0581932105f, 0.313204288f,
0.644775152f, 0.437950462f, 0.775881767f, 0.575452209f, 0.946475744f
}); });
auto expected = NDArrayFactory::create<float>('c', { 5, 4, 3 }, auto expected = NDArrayFactory::create<float>('c', { 5, 4, 3 }, {
{ 0.262831867f, 0.430244058f, 0.725874603f, 0.723622441f, 0.418478161f,
6.75000000e+01f, 2.54545455e-01f, 8.62745098e-01f, 1.80000000e+02f, 0.890151322f, 0.740797927f, 0.906427443f, 0.928968489f, 0.717254877f,
3.27777778e-01f, 7.05882353e-01f, 1.35066079e+02f, 9.26530612e-01f, 0.199753001f, 0.684074104f, 0.312434604f, 0.361258626f, 0.761800349f,
9.60784314e-01f, 7.45341615e-01f, 6.85106383e-01f, 9.21568627e-01f, 0.991390795f, 0.685067773f, 0.753103435f, 0.163174023f, 0.682347894f,
2.78688525e+02f, 7.85407725e-01f, 9.13725490e-01f, 2.10989011e+01f, 0.913557053f, 0.268038541f, 0.84635365f, 0.965541422f, 0.112067183f,
4.76439791e-01f, 7.49019608e-01f, 2.89038462e+02f, 8.48979592e-01f, 0.679180562f, 0.674094439f, 0.540247589f, 0.870388806f, 0.279114306f,
9.60784314e-01f, 1.56416185e+02f, 6.92000000e-01f, 9.80392157e-01f, 0.280050347f, 0.604331017f, 0.632020354f, 0.106776128f, 0.630475283f,
3.52881356e+02f, 5.31531532e-01f, 4.35294118e-01f, 1.07142857e+01f, 0.823592246f, 0.490824632f, 0.883509099f, 0.385092884f, 0.75257351f,
2.90155440e-01f, 7.56862745e-01f, 3.43384615e+02f, 3.86904762e-01f, 0.765611768f, 0.931746006f, 0.129888852f, 0.997870266f, 0.978048146f,
6.58823529e-01f, 1.78321678e+02f, 7.48691099e-01f, 7.49019608e-01f, 0.849081645f, 0.446510047f, 0.91390729f, 0.685308874f, 0.597481251f,
2.30645161e+02f, 7.78242678e-01f, 9.37254902e-01f, 3.19159664e+02f, 0.911922634f, 0.0834472676f, 0.784472764f, 0.270003974f, 0.396037966f,
7.62820513e-01f, 6.11764706e-01f, 2.10126582e+01f, 9.71311475e-01f, 0.514242649f, 0.644775152f, 0.756701186f, 0.392005324f, 0.946475744f
9.56862745e-01f, 2.90896552e+02f, 5.96707819e-01f, 9.52941176e-01f,
1.74822335e+02f, 9.42583732e-01f, 8.19607843e-01f, 2.06600985e+02f,
9.90243902e-01f, 8.03921569e-01f, 1.06883721e+02f, 8.70445344e-01f,
9.68627451e-01f, 1.95272727e+02f, 6.11111111e-01f, 7.05882353e-01f
}); });
@ -302,31 +305,33 @@ TEST_F(DeclarableOpsTests16, test_rgb_to_hsv_2) {
swapped_rgbs=rgbs.swapaxes(1,2).ravel() swapped_rgbs=rgbs.swapaxes(1,2).ravel()
swapped_hsvs=hsvs.swapaxes(1,2).ravel() swapped_hsvs=hsvs.swapaxes(1,2).ravel()
*/ */
auto rgbs = NDArrayFactory::create<float>('c', { 5,3,4 }, auto rgbs = NDArrayFactory::create<float>('c', { 5, 3, 4 }, {
{ 0.545678377f, 0.644941628f, 0.461456001f, 0.588904262f, 0.725874603f,
213.f, 121.f, 18.f, 235.f, 220.f, 180.f, 245.f, 76.f, 164.f, 180.f, 75.f, 74.f, 168.f, 0.517642438f, 0.0869259685f, 0.54742825f, 0.413571358f, 0.890151322f,
191.f, 207.f, 77.f, 50.f, 132.f, 37.f, 250.f, 233.f, 100.f, 245.f, 182.f, 111.f, 193.f, 0.928968489f, 0.684074104f, 0.52110225f, 0.753103435f, 0.913557053f,
168.f, 48.f, 52.f, 147.f, 103.f, 191.f, 59.f, 137.f, 121.f, 187.f, 53.f, 156.f, 244.f, 0.46850124f, 0.761800349f, 0.237176552f, 0.90049392f, 0.965541422f,
221.f, 82.f, 37.f, 90.f, 98.f, 239.f, 118.f, 7.f, 243.f, 12.f, 2.f, 79.f, 70.f, 0.486593395f, 0.263826847f, 0.290193319f, 0.148351923f, 0.674094439f,
209.f, 115.f, 247.f, 152.f, 192.f, 205.f, 32.f, 180.f 0.0361763388f, 0.3721793f, 0.823592246f, 0.524110138f, 0.2204483f,
0.632020354f, 0.637001634f, 0.216262609f, 0.279114306f, 0.25007084f,
0.30433768f, 0.0448598303f, 0.586083114f, 0.978048146f, 0.91390729f,
0.385092884f, 0.218390301f, 0.762684941f, 0.505838513f, 0.366362303f,
0.931746006f, 0.00208298792f, 0.875348926f, 0.428009957f, 0.270003974f,
0.313204288f, 0.775881767f, 0.367065936f, 0.164243385f, 0.644775152f,
0.575452209f, 0.911922634f, 0.0581932105f, 0.437950462f, 0.946475744f
}); });
auto expected = NDArrayFactory::create<float>('c', { 5,3,4 }, auto expected = NDArrayFactory::create<float>('c', { 5, 3, 4 }, {
{ 0.262831867f, 0.723622441f, 0.740797927f, 0.717254877f, 0.430244058f,
6.75000000e+01f, 1.80000000e+02f, 1.35066079e+02f, 7.45341615e-01f, 0.418478161f, 0.906427443f, 0.199753001f, 0.725874603f, 0.890151322f,
2.54545455e-01f, 3.27777778e-01f, 9.26530612e-01f, 6.85106383e-01f, 0.928968489f, 0.684074104f, 0.312434604f, 0.991390795f, 0.163174023f,
8.62745098e-01f, 7.05882353e-01f, 9.60784314e-01f, 9.21568627e-01f, 0.268038541f, 0.361258626f, 0.685067773f, 0.682347894f, 0.84635365f,
2.78688525e+02f, 2.10989011e+01f, 2.89038462e+02f, 1.56416185e+02f, 0.761800349f, 0.753103435f, 0.913557053f, 0.965541422f, 0.112067183f,
7.85407725e-01f, 4.76439791e-01f, 8.48979592e-01f, 6.92000000e-01f, 0.540247589f, 0.280050347f, 0.106776128f, 0.679180562f, 0.870388806f,
9.13725490e-01f, 7.49019608e-01f, 9.60784314e-01f, 9.80392157e-01f, 0.604331017f, 0.630475283f, 0.674094439f, 0.279114306f, 0.632020354f,
3.52881356e+02f, 1.07142857e+01f, 3.43384615e+02f, 1.78321678e+02f, 0.823592246f, 0.490824632f, 0.75257351f, 0.129888852f, 0.849081645f,
5.31531532e-01f, 2.90155440e-01f, 3.86904762e-01f, 7.48691099e-01f, 0.883509099f, 0.765611768f, 0.997870266f, 0.446510047f, 0.385092884f,
4.35294118e-01f, 7.56862745e-01f, 6.58823529e-01f, 7.49019608e-01f, 0.931746006f, 0.978048146f, 0.91390729f, 0.685308874f, 0.0834472676f,
2.30645161e+02f, 3.19159664e+02f, 2.10126582e+01f, 2.90896552e+02f, 0.396037966f, 0.756701186f, 0.597481251f, 0.784472764f, 0.514242649f,
7.78242678e-01f, 7.62820513e-01f, 9.71311475e-01f, 5.96707819e-01f, 0.392005324f, 0.911922634f, 0.270003974f, 0.644775152f, 0.946475744f
9.37254902e-01f, 6.11764706e-01f, 9.56862745e-01f, 9.52941176e-01f,
1.74822335e+02f, 2.06600985e+02f, 1.06883721e+02f, 1.95272727e+02f,
9.42583732e-01f, 9.90243902e-01f, 8.70445344e-01f, 6.11111111e-01f,
8.19607843e-01f, 8.03921569e-01f, 9.68627451e-01f, 7.05882353e-01f
}); });
@ -345,22 +350,19 @@ TEST_F(DeclarableOpsTests16, test_rgb_to_hsv_2) {
} }
TEST_F(DeclarableOpsTests16, test_rgb_to_hsv_3) { TEST_F(DeclarableOpsTests16, test_rgb_to_hsv_3) {
/*
2D
*/
auto rgbs = NDArrayFactory::create<float>('c', { 8,3 },
{ 130.f, 61.f, 239.f, 117.f, 16.f, 168.f, 181.f, 223.f, 0.f, 49.f, 195.f, 195.f, 131.f,
153.f, 78.f, 86.f, 21.f, 209.f, 101.f, 14.f, 107.f, 191.f, 98.f, 210.f });
auto expected = NDArrayFactory::create<float>('c', { 8,3 },
{ 263.25842697f, 0.74476987f, 0.9372549f, 279.86842105f,
0.9047619f, 0.65882353f, 71.30044843f, 1.f,
0.8745098f, 180.f, 0.74871795f, 0.76470588f,
77.6f, 0.49019608f, 0.6f, 260.74468085f,
0.89952153f, 0.81960784f, 296.12903226f, 0.86915888f,
0.41960784f, 289.82142857f, 0.53333333f, 0.82352941f });
auto rgbs = NDArrayFactory::create<float>('c', { 4, 3 }, {
0.545678377f, 0.725874603f, 0.413571358f, 0.644941628f, 0.517642438f,
0.890151322f, 0.461456001f, 0.0869259685f, 0.928968489f, 0.588904262f,
0.54742825f, 0.684074104f
});
auto expected = NDArrayFactory::create<float>('c', { 4, 3 }, {
0.262831867f, 0.430244058f, 0.725874603f, 0.723622441f, 0.418478161f,
0.890151322f, 0.740797927f, 0.906427443f, 0.928968489f, 0.717254877f,
0.199753001f, 0.684074104f
});
auto actual = NDArrayFactory::create<float>('c', { 8,3 }); auto actual = NDArrayFactory::create<float>('c', { 4, 3 });
Context ctx(1); Context ctx(1);
ctx.setInputArray(0, &rgbs); ctx.setInputArray(0, &rgbs);
@ -368,12 +370,7 @@ TEST_F(DeclarableOpsTests16, test_rgb_to_hsv_3) {
nd4j::ops::rgb_to_hsv op; nd4j::ops::rgb_to_hsv op;
auto status = op.execute(&ctx); auto status = op.execute(&ctx);
#if 0
//visual check
rgbs.printBuffer("rgbs ");
actual.printBuffer("HSV ");
expected.printBuffer("exp");
#endif
ASSERT_EQ(ND4J_STATUS_OK, status); ASSERT_EQ(ND4J_STATUS_OK, status);
ASSERT_TRUE(expected.equalsTo(actual)); ASSERT_TRUE(expected.equalsTo(actual));
@ -381,22 +378,18 @@ TEST_F(DeclarableOpsTests16, test_rgb_to_hsv_3) {
TEST_F(DeclarableOpsTests16, test_rgb_to_hsv_4) { TEST_F(DeclarableOpsTests16, test_rgb_to_hsv_4) {
/* auto rgbs = NDArrayFactory::create<float>('c', { 3, 4 }, {
2D 0.545678377f, 0.644941628f, 0.461456001f, 0.588904262f, 0.725874603f,
*/ 0.517642438f, 0.0869259685f, 0.54742825f, 0.413571358f, 0.890151322f,
auto rgbs = NDArrayFactory::create<float>('c', { 3,8 }, 0.928968489f, 0.684074104f
{ 130.f, 117.f, 181.f, 49.f, 131.f, 86.f, 101.f, 191.f, 61.f, 16.f, 223.f, 195.f, 153.f, });
21.f, 14.f, 98.f, 239.f, 168.f, 0.f, 195.f, 78.f, 209.f, 107.f, 210.f }); auto expected = NDArrayFactory::create<float>('c', { 3, 4 }, {
auto expected = NDArrayFactory::create<float>('c', { 3, 8 }, 0.262831867f, 0.723622441f, 0.740797927f, 0.717254877f, 0.430244058f,
{ 263.25842697f, 279.86842105f, 71.30044843f, 180.f, 0.418478161f, 0.906427443f, 0.199753001f, 0.725874603f, 0.890151322f,
77.6f, 260.74468085f, 296.12903226f, 289.82142857f, 0.928968489f, 0.684074104f
0.74476987f, 0.9047619f, 1.f, 0.74871795f, });
0.49019608f, 0.89952153f, 0.86915888f, 0.53333333f,
0.9372549f, 0.65882353f, 0.8745098f, 0.76470588f,
0.6f, 0.81960784f, 0.41960784f, 0.82352941f });
auto actual = NDArrayFactory::create<float>('c', { 3, 4 });
auto actual = NDArrayFactory::create<float>('c', { 3, 8 });
Context ctx(1); Context ctx(1);
ctx.setInputArray(0, &rgbs); ctx.setInputArray(0, &rgbs);
@ -404,26 +397,19 @@ TEST_F(DeclarableOpsTests16, test_rgb_to_hsv_4) {
ctx.setIArguments({ 0 }); ctx.setIArguments({ 0 });
nd4j::ops::rgb_to_hsv op; nd4j::ops::rgb_to_hsv op;
auto status = op.execute(&ctx); auto status = op.execute(&ctx);
#if 0
//visual check
rgbs.printBuffer("rgbs ");
actual.printBuffer("HSV ");
expected.printBuffer("exp");
#endif
ASSERT_EQ(ND4J_STATUS_OK, status); ASSERT_EQ(ND4J_STATUS_OK, status);
ASSERT_TRUE(expected.equalsTo(actual)); ASSERT_TRUE(expected.equalsTo(actual));
} }
TEST_F(DeclarableOpsTests16, test_rgb_to_hsv_5) { TEST_F(DeclarableOpsTests16, test_rgb_to_hsv_5) {
/* auto rgbs = NDArrayFactory::create<float>('c', { 3 }, {
0.545678377f, 0.725874603f, 0.413571358f
*/ });
auto rgbs = NDArrayFactory::create<float>('c', { 3 }, auto expected = NDArrayFactory::create<float>('c', { 3 }, {
{ 213.f, 220.f, 164.f }); 0.262831867f, 0.430244058f, 0.725874603f
auto expected = NDArrayFactory::create<float>('c', { 3 }, });
{ 6.75000000e+01f, 2.54545455e-01f, 8.62745098e-01f });
auto actual = NDArrayFactory::create<float>('c', { 3 }); auto actual = NDArrayFactory::create<float>('c', { 3 });
@ -433,12 +419,7 @@ TEST_F(DeclarableOpsTests16, test_rgb_to_hsv_5) {
nd4j::ops::rgb_to_hsv op; nd4j::ops::rgb_to_hsv op;
auto status = op.execute(&ctx); auto status = op.execute(&ctx);
#if 0
//visual check
rgbs.printBuffer("rgbs ");
actual.printBuffer("HSV ");
expected.printBuffer("exp");
#endif
ASSERT_EQ(ND4J_STATUS_OK, status); ASSERT_EQ(ND4J_STATUS_OK, status);
ASSERT_TRUE(expected.equalsTo(actual)); ASSERT_TRUE(expected.equalsTo(actual));
@ -446,19 +427,22 @@ TEST_F(DeclarableOpsTests16, test_rgb_to_hsv_5) {
TEST_F(DeclarableOpsTests16, test_rgb_to_hsv_6) { TEST_F(DeclarableOpsTests16, test_rgb_to_hsv_6) {
/* auto rgbs = NDArrayFactory::create<float>('c', { 3, 4 }, {
0.545678377f, 0.644941628f, 0.461456001f, 0.588904262f, 0.725874603f,
*/ 0.517642438f, 0.0869259685f, 0.54742825f, 0.413571358f, 0.890151322f,
auto rgbs = NDArrayFactory::create<float>('c', { 3,8 }, 0.928968489f, 0.684074104f
{ 130.f, 117.f, 181.f, 49.f, 131.f, 86.f, 101.f, 191.f, 61.f, 16.f, 223.f, 195.f, 153.f, });
21.f, 14.f, 98.f, 239.f, 168.f, 0.f, 195.f, 78.f, 209.f, 107.f, 210.f }); auto hsvs = NDArrayFactory::create<float>('c', { 3, 4 }, {
0.262831867f, 0.723622441f, 0.740797927f, 0.717254877f, 0.430244058f,
auto expected = NDArrayFactory::create<float>('c', { 3 }, 0.418478161f, 0.906427443f, 0.199753001f, 0.725874603f, 0.890151322f,
{ 263.25842697f, 0.74476987f, 0.9372549f }); 0.928968489f, 0.684074104f
});
//get subarray //get subarray
std::unique_ptr<NDArray> subArrRgbs(rgbs.subarray({ NDIndex::all(), NDIndex::point(0) })); std::unique_ptr<NDArray> subArrRgbs(rgbs.subarray({ NDIndex::all(), NDIndex::point(0) }));
std::unique_ptr<NDArray> expected(hsvs.subarray({ NDIndex::all(), NDIndex::point(0) }));
subArrRgbs->reshapei({ 3 }); subArrRgbs->reshapei({ 3 });
expected->reshapei({ 3 });
#if 0 #if 0
//[RANK][SHAPE][STRIDES][OPTIONS][EWS][ORDER] //[RANK][SHAPE][STRIDES][OPTIONS][EWS][ORDER]
subArrRgbs->printShapeInfo("subArrRgbs"); subArrRgbs->printShapeInfo("subArrRgbs");
@ -468,48 +452,43 @@ TEST_F(DeclarableOpsTests16, test_rgb_to_hsv_6) {
Context ctx(1); Context ctx(1);
ctx.setInputArray(0, subArrRgbs.get()); ctx.setInputArray(0, subArrRgbs.get());
ctx.setOutputArray(0, &actual); ctx.setOutputArray(0, &actual);
ctx.setIArguments({ 0 });
nd4j::ops::rgb_to_hsv op; nd4j::ops::rgb_to_hsv op;
auto status = op.execute(&ctx); auto status = op.execute(&ctx);
#if 0
//visual check
subArrRgbs->printBuffer("subArrRgbs ");
actual.printBuffer("HSV ");
expected.printBuffer("exp");
#endif
ASSERT_EQ(ND4J_STATUS_OK, status); ASSERT_EQ(ND4J_STATUS_OK, status);
ASSERT_TRUE(expected.equalsTo(actual)); ASSERT_TRUE(expected->equalsTo(actual));
} }
TEST_F(DeclarableOpsTests16, test_hsv_to_rgb_1) { TEST_F(DeclarableOpsTests16, test_hsv_to_rgb_1) {
/*
using the same numbers of rgb_to_hsv_1 test auto hsvs = NDArrayFactory::create<float>('c', { 5, 4, 3 }, {
*/ 0.705504596f, 0.793608069f, 0.65870738f, 0.848827183f, 0.920532584f,
auto expected = NDArrayFactory::create<float>('c', { 5,4,3 }, 0.887555957f, 0.72317636f, 0.563831031f, 0.773604929f, 0.269532293f,
{ 213.f, 220.f, 164.f, 121.f, 180.f, 180.f, 18.f, 245.f, 75.f, 235.f, 76.f, 74.f, 168.f, 0.332347751f, 0.111181192f, 0.239250854f, 0.499201417f, 0.862712979f,
50.f, 233.f, 191.f, 132.f, 100.f, 207.f, 37.f, 245.f, 77.f, 250.f, 182.f, 111.f, 52.f, 0.0853395388f, 0.0810681432f, 0.226065159f, 0.851340771f, 0.602043271f,
59.f, 193.f, 147.f, 137.f, 168.f, 103.f, 121.f, 48.f, 191.f, 187.f, 53.f, 82.f, 239.f, 0.690895379f, 0.971996486f, 0.273846686f, 0.464318275f, 0.194078103f,
156.f, 37.f, 118.f, 244.f, 90.f, 7.f, 221.f, 98.f, 243.f, 12.f, 209.f, 192.f, 2.f, 0.219649255f, 0.616706491f, 0.847525477f, 0.653597355f, 0.700065672f,
115.f, 205.f, 79.f, 247.f, 32.f, 70.f, 152.f, 180.f } 0.0299375951f, 0.184475258f, 0.274936169f, 0.196718201f, 0.179381892f,
); 0.934476376f, 0.895766437f, 0.52967906f, 0.675635338f, 0.966644645f,
auto hsvs = NDArrayFactory::create<float>('c', { 5,4,3 }, 0.770889699f, 0.556649387f, 0.13426739f, 0.899450243f, 0.817096591f,
{ 0.150202557f, 0.763557851f, 0.709604502f, 0.741747797f, 0.657703638f,
6.75000000e+01f, 2.54545455e-01f, 8.62745098e-01f, 1.80000000e+02f, 0.167678103f, 0.828556478f, 0.615502477f, 0.478080243f, 0.447288662f,
3.27777778e-01f, 7.05882353e-01f, 1.35066079e+02f, 9.26530612e-01f, 0.864299297f, 0.129833668f, 0.66402483f, 0.795475543f, 0.561332941f
9.60784314e-01f, 7.45341615e-01f, 6.85106383e-01f, 9.21568627e-01f, });
2.78688525e+02f, 7.85407725e-01f, 9.13725490e-01f, 2.10989011e+01f, auto expected = NDArrayFactory::create<float>('c', { 5, 4, 3 }, {
4.76439791e-01f, 7.49019608e-01f, 2.89038462e+02f, 8.48979592e-01f, 0.257768334f, 0.135951888f, 0.65870738f, 0.887555957f, 0.0705317783f,
9.60784314e-01f, 1.56416185e+02f, 6.92000000e-01f, 9.80392157e-01f, 0.811602857f, 0.485313689f, 0.337422464f, 0.773604929f, 0.0883753772f,
3.52881356e+02f, 5.31531532e-01f, 4.35294118e-01f, 1.07142857e+01f, 0.111181192f, 0.074230373f, 0.675155059f, 0.862712979f, 0.432045438f,
2.90155440e-01f, 7.56862745e-01f, 3.43384615e+02f, 3.86904762e-01f, 0.226065159f, 0.21712242f, 0.207738476f, 0.690895379f, 0.274946465f,
6.58823529e-01f, 1.78321678e+02f, 7.48691099e-01f, 7.49019608e-01f, 0.645954334f, 0.464318275f, 0.337166255f, 0.358530475f, 0.594427716f,
2.30645161e+02f, 7.78242678e-01f, 9.37254902e-01f, 3.19159664e+02f, 0.616706491f, 0.481247369f, 0.700065672f, 0.242504601f, 0.661103036f,
7.62820513e-01f, 6.11764706e-01f, 2.10126582e+01f, 9.71311475e-01f, 0.274936169f, 0.233327664f, 0.224217249f, 0.904251479f, 0.934476376f,
9.56862745e-01f, 2.90896552e+02f, 5.96707819e-01f, 9.52941176e-01f, 0.766848235f, 0.675635338f, 0.317765447f, 0.54157777f, 0.556649387f,
1.74822335e+02f, 9.42583732e-01f, 8.19607843e-01f, 2.06600985e+02f, 0.127534108f, 0.213413864f, 0.817096591f, 0.674227886f, 0.0821588641f,
9.90243902e-01f, 8.03921569e-01f, 1.06883721e+02f, 8.70445344e-01f, 0.709604502f, 0.656080596f, 0.167780413f, 0.107076412f, 0.0573956046f,
9.68627451e-01f, 1.95272727e+02f, 6.11111111e-01f, 7.05882353e-01f 0.167678103f, 0.46964643f, 0.183820669f, 0.478080243f, 0.01761852f,
0.129833668f, 0.0943436049f, 0.114806315f, 0.121884218f, 0.561332941f
}); });
@ -528,36 +507,34 @@ TEST_F(DeclarableOpsTests16, test_hsv_to_rgb_1) {
} }
TEST_F(DeclarableOpsTests16, test_hsv_to_rgb_2) { TEST_F(DeclarableOpsTests16, test_hsv_to_rgb_2) {
/* auto hsvs = NDArrayFactory::create<float>('c', { 5, 3, 4 }, {
using the same numbers of hsv_to_rgb_2 0.705504596f, 0.848827183f, 0.72317636f, 0.269532293f, 0.793608069f,
*/ 0.920532584f, 0.563831031f, 0.332347751f, 0.65870738f, 0.887555957f,
auto expected = NDArrayFactory::create<float>('c', { 5,3,4 }, 0.773604929f, 0.111181192f, 0.239250854f, 0.0853395388f, 0.851340771f,
{ 213.f, 121.f, 18.f, 235.f, 220.f, 180.f, 245.f, 76.f, 164.f, 180.f, 75.f, 74.f, 168.f, 0.971996486f, 0.499201417f, 0.0810681432f, 0.602043271f, 0.273846686f,
191.f, 207.f, 77.f, 50.f, 132.f, 37.f, 250.f, 233.f, 100.f, 245.f, 182.f, 111.f, 193.f, 0.862712979f, 0.226065159f, 0.690895379f, 0.464318275f, 0.194078103f,
168.f, 48.f, 52.f, 147.f, 103.f, 191.f, 59.f, 137.f, 121.f, 187.f, 53.f, 156.f, 244.f, 0.847525477f, 0.0299375951f, 0.196718201f, 0.219649255f, 0.653597355f,
221.f, 82.f, 37.f, 90.f, 98.f, 239.f, 118.f, 7.f, 243.f, 12.f, 2.f, 79.f, 70.f, 0.184475258f, 0.179381892f, 0.616706491f, 0.700065672f, 0.274936169f,
209.f, 115.f, 247.f, 152.f, 192.f, 205.f, 32.f, 180.f } 0.934476376f, 0.895766437f, 0.966644645f, 0.13426739f, 0.150202557f,
); 0.52967906f, 0.770889699f, 0.899450243f, 0.763557851f, 0.675635338f,
auto hsvs = NDArrayFactory::create<float>('c', { 5,3,4 }, 0.556649387f, 0.817096591f, 0.709604502f, 0.741747797f, 0.828556478f,
{ 0.447288662f, 0.66402483f, 0.657703638f, 0.615502477f, 0.864299297f,
6.75000000e+01f, 1.80000000e+02f, 1.35066079e+02f, 7.45341615e-01f, 0.795475543f, 0.167678103f, 0.478080243f, 0.129833668f, 0.561332941f
2.54545455e-01f, 3.27777778e-01f, 9.26530612e-01f, 6.85106383e-01f, });
8.62745098e-01f, 7.05882353e-01f, 9.60784314e-01f, 9.21568627e-01f, auto expected = NDArrayFactory::create<float>('c', { 5, 3, 4 }, {
2.78688525e+02f, 2.10989011e+01f, 2.89038462e+02f, 1.56416185e+02f, 0.257768334f, 0.887555957f, 0.485313689f, 0.0883753772f, 0.135951888f,
7.85407725e-01f, 4.76439791e-01f, 8.48979592e-01f, 6.92000000e-01f, 0.0705317783f, 0.337422464f, 0.111181192f, 0.65870738f, 0.811602857f,
9.13725490e-01f, 7.49019608e-01f, 9.60784314e-01f, 9.80392157e-01f, 0.773604929f, 0.074230373f, 0.675155059f, 0.226065159f, 0.690895379f,
3.52881356e+02f, 1.07142857e+01f, 3.43384615e+02f, 1.78321678e+02f, 0.464318275f, 0.862712979f, 0.21712242f, 0.274946465f, 0.337166255f,
5.31531532e-01f, 2.90155440e-01f, 3.86904762e-01f, 7.48691099e-01f, 0.432045438f, 0.207738476f, 0.645954334f, 0.358530475f, 0.594427716f,
4.35294118e-01f, 7.56862745e-01f, 6.58823529e-01f, 7.49019608e-01f, 0.700065672f, 0.274936169f, 0.904251479f, 0.616706491f, 0.242504601f,
2.30645161e+02f, 3.19159664e+02f, 2.10126582e+01f, 2.90896552e+02f, 0.233327664f, 0.934476376f, 0.481247369f, 0.661103036f, 0.224217249f,
7.78242678e-01f, 7.62820513e-01f, 9.71311475e-01f, 5.96707819e-01f, 0.766848235f, 0.675635338f, 0.556649387f, 0.817096591f, 0.709604502f,
9.37254902e-01f, 6.11764706e-01f, 9.56862745e-01f, 9.52941176e-01f, 0.317765447f, 0.127534108f, 0.674227886f, 0.656080596f, 0.54157777f,
1.74822335e+02f, 2.06600985e+02f, 1.06883721e+02f, 1.95272727e+02f, 0.213413864f, 0.0821588641f, 0.167780413f, 0.107076412f, 0.46964643f,
9.42583732e-01f, 9.90243902e-01f, 8.70445344e-01f, 6.11111111e-01f, 0.01761852f, 0.114806315f, 0.0573956046f, 0.183820669f, 0.129833668f,
8.19607843e-01f, 8.03921569e-01f, 9.68627451e-01f, 7.05882353e-01f 0.121884218f, 0.167678103f, 0.478080243f, 0.0943436049f, 0.561332941f
}); });
auto actual = NDArrayFactory::create<float>('c', { 5,3,4 }); auto actual = NDArrayFactory::create<float>('c', { 5,3,4 });
Context ctx(1); Context ctx(1);
@ -573,22 +550,17 @@ TEST_F(DeclarableOpsTests16, test_hsv_to_rgb_2) {
} }
TEST_F(DeclarableOpsTests16, test_hsv_to_rgb_3) { TEST_F(DeclarableOpsTests16, test_hsv_to_rgb_3) {
/* auto hsvs = NDArrayFactory::create<float>('c', { 4, 3 }, {
2D 0.705504596f, 0.793608069f, 0.65870738f, 0.848827183f, 0.920532584f,
*/ 0.887555957f, 0.72317636f, 0.563831031f, 0.773604929f, 0.269532293f,
auto expected = NDArrayFactory::create<float>('c', { 8,3 }, 0.332347751f, 0.111181192f
{ 130.f, 61.f, 239.f, 117.f, 16.f, 168.f, 181.f, 223.f, 0.f, 49.f, 195.f, 195.f, 131.f, });
153.f, 78.f, 86.f, 21.f, 209.f, 101.f, 14.f, 107.f, 191.f, 98.f, 210.f }); auto expected = NDArrayFactory::create<float>('c', { 4, 3 }, {
auto hsvs = NDArrayFactory::create<float>('c', { 8,3 }, 0.257768334f, 0.135951888f, 0.65870738f, 0.887555957f, 0.0705317783f,
{ 263.25842697f, 0.74476987f, 0.9372549f, 279.86842105f, 0.811602857f, 0.485313689f, 0.337422464f, 0.773604929f, 0.0883753772f,
0.9047619f, 0.65882353f, 71.30044843f, 1.f, 0.111181192f, 0.074230373f
0.8745098f, 180.f, 0.74871795f, 0.76470588f, });
77.6f, 0.49019608f, 0.6f, 260.74468085f, auto actual = NDArrayFactory::create<float>('c', { 4,3 });
0.89952153f, 0.81960784f, 296.12903226f, 0.86915888f,
0.41960784f, 289.82142857f, 0.53333333f, 0.82352941f });
auto actual = NDArrayFactory::create<float>('c', { 8,3 });
Context ctx(1); Context ctx(1);
ctx.setInputArray(0, &hsvs); ctx.setInputArray(0, &hsvs);
@ -602,23 +574,19 @@ TEST_F(DeclarableOpsTests16, test_hsv_to_rgb_3) {
} }
TEST_F(DeclarableOpsTests16, test_hsv_to_rgb_4) { TEST_F(DeclarableOpsTests16, test_hsv_to_rgb_4) {
/* auto hsvs = NDArrayFactory::create<float>('c', { 3, 4 }, {
2D 0.705504596f, 0.848827183f, 0.72317636f, 0.269532293f, 0.793608069f,
*/ 0.920532584f, 0.563831031f, 0.332347751f, 0.65870738f, 0.887555957f,
auto expected = NDArrayFactory::create<float>('c', { 3,8 }, 0.773604929f, 0.111181192f
{ 130.f, 117.f, 181.f, 49.f, 131.f, 86.f, 101.f, 191.f, 61.f, 16.f, 223.f, 195.f, 153.f, });
21.f, 14.f, 98.f, 239.f, 168.f, 0.f, 195.f, 78.f, 209.f, 107.f, 210.f }); auto expected = NDArrayFactory::create<float>('c', { 3, 4 }, {
auto hsvs = NDArrayFactory::create<float>('c', { 3,8 }, 0.257768334f, 0.887555957f, 0.485313689f, 0.0883753772f, 0.135951888f,
{ 263.25842697f, 279.86842105f, 71.30044843f, 180.f, 0.0705317783f, 0.337422464f, 0.111181192f, 0.65870738f, 0.811602857f,
77.6f, 260.74468085f, 296.12903226f, 289.82142857f, 0.773604929f, 0.074230373f
0.74476987f, 0.9047619f, 1.f, 0.74871795f, });
0.49019608f, 0.89952153f, 0.86915888f, 0.53333333f, auto actual = NDArrayFactory::create<float>('c', { 3, 4 });
0.9372549f, 0.65882353f, 0.8745098f, 0.76470588f,
0.6f, 0.81960784f, 0.41960784f, 0.82352941f });
auto actual = NDArrayFactory::create<float>('c', { 3,8 });
Context ctx(1); Context ctx(1);
ctx.setInputArray(0, &hsvs); ctx.setInputArray(0, &hsvs);
@ -633,14 +601,13 @@ TEST_F(DeclarableOpsTests16, test_hsv_to_rgb_4) {
} }
TEST_F(DeclarableOpsTests16, test_hsv_to_rgb_5) { TEST_F(DeclarableOpsTests16, test_hsv_to_rgb_5) {
/*
*/
auto expected = NDArrayFactory::create<float>('c', { 3 },
{ 213.f, 220.f, 164.f });
auto hsvs = NDArrayFactory::create<float>('c', { 3 },
{ 6.75000000e+01f, 2.54545455e-01f, 8.62745098e-01f });
auto hsvs = NDArrayFactory::create<float>('c', { 3 }, {
0.705504596f, 0.793608069f, 0.65870738f
});
auto expected = NDArrayFactory::create<float>('c', { 3 }, {
0.257768334f, 0.135951888f, 0.65870738f
});
auto actual = NDArrayFactory::create<float>('c', { 3 }); auto actual = NDArrayFactory::create<float>('c', { 3 });
@ -656,40 +623,38 @@ TEST_F(DeclarableOpsTests16, test_hsv_to_rgb_5) {
} }
TEST_F(DeclarableOpsTests16, test_hsv_to_rgb_6) { TEST_F(DeclarableOpsTests16, test_hsv_to_rgb_6) {
auto expected = NDArrayFactory::create<double>('c', { 3 }, auto hsvs = NDArrayFactory::create<float>('c', { 3, 4 }, {
{ 130.0, 61.0, 239.0 }); 0.705504596f, 0.848827183f, 0.72317636f, 0.269532293f, 0.793608069f,
auto hsvs = NDArrayFactory::create<double>('c', { 3,8 }, 0.920532584f, 0.563831031f, 0.332347751f, 0.65870738f, 0.887555957f,
{ 263.25842697, 279.86842105, 71.30044843, 180, 0.773604929f, 0.111181192f
77.6, 260.74468085, 296.12903226, 289.82142857, });
0.74476987, 0.9047619, 1., 0.74871795, auto rgbs = NDArrayFactory::create<float>('c', { 3, 4 }, {
0.49019608, 0.89952153, 0.86915888, 0.53333333, 0.257768334f, 0.887555957f, 0.485313689f, 0.0883753772f, 0.135951888f,
0.9372549, 0.65882353, 0.8745098, 0.76470588, 0.0705317783f, 0.337422464f, 0.111181192f, 0.65870738f, 0.811602857f,
0.6, 0.81960784, 0.41960784, 0.82352941 0.773604929f, 0.074230373f
}); });
auto actual = NDArrayFactory::create<float>('c', { 3 });
//get subarray //get subarray
std::unique_ptr<NDArray> subArrHsvs(hsvs.subarray({ NDIndex::all(), NDIndex::point(0) })); std::unique_ptr<NDArray> subArrHsvs(hsvs.subarray({ NDIndex::all(), NDIndex::point(0) }));
subArrHsvs->reshapei({ 3 }); subArrHsvs->reshapei({ 3 });
std::unique_ptr<NDArray> expected(rgbs.subarray({ NDIndex::all(), NDIndex::point(0) }));
expected->reshapei({ 3 });
#if 0 #if 0
//[RANK][SHAPE][STRIDES][OPTIONS][EWS][ORDER] //[RANK][SHAPE][STRIDES][OPTIONS][EWS][ORDER]
subArrHsvs->printShapeInfo("subArrHsvs"); subArrHsvs->printShapeInfo("subArrHsvs");
#endif #endif
auto actual = NDArrayFactory::create<double>('c', { 3 });
Context ctx(1); Context ctx(1);
ctx.setInputArray(0, subArrHsvs.get()); ctx.setInputArray(0, subArrHsvs.get());
ctx.setOutputArray(0, &actual); ctx.setOutputArray(0, &actual);
nd4j::ops::hsv_to_rgb op; nd4j::ops::hsv_to_rgb op;
auto status = op.execute(&ctx); auto status = op.execute(&ctx);
#if 0
//visual check
subArrHsvs->printBuffer("subArrHsvs ");
actual.printBuffer("rgb ");
expected.printBuffer("exp");
#endif
ASSERT_EQ(ND4J_STATUS_OK, status);
ASSERT_TRUE(expected.equalsTo(actual));
} ASSERT_EQ(ND4J_STATUS_OK, status);
ASSERT_TRUE(expected->equalsTo(actual));
}