add IndexUtils class containing ravelMultiIndex and unravelIndex methods (#9122)

Also add test functions both for Java and C++.

Signed-off-by: Péter Zarándy <pza@wehowsky.com>
master
pza94 2020-12-09 10:28:59 +01:00 committed by GitHub
parent a1fcc5f19f
commit 95ca39bd21
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 462 additions and 3 deletions

View File

@ -660,6 +660,13 @@ static void execTransformBool(sd::LaunchContext *lc,
BUILD_SINGLE_SELECTOR(xType, sd::sparse::SparseUtils, ::sortCooIndicesGeneric(indices, x, length, rank), LIBND4J_TYPES);
}
inline static void execRavelMultiIndex(Nd4jLong *indices, Nd4jLong *flatIndices, Nd4jLong length, Nd4jLong *shapeInfo, int mode) {
sd::sparse::IndexUtils::ravelMultiIndex(indices, flatIndices, length, shapeInfo, mode);
}
inline static void execUnravelIndex(Nd4jLong *indices, Nd4jLong *flatIndices, Nd4jLong length, Nd4jLong *shapeInfo) {
sd::sparse::IndexUtils::unravelIndex(indices, flatIndices, length, shapeInfo);
}
inline static Nd4jLong encodeBitmap(void *dx, const Nd4jLong *xShapeInfo, Nd4jLong N, int *dz, float threshold) {
auto xType = sd::ArrayOptions::dataType(xShapeInfo);

View File

@ -1482,6 +1482,30 @@ ND4J_EXPORT void sortCooIndices(Nd4jPointer *extraPointers,
Nd4jLong length,
const Nd4jLong *xShapeInfo);
/**
*
* @param extraPointers not used
* @param indices DataBuffer containing COO indices for a sparse matrix that is to be raveled/flattened
* @param flatIndices DataBuffer where the raveled/flattened indices are to be written to
* @param length number of non-zero entries (length of flatIndices)
* @param fullShapeBuffer DataBuffer with ShapeInfo for the full matrix to be flattened
* @param mode clipMode determines the strategy to use if some of the the passed COO indices does
* not fit into the shape determined by fullShapeBuffer
* 0 throw an exception (default)
* 1 wrap around shape
* 2 clip to shape
*/
ND4J_EXPORT void ravelMultiIndex(Nd4jPointer *extraPointers, Nd4jLong *indices, Nd4jLong *flatIndices, Nd4jLong length, Nd4jLong *shapeInfo, int mode);
/**
*
* @param extraPointers not used
* @param indices DataBuffer where the unraveled COO indices are to be written
* @param flatIndices DataBuffer containing the raveled/flattened indices to be unravel
* @param length number of non-zero entries (length of flatIndices)
* @param fullShapeBuffer DataBuffer with ShapeInfo for the full matrix to be unraveled
*/
ND4J_EXPORT void unravelIndex(Nd4jPointer *extraPointers, Nd4jLong *indices, Nd4jLong *flatIndices, Nd4jLong length, Nd4jLong *shapeInfo);
ND4J_EXPORT Nd4jLong* mmapFile(Nd4jPointer *extraPointers, const char *fileName, Nd4jLong length);

View File

@ -1843,6 +1843,14 @@ void sortCooIndices(Nd4jPointer *extraPointers,
}
}
void ravelMultiIndex(Nd4jPointer *extraPointers, Nd4jLong *indices, Nd4jLong *flatIndices, Nd4jLong length, Nd4jLong *shapeInfo, int mode) {
NativeOpExecutioner::execRavelMultiIndex(indices, flatIndices, length, shapeInfo, mode);
}
void unravelIndex(Nd4jPointer *extraPointers, Nd4jLong *indices, Nd4jLong *flatIndices, Nd4jLong length, Nd4jLong *shapeInfo) {
NativeOpExecutioner::execUnravelIndex(indices, flatIndices, length, shapeInfo);
}
Nd4jLong encodeBitmap(Nd4jPointer *extraPointers, void *hX, Nd4jLong const* hXShapeInfo, Nd4jLong N, int *dz, float threshold) {
return NativeOpExecutioner::encodeBitmap(hX, hXShapeInfo, N, dz, threshold);
}

View File

@ -2564,6 +2564,15 @@ void sortCooIndices(Nd4jPointer *extraPointers, Nd4jLong *indices, void *values,
throw std::runtime_error("sortCooIndices:: Not implemented yet");
}
void ravelMultiIndex(Nd4jPointer *extraPointers, Nd4jLong *indices, Nd4jLong *flatIndices, Nd4jLong length, Nd4jLong *shapeInfo, int mode) {
throw std::runtime_error("ravelMultiIndex:: Not implemented yet");
}
void unravelIndex(Nd4jPointer *extraPointers, Nd4jLong *indices, Nd4jLong *flatIndices, Nd4jLong length, Nd4jLong *shapeInfo) {
throw std::runtime_error("unravelIndex:: Not implemented yet");
}
Nd4jLong* mmapFile(Nd4jPointer *extraPointers, const char *fileName, Nd4jLong length) {
return nullptr;
}

View File

@ -23,6 +23,8 @@
#include <system/pointercast.h>
#include <stdio.h>
#include <stdlib.h>
#include <helpers/shape.h>
#ifdef _OPENMP
#include <omp.h>
#endif
@ -219,5 +221,89 @@ PRAGMA_OMP_SINGLE_ARGS(nowait)
}
BUILD_SINGLE_TEMPLATE(template class ND4J_EXPORT SparseUtils, , LIBND4J_TYPES);
void IndexUtils::ravelMultiIndex(Nd4jLong *indices, Nd4jLong *flatIndices, Nd4jLong length, Nd4jLong *shapeInfo, int mode){
Nd4jLong * shape = shape::shapeOf(shapeInfo);
Nd4jLong * stride = shape::stride(shapeInfo);
Nd4jLong rank = shape::rank(shapeInfo);
int errorCount = 0;
PRAGMA_OMP_PARALLEL_FOR
for (Nd4jLong i = 0; i < length; ++i){
Nd4jLong raveledIndex = 0;
for (Nd4jLong j = 0; j < rank; ++j){
Nd4jLong idx = indices[i * rank + j];
if (idx >= shape[j]) {
// index does not fit into shape at j dimension.
if (mode == ND4J_CLIPMODE_CLIP){
// set idx to largest possible value (clip to shape)
idx = shape[j] - 1;
}
else if (mode == ND4J_CLIPMODE_WRAP) {
idx %= shape[j];
} else {
// mode is ND4J_CLIPMODE_THROW or is unknown. either way. throw an error later.
// cannot throw here because of parallel region
nd4j_printf("sparse::IndexUtils::ravelMultiIndex Cannot ravel index at element %d, does not fit into specified shape.\n", i);
++errorCount;
}
}
raveledIndex += idx * stride[j];
}
flatIndices[i] = raveledIndex;
}
if (errorCount > 0){
// throw error if one ocurred in loop
throw std::runtime_error("sparse::IndexUtils::ravelMultiIndex Cannot ravel index");
}
}
void IndexUtils::unravelIndex(Nd4jLong *indices, Nd4jLong *flatIndices, Nd4jLong length, Nd4jLong *shapeInfo){
Nd4jLong * shape = shape::shapeOf(shapeInfo);
Nd4jLong * stride = shape::stride(shapeInfo);
Nd4jLong rank = shape::rank(shapeInfo);
int errorCount = 0;
// unravelOrder ensures that the dimensions with largest stride are unraveled first.
// create vector with elements 0..rank
int * unravelOrder = shape::range<int>(0, rank);
// sort order according to stride length.
std::sort(unravelOrder, unravelOrder + rank,
[&](int i1, int i2) { return stride[i1] > stride[i2]; } );
// calculate the largest raveled index that will fit into passed shape
Nd4jLong maxRaveledIndex = shape[unravelOrder[0]] * stride[unravelOrder[0]] - 1;
PRAGMA_OMP_PARALLEL_FOR
for (Nd4jLong i = 0; i < length; ++i){
Nd4jLong raveledIndex = flatIndices[i];
if (raveledIndex > maxRaveledIndex){
// cannot throw here because of parallel region
nd4j_printf("sparse::IndexUtils::unravelIndex Cannot unravel index at element %d. raveled index of %d does not fit into specified shape.\n", i, raveledIndex);
++errorCount;
}
for (int * it = unravelOrder; it != unravelOrder + rank; it++){
int j = *it;
// how many strides of this size?
indices[i * rank + j] = raveledIndex / stride[j];
// remainder for subsequent smaller strides.
raveledIndex %= stride[j];
}
}
if (errorCount > 0){
// throw error if one ocurred in loop
nd4j_printf("Largest raveled index is: %d, ", maxRaveledIndex)
std::vector<Nd4jLong> v(shape, shape + rank);
nd4j_printv("Shape: ", v);
throw std::runtime_error("sparse::IndexUtils::unravelIndex Cannot unravel index");
}
delete[] unravelOrder;
}
}
}

View File

@ -23,7 +23,13 @@
#ifndef LIBND4J_SPECIALS_SPARSE_H
#define LIBND4J_SPECIALS_SPARSE_H
#define ND4J_CLIPMODE_THROW 0
#define ND4J_CLIPMODE_WRAP 1
#define ND4J_CLIPMODE_CLIP 2
#include <system/pointercast.h>
#include <system/dll.h>
namespace sd {
namespace sparse {
@ -61,6 +67,25 @@ namespace sd {
int rank);
static void sortCooIndicesGeneric(Nd4jLong *indices, void *vx, Nd4jLong length, int rank);
};
class ND4J_EXPORT IndexUtils {
public:
/**
* Converts indices in COO format into an array of flat indices
*
* based on numpy.ravel_multi_index
*/
static void ravelMultiIndex(Nd4jLong *indices, Nd4jLong *flatIndices, Nd4jLong length, Nd4jLong *shapeInfo, int mode);
/**
* Converts flat indices to index matrix in COO format
*
* based on numpy.unravel_index
*/
static void unravelIndex(Nd4jLong *indices, Nd4jLong *flatIndices, Nd4jLong length, Nd4jLong *shapeInfo);
};
}
}

View File

@ -35,7 +35,8 @@ public:
//////////////////////////////////////////////////////////////////////
TEST_F(SparseUtilsTest, SortCOOindices_Test) {
#ifndef __CUDABLAS__
#ifndef __CUDABLAS__
Nd4jLong * indicesArr = new Nd4jLong[nnz * rank]{
0,2,7,
@ -143,5 +144,102 @@ TEST_F(SparseUtilsTest, SortCOOindices_Test) {
delete[] indicesArr;
delete[] expIndicesArr;
#endif
}
#endif
}
//////////////////////////////////////////////////////////////////////
TEST_F(SparseUtilsTest, RavelIndices_Test) {
#ifndef __CUDABLAS__
Nd4jLong * indicesArrExp = new Nd4jLong[nnz * rank]{
0,2,7,
2,36,35,
3,30,17,
5,12,22,
5,43,45,
6,32,11,
8,8,32,
9,29,11,
5,11,22,
15,26,16,
17,48,49,
24,28,31,
26,6,23,
31,21,31,
35,46,45,
37,13,14,
6,38,18,
7,28,20,
8,29,39,
8,32,30,
9,42,43,
11,15,18,
13,18,45,
29,26,39,
30,8,25,
42,31,24,
28,33,5,
31,27,1,
35,43,26,
36,8,37,
39,22,14,
39,24,42,
42,48,2,
43,26,48,
44,23,49,
45,18,34,
46,28,5,
46,32,17,
48,34,44,
49,38,39,
};
Nd4jLong * indicesArr = new Nd4jLong[nnz * rank];
Nd4jLong * flatIndicesExp = new Nd4jLong[nnz]{
147, 10955, 14717, 21862, 24055, 27451, 34192, 39841,
21792, 64836, 74809, 102791, 109643, 131701, 150265, 156324,
27878, 31380, 35669, 35870, 40783, 47268, 55905, 123659,
126585, 178594, 119915, 132091, 150036, 151797, 165354, 165522,
179762, 182468, 186459, 190294, 195165, 195457, 204024, 208499
};
Nd4jLong * flatIndices = new Nd4jLong[nnz];
Nd4jLong * shape = new Nd4jLong[rank]{50, 60, 70};
Nd4jLong * shapeInfoBuffer = shape::shapeBuffer(rank, sd::DataType::INT64, shape);
sd::sparse::IndexUtils::ravelMultiIndex(indicesArrExp, flatIndices, nnz, shapeInfoBuffer, ND4J_CLIPMODE_THROW);
for ( int i = 0; i < nnz; ++i){
ASSERT_EQ(flatIndicesExp[i], flatIndices[i]);
}
sd::sparse::IndexUtils::unravelIndex(indicesArr, flatIndices, nnz, shapeInfoBuffer);
for ( int i = 0; i < nnz * rank; ++i){
ASSERT_EQ(indicesArrExp[i], indicesArr[i]);
}
shape[2] = 30;
shapeInfoBuffer = shape::shapeBuffer(rank, sd::DataType::INT64, shape);
try {
sd::sparse::IndexUtils::ravelMultiIndex(indicesArrExp, flatIndices, nnz, shapeInfoBuffer, ND4J_CLIPMODE_THROW);
FAIL();
} catch (const std::runtime_error& e) {
// pass
}
delete[] indicesArrExp;
delete[] indicesArr;
delete[] flatIndicesExp;
delete[] flatIndices;
delete[] shape;
delete[] shapeInfoBuffer;
#endif
}

View File

@ -1025,6 +1025,32 @@ public interface NativeOps {
void sortCooIndices(PointerPointer extraPointers, @Cast("Nd4jLong *") LongPointer indices, Pointer x, long length, @Cast("Nd4jLong *") LongPointer shapeInfo);
/**
*
* @param extraPointers not used
* @param indices DataBuffer containing COO indices for a sparse matrix that is to be raveled/flattened
* @param flatIndices DataBuffer where the raveled/flattened indices are to be written to
* @param length number of non-zero entries (length of flatIndices)
* @param shapeInfo DataBuffer with ShapeInfo for the full matrix to be flattened
* @param mode clipMode determines the strategy to use if some of the the passed COO indices does
* not fit into the shape determined by fullShapeBuffer
* 0 throw an exception (default)
* 1 wrap around shape
* 2 clip to shape
*/
void ravelMultiIndex(PointerPointer extraPointers, @Cast("Nd4jLong *") LongPointer indices, @Cast("Nd4jLong *") LongPointer flatIndices, long length, @Cast("Nd4jLong *") LongPointer shapeInfo, int mode);
/**
*
* @param extraPointers not used
* @param indices DataBuffer where the unraveled COO indices are to be written
* @param flatIndices DataBuffer containing the raveled/flattened indices to be unravel
* @param length number of non-zero entries (length of flatIndices)
* @param shapeInfo DataBuffer with ShapeInfo for the full matrix to be unraveled
*/
void unravelIndex(PointerPointer extraPointers, @Cast("Nd4jLong *") LongPointer indices, @Cast("Nd4jLong *") LongPointer flatIndices, long length, @Cast("Nd4jLong *") LongPointer shapeInfo);
LongPointer mmapFile(PointerPointer extraPointers, String fileName, long length);
void munmapFile(PointerPointer extraPointers, LongPointer ptrMap, long length);

View File

@ -3151,6 +3151,12 @@ public native void sortCooIndices(@Cast("Nd4jPointer*") PointerPointer extraPoin
@Cast("Nd4jLong") long length,
@Cast("const Nd4jLong*") long[] xShapeInfo);
public native void ravelMultiIndex(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong*") LongPointer indices, @Cast("Nd4jLong*") LongPointer flatIndices, @Cast("Nd4jLong") long length, @Cast("Nd4jLong*") LongPointer shapeInfo, int mode);
public native void ravelMultiIndex(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong*") LongBuffer indices, @Cast("Nd4jLong*") LongBuffer flatIndices, @Cast("Nd4jLong") long length, @Cast("Nd4jLong*") LongBuffer shapeInfo, int mode);
public native void ravelMultiIndex(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong*") long[] indices, @Cast("Nd4jLong*") long[] flatIndices, @Cast("Nd4jLong") long length, @Cast("Nd4jLong*") long[] shapeInfo, int mode);
public native void unravelIndex(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong*") LongPointer indices, @Cast("Nd4jLong*") LongPointer flatIndices, @Cast("Nd4jLong") long length, @Cast("Nd4jLong*") LongPointer shapeInfo);
public native void unravelIndex(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong*") LongBuffer indices, @Cast("Nd4jLong*") LongBuffer flatIndices, @Cast("Nd4jLong") long length, @Cast("Nd4jLong*") LongBuffer shapeInfo);
public native void unravelIndex(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong*") long[] indices, @Cast("Nd4jLong*") long[] flatIndices, @Cast("Nd4jLong") long length, @Cast("Nd4jLong*") long[] shapeInfo);
public native @Cast("Nd4jLong*") LongPointer mmapFile(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("char*") String fileName, @Cast("Nd4jLong") long length);
public native @Cast("Nd4jLong*") LongBuffer mmapFile(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("char*") BytePointer fileName, @Cast("Nd4jLong") long length);

View File

@ -0,0 +1,170 @@
package org.nd4j.linalg.specials;
import lombok.extern.slf4j.Slf4j;
import org.bytedeco.javacpp.LongPointer;
import org.junit.After;
import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
import org.nd4j.linalg.BaseNd4jTest;
import org.nd4j.linalg.api.buffer.DataBuffer;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.factory.Nd4jBackend;
import org.nd4j.nativeblas.NativeOpsHolder;
@Slf4j
@RunWith(Parameterized.class)
public class RavelIndexTest extends BaseNd4jTest {
DataType initialType;
public RavelIndexTest(Nd4jBackend backend) {
super(backend);
this.initialType = Nd4j.dataType();
}
@Before
public void setUp() {
Nd4j.setDataType(DataType.FLOAT);
}
@After
public void setDown() {
Nd4j.setDataType(initialType);
}
@Override
public char ordering() {
return 'c';
}
@Test
public void ravelIndexesTest() {
// FIXME: we don't want this test running on cuda for now
if (Nd4j.getExecutioner().getClass().getCanonicalName().toLowerCase().contains("cuda"))
return;
long[] multiIdxArray = new long[] {
0,2,7,
2,36,35,
3,30,17,
5,12,22,
5,43,45,
6,32,11,
8,8,32,
9,29,11,
5,11,22,
15,26,16,
17,48,49,
24,28,31,
26,6,23,
31,21,31,
35,46,45,
37,13,14,
6,38,18,
7,28,20,
8,29,39,
8,32,30,
9,42,43,
11,15,18,
13,18,45,
29,26,39,
30,8,25,
42,31,24,
28,33,5,
31,27,1,
35,43,26,
36,8,37,
39,22,14,
39,24,42,
42,48,2,
43,26,48,
44,23,49,
45,18,34,
46,28,5,
46,32,17,
48,34,44,
49,38,39,
};
long[] flatIdxArray = new long[] {
147, 10955, 14717, 21862, 24055, 27451, 34192, 39841,
21792, 64836, 74809, 102791, 109643, 131701, 150265, 156324,
27878, 31380, 35669, 35870, 40783, 47268, 55905, 123659,
126585, 178594, 119915, 132091, 150036, 151797, 165354, 165522,
179762, 182468, 186459, 190294, 195165, 195457, 204024, 208499
};
int clipMode = 0;
long DIM = 3;
long length = multiIdxArray.length / DIM;
long[] shape = new long[] {50, 60, 70};
DataBuffer multiIdxDB = Nd4j.getDataBufferFactory().createLong(multiIdxArray);
DataBuffer flatIdxDB = Nd4j.getDataBufferFactory().createLong(flatIdxArray);
DataBuffer shapeInfo = Nd4j.getShapeInfoProvider().createShapeInformation(shape, DataType.FLOAT).getFirst();
DataBuffer resultMulti = Nd4j.getDataBufferFactory().createLong(length*DIM);
DataBuffer resultFlat = Nd4j.getDataBufferFactory().createLong(length);
NativeOpsHolder.getInstance().getDeviceNativeOps().ravelMultiIndex(null, (LongPointer) multiIdxDB.addressPointer(),
(LongPointer) resultFlat.addressPointer(), length, (LongPointer) shapeInfo.addressPointer(),clipMode);
Assert.assertArrayEquals(flatIdxArray, resultFlat.asLong());
NativeOpsHolder.getInstance().getDeviceNativeOps().unravelIndex(null, (LongPointer) resultMulti.addressPointer(),
(LongPointer) flatIdxDB.addressPointer(), length, (LongPointer) shapeInfo.addressPointer());
Assert.assertArrayEquals(multiIdxArray, resultMulti.asLong());
//testing various clipMode cases
// clipMode = 0: throw an exception
try {
shape[2] = 10;
shapeInfo = Nd4j.getShapeInfoProvider().createShapeInformation(shape, DataType.FLOAT).getFirst();
NativeOpsHolder.getInstance().getDeviceNativeOps().ravelMultiIndex(null, (LongPointer) multiIdxDB.addressPointer(),
(LongPointer) resultFlat.addressPointer(), length, (LongPointer) shapeInfo.addressPointer(),clipMode);
Assert.fail("No exception thrown while using CLIP_MODE_THROW.");
} catch (RuntimeException e) {
//OK
}
// clipMode = 1: wrap around shape
clipMode = 1;
multiIdxDB = Nd4j.getDataBufferFactory().createLong(new long[] {3,4, 6,5, 6,9});
resultFlat = Nd4j.getDataBufferFactory().createLong(3);
shapeInfo = Nd4j.getShapeInfoProvider().createShapeInformation(new long[] {4, 6}, DataType.FLOAT).getFirst();
length = 3;
NativeOpsHolder.getInstance().getDeviceNativeOps().ravelMultiIndex(null, (LongPointer) multiIdxDB.addressPointer(),
(LongPointer) resultFlat.addressPointer(), length, (LongPointer) shapeInfo.addressPointer(), clipMode);
Assert.assertArrayEquals(new long[] {22, 17, 15}, resultFlat.asLong());
// clipMode = 2: clip to shape
clipMode = 2;
multiIdxDB = Nd4j.getDataBufferFactory().createLong(new long[] {3,4, 6,5, 6,9});
resultFlat = Nd4j.getDataBufferFactory().createLong(3);
shapeInfo = Nd4j.getShapeInfoProvider().createShapeInformation(new long[] {4, 6}, DataType.FLOAT).getFirst();
length = 3;
NativeOpsHolder.getInstance().getDeviceNativeOps().ravelMultiIndex(null, (LongPointer) multiIdxDB.addressPointer(),
(LongPointer) resultFlat.addressPointer(), length, (LongPointer) shapeInfo.addressPointer(), clipMode);
Assert.assertArrayEquals(new long[] {22, 23, 23}, resultFlat.asLong());
}
}