add IndexUtils class containing ravelMultiIndex and unravelIndex methods (#9122)
Also add test functions both for Java and C++. Signed-off-by: Péter Zarándy <pza@wehowsky.com>master
parent
a1fcc5f19f
commit
95ca39bd21
|
@ -660,6 +660,13 @@ static void execTransformBool(sd::LaunchContext *lc,
|
||||||
BUILD_SINGLE_SELECTOR(xType, sd::sparse::SparseUtils, ::sortCooIndicesGeneric(indices, x, length, rank), LIBND4J_TYPES);
|
BUILD_SINGLE_SELECTOR(xType, sd::sparse::SparseUtils, ::sortCooIndicesGeneric(indices, x, length, rank), LIBND4J_TYPES);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
inline static void execRavelMultiIndex(Nd4jLong *indices, Nd4jLong *flatIndices, Nd4jLong length, Nd4jLong *shapeInfo, int mode) {
|
||||||
|
sd::sparse::IndexUtils::ravelMultiIndex(indices, flatIndices, length, shapeInfo, mode);
|
||||||
|
}
|
||||||
|
|
||||||
|
inline static void execUnravelIndex(Nd4jLong *indices, Nd4jLong *flatIndices, Nd4jLong length, Nd4jLong *shapeInfo) {
|
||||||
|
sd::sparse::IndexUtils::unravelIndex(indices, flatIndices, length, shapeInfo);
|
||||||
|
}
|
||||||
|
|
||||||
inline static Nd4jLong encodeBitmap(void *dx, const Nd4jLong *xShapeInfo, Nd4jLong N, int *dz, float threshold) {
|
inline static Nd4jLong encodeBitmap(void *dx, const Nd4jLong *xShapeInfo, Nd4jLong N, int *dz, float threshold) {
|
||||||
auto xType = sd::ArrayOptions::dataType(xShapeInfo);
|
auto xType = sd::ArrayOptions::dataType(xShapeInfo);
|
||||||
|
|
|
@ -1482,6 +1482,30 @@ ND4J_EXPORT void sortCooIndices(Nd4jPointer *extraPointers,
|
||||||
Nd4jLong length,
|
Nd4jLong length,
|
||||||
const Nd4jLong *xShapeInfo);
|
const Nd4jLong *xShapeInfo);
|
||||||
|
|
||||||
|
/**
|
||||||
|
*
|
||||||
|
* @param extraPointers not used
|
||||||
|
* @param indices DataBuffer containing COO indices for a sparse matrix that is to be raveled/flattened
|
||||||
|
* @param flatIndices DataBuffer where the raveled/flattened indices are to be written to
|
||||||
|
* @param length number of non-zero entries (length of flatIndices)
|
||||||
|
* @param fullShapeBuffer DataBuffer with ShapeInfo for the full matrix to be flattened
|
||||||
|
* @param mode clipMode determines the strategy to use if some of the the passed COO indices does
|
||||||
|
* not fit into the shape determined by fullShapeBuffer
|
||||||
|
* 0 throw an exception (default)
|
||||||
|
* 1 wrap around shape
|
||||||
|
* 2 clip to shape
|
||||||
|
*/
|
||||||
|
ND4J_EXPORT void ravelMultiIndex(Nd4jPointer *extraPointers, Nd4jLong *indices, Nd4jLong *flatIndices, Nd4jLong length, Nd4jLong *shapeInfo, int mode);
|
||||||
|
|
||||||
|
/**
|
||||||
|
*
|
||||||
|
* @param extraPointers not used
|
||||||
|
* @param indices DataBuffer where the unraveled COO indices are to be written
|
||||||
|
* @param flatIndices DataBuffer containing the raveled/flattened indices to be unravel
|
||||||
|
* @param length number of non-zero entries (length of flatIndices)
|
||||||
|
* @param fullShapeBuffer DataBuffer with ShapeInfo for the full matrix to be unraveled
|
||||||
|
*/
|
||||||
|
ND4J_EXPORT void unravelIndex(Nd4jPointer *extraPointers, Nd4jLong *indices, Nd4jLong *flatIndices, Nd4jLong length, Nd4jLong *shapeInfo);
|
||||||
|
|
||||||
ND4J_EXPORT Nd4jLong* mmapFile(Nd4jPointer *extraPointers, const char *fileName, Nd4jLong length);
|
ND4J_EXPORT Nd4jLong* mmapFile(Nd4jPointer *extraPointers, const char *fileName, Nd4jLong length);
|
||||||
|
|
||||||
|
|
|
@ -1843,6 +1843,14 @@ void sortCooIndices(Nd4jPointer *extraPointers,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void ravelMultiIndex(Nd4jPointer *extraPointers, Nd4jLong *indices, Nd4jLong *flatIndices, Nd4jLong length, Nd4jLong *shapeInfo, int mode) {
|
||||||
|
NativeOpExecutioner::execRavelMultiIndex(indices, flatIndices, length, shapeInfo, mode);
|
||||||
|
}
|
||||||
|
|
||||||
|
void unravelIndex(Nd4jPointer *extraPointers, Nd4jLong *indices, Nd4jLong *flatIndices, Nd4jLong length, Nd4jLong *shapeInfo) {
|
||||||
|
NativeOpExecutioner::execUnravelIndex(indices, flatIndices, length, shapeInfo);
|
||||||
|
}
|
||||||
|
|
||||||
Nd4jLong encodeBitmap(Nd4jPointer *extraPointers, void *hX, Nd4jLong const* hXShapeInfo, Nd4jLong N, int *dz, float threshold) {
|
Nd4jLong encodeBitmap(Nd4jPointer *extraPointers, void *hX, Nd4jLong const* hXShapeInfo, Nd4jLong N, int *dz, float threshold) {
|
||||||
return NativeOpExecutioner::encodeBitmap(hX, hXShapeInfo, N, dz, threshold);
|
return NativeOpExecutioner::encodeBitmap(hX, hXShapeInfo, N, dz, threshold);
|
||||||
}
|
}
|
||||||
|
|
|
@ -2564,6 +2564,15 @@ void sortCooIndices(Nd4jPointer *extraPointers, Nd4jLong *indices, void *values,
|
||||||
throw std::runtime_error("sortCooIndices:: Not implemented yet");
|
throw std::runtime_error("sortCooIndices:: Not implemented yet");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void ravelMultiIndex(Nd4jPointer *extraPointers, Nd4jLong *indices, Nd4jLong *flatIndices, Nd4jLong length, Nd4jLong *shapeInfo, int mode) {
|
||||||
|
throw std::runtime_error("ravelMultiIndex:: Not implemented yet");
|
||||||
|
}
|
||||||
|
|
||||||
|
void unravelIndex(Nd4jPointer *extraPointers, Nd4jLong *indices, Nd4jLong *flatIndices, Nd4jLong length, Nd4jLong *shapeInfo) {
|
||||||
|
throw std::runtime_error("unravelIndex:: Not implemented yet");
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
Nd4jLong* mmapFile(Nd4jPointer *extraPointers, const char *fileName, Nd4jLong length) {
|
Nd4jLong* mmapFile(Nd4jPointer *extraPointers, const char *fileName, Nd4jLong length) {
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
|
|
@ -23,6 +23,8 @@
|
||||||
#include <system/pointercast.h>
|
#include <system/pointercast.h>
|
||||||
#include <stdio.h>
|
#include <stdio.h>
|
||||||
#include <stdlib.h>
|
#include <stdlib.h>
|
||||||
|
#include <helpers/shape.h>
|
||||||
|
|
||||||
#ifdef _OPENMP
|
#ifdef _OPENMP
|
||||||
#include <omp.h>
|
#include <omp.h>
|
||||||
#endif
|
#endif
|
||||||
|
@ -219,5 +221,89 @@ PRAGMA_OMP_SINGLE_ARGS(nowait)
|
||||||
}
|
}
|
||||||
|
|
||||||
BUILD_SINGLE_TEMPLATE(template class ND4J_EXPORT SparseUtils, , LIBND4J_TYPES);
|
BUILD_SINGLE_TEMPLATE(template class ND4J_EXPORT SparseUtils, , LIBND4J_TYPES);
|
||||||
|
|
||||||
|
void IndexUtils::ravelMultiIndex(Nd4jLong *indices, Nd4jLong *flatIndices, Nd4jLong length, Nd4jLong *shapeInfo, int mode){
|
||||||
|
Nd4jLong * shape = shape::shapeOf(shapeInfo);
|
||||||
|
Nd4jLong * stride = shape::stride(shapeInfo);
|
||||||
|
Nd4jLong rank = shape::rank(shapeInfo);
|
||||||
|
int errorCount = 0;
|
||||||
|
|
||||||
|
PRAGMA_OMP_PARALLEL_FOR
|
||||||
|
for (Nd4jLong i = 0; i < length; ++i){
|
||||||
|
Nd4jLong raveledIndex = 0;
|
||||||
|
for (Nd4jLong j = 0; j < rank; ++j){
|
||||||
|
Nd4jLong idx = indices[i * rank + j];
|
||||||
|
if (idx >= shape[j]) {
|
||||||
|
// index does not fit into shape at j dimension.
|
||||||
|
if (mode == ND4J_CLIPMODE_CLIP){
|
||||||
|
// set idx to largest possible value (clip to shape)
|
||||||
|
idx = shape[j] - 1;
|
||||||
|
}
|
||||||
|
else if (mode == ND4J_CLIPMODE_WRAP) {
|
||||||
|
idx %= shape[j];
|
||||||
|
} else {
|
||||||
|
// mode is ND4J_CLIPMODE_THROW or is unknown. either way. throw an error later.
|
||||||
|
// cannot throw here because of parallel region
|
||||||
|
nd4j_printf("sparse::IndexUtils::ravelMultiIndex Cannot ravel index at element %d, does not fit into specified shape.\n", i);
|
||||||
|
++errorCount;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
raveledIndex += idx * stride[j];
|
||||||
|
}
|
||||||
|
flatIndices[i] = raveledIndex;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (errorCount > 0){
|
||||||
|
// throw error if one ocurred in loop
|
||||||
|
throw std::runtime_error("sparse::IndexUtils::ravelMultiIndex Cannot ravel index");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void IndexUtils::unravelIndex(Nd4jLong *indices, Nd4jLong *flatIndices, Nd4jLong length, Nd4jLong *shapeInfo){
|
||||||
|
Nd4jLong * shape = shape::shapeOf(shapeInfo);
|
||||||
|
Nd4jLong * stride = shape::stride(shapeInfo);
|
||||||
|
Nd4jLong rank = shape::rank(shapeInfo);
|
||||||
|
int errorCount = 0;
|
||||||
|
|
||||||
|
// unravelOrder ensures that the dimensions with largest stride are unraveled first.
|
||||||
|
// create vector with elements 0..rank
|
||||||
|
int * unravelOrder = shape::range<int>(0, rank);
|
||||||
|
|
||||||
|
// sort order according to stride length.
|
||||||
|
std::sort(unravelOrder, unravelOrder + rank,
|
||||||
|
[&](int i1, int i2) { return stride[i1] > stride[i2]; } );
|
||||||
|
|
||||||
|
// calculate the largest raveled index that will fit into passed shape
|
||||||
|
Nd4jLong maxRaveledIndex = shape[unravelOrder[0]] * stride[unravelOrder[0]] - 1;
|
||||||
|
|
||||||
|
PRAGMA_OMP_PARALLEL_FOR
|
||||||
|
for (Nd4jLong i = 0; i < length; ++i){
|
||||||
|
Nd4jLong raveledIndex = flatIndices[i];
|
||||||
|
if (raveledIndex > maxRaveledIndex){
|
||||||
|
// cannot throw here because of parallel region
|
||||||
|
nd4j_printf("sparse::IndexUtils::unravelIndex Cannot unravel index at element %d. raveled index of %d does not fit into specified shape.\n", i, raveledIndex);
|
||||||
|
++errorCount;
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int * it = unravelOrder; it != unravelOrder + rank; it++){
|
||||||
|
int j = *it;
|
||||||
|
// how many strides of this size?
|
||||||
|
indices[i * rank + j] = raveledIndex / stride[j];
|
||||||
|
|
||||||
|
// remainder for subsequent smaller strides.
|
||||||
|
raveledIndex %= stride[j];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (errorCount > 0){
|
||||||
|
// throw error if one ocurred in loop
|
||||||
|
nd4j_printf("Largest raveled index is: %d, ", maxRaveledIndex)
|
||||||
|
std::vector<Nd4jLong> v(shape, shape + rank);
|
||||||
|
nd4j_printv("Shape: ", v);
|
||||||
|
throw std::runtime_error("sparse::IndexUtils::unravelIndex Cannot unravel index");
|
||||||
|
}
|
||||||
|
|
||||||
|
delete[] unravelOrder;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -23,7 +23,13 @@
|
||||||
#ifndef LIBND4J_SPECIALS_SPARSE_H
|
#ifndef LIBND4J_SPECIALS_SPARSE_H
|
||||||
#define LIBND4J_SPECIALS_SPARSE_H
|
#define LIBND4J_SPECIALS_SPARSE_H
|
||||||
|
|
||||||
|
#define ND4J_CLIPMODE_THROW 0
|
||||||
|
#define ND4J_CLIPMODE_WRAP 1
|
||||||
|
#define ND4J_CLIPMODE_CLIP 2
|
||||||
|
|
||||||
#include <system/pointercast.h>
|
#include <system/pointercast.h>
|
||||||
|
#include <system/dll.h>
|
||||||
|
|
||||||
|
|
||||||
namespace sd {
|
namespace sd {
|
||||||
namespace sparse {
|
namespace sparse {
|
||||||
|
@ -61,6 +67,25 @@ namespace sd {
|
||||||
int rank);
|
int rank);
|
||||||
|
|
||||||
static void sortCooIndicesGeneric(Nd4jLong *indices, void *vx, Nd4jLong length, int rank);
|
static void sortCooIndicesGeneric(Nd4jLong *indices, void *vx, Nd4jLong length, int rank);
|
||||||
|
|
||||||
|
|
||||||
|
};
|
||||||
|
|
||||||
|
class ND4J_EXPORT IndexUtils {
|
||||||
|
public:
|
||||||
|
/**
|
||||||
|
* Converts indices in COO format into an array of flat indices
|
||||||
|
*
|
||||||
|
* based on numpy.ravel_multi_index
|
||||||
|
*/
|
||||||
|
static void ravelMultiIndex(Nd4jLong *indices, Nd4jLong *flatIndices, Nd4jLong length, Nd4jLong *shapeInfo, int mode);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Converts flat indices to index matrix in COO format
|
||||||
|
*
|
||||||
|
* based on numpy.unravel_index
|
||||||
|
*/
|
||||||
|
static void unravelIndex(Nd4jLong *indices, Nd4jLong *flatIndices, Nd4jLong length, Nd4jLong *shapeInfo);
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -35,7 +35,8 @@ public:
|
||||||
//////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////
|
||||||
TEST_F(SparseUtilsTest, SortCOOindices_Test) {
|
TEST_F(SparseUtilsTest, SortCOOindices_Test) {
|
||||||
|
|
||||||
#ifndef __CUDABLAS__
|
#ifndef __CUDABLAS__
|
||||||
|
|
||||||
|
|
||||||
Nd4jLong * indicesArr = new Nd4jLong[nnz * rank]{
|
Nd4jLong * indicesArr = new Nd4jLong[nnz * rank]{
|
||||||
0,2,7,
|
0,2,7,
|
||||||
|
@ -143,5 +144,102 @@ TEST_F(SparseUtilsTest, SortCOOindices_Test) {
|
||||||
delete[] indicesArr;
|
delete[] indicesArr;
|
||||||
delete[] expIndicesArr;
|
delete[] expIndicesArr;
|
||||||
|
|
||||||
#endif
|
|
||||||
}
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
|
//////////////////////////////////////////////////////////////////////
|
||||||
|
TEST_F(SparseUtilsTest, RavelIndices_Test) {
|
||||||
|
|
||||||
|
#ifndef __CUDABLAS__
|
||||||
|
|
||||||
|
Nd4jLong * indicesArrExp = new Nd4jLong[nnz * rank]{
|
||||||
|
0,2,7,
|
||||||
|
2,36,35,
|
||||||
|
3,30,17,
|
||||||
|
5,12,22,
|
||||||
|
5,43,45,
|
||||||
|
6,32,11,
|
||||||
|
8,8,32,
|
||||||
|
9,29,11,
|
||||||
|
5,11,22,
|
||||||
|
15,26,16,
|
||||||
|
17,48,49,
|
||||||
|
24,28,31,
|
||||||
|
26,6,23,
|
||||||
|
31,21,31,
|
||||||
|
35,46,45,
|
||||||
|
37,13,14,
|
||||||
|
6,38,18,
|
||||||
|
7,28,20,
|
||||||
|
8,29,39,
|
||||||
|
8,32,30,
|
||||||
|
9,42,43,
|
||||||
|
11,15,18,
|
||||||
|
13,18,45,
|
||||||
|
29,26,39,
|
||||||
|
30,8,25,
|
||||||
|
42,31,24,
|
||||||
|
28,33,5,
|
||||||
|
31,27,1,
|
||||||
|
35,43,26,
|
||||||
|
36,8,37,
|
||||||
|
39,22,14,
|
||||||
|
39,24,42,
|
||||||
|
42,48,2,
|
||||||
|
43,26,48,
|
||||||
|
44,23,49,
|
||||||
|
45,18,34,
|
||||||
|
46,28,5,
|
||||||
|
46,32,17,
|
||||||
|
48,34,44,
|
||||||
|
49,38,39,
|
||||||
|
};
|
||||||
|
Nd4jLong * indicesArr = new Nd4jLong[nnz * rank];
|
||||||
|
|
||||||
|
Nd4jLong * flatIndicesExp = new Nd4jLong[nnz]{
|
||||||
|
147, 10955, 14717, 21862, 24055, 27451, 34192, 39841,
|
||||||
|
21792, 64836, 74809, 102791, 109643, 131701, 150265, 156324,
|
||||||
|
27878, 31380, 35669, 35870, 40783, 47268, 55905, 123659,
|
||||||
|
126585, 178594, 119915, 132091, 150036, 151797, 165354, 165522,
|
||||||
|
179762, 182468, 186459, 190294, 195165, 195457, 204024, 208499
|
||||||
|
};
|
||||||
|
|
||||||
|
Nd4jLong * flatIndices = new Nd4jLong[nnz];
|
||||||
|
|
||||||
|
|
||||||
|
Nd4jLong * shape = new Nd4jLong[rank]{50, 60, 70};
|
||||||
|
Nd4jLong * shapeInfoBuffer = shape::shapeBuffer(rank, sd::DataType::INT64, shape);
|
||||||
|
|
||||||
|
|
||||||
|
sd::sparse::IndexUtils::ravelMultiIndex(indicesArrExp, flatIndices, nnz, shapeInfoBuffer, ND4J_CLIPMODE_THROW);
|
||||||
|
|
||||||
|
for ( int i = 0; i < nnz; ++i){
|
||||||
|
ASSERT_EQ(flatIndicesExp[i], flatIndices[i]);
|
||||||
|
}
|
||||||
|
|
||||||
|
sd::sparse::IndexUtils::unravelIndex(indicesArr, flatIndices, nnz, shapeInfoBuffer);
|
||||||
|
|
||||||
|
for ( int i = 0; i < nnz * rank; ++i){
|
||||||
|
ASSERT_EQ(indicesArrExp[i], indicesArr[i]);
|
||||||
|
}
|
||||||
|
|
||||||
|
shape[2] = 30;
|
||||||
|
shapeInfoBuffer = shape::shapeBuffer(rank, sd::DataType::INT64, shape);
|
||||||
|
|
||||||
|
try {
|
||||||
|
sd::sparse::IndexUtils::ravelMultiIndex(indicesArrExp, flatIndices, nnz, shapeInfoBuffer, ND4J_CLIPMODE_THROW);
|
||||||
|
FAIL();
|
||||||
|
} catch (const std::runtime_error& e) {
|
||||||
|
// pass
|
||||||
|
}
|
||||||
|
|
||||||
|
delete[] indicesArrExp;
|
||||||
|
delete[] indicesArr;
|
||||||
|
delete[] flatIndicesExp;
|
||||||
|
delete[] flatIndices;
|
||||||
|
delete[] shape;
|
||||||
|
delete[] shapeInfoBuffer;
|
||||||
|
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
|
@ -1025,6 +1025,32 @@ public interface NativeOps {
|
||||||
void sortCooIndices(PointerPointer extraPointers, @Cast("Nd4jLong *") LongPointer indices, Pointer x, long length, @Cast("Nd4jLong *") LongPointer shapeInfo);
|
void sortCooIndices(PointerPointer extraPointers, @Cast("Nd4jLong *") LongPointer indices, Pointer x, long length, @Cast("Nd4jLong *") LongPointer shapeInfo);
|
||||||
|
|
||||||
|
|
||||||
|
/**
|
||||||
|
*
|
||||||
|
* @param extraPointers not used
|
||||||
|
* @param indices DataBuffer containing COO indices for a sparse matrix that is to be raveled/flattened
|
||||||
|
* @param flatIndices DataBuffer where the raveled/flattened indices are to be written to
|
||||||
|
* @param length number of non-zero entries (length of flatIndices)
|
||||||
|
* @param shapeInfo DataBuffer with ShapeInfo for the full matrix to be flattened
|
||||||
|
* @param mode clipMode determines the strategy to use if some of the the passed COO indices does
|
||||||
|
* not fit into the shape determined by fullShapeBuffer
|
||||||
|
* 0 throw an exception (default)
|
||||||
|
* 1 wrap around shape
|
||||||
|
* 2 clip to shape
|
||||||
|
*/
|
||||||
|
void ravelMultiIndex(PointerPointer extraPointers, @Cast("Nd4jLong *") LongPointer indices, @Cast("Nd4jLong *") LongPointer flatIndices, long length, @Cast("Nd4jLong *") LongPointer shapeInfo, int mode);
|
||||||
|
|
||||||
|
/**
|
||||||
|
*
|
||||||
|
* @param extraPointers not used
|
||||||
|
* @param indices DataBuffer where the unraveled COO indices are to be written
|
||||||
|
* @param flatIndices DataBuffer containing the raveled/flattened indices to be unravel
|
||||||
|
* @param length number of non-zero entries (length of flatIndices)
|
||||||
|
* @param shapeInfo DataBuffer with ShapeInfo for the full matrix to be unraveled
|
||||||
|
*/
|
||||||
|
void unravelIndex(PointerPointer extraPointers, @Cast("Nd4jLong *") LongPointer indices, @Cast("Nd4jLong *") LongPointer flatIndices, long length, @Cast("Nd4jLong *") LongPointer shapeInfo);
|
||||||
|
|
||||||
|
|
||||||
LongPointer mmapFile(PointerPointer extraPointers, String fileName, long length);
|
LongPointer mmapFile(PointerPointer extraPointers, String fileName, long length);
|
||||||
|
|
||||||
void munmapFile(PointerPointer extraPointers, LongPointer ptrMap, long length);
|
void munmapFile(PointerPointer extraPointers, LongPointer ptrMap, long length);
|
||||||
|
|
|
@ -3151,6 +3151,12 @@ public native void sortCooIndices(@Cast("Nd4jPointer*") PointerPointer extraPoin
|
||||||
@Cast("Nd4jLong") long length,
|
@Cast("Nd4jLong") long length,
|
||||||
@Cast("const Nd4jLong*") long[] xShapeInfo);
|
@Cast("const Nd4jLong*") long[] xShapeInfo);
|
||||||
|
|
||||||
|
public native void ravelMultiIndex(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong*") LongPointer indices, @Cast("Nd4jLong*") LongPointer flatIndices, @Cast("Nd4jLong") long length, @Cast("Nd4jLong*") LongPointer shapeInfo, int mode);
|
||||||
|
public native void ravelMultiIndex(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong*") LongBuffer indices, @Cast("Nd4jLong*") LongBuffer flatIndices, @Cast("Nd4jLong") long length, @Cast("Nd4jLong*") LongBuffer shapeInfo, int mode);
|
||||||
|
public native void ravelMultiIndex(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong*") long[] indices, @Cast("Nd4jLong*") long[] flatIndices, @Cast("Nd4jLong") long length, @Cast("Nd4jLong*") long[] shapeInfo, int mode);
|
||||||
|
public native void unravelIndex(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong*") LongPointer indices, @Cast("Nd4jLong*") LongPointer flatIndices, @Cast("Nd4jLong") long length, @Cast("Nd4jLong*") LongPointer shapeInfo);
|
||||||
|
public native void unravelIndex(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong*") LongBuffer indices, @Cast("Nd4jLong*") LongBuffer flatIndices, @Cast("Nd4jLong") long length, @Cast("Nd4jLong*") LongBuffer shapeInfo);
|
||||||
|
public native void unravelIndex(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong*") long[] indices, @Cast("Nd4jLong*") long[] flatIndices, @Cast("Nd4jLong") long length, @Cast("Nd4jLong*") long[] shapeInfo);
|
||||||
|
|
||||||
public native @Cast("Nd4jLong*") LongPointer mmapFile(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("char*") String fileName, @Cast("Nd4jLong") long length);
|
public native @Cast("Nd4jLong*") LongPointer mmapFile(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("char*") String fileName, @Cast("Nd4jLong") long length);
|
||||||
public native @Cast("Nd4jLong*") LongBuffer mmapFile(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("char*") BytePointer fileName, @Cast("Nd4jLong") long length);
|
public native @Cast("Nd4jLong*") LongBuffer mmapFile(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("char*") BytePointer fileName, @Cast("Nd4jLong") long length);
|
||||||
|
|
|
@ -0,0 +1,170 @@
|
||||||
|
package org.nd4j.linalg.specials;
|
||||||
|
|
||||||
|
import lombok.extern.slf4j.Slf4j;
|
||||||
|
import org.bytedeco.javacpp.LongPointer;
|
||||||
|
import org.junit.After;
|
||||||
|
import org.junit.Assert;
|
||||||
|
import org.junit.Before;
|
||||||
|
import org.junit.Test;
|
||||||
|
import org.junit.runner.RunWith;
|
||||||
|
import org.junit.runners.Parameterized;
|
||||||
|
import org.nd4j.linalg.BaseNd4jTest;
|
||||||
|
import org.nd4j.linalg.api.buffer.DataBuffer;
|
||||||
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
|
import org.nd4j.linalg.factory.Nd4jBackend;
|
||||||
|
import org.nd4j.nativeblas.NativeOpsHolder;
|
||||||
|
|
||||||
|
|
||||||
|
@Slf4j
|
||||||
|
@RunWith(Parameterized.class)
|
||||||
|
public class RavelIndexTest extends BaseNd4jTest {
|
||||||
|
|
||||||
|
DataType initialType;
|
||||||
|
|
||||||
|
public RavelIndexTest(Nd4jBackend backend) {
|
||||||
|
super(backend);
|
||||||
|
this.initialType = Nd4j.dataType();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Before
|
||||||
|
public void setUp() {
|
||||||
|
Nd4j.setDataType(DataType.FLOAT);
|
||||||
|
}
|
||||||
|
|
||||||
|
@After
|
||||||
|
public void setDown() {
|
||||||
|
Nd4j.setDataType(initialType);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public char ordering() {
|
||||||
|
return 'c';
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void ravelIndexesTest() {
|
||||||
|
// FIXME: we don't want this test running on cuda for now
|
||||||
|
if (Nd4j.getExecutioner().getClass().getCanonicalName().toLowerCase().contains("cuda"))
|
||||||
|
return;
|
||||||
|
|
||||||
|
long[] multiIdxArray = new long[] {
|
||||||
|
0,2,7,
|
||||||
|
2,36,35,
|
||||||
|
3,30,17,
|
||||||
|
5,12,22,
|
||||||
|
5,43,45,
|
||||||
|
6,32,11,
|
||||||
|
8,8,32,
|
||||||
|
9,29,11,
|
||||||
|
5,11,22,
|
||||||
|
15,26,16,
|
||||||
|
17,48,49,
|
||||||
|
24,28,31,
|
||||||
|
26,6,23,
|
||||||
|
31,21,31,
|
||||||
|
35,46,45,
|
||||||
|
37,13,14,
|
||||||
|
6,38,18,
|
||||||
|
7,28,20,
|
||||||
|
8,29,39,
|
||||||
|
8,32,30,
|
||||||
|
9,42,43,
|
||||||
|
11,15,18,
|
||||||
|
13,18,45,
|
||||||
|
29,26,39,
|
||||||
|
30,8,25,
|
||||||
|
42,31,24,
|
||||||
|
28,33,5,
|
||||||
|
31,27,1,
|
||||||
|
35,43,26,
|
||||||
|
36,8,37,
|
||||||
|
39,22,14,
|
||||||
|
39,24,42,
|
||||||
|
42,48,2,
|
||||||
|
43,26,48,
|
||||||
|
44,23,49,
|
||||||
|
45,18,34,
|
||||||
|
46,28,5,
|
||||||
|
46,32,17,
|
||||||
|
48,34,44,
|
||||||
|
49,38,39,
|
||||||
|
};
|
||||||
|
|
||||||
|
long[] flatIdxArray = new long[] {
|
||||||
|
147, 10955, 14717, 21862, 24055, 27451, 34192, 39841,
|
||||||
|
21792, 64836, 74809, 102791, 109643, 131701, 150265, 156324,
|
||||||
|
27878, 31380, 35669, 35870, 40783, 47268, 55905, 123659,
|
||||||
|
126585, 178594, 119915, 132091, 150036, 151797, 165354, 165522,
|
||||||
|
179762, 182468, 186459, 190294, 195165, 195457, 204024, 208499
|
||||||
|
};
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
int clipMode = 0;
|
||||||
|
|
||||||
|
|
||||||
|
long DIM = 3;
|
||||||
|
long length = multiIdxArray.length / DIM;
|
||||||
|
long[] shape = new long[] {50, 60, 70};
|
||||||
|
|
||||||
|
|
||||||
|
DataBuffer multiIdxDB = Nd4j.getDataBufferFactory().createLong(multiIdxArray);
|
||||||
|
DataBuffer flatIdxDB = Nd4j.getDataBufferFactory().createLong(flatIdxArray);
|
||||||
|
DataBuffer shapeInfo = Nd4j.getShapeInfoProvider().createShapeInformation(shape, DataType.FLOAT).getFirst();
|
||||||
|
|
||||||
|
DataBuffer resultMulti = Nd4j.getDataBufferFactory().createLong(length*DIM);
|
||||||
|
DataBuffer resultFlat = Nd4j.getDataBufferFactory().createLong(length);
|
||||||
|
|
||||||
|
NativeOpsHolder.getInstance().getDeviceNativeOps().ravelMultiIndex(null, (LongPointer) multiIdxDB.addressPointer(),
|
||||||
|
(LongPointer) resultFlat.addressPointer(), length, (LongPointer) shapeInfo.addressPointer(),clipMode);
|
||||||
|
|
||||||
|
Assert.assertArrayEquals(flatIdxArray, resultFlat.asLong());
|
||||||
|
|
||||||
|
NativeOpsHolder.getInstance().getDeviceNativeOps().unravelIndex(null, (LongPointer) resultMulti.addressPointer(),
|
||||||
|
(LongPointer) flatIdxDB.addressPointer(), length, (LongPointer) shapeInfo.addressPointer());
|
||||||
|
|
||||||
|
Assert.assertArrayEquals(multiIdxArray, resultMulti.asLong());
|
||||||
|
|
||||||
|
|
||||||
|
//testing various clipMode cases
|
||||||
|
// clipMode = 0: throw an exception
|
||||||
|
try {
|
||||||
|
shape[2] = 10;
|
||||||
|
shapeInfo = Nd4j.getShapeInfoProvider().createShapeInformation(shape, DataType.FLOAT).getFirst();
|
||||||
|
NativeOpsHolder.getInstance().getDeviceNativeOps().ravelMultiIndex(null, (LongPointer) multiIdxDB.addressPointer(),
|
||||||
|
(LongPointer) resultFlat.addressPointer(), length, (LongPointer) shapeInfo.addressPointer(),clipMode);
|
||||||
|
Assert.fail("No exception thrown while using CLIP_MODE_THROW.");
|
||||||
|
|
||||||
|
} catch (RuntimeException e) {
|
||||||
|
//OK
|
||||||
|
}
|
||||||
|
// clipMode = 1: wrap around shape
|
||||||
|
clipMode = 1;
|
||||||
|
multiIdxDB = Nd4j.getDataBufferFactory().createLong(new long[] {3,4, 6,5, 6,9});
|
||||||
|
resultFlat = Nd4j.getDataBufferFactory().createLong(3);
|
||||||
|
shapeInfo = Nd4j.getShapeInfoProvider().createShapeInformation(new long[] {4, 6}, DataType.FLOAT).getFirst();
|
||||||
|
length = 3;
|
||||||
|
|
||||||
|
NativeOpsHolder.getInstance().getDeviceNativeOps().ravelMultiIndex(null, (LongPointer) multiIdxDB.addressPointer(),
|
||||||
|
(LongPointer) resultFlat.addressPointer(), length, (LongPointer) shapeInfo.addressPointer(), clipMode);
|
||||||
|
Assert.assertArrayEquals(new long[] {22, 17, 15}, resultFlat.asLong());
|
||||||
|
|
||||||
|
// clipMode = 2: clip to shape
|
||||||
|
clipMode = 2;
|
||||||
|
multiIdxDB = Nd4j.getDataBufferFactory().createLong(new long[] {3,4, 6,5, 6,9});
|
||||||
|
resultFlat = Nd4j.getDataBufferFactory().createLong(3);
|
||||||
|
shapeInfo = Nd4j.getShapeInfoProvider().createShapeInformation(new long[] {4, 6}, DataType.FLOAT).getFirst();
|
||||||
|
length = 3;
|
||||||
|
|
||||||
|
NativeOpsHolder.getInstance().getDeviceNativeOps().ravelMultiIndex(null, (LongPointer) multiIdxDB.addressPointer(),
|
||||||
|
(LongPointer) resultFlat.addressPointer(), length, (LongPointer) shapeInfo.addressPointer(), clipMode);
|
||||||
|
|
||||||
|
Assert.assertArrayEquals(new long[] {22, 23, 23}, resultFlat.asLong());
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
Loading…
Reference in New Issue