draw_bounding_boxes op implementation. Inital revision.

master
shugeo 2019-10-04 18:32:21 +03:00
parent b8f2a83a5a
commit 8f70b4441f
3 changed files with 97 additions and 0 deletions

View File

@ -0,0 +1,47 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// @author George A. Shulinok <sgazeos@gmail.com>
//
#include <op_boilerplate.h>
#if NOT_EXCLUDED(OP_draw_bounding_boxes)
//#include <ops/declarable/headers/parity_ops.h>
#include <ops/declarable/CustomOperations.h>
namespace nd4j {
namespace ops {
OP_IMPL(draw_bounding_boxes, 3, 1, true) {
auto image = INPUT_VARIABLE(0);
auto boxes = INPUT_VARIABLE(1);
auto colors = INPUT_VARIABLE(2);
return ND4J_STATUS_OK;
}
DECLARE_TYPES(draw_bounding_boxes) {
getOpDescriptor()
->setAllowedInputTypes(0, {HALF, FLOAT32})// TF allows HALF and FLOAT32 only
->setAllowedInputTypes(1, {FLOAT32}) // as TF
->setAllowedInputTypes(2, {FLOAT32}) // as TF
->setAllowedOutputTypes({HALF, FLOAT32}); // TF allows HALF and FLOAT32 only
}
}
}
#endif

View File

@ -1244,6 +1244,23 @@ namespace nd4j {
DECLARE_CUSTOM_OP(extract_image_patches, 1, 1, false, 0, 7);
#endif
/**
* draw_bounding_boxes op - modified input image with given colors exept given boxes.
*
* input params:
* 0 - images tensor (4D) with shape {batch, width, height, channels}, where channes is 1 (BW image),
* 3 (RGB) or 4 (RGBA)
* 1 - boxes tensor (3D) with shape {batch, number_of_boxes, 4} where last dimension encoded as
* (y_min, x_min, y_max, x_max), all values in between 0. and 1.
* 2 - colours tensor (2D) with shape {number_of_boxes, channels} -- bordering color set (palette)
*
* output:
* 0 - 4D tensor with same shape as images (input 0)
*/
#if NOT_EXCLUDED(OP_draw_bounding_boxes)
DECLARE_OP(draw_bounding_boxes, 3, 1, true);
#endif
/**
* roll - op porting from numpy (https://docs.scipy.org/doc/numpy-1.14.0/reference/generated/numpy.roll.html)
*

View File

@ -2043,6 +2043,39 @@ TEST_F(DeclarableOpsTests10, Image_CropAndResize_5) {
delete results;
}
////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests10, Image_DrawBoundingBoxes_1) {
int axis = 0;
NDArray images = NDArrayFactory::create<float>('c', {2,4,5,3});
NDArray boxes = NDArrayFactory::create<float>('c', {2, 2, 4}, {0,0,1,1});
NDArray colors = NDArrayFactory::create<float>('c', {2, 3}, {201., 202., 203., 128., 129., 130.});
//NDArray<float> ('c', {6}, {0.9f, .75f, .6f, .95f, .5f, .3f});
NDArray expected = NDArrayFactory::create<float>('c', {2,4,5,3}, {
127., 128., 129., 127., 128., 129., 127., 128., 129., 127., 128., 129., 201., 202., 203.,
127., 128., 129., 19., 20., 21., 22., 23., 24., 127., 128., 129., 201., 202., 203.,
127., 128., 129., 127., 128., 129., 127., 128., 129., 127., 128., 129., 201., 202., 203.,
201., 202., 203., 201. ,202. ,203., 201., 202., 203., 201., 202., 203., 201., 202., 203.,
61., 62., 63., 201., 202., 203., 201., 202., 203., 70., 71., 72., 73., 74., 75.,
76., 77., 78., 127., 128., 129., 127., 128., 129., 85., 86., 87., 88., 89., 90.,
91., 92., 93., 201., 202., 203., 201., 202., 203., 100., 101., 102., 103., 104., 105.,
106., 107., 108., 109., 110., 111., 112., 113., 114., 115., 116., 117., 118., 119., 120.
});
nd4j::ops::draw_bounding_boxes op;
auto results = op.execute({&images, &boxes, &colors}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto result = results->at(0);
result->printIndexedBuffer("Bounded boxes");
ASSERT_TRUE(expected.isSameShapeStrict(result));
ASSERT_TRUE(expected.equalsTo(result));
delete results;
}
////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests10, FakeQuantWithMinMaxVars_Test_1) {