diff --git a/libnd4j/include/legacy/NativeOpExecutioner.h b/libnd4j/include/legacy/NativeOpExecutioner.h index 20f0b6532..5e8f54550 100644 --- a/libnd4j/include/legacy/NativeOpExecutioner.h +++ b/libnd4j/include/legacy/NativeOpExecutioner.h @@ -28,6 +28,7 @@ #include #include #include +#include /** * Native op executioner: @@ -652,8 +653,11 @@ static void execTransformBool(sd::LaunchContext *lc, BUILD_SINGLE_SELECTOR(xType, sd::SpecialMethods, ::sortTadGeneric(x, xShapeInfo, dimension, dimensionLength, tadShapeInfo, tadOffsets, descending), LIBND4J_TYPES); } - inline static void execSortCooIndices(Nd4jLong *indices, void *values, Nd4jLong length, int rank) { - sd::sparse::SparseUtils::sortCooIndicesGeneric(indices, reinterpret_cast(values), length, rank); + inline static void execSortCooIndices(Nd4jLong *indices, void *x, Nd4jLong length, const Nd4jLong *xShapeInfo) { + auto xType = sd::ArrayOptions::dataType(xShapeInfo); + int rank = shape::rank(xShapeInfo); + + BUILD_SINGLE_SELECTOR(xType, sd::sparse::SparseUtils, ::sortCooIndicesGeneric(indices, x, length, rank), LIBND4J_TYPES); } diff --git a/libnd4j/include/legacy/NativeOps.h b/libnd4j/include/legacy/NativeOps.h index 66c8751da..392bf8e80 100755 --- a/libnd4j/include/legacy/NativeOps.h +++ b/libnd4j/include/legacy/NativeOps.h @@ -1476,7 +1476,11 @@ ND4J_EXPORT void sortTadByValue(Nd4jPointer *extraPointers, // special sort impl for sorting out COO indices and values -ND4J_EXPORT void sortCooIndices(Nd4jPointer *extraPointers, Nd4jLong *indices, void *values, Nd4jLong length, int rank); +ND4J_EXPORT void sortCooIndices(Nd4jPointer *extraPointers, + Nd4jLong *indices, + void *x, + Nd4jLong length, + const Nd4jLong *xShapeInfo); ND4J_EXPORT Nd4jLong* mmapFile(Nd4jPointer *extraPointers, const char *fileName, Nd4jLong length); diff --git a/libnd4j/include/legacy/cpu/NativeOps.cpp b/libnd4j/include/legacy/cpu/NativeOps.cpp index 463adc17e..44831b8ba 100644 --- a/libnd4j/include/legacy/cpu/NativeOps.cpp +++ b/libnd4j/include/legacy/cpu/NativeOps.cpp @@ -1832,11 +1832,11 @@ void sortTad(Nd4jPointer *extraPointers, void sortCooIndices(Nd4jPointer *extraPointers, Nd4jLong *indices, - void *values, + void *x, Nd4jLong length, - int rank) { + const Nd4jLong *xShapeInfo) { try { - NativeOpExecutioner::execSortCooIndices(indices, values, length, rank); + NativeOpExecutioner::execSortCooIndices(indices, x, length, xShapeInfo); } catch (std::exception &e) { sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1); sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what()); diff --git a/libnd4j/include/legacy/cuda/NativeOps.cu b/libnd4j/include/legacy/cuda/NativeOps.cu index 186b3a9cb..d66c9827b 100755 --- a/libnd4j/include/legacy/cuda/NativeOps.cu +++ b/libnd4j/include/legacy/cuda/NativeOps.cu @@ -2560,7 +2560,7 @@ void sortTad(Nd4jPointer *extraPointers, } } -void sortCooIndices(Nd4jPointer *extraPointers, Nd4jLong *indices, void *values, Nd4jLong length, int rank) { +void sortCooIndices(Nd4jPointer *extraPointers, Nd4jLong *indices, void *values, Nd4jLong length, const Nd4jLong *xShapeInfo) { throw std::runtime_error("sortCooIndices:: Not implemented yet"); } diff --git a/libnd4j/include/ops/impl/specials_sparse.cpp b/libnd4j/include/ops/impl/specials_sparse.cpp index c782ccf18..1485557f8 100644 --- a/libnd4j/include/ops/impl/specials_sparse.cpp +++ b/libnd4j/include/ops/impl/specials_sparse.cpp @@ -209,7 +209,8 @@ PRAGMA_OMP_SINGLE_ARGS(nowait) } template - void SparseUtils::sortCooIndicesGeneric(Nd4jLong *indices, T *values, Nd4jLong length, int rank) { + void SparseUtils::sortCooIndicesGeneric(Nd4jLong *indices, void *vx, Nd4jLong length, int rank) { + auto values = reinterpret_cast(vx); #ifdef _OPENMP coo_quickSort_parallel(indices, values, length, omp_get_max_threads(), rank); #else diff --git a/libnd4j/include/ops/specials_sparse.h b/libnd4j/include/ops/specials_sparse.h index cd0e2f6b5..4ea5a1573 100644 --- a/libnd4j/include/ops/specials_sparse.h +++ b/libnd4j/include/ops/specials_sparse.h @@ -60,7 +60,7 @@ namespace sd { static Nd4jLong coo_quickSort_findPivot(Nd4jLong *indices, T *array, Nd4jLong left, Nd4jLong right, int rank); - static void sortCooIndicesGeneric(Nd4jLong *indices, T *values, Nd4jLong length, int rank); + static void sortCooIndicesGeneric(Nd4jLong *indices, void *vx, Nd4jLong length, int rank); }; } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-native-api/src/main/java/org/nd4j/nativeblas/NativeOps.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-native-api/src/main/java/org/nd4j/nativeblas/NativeOps.java index 3053e55e2..ecae7e272 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-native-api/src/main/java/org/nd4j/nativeblas/NativeOps.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-native-api/src/main/java/org/nd4j/nativeblas/NativeOps.java @@ -1022,7 +1022,7 @@ public interface NativeOps { boolean descending); - void sortCooIndices(PointerPointer extraPointers, @Cast("Nd4jLong *") LongPointer indices, Pointer values, long length, int rank); + void sortCooIndices(PointerPointer extraPointers, @Cast("Nd4jLong *") LongPointer indices, Pointer x, long length, @Cast("Nd4jLong *") LongPointer shapeInfo); LongPointer mmapFile(PointerPointer extraPointers, String fileName, long length); diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/nativeblas/Nd4jCpu.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/nativeblas/Nd4jCpu.java index 916ba0a08..bca2a4dbf 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/nativeblas/Nd4jCpu.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/nativeblas/Nd4jCpu.java @@ -3135,9 +3135,21 @@ public native void sortTadByValue(@Cast("Nd4jPointer*") PointerPointer extraPoin // special sort impl for sorting out COO indices and values -public native void sortCooIndices(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong*") LongPointer indices, Pointer values, @Cast("Nd4jLong") long length, int rank); -public native void sortCooIndices(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong*") LongBuffer indices, Pointer values, @Cast("Nd4jLong") long length, int rank); -public native void sortCooIndices(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong*") long[] indices, Pointer values, @Cast("Nd4jLong") long length, int rank); +public native void sortCooIndices(@Cast("Nd4jPointer*") PointerPointer extraPointers, + @Cast("Nd4jLong*") LongPointer indices, + Pointer x, + @Cast("Nd4jLong") long length, + @Cast("const Nd4jLong*") LongPointer xShapeInfo); +public native void sortCooIndices(@Cast("Nd4jPointer*") PointerPointer extraPointers, + @Cast("Nd4jLong*") LongBuffer indices, + Pointer x, + @Cast("Nd4jLong") long length, + @Cast("const Nd4jLong*") LongBuffer xShapeInfo); +public native void sortCooIndices(@Cast("Nd4jPointer*") PointerPointer extraPointers, + @Cast("Nd4jLong*") long[] indices, + Pointer x, + @Cast("Nd4jLong") long length, + @Cast("const Nd4jLong*") long[] xShapeInfo); public native @Cast("Nd4jLong*") LongPointer mmapFile(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("char*") String fileName, @Cast("Nd4jLong") long length); diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/specials/SortCooTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/specials/SortCooTests.java index 39f5c10a9..c95359443 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/specials/SortCooTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/specials/SortCooTests.java @@ -16,6 +16,8 @@ package org.nd4j.linalg.specials; +import com.google.common.primitives.Doubles; +import com.google.common.primitives.Floats; import lombok.extern.slf4j.Slf4j; import lombok.val; import org.bytedeco.javacpp.LongPointer; @@ -29,6 +31,7 @@ import org.nd4j.linalg.api.buffer.DataBuffer; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.rng.Random; +import org.nd4j.linalg.api.shape.Shape; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4jBackend; import org.nd4j.nativeblas.NativeOpsHolder; @@ -46,20 +49,22 @@ import static org.junit.Assert.assertArrayEquals; public class SortCooTests extends BaseNd4jTest { DataType initialType; + DataType initialDefaultType; public SortCooTests(Nd4jBackend backend) { super(backend); this.initialType = Nd4j.dataType(); + this.initialDefaultType = Nd4j.defaultFloatingPointType(); } @Before public void setUp() { - Nd4j.setDataType(DataType.FLOAT); + Nd4j.setDefaultDataTypes(DataType.FLOAT, DataType.FLOAT); } @After public void setDown() { - Nd4j.setDataType(initialType); + Nd4j.setDefaultDataTypes(initialType, Nd4j.defaultFloatingPointType()); } @Test @@ -83,19 +88,17 @@ public class SortCooTests extends BaseNd4jTest { 1, 1, 1}; double expValues[] = new double[] {0, 1, 2, 3}; - DataBuffer idx = Nd4j.getDataBufferFactory().createLong(indices); - DataBuffer val = Nd4j.createBuffer(values); + for (DataType dataType : new DataType[]{DataType.FLOAT, DataType.DOUBLE, DataType.FLOAT16, DataType.INT64, DataType.INT32, DataType.INT16, DataType.INT8}) { + DataBuffer idx = Nd4j.getDataBufferFactory().createLong(indices); + DataBuffer val = Nd4j.createTypedBuffer(values, dataType); + DataBuffer shapeInfo = Nd4j.getShapeInfoProvider().createShapeInformation(new long[]{2, 2, 2}, val.dataType()).getFirst(); + NativeOpsHolder.getInstance().getDeviceNativeOps().sortCooIndices(null, (LongPointer) idx.addressPointer(), + val.addressPointer(), 4, (LongPointer) shapeInfo.addressPointer()); -// log.info("Old indices: {}", Arrays.toString(idx.asInt())); + assertArrayEquals(expIndices, idx.asLong()); + assertArrayEquals(expValues, val.asDouble(), 1e-5); - NativeOpsHolder.getInstance().getDeviceNativeOps().sortCooIndices(null, (LongPointer) idx.addressPointer(), - val.addressPointer(), 4, 3); - - -// log.info("New indices: {}", Arrays.toString(idx.asInt())); - - assertArrayEquals(expIndices, idx.asLong()); - assertArrayEquals(expValues, val.asDouble(), 1e-5); + } } @Test @@ -117,14 +120,20 @@ public class SortCooTests extends BaseNd4jTest { 2, 2, 2}; double expValues[] = new double[] {2, 3, 1}; - DataBuffer idx = Nd4j.getDataBufferFactory().createLong(indices); - DataBuffer val = Nd4j.createBuffer(values); - NativeOpsHolder.getInstance().getDeviceNativeOps().sortCooIndices(null, (LongPointer) idx.addressPointer(), - val.addressPointer(), 3, 3); + for (DataType dataType : new DataType[]{DataType.FLOAT, DataType.DOUBLE, DataType.FLOAT16, DataType.INT64, DataType.INT32, DataType.INT16, DataType.INT8}) { + DataBuffer idx = Nd4j.getDataBufferFactory().createLong(indices); + DataBuffer val = Nd4j.createTypedBuffer(values, dataType); + DataBuffer shapeInfo = Nd4j.getShapeInfoProvider().createShapeInformation(new long[]{2, 2, 2}, val.dataType()).getFirst(); + NativeOpsHolder.getInstance().getDeviceNativeOps().sortCooIndices(null, (LongPointer) idx.addressPointer(), + val.addressPointer(), 3, (LongPointer) shapeInfo.addressPointer()); + + assertArrayEquals(expIndices, idx.asLong()); + assertArrayEquals(expValues, val.asDouble(), 1e-5); + + } + - assertArrayEquals(expIndices, idx.asLong()); - assertArrayEquals(expValues, val.asDouble(), 1e-5); } /** @@ -155,10 +164,11 @@ public class SortCooTests extends BaseNd4jTest { DataBuffer indiceBuffer = Nd4j.getDataBufferFactory().createLong(indices); DataBuffer valueBuffer = Nd4j.createBuffer(values); + DataBuffer shapeInfo = Nd4j.getShapeInfoProvider().createShapeInformation(new long[]{3,3,3}, valueBuffer.dataType()).getFirst(); INDArray indMatrix = Nd4j.create(indiceBuffer).reshape(new long[]{nnz, shape.length}); NativeOpsHolder.getInstance().getDeviceNativeOps().sortCooIndices(null, (LongPointer) indiceBuffer.addressPointer(), - valueBuffer.addressPointer(), nnz, 3); + valueBuffer.addressPointer(), nnz, (LongPointer) shapeInfo.addressPointer()); for (long i = 1; i < nnz; ++i){ for(long j = 0; j < shape.length; ++j){ @@ -273,9 +283,10 @@ public class SortCooTests extends BaseNd4jTest { DataBuffer idx = Nd4j.getDataBufferFactory().createLong(indices); DataBuffer val = Nd4j.createBuffer(values); + DataBuffer shapeInfo = Nd4j.getShapeInfoProvider().createShapeInformation(new long[]{3,3,3}, val.dataType()).getFirst(); NativeOpsHolder.getInstance().getDeviceNativeOps().sortCooIndices(null, (LongPointer) idx.addressPointer(), - val.addressPointer(), 40, 3); + val.addressPointer(), 40, (LongPointer) shapeInfo.addressPointer()); // just check the indices. sortSparseCooIndicesSort1 and sortSparseCooIndicesSort2 checks that // indices and values are both swapped. This test just makes sure index sort works for larger arrays.