[WIP] reverse improvements (#115)

* initial commit

Signed-off-by: raver119 <raver119@gmail.com>

* reverse draft

Signed-off-by: raver119 <raver119@gmail.com>

* reverse kernel

Signed-off-by: raver119 <raver119@gmail.com>

* reverse kernel

Signed-off-by: raver119 <raver119@gmail.com>
master
raver119 2019-12-05 20:03:10 +03:00 committed by GitHub
parent 0e8a4f77bc
commit 355c6b6096
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 139 additions and 18 deletions

View File

@ -30,6 +30,67 @@ namespace nd4j {
namespace ops { namespace ops {
namespace helpers { namespace helpers {
template <typename T>
static __global__ void reverseTadKernel(void* vinput, Nd4jLong *inputShape, void* voutput, Nd4jLong *outputShape, Nd4jLong *inputTadShape, Nd4jLong *inputTadOffsets, Nd4jLong *outputTadShape, Nd4jLong *outputTadOffsets, uint64_t limit, uint64_t numOfElemsToReverse, uint64_t numTads) {
auto input = reinterpret_cast<T*>(vinput);
auto output = reinterpret_cast<T*>(voutput);
const auto tid = blockIdx.x * blockDim.x + threadIdx.x;
const auto step = gridDim.x * blockDim.x;
// this means that we'll have additional cycle, to move middle element
auto div = numOfElemsToReverse / 2;
auto odd = numOfElemsToReverse % 2 != 0;
auto rlimit = odd ? limit / 2 + 1 : limit / 2;
// all threads operate in the same input/output space
for (uint64_t e = tid; e < rlimit; e += step) {
// finding out the TAD we're going to process
auto tadId = e / div;
if (tadId >= numTads)
continue;
// now finding out element within tad
auto idx = e % div;
//printf("TID: %i; numTads: %lld; tadLength: %lld; tadId: %i, idx: %lld\n", tid, numTads, numOfElemsToReverse, tadId, idx);
auto tadInput = input + inputTadOffsets[tadId];
auto tadOutput = output + outputTadOffsets[tadId];
// we're calculating offsets within input TAD
auto fOffset = shape::getIndexOffset(idx, inputTadShape);
auto lOffset = shape::getIndexOffset(numOfElemsToReverse - idx - 1, inputTadShape);
// now we're storing input values
auto v1 = tadInput[fOffset];
auto v2 = tadInput[lOffset];
// now we're calculating offsets within output TAD
auto zfOffset = shape::getIndexOffset(idx, outputTadShape);
auto zlOffset = shape::getIndexOffset(numOfElemsToReverse - idx - 1, outputTadShape);
// and saving values to output arrays
tadOutput[zfOffset] = v2;
tadOutput[zlOffset] = v1;
}
// moving odd element in blocks
if (odd && threadIdx.x == 0) {
for (uint64_t e = blockIdx.x; e < numTads; e += gridDim.x) {
auto tadInput = input + inputTadOffsets[e];
auto tadOutput = output + outputTadOffsets[e];
auto xOffset = shape::getIndexOffset(numOfElemsToReverse / 2, inputTadShape);
auto zOffset = shape::getIndexOffset(numOfElemsToReverse / 2, outputTadShape);
tadOutput[zOffset] = tadInput[xOffset];
}
}
}
template <typename T> template <typename T>
static __global__ void reverseArrayKernel(void* input, Nd4jLong *inputShape, void* output, Nd4jLong *outputShape, Nd4jLong numOfElemsToReverse) { static __global__ void reverseArrayKernel(void* input, Nd4jLong *inputShape, void* output, Nd4jLong *outputShape, Nd4jLong numOfElemsToReverse) {
const auto tid = blockIdx.x * blockDim.x + threadIdx.x; const auto tid = blockIdx.x * blockDim.x + threadIdx.x;
@ -52,7 +113,7 @@ namespace helpers {
auto odd = numOfElemsToReverse % 2 != 0; auto odd = numOfElemsToReverse % 2 != 0;
auto limit = numOfElemsToReverse / 2; auto limit = numOfElemsToReverse / 2;
for (Nd4jLong e = tid; e < limit; e += step) { for (uint64_t e = tid; e < limit; e += step) {
// we're calculating offsets within input array // we're calculating offsets within input array
auto fOffset = shape::getIndexOffset(e, inputShape); auto fOffset = shape::getIndexOffset(e, inputShape);
auto lOffset = shape::getIndexOffset(numOfElemsToReverse - e - 1, inputShape); auto lOffset = shape::getIndexOffset(numOfElemsToReverse - e - 1, inputShape);
@ -80,13 +141,19 @@ namespace helpers {
} }
template<typename T> template<typename T>
static void reverseArray(nd4j::LaunchContext * context, NDArray* input, NDArray* output, Nd4jLong numOfElemsToReverse) { static void reverseTad(nd4j::LaunchContext * context, const NDArray* input, NDArray* output, Nd4jLong *inputTadShape, Nd4jLong *inputTadOffsets, Nd4jLong *outputTadShape, Nd4jLong *outputTadOffsets, uint64_t tadLength) {
auto stream = context->getCudaStream();
reverseTadKernel<T><<<256, 512, 8192, *stream>>>(input->getSpecialBuffer(), input->getSpecialShapeInfo(), output->specialBuffer(), output->specialShapeInfo(), inputTadShape, inputTadOffsets, outputTadShape, outputTadOffsets, input->lengthOf(), tadLength, input->lengthOf() / tadLength);
}
template<typename T>
static void reverseArray(nd4j::LaunchContext * context, const NDArray* input, NDArray* output, Nd4jLong numOfElemsToReverse) {
auto stream = context->getCudaStream(); auto stream = context->getCudaStream();
Nd4jLong numOfReverse = numOfElemsToReverse; Nd4jLong numOfReverse = numOfElemsToReverse;
if (numOfElemsToReverse == 0) if (numOfElemsToReverse == 0)
numOfReverse = input->lengthOf(); numOfReverse = input->lengthOf();
reverseArrayKernel<T><<<256, 512, 8192, *stream>>>(input->specialBuffer(), input->specialShapeInfo(), output->specialBuffer(), output->specialShapeInfo(), numOfReverse); reverseArrayKernel<T><<<256, 512, 8192, *stream>>>(input->getSpecialBuffer(), input->getSpecialShapeInfo(), output->specialBuffer(), output->specialShapeInfo(), numOfReverse);
} }
@ -153,27 +220,23 @@ namespace helpers {
// we need to reverse axis only if that's new op // we need to reverse axis only if that's new op
std::vector<int> dimensions = isBackProp ? ShapeUtils::evalDimsToExclude(input->rankOf(), *intArgs) : *intArgs; std::vector<int> dimensions = isBackProp ? ShapeUtils::evalDimsToExclude(input->rankOf(), *intArgs) : *intArgs;
std::vector<int> axis = ShapeUtils::evalDimsToExclude(input->rankOf(), dimensions); std::vector<int> axis = ShapeUtils::evalDimsToExclude(input->rankOf(), dimensions);
auto packX = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(input->getShapeInfo(), axis); auto packX = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(input->getShapeInfo(), dimensions);
auto packZ = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(output->getShapeInfo(), axis); auto packZ = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(output->getShapeInfo(), dimensions);
auto listOut = output->allTensorsAlongDimension(dimensions);
auto listIn = input->allTensorsAlongDimension(dimensions);
NDArray *subArrIn, *subArrOut;
NDArray::prepareSpecialUse({output}, {input}); NDArray::prepareSpecialUse({output}, {input});
for(int i = 0; i < listIn->size(); ++i) { // listIn->size() = listOut->size()
subArrIn = listIn->at(i); if (packX.numberOfTads() == 1) {
subArrOut = listOut->at(i); BUILD_SINGLE_SELECTOR(input->dataType(), reverseArray, (context, input, output, 0), LIBND4J_TYPES);
BUILD_SINGLE_SELECTOR(input->dataType(), reverseArray, (context, subArrIn, subArrOut, 0), LIBND4J_TYPES); } else {
BUILD_SINGLE_SELECTOR(input->dataType(), reverseTad, (context, input, output, packX.platformShapeInfo(), packX.platformOffsets(), packZ.platformShapeInfo(), packZ.platformOffsets(), (uint64_t) (input->lengthOf() / packX.numberOfTads())), LIBND4J_TYPES);
} }
//BUILD_SINGLE_SELECTOR(input->dataType(), reverseArray, (context, const_cast<NDArray*>(input), output, (int)0), LIBND4J_TYPES);
NDArray::registerSpecialUse({output}, {input}); NDArray::registerSpecialUse({output}, {input});
delete listOut;
delete listIn;
} }
BUILD_SINGLE_TEMPLATE(template void reverseArray, (nd4j::LaunchContext * context, NDArray *inArr, NDArray *outArr, Nd4jLong numOfElemsToReverse), LIBND4J_TYPES); BUILD_SINGLE_TEMPLATE(template void reverseArray, (nd4j::LaunchContext * context, const NDArray *inArr, NDArray *outArr, Nd4jLong numOfElemsToReverse), LIBND4J_TYPES);
} }
} }

View File

@ -3523,7 +3523,8 @@ TEST_F(DeclarableOpsTests1, Reverse_7 ) {
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto result = results->at(0); auto result = results->at(0);
// result->printBuffer(); //expected.printIndexedBuffer("E");
//result->printIndexedBuffer("R");
ASSERT_TRUE(expected.isSameShapeStrict(result)); ASSERT_TRUE(expected.isSameShapeStrict(result));
ASSERT_TRUE(expected.equalsTo(result)); ASSERT_TRUE(expected.equalsTo(result));

View File

@ -197,3 +197,44 @@ TEST_F(DeclarableOpsTests16, test_range_2) {
delete shapes; delete shapes;
} }
TEST_F(DeclarableOpsTests16, test_reverse_1) {
std::vector<Nd4jLong> rows = {3, 5, 7, 8, 9, 10, 119, 211};
std::vector<Nd4jLong> columns = {6, 5, 10, 100, 153, 171, 635};
for (auto r : rows) {
for (auto c : columns) {
//nd4j_printf("Trying [%i, %i]\n", r, c);
auto array = NDArrayFactory::create<float>('c', {r, c});
auto exp = NDArrayFactory::create<float>('c', {r, c});
auto reversed = NDArrayFactory::create<float>('c', {r, c});
auto rowOriginal = NDArrayFactory::create<float>('c', {c});
auto rowReversed = NDArrayFactory::create<float>('c', {c});
for (int e = 0; e < c; e++) {
rowOriginal.p(e, (float) e);
rowReversed.p(c - e - 1, (float) e);
}
auto listI = array.allTensorsAlongDimension({1});
auto listE = exp.allTensorsAlongDimension({1});
for (int e = 0; e < r; e++) {
listI->at(e)->assign(rowOriginal);
listE->at(e)->assign(rowReversed);
}
delete listI;
delete listE;
nd4j::ops::reverse op;
Nd4jLong axis = 1;
auto status = op.execute({&array}, {&reversed}, {}, {axis}, {});
ASSERT_EQ(Status::OK(), status);
ASSERT_EQ(exp, reversed);
}
}
}

View File

@ -24,6 +24,7 @@
#include <NDArray.h> #include <NDArray.h>
#include <ops/ops.h> #include <ops/ops.h>
#include <GradCheck.h> #include <GradCheck.h>
#include <chrono>
using namespace nd4j; using namespace nd4j;
@ -58,5 +59,20 @@ TEST_F(DeclarableOpsTestsCuda1, Test_CHOOSE_SCALAR_LARGE) {
//ASSERT_TRUE(exp.isSameShape(z)); //ASSERT_TRUE(exp.isSameShape(z));
delete result; delete result;
} }
/*
TEST_F(DeclarableOpsTestsCuda1, Test_Reverse_TAD_1) {
auto x = NDArrayFactory::create<float>('c', {1, 3, 608, 608});
auto z = x.like();
x.linspace(1.0f);
nd4j::ops::reverse op;
auto timeStart = std::chrono::system_clock::now();
auto status = op.execute({&x}, {&z}, {}, {1}, {});
auto timeEnd = std::chrono::system_clock::now();
auto outerTime = std::chrono::duration_cast<std::chrono::microseconds> (timeEnd - timeStart).count();
nd4j_printf("exec time: %lld us\n", outerTime);
ASSERT_EQ(Status::OK(), status);
}
*/