Shugeo atomic tests (#97)

* Added atomic tests for atomicAdd, atomicSub and atomicDiv.

* Fixed atomicAdd for 16bit ints.

* Fixed atomicMul for 16 floats.

* Eliminated waste prints.

* Fixed problems with double type on matrix inverse helepers.

* Eliminated commented wrong code.

* Refactored atomicMul for 16bit types.

* few more minor tweaks

Signed-off-by: raver119 <raver119@gmail.com>

* Fixed fake_quant_with_min_max_vars_per_channel args processing.
master
shugeo 2019-12-02 20:40:54 +02:00 committed by raver119
parent 25b3cd9b80
commit 1e9ff114aa
5 changed files with 338 additions and 41 deletions

View File

@ -25,13 +25,12 @@
#include <ops/declarable/helpers/fake_quantization.h> #include <ops/declarable/helpers/fake_quantization.h>
namespace nd4j { namespace nd4j {
namespace ops { namespace ops {
CONFIGURABLE_OP_IMPL(fake_quant_with_min_max_vars_per_channel, 1, 1, true, 0, 0) { CONFIGURABLE_OP_IMPL(fake_quant_with_min_max_vars_per_channel, 3, 1, true, 0, 0) {
auto x = INPUT_VARIABLE(0); auto x = INPUT_VARIABLE(0);
auto min = INPUT_VARIABLE(1); auto min = INPUT_VARIABLE(1);
auto max = INPUT_VARIABLE(2); auto max = INPUT_VARIABLE(2);
REQUIRE_TRUE(block.width() == 3 || block.getTArguments()->size() == 2, 0, "fake_quant_with_min_max_vars_per_channel: No minimum/maximum values provided by either input arrays or TArgs");
auto depth = x->sizeAt(-1); auto depth = x->sizeAt(-1);
REQUIRE_TRUE(min->rankOf() == 1 && max->rankOf() == 1 && min->lengthOf() == max->lengthOf(), 0, REQUIRE_TRUE(min->rankOf() == 1 && max->rankOf() == 1 && min->lengthOf() == max->lengthOf(), 0,
"fake_quant_with_min_max_vars_per_channel: Min and Max should be 1D tensors with the same length"); "fake_quant_with_min_max_vars_per_channel: Min and Max should be 1D tensors with the same length");
@ -49,13 +48,13 @@ namespace nd4j {
numBits = INT_ARG(0); numBits = INT_ARG(0);
bool narrowed = false; bool narrowed = false;
//INT_ARG(1); //INT_ARG(1);
if (block.getIArguments()->size() == 2) { if (block.getBArguments() && block.getBArguments()->size()) {
numBits = INT_ARG(0); narrowed = B_ARG(0);
narrowed = INT_ARG(1);
REQUIRE_TRUE(numBits > 1 && numBits < 17, 0, "fake_quant_with_min_max_vars_per_channel: Number of bits"
" for quatization should be in between 2 and 16, but %i "
"was given.", numBits);
} }
REQUIRE_TRUE(numBits > 1 && numBits < 17, 0, "fake_quant_with_min_max_vars_per_channel: Number of bits"
" for quatization should be in between 2 and 16, but %i "
"was given.", numBits);
helpers::fakeQuantWithMinMaxVarsPerChannel(block.launchContext(), x, min, max, numBits, narrowed, output); helpers::fakeQuantWithMinMaxVarsPerChannel(block.launchContext(), x, min, max, numBits, narrowed, output);
return ND4J_STATUS_OK; return ND4J_STATUS_OK;
} }

View File

@ -110,12 +110,21 @@ namespace helpers {
template<typename T> template<typename T>
static __global__ void static __global__ void
invertLowKernel(void *invertedBuf, Nd4jLong *invertedShape, void *inputBuf, Nd4jLong *inputShape, Nd4jLong n) { invertLowKernel(void *invertedBuf, Nd4jLong *invertedShape, void *inputBuf, Nd4jLong *inputShape, Nd4jLong n) {
T *inverted = reinterpret_cast<T *>(invertedBuf); T *inverted = reinterpret_cast<T *>(invertedBuf);
T *input = reinterpret_cast<T *>(inputBuf); T *input = reinterpret_cast<T *>(inputBuf);
if (threadIdx.x == 0) {
inverted = reinterpret_cast<T *>(invertedBuf);
input = reinterpret_cast<T *>(inputBuf);
}
__syncthreads();
for (int i = blockIdx.x + 2; i < n; i += gridDim.x) { auto tid = blockIdx.x * blockDim.x + threadIdx.x;
auto step = gridDim.x * blockDim.x;
for (int i = tid + 2; i < n; i += step) {
for (int j = i - 2; j >= 0; --j) for (int j = i - 2; j >= 0; --j)
for (int k = threadIdx.x; k < i; k += blockDim.x) { for (int k = 0; k < i; k++) {
Nd4jLong posZ[] = {i, j}; Nd4jLong posZ[] = {i, j};
Nd4jLong posY[] = {k, j}; Nd4jLong posY[] = {k, j};
Nd4jLong posX[] = {i, k}; Nd4jLong posX[] = {i, k};
@ -144,10 +153,12 @@ namespace helpers {
input = reinterpret_cast<T *>(inputBuf); input = reinterpret_cast<T *>(inputBuf);
} }
__syncthreads(); __syncthreads();
auto tid = blockIdx.x * blockDim.x + threadIdx.x;
auto step = blockDim.x * gridDim.x;
for (int i = (int)n - blockIdx.x - 2; i >= 0; i -= gridDim.x) { for (int i = (int)n - tid - 2; i >= 0; i -= step) {
for (int j = i + 2; j < (int)n; j++) for (int j = i + 2; j < (int)n; j++)
for (int k = i + threadIdx.x; k < (int)n; k += blockDim.x) { for (int k = i; k < (int)n; k++) {
Nd4jLong posZ[] = {i, j}; Nd4jLong posZ[] = {i, j};
Nd4jLong posY[] = {k, j}; Nd4jLong posY[] = {k, j};
Nd4jLong posX[] = {i, k}; Nd4jLong posX[] = {i, k};
@ -498,8 +509,6 @@ namespace helpers {
fillMatrix<T, T><<<launchDims.x, launchDims.y, launchDims.z, *stream>>>(matrix.specialBuffer(), matrix.specialShapeInfo(), input->specialBuffer(), input->specialShapeInfo(), pos, n); fillMatrix<T, T><<<launchDims.x, launchDims.y, launchDims.z, *stream>>>(matrix.specialBuffer(), matrix.specialShapeInfo(), input->specialBuffer(), input->specialShapeInfo(), pos, n);
// else // else
// fillMatrix<T, float><<<launchDims.x, launchDims.y, launchDims.z, *stream>>>(matrix.specialBuffer(), matrix.specialShapeInfo(), input->specialBuffer(), input->specialShapeInfo(), pos, n); // fillMatrix<T, float><<<launchDims.x, launchDims.y, launchDims.z, *stream>>>(matrix.specialBuffer(), matrix.specialShapeInfo(), input->specialBuffer(), input->specialShapeInfo(), pos, n);
// if (matrix.dataType() == input->dataType())
lup_<T>(context, &matrix, nullptr, nullptr); lup_<T>(context, &matrix, nullptr, nullptr);
// else // else
// lup_<float>(context, &matrix, nullptr, nullptr); // lup_<float>(context, &matrix, nullptr, nullptr);
@ -627,9 +636,14 @@ namespace helpers {
for (auto i = 0LL; i < packX.numberOfTads(); i++) { for (auto i = 0LL; i < packX.numberOfTads(); i++) {
fillMatrix<T, T><<<1, n2, 1024, *stream>>>(matrix.specialBuffer(), matrix.specialShapeInfo(), input->specialBuffer(), input->specialShapeInfo(), i * n2, n); fillMatrix<T, T><<<1, n2, 1024, *stream>>>(matrix.specialBuffer(), matrix.specialShapeInfo(), input->specialBuffer(), input->specialShapeInfo(), i * n2, n);
matrix.tickWriteDevice(); matrix.tickWriteDevice();
compound.assign(matrix); //compound.assign(matrix);
lup_<T>(context, &compound, nullptr, nullptr); // if (matrix.dataType() == input->dataType())
fillLowerUpperKernel<T><<<n, n, 1024, *stream>>>(lower.specialBuffer(), lower.specialShapeInfo(), upper.specialBuffer(), upper.specialShapeInfo(), compound.specialBuffer(), compound.specialShapeInfo(), n); lup_<T>(context, &matrix, nullptr, nullptr);
fillLowerUpperKernel<T><<<n, n, 1024, *stream>>>(lower.specialBuffer(), lower.specialShapeInfo(), upper.specialBuffer(), upper.specialShapeInfo(), matrix.specialBuffer(), matrix.specialShapeInfo(), n);
lower.tickWriteDevice();
upper.tickWriteDevice();
// lower.printIndexedBuffer("LOWER");
// upper.printIndexedBuffer("UPPER");
matrix.assign(0); matrix.assign(0);
invertUpperMatrix(context, &upper, &matrix); // U^{-1} invertUpperMatrix(context, &upper, &matrix); // U^{-1}
matrix.tickWriteDevice(); matrix.tickWriteDevice();

View File

@ -1305,15 +1305,65 @@ inline __device__ bfloat16 nd4j_atomicAdd<bfloat16>(bfloat16* address, bfloat16
else return old.B.L; else return old.B.L;
} }
template <typename T>
static inline __device__ T internal_16bit_atomicAdd(T* address, T val) {
size_t shift = ((size_t)address & 2);
int *base_address = (int *)((char*)address - shift);
union I16PAIR {
struct {
T H;
T L;
} B;
int W;
__host__ __device__
I16PAIR() {};
__host__ __device__
~I16PAIR() {};
};
I16PAIR pairNew, pairOld, pairAssumed;
if (reinterpret_cast<int*>(address) == base_address) {
pairOld.B.L = val;
do {
pairNew.B.L = pairOld.B.L;
pairNew.B.H = pairOld.B.H + val;
pairAssumed.W = pairOld.W;
pairOld.W = atomicCAS(base_address, pairAssumed.W, pairNew.W);
} while (pairAssumed.W != pairOld.W);
return (T) pairOld.B.H;
} else {
pairOld.B.H = val;
do {
pairNew.B.H = pairOld.B.H;
pairNew.B.L = pairOld.B.L + val;
pairAssumed.W = pairOld.W;
pairOld.W = atomicCAS(base_address, pairAssumed.W, pairNew.W);
} while (pairAssumed.W != pairOld.W);
return (T) pairOld.B.L;
}
}
template <> template <>
inline __device__ int16_t nd4j_atomicAdd<int16_t>(int16_t* address, int16_t val) { inline __device__ int16_t nd4j_atomicAdd<int16_t>(int16_t* address, int16_t val) {
return nd4j_atomicAdd((bfloat16*)address, (bfloat16)val); return internal_16bit_atomicAdd<int16_t>(address, val);
} }
template <> template <>
inline __device__ uint16_t nd4j_atomicAdd<uint16_t>(uint16_t* address, uint16_t val) { inline __device__ uint16_t nd4j_atomicAdd<uint16_t>(uint16_t* address, uint16_t val) {
return nd4j_atomicAdd((bfloat16*)address, (bfloat16)val); return internal_16bit_atomicAdd<uint16_t>(address, val);
} }
template <> template <>
inline __device__ int8_t nd4j_atomicAdd<int8_t>(int8_t* address, int8_t val) { inline __device__ int8_t nd4j_atomicAdd<int8_t>(int8_t* address, int8_t val) {
int res = *address; int res = *address;
@ -1447,7 +1497,7 @@ inline __device__ unsigned char nd4j_atomicMul<unsigned char>(unsigned char* add
} }
template <typename T> template <typename T>
static inline __device__ T internal_16bit_atomicMul(T* address, int16_t val) { static inline __device__ T internal_16bit_atomicMul(T* address, T val) {
size_t shift = ((size_t)address & 2); size_t shift = ((size_t)address & 2);
int *base_address = (int *)((char*)address - shift); int *base_address = (int *)((char*)address - shift);
@ -1467,10 +1517,9 @@ static inline __device__ T internal_16bit_atomicMul(T* address, int16_t val) {
I16PAIR pairNew, pairOld, pairAssumed; I16PAIR pairNew, pairOld, pairAssumed;
pairOld.W = (int) val;
if (reinterpret_cast<int*>(address) == base_address) { if (reinterpret_cast<int*>(address) == base_address) {
pairOld.B.L = val;
do { do {
pairNew.B.L = pairOld.B.L; pairNew.B.L = pairOld.B.L;
pairNew.B.H = pairOld.B.H * val; pairNew.B.H = pairOld.B.H * val;
pairAssumed.W = pairOld.W; pairAssumed.W = pairOld.W;
@ -1480,8 +1529,8 @@ static inline __device__ T internal_16bit_atomicMul(T* address, int16_t val) {
return (T) pairOld.B.H; return (T) pairOld.B.H;
} else { } else {
pairOld.B.H = val;
do { do {
pairNew.B.H = pairOld.B.H; pairNew.B.H = pairOld.B.H;
pairNew.B.L = pairOld.B.L * val; pairNew.B.L = pairOld.B.L * val;
pairAssumed.W = pairOld.W; pairAssumed.W = pairOld.W;
@ -1491,10 +1540,8 @@ static inline __device__ T internal_16bit_atomicMul(T* address, int16_t val) {
return (T) pairOld.B.L; return (T) pairOld.B.L;
} }
} }
template <> template <>
inline __device__ int16_t nd4j_atomicMul<int16_t>(int16_t* address, int16_t val) { inline __device__ int16_t nd4j_atomicMul<int16_t>(int16_t* address, int16_t val) {
return internal_16bit_atomicMul<int16_t>(address, val); return internal_16bit_atomicMul<int16_t>(address, val);
@ -1549,17 +1596,6 @@ inline __device__ uint64_t nd4j_atomicMul<uint64_t>(uint64_t* address, uint64_t
return (uint64_t)old; return (uint64_t)old;
} }
//template <>
//inline __device__ unsigned long long nd4j_atomicMul<unsigned long long>(unsigned long long* address, unsigned long long val) {
// unsigned long long int* res_address = address;
// unsigned long long int old = *res_address, assumed;
// do {
// assumed = old;
// old = atomicCAS(res_address, assumed, val * assumed);
// } while (assumed != old);
// return old;
//}
#if !defined(_WIN32) && !defined(_WIN64) #if !defined(_WIN32) && !defined(_WIN64)
template <> template <>
inline __device__ Nd4jLong nd4j_atomicMul<Nd4jLong>(Nd4jLong* address, Nd4jLong val) { inline __device__ Nd4jLong nd4j_atomicMul<Nd4jLong>(Nd4jLong* address, Nd4jLong val) {
@ -1585,22 +1621,21 @@ inline __device__ float16 nd4j_atomicMul<float16>(float16* address, float16 val)
template <> template <>
inline __device__ float nd4j_atomicDiv<float>(float* address, float val) { inline __device__ float nd4j_atomicDiv<float>(float* address, float val) {
return nd4j_atomicMul<float>(address, (float) 1.f / val); return nd4j_atomicMul<float>(address, 1.f / val);
} }
template <> template <>
inline __device__ float16 nd4j_atomicDiv<float16>(float16* address, float16 val) { inline __device__ float16 nd4j_atomicDiv<float16>(float16* address, float16 val) {
return nd4j_atomicMul<float16>(address, (float16) 1.f / val); return internal_16bit_atomicMul<float16>(address, (float16) 1.f / val);
} }
template <> template <>
inline __device__ bfloat16 nd4j_atomicDiv<bfloat16>(bfloat16* address, bfloat16 val) { inline __device__ bfloat16 nd4j_atomicDiv<bfloat16>(bfloat16* address, bfloat16 val) {
return nd4j_atomicMul<bfloat16>(address, (bfloat16) 1.f / val); return internal_16bit_atomicMul<bfloat16>(address, (bfloat16) 1 / val);
} }
} }
#endif #endif
} }
} }
#ifdef _OPENMP #ifdef _OPENMP

View File

@ -60,12 +60,93 @@ static void multiplyLauncher(void *vbuffer, uint64_t length, void *vresult) {
nd4j::cuda_exception::build("multiply failed", err); nd4j::cuda_exception::build("multiply failed", err);
} }
template <typename T>
static _CUDA_G void sumKernel(void *vbuffer, uint64_t length, void *vresult) {
auto buffer = reinterpret_cast<T*>(vbuffer);
auto result = reinterpret_cast<T*>(vresult);
auto tid = blockIdx.x * blockDim.x + threadIdx.x;
for (auto e = tid; e < length; e += gridDim.x * blockDim.x) {
auto rem = e % 4;
auto i = (e - rem) / 4;
nd4j::math::atomics::nd4j_atomicAdd<T>(&result[i], buffer[e]);
}
}
template <typename T>
static void sumLauncher(void *vbuffer, uint64_t length, void *vresult) {
sumKernel<T><<<256, 256, 1024, *nd4j::LaunchContext::defaultContext()->getCudaStream()>>>(vbuffer, length, vresult);
auto err = cudaStreamSynchronize(*nd4j::LaunchContext::defaultContext()->getCudaStream());
if (err != 0)
nd4j::cuda_exception::build("sum failed", err);
}
template <typename T>
static _CUDA_G void subKernel(void *vbuffer, uint64_t length, void *vresult) {
auto buffer = reinterpret_cast<T*>(vbuffer);
auto result = reinterpret_cast<T*>(vresult);
auto tid = blockIdx.x * blockDim.x + threadIdx.x;
for (auto e = tid; e < length; e += gridDim.x * blockDim.x) {
auto rem = e % 4;
auto i = (e - rem) / 4;
nd4j::math::atomics::nd4j_atomicSub<T>(&result[i], buffer[e]);
}
}
template <typename T>
static void subLauncher(void *vbuffer, uint64_t length, void *vresult) {
subKernel<T><<<256, 256, 1024, *nd4j::LaunchContext::defaultContext()->getCudaStream()>>>(vbuffer, length, vresult);
auto err = cudaStreamSynchronize(*nd4j::LaunchContext::defaultContext()->getCudaStream());
if (err != 0)
nd4j::cuda_exception::build("sub failed", err);
}
template <typename T>
static _CUDA_G void divKernel(void *vbuffer, uint64_t length, void *vresult) {
auto buffer = reinterpret_cast<T*>(vbuffer);
auto result = reinterpret_cast<T*>(vresult);
auto tid = blockIdx.x * blockDim.x + threadIdx.x;
for (auto e = tid; e < length; e += gridDim.x * blockDim.x) {
auto rem = e % 4;
auto i = (e - rem) / 4;
nd4j::math::atomics::nd4j_atomicDiv<T>(&result[i], buffer[e]);
}
}
template <typename T>
static void divLauncher(void *vbuffer, uint64_t length, void *vresult) {
divKernel<T><<<256, 256, 1024, *nd4j::LaunchContext::defaultContext()->getCudaStream()>>>(vbuffer, length, vresult);
auto err = cudaStreamSynchronize(*nd4j::LaunchContext::defaultContext()->getCudaStream());
if (err != 0)
nd4j::cuda_exception::build("div failed", err);
}
static void multiplyHost(NDArray &input, NDArray &output) { static void multiplyHost(NDArray &input, NDArray &output) {
BUILD_SINGLE_SELECTOR(input.dataType(), multiplyLauncher, (input.specialBuffer(), input.lengthOf(), output.specialBuffer()), NUMERIC_TYPES); BUILD_SINGLE_SELECTOR(input.dataType(), multiplyLauncher, (input.specialBuffer(), input.lengthOf(), output.specialBuffer()), NUMERIC_TYPES);
} }
static void sumHost(NDArray &input, NDArray &output) {
BUILD_SINGLE_SELECTOR(input.dataType(), sumLauncher, (input.specialBuffer(), input.lengthOf(), output.specialBuffer()), NUMERIC_TYPES);
}
static void subHost(NDArray &input, NDArray &output) {
BUILD_SINGLE_SELECTOR(input.dataType(), subLauncher, (input.specialBuffer(), input.lengthOf(), output.specialBuffer()), FLOAT_TYPES);
}
static void divHost(NDArray &input, NDArray &output) {
BUILD_SINGLE_SELECTOR(input.dataType(), divLauncher, (input.specialBuffer(), input.lengthOf(), output.specialBuffer()), FLOAT_TYPES);
}
TEST_F(AtomicTests, test_multiply) { TEST_F(AtomicTests, test_multiply) {
std::vector<nd4j::DataType> dtypes = {nd4j::DataType::FLOAT32, nd4j::DataType::DOUBLE, nd4j::DataType::INT16}; std::vector<nd4j::DataType> dtypes = {nd4j::DataType::FLOAT32, nd4j::DataType::DOUBLE, nd4j::DataType::INT16, nd4j::DataType::HALF};
for (auto t:dtypes) { for (auto t:dtypes) {
nd4j_printf("Trying data type [%s]\n", DataTypeUtils::asString(t).c_str()); nd4j_printf("Trying data type [%s]\n", DataTypeUtils::asString(t).c_str());
@ -80,7 +161,81 @@ TEST_F(AtomicTests, test_multiply) {
multiplyHost(input, output); multiplyHost(input, output);
ASSERT_EQ(exp, output); ASSERT_EQ(exp, output);
} }
}
TEST_F(AtomicTests, test_multiply_2) {
std::vector<nd4j::DataType> dtypes = {nd4j::DataType::FLOAT32, nd4j::DataType::DOUBLE, nd4j::DataType::HALF, nd4j::DataType::BFLOAT16};
for (auto t:dtypes) {
nd4j_printf("Trying data type [%s]\n", DataTypeUtils::asString(t).c_str());
NDArray input('c', {4, 25}, t);
NDArray output('c', {input.lengthOf() / 4}, t);
NDArray exp = output.ulike();
input.assign(1.5);
output.assign(2);
exp.assign(10.125);
multiplyHost(input, output);
// output.printBuffer("multiply 2");
ASSERT_EQ(exp, output);
}
}
TEST_F(AtomicTests, test_sum) {
std::vector<nd4j::DataType> dtypes = {nd4j::DataType::FLOAT32, nd4j::DataType::DOUBLE, nd4j::DataType::BFLOAT16, nd4j::DataType::HALF, nd4j::DataType::INT16};
for (auto t:dtypes) {
nd4j_printf("Trying data type [%s]\n", DataTypeUtils::asString(t).c_str());
NDArray input('c', {4, 25}, t);
NDArray output('c', {input.lengthOf() / 4}, t);
NDArray exp = output.ulike();
input.assign(1);
output.assign(1);
exp.assign(5);
sumHost(input, output);
// output.printIndexedBuffer("Sum");
ASSERT_EQ(exp, output);
}
}
TEST_F(AtomicTests, test_sub) {
std::vector<nd4j::DataType> dtypes = {nd4j::DataType::FLOAT32, nd4j::DataType::DOUBLE, nd4j::DataType::HALF};
for (auto t:dtypes) {
nd4j_printf("Trying data type [%s]\n", DataTypeUtils::asString(t).c_str());
NDArray input('c', {4, 25}, t);
NDArray output('c', {input.lengthOf() / 4}, t);
NDArray exp = output.ulike();
input.assign(1);
output.assign(5);
exp.assign(1);
subHost(input, output);
// output.printBuffer("Sub");
ASSERT_EQ(exp, output);
}
}
TEST_F(AtomicTests, test_div) {
std::vector<nd4j::DataType> dtypes = {nd4j::DataType::FLOAT32, nd4j::DataType::DOUBLE, nd4j::DataType::BFLOAT16, nd4j::DataType::HALF};
for (auto t:dtypes) {
nd4j_printf("Trying data type [%s]\n", DataTypeUtils::asString(t).c_str());
NDArray input('c', {4, 25}, t);
NDArray output('c', {input.lengthOf() / 4}, t);
NDArray exp = output.ulike();
input.assign(2);
output.assign(32);
exp.assign(2);
divHost(input, output);
// output.printBuffer("Div");
ASSERT_EQ(exp, output);
}
} }

View File

@ -2785,6 +2785,100 @@ TEST_F(DeclarableOpsTests10, FakeQuantWithMinMaxVars_Test_3) {
delete results; delete results;
} }
TEST_F(DeclarableOpsTests10, FakeQuantWithMinMaxVars_Test_03) {
NDArray x = NDArrayFactory::create<float>('c', {3,5}, {0.7788f,0.8012f, 0.7244f, 0.2309f,0.7271f,
0.1804f, 0.5056f, 0.8925f, 0.5461f, 0.9234f,
0.0856f, 0.7938f, 0.6591f, 0.5555f, 0.1596f});
NDArray exp = NDArrayFactory::create<float>('c', {3,5}, {
0.777002f, 0.596913f, 0.72314f, 0.231040f, 0.509824f,
0.179308f, 0.505282f, 0.86846f, 0.349958f, 0.509824f,
0.087355f, 0.596913f, 0.65740f, 0.349958f, 0.159745f});
NDArray min = NDArrayFactory::create<float>({-0.2283f, -0.0719f, -0.0154f, -0.5162f, -0.3567f});
NDArray max = NDArrayFactory::create<float>({0.9441f, 0.5957f, 0.8669f, 0.3502f, 0.5100f});
nd4j::ops::fake_quant_with_min_max_vars_per_channel op;
auto results = op.execute({&x, &min, &max}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto result = results->at(0);
// result->printIndexedBuffer("Quantized03");
ASSERT_TRUE(exp.isSameShapeStrict(result));
ASSERT_TRUE(exp.equalsTo(result));
delete results;
}
TEST_F(DeclarableOpsTests10, FakeQuantWithMinMaxVars_Test_03_1) {
NDArray x = NDArrayFactory::create<float>('c', {3,5}, {0.7788f,0.8012f, 0.7244f, 0.2309f,0.7271f,
0.1804f, 0.5056f, 0.8925f, 0.5461f, 0.9234f,
0.0856f, 0.7938f, 0.6591f, 0.5555f, 0.1596f});
NDArray exp = NDArrayFactory::create<float>('c', {3,5}, {
0.780061f, 0.596635f, 0.725987f, 0.231950f, 0.508419f,
0.180014f, 0.504643f, 0.868406f, 0.351335f, 0.508419f,
0.087699f, 0.596635f, 0.659988f, 0.351335f, 0.160374f});
NDArray min = NDArrayFactory::create<float>({-0.2283f, -0.0719f, -0.0154f, -0.5162f, -0.3567f});
NDArray max = NDArrayFactory::create<float>({0.9441f, 0.5957f, 0.8669f, 0.3502f, 0.5100f});
nd4j::ops::fake_quant_with_min_max_vars_per_channel op;
auto results = op.execute({&x, &min, &max}, {}, {8}, {true});
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto result = results->at(0);
// result->printIndexedBuffer("Quantized03_1");
ASSERT_TRUE(exp.isSameShapeStrict(result));
ASSERT_TRUE(exp.equalsTo(result));
delete results;
}
TEST_F(DeclarableOpsTests10, FakeQuantWithMinMaxVars_Test_03_2) {
NDArray x = NDArrayFactory::create<float>('c', {3,5}, {0.7788f,0.8012f, 0.7244f, 0.2309f,0.7271f,
0.1804f, 0.5056f, 0.8925f, 0.5461f, 0.9234f,
0.0856f, 0.7938f, 0.6591f, 0.5555f, 0.1596f});
NDArray exp = NDArrayFactory::create<float>('c', {3,5}, {
0.775297f, 0.592226f, 0.725763f, 0.237561f, 0.503245f,
0.189097f, 0.506084f, 0.868069f, 0.349355f, 0.503245f,
0.094548f, 0.592226f, 0.654610f, 0.349355f, 0.153769f});
NDArray min = NDArrayFactory::create<float>({-0.2283f, -0.0719f, -0.0154f, -0.5162f, -0.3567f});
NDArray max = NDArrayFactory::create<float>({0.9441f, 0.5957f, 0.8669f, 0.3502f, 0.5100f});
nd4j::ops::fake_quant_with_min_max_vars_per_channel op;
auto results = op.execute({&x, &min, &max}, {}, {6}, {true});
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto result = results->at(0);
result->printIndexedBuffer("Quantized03_2");
ASSERT_TRUE(exp.isSameShapeStrict(result));
ASSERT_TRUE(exp.equalsTo(result));
delete results;
}
TEST_F(DeclarableOpsTests10, FakeQuantWithMinMaxVars_Test_03_3) {
NDArray x = NDArrayFactory::create<float>('c', {3,5}, {0.7788f,0.8012f, 0.7244f, 0.2309f,0.7271f,
0.1804f, 0.5056f, 0.8925f, 0.5461f, 0.9234f,
0.0856f, 0.7938f, 0.6591f, 0.5555f, 0.1596f});
NDArray exp = NDArrayFactory::create<float>('c', {3,5}, {
0.781600f, 0.593422f, 0.728248f, 0.233790f, 0.509014f, 0.186095f, 0.508648f, 0.868295f, 0.343809f,
0.509014f, 0.093048f, 0.593422f, 0.658224f, 0.343809f, 0.165086f});
NDArray min = NDArrayFactory::create<float>({-0.2283f, -0.0719f, -0.0154f, -0.5162f, -0.3567f});
NDArray max = NDArrayFactory::create<float>({0.9441f, 0.5957f, 0.8669f, 0.3502f, 0.5100f});
nd4j::ops::fake_quant_with_min_max_vars_per_channel op;
auto results = op.execute({&x, &min, &max}, {}, {6}, {false});
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto result = results->at(0);
result->printIndexedBuffer("Quantized03_3");
ASSERT_TRUE(exp.isSameShapeStrict(result));
ASSERT_TRUE(exp.equalsTo(result));
delete results;
}
//////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests10, FakeQuantWithMinMaxVars_Test_4) { TEST_F(DeclarableOpsTests10, FakeQuantWithMinMaxVars_Test_4) {