update sortCooIndicesGeneric to take any data type (#9121)
Previously, this function only worked correctly for 64bit data types. Signed-off-by: Péter Zarándy <pza@wehowsky.com>
This commit is contained in:
		
							parent
							
								
									8e591bbf39
								
							
						
					
					
						commit
						a1fcc5f19f
					
				@ -28,6 +28,7 @@
 | 
			
		||||
#include <ops/specials_sparse.h>
 | 
			
		||||
#include <execution/LaunchContext.h>
 | 
			
		||||
#include <array/ArrayOptions.h>
 | 
			
		||||
#include <helpers/shape.h>
 | 
			
		||||
 | 
			
		||||
/**
 | 
			
		||||
 * Native op executioner:
 | 
			
		||||
@ -652,8 +653,11 @@ static void execTransformBool(sd::LaunchContext  *lc,
 | 
			
		||||
        BUILD_SINGLE_SELECTOR(xType, sd::SpecialMethods, ::sortTadGeneric(x, xShapeInfo, dimension, dimensionLength, tadShapeInfo, tadOffsets, descending), LIBND4J_TYPES);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    inline static void execSortCooIndices(Nd4jLong *indices, void *values, Nd4jLong length, int rank) {
 | 
			
		||||
        sd::sparse::SparseUtils<Nd4jLong>::sortCooIndicesGeneric(indices, reinterpret_cast<Nd4jLong *>(values), length, rank);
 | 
			
		||||
    inline static void execSortCooIndices(Nd4jLong *indices, void *x, Nd4jLong length, const Nd4jLong *xShapeInfo) {
 | 
			
		||||
        auto xType = sd::ArrayOptions::dataType(xShapeInfo);
 | 
			
		||||
        int rank = shape::rank(xShapeInfo);
 | 
			
		||||
 | 
			
		||||
        BUILD_SINGLE_SELECTOR(xType, sd::sparse::SparseUtils, ::sortCooIndicesGeneric(indices, x, length, rank), LIBND4J_TYPES);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -1476,7 +1476,11 @@ ND4J_EXPORT void sortTadByValue(Nd4jPointer *extraPointers,
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
// special sort impl for sorting out COO indices and values
 | 
			
		||||
ND4J_EXPORT void sortCooIndices(Nd4jPointer *extraPointers, Nd4jLong *indices, void *values, Nd4jLong length, int rank);
 | 
			
		||||
ND4J_EXPORT void sortCooIndices(Nd4jPointer *extraPointers,
 | 
			
		||||
                                Nd4jLong *indices,
 | 
			
		||||
                                void *x,
 | 
			
		||||
                                Nd4jLong length,
 | 
			
		||||
                                const Nd4jLong *xShapeInfo);
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
ND4J_EXPORT Nd4jLong* mmapFile(Nd4jPointer *extraPointers, const char *fileName, Nd4jLong length);
 | 
			
		||||
 | 
			
		||||
@ -1832,11 +1832,11 @@ void sortTad(Nd4jPointer *extraPointers,
 | 
			
		||||
 | 
			
		||||
void sortCooIndices(Nd4jPointer *extraPointers,
 | 
			
		||||
        Nd4jLong *indices,
 | 
			
		||||
        void *values,
 | 
			
		||||
        void *x,
 | 
			
		||||
        Nd4jLong length,
 | 
			
		||||
        int rank) {
 | 
			
		||||
        const Nd4jLong *xShapeInfo) {
 | 
			
		||||
    try {
 | 
			
		||||
        NativeOpExecutioner::execSortCooIndices(indices, values, length, rank);
 | 
			
		||||
        NativeOpExecutioner::execSortCooIndices(indices, x, length, xShapeInfo);
 | 
			
		||||
    } catch (std::exception &e) {
 | 
			
		||||
        sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1);
 | 
			
		||||
        sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what());
 | 
			
		||||
 | 
			
		||||
@ -2560,7 +2560,7 @@ void sortTad(Nd4jPointer *extraPointers,
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
void sortCooIndices(Nd4jPointer *extraPointers, Nd4jLong *indices, void *values, Nd4jLong length, int rank) {
 | 
			
		||||
void sortCooIndices(Nd4jPointer *extraPointers, Nd4jLong *indices, void *values, Nd4jLong length, const Nd4jLong *xShapeInfo) {
 | 
			
		||||
	throw std::runtime_error("sortCooIndices:: Not implemented yet");
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -209,7 +209,8 @@ PRAGMA_OMP_SINGLE_ARGS(nowait)
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        template <typename T>
 | 
			
		||||
        void SparseUtils<T>::sortCooIndicesGeneric(Nd4jLong *indices, T *values, Nd4jLong length, int rank) {
 | 
			
		||||
        void SparseUtils<T>::sortCooIndicesGeneric(Nd4jLong *indices, void *vx, Nd4jLong length, int rank) {
 | 
			
		||||
        auto values = reinterpret_cast<T *>(vx);
 | 
			
		||||
#ifdef _OPENMP
 | 
			
		||||
            coo_quickSort_parallel(indices, values, length, omp_get_max_threads(), rank);
 | 
			
		||||
#else
 | 
			
		||||
 | 
			
		||||
@ -60,7 +60,7 @@ namespace sd {
 | 
			
		||||
            static Nd4jLong coo_quickSort_findPivot(Nd4jLong *indices, T *array, Nd4jLong left, Nd4jLong right,
 | 
			
		||||
                                                    int rank);
 | 
			
		||||
 | 
			
		||||
            static void sortCooIndicesGeneric(Nd4jLong *indices, T *values, Nd4jLong length, int rank);
 | 
			
		||||
            static void sortCooIndicesGeneric(Nd4jLong *indices, void *vx, Nd4jLong length, int rank);
 | 
			
		||||
        };
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@ -1022,7 +1022,7 @@ public interface NativeOps {
 | 
			
		||||
                 boolean descending);
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
    void sortCooIndices(PointerPointer extraPointers, @Cast("Nd4jLong *") LongPointer indices, Pointer values, long length, int rank);
 | 
			
		||||
    void sortCooIndices(PointerPointer extraPointers, @Cast("Nd4jLong *") LongPointer indices, Pointer x, long length, @Cast("Nd4jLong *") LongPointer shapeInfo);
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
    LongPointer mmapFile(PointerPointer extraPointers, String fileName, long length);
 | 
			
		||||
 | 
			
		||||
@ -3135,9 +3135,21 @@ public native void sortTadByValue(@Cast("Nd4jPointer*") PointerPointer extraPoin
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
// special sort impl for sorting out COO indices and values
 | 
			
		||||
public native void sortCooIndices(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong*") LongPointer indices, Pointer values, @Cast("Nd4jLong") long length, int rank);
 | 
			
		||||
public native void sortCooIndices(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong*") LongBuffer indices, Pointer values, @Cast("Nd4jLong") long length, int rank);
 | 
			
		||||
public native void sortCooIndices(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong*") long[] indices, Pointer values, @Cast("Nd4jLong") long length, int rank);
 | 
			
		||||
public native void sortCooIndices(@Cast("Nd4jPointer*") PointerPointer extraPointers,
 | 
			
		||||
                                @Cast("Nd4jLong*") LongPointer indices,
 | 
			
		||||
                                Pointer x,
 | 
			
		||||
                                @Cast("Nd4jLong") long length,
 | 
			
		||||
                                @Cast("const Nd4jLong*") LongPointer xShapeInfo);
 | 
			
		||||
public native void sortCooIndices(@Cast("Nd4jPointer*") PointerPointer extraPointers,
 | 
			
		||||
                                @Cast("Nd4jLong*") LongBuffer indices,
 | 
			
		||||
                                Pointer x,
 | 
			
		||||
                                @Cast("Nd4jLong") long length,
 | 
			
		||||
                                @Cast("const Nd4jLong*") LongBuffer xShapeInfo);
 | 
			
		||||
public native void sortCooIndices(@Cast("Nd4jPointer*") PointerPointer extraPointers,
 | 
			
		||||
                                @Cast("Nd4jLong*") long[] indices,
 | 
			
		||||
                                Pointer x,
 | 
			
		||||
                                @Cast("Nd4jLong") long length,
 | 
			
		||||
                                @Cast("const Nd4jLong*") long[] xShapeInfo);
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
public native @Cast("Nd4jLong*") LongPointer mmapFile(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("char*") String fileName, @Cast("Nd4jLong") long length);
 | 
			
		||||
 | 
			
		||||
@ -16,6 +16,8 @@
 | 
			
		||||
 | 
			
		||||
package org.nd4j.linalg.specials;
 | 
			
		||||
 | 
			
		||||
import com.google.common.primitives.Doubles;
 | 
			
		||||
import com.google.common.primitives.Floats;
 | 
			
		||||
import lombok.extern.slf4j.Slf4j;
 | 
			
		||||
import lombok.val;
 | 
			
		||||
import org.bytedeco.javacpp.LongPointer;
 | 
			
		||||
@ -29,6 +31,7 @@ import org.nd4j.linalg.api.buffer.DataBuffer;
 | 
			
		||||
import org.nd4j.linalg.api.buffer.DataType;
 | 
			
		||||
import org.nd4j.linalg.api.ndarray.INDArray;
 | 
			
		||||
import org.nd4j.linalg.api.rng.Random;
 | 
			
		||||
import org.nd4j.linalg.api.shape.Shape;
 | 
			
		||||
import org.nd4j.linalg.factory.Nd4j;
 | 
			
		||||
import org.nd4j.linalg.factory.Nd4jBackend;
 | 
			
		||||
import org.nd4j.nativeblas.NativeOpsHolder;
 | 
			
		||||
@ -46,20 +49,22 @@ import static org.junit.Assert.assertArrayEquals;
 | 
			
		||||
public class SortCooTests extends BaseNd4jTest {
 | 
			
		||||
 | 
			
		||||
    DataType initialType;
 | 
			
		||||
    DataType initialDefaultType;
 | 
			
		||||
 | 
			
		||||
    public SortCooTests(Nd4jBackend backend) {
 | 
			
		||||
        super(backend);
 | 
			
		||||
        this.initialType = Nd4j.dataType();
 | 
			
		||||
        this.initialDefaultType = Nd4j.defaultFloatingPointType();
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    @Before
 | 
			
		||||
    public void setUp() {
 | 
			
		||||
        Nd4j.setDataType(DataType.FLOAT);
 | 
			
		||||
        Nd4j.setDefaultDataTypes(DataType.FLOAT, DataType.FLOAT);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    @After
 | 
			
		||||
    public void setDown() {
 | 
			
		||||
        Nd4j.setDataType(initialType);
 | 
			
		||||
        Nd4j.setDefaultDataTypes(initialType, Nd4j.defaultFloatingPointType());
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    @Test
 | 
			
		||||
@ -83,19 +88,17 @@ public class SortCooTests extends BaseNd4jTest {
 | 
			
		||||
                1, 1, 1};
 | 
			
		||||
        double expValues[] = new double[] {0, 1, 2, 3};
 | 
			
		||||
 | 
			
		||||
        DataBuffer idx = Nd4j.getDataBufferFactory().createLong(indices);
 | 
			
		||||
        DataBuffer val = Nd4j.createBuffer(values);
 | 
			
		||||
        for (DataType dataType : new DataType[]{DataType.FLOAT, DataType.DOUBLE, DataType.FLOAT16, DataType.INT64, DataType.INT32, DataType.INT16, DataType.INT8}) {
 | 
			
		||||
            DataBuffer idx = Nd4j.getDataBufferFactory().createLong(indices);
 | 
			
		||||
            DataBuffer val = Nd4j.createTypedBuffer(values, dataType);
 | 
			
		||||
            DataBuffer shapeInfo = Nd4j.getShapeInfoProvider().createShapeInformation(new long[]{2, 2, 2}, val.dataType()).getFirst();
 | 
			
		||||
            NativeOpsHolder.getInstance().getDeviceNativeOps().sortCooIndices(null, (LongPointer) idx.addressPointer(),
 | 
			
		||||
                    val.addressPointer(), 4, (LongPointer) shapeInfo.addressPointer());
 | 
			
		||||
 | 
			
		||||
//        log.info("Old indices: {}", Arrays.toString(idx.asInt()));
 | 
			
		||||
            assertArrayEquals(expIndices, idx.asLong());
 | 
			
		||||
            assertArrayEquals(expValues, val.asDouble(), 1e-5);
 | 
			
		||||
 | 
			
		||||
        NativeOpsHolder.getInstance().getDeviceNativeOps().sortCooIndices(null, (LongPointer) idx.addressPointer(),
 | 
			
		||||
                val.addressPointer(), 4, 3);
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
//        log.info("New indices: {}", Arrays.toString(idx.asInt()));
 | 
			
		||||
 | 
			
		||||
        assertArrayEquals(expIndices, idx.asLong());
 | 
			
		||||
        assertArrayEquals(expValues, val.asDouble(), 1e-5);
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    @Test
 | 
			
		||||
@ -117,14 +120,20 @@ public class SortCooTests extends BaseNd4jTest {
 | 
			
		||||
                2, 2, 2};
 | 
			
		||||
        double expValues[] = new double[] {2, 3, 1};
 | 
			
		||||
 | 
			
		||||
        DataBuffer idx = Nd4j.getDataBufferFactory().createLong(indices);
 | 
			
		||||
        DataBuffer val = Nd4j.createBuffer(values);
 | 
			
		||||
 | 
			
		||||
        NativeOpsHolder.getInstance().getDeviceNativeOps().sortCooIndices(null, (LongPointer) idx.addressPointer(),
 | 
			
		||||
                val.addressPointer(), 3, 3);
 | 
			
		||||
        for (DataType dataType : new DataType[]{DataType.FLOAT, DataType.DOUBLE, DataType.FLOAT16, DataType.INT64, DataType.INT32, DataType.INT16, DataType.INT8}) {
 | 
			
		||||
            DataBuffer idx = Nd4j.getDataBufferFactory().createLong(indices);
 | 
			
		||||
            DataBuffer val = Nd4j.createTypedBuffer(values, dataType);
 | 
			
		||||
            DataBuffer shapeInfo = Nd4j.getShapeInfoProvider().createShapeInformation(new long[]{2, 2, 2}, val.dataType()).getFirst();
 | 
			
		||||
            NativeOpsHolder.getInstance().getDeviceNativeOps().sortCooIndices(null, (LongPointer) idx.addressPointer(),
 | 
			
		||||
                    val.addressPointer(), 3, (LongPointer) shapeInfo.addressPointer());
 | 
			
		||||
 | 
			
		||||
            assertArrayEquals(expIndices, idx.asLong());
 | 
			
		||||
            assertArrayEquals(expValues, val.asDouble(), 1e-5);
 | 
			
		||||
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
        assertArrayEquals(expIndices, idx.asLong());
 | 
			
		||||
        assertArrayEquals(expValues, val.asDouble(), 1e-5);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    /**
 | 
			
		||||
@ -155,10 +164,11 @@ public class SortCooTests extends BaseNd4jTest {
 | 
			
		||||
 | 
			
		||||
        DataBuffer indiceBuffer = Nd4j.getDataBufferFactory().createLong(indices);
 | 
			
		||||
        DataBuffer valueBuffer = Nd4j.createBuffer(values);
 | 
			
		||||
        DataBuffer shapeInfo = Nd4j.getShapeInfoProvider().createShapeInformation(new long[]{3,3,3}, valueBuffer.dataType()).getFirst();
 | 
			
		||||
        INDArray indMatrix = Nd4j.create(indiceBuffer).reshape(new long[]{nnz, shape.length});
 | 
			
		||||
 | 
			
		||||
        NativeOpsHolder.getInstance().getDeviceNativeOps().sortCooIndices(null, (LongPointer) indiceBuffer.addressPointer(),
 | 
			
		||||
                valueBuffer.addressPointer(), nnz, 3);
 | 
			
		||||
                valueBuffer.addressPointer(), nnz, (LongPointer) shapeInfo.addressPointer());
 | 
			
		||||
 | 
			
		||||
        for (long i = 1; i < nnz; ++i){
 | 
			
		||||
            for(long j = 0; j < shape.length; ++j){
 | 
			
		||||
@ -273,9 +283,10 @@ public class SortCooTests extends BaseNd4jTest {
 | 
			
		||||
 | 
			
		||||
        DataBuffer idx = Nd4j.getDataBufferFactory().createLong(indices);
 | 
			
		||||
        DataBuffer val = Nd4j.createBuffer(values);
 | 
			
		||||
        DataBuffer shapeInfo = Nd4j.getShapeInfoProvider().createShapeInformation(new long[]{3,3,3}, val.dataType()).getFirst();
 | 
			
		||||
 | 
			
		||||
        NativeOpsHolder.getInstance().getDeviceNativeOps().sortCooIndices(null, (LongPointer) idx.addressPointer(),
 | 
			
		||||
                val.addressPointer(), 40, 3);
 | 
			
		||||
                val.addressPointer(), 40, (LongPointer) shapeInfo.addressPointer());
 | 
			
		||||
 | 
			
		||||
        // just check the indices. sortSparseCooIndicesSort1 and sortSparseCooIndicesSort2 checks that
 | 
			
		||||
        // indices and values are both swapped. This test just makes sure index sort works for larger arrays.
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user