CUDA small sort tests (#482)
* couple of C++ sort tests Signed-off-by: raver119@gmail.com <raver119@gmail.com> * Java sort test Signed-off-by: raver119@gmail.com <raver119@gmail.com>master
parent
c783a5938a
commit
45ebd4899c
|
@ -354,11 +354,11 @@ namespace sd {
|
||||||
* @param writeList
|
* @param writeList
|
||||||
* @param readList
|
* @param readList
|
||||||
*/
|
*/
|
||||||
static void registerSpecialUse(const std::vector<const NDArray*>& writeList, const std::vector<const NDArray*>& readList);
|
static void registerSpecialUse(const std::vector<const NDArray*>& writeList, const std::vector<const NDArray*>& readList = {});
|
||||||
static void prepareSpecialUse(const std::vector<const NDArray*>& writeList, const std::vector<const NDArray*>& readList, bool synchronizeWritables = false);
|
static void prepareSpecialUse(const std::vector<const NDArray*>& writeList, const std::vector<const NDArray*>& readList = {}, bool synchronizeWritables = false);
|
||||||
|
|
||||||
static void registerPrimaryUse(const std::vector<const NDArray*>& writeList, const std::vector<const NDArray*>& readList);
|
static void registerPrimaryUse(const std::vector<const NDArray*>& writeList, const std::vector<const NDArray*>& readList = {});
|
||||||
static void preparePrimaryUse(const std::vector<const NDArray*>& writeList, const std::vector<const NDArray*>& readList, bool synchronizeWritables = false);
|
static void preparePrimaryUse(const std::vector<const NDArray*>& writeList, const std::vector<const NDArray*>& readList = {}, bool synchronizeWritables = false);
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* This method returns buffer pointer offset by given number of elements, wrt own data type
|
* This method returns buffer pointer offset by given number of elements, wrt own data type
|
||||||
|
|
|
@ -58,3 +58,55 @@ TEST_F(LegacyOpsCudaTests, test_sortTad_1) {
|
||||||
|
|
||||||
ASSERT_EQ(e, x);
|
ASSERT_EQ(e, x);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_F(LegacyOpsCudaTests, test_sort_1) {
|
||||||
|
auto x = NDArrayFactory::create<float>('c', {4}, {4.f, 2.f, 1.f, 3.f});
|
||||||
|
auto e = NDArrayFactory::create<float>('c', {4}, {1.f, 2.f, 3.f, 4.f});
|
||||||
|
|
||||||
|
Nd4jPointer extras[2] = {nullptr, LaunchContext::defaultContext()->getCudaStream()};
|
||||||
|
|
||||||
|
NDArray::prepareSpecialUse({&x}, {&x});
|
||||||
|
::sort(extras, x.buffer(), x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(), false);
|
||||||
|
NDArray::registerSpecialUse({&x});
|
||||||
|
|
||||||
|
ASSERT_EQ(e, x);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(LegacyOpsCudaTests, test_sort_2) {
|
||||||
|
auto x = NDArrayFactory::create<float>('c', {4}, {4.f, 2.f, 1.f, 3.f});
|
||||||
|
auto e = NDArrayFactory::create<float>('c', {4}, {4.f, 3.f, 2.f, 1.f});
|
||||||
|
|
||||||
|
Nd4jPointer extras[2] = {nullptr, LaunchContext::defaultContext()->getCudaStream()};
|
||||||
|
|
||||||
|
NDArray::prepareSpecialUse({&x}, {&x});
|
||||||
|
::sort(extras, x.buffer(), x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(), true);
|
||||||
|
NDArray::registerSpecialUse({&x});
|
||||||
|
|
||||||
|
ASSERT_EQ(e, x);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(LegacyOpsCudaTests, test_sort_3) {
|
||||||
|
auto x = NDArrayFactory::create<double>('c', {4}, {0.5, 0.4, 0.1, 0.2});
|
||||||
|
auto e = NDArrayFactory::create<double>('c', {4}, {0.1, 0.2, 0.4, 0.5});
|
||||||
|
|
||||||
|
Nd4jPointer extras[2] = {nullptr, LaunchContext::defaultContext()->getCudaStream()};
|
||||||
|
|
||||||
|
NDArray::prepareSpecialUse({&x}, {&x});
|
||||||
|
::sort(extras, x.buffer(), x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(), false);
|
||||||
|
NDArray::registerSpecialUse({&x});
|
||||||
|
|
||||||
|
ASSERT_EQ(e, x);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(LegacyOpsCudaTests, test_sort_4) {
|
||||||
|
auto x = NDArrayFactory::create<double>('c', {4}, {7, 4, 9, 2});
|
||||||
|
auto e = NDArrayFactory::create<double>('c', {4}, {2, 4, 7, 9});
|
||||||
|
|
||||||
|
Nd4jPointer extras[2] = {nullptr, LaunchContext::defaultContext()->getCudaStream()};
|
||||||
|
|
||||||
|
NDArray::prepareSpecialUse({&x}, {&x});
|
||||||
|
::sort(extras, x.buffer(), x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(), false);
|
||||||
|
NDArray::registerSpecialUse({&x});
|
||||||
|
|
||||||
|
ASSERT_EQ(e, x);
|
||||||
|
}
|
|
@ -3849,13 +3849,15 @@ public native @Cast("bool") boolean isOptimalRequirementsMet();
|
||||||
* @param writeList
|
* @param writeList
|
||||||
* @param readList
|
* @param readList
|
||||||
*/
|
*/
|
||||||
public static native void registerSpecialUse(@Const @ByRef ConstNDArrayVector writeList, @Const @ByRef ConstNDArrayVector readList);
|
public static native void registerSpecialUse(@Const @ByRef ConstNDArrayVector writeList, @Const @ByRef(nullValue = "std::vector<const sd::NDArray*>({})") ConstNDArrayVector readList);
|
||||||
public static native void prepareSpecialUse(@Const @ByRef ConstNDArrayVector writeList, @Const @ByRef ConstNDArrayVector readList, @Cast("bool") boolean synchronizeWritables/*=false*/);
|
public static native void registerSpecialUse(@Const @ByRef ConstNDArrayVector writeList);
|
||||||
public static native void prepareSpecialUse(@Const @ByRef ConstNDArrayVector writeList, @Const @ByRef ConstNDArrayVector readList);
|
public static native void prepareSpecialUse(@Const @ByRef ConstNDArrayVector writeList, @Const @ByRef(nullValue = "std::vector<const sd::NDArray*>({})") ConstNDArrayVector readList, @Cast("bool") boolean synchronizeWritables/*=false*/);
|
||||||
|
public static native void prepareSpecialUse(@Const @ByRef ConstNDArrayVector writeList);
|
||||||
|
|
||||||
public static native void registerPrimaryUse(@Const @ByRef ConstNDArrayVector writeList, @Const @ByRef ConstNDArrayVector readList);
|
public static native void registerPrimaryUse(@Const @ByRef ConstNDArrayVector writeList, @Const @ByRef(nullValue = "std::vector<const sd::NDArray*>({})") ConstNDArrayVector readList);
|
||||||
public static native void preparePrimaryUse(@Const @ByRef ConstNDArrayVector writeList, @Const @ByRef ConstNDArrayVector readList, @Cast("bool") boolean synchronizeWritables/*=false*/);
|
public static native void registerPrimaryUse(@Const @ByRef ConstNDArrayVector writeList);
|
||||||
public static native void preparePrimaryUse(@Const @ByRef ConstNDArrayVector writeList, @Const @ByRef ConstNDArrayVector readList);
|
public static native void preparePrimaryUse(@Const @ByRef ConstNDArrayVector writeList, @Const @ByRef(nullValue = "std::vector<const sd::NDArray*>({})") ConstNDArrayVector readList, @Cast("bool") boolean synchronizeWritables/*=false*/);
|
||||||
|
public static native void preparePrimaryUse(@Const @ByRef ConstNDArrayVector writeList);
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* This method returns buffer pointer offset by given number of elements, wrt own data type
|
* This method returns buffer pointer offset by given number of elements, wrt own data type
|
||||||
|
@ -5043,6 +5045,7 @@ public native @Cast("bool") boolean isOptimalRequirementsMet();
|
||||||
// #define LIBND4J_GRAPH_RNG_H
|
// #define LIBND4J_GRAPH_RNG_H
|
||||||
|
|
||||||
// #include <types/u64.h>
|
// #include <types/u64.h>
|
||||||
|
// #include <types/u32.h>
|
||||||
// #include <system/pointercast.h>
|
// #include <system/pointercast.h>
|
||||||
// #include <system/op_boilerplate.h>
|
// #include <system/op_boilerplate.h>
|
||||||
// #include <system/dll.h>
|
// #include <system/dll.h>
|
||||||
|
|
|
@ -3853,13 +3853,15 @@ public native @Cast("bool") boolean isOptimalRequirementsMet();
|
||||||
* @param writeList
|
* @param writeList
|
||||||
* @param readList
|
* @param readList
|
||||||
*/
|
*/
|
||||||
public static native void registerSpecialUse(@Const @ByRef ConstNDArrayVector writeList, @Const @ByRef ConstNDArrayVector readList);
|
public static native void registerSpecialUse(@Const @ByRef ConstNDArrayVector writeList, @Const @ByRef(nullValue = "std::vector<const sd::NDArray*>({})") ConstNDArrayVector readList);
|
||||||
public static native void prepareSpecialUse(@Const @ByRef ConstNDArrayVector writeList, @Const @ByRef ConstNDArrayVector readList, @Cast("bool") boolean synchronizeWritables/*=false*/);
|
public static native void registerSpecialUse(@Const @ByRef ConstNDArrayVector writeList);
|
||||||
public static native void prepareSpecialUse(@Const @ByRef ConstNDArrayVector writeList, @Const @ByRef ConstNDArrayVector readList);
|
public static native void prepareSpecialUse(@Const @ByRef ConstNDArrayVector writeList, @Const @ByRef(nullValue = "std::vector<const sd::NDArray*>({})") ConstNDArrayVector readList, @Cast("bool") boolean synchronizeWritables/*=false*/);
|
||||||
|
public static native void prepareSpecialUse(@Const @ByRef ConstNDArrayVector writeList);
|
||||||
|
|
||||||
public static native void registerPrimaryUse(@Const @ByRef ConstNDArrayVector writeList, @Const @ByRef ConstNDArrayVector readList);
|
public static native void registerPrimaryUse(@Const @ByRef ConstNDArrayVector writeList, @Const @ByRef(nullValue = "std::vector<const sd::NDArray*>({})") ConstNDArrayVector readList);
|
||||||
public static native void preparePrimaryUse(@Const @ByRef ConstNDArrayVector writeList, @Const @ByRef ConstNDArrayVector readList, @Cast("bool") boolean synchronizeWritables/*=false*/);
|
public static native void registerPrimaryUse(@Const @ByRef ConstNDArrayVector writeList);
|
||||||
public static native void preparePrimaryUse(@Const @ByRef ConstNDArrayVector writeList, @Const @ByRef ConstNDArrayVector readList);
|
public static native void preparePrimaryUse(@Const @ByRef ConstNDArrayVector writeList, @Const @ByRef(nullValue = "std::vector<const sd::NDArray*>({})") ConstNDArrayVector readList, @Cast("bool") boolean synchronizeWritables/*=false*/);
|
||||||
|
public static native void preparePrimaryUse(@Const @ByRef ConstNDArrayVector writeList);
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* This method returns buffer pointer offset by given number of elements, wrt own data type
|
* This method returns buffer pointer offset by given number of elements, wrt own data type
|
||||||
|
@ -5047,6 +5049,7 @@ public native @Cast("bool") boolean isOptimalRequirementsMet();
|
||||||
// #define LIBND4J_GRAPH_RNG_H
|
// #define LIBND4J_GRAPH_RNG_H
|
||||||
|
|
||||||
// #include <types/u64.h>
|
// #include <types/u64.h>
|
||||||
|
// #include <types/u32.h>
|
||||||
// #include <system/pointercast.h>
|
// #include <system/pointercast.h>
|
||||||
// #include <system/op_boilerplate.h>
|
// #include <system/op_boilerplate.h>
|
||||||
// #include <system/dll.h>
|
// #include <system/dll.h>
|
||||||
|
|
|
@ -8484,6 +8484,14 @@ public class Nd4jTestsC extends BaseNd4jTest {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testSmallSort(){
|
||||||
|
INDArray arr = Nd4j.createFromArray(0.5, 0.4, 0.1, 0.2);
|
||||||
|
INDArray expected = Nd4j.createFromArray(0.1, 0.2, 0.4, 0.5);
|
||||||
|
INDArray sorted = Nd4j.sort(arr, true);
|
||||||
|
assertEquals(expected, sorted);
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public char ordering() {
|
public char ordering() {
|
||||||
return 'c';
|
return 'c';
|
||||||
|
|
Loading…
Reference in New Issue