312 lines
12 KiB
C++
312 lines
12 KiB
C++
/* ******************************************************************************
|
|
*
|
|
*
|
|
* This program and the accompanying materials are made available under the
|
|
* terms of the Apache License, Version 2.0 which is available at
|
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
|
*
|
|
* See the NOTICE file distributed with this work for additional
|
|
* information regarding copyright ownership.
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
* License for the specific language governing permissions and limitations
|
|
* under the License.
|
|
*
|
|
* SPDX-License-Identifier: Apache-2.0
|
|
******************************************************************************/
|
|
|
|
//
|
|
// @author raver119@gmail.com
|
|
//
|
|
|
|
#include <ops/specials_sparse.h>
|
|
#include <system/dll.h>
|
|
#include <system/pointercast.h>
|
|
#include <stdio.h>
|
|
#include <stdlib.h>
|
|
#include <helpers/shape.h>
|
|
|
|
#ifdef _OPENMP
|
|
#include <omp.h>
|
|
#endif
|
|
#include <types/float16.h>
|
|
#include <types/types.h>
|
|
|
|
namespace sd {
|
|
namespace sparse {
|
|
|
|
template <typename T>
|
|
void SparseUtils<T>::printIndex(Nd4jLong *indices, int rank, int x) {
|
|
printf(" [");
|
|
for (int e = 0; e < rank; e++) {
|
|
if (e > 0)
|
|
printf(", ");
|
|
|
|
printf("%lld", (long long) indices[x * rank + e]);
|
|
}
|
|
printf("] ");
|
|
}
|
|
|
|
template <typename T>
|
|
bool SparseUtils<T>::ltIndices(Nd4jLong *indices, int rank, Nd4jLong x, Nd4jLong y) {
|
|
for (int e = 0; e < rank; e++) {
|
|
Nd4jLong idxX = indices[x * rank + e];
|
|
Nd4jLong idxY = indices[y * rank + e];
|
|
// we're comparing indices one by one, starting from outer dimension
|
|
if (idxX < idxY) {
|
|
return true;
|
|
} else if (idxX == idxY) {
|
|
// do nothing, continue to next dimension
|
|
} else
|
|
return false;
|
|
}
|
|
|
|
return false;
|
|
}
|
|
|
|
template <typename T>
|
|
bool SparseUtils<T>::gtIndices(Nd4jLong *indices, int rank, Nd4jLong x, Nd4jLong y) {
|
|
for (int e = 0; e < rank; e++) {
|
|
// we're comparing indices one by one, starting from outer dimension
|
|
Nd4jLong idxX = indices[x * rank + e];
|
|
Nd4jLong idxY = indices[y * rank + e];
|
|
if ( idxX > idxY) {
|
|
return true;
|
|
} else if (idxX == idxY) {
|
|
// do nothing, continue to next dimension
|
|
} else
|
|
return false;
|
|
}
|
|
return false;
|
|
}
|
|
|
|
template <typename T>
|
|
void SparseUtils<T>::swapEverything(Nd4jLong *indices, T *array, int rank, Nd4jLong x, Nd4jLong y) {
|
|
// swap indices
|
|
for (int e = 0; e < rank; e++) {
|
|
Nd4jLong tmp = indices[x * rank + e];
|
|
indices[x * rank + e] = indices[y * rank + e];
|
|
indices[y * rank + e] = tmp;
|
|
}
|
|
|
|
// swap values
|
|
T tmp = array[x];
|
|
array[x] = array[y];
|
|
array[y] = tmp;
|
|
}
|
|
|
|
template <typename T>
|
|
Nd4jLong SparseUtils<T>::coo_quickSort_findPivot(Nd4jLong *indices, T *array, Nd4jLong left, Nd4jLong right,
|
|
int rank) {
|
|
Nd4jLong mid = (left + right) / 2;
|
|
|
|
// ensure left < mid
|
|
if (ltIndices(indices, rank, mid, left)) { // ensure lo < mid
|
|
swapEverything(indices, array, rank, mid, left);
|
|
}
|
|
|
|
// ensure left < right
|
|
if (ltIndices(indices, rank, right, left)) {
|
|
swapEverything(indices, array, rank, right, left);
|
|
}
|
|
|
|
// ensure mid < right
|
|
if (ltIndices(indices, rank, right, mid)) {
|
|
swapEverything(indices, array, rank, right, mid);
|
|
}
|
|
|
|
// mid is the median of the 3, and is the optimal pivot point
|
|
return mid;
|
|
}
|
|
|
|
template<typename T>
|
|
void SparseUtils<T>::coo_quickSort_parallel_internal(Nd4jLong *indices, T* array, Nd4jLong left, Nd4jLong right, int cutoff, int rank) {
|
|
Nd4jLong span = right - left; // elements to be partitioned - 1
|
|
|
|
if (span == 1){
|
|
// only 2 elements to partition. swap if needed and return directly without further sorting.
|
|
if (ltIndices(indices, rank, right, left)){
|
|
swapEverything(indices, array, rank, left, right);
|
|
}
|
|
return;
|
|
}
|
|
|
|
|
|
// find optimal pivot and sort left < right < right
|
|
Nd4jLong pvt = coo_quickSort_findPivot(indices, array, left, right, rank);
|
|
|
|
if (span == 2){
|
|
// only 3 elements to partition. findPivot has already sorted them. no further sorting is needed.
|
|
return;
|
|
}
|
|
|
|
// index that is greater than pivot - leftmost element is already partitioned because of findPivot.
|
|
Nd4jLong i = left + 1;
|
|
|
|
// index that is smaller than pivot - rightmost element is already partitioned because of findPivot.
|
|
Nd4jLong j = right - 1;
|
|
|
|
|
|
{
|
|
// flag that indicates that pivot index lies between i and j and *could* be swapped.
|
|
bool checkPivot = true;
|
|
/* PARTITION PART */
|
|
while (i <= j) {
|
|
while (ltIndices(indices, rank, i, pvt))
|
|
i++;
|
|
|
|
while (gtIndices(indices, rank, j, pvt))
|
|
j--;
|
|
|
|
|
|
if (i <= j) {
|
|
if(i != j) { // swap can be fairly expensive. don't swap i -> i
|
|
swapEverything(indices, array, rank, i, j);
|
|
}
|
|
|
|
// only check pivot if it hasn't already been swapped.
|
|
if (checkPivot) {
|
|
// check if we moved the pivot, if so, change pivot index accordingly
|
|
if (pvt == j) {
|
|
pvt = i;
|
|
checkPivot = false;
|
|
} else if (pvt == i) {
|
|
pvt = j;
|
|
checkPivot = false;
|
|
}
|
|
}
|
|
|
|
i++;
|
|
j--;
|
|
}
|
|
}
|
|
|
|
}
|
|
|
|
if ( (span < cutoff) ){
|
|
if (left < j){ coo_quickSort_parallel_internal(indices, array, left, j, cutoff, rank); }
|
|
if (i < right){ coo_quickSort_parallel_internal(indices, array, i, right, cutoff, rank); }
|
|
|
|
}else{
|
|
PRAGMA_OMP_TASK
|
|
{ coo_quickSort_parallel_internal(indices, array, left, j, cutoff, rank); }
|
|
PRAGMA_OMP_TASK
|
|
{ coo_quickSort_parallel_internal(indices, array, i, right, cutoff, rank); }
|
|
}
|
|
|
|
}
|
|
|
|
template <typename T>
|
|
void SparseUtils<T>::coo_quickSort_parallel(Nd4jLong *indices, T* array, Nd4jLong lenArray, int numThreads, int rank){
|
|
|
|
int cutoff = 1000;
|
|
|
|
PRAGMA_OMP_PARALLEL_THREADS(numThreads)
|
|
{
|
|
PRAGMA_OMP_SINGLE_ARGS(nowait)
|
|
{
|
|
coo_quickSort_parallel_internal(indices, array, 0, lenArray-1, cutoff, rank);
|
|
}
|
|
}
|
|
|
|
}
|
|
|
|
template <typename T>
|
|
void SparseUtils<T>::sortCooIndicesGeneric(Nd4jLong *indices, void *vx, Nd4jLong length, int rank) {
|
|
auto values = reinterpret_cast<T *>(vx);
|
|
#ifdef _OPENMP
|
|
coo_quickSort_parallel(indices, values, length, omp_get_max_threads(), rank);
|
|
#else
|
|
coo_quickSort_parallel(indices, values, length, 1, rank);
|
|
#endif
|
|
}
|
|
|
|
BUILD_SINGLE_TEMPLATE(template class ND4J_EXPORT SparseUtils, , LIBND4J_TYPES);
|
|
|
|
void IndexUtils::ravelMultiIndex(Nd4jLong *indices, Nd4jLong *flatIndices, Nd4jLong length, Nd4jLong *shapeInfo, int mode){
|
|
Nd4jLong * shape = shape::shapeOf(shapeInfo);
|
|
Nd4jLong * stride = shape::stride(shapeInfo);
|
|
Nd4jLong rank = shape::rank(shapeInfo);
|
|
int errorCount = 0;
|
|
|
|
PRAGMA_OMP_PARALLEL_FOR
|
|
for (Nd4jLong i = 0; i < length; ++i){
|
|
Nd4jLong raveledIndex = 0;
|
|
for (Nd4jLong j = 0; j < rank; ++j){
|
|
Nd4jLong idx = indices[i * rank + j];
|
|
if (idx >= shape[j]) {
|
|
// index does not fit into shape at j dimension.
|
|
if (mode == ND4J_CLIPMODE_CLIP){
|
|
// set idx to largest possible value (clip to shape)
|
|
idx = shape[j] - 1;
|
|
}
|
|
else if (mode == ND4J_CLIPMODE_WRAP) {
|
|
idx %= shape[j];
|
|
} else {
|
|
// mode is ND4J_CLIPMODE_THROW or is unknown. either way. throw an error later.
|
|
// cannot throw here because of parallel region
|
|
nd4j_printf("sparse::IndexUtils::ravelMultiIndex Cannot ravel index at element %d, does not fit into specified shape.\n", i);
|
|
++errorCount;
|
|
}
|
|
}
|
|
raveledIndex += idx * stride[j];
|
|
}
|
|
flatIndices[i] = raveledIndex;
|
|
}
|
|
|
|
if (errorCount > 0){
|
|
// throw error if one ocurred in loop
|
|
throw std::runtime_error("sparse::IndexUtils::ravelMultiIndex Cannot ravel index");
|
|
}
|
|
}
|
|
|
|
void IndexUtils::unravelIndex(Nd4jLong *indices, Nd4jLong *flatIndices, Nd4jLong length, Nd4jLong *shapeInfo){
|
|
Nd4jLong * shape = shape::shapeOf(shapeInfo);
|
|
Nd4jLong * stride = shape::stride(shapeInfo);
|
|
Nd4jLong rank = shape::rank(shapeInfo);
|
|
int errorCount = 0;
|
|
|
|
// unravelOrder ensures that the dimensions with largest stride are unraveled first.
|
|
// create vector with elements 0..rank
|
|
int * unravelOrder = shape::range<int>(0, rank);
|
|
|
|
// sort order according to stride length.
|
|
std::sort(unravelOrder, unravelOrder + rank,
|
|
[&](int i1, int i2) { return stride[i1] > stride[i2]; } );
|
|
|
|
// calculate the largest raveled index that will fit into passed shape
|
|
Nd4jLong maxRaveledIndex = shape[unravelOrder[0]] * stride[unravelOrder[0]] - 1;
|
|
|
|
PRAGMA_OMP_PARALLEL_FOR
|
|
for (Nd4jLong i = 0; i < length; ++i){
|
|
Nd4jLong raveledIndex = flatIndices[i];
|
|
if (raveledIndex > maxRaveledIndex){
|
|
// cannot throw here because of parallel region
|
|
nd4j_printf("sparse::IndexUtils::unravelIndex Cannot unravel index at element %d. raveled index of %d does not fit into specified shape.\n", i, raveledIndex);
|
|
++errorCount;
|
|
}
|
|
|
|
for (int * it = unravelOrder; it != unravelOrder + rank; it++){
|
|
int j = *it;
|
|
// how many strides of this size?
|
|
indices[i * rank + j] = raveledIndex / stride[j];
|
|
|
|
// remainder for subsequent smaller strides.
|
|
raveledIndex %= stride[j];
|
|
}
|
|
}
|
|
|
|
if (errorCount > 0){
|
|
// throw error if one ocurred in loop
|
|
nd4j_printf("Largest raveled index is: %d, ", maxRaveledIndex)
|
|
std::vector<Nd4jLong> v(shape, shape + rank);
|
|
nd4j_printv("Shape: ", v);
|
|
throw std::runtime_error("sparse::IndexUtils::unravelIndex Cannot unravel index");
|
|
}
|
|
|
|
delete[] unravelOrder;
|
|
}
|
|
}
|
|
}
|