The working implementation of draw_bounding_boxes op.

master
shugeo 2019-10-08 15:42:27 +03:00
parent 30a8af566c
commit 8fe5a1fa96
2 changed files with 12 additions and 11 deletions

View File

@ -40,6 +40,8 @@ namespace helpers {
// auto boxList = boxes->allTensorsAlongDimension({1, 2}); // split boxes by batch // auto boxList = boxes->allTensorsAlongDimension({1, 2}); // split boxes by batch
auto colorSet = colors->allTensorsAlongDimension({1}); auto colorSet = colors->allTensorsAlongDimension({1});
output->assign(images); // fill up all output with input images, then fill up boxes output->assign(images); // fill up all output with input images, then fill up boxes
PRAGMA_OMP_PARALLEL_FOR
for (auto b = 0; b < batchSize; ++b) { // loop by batch for (auto b = 0; b < batchSize; ++b) { // loop by batch
// auto image = imageList->at(b); // auto image = imageList->at(b);

View File

@ -29,32 +29,31 @@ namespace helpers {
Nd4jLong* boxesShape, T const* colors, Nd4jLong* colorsShape, T* output, Nd4jLong* outputShape, Nd4jLong* boxesShape, T const* colors, Nd4jLong* colorsShape, T* output, Nd4jLong* outputShape,
Nd4jLong batchSize, Nd4jLong width, Nd4jLong height, Nd4jLong channels, Nd4jLong colorSetSize) { Nd4jLong batchSize, Nd4jLong width, Nd4jLong height, Nd4jLong channels, Nd4jLong colorSetSize) {
for (auto b = blockIdx.x; b < batchSize; b += gridDim.x) { // loop by batch for (auto b = blockIdx.x; b < (int)batchSize; b += gridDim.x) { // loop by batch
for (auto c = threadIdx.x; c < colorSetSize; c += blockDim.x) { for (auto c = 0; c < colorSetSize; c++) {
// box with shape // box with shape
auto pos = channels * c;
auto internalBox = &boxes[b * colorSetSize * 4 + c * 4];//(*boxes)(b, {0})(c, {0});//internalBoxes->at(c); auto internalBox = &boxes[b * colorSetSize * 4 + c * 4];//(*boxes)(b, {0})(c, {0});//internalBoxes->at(c);
auto color = &colors[pos];//colorSet->at(c); auto color = &colors[channels * c];//colorSet->at(c);
auto rowStart = nd4j::math::nd4j_max(Nd4jLong (0), Nd4jLong ((height - 1) * internalBox[0])); auto rowStart = nd4j::math::nd4j_max(Nd4jLong (0), Nd4jLong ((height - 1) * internalBox[0]));
auto rowEnd = nd4j::math::nd4j_min(Nd4jLong (height - 1), Nd4jLong ((height - 1) * internalBox[2])); auto rowEnd = nd4j::math::nd4j_min(Nd4jLong (height - 1), Nd4jLong ((height - 1) * internalBox[2]));
auto colStart = nd4j::math::nd4j_max(Nd4jLong (0), Nd4jLong ((width - 1) * internalBox[1])); auto colStart = nd4j::math::nd4j_max(Nd4jLong (0), Nd4jLong ((width - 1) * internalBox[1]));
auto colEnd = nd4j::math::nd4j_min(Nd4jLong(width - 1), Nd4jLong ((width - 1) * internalBox[3])); auto colEnd = nd4j::math::nd4j_min(Nd4jLong(width - 1), Nd4jLong ((width - 1) * internalBox[3]));
for (auto y = rowStart; y <= rowEnd; y++) { for (auto y = rowStart + threadIdx.x; y <= rowEnd; y += blockDim.x) {
for (auto e = 0; e < channels; ++e) { for (auto e = 0; e < channels; ++e) {
Nd4jLong yMinPos[] = {b, y, colStart, e}; Nd4jLong yMinPos[] = {b, y, colStart, e};
Nd4jLong yMaxPos[] = {b, y, colEnd, e}; Nd4jLong yMaxPos[] = {b, y, colEnd, e};
auto zIndexYmin = shape::getOffset(outputShape, yMinPos, 0); auto zIndexYmin = shape::getOffset(outputShape, yMinPos);
auto zIndexYmax = shape::getOffset(outputShape, yMaxPos, 0); auto zIndexYmax = shape::getOffset(outputShape, yMaxPos);
output[zIndexYmin] = color[e]; output[zIndexYmin] = color[e];
output[zIndexYmax] = color[e]; output[zIndexYmax] = color[e];
} }
} }
for (auto x = colStart + 1; x < colEnd; x++) { for (auto x = colStart + 1 + threadIdx.x; x < colEnd; x += blockDim.x) {
for (auto e = 0; e < channels; ++e) { for (auto e = 0; e < channels; ++e) {
Nd4jLong xMinPos[] = {b, rowStart, x, e}; Nd4jLong xMinPos[] = {b, rowStart, x, e};
Nd4jLong xMaxPos[] = {b, rowEnd, x, e}; Nd4jLong xMaxPos[] = {b, rowEnd, x, e};
auto zIndexXmin = shape::getOffset(outputShape, xMinPos, 0); auto zIndexXmin = shape::getOffset(outputShape, xMinPos);
auto zIndexXmax = shape::getOffset(outputShape, xMaxPos, 0); auto zIndexXmax = shape::getOffset(outputShape, xMaxPos);
output[zIndexXmin] = color[e]; output[zIndexXmin] = color[e];
output[zIndexXmax] = color[e]; output[zIndexXmax] = color[e];
} }
@ -77,7 +76,7 @@ namespace helpers {
auto boxesBuf = boxes->getDataBuffer()->specialAsT<T>(); auto boxesBuf = boxes->getDataBuffer()->specialAsT<T>();
auto colorsBuf = colors->getDataBuffer()->specialAsT<T>(); auto colorsBuf = colors->getDataBuffer()->specialAsT<T>();
auto outputBuf = output->dataBuffer()->specialAsT<T>(); auto outputBuf = output->dataBuffer()->specialAsT<T>();
drawBoundingBoxesKernel<<<1, 1, 1024, *stream>>>(imagesBuf, images->getSpecialShapeInfo(), drawBoundingBoxesKernel<<<batchSize > 128? 128: batchSize, 256, 1024, *stream>>>(imagesBuf, images->getSpecialShapeInfo(),
boxesBuf, boxes->getSpecialShapeInfo(), colorsBuf, colors->getSpecialShapeInfo(), boxesBuf, boxes->getSpecialShapeInfo(), colorsBuf, colors->getSpecialShapeInfo(),
outputBuf, output->specialShapeInfo(), batchSize, width, height, channels, colorSetSize); outputBuf, output->specialShapeInfo(), batchSize, width, height, channels, colorSetSize);
} }