[WIP] reverse_sequence (#188)

* initial commit

Signed-off-by: raver119 <raver119@gmail.com>

* one more print

Signed-off-by: raver119 <raver119@gmail.com>

* minor fix

Signed-off-by: raver119 <raver119@gmail.com>

* reverse_sequence fix

Signed-off-by: raver119 <raver119@gmail.com>

* confusion_matrix test updated

Signed-off-by: raver119 <raver119@gmail.com>

* minor tweak

Signed-off-by: raver119 <raver119@gmail.com>

* minor tweak

Signed-off-by: raver119 <raver119@gmail.com>

* one more reverse_sequence test

Signed-off-by: raver119 <raver119@gmail.com>
master
raver119 2019-08-28 11:14:22 +03:00 committed by GitHub
parent 2a1431264f
commit 3157ec110c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 34 additions and 10 deletions

View File

@ -51,8 +51,8 @@ namespace helpers {
}
__syncthreads();
auto odd = length % 2 != 0;
auto limit = length / 2;
auto odd = numOfElemsToReverse % 2 != 0;
auto limit = numOfElemsToReverse / 2;
for (Nd4jLong e = tid; e < limit; e += step) {
// we're calculating offsets within input array
@ -102,11 +102,8 @@ namespace helpers {
seqLengths->syncToHost();
auto stream = context->getCudaStream();
NDArray::prepareSpecialUse({output}, {input, seqLengths});
if(input->isVector() || shape::isLikeVector(input->getShapeInfo(), posOfNonUnityDim) || seqLengths->lengthOf() == 1) {
int numOfElemsToReverse = seqLengths->e<int>(0);
// printf("Length %d\n", numOfElemsToReverse);
// input->printBuffer("INPUT");
if((seqDim == 0 && input->sizeAt(0) == 1) || (batchDim == posOfNonUnityDim))
output->assign(input);
else
@ -122,7 +119,6 @@ namespace helpers {
auto inSubArrsSet = input->allTensorsAlongDimension(dimensions);
auto outSubArrsSet = output->allTensorsAlongDimension(dimensions);
// #pragma omp parallel for schedule(guided) if(inSubArrsSet->size() > Environment::getInstance()->elementwiseThreshold())
for(int i = 0; i < inSubArrsSet->size(); ++i) {
int numOfElemsToReverse = seqLengths->e<int>(i);
@ -143,11 +139,18 @@ namespace helpers {
delete inSubArrsSet;
delete outSubArrsSet;
}
NDArray::registerSpecialUse({output}, {input, seqLengths});
}
void reverseSequence(nd4j::LaunchContext * context, const NDArray* input, const NDArray* seqLengths, NDArray* output, int seqDim, const int batchDim) {
NDArray::prepareSpecialUse({output}, {input, seqLengths});
// if op isn't inplace - copy original data into output array
if (output->getSpecialBuffer() != input->getSpecialBuffer())
output->assign(input);
BUILD_SINGLE_SELECTOR(input->dataType(), reverseSequence_, (context, input, seqLengths, output, seqDim, batchDim), LIBND4J_TYPES);
NDArray::registerSpecialUse({output}, {input, seqLengths});
}
//////////////////////////////////////////////////////////////////////////

File diff suppressed because one or more lines are too long