diff --git a/libnd4j/include/legacy/NativeOpExecutioner.h b/libnd4j/include/legacy/NativeOpExecutioner.h index 5e8f54550..886f5fa03 100644 --- a/libnd4j/include/legacy/NativeOpExecutioner.h +++ b/libnd4j/include/legacy/NativeOpExecutioner.h @@ -660,6 +660,13 @@ static void execTransformBool(sd::LaunchContext *lc, BUILD_SINGLE_SELECTOR(xType, sd::sparse::SparseUtils, ::sortCooIndicesGeneric(indices, x, length, rank), LIBND4J_TYPES); } + inline static void execRavelMultiIndex(Nd4jLong *indices, Nd4jLong *flatIndices, Nd4jLong length, Nd4jLong *shapeInfo, int mode) { + sd::sparse::IndexUtils::ravelMultiIndex(indices, flatIndices, length, shapeInfo, mode); + } + + inline static void execUnravelIndex(Nd4jLong *indices, Nd4jLong *flatIndices, Nd4jLong length, Nd4jLong *shapeInfo) { + sd::sparse::IndexUtils::unravelIndex(indices, flatIndices, length, shapeInfo); + } inline static Nd4jLong encodeBitmap(void *dx, const Nd4jLong *xShapeInfo, Nd4jLong N, int *dz, float threshold) { auto xType = sd::ArrayOptions::dataType(xShapeInfo); diff --git a/libnd4j/include/legacy/NativeOps.h b/libnd4j/include/legacy/NativeOps.h index 392bf8e80..a55846d87 100755 --- a/libnd4j/include/legacy/NativeOps.h +++ b/libnd4j/include/legacy/NativeOps.h @@ -1482,6 +1482,30 @@ ND4J_EXPORT void sortCooIndices(Nd4jPointer *extraPointers, Nd4jLong length, const Nd4jLong *xShapeInfo); +/** + * + * @param extraPointers not used + * @param indices DataBuffer containing COO indices for a sparse matrix that is to be raveled/flattened + * @param flatIndices DataBuffer where the raveled/flattened indices are to be written to + * @param length number of non-zero entries (length of flatIndices) + * @param fullShapeBuffer DataBuffer with ShapeInfo for the full matrix to be flattened + * @param mode clipMode determines the strategy to use if some of the the passed COO indices does + * not fit into the shape determined by fullShapeBuffer + * 0 throw an exception (default) + * 1 wrap around shape + * 2 clip to shape + */ +ND4J_EXPORT void ravelMultiIndex(Nd4jPointer *extraPointers, Nd4jLong *indices, Nd4jLong *flatIndices, Nd4jLong length, Nd4jLong *shapeInfo, int mode); + +/** + * + * @param extraPointers not used + * @param indices DataBuffer where the unraveled COO indices are to be written + * @param flatIndices DataBuffer containing the raveled/flattened indices to be unravel + * @param length number of non-zero entries (length of flatIndices) + * @param fullShapeBuffer DataBuffer with ShapeInfo for the full matrix to be unraveled + */ +ND4J_EXPORT void unravelIndex(Nd4jPointer *extraPointers, Nd4jLong *indices, Nd4jLong *flatIndices, Nd4jLong length, Nd4jLong *shapeInfo); ND4J_EXPORT Nd4jLong* mmapFile(Nd4jPointer *extraPointers, const char *fileName, Nd4jLong length); diff --git a/libnd4j/include/legacy/cpu/NativeOps.cpp b/libnd4j/include/legacy/cpu/NativeOps.cpp index 44831b8ba..b483ef91a 100644 --- a/libnd4j/include/legacy/cpu/NativeOps.cpp +++ b/libnd4j/include/legacy/cpu/NativeOps.cpp @@ -1843,6 +1843,14 @@ void sortCooIndices(Nd4jPointer *extraPointers, } } +void ravelMultiIndex(Nd4jPointer *extraPointers, Nd4jLong *indices, Nd4jLong *flatIndices, Nd4jLong length, Nd4jLong *shapeInfo, int mode) { + NativeOpExecutioner::execRavelMultiIndex(indices, flatIndices, length, shapeInfo, mode); +} + +void unravelIndex(Nd4jPointer *extraPointers, Nd4jLong *indices, Nd4jLong *flatIndices, Nd4jLong length, Nd4jLong *shapeInfo) { + NativeOpExecutioner::execUnravelIndex(indices, flatIndices, length, shapeInfo); +} + Nd4jLong encodeBitmap(Nd4jPointer *extraPointers, void *hX, Nd4jLong const* hXShapeInfo, Nd4jLong N, int *dz, float threshold) { return NativeOpExecutioner::encodeBitmap(hX, hXShapeInfo, N, dz, threshold); } diff --git a/libnd4j/include/legacy/cuda/NativeOps.cu b/libnd4j/include/legacy/cuda/NativeOps.cu index d66c9827b..aa6c139d2 100755 --- a/libnd4j/include/legacy/cuda/NativeOps.cu +++ b/libnd4j/include/legacy/cuda/NativeOps.cu @@ -2564,6 +2564,15 @@ void sortCooIndices(Nd4jPointer *extraPointers, Nd4jLong *indices, void *values, throw std::runtime_error("sortCooIndices:: Not implemented yet"); } +void ravelMultiIndex(Nd4jPointer *extraPointers, Nd4jLong *indices, Nd4jLong *flatIndices, Nd4jLong length, Nd4jLong *shapeInfo, int mode) { + throw std::runtime_error("ravelMultiIndex:: Not implemented yet"); +} + +void unravelIndex(Nd4jPointer *extraPointers, Nd4jLong *indices, Nd4jLong *flatIndices, Nd4jLong length, Nd4jLong *shapeInfo) { + throw std::runtime_error("unravelIndex:: Not implemented yet"); +} + + Nd4jLong* mmapFile(Nd4jPointer *extraPointers, const char *fileName, Nd4jLong length) { return nullptr; } diff --git a/libnd4j/include/ops/impl/specials_sparse.cpp b/libnd4j/include/ops/impl/specials_sparse.cpp index 1485557f8..ae9669f46 100644 --- a/libnd4j/include/ops/impl/specials_sparse.cpp +++ b/libnd4j/include/ops/impl/specials_sparse.cpp @@ -23,6 +23,8 @@ #include #include #include +#include + #ifdef _OPENMP #include #endif @@ -219,5 +221,89 @@ PRAGMA_OMP_SINGLE_ARGS(nowait) } BUILD_SINGLE_TEMPLATE(template class ND4J_EXPORT SparseUtils, , LIBND4J_TYPES); + + void IndexUtils::ravelMultiIndex(Nd4jLong *indices, Nd4jLong *flatIndices, Nd4jLong length, Nd4jLong *shapeInfo, int mode){ + Nd4jLong * shape = shape::shapeOf(shapeInfo); + Nd4jLong * stride = shape::stride(shapeInfo); + Nd4jLong rank = shape::rank(shapeInfo); + int errorCount = 0; + + PRAGMA_OMP_PARALLEL_FOR + for (Nd4jLong i = 0; i < length; ++i){ + Nd4jLong raveledIndex = 0; + for (Nd4jLong j = 0; j < rank; ++j){ + Nd4jLong idx = indices[i * rank + j]; + if (idx >= shape[j]) { + // index does not fit into shape at j dimension. + if (mode == ND4J_CLIPMODE_CLIP){ + // set idx to largest possible value (clip to shape) + idx = shape[j] - 1; + } + else if (mode == ND4J_CLIPMODE_WRAP) { + idx %= shape[j]; + } else { + // mode is ND4J_CLIPMODE_THROW or is unknown. either way. throw an error later. + // cannot throw here because of parallel region + nd4j_printf("sparse::IndexUtils::ravelMultiIndex Cannot ravel index at element %d, does not fit into specified shape.\n", i); + ++errorCount; + } + } + raveledIndex += idx * stride[j]; + } + flatIndices[i] = raveledIndex; + } + + if (errorCount > 0){ + // throw error if one ocurred in loop + throw std::runtime_error("sparse::IndexUtils::ravelMultiIndex Cannot ravel index"); + } + } + + void IndexUtils::unravelIndex(Nd4jLong *indices, Nd4jLong *flatIndices, Nd4jLong length, Nd4jLong *shapeInfo){ + Nd4jLong * shape = shape::shapeOf(shapeInfo); + Nd4jLong * stride = shape::stride(shapeInfo); + Nd4jLong rank = shape::rank(shapeInfo); + int errorCount = 0; + + // unravelOrder ensures that the dimensions with largest stride are unraveled first. + // create vector with elements 0..rank + int * unravelOrder = shape::range(0, rank); + + // sort order according to stride length. + std::sort(unravelOrder, unravelOrder + rank, + [&](int i1, int i2) { return stride[i1] > stride[i2]; } ); + + // calculate the largest raveled index that will fit into passed shape + Nd4jLong maxRaveledIndex = shape[unravelOrder[0]] * stride[unravelOrder[0]] - 1; + + PRAGMA_OMP_PARALLEL_FOR + for (Nd4jLong i = 0; i < length; ++i){ + Nd4jLong raveledIndex = flatIndices[i]; + if (raveledIndex > maxRaveledIndex){ + // cannot throw here because of parallel region + nd4j_printf("sparse::IndexUtils::unravelIndex Cannot unravel index at element %d. raveled index of %d does not fit into specified shape.\n", i, raveledIndex); + ++errorCount; + } + + for (int * it = unravelOrder; it != unravelOrder + rank; it++){ + int j = *it; + // how many strides of this size? + indices[i * rank + j] = raveledIndex / stride[j]; + + // remainder for subsequent smaller strides. + raveledIndex %= stride[j]; + } + } + + if (errorCount > 0){ + // throw error if one ocurred in loop + nd4j_printf("Largest raveled index is: %d, ", maxRaveledIndex) + std::vector v(shape, shape + rank); + nd4j_printv("Shape: ", v); + throw std::runtime_error("sparse::IndexUtils::unravelIndex Cannot unravel index"); + } + + delete[] unravelOrder; + } } } diff --git a/libnd4j/include/ops/specials_sparse.h b/libnd4j/include/ops/specials_sparse.h index 4ea5a1573..35069048f 100644 --- a/libnd4j/include/ops/specials_sparse.h +++ b/libnd4j/include/ops/specials_sparse.h @@ -23,7 +23,13 @@ #ifndef LIBND4J_SPECIALS_SPARSE_H #define LIBND4J_SPECIALS_SPARSE_H +#define ND4J_CLIPMODE_THROW 0 +#define ND4J_CLIPMODE_WRAP 1 +#define ND4J_CLIPMODE_CLIP 2 + #include +#include + namespace sd { namespace sparse { @@ -61,6 +67,25 @@ namespace sd { int rank); static void sortCooIndicesGeneric(Nd4jLong *indices, void *vx, Nd4jLong length, int rank); + + + }; + + class ND4J_EXPORT IndexUtils { + public: + /** + * Converts indices in COO format into an array of flat indices + * + * based on numpy.ravel_multi_index + */ + static void ravelMultiIndex(Nd4jLong *indices, Nd4jLong *flatIndices, Nd4jLong length, Nd4jLong *shapeInfo, int mode); + + /** + * Converts flat indices to index matrix in COO format + * + * based on numpy.unravel_index + */ + static void unravelIndex(Nd4jLong *indices, Nd4jLong *flatIndices, Nd4jLong length, Nd4jLong *shapeInfo); }; } } diff --git a/libnd4j/tests_cpu/layers_tests/SparseUtilsTest.cpp b/libnd4j/tests_cpu/layers_tests/SparseUtilsTest.cpp index 37f52568f..c6b2c0a54 100644 --- a/libnd4j/tests_cpu/layers_tests/SparseUtilsTest.cpp +++ b/libnd4j/tests_cpu/layers_tests/SparseUtilsTest.cpp @@ -35,7 +35,8 @@ public: ////////////////////////////////////////////////////////////////////// TEST_F(SparseUtilsTest, SortCOOindices_Test) { - #ifndef __CUDABLAS__ +#ifndef __CUDABLAS__ + Nd4jLong * indicesArr = new Nd4jLong[nnz * rank]{ 0,2,7, @@ -143,5 +144,102 @@ TEST_F(SparseUtilsTest, SortCOOindices_Test) { delete[] indicesArr; delete[] expIndicesArr; - #endif -} \ No newline at end of file + +#endif +} + +////////////////////////////////////////////////////////////////////// +TEST_F(SparseUtilsTest, RavelIndices_Test) { + +#ifndef __CUDABLAS__ + + Nd4jLong * indicesArrExp = new Nd4jLong[nnz * rank]{ + 0,2,7, + 2,36,35, + 3,30,17, + 5,12,22, + 5,43,45, + 6,32,11, + 8,8,32, + 9,29,11, + 5,11,22, + 15,26,16, + 17,48,49, + 24,28,31, + 26,6,23, + 31,21,31, + 35,46,45, + 37,13,14, + 6,38,18, + 7,28,20, + 8,29,39, + 8,32,30, + 9,42,43, + 11,15,18, + 13,18,45, + 29,26,39, + 30,8,25, + 42,31,24, + 28,33,5, + 31,27,1, + 35,43,26, + 36,8,37, + 39,22,14, + 39,24,42, + 42,48,2, + 43,26,48, + 44,23,49, + 45,18,34, + 46,28,5, + 46,32,17, + 48,34,44, + 49,38,39, + }; + Nd4jLong * indicesArr = new Nd4jLong[nnz * rank]; + + Nd4jLong * flatIndicesExp = new Nd4jLong[nnz]{ + 147, 10955, 14717, 21862, 24055, 27451, 34192, 39841, + 21792, 64836, 74809, 102791, 109643, 131701, 150265, 156324, + 27878, 31380, 35669, 35870, 40783, 47268, 55905, 123659, + 126585, 178594, 119915, 132091, 150036, 151797, 165354, 165522, + 179762, 182468, 186459, 190294, 195165, 195457, 204024, 208499 + }; + + Nd4jLong * flatIndices = new Nd4jLong[nnz]; + + + Nd4jLong * shape = new Nd4jLong[rank]{50, 60, 70}; + Nd4jLong * shapeInfoBuffer = shape::shapeBuffer(rank, sd::DataType::INT64, shape); + + + sd::sparse::IndexUtils::ravelMultiIndex(indicesArrExp, flatIndices, nnz, shapeInfoBuffer, ND4J_CLIPMODE_THROW); + + for ( int i = 0; i < nnz; ++i){ + ASSERT_EQ(flatIndicesExp[i], flatIndices[i]); + } + + sd::sparse::IndexUtils::unravelIndex(indicesArr, flatIndices, nnz, shapeInfoBuffer); + + for ( int i = 0; i < nnz * rank; ++i){ + ASSERT_EQ(indicesArrExp[i], indicesArr[i]); + } + + shape[2] = 30; + shapeInfoBuffer = shape::shapeBuffer(rank, sd::DataType::INT64, shape); + + try { + sd::sparse::IndexUtils::ravelMultiIndex(indicesArrExp, flatIndices, nnz, shapeInfoBuffer, ND4J_CLIPMODE_THROW); + FAIL(); + } catch (const std::runtime_error& e) { + // pass + } + + delete[] indicesArrExp; + delete[] indicesArr; + delete[] flatIndicesExp; + delete[] flatIndices; + delete[] shape; + delete[] shapeInfoBuffer; + +#endif +} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-native-api/src/main/java/org/nd4j/nativeblas/NativeOps.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-native-api/src/main/java/org/nd4j/nativeblas/NativeOps.java index ecae7e272..0861c4e1b 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-native-api/src/main/java/org/nd4j/nativeblas/NativeOps.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-native-api/src/main/java/org/nd4j/nativeblas/NativeOps.java @@ -1025,6 +1025,32 @@ public interface NativeOps { void sortCooIndices(PointerPointer extraPointers, @Cast("Nd4jLong *") LongPointer indices, Pointer x, long length, @Cast("Nd4jLong *") LongPointer shapeInfo); + /** + * + * @param extraPointers not used + * @param indices DataBuffer containing COO indices for a sparse matrix that is to be raveled/flattened + * @param flatIndices DataBuffer where the raveled/flattened indices are to be written to + * @param length number of non-zero entries (length of flatIndices) + * @param shapeInfo DataBuffer with ShapeInfo for the full matrix to be flattened + * @param mode clipMode determines the strategy to use if some of the the passed COO indices does + * not fit into the shape determined by fullShapeBuffer + * 0 throw an exception (default) + * 1 wrap around shape + * 2 clip to shape + */ + void ravelMultiIndex(PointerPointer extraPointers, @Cast("Nd4jLong *") LongPointer indices, @Cast("Nd4jLong *") LongPointer flatIndices, long length, @Cast("Nd4jLong *") LongPointer shapeInfo, int mode); + + /** + * + * @param extraPointers not used + * @param indices DataBuffer where the unraveled COO indices are to be written + * @param flatIndices DataBuffer containing the raveled/flattened indices to be unravel + * @param length number of non-zero entries (length of flatIndices) + * @param shapeInfo DataBuffer with ShapeInfo for the full matrix to be unraveled + */ + void unravelIndex(PointerPointer extraPointers, @Cast("Nd4jLong *") LongPointer indices, @Cast("Nd4jLong *") LongPointer flatIndices, long length, @Cast("Nd4jLong *") LongPointer shapeInfo); + + LongPointer mmapFile(PointerPointer extraPointers, String fileName, long length); void munmapFile(PointerPointer extraPointers, LongPointer ptrMap, long length); diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/nativeblas/Nd4jCpu.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/nativeblas/Nd4jCpu.java index bca2a4dbf..93ddc153b 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/nativeblas/Nd4jCpu.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/nativeblas/Nd4jCpu.java @@ -3151,6 +3151,12 @@ public native void sortCooIndices(@Cast("Nd4jPointer*") PointerPointer extraPoin @Cast("Nd4jLong") long length, @Cast("const Nd4jLong*") long[] xShapeInfo); +public native void ravelMultiIndex(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong*") LongPointer indices, @Cast("Nd4jLong*") LongPointer flatIndices, @Cast("Nd4jLong") long length, @Cast("Nd4jLong*") LongPointer shapeInfo, int mode); +public native void ravelMultiIndex(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong*") LongBuffer indices, @Cast("Nd4jLong*") LongBuffer flatIndices, @Cast("Nd4jLong") long length, @Cast("Nd4jLong*") LongBuffer shapeInfo, int mode); +public native void ravelMultiIndex(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong*") long[] indices, @Cast("Nd4jLong*") long[] flatIndices, @Cast("Nd4jLong") long length, @Cast("Nd4jLong*") long[] shapeInfo, int mode); +public native void unravelIndex(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong*") LongPointer indices, @Cast("Nd4jLong*") LongPointer flatIndices, @Cast("Nd4jLong") long length, @Cast("Nd4jLong*") LongPointer shapeInfo); +public native void unravelIndex(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong*") LongBuffer indices, @Cast("Nd4jLong*") LongBuffer flatIndices, @Cast("Nd4jLong") long length, @Cast("Nd4jLong*") LongBuffer shapeInfo); +public native void unravelIndex(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong*") long[] indices, @Cast("Nd4jLong*") long[] flatIndices, @Cast("Nd4jLong") long length, @Cast("Nd4jLong*") long[] shapeInfo); public native @Cast("Nd4jLong*") LongPointer mmapFile(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("char*") String fileName, @Cast("Nd4jLong") long length); public native @Cast("Nd4jLong*") LongBuffer mmapFile(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("char*") BytePointer fileName, @Cast("Nd4jLong") long length); diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/specials/RavelIndexTest.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/specials/RavelIndexTest.java new file mode 100644 index 000000000..f9ea69c08 --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/specials/RavelIndexTest.java @@ -0,0 +1,170 @@ +package org.nd4j.linalg.specials; + +import lombok.extern.slf4j.Slf4j; +import org.bytedeco.javacpp.LongPointer; +import org.junit.After; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; +import org.nd4j.linalg.BaseNd4jTest; +import org.nd4j.linalg.api.buffer.DataBuffer; +import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.factory.Nd4j; +import org.nd4j.linalg.factory.Nd4jBackend; +import org.nd4j.nativeblas.NativeOpsHolder; + + +@Slf4j +@RunWith(Parameterized.class) +public class RavelIndexTest extends BaseNd4jTest { + + DataType initialType; + + public RavelIndexTest(Nd4jBackend backend) { + super(backend); + this.initialType = Nd4j.dataType(); + } + + @Before + public void setUp() { + Nd4j.setDataType(DataType.FLOAT); + } + + @After + public void setDown() { + Nd4j.setDataType(initialType); + } + + @Override + public char ordering() { + return 'c'; + } + + + @Test + public void ravelIndexesTest() { + // FIXME: we don't want this test running on cuda for now + if (Nd4j.getExecutioner().getClass().getCanonicalName().toLowerCase().contains("cuda")) + return; + + long[] multiIdxArray = new long[] { + 0,2,7, + 2,36,35, + 3,30,17, + 5,12,22, + 5,43,45, + 6,32,11, + 8,8,32, + 9,29,11, + 5,11,22, + 15,26,16, + 17,48,49, + 24,28,31, + 26,6,23, + 31,21,31, + 35,46,45, + 37,13,14, + 6,38,18, + 7,28,20, + 8,29,39, + 8,32,30, + 9,42,43, + 11,15,18, + 13,18,45, + 29,26,39, + 30,8,25, + 42,31,24, + 28,33,5, + 31,27,1, + 35,43,26, + 36,8,37, + 39,22,14, + 39,24,42, + 42,48,2, + 43,26,48, + 44,23,49, + 45,18,34, + 46,28,5, + 46,32,17, + 48,34,44, + 49,38,39, + }; + + long[] flatIdxArray = new long[] { + 147, 10955, 14717, 21862, 24055, 27451, 34192, 39841, + 21792, 64836, 74809, 102791, 109643, 131701, 150265, 156324, + 27878, 31380, 35669, 35870, 40783, 47268, 55905, 123659, + 126585, 178594, 119915, 132091, 150036, 151797, 165354, 165522, + 179762, 182468, 186459, 190294, 195165, 195457, 204024, 208499 + }; + + + + int clipMode = 0; + + + long DIM = 3; + long length = multiIdxArray.length / DIM; + long[] shape = new long[] {50, 60, 70}; + + + DataBuffer multiIdxDB = Nd4j.getDataBufferFactory().createLong(multiIdxArray); + DataBuffer flatIdxDB = Nd4j.getDataBufferFactory().createLong(flatIdxArray); + DataBuffer shapeInfo = Nd4j.getShapeInfoProvider().createShapeInformation(shape, DataType.FLOAT).getFirst(); + + DataBuffer resultMulti = Nd4j.getDataBufferFactory().createLong(length*DIM); + DataBuffer resultFlat = Nd4j.getDataBufferFactory().createLong(length); + + NativeOpsHolder.getInstance().getDeviceNativeOps().ravelMultiIndex(null, (LongPointer) multiIdxDB.addressPointer(), + (LongPointer) resultFlat.addressPointer(), length, (LongPointer) shapeInfo.addressPointer(),clipMode); + + Assert.assertArrayEquals(flatIdxArray, resultFlat.asLong()); + + NativeOpsHolder.getInstance().getDeviceNativeOps().unravelIndex(null, (LongPointer) resultMulti.addressPointer(), + (LongPointer) flatIdxDB.addressPointer(), length, (LongPointer) shapeInfo.addressPointer()); + + Assert.assertArrayEquals(multiIdxArray, resultMulti.asLong()); + + + //testing various clipMode cases + // clipMode = 0: throw an exception + try { + shape[2] = 10; + shapeInfo = Nd4j.getShapeInfoProvider().createShapeInformation(shape, DataType.FLOAT).getFirst(); + NativeOpsHolder.getInstance().getDeviceNativeOps().ravelMultiIndex(null, (LongPointer) multiIdxDB.addressPointer(), + (LongPointer) resultFlat.addressPointer(), length, (LongPointer) shapeInfo.addressPointer(),clipMode); + Assert.fail("No exception thrown while using CLIP_MODE_THROW."); + + } catch (RuntimeException e) { + //OK + } + // clipMode = 1: wrap around shape + clipMode = 1; + multiIdxDB = Nd4j.getDataBufferFactory().createLong(new long[] {3,4, 6,5, 6,9}); + resultFlat = Nd4j.getDataBufferFactory().createLong(3); + shapeInfo = Nd4j.getShapeInfoProvider().createShapeInformation(new long[] {4, 6}, DataType.FLOAT).getFirst(); + length = 3; + + NativeOpsHolder.getInstance().getDeviceNativeOps().ravelMultiIndex(null, (LongPointer) multiIdxDB.addressPointer(), + (LongPointer) resultFlat.addressPointer(), length, (LongPointer) shapeInfo.addressPointer(), clipMode); + Assert.assertArrayEquals(new long[] {22, 17, 15}, resultFlat.asLong()); + + // clipMode = 2: clip to shape + clipMode = 2; + multiIdxDB = Nd4j.getDataBufferFactory().createLong(new long[] {3,4, 6,5, 6,9}); + resultFlat = Nd4j.getDataBufferFactory().createLong(3); + shapeInfo = Nd4j.getShapeInfoProvider().createShapeInformation(new long[] {4, 6}, DataType.FLOAT).getFirst(); + length = 3; + + NativeOpsHolder.getInstance().getDeviceNativeOps().ravelMultiIndex(null, (LongPointer) multiIdxDB.addressPointer(), + (LongPointer) resultFlat.addressPointer(), length, (LongPointer) shapeInfo.addressPointer(), clipMode); + + Assert.assertArrayEquals(new long[] {22, 23, 23}, resultFlat.asLong()); + + + + } + +}