update sortCooIndicesGeneric to take any data type (#9121)
Previously, this function only worked correctly for 64bit data types. Signed-off-by: Péter Zarándy <pza@wehowsky.com>master
parent
8e591bbf39
commit
a1fcc5f19f
|
@ -28,6 +28,7 @@
|
|||
#include <ops/specials_sparse.h>
|
||||
#include <execution/LaunchContext.h>
|
||||
#include <array/ArrayOptions.h>
|
||||
#include <helpers/shape.h>
|
||||
|
||||
/**
|
||||
* Native op executioner:
|
||||
|
@ -652,8 +653,11 @@ static void execTransformBool(sd::LaunchContext *lc,
|
|||
BUILD_SINGLE_SELECTOR(xType, sd::SpecialMethods, ::sortTadGeneric(x, xShapeInfo, dimension, dimensionLength, tadShapeInfo, tadOffsets, descending), LIBND4J_TYPES);
|
||||
}
|
||||
|
||||
inline static void execSortCooIndices(Nd4jLong *indices, void *values, Nd4jLong length, int rank) {
|
||||
sd::sparse::SparseUtils<Nd4jLong>::sortCooIndicesGeneric(indices, reinterpret_cast<Nd4jLong *>(values), length, rank);
|
||||
inline static void execSortCooIndices(Nd4jLong *indices, void *x, Nd4jLong length, const Nd4jLong *xShapeInfo) {
|
||||
auto xType = sd::ArrayOptions::dataType(xShapeInfo);
|
||||
int rank = shape::rank(xShapeInfo);
|
||||
|
||||
BUILD_SINGLE_SELECTOR(xType, sd::sparse::SparseUtils, ::sortCooIndicesGeneric(indices, x, length, rank), LIBND4J_TYPES);
|
||||
}
|
||||
|
||||
|
||||
|
|
|
@ -1476,7 +1476,11 @@ ND4J_EXPORT void sortTadByValue(Nd4jPointer *extraPointers,
|
|||
|
||||
|
||||
// special sort impl for sorting out COO indices and values
|
||||
ND4J_EXPORT void sortCooIndices(Nd4jPointer *extraPointers, Nd4jLong *indices, void *values, Nd4jLong length, int rank);
|
||||
ND4J_EXPORT void sortCooIndices(Nd4jPointer *extraPointers,
|
||||
Nd4jLong *indices,
|
||||
void *x,
|
||||
Nd4jLong length,
|
||||
const Nd4jLong *xShapeInfo);
|
||||
|
||||
|
||||
ND4J_EXPORT Nd4jLong* mmapFile(Nd4jPointer *extraPointers, const char *fileName, Nd4jLong length);
|
||||
|
|
|
@ -1832,11 +1832,11 @@ void sortTad(Nd4jPointer *extraPointers,
|
|||
|
||||
void sortCooIndices(Nd4jPointer *extraPointers,
|
||||
Nd4jLong *indices,
|
||||
void *values,
|
||||
void *x,
|
||||
Nd4jLong length,
|
||||
int rank) {
|
||||
const Nd4jLong *xShapeInfo) {
|
||||
try {
|
||||
NativeOpExecutioner::execSortCooIndices(indices, values, length, rank);
|
||||
NativeOpExecutioner::execSortCooIndices(indices, x, length, xShapeInfo);
|
||||
} catch (std::exception &e) {
|
||||
sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1);
|
||||
sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what());
|
||||
|
|
|
@ -2560,7 +2560,7 @@ void sortTad(Nd4jPointer *extraPointers,
|
|||
}
|
||||
}
|
||||
|
||||
void sortCooIndices(Nd4jPointer *extraPointers, Nd4jLong *indices, void *values, Nd4jLong length, int rank) {
|
||||
void sortCooIndices(Nd4jPointer *extraPointers, Nd4jLong *indices, void *values, Nd4jLong length, const Nd4jLong *xShapeInfo) {
|
||||
throw std::runtime_error("sortCooIndices:: Not implemented yet");
|
||||
}
|
||||
|
||||
|
|
|
@ -209,7 +209,8 @@ PRAGMA_OMP_SINGLE_ARGS(nowait)
|
|||
}
|
||||
|
||||
template <typename T>
|
||||
void SparseUtils<T>::sortCooIndicesGeneric(Nd4jLong *indices, T *values, Nd4jLong length, int rank) {
|
||||
void SparseUtils<T>::sortCooIndicesGeneric(Nd4jLong *indices, void *vx, Nd4jLong length, int rank) {
|
||||
auto values = reinterpret_cast<T *>(vx);
|
||||
#ifdef _OPENMP
|
||||
coo_quickSort_parallel(indices, values, length, omp_get_max_threads(), rank);
|
||||
#else
|
||||
|
|
|
@ -60,7 +60,7 @@ namespace sd {
|
|||
static Nd4jLong coo_quickSort_findPivot(Nd4jLong *indices, T *array, Nd4jLong left, Nd4jLong right,
|
||||
int rank);
|
||||
|
||||
static void sortCooIndicesGeneric(Nd4jLong *indices, T *values, Nd4jLong length, int rank);
|
||||
static void sortCooIndicesGeneric(Nd4jLong *indices, void *vx, Nd4jLong length, int rank);
|
||||
};
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1022,7 +1022,7 @@ public interface NativeOps {
|
|||
boolean descending);
|
||||
|
||||
|
||||
void sortCooIndices(PointerPointer extraPointers, @Cast("Nd4jLong *") LongPointer indices, Pointer values, long length, int rank);
|
||||
void sortCooIndices(PointerPointer extraPointers, @Cast("Nd4jLong *") LongPointer indices, Pointer x, long length, @Cast("Nd4jLong *") LongPointer shapeInfo);
|
||||
|
||||
|
||||
LongPointer mmapFile(PointerPointer extraPointers, String fileName, long length);
|
||||
|
|
|
@ -3135,9 +3135,21 @@ public native void sortTadByValue(@Cast("Nd4jPointer*") PointerPointer extraPoin
|
|||
|
||||
|
||||
// special sort impl for sorting out COO indices and values
|
||||
public native void sortCooIndices(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong*") LongPointer indices, Pointer values, @Cast("Nd4jLong") long length, int rank);
|
||||
public native void sortCooIndices(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong*") LongBuffer indices, Pointer values, @Cast("Nd4jLong") long length, int rank);
|
||||
public native void sortCooIndices(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong*") long[] indices, Pointer values, @Cast("Nd4jLong") long length, int rank);
|
||||
public native void sortCooIndices(@Cast("Nd4jPointer*") PointerPointer extraPointers,
|
||||
@Cast("Nd4jLong*") LongPointer indices,
|
||||
Pointer x,
|
||||
@Cast("Nd4jLong") long length,
|
||||
@Cast("const Nd4jLong*") LongPointer xShapeInfo);
|
||||
public native void sortCooIndices(@Cast("Nd4jPointer*") PointerPointer extraPointers,
|
||||
@Cast("Nd4jLong*") LongBuffer indices,
|
||||
Pointer x,
|
||||
@Cast("Nd4jLong") long length,
|
||||
@Cast("const Nd4jLong*") LongBuffer xShapeInfo);
|
||||
public native void sortCooIndices(@Cast("Nd4jPointer*") PointerPointer extraPointers,
|
||||
@Cast("Nd4jLong*") long[] indices,
|
||||
Pointer x,
|
||||
@Cast("Nd4jLong") long length,
|
||||
@Cast("const Nd4jLong*") long[] xShapeInfo);
|
||||
|
||||
|
||||
public native @Cast("Nd4jLong*") LongPointer mmapFile(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("char*") String fileName, @Cast("Nd4jLong") long length);
|
||||
|
|
|
@ -16,6 +16,8 @@
|
|||
|
||||
package org.nd4j.linalg.specials;
|
||||
|
||||
import com.google.common.primitives.Doubles;
|
||||
import com.google.common.primitives.Floats;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import lombok.val;
|
||||
import org.bytedeco.javacpp.LongPointer;
|
||||
|
@ -29,6 +31,7 @@ import org.nd4j.linalg.api.buffer.DataBuffer;
|
|||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.api.rng.Random;
|
||||
import org.nd4j.linalg.api.shape.Shape;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
import org.nd4j.linalg.factory.Nd4jBackend;
|
||||
import org.nd4j.nativeblas.NativeOpsHolder;
|
||||
|
@ -46,20 +49,22 @@ import static org.junit.Assert.assertArrayEquals;
|
|||
public class SortCooTests extends BaseNd4jTest {
|
||||
|
||||
DataType initialType;
|
||||
DataType initialDefaultType;
|
||||
|
||||
public SortCooTests(Nd4jBackend backend) {
|
||||
super(backend);
|
||||
this.initialType = Nd4j.dataType();
|
||||
this.initialDefaultType = Nd4j.defaultFloatingPointType();
|
||||
}
|
||||
|
||||
@Before
|
||||
public void setUp() {
|
||||
Nd4j.setDataType(DataType.FLOAT);
|
||||
Nd4j.setDefaultDataTypes(DataType.FLOAT, DataType.FLOAT);
|
||||
}
|
||||
|
||||
@After
|
||||
public void setDown() {
|
||||
Nd4j.setDataType(initialType);
|
||||
Nd4j.setDefaultDataTypes(initialType, Nd4j.defaultFloatingPointType());
|
||||
}
|
||||
|
||||
@Test
|
||||
|
@ -83,19 +88,17 @@ public class SortCooTests extends BaseNd4jTest {
|
|||
1, 1, 1};
|
||||
double expValues[] = new double[] {0, 1, 2, 3};
|
||||
|
||||
for (DataType dataType : new DataType[]{DataType.FLOAT, DataType.DOUBLE, DataType.FLOAT16, DataType.INT64, DataType.INT32, DataType.INT16, DataType.INT8}) {
|
||||
DataBuffer idx = Nd4j.getDataBufferFactory().createLong(indices);
|
||||
DataBuffer val = Nd4j.createBuffer(values);
|
||||
|
||||
// log.info("Old indices: {}", Arrays.toString(idx.asInt()));
|
||||
|
||||
DataBuffer val = Nd4j.createTypedBuffer(values, dataType);
|
||||
DataBuffer shapeInfo = Nd4j.getShapeInfoProvider().createShapeInformation(new long[]{2, 2, 2}, val.dataType()).getFirst();
|
||||
NativeOpsHolder.getInstance().getDeviceNativeOps().sortCooIndices(null, (LongPointer) idx.addressPointer(),
|
||||
val.addressPointer(), 4, 3);
|
||||
|
||||
|
||||
// log.info("New indices: {}", Arrays.toString(idx.asInt()));
|
||||
val.addressPointer(), 4, (LongPointer) shapeInfo.addressPointer());
|
||||
|
||||
assertArrayEquals(expIndices, idx.asLong());
|
||||
assertArrayEquals(expValues, val.asDouble(), 1e-5);
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
|
@ -117,14 +120,20 @@ public class SortCooTests extends BaseNd4jTest {
|
|||
2, 2, 2};
|
||||
double expValues[] = new double[] {2, 3, 1};
|
||||
|
||||
DataBuffer idx = Nd4j.getDataBufferFactory().createLong(indices);
|
||||
DataBuffer val = Nd4j.createBuffer(values);
|
||||
|
||||
for (DataType dataType : new DataType[]{DataType.FLOAT, DataType.DOUBLE, DataType.FLOAT16, DataType.INT64, DataType.INT32, DataType.INT16, DataType.INT8}) {
|
||||
DataBuffer idx = Nd4j.getDataBufferFactory().createLong(indices);
|
||||
DataBuffer val = Nd4j.createTypedBuffer(values, dataType);
|
||||
DataBuffer shapeInfo = Nd4j.getShapeInfoProvider().createShapeInformation(new long[]{2, 2, 2}, val.dataType()).getFirst();
|
||||
NativeOpsHolder.getInstance().getDeviceNativeOps().sortCooIndices(null, (LongPointer) idx.addressPointer(),
|
||||
val.addressPointer(), 3, 3);
|
||||
val.addressPointer(), 3, (LongPointer) shapeInfo.addressPointer());
|
||||
|
||||
assertArrayEquals(expIndices, idx.asLong());
|
||||
assertArrayEquals(expValues, val.asDouble(), 1e-5);
|
||||
|
||||
}
|
||||
|
||||
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -155,10 +164,11 @@ public class SortCooTests extends BaseNd4jTest {
|
|||
|
||||
DataBuffer indiceBuffer = Nd4j.getDataBufferFactory().createLong(indices);
|
||||
DataBuffer valueBuffer = Nd4j.createBuffer(values);
|
||||
DataBuffer shapeInfo = Nd4j.getShapeInfoProvider().createShapeInformation(new long[]{3,3,3}, valueBuffer.dataType()).getFirst();
|
||||
INDArray indMatrix = Nd4j.create(indiceBuffer).reshape(new long[]{nnz, shape.length});
|
||||
|
||||
NativeOpsHolder.getInstance().getDeviceNativeOps().sortCooIndices(null, (LongPointer) indiceBuffer.addressPointer(),
|
||||
valueBuffer.addressPointer(), nnz, 3);
|
||||
valueBuffer.addressPointer(), nnz, (LongPointer) shapeInfo.addressPointer());
|
||||
|
||||
for (long i = 1; i < nnz; ++i){
|
||||
for(long j = 0; j < shape.length; ++j){
|
||||
|
@ -273,9 +283,10 @@ public class SortCooTests extends BaseNd4jTest {
|
|||
|
||||
DataBuffer idx = Nd4j.getDataBufferFactory().createLong(indices);
|
||||
DataBuffer val = Nd4j.createBuffer(values);
|
||||
DataBuffer shapeInfo = Nd4j.getShapeInfoProvider().createShapeInformation(new long[]{3,3,3}, val.dataType()).getFirst();
|
||||
|
||||
NativeOpsHolder.getInstance().getDeviceNativeOps().sortCooIndices(null, (LongPointer) idx.addressPointer(),
|
||||
val.addressPointer(), 40, 3);
|
||||
val.addressPointer(), 40, (LongPointer) shapeInfo.addressPointer());
|
||||
|
||||
// just check the indices. sortSparseCooIndicesSort1 and sortSparseCooIndicesSort2 checks that
|
||||
// indices and values are both swapped. This test just makes sure index sort works for larger arrays.
|
||||
|
|
Loading…
Reference in New Issue