[WIP] reverse_sequence (#188)
* initial commit Signed-off-by: raver119 <raver119@gmail.com> * one more print Signed-off-by: raver119 <raver119@gmail.com> * minor fix Signed-off-by: raver119 <raver119@gmail.com> * reverse_sequence fix Signed-off-by: raver119 <raver119@gmail.com> * confusion_matrix test updated Signed-off-by: raver119 <raver119@gmail.com> * minor tweak Signed-off-by: raver119 <raver119@gmail.com> * minor tweak Signed-off-by: raver119 <raver119@gmail.com> * one more reverse_sequence test Signed-off-by: raver119 <raver119@gmail.com>master
parent
2a1431264f
commit
3157ec110c
|
@ -51,8 +51,8 @@ namespace helpers {
|
|||
}
|
||||
__syncthreads();
|
||||
|
||||
auto odd = length % 2 != 0;
|
||||
auto limit = length / 2;
|
||||
auto odd = numOfElemsToReverse % 2 != 0;
|
||||
auto limit = numOfElemsToReverse / 2;
|
||||
|
||||
for (Nd4jLong e = tid; e < limit; e += step) {
|
||||
// we're calculating offsets within input array
|
||||
|
@ -102,11 +102,8 @@ namespace helpers {
|
|||
seqLengths->syncToHost();
|
||||
auto stream = context->getCudaStream();
|
||||
|
||||
NDArray::prepareSpecialUse({output}, {input, seqLengths});
|
||||
if(input->isVector() || shape::isLikeVector(input->getShapeInfo(), posOfNonUnityDim) || seqLengths->lengthOf() == 1) {
|
||||
int numOfElemsToReverse = seqLengths->e<int>(0);
|
||||
// printf("Length %d\n", numOfElemsToReverse);
|
||||
// input->printBuffer("INPUT");
|
||||
if((seqDim == 0 && input->sizeAt(0) == 1) || (batchDim == posOfNonUnityDim))
|
||||
output->assign(input);
|
||||
else
|
||||
|
@ -122,7 +119,6 @@ namespace helpers {
|
|||
auto inSubArrsSet = input->allTensorsAlongDimension(dimensions);
|
||||
auto outSubArrsSet = output->allTensorsAlongDimension(dimensions);
|
||||
|
||||
// #pragma omp parallel for schedule(guided) if(inSubArrsSet->size() > Environment::getInstance()->elementwiseThreshold())
|
||||
for(int i = 0; i < inSubArrsSet->size(); ++i) {
|
||||
|
||||
int numOfElemsToReverse = seqLengths->e<int>(i);
|
||||
|
@ -143,11 +139,18 @@ namespace helpers {
|
|||
delete inSubArrsSet;
|
||||
delete outSubArrsSet;
|
||||
}
|
||||
NDArray::registerSpecialUse({output}, {input, seqLengths});
|
||||
|
||||
}
|
||||
|
||||
void reverseSequence(nd4j::LaunchContext * context, const NDArray* input, const NDArray* seqLengths, NDArray* output, int seqDim, const int batchDim) {
|
||||
NDArray::prepareSpecialUse({output}, {input, seqLengths});
|
||||
|
||||
// if op isn't inplace - copy original data into output array
|
||||
if (output->getSpecialBuffer() != input->getSpecialBuffer())
|
||||
output->assign(input);
|
||||
|
||||
BUILD_SINGLE_SELECTOR(input->dataType(), reverseSequence_, (context, input, seqLengths, output, seqDim, batchDim), LIBND4J_TYPES);
|
||||
NDArray::registerSpecialUse({output}, {input, seqLengths});
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
|
|
File diff suppressed because one or more lines are too long
Loading…
Reference in New Issue