fix bits_hamming_distance for ppc

Signed-off-by: raver119 <raver119@gmail.com>
master
raver119 2019-09-01 19:33:23 +03:00
parent ef1de6a4aa
commit 3679e55c49
1 changed files with 20 additions and 9 deletions

View File

@ -33,34 +33,45 @@ namespace nd4j {
auto yBuffer = y.bufferAsT<X>(); auto yBuffer = y.bufferAsT<X>();
Nd4jLong distance = 0; Nd4jLong distance = 0;
auto lengthOf = x.lengthOf();
const int maxThreads = nd4j::math::nd4j_min<int>(256, omp_get_max_threads());
Nd4jLong intermediate[256];
// nullify temp values
for (int e = 0; e < maxThreads; e++)
intermediate[e] = 0;
if (xEws == 1 && yEws == 1 && x.ordering() == y.ordering()) { if (xEws == 1 && yEws == 1 && x.ordering() == y.ordering()) {
PRAGMA_OMP_PARALLEL_FOR_SIMD_REDUCTION(+:distance) PRAGMA_OMP_PARALLEL_FOR
for (Nd4jLong e = 0; e < x.lengthOf(); e++) { for (Nd4jLong e = 0; e < lengthOf; e++) {
auto _x = static_cast<unsigned long long>(xBuffer[e]); auto _x = static_cast<unsigned long long>(xBuffer[e]);
auto _y = static_cast<unsigned long long>(yBuffer[e]); auto _y = static_cast<unsigned long long>(yBuffer[e]);
distance += __builtin_popcountll(_x ^ _y); intermediate[omp_get_thread_num()] += __builtin_popcountll(_x ^ _y);
} }
} else if (xEws > 1 && yEws > 1 && x.ordering() == y.ordering()) { } else if (xEws > 1 && yEws > 1 && x.ordering() == y.ordering()) {
PRAGMA_OMP_PARALLEL_FOR_SIMD_REDUCTION(+:distance) PRAGMA_OMP_PARALLEL_FOR
for (Nd4jLong e = 0; e < x.lengthOf(); e++) { for (Nd4jLong e = 0; e < lengthOf; e++) {
auto _x = static_cast<unsigned long long>(xBuffer[e * xEws]); auto _x = static_cast<unsigned long long>(xBuffer[e * xEws]);
auto _y = static_cast<unsigned long long>(yBuffer[e * yEws]); auto _y = static_cast<unsigned long long>(yBuffer[e * yEws]);
distance += __builtin_popcountll(_x ^ _y); intermediate[omp_get_thread_num()] += __builtin_popcountll(_x ^ _y);
} }
} else { } else {
PRAGMA_OMP_PARALLEL_FOR_SIMD_REDUCTION(+:distance) PRAGMA_OMP_PARALLEL_FOR
for (Nd4jLong e = 0; e < x.lengthOf(); e++) { for (Nd4jLong e = 0; e < lengthOf; e++) {
auto _x = static_cast<unsigned long long>(x.e<Nd4jLong>(e)); auto _x = static_cast<unsigned long long>(x.e<Nd4jLong>(e));
auto _y = static_cast<unsigned long long>(y.e<Nd4jLong>(e)); auto _y = static_cast<unsigned long long>(y.e<Nd4jLong>(e));
distance += __builtin_popcountll(_x ^ _y); intermediate[omp_get_thread_num()] += __builtin_popcountll(_x ^ _y);
} }
} }
// accumulate intermediate variables into output array
for (int e = 0; e < maxThreads; e++)
distance += intermediate[e];
z.p(0, distance); z.p(0, distance);
} }