[WIP] reverse_sequence (#188)
* initial commit Signed-off-by: raver119 <raver119@gmail.com> * one more print Signed-off-by: raver119 <raver119@gmail.com> * minor fix Signed-off-by: raver119 <raver119@gmail.com> * reverse_sequence fix Signed-off-by: raver119 <raver119@gmail.com> * confusion_matrix test updated Signed-off-by: raver119 <raver119@gmail.com> * minor tweak Signed-off-by: raver119 <raver119@gmail.com> * minor tweak Signed-off-by: raver119 <raver119@gmail.com> * one more reverse_sequence test Signed-off-by: raver119 <raver119@gmail.com>master
parent
2a1431264f
commit
3157ec110c
|
@ -51,8 +51,8 @@ namespace helpers {
|
||||||
}
|
}
|
||||||
__syncthreads();
|
__syncthreads();
|
||||||
|
|
||||||
auto odd = length % 2 != 0;
|
auto odd = numOfElemsToReverse % 2 != 0;
|
||||||
auto limit = length / 2;
|
auto limit = numOfElemsToReverse / 2;
|
||||||
|
|
||||||
for (Nd4jLong e = tid; e < limit; e += step) {
|
for (Nd4jLong e = tid; e < limit; e += step) {
|
||||||
// we're calculating offsets within input array
|
// we're calculating offsets within input array
|
||||||
|
@ -102,11 +102,8 @@ namespace helpers {
|
||||||
seqLengths->syncToHost();
|
seqLengths->syncToHost();
|
||||||
auto stream = context->getCudaStream();
|
auto stream = context->getCudaStream();
|
||||||
|
|
||||||
NDArray::prepareSpecialUse({output}, {input, seqLengths});
|
|
||||||
if(input->isVector() || shape::isLikeVector(input->getShapeInfo(), posOfNonUnityDim) || seqLengths->lengthOf() == 1) {
|
if(input->isVector() || shape::isLikeVector(input->getShapeInfo(), posOfNonUnityDim) || seqLengths->lengthOf() == 1) {
|
||||||
int numOfElemsToReverse = seqLengths->e<int>(0);
|
int numOfElemsToReverse = seqLengths->e<int>(0);
|
||||||
// printf("Length %d\n", numOfElemsToReverse);
|
|
||||||
// input->printBuffer("INPUT");
|
|
||||||
if((seqDim == 0 && input->sizeAt(0) == 1) || (batchDim == posOfNonUnityDim))
|
if((seqDim == 0 && input->sizeAt(0) == 1) || (batchDim == posOfNonUnityDim))
|
||||||
output->assign(input);
|
output->assign(input);
|
||||||
else
|
else
|
||||||
|
@ -122,7 +119,6 @@ namespace helpers {
|
||||||
auto inSubArrsSet = input->allTensorsAlongDimension(dimensions);
|
auto inSubArrsSet = input->allTensorsAlongDimension(dimensions);
|
||||||
auto outSubArrsSet = output->allTensorsAlongDimension(dimensions);
|
auto outSubArrsSet = output->allTensorsAlongDimension(dimensions);
|
||||||
|
|
||||||
// #pragma omp parallel for schedule(guided) if(inSubArrsSet->size() > Environment::getInstance()->elementwiseThreshold())
|
|
||||||
for(int i = 0; i < inSubArrsSet->size(); ++i) {
|
for(int i = 0; i < inSubArrsSet->size(); ++i) {
|
||||||
|
|
||||||
int numOfElemsToReverse = seqLengths->e<int>(i);
|
int numOfElemsToReverse = seqLengths->e<int>(i);
|
||||||
|
@ -143,11 +139,18 @@ namespace helpers {
|
||||||
delete inSubArrsSet;
|
delete inSubArrsSet;
|
||||||
delete outSubArrsSet;
|
delete outSubArrsSet;
|
||||||
}
|
}
|
||||||
NDArray::registerSpecialUse({output}, {input, seqLengths});
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void reverseSequence(nd4j::LaunchContext * context, const NDArray* input, const NDArray* seqLengths, NDArray* output, int seqDim, const int batchDim) {
|
void reverseSequence(nd4j::LaunchContext * context, const NDArray* input, const NDArray* seqLengths, NDArray* output, int seqDim, const int batchDim) {
|
||||||
|
NDArray::prepareSpecialUse({output}, {input, seqLengths});
|
||||||
|
|
||||||
|
// if op isn't inplace - copy original data into output array
|
||||||
|
if (output->getSpecialBuffer() != input->getSpecialBuffer())
|
||||||
|
output->assign(input);
|
||||||
|
|
||||||
BUILD_SINGLE_SELECTOR(input->dataType(), reverseSequence_, (context, input, seqLengths, output, seqDim, batchDim), LIBND4J_TYPES);
|
BUILD_SINGLE_SELECTOR(input->dataType(), reverseSequence_, (context, input, seqLengths, output, seqDim, batchDim), LIBND4J_TYPES);
|
||||||
|
NDArray::registerSpecialUse({output}, {input, seqLengths});
|
||||||
}
|
}
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////////
|
||||||
|
|
File diff suppressed because one or more lines are too long
Loading…
Reference in New Issue