[WIP] reverse improvements (#115)
* initial commit Signed-off-by: raver119 <raver119@gmail.com> * reverse draft Signed-off-by: raver119 <raver119@gmail.com> * reverse kernel Signed-off-by: raver119 <raver119@gmail.com> * reverse kernel Signed-off-by: raver119 <raver119@gmail.com>master
parent
0e8a4f77bc
commit
355c6b6096
|
@ -30,6 +30,67 @@ namespace nd4j {
|
|||
namespace ops {
|
||||
namespace helpers {
|
||||
|
||||
template <typename T>
|
||||
static __global__ void reverseTadKernel(void* vinput, Nd4jLong *inputShape, void* voutput, Nd4jLong *outputShape, Nd4jLong *inputTadShape, Nd4jLong *inputTadOffsets, Nd4jLong *outputTadShape, Nd4jLong *outputTadOffsets, uint64_t limit, uint64_t numOfElemsToReverse, uint64_t numTads) {
|
||||
auto input = reinterpret_cast<T*>(vinput);
|
||||
auto output = reinterpret_cast<T*>(voutput);
|
||||
const auto tid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
const auto step = gridDim.x * blockDim.x;
|
||||
|
||||
// this means that we'll have additional cycle, to move middle element
|
||||
auto div = numOfElemsToReverse / 2;
|
||||
auto odd = numOfElemsToReverse % 2 != 0;
|
||||
auto rlimit = odd ? limit / 2 + 1 : limit / 2;
|
||||
|
||||
// all threads operate in the same input/output space
|
||||
for (uint64_t e = tid; e < rlimit; e += step) {
|
||||
// finding out the TAD we're going to process
|
||||
auto tadId = e / div;
|
||||
|
||||
if (tadId >= numTads)
|
||||
continue;
|
||||
|
||||
// now finding out element within tad
|
||||
auto idx = e % div;
|
||||
|
||||
//printf("TID: %i; numTads: %lld; tadLength: %lld; tadId: %i, idx: %lld\n", tid, numTads, numOfElemsToReverse, tadId, idx);
|
||||
|
||||
auto tadInput = input + inputTadOffsets[tadId];
|
||||
auto tadOutput = output + outputTadOffsets[tadId];
|
||||
|
||||
// we're calculating offsets within input TAD
|
||||
auto fOffset = shape::getIndexOffset(idx, inputTadShape);
|
||||
auto lOffset = shape::getIndexOffset(numOfElemsToReverse - idx - 1, inputTadShape);
|
||||
|
||||
// now we're storing input values
|
||||
auto v1 = tadInput[fOffset];
|
||||
auto v2 = tadInput[lOffset];
|
||||
|
||||
// now we're calculating offsets within output TAD
|
||||
auto zfOffset = shape::getIndexOffset(idx, outputTadShape);
|
||||
auto zlOffset = shape::getIndexOffset(numOfElemsToReverse - idx - 1, outputTadShape);
|
||||
|
||||
// and saving values to output arrays
|
||||
tadOutput[zfOffset] = v2;
|
||||
tadOutput[zlOffset] = v1;
|
||||
}
|
||||
|
||||
// moving odd element in blocks
|
||||
if (odd && threadIdx.x == 0) {
|
||||
for (uint64_t e = blockIdx.x; e < numTads; e += gridDim.x) {
|
||||
auto tadInput = input + inputTadOffsets[e];
|
||||
auto tadOutput = output + outputTadOffsets[e];
|
||||
|
||||
auto xOffset = shape::getIndexOffset(numOfElemsToReverse / 2, inputTadShape);
|
||||
auto zOffset = shape::getIndexOffset(numOfElemsToReverse / 2, outputTadShape);
|
||||
|
||||
tadOutput[zOffset] = tadInput[xOffset];
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
|
||||
template <typename T>
|
||||
static __global__ void reverseArrayKernel(void* input, Nd4jLong *inputShape, void* output, Nd4jLong *outputShape, Nd4jLong numOfElemsToReverse) {
|
||||
const auto tid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
|
@ -52,7 +113,7 @@ namespace helpers {
|
|||
auto odd = numOfElemsToReverse % 2 != 0;
|
||||
auto limit = numOfElemsToReverse / 2;
|
||||
|
||||
for (Nd4jLong e = tid; e < limit; e += step) {
|
||||
for (uint64_t e = tid; e < limit; e += step) {
|
||||
// we're calculating offsets within input array
|
||||
auto fOffset = shape::getIndexOffset(e, inputShape);
|
||||
auto lOffset = shape::getIndexOffset(numOfElemsToReverse - e - 1, inputShape);
|
||||
|
@ -80,13 +141,19 @@ namespace helpers {
|
|||
}
|
||||
|
||||
template<typename T>
|
||||
static void reverseArray(nd4j::LaunchContext * context, NDArray* input, NDArray* output, Nd4jLong numOfElemsToReverse) {
|
||||
static void reverseTad(nd4j::LaunchContext * context, const NDArray* input, NDArray* output, Nd4jLong *inputTadShape, Nd4jLong *inputTadOffsets, Nd4jLong *outputTadShape, Nd4jLong *outputTadOffsets, uint64_t tadLength) {
|
||||
auto stream = context->getCudaStream();
|
||||
reverseTadKernel<T><<<256, 512, 8192, *stream>>>(input->getSpecialBuffer(), input->getSpecialShapeInfo(), output->specialBuffer(), output->specialShapeInfo(), inputTadShape, inputTadOffsets, outputTadShape, outputTadOffsets, input->lengthOf(), tadLength, input->lengthOf() / tadLength);
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
static void reverseArray(nd4j::LaunchContext * context, const NDArray* input, NDArray* output, Nd4jLong numOfElemsToReverse) {
|
||||
auto stream = context->getCudaStream();
|
||||
Nd4jLong numOfReverse = numOfElemsToReverse;
|
||||
if (numOfElemsToReverse == 0)
|
||||
numOfReverse = input->lengthOf();
|
||||
|
||||
reverseArrayKernel<T><<<256, 512, 8192, *stream>>>(input->specialBuffer(), input->specialShapeInfo(), output->specialBuffer(), output->specialShapeInfo(), numOfReverse);
|
||||
reverseArrayKernel<T><<<256, 512, 8192, *stream>>>(input->getSpecialBuffer(), input->getSpecialShapeInfo(), output->specialBuffer(), output->specialShapeInfo(), numOfReverse);
|
||||
}
|
||||
|
||||
|
||||
|
@ -153,27 +220,23 @@ namespace helpers {
|
|||
// we need to reverse axis only if that's new op
|
||||
std::vector<int> dimensions = isBackProp ? ShapeUtils::evalDimsToExclude(input->rankOf(), *intArgs) : *intArgs;
|
||||
std::vector<int> axis = ShapeUtils::evalDimsToExclude(input->rankOf(), dimensions);
|
||||
auto packX = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(input->getShapeInfo(), axis);
|
||||
auto packZ = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(output->getShapeInfo(), axis);
|
||||
auto packX = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(input->getShapeInfo(), dimensions);
|
||||
auto packZ = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(output->getShapeInfo(), dimensions);
|
||||
|
||||
auto listOut = output->allTensorsAlongDimension(dimensions);
|
||||
auto listIn = input->allTensorsAlongDimension(dimensions);
|
||||
|
||||
NDArray *subArrIn, *subArrOut;
|
||||
|
||||
NDArray::prepareSpecialUse({output}, {input});
|
||||
for(int i = 0; i < listIn->size(); ++i) { // listIn->size() = listOut->size()
|
||||
subArrIn = listIn->at(i);
|
||||
subArrOut = listOut->at(i);
|
||||
BUILD_SINGLE_SELECTOR(input->dataType(), reverseArray, (context, subArrIn, subArrOut, 0), LIBND4J_TYPES);
|
||||
|
||||
if (packX.numberOfTads() == 1) {
|
||||
BUILD_SINGLE_SELECTOR(input->dataType(), reverseArray, (context, input, output, 0), LIBND4J_TYPES);
|
||||
} else {
|
||||
BUILD_SINGLE_SELECTOR(input->dataType(), reverseTad, (context, input, output, packX.platformShapeInfo(), packX.platformOffsets(), packZ.platformShapeInfo(), packZ.platformOffsets(), (uint64_t) (input->lengthOf() / packX.numberOfTads())), LIBND4J_TYPES);
|
||||
}
|
||||
//BUILD_SINGLE_SELECTOR(input->dataType(), reverseArray, (context, const_cast<NDArray*>(input), output, (int)0), LIBND4J_TYPES);
|
||||
|
||||
NDArray::registerSpecialUse({output}, {input});
|
||||
delete listOut;
|
||||
delete listIn;
|
||||
}
|
||||
|
||||
BUILD_SINGLE_TEMPLATE(template void reverseArray, (nd4j::LaunchContext * context, NDArray *inArr, NDArray *outArr, Nd4jLong numOfElemsToReverse), LIBND4J_TYPES);
|
||||
BUILD_SINGLE_TEMPLATE(template void reverseArray, (nd4j::LaunchContext * context, const NDArray *inArr, NDArray *outArr, Nd4jLong numOfElemsToReverse), LIBND4J_TYPES);
|
||||
|
||||
}
|
||||
}
|
||||
|
|
|
@ -3523,7 +3523,8 @@ TEST_F(DeclarableOpsTests1, Reverse_7 ) {
|
|||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||
|
||||
auto result = results->at(0);
|
||||
// result->printBuffer();
|
||||
//expected.printIndexedBuffer("E");
|
||||
//result->printIndexedBuffer("R");
|
||||
|
||||
ASSERT_TRUE(expected.isSameShapeStrict(result));
|
||||
ASSERT_TRUE(expected.equalsTo(result));
|
||||
|
|
|
@ -197,3 +197,44 @@ TEST_F(DeclarableOpsTests16, test_range_2) {
|
|||
|
||||
delete shapes;
|
||||
}
|
||||
|
||||
TEST_F(DeclarableOpsTests16, test_reverse_1) {
|
||||
std::vector<Nd4jLong> rows = {3, 5, 7, 8, 9, 10, 119, 211};
|
||||
std::vector<Nd4jLong> columns = {6, 5, 10, 100, 153, 171, 635};
|
||||
|
||||
for (auto r : rows) {
|
||||
for (auto c : columns) {
|
||||
//nd4j_printf("Trying [%i, %i]\n", r, c);
|
||||
auto array = NDArrayFactory::create<float>('c', {r, c});
|
||||
auto exp = NDArrayFactory::create<float>('c', {r, c});
|
||||
auto reversed = NDArrayFactory::create<float>('c', {r, c});
|
||||
|
||||
auto rowOriginal = NDArrayFactory::create<float>('c', {c});
|
||||
auto rowReversed = NDArrayFactory::create<float>('c', {c});
|
||||
|
||||
for (int e = 0; e < c; e++) {
|
||||
rowOriginal.p(e, (float) e);
|
||||
rowReversed.p(c - e - 1, (float) e);
|
||||
}
|
||||
|
||||
|
||||
auto listI = array.allTensorsAlongDimension({1});
|
||||
auto listE = exp.allTensorsAlongDimension({1});
|
||||
|
||||
for (int e = 0; e < r; e++) {
|
||||
listI->at(e)->assign(rowOriginal);
|
||||
listE->at(e)->assign(rowReversed);
|
||||
}
|
||||
|
||||
delete listI;
|
||||
delete listE;
|
||||
|
||||
nd4j::ops::reverse op;
|
||||
Nd4jLong axis = 1;
|
||||
auto status = op.execute({&array}, {&reversed}, {}, {axis}, {});
|
||||
ASSERT_EQ(Status::OK(), status);
|
||||
|
||||
ASSERT_EQ(exp, reversed);
|
||||
}
|
||||
}
|
||||
}
|
|
@ -24,6 +24,7 @@
|
|||
#include <NDArray.h>
|
||||
#include <ops/ops.h>
|
||||
#include <GradCheck.h>
|
||||
#include <chrono>
|
||||
|
||||
|
||||
using namespace nd4j;
|
||||
|
@ -58,5 +59,20 @@ TEST_F(DeclarableOpsTestsCuda1, Test_CHOOSE_SCALAR_LARGE) {
|
|||
//ASSERT_TRUE(exp.isSameShape(z));
|
||||
|
||||
delete result;
|
||||
|
||||
}
|
||||
|
||||
/*
|
||||
TEST_F(DeclarableOpsTestsCuda1, Test_Reverse_TAD_1) {
|
||||
auto x = NDArrayFactory::create<float>('c', {1, 3, 608, 608});
|
||||
auto z = x.like();
|
||||
x.linspace(1.0f);
|
||||
|
||||
nd4j::ops::reverse op;
|
||||
auto timeStart = std::chrono::system_clock::now();
|
||||
auto status = op.execute({&x}, {&z}, {}, {1}, {});
|
||||
auto timeEnd = std::chrono::system_clock::now();
|
||||
auto outerTime = std::chrono::duration_cast<std::chrono::microseconds> (timeEnd - timeStart).count();
|
||||
nd4j_printf("exec time: %lld us\n", outerTime);
|
||||
ASSERT_EQ(Status::OK(), status);
|
||||
}
|
||||
*/
|
Loading…
Reference in New Issue