[WIP] reverse improvements (#115)
* initial commit Signed-off-by: raver119 <raver119@gmail.com> * reverse draft Signed-off-by: raver119 <raver119@gmail.com> * reverse kernel Signed-off-by: raver119 <raver119@gmail.com> * reverse kernel Signed-off-by: raver119 <raver119@gmail.com>master
parent
0e8a4f77bc
commit
355c6b6096
|
@ -30,6 +30,67 @@ namespace nd4j {
|
||||||
namespace ops {
|
namespace ops {
|
||||||
namespace helpers {
|
namespace helpers {
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
static __global__ void reverseTadKernel(void* vinput, Nd4jLong *inputShape, void* voutput, Nd4jLong *outputShape, Nd4jLong *inputTadShape, Nd4jLong *inputTadOffsets, Nd4jLong *outputTadShape, Nd4jLong *outputTadOffsets, uint64_t limit, uint64_t numOfElemsToReverse, uint64_t numTads) {
|
||||||
|
auto input = reinterpret_cast<T*>(vinput);
|
||||||
|
auto output = reinterpret_cast<T*>(voutput);
|
||||||
|
const auto tid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||||
|
const auto step = gridDim.x * blockDim.x;
|
||||||
|
|
||||||
|
// this means that we'll have additional cycle, to move middle element
|
||||||
|
auto div = numOfElemsToReverse / 2;
|
||||||
|
auto odd = numOfElemsToReverse % 2 != 0;
|
||||||
|
auto rlimit = odd ? limit / 2 + 1 : limit / 2;
|
||||||
|
|
||||||
|
// all threads operate in the same input/output space
|
||||||
|
for (uint64_t e = tid; e < rlimit; e += step) {
|
||||||
|
// finding out the TAD we're going to process
|
||||||
|
auto tadId = e / div;
|
||||||
|
|
||||||
|
if (tadId >= numTads)
|
||||||
|
continue;
|
||||||
|
|
||||||
|
// now finding out element within tad
|
||||||
|
auto idx = e % div;
|
||||||
|
|
||||||
|
//printf("TID: %i; numTads: %lld; tadLength: %lld; tadId: %i, idx: %lld\n", tid, numTads, numOfElemsToReverse, tadId, idx);
|
||||||
|
|
||||||
|
auto tadInput = input + inputTadOffsets[tadId];
|
||||||
|
auto tadOutput = output + outputTadOffsets[tadId];
|
||||||
|
|
||||||
|
// we're calculating offsets within input TAD
|
||||||
|
auto fOffset = shape::getIndexOffset(idx, inputTadShape);
|
||||||
|
auto lOffset = shape::getIndexOffset(numOfElemsToReverse - idx - 1, inputTadShape);
|
||||||
|
|
||||||
|
// now we're storing input values
|
||||||
|
auto v1 = tadInput[fOffset];
|
||||||
|
auto v2 = tadInput[lOffset];
|
||||||
|
|
||||||
|
// now we're calculating offsets within output TAD
|
||||||
|
auto zfOffset = shape::getIndexOffset(idx, outputTadShape);
|
||||||
|
auto zlOffset = shape::getIndexOffset(numOfElemsToReverse - idx - 1, outputTadShape);
|
||||||
|
|
||||||
|
// and saving values to output arrays
|
||||||
|
tadOutput[zfOffset] = v2;
|
||||||
|
tadOutput[zlOffset] = v1;
|
||||||
|
}
|
||||||
|
|
||||||
|
// moving odd element in blocks
|
||||||
|
if (odd && threadIdx.x == 0) {
|
||||||
|
for (uint64_t e = blockIdx.x; e < numTads; e += gridDim.x) {
|
||||||
|
auto tadInput = input + inputTadOffsets[e];
|
||||||
|
auto tadOutput = output + outputTadOffsets[e];
|
||||||
|
|
||||||
|
auto xOffset = shape::getIndexOffset(numOfElemsToReverse / 2, inputTadShape);
|
||||||
|
auto zOffset = shape::getIndexOffset(numOfElemsToReverse / 2, outputTadShape);
|
||||||
|
|
||||||
|
tadOutput[zOffset] = tadInput[xOffset];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
static __global__ void reverseArrayKernel(void* input, Nd4jLong *inputShape, void* output, Nd4jLong *outputShape, Nd4jLong numOfElemsToReverse) {
|
static __global__ void reverseArrayKernel(void* input, Nd4jLong *inputShape, void* output, Nd4jLong *outputShape, Nd4jLong numOfElemsToReverse) {
|
||||||
const auto tid = blockIdx.x * blockDim.x + threadIdx.x;
|
const auto tid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||||
|
@ -52,7 +113,7 @@ namespace helpers {
|
||||||
auto odd = numOfElemsToReverse % 2 != 0;
|
auto odd = numOfElemsToReverse % 2 != 0;
|
||||||
auto limit = numOfElemsToReverse / 2;
|
auto limit = numOfElemsToReverse / 2;
|
||||||
|
|
||||||
for (Nd4jLong e = tid; e < limit; e += step) {
|
for (uint64_t e = tid; e < limit; e += step) {
|
||||||
// we're calculating offsets within input array
|
// we're calculating offsets within input array
|
||||||
auto fOffset = shape::getIndexOffset(e, inputShape);
|
auto fOffset = shape::getIndexOffset(e, inputShape);
|
||||||
auto lOffset = shape::getIndexOffset(numOfElemsToReverse - e - 1, inputShape);
|
auto lOffset = shape::getIndexOffset(numOfElemsToReverse - e - 1, inputShape);
|
||||||
|
@ -80,13 +141,19 @@ namespace helpers {
|
||||||
}
|
}
|
||||||
|
|
||||||
template<typename T>
|
template<typename T>
|
||||||
static void reverseArray(nd4j::LaunchContext * context, NDArray* input, NDArray* output, Nd4jLong numOfElemsToReverse) {
|
static void reverseTad(nd4j::LaunchContext * context, const NDArray* input, NDArray* output, Nd4jLong *inputTadShape, Nd4jLong *inputTadOffsets, Nd4jLong *outputTadShape, Nd4jLong *outputTadOffsets, uint64_t tadLength) {
|
||||||
|
auto stream = context->getCudaStream();
|
||||||
|
reverseTadKernel<T><<<256, 512, 8192, *stream>>>(input->getSpecialBuffer(), input->getSpecialShapeInfo(), output->specialBuffer(), output->specialShapeInfo(), inputTadShape, inputTadOffsets, outputTadShape, outputTadOffsets, input->lengthOf(), tadLength, input->lengthOf() / tadLength);
|
||||||
|
}
|
||||||
|
|
||||||
|
template<typename T>
|
||||||
|
static void reverseArray(nd4j::LaunchContext * context, const NDArray* input, NDArray* output, Nd4jLong numOfElemsToReverse) {
|
||||||
auto stream = context->getCudaStream();
|
auto stream = context->getCudaStream();
|
||||||
Nd4jLong numOfReverse = numOfElemsToReverse;
|
Nd4jLong numOfReverse = numOfElemsToReverse;
|
||||||
if (numOfElemsToReverse == 0)
|
if (numOfElemsToReverse == 0)
|
||||||
numOfReverse = input->lengthOf();
|
numOfReverse = input->lengthOf();
|
||||||
|
|
||||||
reverseArrayKernel<T><<<256, 512, 8192, *stream>>>(input->specialBuffer(), input->specialShapeInfo(), output->specialBuffer(), output->specialShapeInfo(), numOfReverse);
|
reverseArrayKernel<T><<<256, 512, 8192, *stream>>>(input->getSpecialBuffer(), input->getSpecialShapeInfo(), output->specialBuffer(), output->specialShapeInfo(), numOfReverse);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -153,27 +220,23 @@ namespace helpers {
|
||||||
// we need to reverse axis only if that's new op
|
// we need to reverse axis only if that's new op
|
||||||
std::vector<int> dimensions = isBackProp ? ShapeUtils::evalDimsToExclude(input->rankOf(), *intArgs) : *intArgs;
|
std::vector<int> dimensions = isBackProp ? ShapeUtils::evalDimsToExclude(input->rankOf(), *intArgs) : *intArgs;
|
||||||
std::vector<int> axis = ShapeUtils::evalDimsToExclude(input->rankOf(), dimensions);
|
std::vector<int> axis = ShapeUtils::evalDimsToExclude(input->rankOf(), dimensions);
|
||||||
auto packX = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(input->getShapeInfo(), axis);
|
auto packX = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(input->getShapeInfo(), dimensions);
|
||||||
auto packZ = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(output->getShapeInfo(), axis);
|
auto packZ = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(output->getShapeInfo(), dimensions);
|
||||||
|
|
||||||
auto listOut = output->allTensorsAlongDimension(dimensions);
|
|
||||||
auto listIn = input->allTensorsAlongDimension(dimensions);
|
|
||||||
|
|
||||||
NDArray *subArrIn, *subArrOut;
|
|
||||||
|
|
||||||
NDArray::prepareSpecialUse({output}, {input});
|
NDArray::prepareSpecialUse({output}, {input});
|
||||||
for(int i = 0; i < listIn->size(); ++i) { // listIn->size() = listOut->size()
|
|
||||||
subArrIn = listIn->at(i);
|
if (packX.numberOfTads() == 1) {
|
||||||
subArrOut = listOut->at(i);
|
BUILD_SINGLE_SELECTOR(input->dataType(), reverseArray, (context, input, output, 0), LIBND4J_TYPES);
|
||||||
BUILD_SINGLE_SELECTOR(input->dataType(), reverseArray, (context, subArrIn, subArrOut, 0), LIBND4J_TYPES);
|
} else {
|
||||||
|
BUILD_SINGLE_SELECTOR(input->dataType(), reverseTad, (context, input, output, packX.platformShapeInfo(), packX.platformOffsets(), packZ.platformShapeInfo(), packZ.platformOffsets(), (uint64_t) (input->lengthOf() / packX.numberOfTads())), LIBND4J_TYPES);
|
||||||
}
|
}
|
||||||
//BUILD_SINGLE_SELECTOR(input->dataType(), reverseArray, (context, const_cast<NDArray*>(input), output, (int)0), LIBND4J_TYPES);
|
|
||||||
NDArray::registerSpecialUse({output}, {input});
|
NDArray::registerSpecialUse({output}, {input});
|
||||||
delete listOut;
|
|
||||||
delete listIn;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
BUILD_SINGLE_TEMPLATE(template void reverseArray, (nd4j::LaunchContext * context, NDArray *inArr, NDArray *outArr, Nd4jLong numOfElemsToReverse), LIBND4J_TYPES);
|
BUILD_SINGLE_TEMPLATE(template void reverseArray, (nd4j::LaunchContext * context, const NDArray *inArr, NDArray *outArr, Nd4jLong numOfElemsToReverse), LIBND4J_TYPES);
|
||||||
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -3523,7 +3523,8 @@ TEST_F(DeclarableOpsTests1, Reverse_7 ) {
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
auto result = results->at(0);
|
auto result = results->at(0);
|
||||||
// result->printBuffer();
|
//expected.printIndexedBuffer("E");
|
||||||
|
//result->printIndexedBuffer("R");
|
||||||
|
|
||||||
ASSERT_TRUE(expected.isSameShapeStrict(result));
|
ASSERT_TRUE(expected.isSameShapeStrict(result));
|
||||||
ASSERT_TRUE(expected.equalsTo(result));
|
ASSERT_TRUE(expected.equalsTo(result));
|
||||||
|
|
|
@ -197,3 +197,44 @@ TEST_F(DeclarableOpsTests16, test_range_2) {
|
||||||
|
|
||||||
delete shapes;
|
delete shapes;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_F(DeclarableOpsTests16, test_reverse_1) {
|
||||||
|
std::vector<Nd4jLong> rows = {3, 5, 7, 8, 9, 10, 119, 211};
|
||||||
|
std::vector<Nd4jLong> columns = {6, 5, 10, 100, 153, 171, 635};
|
||||||
|
|
||||||
|
for (auto r : rows) {
|
||||||
|
for (auto c : columns) {
|
||||||
|
//nd4j_printf("Trying [%i, %i]\n", r, c);
|
||||||
|
auto array = NDArrayFactory::create<float>('c', {r, c});
|
||||||
|
auto exp = NDArrayFactory::create<float>('c', {r, c});
|
||||||
|
auto reversed = NDArrayFactory::create<float>('c', {r, c});
|
||||||
|
|
||||||
|
auto rowOriginal = NDArrayFactory::create<float>('c', {c});
|
||||||
|
auto rowReversed = NDArrayFactory::create<float>('c', {c});
|
||||||
|
|
||||||
|
for (int e = 0; e < c; e++) {
|
||||||
|
rowOriginal.p(e, (float) e);
|
||||||
|
rowReversed.p(c - e - 1, (float) e);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
auto listI = array.allTensorsAlongDimension({1});
|
||||||
|
auto listE = exp.allTensorsAlongDimension({1});
|
||||||
|
|
||||||
|
for (int e = 0; e < r; e++) {
|
||||||
|
listI->at(e)->assign(rowOriginal);
|
||||||
|
listE->at(e)->assign(rowReversed);
|
||||||
|
}
|
||||||
|
|
||||||
|
delete listI;
|
||||||
|
delete listE;
|
||||||
|
|
||||||
|
nd4j::ops::reverse op;
|
||||||
|
Nd4jLong axis = 1;
|
||||||
|
auto status = op.execute({&array}, {&reversed}, {}, {axis}, {});
|
||||||
|
ASSERT_EQ(Status::OK(), status);
|
||||||
|
|
||||||
|
ASSERT_EQ(exp, reversed);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -24,6 +24,7 @@
|
||||||
#include <NDArray.h>
|
#include <NDArray.h>
|
||||||
#include <ops/ops.h>
|
#include <ops/ops.h>
|
||||||
#include <GradCheck.h>
|
#include <GradCheck.h>
|
||||||
|
#include <chrono>
|
||||||
|
|
||||||
|
|
||||||
using namespace nd4j;
|
using namespace nd4j;
|
||||||
|
@ -58,5 +59,20 @@ TEST_F(DeclarableOpsTestsCuda1, Test_CHOOSE_SCALAR_LARGE) {
|
||||||
//ASSERT_TRUE(exp.isSameShape(z));
|
//ASSERT_TRUE(exp.isSameShape(z));
|
||||||
|
|
||||||
delete result;
|
delete result;
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/*
|
||||||
|
TEST_F(DeclarableOpsTestsCuda1, Test_Reverse_TAD_1) {
|
||||||
|
auto x = NDArrayFactory::create<float>('c', {1, 3, 608, 608});
|
||||||
|
auto z = x.like();
|
||||||
|
x.linspace(1.0f);
|
||||||
|
|
||||||
|
nd4j::ops::reverse op;
|
||||||
|
auto timeStart = std::chrono::system_clock::now();
|
||||||
|
auto status = op.execute({&x}, {&z}, {}, {1}, {});
|
||||||
|
auto timeEnd = std::chrono::system_clock::now();
|
||||||
|
auto outerTime = std::chrono::duration_cast<std::chrono::microseconds> (timeEnd - timeStart).count();
|
||||||
|
nd4j_printf("exec time: %lld us\n", outerTime);
|
||||||
|
ASSERT_EQ(Status::OK(), status);
|
||||||
|
}
|
||||||
|
*/
|
Loading…
Reference in New Issue