cavis/libnd4j/include/ops/impl/specials_sparse.cpp

224 lines
7.9 KiB
C++

/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// @author raver119@gmail.com
//
#include <ops/specials_sparse.h>
#include <system/dll.h>
#include <system/pointercast.h>
#include <stdio.h>
#include <stdlib.h>
#ifdef _OPENMP
#include <omp.h>
#endif
#include <types/float16.h>
#include <types/types.h>
namespace sd {
namespace sparse {
template <typename T>
void SparseUtils<T>::printIndex(Nd4jLong *indices, int rank, int x) {
printf(" [");
for (int e = 0; e < rank; e++) {
if (e > 0)
printf(", ");
printf("%lld", (long long) indices[x * rank + e]);
}
printf("] ");
}
template <typename T>
bool SparseUtils<T>::ltIndices(Nd4jLong *indices, int rank, Nd4jLong x, Nd4jLong y) {
for (int e = 0; e < rank; e++) {
Nd4jLong idxX = indices[x * rank + e];
Nd4jLong idxY = indices[y * rank + e];
// we're comparing indices one by one, starting from outer dimension
if (idxX < idxY) {
return true;
} else if (idxX == idxY) {
// do nothing, continue to next dimension
} else
return false;
}
return false;
}
template <typename T>
bool SparseUtils<T>::gtIndices(Nd4jLong *indices, int rank, Nd4jLong x, Nd4jLong y) {
for (int e = 0; e < rank; e++) {
// we're comparing indices one by one, starting from outer dimension
Nd4jLong idxX = indices[x * rank + e];
Nd4jLong idxY = indices[y * rank + e];
if ( idxX > idxY) {
return true;
} else if (idxX == idxY) {
// do nothing, continue to next dimension
} else
return false;
}
return false;
}
template <typename T>
void SparseUtils<T>::swapEverything(Nd4jLong *indices, T *array, int rank, Nd4jLong x, Nd4jLong y) {
// swap indices
for (int e = 0; e < rank; e++) {
Nd4jLong tmp = indices[x * rank + e];
indices[x * rank + e] = indices[y * rank + e];
indices[y * rank + e] = tmp;
}
// swap values
T tmp = array[x];
array[x] = array[y];
array[y] = tmp;
}
template <typename T>
Nd4jLong SparseUtils<T>::coo_quickSort_findPivot(Nd4jLong *indices, T *array, Nd4jLong left, Nd4jLong right,
int rank) {
Nd4jLong mid = (left + right) / 2;
// ensure left < mid
if (ltIndices(indices, rank, mid, left)) { // ensure lo < mid
swapEverything(indices, array, rank, mid, left);
}
// ensure left < right
if (ltIndices(indices, rank, right, left)) {
swapEverything(indices, array, rank, right, left);
}
// ensure mid < right
if (ltIndices(indices, rank, right, mid)) {
swapEverything(indices, array, rank, right, mid);
}
// mid is the median of the 3, and is the optimal pivot point
return mid;
}
template<typename T>
void SparseUtils<T>::coo_quickSort_parallel_internal(Nd4jLong *indices, T* array, Nd4jLong left, Nd4jLong right, int cutoff, int rank) {
Nd4jLong span = right - left; // elements to be partitioned - 1
if (span == 1){
// only 2 elements to partition. swap if needed and return directly without further sorting.
if (ltIndices(indices, rank, right, left)){
swapEverything(indices, array, rank, left, right);
}
return;
}
// find optimal pivot and sort left < right < right
Nd4jLong pvt = coo_quickSort_findPivot(indices, array, left, right, rank);
if (span == 2){
// only 3 elements to partition. findPivot has already sorted them. no further sorting is needed.
return;
}
// index that is greater than pivot - leftmost element is already partitioned because of findPivot.
Nd4jLong i = left + 1;
// index that is smaller than pivot - rightmost element is already partitioned because of findPivot.
Nd4jLong j = right - 1;
{
// flag that indicates that pivot index lies between i and j and *could* be swapped.
bool checkPivot = true;
/* PARTITION PART */
while (i <= j) {
while (ltIndices(indices, rank, i, pvt))
i++;
while (gtIndices(indices, rank, j, pvt))
j--;
if (i <= j) {
if(i != j) { // swap can be fairly expensive. don't swap i -> i
swapEverything(indices, array, rank, i, j);
}
// only check pivot if it hasn't already been swapped.
if (checkPivot) {
// check if we moved the pivot, if so, change pivot index accordingly
if (pvt == j) {
pvt = i;
checkPivot = false;
} else if (pvt == i) {
pvt = j;
checkPivot = false;
}
}
i++;
j--;
}
}
}
if ( (span < cutoff) ){
if (left < j){ coo_quickSort_parallel_internal(indices, array, left, j, cutoff, rank); }
if (i < right){ coo_quickSort_parallel_internal(indices, array, i, right, cutoff, rank); }
}else{
PRAGMA_OMP_TASK
{ coo_quickSort_parallel_internal(indices, array, left, j, cutoff, rank); }
PRAGMA_OMP_TASK
{ coo_quickSort_parallel_internal(indices, array, i, right, cutoff, rank); }
}
}
template <typename T>
void SparseUtils<T>::coo_quickSort_parallel(Nd4jLong *indices, T* array, Nd4jLong lenArray, int numThreads, int rank){
int cutoff = 1000;
PRAGMA_OMP_PARALLEL_THREADS(numThreads)
{
PRAGMA_OMP_SINGLE_ARGS(nowait)
{
coo_quickSort_parallel_internal(indices, array, 0, lenArray-1, cutoff, rank);
}
}
}
template <typename T>
void SparseUtils<T>::sortCooIndicesGeneric(Nd4jLong *indices, void *vx, Nd4jLong length, int rank) {
auto values = reinterpret_cast<T *>(vx);
#ifdef _OPENMP
coo_quickSort_parallel(indices, values, length, omp_get_max_threads(), rank);
#else
coo_quickSort_parallel(indices, values, length, 1, rank);
#endif
}
BUILD_SINGLE_TEMPLATE(template class ND4J_EXPORT SparseUtils, , LIBND4J_TYPES);
}
}