parent
ef1de6a4aa
commit
3679e55c49
|
@ -33,34 +33,45 @@ namespace nd4j {
|
|||
auto yBuffer = y.bufferAsT<X>();
|
||||
|
||||
Nd4jLong distance = 0;
|
||||
auto lengthOf = x.lengthOf();
|
||||
const int maxThreads = nd4j::math::nd4j_min<int>(256, omp_get_max_threads());
|
||||
Nd4jLong intermediate[256];
|
||||
|
||||
// nullify temp values
|
||||
for (int e = 0; e < maxThreads; e++)
|
||||
intermediate[e] = 0;
|
||||
|
||||
if (xEws == 1 && yEws == 1 && x.ordering() == y.ordering()) {
|
||||
PRAGMA_OMP_PARALLEL_FOR_SIMD_REDUCTION(+:distance)
|
||||
for (Nd4jLong e = 0; e < x.lengthOf(); e++) {
|
||||
PRAGMA_OMP_PARALLEL_FOR
|
||||
for (Nd4jLong e = 0; e < lengthOf; e++) {
|
||||
auto _x = static_cast<unsigned long long>(xBuffer[e]);
|
||||
auto _y = static_cast<unsigned long long>(yBuffer[e]);
|
||||
|
||||
distance += __builtin_popcountll(_x ^ _y);
|
||||
intermediate[omp_get_thread_num()] += __builtin_popcountll(_x ^ _y);
|
||||
}
|
||||
|
||||
} else if (xEws > 1 && yEws > 1 && x.ordering() == y.ordering()) {
|
||||
PRAGMA_OMP_PARALLEL_FOR_SIMD_REDUCTION(+:distance)
|
||||
for (Nd4jLong e = 0; e < x.lengthOf(); e++) {
|
||||
PRAGMA_OMP_PARALLEL_FOR
|
||||
for (Nd4jLong e = 0; e < lengthOf; e++) {
|
||||
auto _x = static_cast<unsigned long long>(xBuffer[e * xEws]);
|
||||
auto _y = static_cast<unsigned long long>(yBuffer[e * yEws]);
|
||||
|
||||
distance += __builtin_popcountll(_x ^ _y);
|
||||
intermediate[omp_get_thread_num()] += __builtin_popcountll(_x ^ _y);
|
||||
}
|
||||
} else {
|
||||
PRAGMA_OMP_PARALLEL_FOR_SIMD_REDUCTION(+:distance)
|
||||
for (Nd4jLong e = 0; e < x.lengthOf(); e++) {
|
||||
PRAGMA_OMP_PARALLEL_FOR
|
||||
for (Nd4jLong e = 0; e < lengthOf; e++) {
|
||||
auto _x = static_cast<unsigned long long>(x.e<Nd4jLong>(e));
|
||||
auto _y = static_cast<unsigned long long>(y.e<Nd4jLong>(e));
|
||||
|
||||
distance += __builtin_popcountll(_x ^ _y);
|
||||
intermediate[omp_get_thread_num()] += __builtin_popcountll(_x ^ _y);
|
||||
}
|
||||
}
|
||||
|
||||
// accumulate intermediate variables into output array
|
||||
for (int e = 0; e < maxThreads; e++)
|
||||
distance += intermediate[e];
|
||||
|
||||
z.p(0, distance);
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue