CUDA small sort tests (#482)

* couple of C++ sort tests

Signed-off-by: raver119@gmail.com <raver119@gmail.com>

* Java sort test

Signed-off-by: raver119@gmail.com <raver119@gmail.com>
master
raver119 2020-06-02 10:43:12 +03:00 committed by GitHub
parent c783a5938a
commit 45ebd4899c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 82 additions and 16 deletions

View File

@ -354,11 +354,11 @@ namespace sd {
* @param writeList * @param writeList
* @param readList * @param readList
*/ */
static void registerSpecialUse(const std::vector<const NDArray*>& writeList, const std::vector<const NDArray*>& readList); static void registerSpecialUse(const std::vector<const NDArray*>& writeList, const std::vector<const NDArray*>& readList = {});
static void prepareSpecialUse(const std::vector<const NDArray*>& writeList, const std::vector<const NDArray*>& readList, bool synchronizeWritables = false); static void prepareSpecialUse(const std::vector<const NDArray*>& writeList, const std::vector<const NDArray*>& readList = {}, bool synchronizeWritables = false);
static void registerPrimaryUse(const std::vector<const NDArray*>& writeList, const std::vector<const NDArray*>& readList); static void registerPrimaryUse(const std::vector<const NDArray*>& writeList, const std::vector<const NDArray*>& readList = {});
static void preparePrimaryUse(const std::vector<const NDArray*>& writeList, const std::vector<const NDArray*>& readList, bool synchronizeWritables = false); static void preparePrimaryUse(const std::vector<const NDArray*>& writeList, const std::vector<const NDArray*>& readList = {}, bool synchronizeWritables = false);
/** /**
* This method returns buffer pointer offset by given number of elements, wrt own data type * This method returns buffer pointer offset by given number of elements, wrt own data type

View File

@ -58,3 +58,55 @@ TEST_F(LegacyOpsCudaTests, test_sortTad_1) {
ASSERT_EQ(e, x); ASSERT_EQ(e, x);
} }
TEST_F(LegacyOpsCudaTests, test_sort_1) {
auto x = NDArrayFactory::create<float>('c', {4}, {4.f, 2.f, 1.f, 3.f});
auto e = NDArrayFactory::create<float>('c', {4}, {1.f, 2.f, 3.f, 4.f});
Nd4jPointer extras[2] = {nullptr, LaunchContext::defaultContext()->getCudaStream()};
NDArray::prepareSpecialUse({&x}, {&x});
::sort(extras, x.buffer(), x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(), false);
NDArray::registerSpecialUse({&x});
ASSERT_EQ(e, x);
}
TEST_F(LegacyOpsCudaTests, test_sort_2) {
auto x = NDArrayFactory::create<float>('c', {4}, {4.f, 2.f, 1.f, 3.f});
auto e = NDArrayFactory::create<float>('c', {4}, {4.f, 3.f, 2.f, 1.f});
Nd4jPointer extras[2] = {nullptr, LaunchContext::defaultContext()->getCudaStream()};
NDArray::prepareSpecialUse({&x}, {&x});
::sort(extras, x.buffer(), x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(), true);
NDArray::registerSpecialUse({&x});
ASSERT_EQ(e, x);
}
TEST_F(LegacyOpsCudaTests, test_sort_3) {
auto x = NDArrayFactory::create<double>('c', {4}, {0.5, 0.4, 0.1, 0.2});
auto e = NDArrayFactory::create<double>('c', {4}, {0.1, 0.2, 0.4, 0.5});
Nd4jPointer extras[2] = {nullptr, LaunchContext::defaultContext()->getCudaStream()};
NDArray::prepareSpecialUse({&x}, {&x});
::sort(extras, x.buffer(), x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(), false);
NDArray::registerSpecialUse({&x});
ASSERT_EQ(e, x);
}
TEST_F(LegacyOpsCudaTests, test_sort_4) {
auto x = NDArrayFactory::create<double>('c', {4}, {7, 4, 9, 2});
auto e = NDArrayFactory::create<double>('c', {4}, {2, 4, 7, 9});
Nd4jPointer extras[2] = {nullptr, LaunchContext::defaultContext()->getCudaStream()};
NDArray::prepareSpecialUse({&x}, {&x});
::sort(extras, x.buffer(), x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(), false);
NDArray::registerSpecialUse({&x});
ASSERT_EQ(e, x);
}

View File

@ -3849,13 +3849,15 @@ public native @Cast("bool") boolean isOptimalRequirementsMet();
* @param writeList * @param writeList
* @param readList * @param readList
*/ */
public static native void registerSpecialUse(@Const @ByRef ConstNDArrayVector writeList, @Const @ByRef ConstNDArrayVector readList); public static native void registerSpecialUse(@Const @ByRef ConstNDArrayVector writeList, @Const @ByRef(nullValue = "std::vector<const sd::NDArray*>({})") ConstNDArrayVector readList);
public static native void prepareSpecialUse(@Const @ByRef ConstNDArrayVector writeList, @Const @ByRef ConstNDArrayVector readList, @Cast("bool") boolean synchronizeWritables/*=false*/); public static native void registerSpecialUse(@Const @ByRef ConstNDArrayVector writeList);
public static native void prepareSpecialUse(@Const @ByRef ConstNDArrayVector writeList, @Const @ByRef ConstNDArrayVector readList); public static native void prepareSpecialUse(@Const @ByRef ConstNDArrayVector writeList, @Const @ByRef(nullValue = "std::vector<const sd::NDArray*>({})") ConstNDArrayVector readList, @Cast("bool") boolean synchronizeWritables/*=false*/);
public static native void prepareSpecialUse(@Const @ByRef ConstNDArrayVector writeList);
public static native void registerPrimaryUse(@Const @ByRef ConstNDArrayVector writeList, @Const @ByRef ConstNDArrayVector readList); public static native void registerPrimaryUse(@Const @ByRef ConstNDArrayVector writeList, @Const @ByRef(nullValue = "std::vector<const sd::NDArray*>({})") ConstNDArrayVector readList);
public static native void preparePrimaryUse(@Const @ByRef ConstNDArrayVector writeList, @Const @ByRef ConstNDArrayVector readList, @Cast("bool") boolean synchronizeWritables/*=false*/); public static native void registerPrimaryUse(@Const @ByRef ConstNDArrayVector writeList);
public static native void preparePrimaryUse(@Const @ByRef ConstNDArrayVector writeList, @Const @ByRef ConstNDArrayVector readList); public static native void preparePrimaryUse(@Const @ByRef ConstNDArrayVector writeList, @Const @ByRef(nullValue = "std::vector<const sd::NDArray*>({})") ConstNDArrayVector readList, @Cast("bool") boolean synchronizeWritables/*=false*/);
public static native void preparePrimaryUse(@Const @ByRef ConstNDArrayVector writeList);
/** /**
* This method returns buffer pointer offset by given number of elements, wrt own data type * This method returns buffer pointer offset by given number of elements, wrt own data type
@ -5043,6 +5045,7 @@ public native @Cast("bool") boolean isOptimalRequirementsMet();
// #define LIBND4J_GRAPH_RNG_H // #define LIBND4J_GRAPH_RNG_H
// #include <types/u64.h> // #include <types/u64.h>
// #include <types/u32.h>
// #include <system/pointercast.h> // #include <system/pointercast.h>
// #include <system/op_boilerplate.h> // #include <system/op_boilerplate.h>
// #include <system/dll.h> // #include <system/dll.h>

View File

@ -3853,13 +3853,15 @@ public native @Cast("bool") boolean isOptimalRequirementsMet();
* @param writeList * @param writeList
* @param readList * @param readList
*/ */
public static native void registerSpecialUse(@Const @ByRef ConstNDArrayVector writeList, @Const @ByRef ConstNDArrayVector readList); public static native void registerSpecialUse(@Const @ByRef ConstNDArrayVector writeList, @Const @ByRef(nullValue = "std::vector<const sd::NDArray*>({})") ConstNDArrayVector readList);
public static native void prepareSpecialUse(@Const @ByRef ConstNDArrayVector writeList, @Const @ByRef ConstNDArrayVector readList, @Cast("bool") boolean synchronizeWritables/*=false*/); public static native void registerSpecialUse(@Const @ByRef ConstNDArrayVector writeList);
public static native void prepareSpecialUse(@Const @ByRef ConstNDArrayVector writeList, @Const @ByRef ConstNDArrayVector readList); public static native void prepareSpecialUse(@Const @ByRef ConstNDArrayVector writeList, @Const @ByRef(nullValue = "std::vector<const sd::NDArray*>({})") ConstNDArrayVector readList, @Cast("bool") boolean synchronizeWritables/*=false*/);
public static native void prepareSpecialUse(@Const @ByRef ConstNDArrayVector writeList);
public static native void registerPrimaryUse(@Const @ByRef ConstNDArrayVector writeList, @Const @ByRef ConstNDArrayVector readList); public static native void registerPrimaryUse(@Const @ByRef ConstNDArrayVector writeList, @Const @ByRef(nullValue = "std::vector<const sd::NDArray*>({})") ConstNDArrayVector readList);
public static native void preparePrimaryUse(@Const @ByRef ConstNDArrayVector writeList, @Const @ByRef ConstNDArrayVector readList, @Cast("bool") boolean synchronizeWritables/*=false*/); public static native void registerPrimaryUse(@Const @ByRef ConstNDArrayVector writeList);
public static native void preparePrimaryUse(@Const @ByRef ConstNDArrayVector writeList, @Const @ByRef ConstNDArrayVector readList); public static native void preparePrimaryUse(@Const @ByRef ConstNDArrayVector writeList, @Const @ByRef(nullValue = "std::vector<const sd::NDArray*>({})") ConstNDArrayVector readList, @Cast("bool") boolean synchronizeWritables/*=false*/);
public static native void preparePrimaryUse(@Const @ByRef ConstNDArrayVector writeList);
/** /**
* This method returns buffer pointer offset by given number of elements, wrt own data type * This method returns buffer pointer offset by given number of elements, wrt own data type
@ -5047,6 +5049,7 @@ public native @Cast("bool") boolean isOptimalRequirementsMet();
// #define LIBND4J_GRAPH_RNG_H // #define LIBND4J_GRAPH_RNG_H
// #include <types/u64.h> // #include <types/u64.h>
// #include <types/u32.h>
// #include <system/pointercast.h> // #include <system/pointercast.h>
// #include <system/op_boilerplate.h> // #include <system/op_boilerplate.h>
// #include <system/dll.h> // #include <system/dll.h>

View File

@ -8484,6 +8484,14 @@ public class Nd4jTestsC extends BaseNd4jTest {
} }
} }
@Test
public void testSmallSort(){
INDArray arr = Nd4j.createFromArray(0.5, 0.4, 0.1, 0.2);
INDArray expected = Nd4j.createFromArray(0.1, 0.2, 0.4, 0.5);
INDArray sorted = Nd4j.sort(arr, true);
assertEquals(expected, sorted);
}
@Override @Override
public char ordering() { public char ordering() {
return 'c'; return 'c';