parent
ef1de6a4aa
commit
3679e55c49
|
@ -33,34 +33,45 @@ namespace nd4j {
|
||||||
auto yBuffer = y.bufferAsT<X>();
|
auto yBuffer = y.bufferAsT<X>();
|
||||||
|
|
||||||
Nd4jLong distance = 0;
|
Nd4jLong distance = 0;
|
||||||
|
auto lengthOf = x.lengthOf();
|
||||||
|
const int maxThreads = nd4j::math::nd4j_min<int>(256, omp_get_max_threads());
|
||||||
|
Nd4jLong intermediate[256];
|
||||||
|
|
||||||
|
// nullify temp values
|
||||||
|
for (int e = 0; e < maxThreads; e++)
|
||||||
|
intermediate[e] = 0;
|
||||||
|
|
||||||
if (xEws == 1 && yEws == 1 && x.ordering() == y.ordering()) {
|
if (xEws == 1 && yEws == 1 && x.ordering() == y.ordering()) {
|
||||||
PRAGMA_OMP_PARALLEL_FOR_SIMD_REDUCTION(+:distance)
|
PRAGMA_OMP_PARALLEL_FOR
|
||||||
for (Nd4jLong e = 0; e < x.lengthOf(); e++) {
|
for (Nd4jLong e = 0; e < lengthOf; e++) {
|
||||||
auto _x = static_cast<unsigned long long>(xBuffer[e]);
|
auto _x = static_cast<unsigned long long>(xBuffer[e]);
|
||||||
auto _y = static_cast<unsigned long long>(yBuffer[e]);
|
auto _y = static_cast<unsigned long long>(yBuffer[e]);
|
||||||
|
|
||||||
distance += __builtin_popcountll(_x ^ _y);
|
intermediate[omp_get_thread_num()] += __builtin_popcountll(_x ^ _y);
|
||||||
}
|
}
|
||||||
|
|
||||||
} else if (xEws > 1 && yEws > 1 && x.ordering() == y.ordering()) {
|
} else if (xEws > 1 && yEws > 1 && x.ordering() == y.ordering()) {
|
||||||
PRAGMA_OMP_PARALLEL_FOR_SIMD_REDUCTION(+:distance)
|
PRAGMA_OMP_PARALLEL_FOR
|
||||||
for (Nd4jLong e = 0; e < x.lengthOf(); e++) {
|
for (Nd4jLong e = 0; e < lengthOf; e++) {
|
||||||
auto _x = static_cast<unsigned long long>(xBuffer[e * xEws]);
|
auto _x = static_cast<unsigned long long>(xBuffer[e * xEws]);
|
||||||
auto _y = static_cast<unsigned long long>(yBuffer[e * yEws]);
|
auto _y = static_cast<unsigned long long>(yBuffer[e * yEws]);
|
||||||
|
|
||||||
distance += __builtin_popcountll(_x ^ _y);
|
intermediate[omp_get_thread_num()] += __builtin_popcountll(_x ^ _y);
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
PRAGMA_OMP_PARALLEL_FOR_SIMD_REDUCTION(+:distance)
|
PRAGMA_OMP_PARALLEL_FOR
|
||||||
for (Nd4jLong e = 0; e < x.lengthOf(); e++) {
|
for (Nd4jLong e = 0; e < lengthOf; e++) {
|
||||||
auto _x = static_cast<unsigned long long>(x.e<Nd4jLong>(e));
|
auto _x = static_cast<unsigned long long>(x.e<Nd4jLong>(e));
|
||||||
auto _y = static_cast<unsigned long long>(y.e<Nd4jLong>(e));
|
auto _y = static_cast<unsigned long long>(y.e<Nd4jLong>(e));
|
||||||
|
|
||||||
distance += __builtin_popcountll(_x ^ _y);
|
intermediate[omp_get_thread_num()] += __builtin_popcountll(_x ^ _y);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// accumulate intermediate variables into output array
|
||||||
|
for (int e = 0; e < maxThreads; e++)
|
||||||
|
distance += intermediate[e];
|
||||||
|
|
||||||
z.p(0, distance);
|
z.p(0, distance);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue