Shugeo release fix1 (#61)

* Added a pair of tests for failed ops.

* Fixed cpu helper for draw_bounding_boxes op.

* Refactored implementation of draw_bounding_boxes op to full conform with TF.

* Improved multithreading with draw_bounding_boxes op cuda helper.

* Eliminated log messages.

* Changed logging with draw_bounding_boxes op helper and tests.

* Resize_biliear with 3D input allowed.

* Refactored 3D input acception with resize_bilinear op.

* And another improvement.

* Refactored reshape of input/output for resize_bilinear.

* Improvements final.

* Finished with 3D replication for image.resize_bilinear/_nearest_neighbor.

* Added copyrights for TF code.

* Using new form of multithreading for cpu implementation.

* Fixed shape error.

* Added multithreaded with batches on crop_and_resize functor.

* Refactored multithreading with crop_and_resize and draw_bounding_boxes.
master
shugeo 2019-11-20 13:37:48 +02:00 committed by GitHub
parent 59e955cedc
commit 13e5c0a280
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 565 additions and 106 deletions

View File

@ -57,6 +57,7 @@ namespace nd4j {
DECLARE_SHAPE_FN(crop_and_resize) { DECLARE_SHAPE_FN(crop_and_resize) {
auto in = inputShape->at(0); auto in = inputShape->at(0);
auto boxShape = inputShape->at(1);
Nd4jLong outputShape[4]; Nd4jLong outputShape[4];
@ -68,7 +69,7 @@ namespace nd4j {
width = newImageSize->e<int>(0); width = newImageSize->e<int>(0);
height = newImageSize->e<int>(1); height = newImageSize->e<int>(1);
outputShape[0] = in[1]; outputShape[0] = boxShape[1];
outputShape[1] = width; outputShape[1] = width;
outputShape[2] = height; outputShape[2] = height;
outputShape[3] = in[4]; outputShape[3] = in[4];

View File

@ -29,9 +29,31 @@ namespace nd4j {
auto images = INPUT_VARIABLE(0); auto images = INPUT_VARIABLE(0);
auto boxes = INPUT_VARIABLE(1); auto boxes = INPUT_VARIABLE(1);
auto colors = INPUT_VARIABLE(2);
auto output = OUTPUT_VARIABLE(0);
auto colors = (NDArray*) nullptr;
if (block.width() > 2) // TF v.1.x ommits color set for boxes, and use color 1.0 for fill up
colors = INPUT_VARIABLE(2); // but v.2.y require color set
auto output = OUTPUT_VARIABLE(0);
REQUIRE_TRUE(images->dataType() == output->dataType(), 0, "draw_bounding_boxes: Input and Output types "
"should be equals, but %d and %d occured.",
(int)images->dataType(), (int)output->dataType());
REQUIRE_TRUE(images->rankOf() == 4, 0, "draw_bounding_boxes: Images input should be 4D tensor, but %i occured.",
images->rankOf());
REQUIRE_TRUE(boxes->rankOf() == 3, 0, "draw_bounding_boxes: Boxes should be 3D tensor, but %i occured.",
boxes->rankOf());
if (colors) {
REQUIRE_TRUE(colors->rankOf() == 2, 0, "draw_bounding_boxes: Color set should be 2D matrix, but %i occured.",
colors->rankOf());
REQUIRE_TRUE(colors->sizeAt(1) >= images->sizeAt(3), 0, "draw_bounding_boxes: Color set last dim "
"should be not less than images depth, but "
"%lld and %lld occured.",
colors->sizeAt(1), images->sizeAt(3));
}
REQUIRE_TRUE(boxes->sizeAt(0) == images->sizeAt(0), 0, "draw_bounding_boxes: Batches for images and boxes "
"should be the same, but %lld and %lld occured.",
images->sizeAt(0), boxes->sizeAt(0));
helpers::drawBoundingBoxesFunctor(block.launchContext(), images, boxes, colors, output); helpers::drawBoundingBoxesFunctor(block.launchContext(), images, boxes, colors, output);
return ND4J_STATUS_OK; return ND4J_STATUS_OK;
} }

View File

@ -33,6 +33,15 @@ namespace nd4j {
int width; int width;
int height; int height;
bool center = false; // - default value bool center = false; // - default value
auto inRank = image->rankOf();
REQUIRE_TRUE( inRank == 4 || inRank == 3, 0, "resize_bilinear: input image should be 4D "
"tensor, but input has rank %i",
image->rankOf());
REQUIRE_TRUE(inRank == output->rankOf(), 0, "resize_bilinear: Input and output ranks should be equals, but %i and %i occured.", inRank, output->rankOf());
auto source = inRank == 4?image->reshape(image->ordering(), {image->sizeAt(0), image->sizeAt(1), image->sizeAt(2), image->sizeAt(3)}):image->reshape(image->ordering(), {1, image->sizeAt(0), image->sizeAt(1), image->sizeAt(2)});
auto target = inRank == 4?output->reshape(output->ordering(), {output->sizeAt(0), output->sizeAt(1), output->sizeAt(2), output->sizeAt(3)}):output->reshape(output->ordering(), {1, output->sizeAt(0), output->sizeAt(1), output->sizeAt(2)});
if (block.width() > 1) { if (block.width() > 1) {
auto newImageSize = INPUT_VARIABLE(1); auto newImageSize = INPUT_VARIABLE(1);
REQUIRE_TRUE(newImageSize->lengthOf() == 2, 0, "resize_bilinear: Resize params is a pair of values, not %i.", newImageSize->lengthOf()); REQUIRE_TRUE(newImageSize->lengthOf() == 2, 0, "resize_bilinear: Resize params is a pair of values, not %i.", newImageSize->lengthOf());
@ -51,7 +60,8 @@ namespace nd4j {
center = 0 != INT_ARG(2); center = 0 != INT_ARG(2);
} }
return helpers::resizeBilinearFunctor(block.launchContext(), image, width, height, center, output); auto res = helpers::resizeBilinearFunctor(block.launchContext(), &source, width, height, center, &target);
return res;
} }
DECLARE_SHAPE_FN(resize_bilinear) { DECLARE_SHAPE_FN(resize_bilinear) {
@ -59,6 +69,10 @@ namespace nd4j {
auto in = inputShape->at(0); auto in = inputShape->at(0);
Nd4jLong* outputShape; Nd4jLong* outputShape;
auto inRank = shape::rank(in);
REQUIRE_TRUE(inRank == 4 || inRank == 3, 0, "resize_bilinear: input image should be 4D "
"tensor, but input has rank %i",
inRank);
int width; int width;
int height; int height;
@ -75,12 +89,19 @@ namespace nd4j {
height = INT_ARG(1); height = INT_ARG(1);
} }
ALLOCATE(outputShape, block.getWorkspace(), shape::shapeInfoLength(4), Nd4jLong); ALLOCATE(outputShape, block.getWorkspace(), shape::shapeInfoLength(inRank), Nd4jLong);
outputShape[0] = 4; outputShape[0] = inRank;
outputShape[1] = in[1]; if (inRank == 4) {
outputShape[2] = width; outputShape[1] = in[1];
outputShape[3] = height; outputShape[2] = width;
outputShape[4] = in[4]; outputShape[3] = height;
outputShape[4] = in[4];
}
else { // input shape is 3D, so result also should be 3D
outputShape[1] = width;
outputShape[2] = height;
outputShape[3] = in[3];
}
ShapeUtils::updateStridesAndType(outputShape, in, shape::order(in)); ShapeUtils::updateStridesAndType(outputShape, in, shape::order(in));
shapeList->push_back(CONSTANT(outputShape)); shapeList->push_back(CONSTANT(outputShape));

View File

@ -51,16 +51,25 @@ namespace nd4j {
if (block.numI() == 3) if (block.numI() == 3)
center = 0 != INT_ARG(2); center = 0 != INT_ARG(2);
} }
auto inRank = image->rankOf();
REQUIRE_TRUE(inRank == 4 || inRank == 3, 0, "resize_nearest_neighbor: Input should be 4D tensor, but rank %i occured");
REQUIRE_TRUE(inRank == output->rankOf(), 0, "resize_nearest_neighbor: Input and output ranks should be equals, but %i and %i occured.", inRank, output->rankOf());
auto source = inRank == 4?image->reshape(image->ordering(), {image->sizeAt(0), image->sizeAt(1), image->sizeAt(2), image->sizeAt(3)}):image->reshape(image->ordering(), {1, image->sizeAt(0), image->sizeAt(1), image->sizeAt(2)});
auto target = inRank == 4?output->reshape(output->ordering(), {output->sizeAt(0), output->sizeAt(1), output->sizeAt(2), output->sizeAt(3)}):output->reshape(output->ordering(), {1, output->sizeAt(0), output->sizeAt(1), output->sizeAt(2)});
return helpers::resizeNeighborFunctor(block.launchContext(), image, width, height, center, output); return helpers::resizeNeighborFunctor(block.launchContext(), &source, width, height, center, &target);
} }
DECLARE_SHAPE_FN(resize_nearest_neighbor) { DECLARE_SHAPE_FN(resize_nearest_neighbor) {
auto shapeList = SHAPELIST(); auto shapeList = SHAPELIST();
auto in = inputShape->at(0); auto in = inputShape->at(0);
auto inRank = shape::rank(in);
Nd4jLong* outputShape; Nd4jLong* outputShape;
REQUIRE_TRUE(inRank == 4 || inRank == 3, 0, "resize_bilinear: input image should be 4D "
"tensor, but input has rank %i",
inRank);
int width; int width;
int height; int height;
if (block.width() > 1) { if (block.width() > 1) {
@ -76,12 +85,19 @@ namespace nd4j {
height = INT_ARG(1); height = INT_ARG(1);
} }
ALLOCATE(outputShape, block.getWorkspace(), shape::shapeInfoLength(4), Nd4jLong); ALLOCATE(outputShape, block.getWorkspace(), shape::shapeInfoLength(inRank), Nd4jLong);
outputShape[0] = 4; outputShape[0] = inRank;
outputShape[1] = in[1]; if (inRank == 4) {
outputShape[2] = width; outputShape[1] = in[1];
outputShape[3] = height; outputShape[2] = width;
outputShape[4] = in[4]; outputShape[3] = height;
outputShape[4] = in[4];
}
else { // input shape is 3D, so result also should be 3D
outputShape[1] = width;
outputShape[2] = height;
outputShape[3] = in[3];
}
ShapeUtils::updateStridesAndType(outputShape, in, shape::order(in)); ShapeUtils::updateStridesAndType(outputShape, in, shape::order(in));
shapeList->push_back(CONSTANT(outputShape)); shapeList->push_back(CONSTANT(outputShape));

View File

@ -47,7 +47,7 @@ namespace helpers {
if (zeroPointFromMin > quantMaxF) { if (zeroPointFromMin > quantMaxF) {
return static_cast<uint16_t>(quantMax); return static_cast<uint16_t>(quantMax);
} }
return nd4j::math::nd4j_round<T,uint16_t>(zeroPointFromMin); return (uint16_t)nd4j::math::nd4j_round<T,int>(zeroPointFromMin);
}(); }();
// compute nudged min and max with computed nudged zero point // compute nudged min and max with computed nudged zero point
*nudgedMin = (quantMinF - nudged_zero_point) * (*scale); *nudgedMin = (quantMinF - nudged_zero_point) * (*scale);

View File

@ -13,16 +13,52 @@
* *
* SPDX-License-Identifier: Apache-2.0 * SPDX-License-Identifier: Apache-2.0
******************************************************************************/ ******************************************************************************/
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// //
// @author sgazeos@gmail.com // @author sgazeos@gmail.com
// //
#include <op_boilerplate.h> #include <op_boilerplate.h>
#include <NDArray.h> #include <NDArray.h>
#include <execution/Threads.h>
namespace nd4j { namespace nd4j {
namespace ops { namespace ops {
namespace helpers { namespace helpers {
typedef std::vector<std::vector<float>> ColorTable_t;
static ColorTable_t DefaultColorTable(int depth) {
std::vector<std::vector<float>> colorTable;
colorTable.emplace_back(std::vector<float>({1, 1, 0, 1})); // 0: yellow
colorTable.emplace_back(std::vector<float>({0, 0, 1, 1})); // 1: blue
colorTable.emplace_back(std::vector<float>({1, 0, 0, 1})); // 2: red
colorTable.emplace_back(std::vector<float>({0, 1, 0, 1})); // 3: lime
colorTable.emplace_back(std::vector<float>({0.5, 0, 0.5, 1})); // 4: purple
colorTable.emplace_back(std::vector<float>({0.5, 0.5, 0, 1})); // 5: olive
colorTable.emplace_back(std::vector<float>({0.5, 0, 0, 1})); // 6: maroon
colorTable.emplace_back(std::vector<float>({0, 0, 0.5, 1})); // 7: navy blue
colorTable.emplace_back(std::vector<float>({0, 1, 1, 1})); // 8: aqua
colorTable.emplace_back(std::vector<float>({1, 0, 1, 1})); // 9: fuchsia
if (depth == 1) {
for (Nd4jLong i = 0; i < colorTable.size(); i++) {
colorTable[i][0] = 1;
}
}
return colorTable;
}
void drawBoundingBoxesFunctor(nd4j::LaunchContext * context, NDArray* images, NDArray* boxes, NDArray* colors, NDArray* output) { void drawBoundingBoxesFunctor(nd4j::LaunchContext * context, NDArray* images, NDArray* boxes, NDArray* colors, NDArray* output) {
// images - batch of 3D images with BW (last dim = 1), RGB (last dim = 3) or RGBA (last dim = 4) channel set // images - batch of 3D images with BW (last dim = 1), RGB (last dim = 3) or RGBA (last dim = 4) channel set
@ -33,44 +69,107 @@ namespace helpers {
// colors - colors for each box given // colors - colors for each box given
// set up color for each box as frame // set up color for each box as frame
auto batchSize = images->sizeAt(0); auto batchSize = images->sizeAt(0);
auto boxSize = boxes->sizeAt(0);
auto height = images->sizeAt(1); auto height = images->sizeAt(1);
auto width = images->sizeAt(2); auto width = images->sizeAt(2);
auto channels = images->sizeAt(3); auto channels = images->sizeAt(3);
//auto imageList = images->allTensorsAlongDimension({1, 2, 3}); // split images by batch //auto imageList = images->allTensorsAlongDimension({1, 2, 3}); // split images by batch
// auto boxList = boxes->allTensorsAlongDimension({1, 2}); // split boxes by batch // auto boxList = boxes->allTensorsAlongDimension({1, 2}); // split boxes by batch
auto colorSet = colors->allTensorsAlongDimension({1}); //auto colorSet = colors->allTensorsAlongDimension({0});
output->assign(images); // fill up all output with input images, then fill up boxes output->assign(images); // fill up all output with input images, then fill up boxes
ColorTable_t colorTable;
PRAGMA_OMP_PARALLEL_FOR if (colors) {
for (auto b = 0; b < batchSize; ++b) { // loop by batch for (auto i = 0; i < colors->sizeAt(0); i++) {
// auto image = imageList->at(b); std::vector<float> colorValue(4);
for (auto j = 0; j < 4; j++) {
for (auto c = 0; c < colorSet->size(); ++c) { colorValue[j] = j < colors->sizeAt(1) ? colors->e<float>(i, j) : 1.f;
// box with shape
auto internalBox = (*boxes)(b, {0})(c, {0});//internalBoxes->at(c);
auto color = colorSet->at(c);
auto rowStart = nd4j::math::nd4j_max(Nd4jLong (0), Nd4jLong ((height - 1) * internalBox.e<float>(0)));
auto rowEnd = nd4j::math::nd4j_min(Nd4jLong (height - 1), Nd4jLong ((height - 1) * internalBox.e<float>(2)));
auto colStart = nd4j::math::nd4j_max(Nd4jLong (0), Nd4jLong ((width - 1) * internalBox.e<float>(1)));
auto colEnd = nd4j::math::nd4j_min(Nd4jLong(width - 1), Nd4jLong ((width - 1) * internalBox.e<float>(3)));
for (auto y = rowStart; y <= rowEnd; y++) {
for (auto e = 0; e < color->lengthOf(); ++e) {
output->p(b, y, colStart, e, color->e(e));
output->p(b, y, colEnd, e, color->e(e));
}
}
for (auto x = colStart + 1; x < colEnd; x++) {
for (auto e = 0; e < color->lengthOf(); ++e) {
output->p(b, rowStart, x, e, color->e(e));
output->p(b, rowEnd, x, e, color->e(e));
}
} }
colorTable.emplace_back(colorValue);
} }
// delete internalBoxes;
} }
delete colorSet; if (colorTable.empty())
// delete imageList; colorTable = DefaultColorTable(channels);
// delete boxList; // PRAGMA_OMP_PARALLEL_FOR
auto func = PRAGMA_THREADS_FOR {
for (auto batch = 0; batch < batchSize; ++batch) { // loop by batch
// auto image = imageList->at(batch);
const Nd4jLong numBoxes = boxes->sizeAt(1);
for (auto boxIndex = 0; boxIndex < numBoxes; ++boxIndex) {
// box with shape
//auto internalBox = (*boxes)(batch, {0})(boxIndex, {0});//internalBoxes->at(c);
//auto color = colorSet->at(c);
//internalBox.printIndexedBuffer("Current Box");
auto colorIndex = boxIndex % colorTable.size();
auto rowStart = Nd4jLong((height - 1) * boxes->t<float>(batch, boxIndex, 0));
auto rowStartBound = nd4j::math::nd4j_max(Nd4jLong(0), rowStart);
auto rowEnd = Nd4jLong((height - 1) * boxes->t<float>(batch, boxIndex, 2));
auto rowEndBound = nd4j::math::nd4j_min(Nd4jLong(height - 1), rowEnd);
auto colStart = Nd4jLong((width - 1) * boxes->t<float>(batch, boxIndex, 1));
auto colStartBound = nd4j::math::nd4j_max(Nd4jLong(0), colStart);
auto colEnd = Nd4jLong((width - 1) * boxes->t<float>(batch, boxIndex, 3));
auto colEndBound = nd4j::math::nd4j_min(Nd4jLong(width - 1), colEnd);
if (rowStart > rowEnd || colStart > colEnd) {
nd4j_debug(
"helpers::drawBoundingBoxesFunctor: Bounding box (%lld, %lld, %lld, %lld) is inverted "
"and will not be drawn\n", rowStart, colStart, rowEnd, colEnd);
continue;
}
if (rowStart >= height || rowEnd < 0 || colStart >= width ||
colEnd < 0) {
nd4j_debug(
"helpers::drawBoundingBoxesFunctor: Bounding box (%lld, %lld, %lld, %lld) is completely "
"outside the image and not be drawn\n ", rowStart, colStart, rowEnd, colEnd);
continue;
}
// Draw upper line
if (rowStart >= 0) {
for (auto j = colStartBound; j <= colEndBound; ++j)
for (auto c = 0; c < channels; c++) {
output->p(batch, rowStart, j, c, colorTable[colorIndex][c]);
}
}
// Draw bottom line.
if (rowEnd < height) {
for (auto j = colStartBound; j <= colEndBound; ++j)
for (auto c = 0; c < channels; c++) {
output->p(batch, rowEnd, j, c, colorTable[colorIndex][c]);
}
}
// Draw left line.
if (colStart >= 0) {
for (auto i = rowStartBound; i <= rowEndBound; ++i)
for (auto c = 0; c < channels; c++) {
output->p(batch, i, colStart, c, colorTable[colorIndex][c]);
}
}
// Draw right line.
if (colEnd < width) {
for (auto i = rowStartBound; i <= rowEndBound; ++i)
for (auto c = 0; c < channels; c++) {
output->p(batch, i, colEnd, c, colorTable[colorIndex][c]);
}
}
// for (auto y = rowStart; y <= rowEnd; y++) {
// for (auto e = 0; e < channels; ++e) {
// output->p(batch, y, colStart, e, colorTable[colorIndex][e]);
// output->p(batch, y, colEnd, e, colorTable[colorIndex][e]);
// }
// }
// for (auto x = colStart + 1; x < colEnd; x++) {
// for (auto e = 0; e < channels; ++e) {
// output->p(batch, rowStart, x, e, colorTable[colorIndex][e]);
// output->p(batch, rowEnd, x, e, colorTable[colorIndex][e]);
// }
// }
}
// delete internalBoxes;
}
};
samediff::Threads::parallel_tad(func, 0, batchSize);
} }
} }

View File

@ -14,6 +14,20 @@
* *
* SPDX-License-Identifier: Apache-2.0 * SPDX-License-Identifier: Apache-2.0
******************************************************************************/ ******************************************************************************/
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// //
// @author sgazeos@gmail.com // @author sgazeos@gmail.com
@ -85,30 +99,32 @@ namespace helpers {
return top + (bottom - top) * yVal; return top + (bottom - top) * yVal;
}; };
// FIXME: fix parallelism here auto func = PRAGMA_THREADS_FOR {
for (Nd4jLong b = 0; b < batchSize; ++b) { for (Nd4jLong b = 0; b < batchSize; ++b) {
for (Nd4jLong y = 0; y < outHeight; ++y) { for (Nd4jLong y = 0; y < outHeight; ++y) {
const T *ys_input_lower_ptr = input_b_ptr + ys[y].bottomIndex * inRowSize; const T *ys_input_lower_ptr = input_b_ptr + ys[y].bottomIndex * inRowSize;
const T *ys_input_upper_ptr = input_b_ptr + ys[y].topIndex * inRowSize; const T *ys_input_upper_ptr = input_b_ptr + ys[y].topIndex * inRowSize;
double yVal = ys[y].interpolarValue; double yVal = ys[y].interpolarValue;
for (Nd4jLong x = 0; x < outWidth; ++x) { for (Nd4jLong x = 0; x < outWidth; ++x) {
auto xsBottom = xs_[x].bottomIndex; auto xsBottom = xs_[x].bottomIndex;
auto xsTop = xs_[x].topIndex; auto xsTop = xs_[x].topIndex;
auto xVal = xs_[x].interpolarValue; auto xVal = xs_[x].interpolarValue;
for (int c = 0; c < channels; ++c) { for (int c = 0; c < channels; ++c) {
double topLeft(ys_input_lower_ptr[xsBottom + c]); double topLeft(ys_input_lower_ptr[xsBottom + c]);
double topRight(ys_input_lower_ptr[xsTop + c]); double topRight(ys_input_lower_ptr[xsTop + c]);
double bottomLeft(ys_input_upper_ptr[xsBottom + c]); double bottomLeft(ys_input_upper_ptr[xsBottom + c]);
double bottomRight(ys_input_upper_ptr[xsTop + c]); double bottomRight(ys_input_upper_ptr[xsTop + c]);
output_y_ptr[x * channels + c] = output_y_ptr[x * channels + c] =
computeBilinear(topLeft, topRight, bottomLeft, bottomRight, computeBilinear(topLeft, topRight, bottomLeft, bottomRight,
xVal, yVal); xVal, yVal);
}
} }
output_y_ptr += outRowSize;
} }
output_y_ptr += outRowSize; input_b_ptr += inBatchNumValues;
} }
input_b_ptr += inBatchNumValues; };
} samediff::Threads::parallel_tad(func, 0, batchSize);
} }
template<typename T> template<typename T>

View File

@ -24,39 +24,112 @@ namespace nd4j {
namespace ops { namespace ops {
namespace helpers { namespace helpers {
template <typename T> typedef NDArray ColorTable_t;
static __global__ void drawBoundingBoxesKernel(T const* images, Nd4jLong* imagesShape, T const* boxes, static NDArray DefaultColorTable(int depth) {
Nd4jLong* boxesShape, T const* colors, Nd4jLong* colorsShape, T* output, Nd4jLong* outputShape, //std::vector<std::vector<float>> colorTable;
Nd4jLong batchSize, Nd4jLong width, Nd4jLong height, Nd4jLong channels, Nd4jLong colorSetSize) { const Nd4jLong kDefaultTableLength = 10;
const Nd4jLong kDefaultChannelLength = 4;
NDArray colorTable('c', {kDefaultTableLength, kDefaultChannelLength}, {
1,1,0,1, // yellow
0, 0, 1, 1, // 1: blue
1, 0, 0, 1, // 2: red
0, 1, 0, 1, // 3: lime
0.5, 0, 0.5, 1, // 4: purple
0.5, 0.5, 0, 1, // 5: olive
0.5, 0, 0, 1, // 6: maroon
0, 0, 0.5, 1, // 7: navy blue
0, 1, 1, 1, // 8: aqua
1, 0, 1, 1 // 9: fuchsia
}, DataType::FLOAT32);
for (auto b = blockIdx.x; b < (int)batchSize; b += gridDim.x) { // loop by batch if (depth == 1) {
for (auto c = 0; c < colorSetSize; c++) { colorTable.assign(1.f); // all to white when black and white colors
}
return colorTable;
}
template <typename T>
static __global__ void drawBoundingBoxesKernel(T const* images, Nd4jLong* imagesShape, float const* boxes,
Nd4jLong* boxesShape, float const* colorTable, Nd4jLong* colorTableShape, T* output, Nd4jLong* outputShape,
Nd4jLong batchSize, Nd4jLong width, Nd4jLong height, Nd4jLong channels, Nd4jLong boxSize, Nd4jLong colorTableLen) {
for (auto batch = blockIdx.x; batch < (int)batchSize; batch += gridDim.x) { // loop by batch
for (auto boxIndex = 0; boxIndex < boxSize; ++boxIndex) {
// box with shape // box with shape
auto internalBox = &boxes[b * colorSetSize * 4 + c * 4];//(*boxes)(b, {0})(c, {0});//internalBoxes->at(c); //auto internalBox = &boxes[b * colorSetSize * 4 + c * 4];//(*boxes)(b, {0})(c, {0});//internalBoxes->at(c);
auto color = &colors[channels * c];//colorSet->at(c); auto colorIndex = boxIndex % colorTableLen;//colorSet->at(c);
auto rowStart = nd4j::math::nd4j_max(Nd4jLong (0), Nd4jLong ((height - 1) * internalBox[0])); // auto rowStart = nd4j::math::nd4j_max(Nd4jLong (0), Nd4jLong ((height - 1) * internalBox[0]));
auto rowEnd = nd4j::math::nd4j_min(Nd4jLong (height - 1), Nd4jLong ((height - 1) * internalBox[2])); // auto rowEnd = nd4j::math::nd4j_min(Nd4jLong (height - 1), Nd4jLong ((height - 1) * internalBox[2]));
auto colStart = nd4j::math::nd4j_max(Nd4jLong (0), Nd4jLong ((width - 1) * internalBox[1])); // auto colStart = nd4j::math::nd4j_max(Nd4jLong (0), Nd4jLong ((width - 1) * internalBox[1]));
auto colEnd = nd4j::math::nd4j_min(Nd4jLong(width - 1), Nd4jLong ((width - 1) * internalBox[3])); // auto colEnd = nd4j::math::nd4j_min(Nd4jLong(width - 1), Nd4jLong ((width - 1) * internalBox[3]));
for (auto y = rowStart + threadIdx.x; y <= rowEnd; y += blockDim.x) { Nd4jLong indices0[] = {batch, boxIndex, 0};
for (auto e = 0; e < channels; ++e) { Nd4jLong indices1[] = {batch, boxIndex, 1};
Nd4jLong yMinPos[] = {b, y, colStart, e}; Nd4jLong indices2[] = {batch, boxIndex, 2};
Nd4jLong yMaxPos[] = {b, y, colEnd, e}; Nd4jLong indices3[] = {batch, boxIndex, 3};
auto zIndexYmin = shape::getOffset(outputShape, yMinPos); auto rowStart = Nd4jLong ((height - 1) * boxes[shape::getOffset(boxesShape, indices0, 0)]);
auto zIndexYmax = shape::getOffset(outputShape, yMaxPos); auto rowStartBound = nd4j::math::nd4j_max(Nd4jLong (0), rowStart);
output[zIndexYmin] = color[e]; auto rowEnd = Nd4jLong ((height - 1) * boxes[shape::getOffset(boxesShape, indices2, 0)]);
output[zIndexYmax] = color[e]; auto rowEndBound = nd4j::math::nd4j_min(Nd4jLong (height - 1), rowEnd);
} auto colStart = Nd4jLong ((width - 1) * boxes[shape::getOffset(boxesShape, indices1, 0)]);
auto colStartBound = nd4j::math::nd4j_max(Nd4jLong (0), colStart);
auto colEnd = Nd4jLong ((width - 1) * boxes[shape::getOffset(boxesShape, indices3, 0)]);
auto colEndBound = nd4j::math::nd4j_min(Nd4jLong(width - 1), colEnd);
if (rowStart > rowEnd || colStart > colEnd) {
// printf("helpers::drawBoundingBoxesFunctor: Bounding box (%lld, %lld, %lld, %lld) is inverted "
// "and will not be drawn\n", rowStart, colStart, rowEnd, colEnd);
continue;
} }
for (auto x = colStart + 1 + threadIdx.x; x < colEnd; x += blockDim.x) { if (rowStart >= height || rowEnd < 0 || colStart >= width ||
for (auto e = 0; e < channels; ++e) { colEnd < 0) {
Nd4jLong xMinPos[] = {b, rowStart, x, e}; // printf("helpers::drawBoundingBoxesFunctor: Bounding box (%lld, %lld, %lld, %lld) is completely "
Nd4jLong xMaxPos[] = {b, rowEnd, x, e}; // "outside the image and not be drawn\n", rowStart, colStart, rowEnd, colEnd);
auto zIndexXmin = shape::getOffset(outputShape, xMinPos); continue;
auto zIndexXmax = shape::getOffset(outputShape, xMaxPos); }
output[zIndexXmin] = color[e];
output[zIndexXmax] = color[e]; // Draw upper line
} if (rowStart >= 0) {
for (auto j = colStartBound + threadIdx.x; j <= colEndBound; j += blockDim.x)
for (auto c = 0; c < channels; c++) {
Nd4jLong zPos[] = {batch, rowStart, j, c};
Nd4jLong cPos[] = {colorIndex, c};
auto cIndex = shape::getOffset(colorTableShape, cPos, 0);
auto zIndex = shape::getOffset(outputShape, zPos, 0);
output[zIndex] = (T)colorTable[cIndex];
}
}
// Draw bottom line.
if (rowEnd < height) {
for (auto j = colStartBound + threadIdx.x; j <= colEndBound; j += blockDim.x)
for (auto c = 0; c < channels; c++) {
Nd4jLong zPos[] = {batch, rowEnd, j, c};
Nd4jLong cPos[] = {colorIndex, c};
auto cIndex = shape::getOffset(colorTableShape, cPos, 0);
auto zIndex = shape::getOffset(outputShape, zPos, 0);
output[zIndex] = (T)colorTable[cIndex];
}
}
// Draw left line.
if (colStart >= 0) {
for (auto i = rowStartBound + threadIdx.x; i <= rowEndBound; i += blockDim.x)
for (auto c = 0; c < channels; c++) {
Nd4jLong zPos[] = {batch, i, colStart, c};
Nd4jLong cPos[] = {colorIndex, c};
auto cIndex = shape::getOffset(colorTableShape, cPos, 0);
auto zIndex = shape::getOffset(outputShape, zPos, 0);
output[zIndex] = (T)colorTable[cIndex];
}
}
// Draw right line.
if (colEnd < width) {
for (auto i = rowStartBound + threadIdx.x; i <= rowEndBound; i += blockDim.x)
for (auto c = 0; c < channels; c++) {
Nd4jLong zPos[] = {batch, i, colEnd, c};
Nd4jLong cPos[] = {colorIndex, c};
auto cIndex = shape::getOffset(colorTableShape, cPos, 0);
auto zIndex = shape::getOffset(outputShape, zPos, 0);
output[zIndex] = (T)colorTable[cIndex];
}
} }
} }
} }
@ -70,15 +143,19 @@ namespace helpers {
auto width = images->sizeAt(2); auto width = images->sizeAt(2);
auto channels = images->sizeAt(3); auto channels = images->sizeAt(3);
auto stream = context->getCudaStream(); auto stream = context->getCudaStream();
auto colorSetSize = colors->sizeAt(0); auto boxSize = boxes->sizeAt(1);
NDArray colorsTable = DefaultColorTable(channels);
if ((colors != nullptr && colors->lengthOf() > 0)) {
colorsTable = *colors;
}
auto imagesBuf = images->getDataBuffer()->specialAsT<T>(); auto imagesBuf = images->getDataBuffer()->specialAsT<T>();
auto boxesBuf = boxes->getDataBuffer()->specialAsT<T>(); auto boxesBuf = boxes->getDataBuffer()->specialAsT<float>(); // boxes should be float32
auto colorsBuf = colors->getDataBuffer()->specialAsT<T>(); auto colorsTableBuf = colorsTable.getDataBuffer()->specialAsT<float>(); // color table is float32
auto outputBuf = output->dataBuffer()->specialAsT<T>(); auto outputBuf = output->dataBuffer()->specialAsT<T>();
drawBoundingBoxesKernel<<<batchSize > 128? 128: batchSize, 256, 1024, *stream>>>(imagesBuf, images->getSpecialShapeInfo(), drawBoundingBoxesKernel<<<128, 128, 1024, *stream>>>(imagesBuf, images->getSpecialShapeInfo(),
boxesBuf, boxes->getSpecialShapeInfo(), colorsBuf, colors->getSpecialShapeInfo(), boxesBuf, boxes->getSpecialShapeInfo(), colorsTableBuf, colorsTable.getSpecialShapeInfo(),
outputBuf, output->specialShapeInfo(), batchSize, width, height, channels, colorSetSize); outputBuf, output->specialShapeInfo(), batchSize, width, height, channels, boxSize, colorsTable.lengthOf());
} }
void drawBoundingBoxesFunctor(nd4j::LaunchContext * context, NDArray* images, NDArray* boxes, NDArray* colors, NDArray* output) { void drawBoundingBoxesFunctor(nd4j::LaunchContext * context, NDArray* images, NDArray* boxes, NDArray* colors, NDArray* output) {

View File

@ -1470,6 +1470,64 @@ TEST_F(DeclarableOpsTests10, ImageResizeBilinear_Test1) {
delete results; delete results;
} }
TEST_F(DeclarableOpsTests10, ImageResizeBilinear_Test01) {
NDArray input = NDArrayFactory::create<double>('c', {2,3,4});
//NDArray<float> paddings('c', {3,2}, {0,0, 0,1, 0,0});
//NDArray<float> expected('c', {2,4,4}, {1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,0.,0.,0.,0.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,0.,0.,0.,0.});
NDArray expected = NDArrayFactory::create<double>('c', {10, 10, 4}, {1., 2., 3., 4., 2.2, 3.2, 4.2, 5.2, 3.4, 4.4, 5.4, 6.4,
4.6, 5.6, 6.6, 7.6, 5.8, 6.8, 7.8, 8.8, 7., 8., 9., 10.,
8.2, 9.2, 10.2, 11.2, 9., 10., 11., 12., 9., 10., 11., 12.,
9., 10., 11., 12., 3.4, 4.4, 5.4, 6.4, 4.6, 5.6, 6.6, 7.6,
5.8, 6.8, 7.8, 8.8, 7.0, 8., 9., 10., 8.2, 9.2, 10.2, 11.2,
9.4,10.4, 11.4, 12.4,10.6, 11.6,12.6, 13.6,11.4, 12.4, 13.4, 14.4,
11.4,12.4, 13.4, 14.4,11.4, 12.4,13.4, 14.4, 5.8, 6.8, 7.8, 8.8,
7., 8., 9., 10., 8.2, 9.2,10.2, 11.2, 9.4, 10.4, 11.4, 12.4,
10.6,11.6, 12.6, 13.6,11.8, 12.8,13.8, 14.8,13.0, 14.0, 15.0, 16.,
13.8,14.8, 15.8, 16.8,13.8, 14.8,15.8, 16.8,13.8, 14.8, 15.8, 16.8,
8.2, 9.2, 10.2, 11.2, 9.4, 10.4,11.4, 12.4,10.6, 11.6, 12.6, 13.6,
11.8,12.8, 13.8, 14.8,13., 14., 15., 16., 14.2, 15.2, 16.2, 17.2,
15.4,16.4, 17.4, 18.4,16.2, 17.2,18.2, 19.2,16.2, 17.2, 18.2, 19.2,
16.2,17.2, 18.2, 19.2,10.6, 11.6,12.6, 13.6,11.8, 12.8, 13.8, 14.8,
13., 14., 15., 16., 14.2, 15.2,16.2, 17.2,15.4, 16.4, 17.4, 18.4,
16.6,17.6, 18.6, 19.6,17.8, 18.8,19.8, 20.8,18.6, 19.6, 20.6, 21.6,
18.6,19.6, 20.6, 21.6,18.6, 19.6,20.6, 21.6,13., 14., 15., 16.,
14.2,15.2, 16.2, 17.2,15.4, 16.4,17.4, 18.4,16.6, 17.6, 18.6, 19.6,
17.8,18.8, 19.8, 20.8,19., 20., 21., 22., 20.2, 21.2, 22.2, 23.2,
21., 22., 23., 24., 21., 22., 23., 24., 21., 22., 23., 24.,
13., 14., 15., 16., 14.2, 15.2,16.2, 17.2,15.4, 16.4, 17.4, 18.4,
16.6,17.6, 18.6, 19.6,17.8, 18.8, 19.8, 20.8,19., 20., 21., 22.,
20.2,21.2, 22.2, 23.2,21., 22., 23., 24., 21., 22., 23., 24.,
21., 22., 23., 24., 13., 14., 15., 16., 14.2, 15.2, 16.2, 17.2,
15.4,16.4, 17.4, 18.4,16.6, 17.6, 18.6, 19.6,17.8, 18.8, 19.8, 20.8,
19., 20., 21., 22., 20.2, 21.2, 22.2, 23.2,21., 22., 23., 24.,
21., 22., 23., 24., 21., 22., 23., 24., 13., 14., 15., 16.,
14.2,15.2, 16.2, 17.2,15.4, 16.4, 17.4, 18.4,16.6, 17.6, 18.6, 19.6,
17.8,18.8, 19.8, 20.8,19., 20., 21., 22., 20.2, 21.2, 22.2, 23.2,
21., 22., 23., 24., 21., 22., 23., 24., 21., 22., 23., 24.,
13., 14., 15., 16., 14.2, 15.2, 16.2, 17.2,15.4, 16.4, 17.4, 18.4,
16.6,17.6, 18.6, 19.6,17.8, 18.8, 19.8, 20.8,19., 20., 21., 22.,
20.2,21.2, 22.2, 23.2,
21., 22., 23., 24., 21., 22., 23., 24., 21., 22., 23., 24.});
//input = 1.f;
input.linspace(1);
nd4j::ops::resize_bilinear op;
auto results = op.execute({&input}, {}, {10, 10});
ASSERT_EQ(ND4J_STATUS_OK, results->status());
NDArray* result = results->at(0);
//result->printIndexedBuffer("Resized to 10x10");
//expected.printIndexedBuffer("Expect for 10x10");
// result->printShapeInfo("Output shape");
// expected.printShapeInfo("Expect shape");
ASSERT_TRUE(expected.isSameShape(result));
ASSERT_TRUE(expected.equalsTo(result));
delete results;
}
//////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests10, ImageResizeBilinear_Test2) { TEST_F(DeclarableOpsTests10, ImageResizeBilinear_Test2) {
@ -1852,6 +1910,53 @@ TEST_F(DeclarableOpsTests10, ImageResizeNeighbor_Test1) {
delete results; delete results;
} }
TEST_F(DeclarableOpsTests10, ImageResizeNeighbor_Test01) {
NDArray input = NDArrayFactory::create<double>('c', {2, 3, 4});
//NDArray<float> paddings('c', {3,2}, {0,0, 0,1, 0,0});
//NDArray<float> expected('c', {2,4,4}, {1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,0.,0.,0.,0.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,0.,0.,0.,0.});
NDArray expected = NDArrayFactory::create<double>('c', {4, 5, 4}, { 1, 2, 3, 4,
1, 2, 3, 4,
5, 6, 7, 8,
5, 6, 7, 8,
9, 10, 11, 12,
1, 2, 3, 4,
1, 2, 3, 4,
5, 6, 7, 8,
5, 6, 7, 8,
9, 10, 11, 12,
13, 14, 15, 16,
13, 14, 15, 16,
17, 18, 19, 20,
17, 18, 19, 20,
21, 22, 23, 24,
13, 14, 15, 16,
13, 14, 15, 16,
17, 18, 19, 20,
17, 18, 19, 20,
21, 22, 23, 24
});
//input = 1.f;
input.linspace(1);
nd4j::ops::resize_nearest_neighbor op;
auto results = op.execute({&input}, {}, {4, 5});
ASSERT_EQ(ND4J_STATUS_OK, results->status());
NDArray* result = results->at(0);
//result->printIndexedBuffer("Resized to 4x5");
//expected.printIndexedBuffer("Expect for 4x5");
ASSERT_TRUE(expected.isSameShape(result));
ASSERT_TRUE(expected.equalsTo(result));
delete results;
}
//////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests10, ReduceLogSumExpTest_1) { TEST_F(DeclarableOpsTests10, ReduceLogSumExpTest_1) {
@ -2180,6 +2285,7 @@ TEST_F(DeclarableOpsTests10, Image_DrawBoundingBoxes_1) {
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto result = results->at(0); auto result = results->at(0);
result->syncToHost();
// result->printBuffer("Bounded boxes"); // result->printBuffer("Bounded boxes");
// expected.printBuffer("Bounded expec"); // expected.printBuffer("Bounded expec");
ASSERT_TRUE(expected.isSameShapeStrict(result)); ASSERT_TRUE(expected.isSameShapeStrict(result));
@ -2221,6 +2327,58 @@ TEST_F(DeclarableOpsTests10, Image_DrawBoundingBoxes_2) {
delete results; delete results;
} }
////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests10, Image_DrawBoundingBoxes_3) {
NDArray images = NDArrayFactory::create<float>('c', {2,5,5,1}, {0.7788, 0.8012, 0.7244, 0.2309, 0.7271, 0.1804,
0.5056, 0.8925, 0.5461, 0.9234, 0.0856,0.7938,
0.6591, 0.5555, 0.1596, 0.3087, 0.1548, 0.4695,
0.9939, 0.6113,0.6765, 0.1800, 0.6750, 0.2246,
0.0509, 0.4601, 0.8284, 0.2354, 0.9752, 0.8361,
0.2585, 0.4189, 0.7028, 0.7679, 0.5373, 0.7234,
0.2690, 0.0062, 0.0327, 0.0644, 0.8428, 0.7494,
0.0755, 0.6245, 0.3491, 0.5793, 0.5730, 0.1822,
0.6420, 0.9143});
NDArray boxes = NDArrayFactory::create<float>('c', {2, 2, 4}, {0.7717, 0.9281, 0.9846, 0.4838, 0.6433, 0.6041, 0.6501, 0.7612, 0.7605, 0.3948, 0.9493, 0.8600, 0.7876, 0.8945, 0.4638, 0.7157});
NDArray colors = NDArrayFactory::create<float>('c', {1, 2}, {0.9441, 0.5957});
//NDArray<float> ('c', {6}, {0.9f, .75f, .6f, .95f, .5f, .3f});
// NDArray expected = NDArrayFactory::create<float>('c', {2,5,5,1}, {
// 0.7788f, 0.8012f, 0.7244f, 0.2309f, 0.7271f,
// 0.1804f, 0.5056f, 0.8925f, 0.5461f, 0.9234f, 0.0856f, 0.7938f, 0.9441f,
// 0.9441f, 0.1596f, 0.3087f, 0.1548f, 0.4695f, 0.9939f, 0.6113f, 0.6765f,
// 0.1800f, 0.6750f, 0.2246f, 0.0509f, 0.4601f, 0.8284f, 0.2354f, 0.9752f, 0.8361f,
// 0.2585f, 0.4189f,0.7028f,0.7679f,0.5373f,0.7234f,0.2690f,0.0062f,0.0327f,0.0644f,
// 0.8428f, 0.9441f,0.9441f,0.9441f,0.3491f,0.5793f,0.5730f,0.1822f,0.6420f,0.9143f });
NDArray expected = NDArrayFactory::create<float>('c', {2,5,5,1}, {
0.7788, 0.8012, 0.7244, 0.2309, 0.7271,
0.1804, 0.5056, 0.8925, 0.5461, 0.9234,
0.0856, 0.7938, 0.9441, 0.9441, 0.1596,
0.3087, 0.1548, 0.4695, 0.9939, 0.6113,
0.6765, 0.18 , 0.675 , 0.2246, 0.0509,
0.4601, 0.8284, 0.2354, 0.9752, 0.8361,
0.2585, 0.4189, 0.7028, 0.7679, 0.5373,
0.7234, 0.269 , 0.0062, 0.0327, 0.0644,
0.8428, 0.9441, 0.9441, 0.9441, 0.3491,
0.5793, 0.573 , 0.1822, 0.642 , 0.9143});
nd4j::ops::draw_bounding_boxes op;
auto results = op.execute({&images, &boxes, &colors}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto result = results->at(0);
// result->printBuffer("Boxes3 output");
// expected.printBuffer("Boxes3 expect");
// result->syncToHost();
// result->printBuffer("Bounded boxes 2");
// expected.printBuffer("Bounded expec 2");
ASSERT_TRUE(expected.isSameShapeStrict(result));
ASSERT_TRUE(expected.equalsTo(result));
delete results;
}
//////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests10, FakeQuantWithMinMaxVars_Test_1) { TEST_F(DeclarableOpsTests10, FakeQuantWithMinMaxVars_Test_1) {
@ -2376,6 +2534,55 @@ TEST_F(DeclarableOpsTests10, FakeQuantWithMinMaxVars_Test_5) {
delete results; delete results;
} }
/* public void testFakeQuantAgainstTF_1() {
INDArray x = Nd4j.createFromArray(new float[]{ 0.7788f,0.8012f, 0.7244f, 0.2309f,0.7271f,
0.1804f, 0.5056f, 0.8925f, 0.5461f, 0.9234f,
0.0856f, 0.7938f, 0.6591f, 0.5555f, 0.1596f}).reshape(3,5);
INDArray min = Nd4j.createFromArray(new float[]{-0.2283f, -0.0719f, -0.0154f, -0.5162f, -0.3567f}).reshape(1,5);
INDArray max = Nd4j.createFromArray(new float[]{0.9441f, 0.5957f, 0.8669f, 0.3502f, 0.5100f}).reshape(1,5);
INDArray out = Nd4j.createUninitialized(x.shape());
val op = new FakeQuantWithMinMaxVarsPerChannel(x,min,max,out);
INDArray expected = Nd4j.createFromArray(new float[]{0.7801f, 0.5966f, 0.7260f, 0.2320f, 0.5084f,
0.1800f, 0.5046f, 0.8684f, 0.3513f, 0.5084f,
0.0877f, 0.5966f, 0.6600f, 0.3513f, 0.1604f}).reshape(3,5);
assertEquals(expected, out);
}*/
TEST_F(DeclarableOpsTests10, FakeQuantWithMinMaxVars_Test_6) {
NDArray x = NDArrayFactory::create<float>('c', {3, 5}, {0.7788f,0.8012f, 0.7244f, 0.2309f,0.7271f,
0.1804f, 0.5056f, 0.8925f, 0.5461f, 0.9234f,
0.0856f, 0.7938f, 0.6591f, 0.5555f, 0.1596f});
// NDArray exp = NDArrayFactory::create<float>('c', {3, 5},{
// 0.7801f, 0.5966f, 0.7260f, 0.2320f, 0.5084f,
// 0.1800f, 0.5046f, 0.8684f, 0.3513f, 0.5084f,
// 0.0877f, 0.5966f, 0.6600f, 0.3513f, 0.1604f
// });
NDArray exp = NDArrayFactory::create<float>('c', {3,5}, {0.77700233, 0.596913, 0.72314, 0.23104, 0.50982356,
0.17930824, 0.50528157, 0.86846, 0.34995764, 0.50982356,
0.08735529, 0.596913, 0.6574, 0.34995764, 0.15974471});
NDArray min = NDArrayFactory::create<float>('c', {5}, {-0.2283f, -0.0719f, -0.0154f, -0.5162f, -0.3567f});
NDArray max = NDArrayFactory::create<float>('c', {5}, {0.9441f, 0.5957f, 0.8669f, 0.3502f, 0.5100f});
// x.linspace(-60.);
nd4j::ops::fake_quant_with_min_max_vars_per_channel op;
auto results = op.execute({&x, &min, &max}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto result = results->at(0);
// result->printBuffer("Quantized per channels 5");
// exp.printBuffer("Quantized per channest E");
// auto diff = *result - exp;
// diff.printIndexedBuffer("Difference");
ASSERT_TRUE(exp.isSameShapeStrict(result));
ASSERT_TRUE(exp.equalsTo(result));
delete results;
}
//////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests10, batchnorm_test1) { TEST_F(DeclarableOpsTests10, batchnorm_test1) {