From 45ebd4899c009cf7776abe275c2ef5269819245a Mon Sep 17 00:00:00 2001 From: raver119 Date: Tue, 2 Jun 2020 10:43:12 +0300 Subject: [PATCH] CUDA small sort tests (#482) * couple of C++ sort tests Signed-off-by: raver119@gmail.com * Java sort test Signed-off-by: raver119@gmail.com --- libnd4j/include/array/NDArray.h | 8 +-- .../layers_tests/LegacyOpsCudaTests.cu | 52 +++++++++++++++++++ .../java/org/nd4j/nativeblas/Nd4jCuda.java | 15 +++--- .../java/org/nd4j/nativeblas/Nd4jCpu.java | 15 +++--- .../test/java/org/nd4j/linalg/Nd4jTestsC.java | 8 +++ 5 files changed, 82 insertions(+), 16 deletions(-) diff --git a/libnd4j/include/array/NDArray.h b/libnd4j/include/array/NDArray.h index 04500a987..c314d25b6 100644 --- a/libnd4j/include/array/NDArray.h +++ b/libnd4j/include/array/NDArray.h @@ -354,11 +354,11 @@ namespace sd { * @param writeList * @param readList */ - static void registerSpecialUse(const std::vector& writeList, const std::vector& readList); - static void prepareSpecialUse(const std::vector& writeList, const std::vector& readList, bool synchronizeWritables = false); + static void registerSpecialUse(const std::vector& writeList, const std::vector& readList = {}); + static void prepareSpecialUse(const std::vector& writeList, const std::vector& readList = {}, bool synchronizeWritables = false); - static void registerPrimaryUse(const std::vector& writeList, const std::vector& readList); - static void preparePrimaryUse(const std::vector& writeList, const std::vector& readList, bool synchronizeWritables = false); + static void registerPrimaryUse(const std::vector& writeList, const std::vector& readList = {}); + static void preparePrimaryUse(const std::vector& writeList, const std::vector& readList = {}, bool synchronizeWritables = false); /** * This method returns buffer pointer offset by given number of elements, wrt own data type diff --git a/libnd4j/tests_cpu/layers_tests/LegacyOpsCudaTests.cu b/libnd4j/tests_cpu/layers_tests/LegacyOpsCudaTests.cu index 53179cd68..622ce9fbb 100644 --- a/libnd4j/tests_cpu/layers_tests/LegacyOpsCudaTests.cu +++ b/libnd4j/tests_cpu/layers_tests/LegacyOpsCudaTests.cu @@ -58,3 +58,55 @@ TEST_F(LegacyOpsCudaTests, test_sortTad_1) { ASSERT_EQ(e, x); } + +TEST_F(LegacyOpsCudaTests, test_sort_1) { + auto x = NDArrayFactory::create('c', {4}, {4.f, 2.f, 1.f, 3.f}); + auto e = NDArrayFactory::create('c', {4}, {1.f, 2.f, 3.f, 4.f}); + + Nd4jPointer extras[2] = {nullptr, LaunchContext::defaultContext()->getCudaStream()}; + + NDArray::prepareSpecialUse({&x}, {&x}); + ::sort(extras, x.buffer(), x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(), false); + NDArray::registerSpecialUse({&x}); + + ASSERT_EQ(e, x); +} + +TEST_F(LegacyOpsCudaTests, test_sort_2) { + auto x = NDArrayFactory::create('c', {4}, {4.f, 2.f, 1.f, 3.f}); + auto e = NDArrayFactory::create('c', {4}, {4.f, 3.f, 2.f, 1.f}); + + Nd4jPointer extras[2] = {nullptr, LaunchContext::defaultContext()->getCudaStream()}; + + NDArray::prepareSpecialUse({&x}, {&x}); + ::sort(extras, x.buffer(), x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(), true); + NDArray::registerSpecialUse({&x}); + + ASSERT_EQ(e, x); +} + +TEST_F(LegacyOpsCudaTests, test_sort_3) { + auto x = NDArrayFactory::create('c', {4}, {0.5, 0.4, 0.1, 0.2}); + auto e = NDArrayFactory::create('c', {4}, {0.1, 0.2, 0.4, 0.5}); + + Nd4jPointer extras[2] = {nullptr, LaunchContext::defaultContext()->getCudaStream()}; + + NDArray::prepareSpecialUse({&x}, {&x}); + ::sort(extras, x.buffer(), x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(), false); + NDArray::registerSpecialUse({&x}); + + ASSERT_EQ(e, x); +} + +TEST_F(LegacyOpsCudaTests, test_sort_4) { + auto x = NDArrayFactory::create('c', {4}, {7, 4, 9, 2}); + auto e = NDArrayFactory::create('c', {4}, {2, 4, 7, 9}); + + Nd4jPointer extras[2] = {nullptr, LaunchContext::defaultContext()->getCudaStream()}; + + NDArray::prepareSpecialUse({&x}, {&x}); + ::sort(extras, x.buffer(), x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(), false); + NDArray::registerSpecialUse({&x}); + + ASSERT_EQ(e, x); +} \ No newline at end of file diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/nativeblas/Nd4jCuda.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/nativeblas/Nd4jCuda.java index ad9503849..cc6ffc19a 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/nativeblas/Nd4jCuda.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/nativeblas/Nd4jCuda.java @@ -3849,13 +3849,15 @@ public native @Cast("bool") boolean isOptimalRequirementsMet(); * @param writeList * @param readList */ - public static native void registerSpecialUse(@Const @ByRef ConstNDArrayVector writeList, @Const @ByRef ConstNDArrayVector readList); - public static native void prepareSpecialUse(@Const @ByRef ConstNDArrayVector writeList, @Const @ByRef ConstNDArrayVector readList, @Cast("bool") boolean synchronizeWritables/*=false*/); - public static native void prepareSpecialUse(@Const @ByRef ConstNDArrayVector writeList, @Const @ByRef ConstNDArrayVector readList); + public static native void registerSpecialUse(@Const @ByRef ConstNDArrayVector writeList, @Const @ByRef(nullValue = "std::vector({})") ConstNDArrayVector readList); + public static native void registerSpecialUse(@Const @ByRef ConstNDArrayVector writeList); + public static native void prepareSpecialUse(@Const @ByRef ConstNDArrayVector writeList, @Const @ByRef(nullValue = "std::vector({})") ConstNDArrayVector readList, @Cast("bool") boolean synchronizeWritables/*=false*/); + public static native void prepareSpecialUse(@Const @ByRef ConstNDArrayVector writeList); - public static native void registerPrimaryUse(@Const @ByRef ConstNDArrayVector writeList, @Const @ByRef ConstNDArrayVector readList); - public static native void preparePrimaryUse(@Const @ByRef ConstNDArrayVector writeList, @Const @ByRef ConstNDArrayVector readList, @Cast("bool") boolean synchronizeWritables/*=false*/); - public static native void preparePrimaryUse(@Const @ByRef ConstNDArrayVector writeList, @Const @ByRef ConstNDArrayVector readList); + public static native void registerPrimaryUse(@Const @ByRef ConstNDArrayVector writeList, @Const @ByRef(nullValue = "std::vector({})") ConstNDArrayVector readList); + public static native void registerPrimaryUse(@Const @ByRef ConstNDArrayVector writeList); + public static native void preparePrimaryUse(@Const @ByRef ConstNDArrayVector writeList, @Const @ByRef(nullValue = "std::vector({})") ConstNDArrayVector readList, @Cast("bool") boolean synchronizeWritables/*=false*/); + public static native void preparePrimaryUse(@Const @ByRef ConstNDArrayVector writeList); /** * This method returns buffer pointer offset by given number of elements, wrt own data type @@ -5043,6 +5045,7 @@ public native @Cast("bool") boolean isOptimalRequirementsMet(); // #define LIBND4J_GRAPH_RNG_H // #include +// #include // #include // #include // #include diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/nativeblas/Nd4jCpu.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/nativeblas/Nd4jCpu.java index 402b096c6..f17f11093 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/nativeblas/Nd4jCpu.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/nativeblas/Nd4jCpu.java @@ -3853,13 +3853,15 @@ public native @Cast("bool") boolean isOptimalRequirementsMet(); * @param writeList * @param readList */ - public static native void registerSpecialUse(@Const @ByRef ConstNDArrayVector writeList, @Const @ByRef ConstNDArrayVector readList); - public static native void prepareSpecialUse(@Const @ByRef ConstNDArrayVector writeList, @Const @ByRef ConstNDArrayVector readList, @Cast("bool") boolean synchronizeWritables/*=false*/); - public static native void prepareSpecialUse(@Const @ByRef ConstNDArrayVector writeList, @Const @ByRef ConstNDArrayVector readList); + public static native void registerSpecialUse(@Const @ByRef ConstNDArrayVector writeList, @Const @ByRef(nullValue = "std::vector({})") ConstNDArrayVector readList); + public static native void registerSpecialUse(@Const @ByRef ConstNDArrayVector writeList); + public static native void prepareSpecialUse(@Const @ByRef ConstNDArrayVector writeList, @Const @ByRef(nullValue = "std::vector({})") ConstNDArrayVector readList, @Cast("bool") boolean synchronizeWritables/*=false*/); + public static native void prepareSpecialUse(@Const @ByRef ConstNDArrayVector writeList); - public static native void registerPrimaryUse(@Const @ByRef ConstNDArrayVector writeList, @Const @ByRef ConstNDArrayVector readList); - public static native void preparePrimaryUse(@Const @ByRef ConstNDArrayVector writeList, @Const @ByRef ConstNDArrayVector readList, @Cast("bool") boolean synchronizeWritables/*=false*/); - public static native void preparePrimaryUse(@Const @ByRef ConstNDArrayVector writeList, @Const @ByRef ConstNDArrayVector readList); + public static native void registerPrimaryUse(@Const @ByRef ConstNDArrayVector writeList, @Const @ByRef(nullValue = "std::vector({})") ConstNDArrayVector readList); + public static native void registerPrimaryUse(@Const @ByRef ConstNDArrayVector writeList); + public static native void preparePrimaryUse(@Const @ByRef ConstNDArrayVector writeList, @Const @ByRef(nullValue = "std::vector({})") ConstNDArrayVector readList, @Cast("bool") boolean synchronizeWritables/*=false*/); + public static native void preparePrimaryUse(@Const @ByRef ConstNDArrayVector writeList); /** * This method returns buffer pointer offset by given number of elements, wrt own data type @@ -5047,6 +5049,7 @@ public native @Cast("bool") boolean isOptimalRequirementsMet(); // #define LIBND4J_GRAPH_RNG_H // #include +// #include // #include // #include // #include diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/Nd4jTestsC.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/Nd4jTestsC.java index c9f5cef6f..e6c380b31 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/Nd4jTestsC.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/Nd4jTestsC.java @@ -8484,6 +8484,14 @@ public class Nd4jTestsC extends BaseNd4jTest { } } + @Test + public void testSmallSort(){ + INDArray arr = Nd4j.createFromArray(0.5, 0.4, 0.1, 0.2); + INDArray expected = Nd4j.createFromArray(0.1, 0.2, 0.4, 0.5); + INDArray sorted = Nd4j.sort(arr, true); + assertEquals(expected, sorted); + } + @Override public char ordering() { return 'c';