update sortCooIndicesGeneric to take any data type (#9121)
Previously, this function only worked correctly for 64bit data types. Signed-off-by: Péter Zarándy <pza@wehowsky.com>master
parent
8e591bbf39
commit
a1fcc5f19f
|
@ -28,6 +28,7 @@
|
||||||
#include <ops/specials_sparse.h>
|
#include <ops/specials_sparse.h>
|
||||||
#include <execution/LaunchContext.h>
|
#include <execution/LaunchContext.h>
|
||||||
#include <array/ArrayOptions.h>
|
#include <array/ArrayOptions.h>
|
||||||
|
#include <helpers/shape.h>
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Native op executioner:
|
* Native op executioner:
|
||||||
|
@ -652,8 +653,11 @@ static void execTransformBool(sd::LaunchContext *lc,
|
||||||
BUILD_SINGLE_SELECTOR(xType, sd::SpecialMethods, ::sortTadGeneric(x, xShapeInfo, dimension, dimensionLength, tadShapeInfo, tadOffsets, descending), LIBND4J_TYPES);
|
BUILD_SINGLE_SELECTOR(xType, sd::SpecialMethods, ::sortTadGeneric(x, xShapeInfo, dimension, dimensionLength, tadShapeInfo, tadOffsets, descending), LIBND4J_TYPES);
|
||||||
}
|
}
|
||||||
|
|
||||||
inline static void execSortCooIndices(Nd4jLong *indices, void *values, Nd4jLong length, int rank) {
|
inline static void execSortCooIndices(Nd4jLong *indices, void *x, Nd4jLong length, const Nd4jLong *xShapeInfo) {
|
||||||
sd::sparse::SparseUtils<Nd4jLong>::sortCooIndicesGeneric(indices, reinterpret_cast<Nd4jLong *>(values), length, rank);
|
auto xType = sd::ArrayOptions::dataType(xShapeInfo);
|
||||||
|
int rank = shape::rank(xShapeInfo);
|
||||||
|
|
||||||
|
BUILD_SINGLE_SELECTOR(xType, sd::sparse::SparseUtils, ::sortCooIndicesGeneric(indices, x, length, rank), LIBND4J_TYPES);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -1476,7 +1476,11 @@ ND4J_EXPORT void sortTadByValue(Nd4jPointer *extraPointers,
|
||||||
|
|
||||||
|
|
||||||
// special sort impl for sorting out COO indices and values
|
// special sort impl for sorting out COO indices and values
|
||||||
ND4J_EXPORT void sortCooIndices(Nd4jPointer *extraPointers, Nd4jLong *indices, void *values, Nd4jLong length, int rank);
|
ND4J_EXPORT void sortCooIndices(Nd4jPointer *extraPointers,
|
||||||
|
Nd4jLong *indices,
|
||||||
|
void *x,
|
||||||
|
Nd4jLong length,
|
||||||
|
const Nd4jLong *xShapeInfo);
|
||||||
|
|
||||||
|
|
||||||
ND4J_EXPORT Nd4jLong* mmapFile(Nd4jPointer *extraPointers, const char *fileName, Nd4jLong length);
|
ND4J_EXPORT Nd4jLong* mmapFile(Nd4jPointer *extraPointers, const char *fileName, Nd4jLong length);
|
||||||
|
|
|
@ -1832,11 +1832,11 @@ void sortTad(Nd4jPointer *extraPointers,
|
||||||
|
|
||||||
void sortCooIndices(Nd4jPointer *extraPointers,
|
void sortCooIndices(Nd4jPointer *extraPointers,
|
||||||
Nd4jLong *indices,
|
Nd4jLong *indices,
|
||||||
void *values,
|
void *x,
|
||||||
Nd4jLong length,
|
Nd4jLong length,
|
||||||
int rank) {
|
const Nd4jLong *xShapeInfo) {
|
||||||
try {
|
try {
|
||||||
NativeOpExecutioner::execSortCooIndices(indices, values, length, rank);
|
NativeOpExecutioner::execSortCooIndices(indices, x, length, xShapeInfo);
|
||||||
} catch (std::exception &e) {
|
} catch (std::exception &e) {
|
||||||
sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1);
|
sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1);
|
||||||
sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what());
|
sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what());
|
||||||
|
|
|
@ -2560,7 +2560,7 @@ void sortTad(Nd4jPointer *extraPointers,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void sortCooIndices(Nd4jPointer *extraPointers, Nd4jLong *indices, void *values, Nd4jLong length, int rank) {
|
void sortCooIndices(Nd4jPointer *extraPointers, Nd4jLong *indices, void *values, Nd4jLong length, const Nd4jLong *xShapeInfo) {
|
||||||
throw std::runtime_error("sortCooIndices:: Not implemented yet");
|
throw std::runtime_error("sortCooIndices:: Not implemented yet");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -209,7 +209,8 @@ PRAGMA_OMP_SINGLE_ARGS(nowait)
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
void SparseUtils<T>::sortCooIndicesGeneric(Nd4jLong *indices, T *values, Nd4jLong length, int rank) {
|
void SparseUtils<T>::sortCooIndicesGeneric(Nd4jLong *indices, void *vx, Nd4jLong length, int rank) {
|
||||||
|
auto values = reinterpret_cast<T *>(vx);
|
||||||
#ifdef _OPENMP
|
#ifdef _OPENMP
|
||||||
coo_quickSort_parallel(indices, values, length, omp_get_max_threads(), rank);
|
coo_quickSort_parallel(indices, values, length, omp_get_max_threads(), rank);
|
||||||
#else
|
#else
|
||||||
|
|
|
@ -60,7 +60,7 @@ namespace sd {
|
||||||
static Nd4jLong coo_quickSort_findPivot(Nd4jLong *indices, T *array, Nd4jLong left, Nd4jLong right,
|
static Nd4jLong coo_quickSort_findPivot(Nd4jLong *indices, T *array, Nd4jLong left, Nd4jLong right,
|
||||||
int rank);
|
int rank);
|
||||||
|
|
||||||
static void sortCooIndicesGeneric(Nd4jLong *indices, T *values, Nd4jLong length, int rank);
|
static void sortCooIndicesGeneric(Nd4jLong *indices, void *vx, Nd4jLong length, int rank);
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -1022,7 +1022,7 @@ public interface NativeOps {
|
||||||
boolean descending);
|
boolean descending);
|
||||||
|
|
||||||
|
|
||||||
void sortCooIndices(PointerPointer extraPointers, @Cast("Nd4jLong *") LongPointer indices, Pointer values, long length, int rank);
|
void sortCooIndices(PointerPointer extraPointers, @Cast("Nd4jLong *") LongPointer indices, Pointer x, long length, @Cast("Nd4jLong *") LongPointer shapeInfo);
|
||||||
|
|
||||||
|
|
||||||
LongPointer mmapFile(PointerPointer extraPointers, String fileName, long length);
|
LongPointer mmapFile(PointerPointer extraPointers, String fileName, long length);
|
||||||
|
|
|
@ -3135,9 +3135,21 @@ public native void sortTadByValue(@Cast("Nd4jPointer*") PointerPointer extraPoin
|
||||||
|
|
||||||
|
|
||||||
// special sort impl for sorting out COO indices and values
|
// special sort impl for sorting out COO indices and values
|
||||||
public native void sortCooIndices(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong*") LongPointer indices, Pointer values, @Cast("Nd4jLong") long length, int rank);
|
public native void sortCooIndices(@Cast("Nd4jPointer*") PointerPointer extraPointers,
|
||||||
public native void sortCooIndices(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong*") LongBuffer indices, Pointer values, @Cast("Nd4jLong") long length, int rank);
|
@Cast("Nd4jLong*") LongPointer indices,
|
||||||
public native void sortCooIndices(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong*") long[] indices, Pointer values, @Cast("Nd4jLong") long length, int rank);
|
Pointer x,
|
||||||
|
@Cast("Nd4jLong") long length,
|
||||||
|
@Cast("const Nd4jLong*") LongPointer xShapeInfo);
|
||||||
|
public native void sortCooIndices(@Cast("Nd4jPointer*") PointerPointer extraPointers,
|
||||||
|
@Cast("Nd4jLong*") LongBuffer indices,
|
||||||
|
Pointer x,
|
||||||
|
@Cast("Nd4jLong") long length,
|
||||||
|
@Cast("const Nd4jLong*") LongBuffer xShapeInfo);
|
||||||
|
public native void sortCooIndices(@Cast("Nd4jPointer*") PointerPointer extraPointers,
|
||||||
|
@Cast("Nd4jLong*") long[] indices,
|
||||||
|
Pointer x,
|
||||||
|
@Cast("Nd4jLong") long length,
|
||||||
|
@Cast("const Nd4jLong*") long[] xShapeInfo);
|
||||||
|
|
||||||
|
|
||||||
public native @Cast("Nd4jLong*") LongPointer mmapFile(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("char*") String fileName, @Cast("Nd4jLong") long length);
|
public native @Cast("Nd4jLong*") LongPointer mmapFile(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("char*") String fileName, @Cast("Nd4jLong") long length);
|
||||||
|
|
|
@ -16,6 +16,8 @@
|
||||||
|
|
||||||
package org.nd4j.linalg.specials;
|
package org.nd4j.linalg.specials;
|
||||||
|
|
||||||
|
import com.google.common.primitives.Doubles;
|
||||||
|
import com.google.common.primitives.Floats;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import lombok.val;
|
import lombok.val;
|
||||||
import org.bytedeco.javacpp.LongPointer;
|
import org.bytedeco.javacpp.LongPointer;
|
||||||
|
@ -29,6 +31,7 @@ import org.nd4j.linalg.api.buffer.DataBuffer;
|
||||||
import org.nd4j.linalg.api.buffer.DataType;
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.api.rng.Random;
|
import org.nd4j.linalg.api.rng.Random;
|
||||||
|
import org.nd4j.linalg.api.shape.Shape;
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
import org.nd4j.linalg.factory.Nd4jBackend;
|
import org.nd4j.linalg.factory.Nd4jBackend;
|
||||||
import org.nd4j.nativeblas.NativeOpsHolder;
|
import org.nd4j.nativeblas.NativeOpsHolder;
|
||||||
|
@ -46,20 +49,22 @@ import static org.junit.Assert.assertArrayEquals;
|
||||||
public class SortCooTests extends BaseNd4jTest {
|
public class SortCooTests extends BaseNd4jTest {
|
||||||
|
|
||||||
DataType initialType;
|
DataType initialType;
|
||||||
|
DataType initialDefaultType;
|
||||||
|
|
||||||
public SortCooTests(Nd4jBackend backend) {
|
public SortCooTests(Nd4jBackend backend) {
|
||||||
super(backend);
|
super(backend);
|
||||||
this.initialType = Nd4j.dataType();
|
this.initialType = Nd4j.dataType();
|
||||||
|
this.initialDefaultType = Nd4j.defaultFloatingPointType();
|
||||||
}
|
}
|
||||||
|
|
||||||
@Before
|
@Before
|
||||||
public void setUp() {
|
public void setUp() {
|
||||||
Nd4j.setDataType(DataType.FLOAT);
|
Nd4j.setDefaultDataTypes(DataType.FLOAT, DataType.FLOAT);
|
||||||
}
|
}
|
||||||
|
|
||||||
@After
|
@After
|
||||||
public void setDown() {
|
public void setDown() {
|
||||||
Nd4j.setDataType(initialType);
|
Nd4j.setDefaultDataTypes(initialType, Nd4j.defaultFloatingPointType());
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
|
@ -83,19 +88,17 @@ public class SortCooTests extends BaseNd4jTest {
|
||||||
1, 1, 1};
|
1, 1, 1};
|
||||||
double expValues[] = new double[] {0, 1, 2, 3};
|
double expValues[] = new double[] {0, 1, 2, 3};
|
||||||
|
|
||||||
|
for (DataType dataType : new DataType[]{DataType.FLOAT, DataType.DOUBLE, DataType.FLOAT16, DataType.INT64, DataType.INT32, DataType.INT16, DataType.INT8}) {
|
||||||
DataBuffer idx = Nd4j.getDataBufferFactory().createLong(indices);
|
DataBuffer idx = Nd4j.getDataBufferFactory().createLong(indices);
|
||||||
DataBuffer val = Nd4j.createBuffer(values);
|
DataBuffer val = Nd4j.createTypedBuffer(values, dataType);
|
||||||
|
DataBuffer shapeInfo = Nd4j.getShapeInfoProvider().createShapeInformation(new long[]{2, 2, 2}, val.dataType()).getFirst();
|
||||||
// log.info("Old indices: {}", Arrays.toString(idx.asInt()));
|
|
||||||
|
|
||||||
NativeOpsHolder.getInstance().getDeviceNativeOps().sortCooIndices(null, (LongPointer) idx.addressPointer(),
|
NativeOpsHolder.getInstance().getDeviceNativeOps().sortCooIndices(null, (LongPointer) idx.addressPointer(),
|
||||||
val.addressPointer(), 4, 3);
|
val.addressPointer(), 4, (LongPointer) shapeInfo.addressPointer());
|
||||||
|
|
||||||
|
|
||||||
// log.info("New indices: {}", Arrays.toString(idx.asInt()));
|
|
||||||
|
|
||||||
assertArrayEquals(expIndices, idx.asLong());
|
assertArrayEquals(expIndices, idx.asLong());
|
||||||
assertArrayEquals(expValues, val.asDouble(), 1e-5);
|
assertArrayEquals(expValues, val.asDouble(), 1e-5);
|
||||||
|
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
|
@ -117,14 +120,20 @@ public class SortCooTests extends BaseNd4jTest {
|
||||||
2, 2, 2};
|
2, 2, 2};
|
||||||
double expValues[] = new double[] {2, 3, 1};
|
double expValues[] = new double[] {2, 3, 1};
|
||||||
|
|
||||||
DataBuffer idx = Nd4j.getDataBufferFactory().createLong(indices);
|
|
||||||
DataBuffer val = Nd4j.createBuffer(values);
|
|
||||||
|
|
||||||
|
for (DataType dataType : new DataType[]{DataType.FLOAT, DataType.DOUBLE, DataType.FLOAT16, DataType.INT64, DataType.INT32, DataType.INT16, DataType.INT8}) {
|
||||||
|
DataBuffer idx = Nd4j.getDataBufferFactory().createLong(indices);
|
||||||
|
DataBuffer val = Nd4j.createTypedBuffer(values, dataType);
|
||||||
|
DataBuffer shapeInfo = Nd4j.getShapeInfoProvider().createShapeInformation(new long[]{2, 2, 2}, val.dataType()).getFirst();
|
||||||
NativeOpsHolder.getInstance().getDeviceNativeOps().sortCooIndices(null, (LongPointer) idx.addressPointer(),
|
NativeOpsHolder.getInstance().getDeviceNativeOps().sortCooIndices(null, (LongPointer) idx.addressPointer(),
|
||||||
val.addressPointer(), 3, 3);
|
val.addressPointer(), 3, (LongPointer) shapeInfo.addressPointer());
|
||||||
|
|
||||||
assertArrayEquals(expIndices, idx.asLong());
|
assertArrayEquals(expIndices, idx.asLong());
|
||||||
assertArrayEquals(expValues, val.asDouble(), 1e-5);
|
assertArrayEquals(expValues, val.asDouble(), 1e-5);
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -155,10 +164,11 @@ public class SortCooTests extends BaseNd4jTest {
|
||||||
|
|
||||||
DataBuffer indiceBuffer = Nd4j.getDataBufferFactory().createLong(indices);
|
DataBuffer indiceBuffer = Nd4j.getDataBufferFactory().createLong(indices);
|
||||||
DataBuffer valueBuffer = Nd4j.createBuffer(values);
|
DataBuffer valueBuffer = Nd4j.createBuffer(values);
|
||||||
|
DataBuffer shapeInfo = Nd4j.getShapeInfoProvider().createShapeInformation(new long[]{3,3,3}, valueBuffer.dataType()).getFirst();
|
||||||
INDArray indMatrix = Nd4j.create(indiceBuffer).reshape(new long[]{nnz, shape.length});
|
INDArray indMatrix = Nd4j.create(indiceBuffer).reshape(new long[]{nnz, shape.length});
|
||||||
|
|
||||||
NativeOpsHolder.getInstance().getDeviceNativeOps().sortCooIndices(null, (LongPointer) indiceBuffer.addressPointer(),
|
NativeOpsHolder.getInstance().getDeviceNativeOps().sortCooIndices(null, (LongPointer) indiceBuffer.addressPointer(),
|
||||||
valueBuffer.addressPointer(), nnz, 3);
|
valueBuffer.addressPointer(), nnz, (LongPointer) shapeInfo.addressPointer());
|
||||||
|
|
||||||
for (long i = 1; i < nnz; ++i){
|
for (long i = 1; i < nnz; ++i){
|
||||||
for(long j = 0; j < shape.length; ++j){
|
for(long j = 0; j < shape.length; ++j){
|
||||||
|
@ -273,9 +283,10 @@ public class SortCooTests extends BaseNd4jTest {
|
||||||
|
|
||||||
DataBuffer idx = Nd4j.getDataBufferFactory().createLong(indices);
|
DataBuffer idx = Nd4j.getDataBufferFactory().createLong(indices);
|
||||||
DataBuffer val = Nd4j.createBuffer(values);
|
DataBuffer val = Nd4j.createBuffer(values);
|
||||||
|
DataBuffer shapeInfo = Nd4j.getShapeInfoProvider().createShapeInformation(new long[]{3,3,3}, val.dataType()).getFirst();
|
||||||
|
|
||||||
NativeOpsHolder.getInstance().getDeviceNativeOps().sortCooIndices(null, (LongPointer) idx.addressPointer(),
|
NativeOpsHolder.getInstance().getDeviceNativeOps().sortCooIndices(null, (LongPointer) idx.addressPointer(),
|
||||||
val.addressPointer(), 40, 3);
|
val.addressPointer(), 40, (LongPointer) shapeInfo.addressPointer());
|
||||||
|
|
||||||
// just check the indices. sortSparseCooIndicesSort1 and sortSparseCooIndicesSort2 checks that
|
// just check the indices. sortSparseCooIndicesSort1 and sortSparseCooIndicesSort2 checks that
|
||||||
// indices and values are both swapped. This test just makes sure index sort works for larger arrays.
|
// indices and values are both swapped. This test just makes sure index sort works for larger arrays.
|
||||||
|
|
Loading…
Reference in New Issue