add IndexUtils class containing ravelMultiIndex and unravelIndex methods (#9122)
Also add test functions both for Java and C++. Signed-off-by: Péter Zarándy <pza@wehowsky.com>master
parent
a1fcc5f19f
commit
95ca39bd21
|
@ -660,6 +660,13 @@ static void execTransformBool(sd::LaunchContext *lc,
|
|||
BUILD_SINGLE_SELECTOR(xType, sd::sparse::SparseUtils, ::sortCooIndicesGeneric(indices, x, length, rank), LIBND4J_TYPES);
|
||||
}
|
||||
|
||||
inline static void execRavelMultiIndex(Nd4jLong *indices, Nd4jLong *flatIndices, Nd4jLong length, Nd4jLong *shapeInfo, int mode) {
|
||||
sd::sparse::IndexUtils::ravelMultiIndex(indices, flatIndices, length, shapeInfo, mode);
|
||||
}
|
||||
|
||||
inline static void execUnravelIndex(Nd4jLong *indices, Nd4jLong *flatIndices, Nd4jLong length, Nd4jLong *shapeInfo) {
|
||||
sd::sparse::IndexUtils::unravelIndex(indices, flatIndices, length, shapeInfo);
|
||||
}
|
||||
|
||||
inline static Nd4jLong encodeBitmap(void *dx, const Nd4jLong *xShapeInfo, Nd4jLong N, int *dz, float threshold) {
|
||||
auto xType = sd::ArrayOptions::dataType(xShapeInfo);
|
||||
|
|
|
@ -1482,6 +1482,30 @@ ND4J_EXPORT void sortCooIndices(Nd4jPointer *extraPointers,
|
|||
Nd4jLong length,
|
||||
const Nd4jLong *xShapeInfo);
|
||||
|
||||
/**
|
||||
*
|
||||
* @param extraPointers not used
|
||||
* @param indices DataBuffer containing COO indices for a sparse matrix that is to be raveled/flattened
|
||||
* @param flatIndices DataBuffer where the raveled/flattened indices are to be written to
|
||||
* @param length number of non-zero entries (length of flatIndices)
|
||||
* @param fullShapeBuffer DataBuffer with ShapeInfo for the full matrix to be flattened
|
||||
* @param mode clipMode determines the strategy to use if some of the the passed COO indices does
|
||||
* not fit into the shape determined by fullShapeBuffer
|
||||
* 0 throw an exception (default)
|
||||
* 1 wrap around shape
|
||||
* 2 clip to shape
|
||||
*/
|
||||
ND4J_EXPORT void ravelMultiIndex(Nd4jPointer *extraPointers, Nd4jLong *indices, Nd4jLong *flatIndices, Nd4jLong length, Nd4jLong *shapeInfo, int mode);
|
||||
|
||||
/**
|
||||
*
|
||||
* @param extraPointers not used
|
||||
* @param indices DataBuffer where the unraveled COO indices are to be written
|
||||
* @param flatIndices DataBuffer containing the raveled/flattened indices to be unravel
|
||||
* @param length number of non-zero entries (length of flatIndices)
|
||||
* @param fullShapeBuffer DataBuffer with ShapeInfo for the full matrix to be unraveled
|
||||
*/
|
||||
ND4J_EXPORT void unravelIndex(Nd4jPointer *extraPointers, Nd4jLong *indices, Nd4jLong *flatIndices, Nd4jLong length, Nd4jLong *shapeInfo);
|
||||
|
||||
ND4J_EXPORT Nd4jLong* mmapFile(Nd4jPointer *extraPointers, const char *fileName, Nd4jLong length);
|
||||
|
||||
|
|
|
@ -1843,6 +1843,14 @@ void sortCooIndices(Nd4jPointer *extraPointers,
|
|||
}
|
||||
}
|
||||
|
||||
void ravelMultiIndex(Nd4jPointer *extraPointers, Nd4jLong *indices, Nd4jLong *flatIndices, Nd4jLong length, Nd4jLong *shapeInfo, int mode) {
|
||||
NativeOpExecutioner::execRavelMultiIndex(indices, flatIndices, length, shapeInfo, mode);
|
||||
}
|
||||
|
||||
void unravelIndex(Nd4jPointer *extraPointers, Nd4jLong *indices, Nd4jLong *flatIndices, Nd4jLong length, Nd4jLong *shapeInfo) {
|
||||
NativeOpExecutioner::execUnravelIndex(indices, flatIndices, length, shapeInfo);
|
||||
}
|
||||
|
||||
Nd4jLong encodeBitmap(Nd4jPointer *extraPointers, void *hX, Nd4jLong const* hXShapeInfo, Nd4jLong N, int *dz, float threshold) {
|
||||
return NativeOpExecutioner::encodeBitmap(hX, hXShapeInfo, N, dz, threshold);
|
||||
}
|
||||
|
|
|
@ -2564,6 +2564,15 @@ void sortCooIndices(Nd4jPointer *extraPointers, Nd4jLong *indices, void *values,
|
|||
throw std::runtime_error("sortCooIndices:: Not implemented yet");
|
||||
}
|
||||
|
||||
void ravelMultiIndex(Nd4jPointer *extraPointers, Nd4jLong *indices, Nd4jLong *flatIndices, Nd4jLong length, Nd4jLong *shapeInfo, int mode) {
|
||||
throw std::runtime_error("ravelMultiIndex:: Not implemented yet");
|
||||
}
|
||||
|
||||
void unravelIndex(Nd4jPointer *extraPointers, Nd4jLong *indices, Nd4jLong *flatIndices, Nd4jLong length, Nd4jLong *shapeInfo) {
|
||||
throw std::runtime_error("unravelIndex:: Not implemented yet");
|
||||
}
|
||||
|
||||
|
||||
Nd4jLong* mmapFile(Nd4jPointer *extraPointers, const char *fileName, Nd4jLong length) {
|
||||
return nullptr;
|
||||
}
|
||||
|
|
|
@ -23,6 +23,8 @@
|
|||
#include <system/pointercast.h>
|
||||
#include <stdio.h>
|
||||
#include <stdlib.h>
|
||||
#include <helpers/shape.h>
|
||||
|
||||
#ifdef _OPENMP
|
||||
#include <omp.h>
|
||||
#endif
|
||||
|
@ -219,5 +221,89 @@ PRAGMA_OMP_SINGLE_ARGS(nowait)
|
|||
}
|
||||
|
||||
BUILD_SINGLE_TEMPLATE(template class ND4J_EXPORT SparseUtils, , LIBND4J_TYPES);
|
||||
|
||||
void IndexUtils::ravelMultiIndex(Nd4jLong *indices, Nd4jLong *flatIndices, Nd4jLong length, Nd4jLong *shapeInfo, int mode){
|
||||
Nd4jLong * shape = shape::shapeOf(shapeInfo);
|
||||
Nd4jLong * stride = shape::stride(shapeInfo);
|
||||
Nd4jLong rank = shape::rank(shapeInfo);
|
||||
int errorCount = 0;
|
||||
|
||||
PRAGMA_OMP_PARALLEL_FOR
|
||||
for (Nd4jLong i = 0; i < length; ++i){
|
||||
Nd4jLong raveledIndex = 0;
|
||||
for (Nd4jLong j = 0; j < rank; ++j){
|
||||
Nd4jLong idx = indices[i * rank + j];
|
||||
if (idx >= shape[j]) {
|
||||
// index does not fit into shape at j dimension.
|
||||
if (mode == ND4J_CLIPMODE_CLIP){
|
||||
// set idx to largest possible value (clip to shape)
|
||||
idx = shape[j] - 1;
|
||||
}
|
||||
else if (mode == ND4J_CLIPMODE_WRAP) {
|
||||
idx %= shape[j];
|
||||
} else {
|
||||
// mode is ND4J_CLIPMODE_THROW or is unknown. either way. throw an error later.
|
||||
// cannot throw here because of parallel region
|
||||
nd4j_printf("sparse::IndexUtils::ravelMultiIndex Cannot ravel index at element %d, does not fit into specified shape.\n", i);
|
||||
++errorCount;
|
||||
}
|
||||
}
|
||||
raveledIndex += idx * stride[j];
|
||||
}
|
||||
flatIndices[i] = raveledIndex;
|
||||
}
|
||||
|
||||
if (errorCount > 0){
|
||||
// throw error if one ocurred in loop
|
||||
throw std::runtime_error("sparse::IndexUtils::ravelMultiIndex Cannot ravel index");
|
||||
}
|
||||
}
|
||||
|
||||
void IndexUtils::unravelIndex(Nd4jLong *indices, Nd4jLong *flatIndices, Nd4jLong length, Nd4jLong *shapeInfo){
|
||||
Nd4jLong * shape = shape::shapeOf(shapeInfo);
|
||||
Nd4jLong * stride = shape::stride(shapeInfo);
|
||||
Nd4jLong rank = shape::rank(shapeInfo);
|
||||
int errorCount = 0;
|
||||
|
||||
// unravelOrder ensures that the dimensions with largest stride are unraveled first.
|
||||
// create vector with elements 0..rank
|
||||
int * unravelOrder = shape::range<int>(0, rank);
|
||||
|
||||
// sort order according to stride length.
|
||||
std::sort(unravelOrder, unravelOrder + rank,
|
||||
[&](int i1, int i2) { return stride[i1] > stride[i2]; } );
|
||||
|
||||
// calculate the largest raveled index that will fit into passed shape
|
||||
Nd4jLong maxRaveledIndex = shape[unravelOrder[0]] * stride[unravelOrder[0]] - 1;
|
||||
|
||||
PRAGMA_OMP_PARALLEL_FOR
|
||||
for (Nd4jLong i = 0; i < length; ++i){
|
||||
Nd4jLong raveledIndex = flatIndices[i];
|
||||
if (raveledIndex > maxRaveledIndex){
|
||||
// cannot throw here because of parallel region
|
||||
nd4j_printf("sparse::IndexUtils::unravelIndex Cannot unravel index at element %d. raveled index of %d does not fit into specified shape.\n", i, raveledIndex);
|
||||
++errorCount;
|
||||
}
|
||||
|
||||
for (int * it = unravelOrder; it != unravelOrder + rank; it++){
|
||||
int j = *it;
|
||||
// how many strides of this size?
|
||||
indices[i * rank + j] = raveledIndex / stride[j];
|
||||
|
||||
// remainder for subsequent smaller strides.
|
||||
raveledIndex %= stride[j];
|
||||
}
|
||||
}
|
||||
|
||||
if (errorCount > 0){
|
||||
// throw error if one ocurred in loop
|
||||
nd4j_printf("Largest raveled index is: %d, ", maxRaveledIndex)
|
||||
std::vector<Nd4jLong> v(shape, shape + rank);
|
||||
nd4j_printv("Shape: ", v);
|
||||
throw std::runtime_error("sparse::IndexUtils::unravelIndex Cannot unravel index");
|
||||
}
|
||||
|
||||
delete[] unravelOrder;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -23,7 +23,13 @@
|
|||
#ifndef LIBND4J_SPECIALS_SPARSE_H
|
||||
#define LIBND4J_SPECIALS_SPARSE_H
|
||||
|
||||
#define ND4J_CLIPMODE_THROW 0
|
||||
#define ND4J_CLIPMODE_WRAP 1
|
||||
#define ND4J_CLIPMODE_CLIP 2
|
||||
|
||||
#include <system/pointercast.h>
|
||||
#include <system/dll.h>
|
||||
|
||||
|
||||
namespace sd {
|
||||
namespace sparse {
|
||||
|
@ -61,6 +67,25 @@ namespace sd {
|
|||
int rank);
|
||||
|
||||
static void sortCooIndicesGeneric(Nd4jLong *indices, void *vx, Nd4jLong length, int rank);
|
||||
|
||||
|
||||
};
|
||||
|
||||
class ND4J_EXPORT IndexUtils {
|
||||
public:
|
||||
/**
|
||||
* Converts indices in COO format into an array of flat indices
|
||||
*
|
||||
* based on numpy.ravel_multi_index
|
||||
*/
|
||||
static void ravelMultiIndex(Nd4jLong *indices, Nd4jLong *flatIndices, Nd4jLong length, Nd4jLong *shapeInfo, int mode);
|
||||
|
||||
/**
|
||||
* Converts flat indices to index matrix in COO format
|
||||
*
|
||||
* based on numpy.unravel_index
|
||||
*/
|
||||
static void unravelIndex(Nd4jLong *indices, Nd4jLong *flatIndices, Nd4jLong length, Nd4jLong *shapeInfo);
|
||||
};
|
||||
}
|
||||
}
|
||||
|
|
|
@ -35,7 +35,8 @@ public:
|
|||
//////////////////////////////////////////////////////////////////////
|
||||
TEST_F(SparseUtilsTest, SortCOOindices_Test) {
|
||||
|
||||
#ifndef __CUDABLAS__
|
||||
#ifndef __CUDABLAS__
|
||||
|
||||
|
||||
Nd4jLong * indicesArr = new Nd4jLong[nnz * rank]{
|
||||
0,2,7,
|
||||
|
@ -143,5 +144,102 @@ TEST_F(SparseUtilsTest, SortCOOindices_Test) {
|
|||
delete[] indicesArr;
|
||||
delete[] expIndicesArr;
|
||||
|
||||
#endif
|
||||
}
|
||||
|
||||
#endif
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////
|
||||
TEST_F(SparseUtilsTest, RavelIndices_Test) {
|
||||
|
||||
#ifndef __CUDABLAS__
|
||||
|
||||
Nd4jLong * indicesArrExp = new Nd4jLong[nnz * rank]{
|
||||
0,2,7,
|
||||
2,36,35,
|
||||
3,30,17,
|
||||
5,12,22,
|
||||
5,43,45,
|
||||
6,32,11,
|
||||
8,8,32,
|
||||
9,29,11,
|
||||
5,11,22,
|
||||
15,26,16,
|
||||
17,48,49,
|
||||
24,28,31,
|
||||
26,6,23,
|
||||
31,21,31,
|
||||
35,46,45,
|
||||
37,13,14,
|
||||
6,38,18,
|
||||
7,28,20,
|
||||
8,29,39,
|
||||
8,32,30,
|
||||
9,42,43,
|
||||
11,15,18,
|
||||
13,18,45,
|
||||
29,26,39,
|
||||
30,8,25,
|
||||
42,31,24,
|
||||
28,33,5,
|
||||
31,27,1,
|
||||
35,43,26,
|
||||
36,8,37,
|
||||
39,22,14,
|
||||
39,24,42,
|
||||
42,48,2,
|
||||
43,26,48,
|
||||
44,23,49,
|
||||
45,18,34,
|
||||
46,28,5,
|
||||
46,32,17,
|
||||
48,34,44,
|
||||
49,38,39,
|
||||
};
|
||||
Nd4jLong * indicesArr = new Nd4jLong[nnz * rank];
|
||||
|
||||
Nd4jLong * flatIndicesExp = new Nd4jLong[nnz]{
|
||||
147, 10955, 14717, 21862, 24055, 27451, 34192, 39841,
|
||||
21792, 64836, 74809, 102791, 109643, 131701, 150265, 156324,
|
||||
27878, 31380, 35669, 35870, 40783, 47268, 55905, 123659,
|
||||
126585, 178594, 119915, 132091, 150036, 151797, 165354, 165522,
|
||||
179762, 182468, 186459, 190294, 195165, 195457, 204024, 208499
|
||||
};
|
||||
|
||||
Nd4jLong * flatIndices = new Nd4jLong[nnz];
|
||||
|
||||
|
||||
Nd4jLong * shape = new Nd4jLong[rank]{50, 60, 70};
|
||||
Nd4jLong * shapeInfoBuffer = shape::shapeBuffer(rank, sd::DataType::INT64, shape);
|
||||
|
||||
|
||||
sd::sparse::IndexUtils::ravelMultiIndex(indicesArrExp, flatIndices, nnz, shapeInfoBuffer, ND4J_CLIPMODE_THROW);
|
||||
|
||||
for ( int i = 0; i < nnz; ++i){
|
||||
ASSERT_EQ(flatIndicesExp[i], flatIndices[i]);
|
||||
}
|
||||
|
||||
sd::sparse::IndexUtils::unravelIndex(indicesArr, flatIndices, nnz, shapeInfoBuffer);
|
||||
|
||||
for ( int i = 0; i < nnz * rank; ++i){
|
||||
ASSERT_EQ(indicesArrExp[i], indicesArr[i]);
|
||||
}
|
||||
|
||||
shape[2] = 30;
|
||||
shapeInfoBuffer = shape::shapeBuffer(rank, sd::DataType::INT64, shape);
|
||||
|
||||
try {
|
||||
sd::sparse::IndexUtils::ravelMultiIndex(indicesArrExp, flatIndices, nnz, shapeInfoBuffer, ND4J_CLIPMODE_THROW);
|
||||
FAIL();
|
||||
} catch (const std::runtime_error& e) {
|
||||
// pass
|
||||
}
|
||||
|
||||
delete[] indicesArrExp;
|
||||
delete[] indicesArr;
|
||||
delete[] flatIndicesExp;
|
||||
delete[] flatIndices;
|
||||
delete[] shape;
|
||||
delete[] shapeInfoBuffer;
|
||||
|
||||
#endif
|
||||
}
|
||||
|
|
|
@ -1025,6 +1025,32 @@ public interface NativeOps {
|
|||
void sortCooIndices(PointerPointer extraPointers, @Cast("Nd4jLong *") LongPointer indices, Pointer x, long length, @Cast("Nd4jLong *") LongPointer shapeInfo);
|
||||
|
||||
|
||||
/**
|
||||
*
|
||||
* @param extraPointers not used
|
||||
* @param indices DataBuffer containing COO indices for a sparse matrix that is to be raveled/flattened
|
||||
* @param flatIndices DataBuffer where the raveled/flattened indices are to be written to
|
||||
* @param length number of non-zero entries (length of flatIndices)
|
||||
* @param shapeInfo DataBuffer with ShapeInfo for the full matrix to be flattened
|
||||
* @param mode clipMode determines the strategy to use if some of the the passed COO indices does
|
||||
* not fit into the shape determined by fullShapeBuffer
|
||||
* 0 throw an exception (default)
|
||||
* 1 wrap around shape
|
||||
* 2 clip to shape
|
||||
*/
|
||||
void ravelMultiIndex(PointerPointer extraPointers, @Cast("Nd4jLong *") LongPointer indices, @Cast("Nd4jLong *") LongPointer flatIndices, long length, @Cast("Nd4jLong *") LongPointer shapeInfo, int mode);
|
||||
|
||||
/**
|
||||
*
|
||||
* @param extraPointers not used
|
||||
* @param indices DataBuffer where the unraveled COO indices are to be written
|
||||
* @param flatIndices DataBuffer containing the raveled/flattened indices to be unravel
|
||||
* @param length number of non-zero entries (length of flatIndices)
|
||||
* @param shapeInfo DataBuffer with ShapeInfo for the full matrix to be unraveled
|
||||
*/
|
||||
void unravelIndex(PointerPointer extraPointers, @Cast("Nd4jLong *") LongPointer indices, @Cast("Nd4jLong *") LongPointer flatIndices, long length, @Cast("Nd4jLong *") LongPointer shapeInfo);
|
||||
|
||||
|
||||
LongPointer mmapFile(PointerPointer extraPointers, String fileName, long length);
|
||||
|
||||
void munmapFile(PointerPointer extraPointers, LongPointer ptrMap, long length);
|
||||
|
|
|
@ -3151,6 +3151,12 @@ public native void sortCooIndices(@Cast("Nd4jPointer*") PointerPointer extraPoin
|
|||
@Cast("Nd4jLong") long length,
|
||||
@Cast("const Nd4jLong*") long[] xShapeInfo);
|
||||
|
||||
public native void ravelMultiIndex(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong*") LongPointer indices, @Cast("Nd4jLong*") LongPointer flatIndices, @Cast("Nd4jLong") long length, @Cast("Nd4jLong*") LongPointer shapeInfo, int mode);
|
||||
public native void ravelMultiIndex(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong*") LongBuffer indices, @Cast("Nd4jLong*") LongBuffer flatIndices, @Cast("Nd4jLong") long length, @Cast("Nd4jLong*") LongBuffer shapeInfo, int mode);
|
||||
public native void ravelMultiIndex(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong*") long[] indices, @Cast("Nd4jLong*") long[] flatIndices, @Cast("Nd4jLong") long length, @Cast("Nd4jLong*") long[] shapeInfo, int mode);
|
||||
public native void unravelIndex(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong*") LongPointer indices, @Cast("Nd4jLong*") LongPointer flatIndices, @Cast("Nd4jLong") long length, @Cast("Nd4jLong*") LongPointer shapeInfo);
|
||||
public native void unravelIndex(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong*") LongBuffer indices, @Cast("Nd4jLong*") LongBuffer flatIndices, @Cast("Nd4jLong") long length, @Cast("Nd4jLong*") LongBuffer shapeInfo);
|
||||
public native void unravelIndex(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong*") long[] indices, @Cast("Nd4jLong*") long[] flatIndices, @Cast("Nd4jLong") long length, @Cast("Nd4jLong*") long[] shapeInfo);
|
||||
|
||||
public native @Cast("Nd4jLong*") LongPointer mmapFile(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("char*") String fileName, @Cast("Nd4jLong") long length);
|
||||
public native @Cast("Nd4jLong*") LongBuffer mmapFile(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("char*") BytePointer fileName, @Cast("Nd4jLong") long length);
|
||||
|
|
|
@ -0,0 +1,170 @@
|
|||
package org.nd4j.linalg.specials;
|
||||
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.bytedeco.javacpp.LongPointer;
|
||||
import org.junit.After;
|
||||
import org.junit.Assert;
|
||||
import org.junit.Before;
|
||||
import org.junit.Test;
|
||||
import org.junit.runner.RunWith;
|
||||
import org.junit.runners.Parameterized;
|
||||
import org.nd4j.linalg.BaseNd4jTest;
|
||||
import org.nd4j.linalg.api.buffer.DataBuffer;
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
import org.nd4j.linalg.factory.Nd4jBackend;
|
||||
import org.nd4j.nativeblas.NativeOpsHolder;
|
||||
|
||||
|
||||
@Slf4j
|
||||
@RunWith(Parameterized.class)
|
||||
public class RavelIndexTest extends BaseNd4jTest {
|
||||
|
||||
DataType initialType;
|
||||
|
||||
public RavelIndexTest(Nd4jBackend backend) {
|
||||
super(backend);
|
||||
this.initialType = Nd4j.dataType();
|
||||
}
|
||||
|
||||
@Before
|
||||
public void setUp() {
|
||||
Nd4j.setDataType(DataType.FLOAT);
|
||||
}
|
||||
|
||||
@After
|
||||
public void setDown() {
|
||||
Nd4j.setDataType(initialType);
|
||||
}
|
||||
|
||||
@Override
|
||||
public char ordering() {
|
||||
return 'c';
|
||||
}
|
||||
|
||||
|
||||
@Test
|
||||
public void ravelIndexesTest() {
|
||||
// FIXME: we don't want this test running on cuda for now
|
||||
if (Nd4j.getExecutioner().getClass().getCanonicalName().toLowerCase().contains("cuda"))
|
||||
return;
|
||||
|
||||
long[] multiIdxArray = new long[] {
|
||||
0,2,7,
|
||||
2,36,35,
|
||||
3,30,17,
|
||||
5,12,22,
|
||||
5,43,45,
|
||||
6,32,11,
|
||||
8,8,32,
|
||||
9,29,11,
|
||||
5,11,22,
|
||||
15,26,16,
|
||||
17,48,49,
|
||||
24,28,31,
|
||||
26,6,23,
|
||||
31,21,31,
|
||||
35,46,45,
|
||||
37,13,14,
|
||||
6,38,18,
|
||||
7,28,20,
|
||||
8,29,39,
|
||||
8,32,30,
|
||||
9,42,43,
|
||||
11,15,18,
|
||||
13,18,45,
|
||||
29,26,39,
|
||||
30,8,25,
|
||||
42,31,24,
|
||||
28,33,5,
|
||||
31,27,1,
|
||||
35,43,26,
|
||||
36,8,37,
|
||||
39,22,14,
|
||||
39,24,42,
|
||||
42,48,2,
|
||||
43,26,48,
|
||||
44,23,49,
|
||||
45,18,34,
|
||||
46,28,5,
|
||||
46,32,17,
|
||||
48,34,44,
|
||||
49,38,39,
|
||||
};
|
||||
|
||||
long[] flatIdxArray = new long[] {
|
||||
147, 10955, 14717, 21862, 24055, 27451, 34192, 39841,
|
||||
21792, 64836, 74809, 102791, 109643, 131701, 150265, 156324,
|
||||
27878, 31380, 35669, 35870, 40783, 47268, 55905, 123659,
|
||||
126585, 178594, 119915, 132091, 150036, 151797, 165354, 165522,
|
||||
179762, 182468, 186459, 190294, 195165, 195457, 204024, 208499
|
||||
};
|
||||
|
||||
|
||||
|
||||
int clipMode = 0;
|
||||
|
||||
|
||||
long DIM = 3;
|
||||
long length = multiIdxArray.length / DIM;
|
||||
long[] shape = new long[] {50, 60, 70};
|
||||
|
||||
|
||||
DataBuffer multiIdxDB = Nd4j.getDataBufferFactory().createLong(multiIdxArray);
|
||||
DataBuffer flatIdxDB = Nd4j.getDataBufferFactory().createLong(flatIdxArray);
|
||||
DataBuffer shapeInfo = Nd4j.getShapeInfoProvider().createShapeInformation(shape, DataType.FLOAT).getFirst();
|
||||
|
||||
DataBuffer resultMulti = Nd4j.getDataBufferFactory().createLong(length*DIM);
|
||||
DataBuffer resultFlat = Nd4j.getDataBufferFactory().createLong(length);
|
||||
|
||||
NativeOpsHolder.getInstance().getDeviceNativeOps().ravelMultiIndex(null, (LongPointer) multiIdxDB.addressPointer(),
|
||||
(LongPointer) resultFlat.addressPointer(), length, (LongPointer) shapeInfo.addressPointer(),clipMode);
|
||||
|
||||
Assert.assertArrayEquals(flatIdxArray, resultFlat.asLong());
|
||||
|
||||
NativeOpsHolder.getInstance().getDeviceNativeOps().unravelIndex(null, (LongPointer) resultMulti.addressPointer(),
|
||||
(LongPointer) flatIdxDB.addressPointer(), length, (LongPointer) shapeInfo.addressPointer());
|
||||
|
||||
Assert.assertArrayEquals(multiIdxArray, resultMulti.asLong());
|
||||
|
||||
|
||||
//testing various clipMode cases
|
||||
// clipMode = 0: throw an exception
|
||||
try {
|
||||
shape[2] = 10;
|
||||
shapeInfo = Nd4j.getShapeInfoProvider().createShapeInformation(shape, DataType.FLOAT).getFirst();
|
||||
NativeOpsHolder.getInstance().getDeviceNativeOps().ravelMultiIndex(null, (LongPointer) multiIdxDB.addressPointer(),
|
||||
(LongPointer) resultFlat.addressPointer(), length, (LongPointer) shapeInfo.addressPointer(),clipMode);
|
||||
Assert.fail("No exception thrown while using CLIP_MODE_THROW.");
|
||||
|
||||
} catch (RuntimeException e) {
|
||||
//OK
|
||||
}
|
||||
// clipMode = 1: wrap around shape
|
||||
clipMode = 1;
|
||||
multiIdxDB = Nd4j.getDataBufferFactory().createLong(new long[] {3,4, 6,5, 6,9});
|
||||
resultFlat = Nd4j.getDataBufferFactory().createLong(3);
|
||||
shapeInfo = Nd4j.getShapeInfoProvider().createShapeInformation(new long[] {4, 6}, DataType.FLOAT).getFirst();
|
||||
length = 3;
|
||||
|
||||
NativeOpsHolder.getInstance().getDeviceNativeOps().ravelMultiIndex(null, (LongPointer) multiIdxDB.addressPointer(),
|
||||
(LongPointer) resultFlat.addressPointer(), length, (LongPointer) shapeInfo.addressPointer(), clipMode);
|
||||
Assert.assertArrayEquals(new long[] {22, 17, 15}, resultFlat.asLong());
|
||||
|
||||
// clipMode = 2: clip to shape
|
||||
clipMode = 2;
|
||||
multiIdxDB = Nd4j.getDataBufferFactory().createLong(new long[] {3,4, 6,5, 6,9});
|
||||
resultFlat = Nd4j.getDataBufferFactory().createLong(3);
|
||||
shapeInfo = Nd4j.getShapeInfoProvider().createShapeInformation(new long[] {4, 6}, DataType.FLOAT).getFirst();
|
||||
length = 3;
|
||||
|
||||
NativeOpsHolder.getInstance().getDeviceNativeOps().ravelMultiIndex(null, (LongPointer) multiIdxDB.addressPointer(),
|
||||
(LongPointer) resultFlat.addressPointer(), length, (LongPointer) shapeInfo.addressPointer(), clipMode);
|
||||
|
||||
Assert.assertArrayEquals(new long[] {22, 23, 23}, resultFlat.asLong());
|
||||
|
||||
|
||||
|
||||
}
|
||||
|
||||
}
|
Loading…
Reference in New Issue