update sortCooIndicesGeneric to take any data type (#9121)

Previously, this function only worked correctly for 64bit data types.

Signed-off-by: Péter Zarándy <pza@wehowsky.com>
master
pza94 2020-11-27 07:08:25 +01:00 committed by GitHub
parent 8e591bbf39
commit a1fcc5f19f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 66 additions and 34 deletions

View File

@ -28,6 +28,7 @@
#include <ops/specials_sparse.h>
#include <execution/LaunchContext.h>
#include <array/ArrayOptions.h>
#include <helpers/shape.h>
/**
* Native op executioner:
@ -652,8 +653,11 @@ static void execTransformBool(sd::LaunchContext *lc,
BUILD_SINGLE_SELECTOR(xType, sd::SpecialMethods, ::sortTadGeneric(x, xShapeInfo, dimension, dimensionLength, tadShapeInfo, tadOffsets, descending), LIBND4J_TYPES);
}
inline static void execSortCooIndices(Nd4jLong *indices, void *values, Nd4jLong length, int rank) {
sd::sparse::SparseUtils<Nd4jLong>::sortCooIndicesGeneric(indices, reinterpret_cast<Nd4jLong *>(values), length, rank);
inline static void execSortCooIndices(Nd4jLong *indices, void *x, Nd4jLong length, const Nd4jLong *xShapeInfo) {
auto xType = sd::ArrayOptions::dataType(xShapeInfo);
int rank = shape::rank(xShapeInfo);
BUILD_SINGLE_SELECTOR(xType, sd::sparse::SparseUtils, ::sortCooIndicesGeneric(indices, x, length, rank), LIBND4J_TYPES);
}

View File

@ -1476,7 +1476,11 @@ ND4J_EXPORT void sortTadByValue(Nd4jPointer *extraPointers,
// special sort impl for sorting out COO indices and values
ND4J_EXPORT void sortCooIndices(Nd4jPointer *extraPointers, Nd4jLong *indices, void *values, Nd4jLong length, int rank);
ND4J_EXPORT void sortCooIndices(Nd4jPointer *extraPointers,
Nd4jLong *indices,
void *x,
Nd4jLong length,
const Nd4jLong *xShapeInfo);
ND4J_EXPORT Nd4jLong* mmapFile(Nd4jPointer *extraPointers, const char *fileName, Nd4jLong length);

View File

@ -1832,11 +1832,11 @@ void sortTad(Nd4jPointer *extraPointers,
void sortCooIndices(Nd4jPointer *extraPointers,
Nd4jLong *indices,
void *values,
void *x,
Nd4jLong length,
int rank) {
const Nd4jLong *xShapeInfo) {
try {
NativeOpExecutioner::execSortCooIndices(indices, values, length, rank);
NativeOpExecutioner::execSortCooIndices(indices, x, length, xShapeInfo);
} catch (std::exception &e) {
sd::LaunchContext::defaultContext()->errorReference()->setErrorCode(1);
sd::LaunchContext::defaultContext()->errorReference()->setErrorMessage(e.what());

View File

@ -2560,7 +2560,7 @@ void sortTad(Nd4jPointer *extraPointers,
}
}
void sortCooIndices(Nd4jPointer *extraPointers, Nd4jLong *indices, void *values, Nd4jLong length, int rank) {
void sortCooIndices(Nd4jPointer *extraPointers, Nd4jLong *indices, void *values, Nd4jLong length, const Nd4jLong *xShapeInfo) {
throw std::runtime_error("sortCooIndices:: Not implemented yet");
}

View File

@ -209,7 +209,8 @@ PRAGMA_OMP_SINGLE_ARGS(nowait)
}
template <typename T>
void SparseUtils<T>::sortCooIndicesGeneric(Nd4jLong *indices, T *values, Nd4jLong length, int rank) {
void SparseUtils<T>::sortCooIndicesGeneric(Nd4jLong *indices, void *vx, Nd4jLong length, int rank) {
auto values = reinterpret_cast<T *>(vx);
#ifdef _OPENMP
coo_quickSort_parallel(indices, values, length, omp_get_max_threads(), rank);
#else

View File

@ -60,7 +60,7 @@ namespace sd {
static Nd4jLong coo_quickSort_findPivot(Nd4jLong *indices, T *array, Nd4jLong left, Nd4jLong right,
int rank);
static void sortCooIndicesGeneric(Nd4jLong *indices, T *values, Nd4jLong length, int rank);
static void sortCooIndicesGeneric(Nd4jLong *indices, void *vx, Nd4jLong length, int rank);
};
}
}

View File

@ -1022,7 +1022,7 @@ public interface NativeOps {
boolean descending);
void sortCooIndices(PointerPointer extraPointers, @Cast("Nd4jLong *") LongPointer indices, Pointer values, long length, int rank);
void sortCooIndices(PointerPointer extraPointers, @Cast("Nd4jLong *") LongPointer indices, Pointer x, long length, @Cast("Nd4jLong *") LongPointer shapeInfo);
LongPointer mmapFile(PointerPointer extraPointers, String fileName, long length);

View File

@ -3135,9 +3135,21 @@ public native void sortTadByValue(@Cast("Nd4jPointer*") PointerPointer extraPoin
// special sort impl for sorting out COO indices and values
public native void sortCooIndices(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong*") LongPointer indices, Pointer values, @Cast("Nd4jLong") long length, int rank);
public native void sortCooIndices(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong*") LongBuffer indices, Pointer values, @Cast("Nd4jLong") long length, int rank);
public native void sortCooIndices(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong*") long[] indices, Pointer values, @Cast("Nd4jLong") long length, int rank);
public native void sortCooIndices(@Cast("Nd4jPointer*") PointerPointer extraPointers,
@Cast("Nd4jLong*") LongPointer indices,
Pointer x,
@Cast("Nd4jLong") long length,
@Cast("const Nd4jLong*") LongPointer xShapeInfo);
public native void sortCooIndices(@Cast("Nd4jPointer*") PointerPointer extraPointers,
@Cast("Nd4jLong*") LongBuffer indices,
Pointer x,
@Cast("Nd4jLong") long length,
@Cast("const Nd4jLong*") LongBuffer xShapeInfo);
public native void sortCooIndices(@Cast("Nd4jPointer*") PointerPointer extraPointers,
@Cast("Nd4jLong*") long[] indices,
Pointer x,
@Cast("Nd4jLong") long length,
@Cast("const Nd4jLong*") long[] xShapeInfo);
public native @Cast("Nd4jLong*") LongPointer mmapFile(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("char*") String fileName, @Cast("Nd4jLong") long length);

View File

@ -16,6 +16,8 @@
package org.nd4j.linalg.specials;
import com.google.common.primitives.Doubles;
import com.google.common.primitives.Floats;
import lombok.extern.slf4j.Slf4j;
import lombok.val;
import org.bytedeco.javacpp.LongPointer;
@ -29,6 +31,7 @@ import org.nd4j.linalg.api.buffer.DataBuffer;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.rng.Random;
import org.nd4j.linalg.api.shape.Shape;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.factory.Nd4jBackend;
import org.nd4j.nativeblas.NativeOpsHolder;
@ -46,20 +49,22 @@ import static org.junit.Assert.assertArrayEquals;
public class SortCooTests extends BaseNd4jTest {
DataType initialType;
DataType initialDefaultType;
public SortCooTests(Nd4jBackend backend) {
super(backend);
this.initialType = Nd4j.dataType();
this.initialDefaultType = Nd4j.defaultFloatingPointType();
}
@Before
public void setUp() {
Nd4j.setDataType(DataType.FLOAT);
Nd4j.setDefaultDataTypes(DataType.FLOAT, DataType.FLOAT);
}
@After
public void setDown() {
Nd4j.setDataType(initialType);
Nd4j.setDefaultDataTypes(initialType, Nd4j.defaultFloatingPointType());
}
@Test
@ -83,19 +88,17 @@ public class SortCooTests extends BaseNd4jTest {
1, 1, 1};
double expValues[] = new double[] {0, 1, 2, 3};
for (DataType dataType : new DataType[]{DataType.FLOAT, DataType.DOUBLE, DataType.FLOAT16, DataType.INT64, DataType.INT32, DataType.INT16, DataType.INT8}) {
DataBuffer idx = Nd4j.getDataBufferFactory().createLong(indices);
DataBuffer val = Nd4j.createBuffer(values);
// log.info("Old indices: {}", Arrays.toString(idx.asInt()));
DataBuffer val = Nd4j.createTypedBuffer(values, dataType);
DataBuffer shapeInfo = Nd4j.getShapeInfoProvider().createShapeInformation(new long[]{2, 2, 2}, val.dataType()).getFirst();
NativeOpsHolder.getInstance().getDeviceNativeOps().sortCooIndices(null, (LongPointer) idx.addressPointer(),
val.addressPointer(), 4, 3);
// log.info("New indices: {}", Arrays.toString(idx.asInt()));
val.addressPointer(), 4, (LongPointer) shapeInfo.addressPointer());
assertArrayEquals(expIndices, idx.asLong());
assertArrayEquals(expValues, val.asDouble(), 1e-5);
}
}
@Test
@ -117,14 +120,20 @@ public class SortCooTests extends BaseNd4jTest {
2, 2, 2};
double expValues[] = new double[] {2, 3, 1};
DataBuffer idx = Nd4j.getDataBufferFactory().createLong(indices);
DataBuffer val = Nd4j.createBuffer(values);
for (DataType dataType : new DataType[]{DataType.FLOAT, DataType.DOUBLE, DataType.FLOAT16, DataType.INT64, DataType.INT32, DataType.INT16, DataType.INT8}) {
DataBuffer idx = Nd4j.getDataBufferFactory().createLong(indices);
DataBuffer val = Nd4j.createTypedBuffer(values, dataType);
DataBuffer shapeInfo = Nd4j.getShapeInfoProvider().createShapeInformation(new long[]{2, 2, 2}, val.dataType()).getFirst();
NativeOpsHolder.getInstance().getDeviceNativeOps().sortCooIndices(null, (LongPointer) idx.addressPointer(),
val.addressPointer(), 3, 3);
val.addressPointer(), 3, (LongPointer) shapeInfo.addressPointer());
assertArrayEquals(expIndices, idx.asLong());
assertArrayEquals(expValues, val.asDouble(), 1e-5);
}
}
/**
@ -155,10 +164,11 @@ public class SortCooTests extends BaseNd4jTest {
DataBuffer indiceBuffer = Nd4j.getDataBufferFactory().createLong(indices);
DataBuffer valueBuffer = Nd4j.createBuffer(values);
DataBuffer shapeInfo = Nd4j.getShapeInfoProvider().createShapeInformation(new long[]{3,3,3}, valueBuffer.dataType()).getFirst();
INDArray indMatrix = Nd4j.create(indiceBuffer).reshape(new long[]{nnz, shape.length});
NativeOpsHolder.getInstance().getDeviceNativeOps().sortCooIndices(null, (LongPointer) indiceBuffer.addressPointer(),
valueBuffer.addressPointer(), nnz, 3);
valueBuffer.addressPointer(), nnz, (LongPointer) shapeInfo.addressPointer());
for (long i = 1; i < nnz; ++i){
for(long j = 0; j < shape.length; ++j){
@ -273,9 +283,10 @@ public class SortCooTests extends BaseNd4jTest {
DataBuffer idx = Nd4j.getDataBufferFactory().createLong(indices);
DataBuffer val = Nd4j.createBuffer(values);
DataBuffer shapeInfo = Nd4j.getShapeInfoProvider().createShapeInformation(new long[]{3,3,3}, val.dataType()).getFirst();
NativeOpsHolder.getInstance().getDeviceNativeOps().sortCooIndices(null, (LongPointer) idx.addressPointer(),
val.addressPointer(), 40, 3);
val.addressPointer(), 40, (LongPointer) shapeInfo.addressPointer());
// just check the indices. sortSparseCooIndicesSort1 and sortSparseCooIndicesSort2 checks that
// indices and values are both swapped. This test just makes sure index sort works for larger arrays.