Shugeo_release_fixes3 (#81)
* Implementation for non_max_suppression_v3 was added. Initial version * Added check for overcome threshold. * Added definition for V3 method. * java remapping for NonMaxSuppressionV3 Signed-off-by: raver119 <raver119@gmail.com> * Fixed proporly processing of an empty output and test. * Refactored op to less threshold data to float. * Implemented cuda-based helper for non_max_suppression_v3 op. * Fixed fake_quant_with_min_max_vars op. * Fixed tests with float numbers. * - assert now stops execution - sortByKey/sortByValue now have input validation Signed-off-by: raver119 <raver119@gmail.com> * missing var Signed-off-by: raver119 <raver119@gmail.com> * Fixed proper processing for zero max_size inputs. * Refactored kernel callers. * Fixed return statement for logdet op helper. * Refactored unsorted segment SqrtN op. * get back 8 tail bytes on CUDA Signed-off-by: raver119 <raver119@gmail.com> * Refactored segment prod ops and helpers for cuda and tests. * Additional test. * CudaWorkspace tests updated for 8 tail bytes Signed-off-by: raver119 <raver119@gmail.com> * special atomic test Signed-off-by: raver119 <raver119@gmail.com> * atomicMul/atomicDiv fix for 16bit values Signed-off-by: raver119 <raver119@gmail.com> * Eliminated waste prints.master
parent
abd2017a0a
commit
009007120b
|
@ -2202,10 +2202,17 @@ void sortByKey(Nd4jPointer *extraPointers,
|
|||
auto stream = reinterpret_cast<cudaStream_t *>(extraPointers[1]);
|
||||
|
||||
auto xLength = shape::length(xShapeInfo);
|
||||
auto yLength = shape::length(yShapeInfo);
|
||||
auto xEWS = shape::elementWiseStride(xShapeInfo);
|
||||
auto xType = nd4j::ArrayOptions::dataType(xShapeInfo);
|
||||
auto yType = nd4j::ArrayOptions::dataType(yShapeInfo);
|
||||
|
||||
if (shape::isEmpty(xShapeInfo) || shape::isEmpty(yShapeInfo))
|
||||
return;
|
||||
|
||||
if (xLength != yLength)
|
||||
throw std::runtime_error("sortByKey: keys and values must have the same size");
|
||||
|
||||
|
||||
// check if xLength is a power of 2, and use bitonic sort, if that's the case
|
||||
if ((xLength != 0) && ((xLength & (xLength - 1)) == 0) && (xLength <= 1024 * 1024 * 10)) {
|
||||
|
@ -2269,10 +2276,17 @@ void sortByValue(Nd4jPointer *extraPointers,
|
|||
auto stream = reinterpret_cast<cudaStream_t *>(extraPointers[1]);
|
||||
|
||||
auto xLength = shape::length(xShapeInfo);
|
||||
auto yLength = shape::length(yShapeInfo);
|
||||
auto xEWS = shape::elementWiseStride(xShapeInfo);
|
||||
auto xType = nd4j::ArrayOptions::dataType(yShapeInfo);
|
||||
auto yType = nd4j::ArrayOptions::dataType(xShapeInfo);
|
||||
|
||||
if (shape::isEmpty(xShapeInfo) || shape::isEmpty(yShapeInfo))
|
||||
return;
|
||||
|
||||
if (xLength != yLength)
|
||||
throw std::runtime_error("sortByValue: keys and values must have the same size");
|
||||
|
||||
|
||||
// check if xLength is a power of 2, and use bitonic sort, if that's the case
|
||||
if ((xLength != 0) && ((xLength & (xLength - 1)) == 0) && (xLength <= 1024 * 1024 * 10)) {
|
||||
|
|
|
@ -1461,12 +1461,14 @@
|
|||
|
||||
#ifdef _RELEASE
|
||||
|
||||
#define ALLOCATE_SPECIAL(VARIABLE, WORKSPACE, LENGTH, TT) if (WORKSPACE == nullptr) {auto erc_##VARIABLE = cudaMalloc(reinterpret_cast<void**>(&VARIABLE), LENGTH * sizeof(TT)); if (erc_##VARIABLE != 0) {throw cuda_exception::build("[DEVICE] allocation failed", erc_##VARIABLE);} else { }; } else {VARIABLE = reinterpret_cast<TT *>(WORKSPACE->allocateBytes(nd4j::memory::MemoryType::DEVICE, LENGTH * sizeof(TT))); }
|
||||
// we intentionally add 8 tail bytes here to avoid problems with atomic operations
|
||||
#define ALLOCATE_SPECIAL(VARIABLE, WORKSPACE, LENGTH, TT) if (WORKSPACE == nullptr) {auto erc_##VARIABLE = cudaMalloc(reinterpret_cast<void**>(&VARIABLE), LENGTH * sizeof(TT) + 8); if (erc_##VARIABLE != 0) {throw cuda_exception::build("[DEVICE] allocation failed", erc_##VARIABLE);} else { }; } else {VARIABLE = reinterpret_cast<TT *>(WORKSPACE->allocateBytes(nd4j::memory::MemoryType::DEVICE, LENGTH * sizeof(TT) + 8)); }
|
||||
#define RELEASE_SPECIAL(VARIABLE, WORKSPACE) if (VARIABLE != nullptr) {if (WORKSPACE == nullptr) { auto erc_##VARIABLE = cudaFree(reinterpret_cast<void *>(VARIABLE)); if (erc_##VARIABLE != 0) {throw cuda_exception::build("[DEVICE] deallocation failed", erc_##VARIABLE);}; }; };
|
||||
|
||||
#else
|
||||
|
||||
#define ALLOCATE_SPECIAL(VARIABLE, WORKSPACE, LENGTH, TT) if (WORKSPACE == nullptr) {auto erc_##VARIABLE = cudaMalloc(reinterpret_cast<void**>(&VARIABLE), LENGTH * sizeof(TT)); if (erc_##VARIABLE != 0) {throw cuda_exception::build("[DEVICE] allocation failed", erc_##VARIABLE);} else { nd4j::memory::MemoryTracker::getInstance()->countIn(nd4j::memory::MemoryType::DEVICE, VARIABLE, LENGTH * sizeof(TT)); }; } else {VARIABLE = reinterpret_cast<TT *>(WORKSPACE->allocateBytes(nd4j::memory::MemoryType::DEVICE, LENGTH * sizeof(TT))); }
|
||||
// we intentionally add 8 tail bytes here to avoid problems with atomic operations
|
||||
#define ALLOCATE_SPECIAL(VARIABLE, WORKSPACE, LENGTH, TT) if (WORKSPACE == nullptr) {auto erc_##VARIABLE = cudaMalloc(reinterpret_cast<void**>(&VARIABLE), LENGTH * sizeof(TT) + 8); if (erc_##VARIABLE != 0) {throw cuda_exception::build("[DEVICE] allocation failed", erc_##VARIABLE);} else { nd4j::memory::MemoryTracker::getInstance()->countIn(nd4j::memory::MemoryType::DEVICE, VARIABLE, LENGTH * sizeof(TT)); }; } else {VARIABLE = reinterpret_cast<TT *>(WORKSPACE->allocateBytes(nd4j::memory::MemoryType::DEVICE, LENGTH * sizeof(TT) + 8)); }
|
||||
#define RELEASE_SPECIAL(VARIABLE, WORKSPACE) if (VARIABLE != nullptr) {if (WORKSPACE == nullptr) { nd4j::memory::MemoryTracker::getInstance()->countOut(VARIABLE); auto erc_##VARIABLE = cudaFree(reinterpret_cast<void *>(VARIABLE)); if (erc_##VARIABLE != 0) {throw cuda_exception::build("[DEVICE] deallocation failed", erc_##VARIABLE);}; }; };
|
||||
|
||||
#endif
|
||||
|
|
|
@ -29,8 +29,8 @@ namespace nd4j {
|
|||
OP_IMPL(Assert, 1, 1, false) {
|
||||
auto x = INPUT_VARIABLE(0);
|
||||
|
||||
if (x->e<float>(0) == 0.0f) {
|
||||
nd4j_printf("Assertion failed for node [%i]\n", block.getNodeId());
|
||||
if (!x->e<bool>(0)) {
|
||||
REQUIRE_TRUE(false, 0, "Assertion failed for node [%i]\n", block.getNodeId());
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
|
|
|
@ -21,10 +21,10 @@
|
|||
#include <ops/declarable/CustomOperations.h>
|
||||
#include <ops/declarable/helpers/image_suppression.h>
|
||||
|
||||
#if NOT_EXCLUDED(OP_image_non_max_suppression)
|
||||
|
||||
namespace nd4j {
|
||||
namespace ops {
|
||||
#if NOT_EXCLUDED(OP_image_non_max_suppression)
|
||||
CUSTOM_OP_IMPL(non_max_suppression, 2, 1, false, 0, 0) {
|
||||
auto boxes = INPUT_VARIABLE(0);
|
||||
auto scales = INPUT_VARIABLE(1);
|
||||
|
@ -56,11 +56,24 @@ namespace nd4j {
|
|||
if (boxes->isEmpty() || scales->isEmpty())
|
||||
return Status::OK();
|
||||
|
||||
REQUIRE_TRUE(boxes->rankOf() == 2, 0, "image.non_max_suppression: The rank of boxes array should be 2, but %i is given", boxes->rankOf());
|
||||
REQUIRE_TRUE(boxes->sizeAt(1) == 4, 0, "image.non_max_suppression: The last dimension of boxes array should be 4, but %i is given", boxes->sizeAt(1));
|
||||
REQUIRE_TRUE(scales->rankOf() == 1 && scales->lengthOf() == boxes->sizeAt(0), 0, "image.non_max_suppression: The rank of scales array should be 1, but %i is given", boxes->rankOf());
|
||||
if (output->isEmpty())
|
||||
return Status::OK();
|
||||
|
||||
helpers::nonMaxSuppression(block.launchContext(), boxes, scales, maxOutputSize, overlayThreshold, scoreThreshold, output);
|
||||
REQUIRE_TRUE(boxes->rankOf() == 2, 0, "image.non_max_suppression: The rank of boxes array should be 2, "
|
||||
"but %i is given", boxes->rankOf());
|
||||
REQUIRE_TRUE(boxes->sizeAt(1) == 4, 0, "image.non_max_suppression: The last dimension of boxes array "
|
||||
"should be 4, but %i is given", boxes->sizeAt(1));
|
||||
REQUIRE_TRUE(scales->rankOf() == 1 && scales->lengthOf() == boxes->sizeAt(0), 0,
|
||||
"image.non_max_suppression: The rank of scales array should be 1, but %i is given", boxes->rankOf());
|
||||
REQUIRE_TRUE(overlayThreshold >= 0. && overlayThreshold <= 1., 0, "image.non_max_suppressio: The overlay "
|
||||
"threashold should be in [0, 1], but "
|
||||
"%lf is given.", overlayThreshold);
|
||||
REQUIRE_TRUE(boxes->dataType() == scales->dataType(), 0,
|
||||
"image.non_max_suppression: Boxes and scores inputs should have the same data type, but %s and %s "
|
||||
"were given.", DataTypeUtils::asString(boxes->dataType()).c_str(),
|
||||
DataTypeUtils::asString(scales->dataType()).c_str());
|
||||
helpers::nonMaxSuppression(block.launchContext(), boxes, scales, maxOutputSize, overlayThreshold,
|
||||
scoreThreshold, output);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
@ -77,9 +90,11 @@ namespace nd4j {
|
|||
else
|
||||
REQUIRE_TRUE(false, 0, "image.non_max_suppression: Max output size argument cannot be retrieved.");
|
||||
|
||||
if (maxOutputSize > 0) {
|
||||
auto actualIndicesCount = shape::sizeAt(in, 0);
|
||||
if (block.getTArguments()->size() > 1 || block.width() > 4) {
|
||||
auto scoreThreshold = block.getTArguments()->size() > 1?T_ARG(1):INPUT_VARIABLE(4)->e<double>(0);
|
||||
auto scoreThreshold =
|
||||
block.getTArguments()->size() > 1 ? T_ARG(1) : INPUT_VARIABLE(4)->e<double>(0);
|
||||
auto scales = INPUT_VARIABLE(1);
|
||||
scales->syncToHost();
|
||||
for (auto e = 0; e < scales->lengthOf(); e++) {
|
||||
|
@ -90,7 +105,7 @@ namespace nd4j {
|
|||
}
|
||||
if (actualIndicesCount < maxOutputSize)
|
||||
maxOutputSize = actualIndicesCount;
|
||||
|
||||
}
|
||||
outputShape = ConstantShapeHelper::getInstance()->vectorShapeInfo(maxOutputSize, DataType::INT32);
|
||||
|
||||
return SHAPELIST(outputShape);
|
||||
|
@ -100,7 +115,107 @@ namespace nd4j {
|
|||
->setAllowedInputTypes(nd4j::DataType::ANY)
|
||||
->setAllowedOutputTypes({ALL_INDICES});
|
||||
}
|
||||
#endif
|
||||
#if NOT_EXCLUDED(OP_image_non_max_suppression_v3)
|
||||
DECLARE_TYPES(non_max_suppression_v3) {
|
||||
getOpDescriptor()
|
||||
->setAllowedInputTypes(nd4j::DataType::ANY)
|
||||
->setAllowedOutputTypes({ALL_INDICES});
|
||||
}
|
||||
|
||||
CUSTOM_OP_IMPL(non_max_suppression_v3, 2, 1, false, 0, 0) {
|
||||
auto boxes = INPUT_VARIABLE(0);
|
||||
auto scales = INPUT_VARIABLE(1);
|
||||
auto output = OUTPUT_VARIABLE(0);
|
||||
int maxOutputSize; // = INT_ARG(0);
|
||||
if (block.width() > 2)
|
||||
maxOutputSize = INPUT_VARIABLE(2)->e<int>(0);
|
||||
else if (block.getIArguments()->size() == 1)
|
||||
maxOutputSize = INT_ARG(0);
|
||||
else
|
||||
REQUIRE_TRUE(false, 0, "image.non_max_suppression: Max output size argument cannot be retrieved.");
|
||||
|
||||
double overlayThreshold = 0.5;
|
||||
double scoreThreshold = - DataTypeUtils::infOrMax<float>();
|
||||
|
||||
if (block.width() > 3) {
|
||||
overlayThreshold = INPUT_VARIABLE(3)->e<double>(0);
|
||||
}
|
||||
else if (block.getTArguments()->size() > 0) {
|
||||
overlayThreshold = T_ARG(0);
|
||||
}
|
||||
|
||||
if (block.width() > 4) {
|
||||
scoreThreshold = INPUT_VARIABLE(4)->e<double>(0);
|
||||
}
|
||||
else if (block.getTArguments()->size() > 1) {
|
||||
scoreThreshold = T_ARG(1);
|
||||
}
|
||||
if (boxes->isEmpty() || scales->isEmpty())
|
||||
return Status::OK();
|
||||
if (output->isEmpty())
|
||||
return Status::OK();
|
||||
|
||||
REQUIRE_TRUE(boxes->rankOf() == 2, 0, "image.non_max_suppression: The rank of boxes array should be 2, but "
|
||||
"%i is given", boxes->rankOf());
|
||||
REQUIRE_TRUE(boxes->sizeAt(1) == 4, 0, "image.non_max_suppression: The last dimension of boxes array should "
|
||||
"be 4, but %i is given", boxes->sizeAt(1));
|
||||
REQUIRE_TRUE(scales->rankOf() == 1 && scales->lengthOf() == boxes->sizeAt(0), 0,
|
||||
"image.non_max_suppression: The rank of scales array should be 1, but %i is given", boxes->rankOf());
|
||||
REQUIRE_TRUE(overlayThreshold >= 0. && overlayThreshold <= 1., 0,
|
||||
"image.non_max_suppression_v3: The overlay threashold should be in [0, 1], but %lf given.",
|
||||
overlayThreshold);
|
||||
REQUIRE_TRUE(boxes->dataType() == scales->dataType(), 0,
|
||||
"image.non_max_suppression_v3: Boxes and scores inputs should have the same data type, but %s and %s "
|
||||
"were given.", DataTypeUtils::asString(boxes->dataType()).c_str(),
|
||||
DataTypeUtils::asString(scales->dataType()).c_str());
|
||||
|
||||
helpers::nonMaxSuppressionV3(block.launchContext(), boxes, scales, maxOutputSize, overlayThreshold,
|
||||
scoreThreshold, output);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
DECLARE_SHAPE_FN(non_max_suppression_v3) {
|
||||
auto in = inputShape->at(0);
|
||||
int outRank = shape::rank(in);
|
||||
Nd4jLong *outputShape = nullptr;
|
||||
|
||||
int maxOutputSize;
|
||||
if (block.width() > 2)
|
||||
maxOutputSize = INPUT_VARIABLE(2)->e<int>(0);
|
||||
else if (block.getIArguments()->size() == 1)
|
||||
maxOutputSize = INT_ARG(0);
|
||||
else
|
||||
REQUIRE_TRUE(false, 0, "image.non_max_suppression: Max output size argument cannot be retrieved.");
|
||||
auto boxes = INPUT_VARIABLE(0);
|
||||
auto scales = INPUT_VARIABLE(1);
|
||||
|
||||
double overlayThreshold = 0.5;
|
||||
double scoreThreshold = - DataTypeUtils::infOrMax<float>();
|
||||
|
||||
if (block.width() > 3) {
|
||||
overlayThreshold = INPUT_VARIABLE(3)->e<double>(0);
|
||||
}
|
||||
else if (block.getTArguments()->size() > 0) {
|
||||
overlayThreshold = T_ARG(0);
|
||||
}
|
||||
|
||||
if (block.width() > 4) {
|
||||
scoreThreshold = INPUT_VARIABLE(4)->e<double>(0);
|
||||
}
|
||||
else if (block.getTArguments()->size() > 1) {
|
||||
scoreThreshold = T_ARG(1);
|
||||
}
|
||||
|
||||
auto len = maxOutputSize;
|
||||
if (len > 0)
|
||||
len = helpers::nonMaxSuppressionV3(block.launchContext(), boxes, scales, maxOutputSize, overlayThreshold, scoreThreshold, nullptr);
|
||||
|
||||
outputShape = ConstantShapeHelper::getInstance()->vectorShapeInfo(len, DataType::INT32);
|
||||
|
||||
return SHAPELIST(outputShape);
|
||||
}
|
||||
#endif
|
||||
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
|
|
@ -61,9 +61,9 @@ namespace nd4j {
|
|||
}
|
||||
DECLARE_TYPES(unsorted_segment_prod) {
|
||||
getOpDescriptor()
|
||||
->setAllowedOutputTypes({ALL_FLOATS})
|
||||
->setAllowedInputTypes(0, {ALL_FLOATS})
|
||||
->setAllowedInputTypes(1, {ALL_INTS})
|
||||
->setAllowedOutputTypes({ALL_FLOATS, ALL_INTS})
|
||||
->setAllowedInputTypes(0, {ALL_FLOATS, ALL_INTS})
|
||||
->setAllowedInputTypes(1, {ALL_INDICES})
|
||||
->setSameMode(false);
|
||||
}
|
||||
|
||||
|
@ -88,10 +88,10 @@ namespace nd4j {
|
|||
DECLARE_TYPES(unsorted_segment_prod_bp) {
|
||||
getOpDescriptor()
|
||||
->setAllowedOutputTypes(0, {ALL_FLOATS})
|
||||
->setAllowedOutputTypes(1, {ALL_INTS})
|
||||
->setAllowedOutputTypes(1, {ALL_INDICES})
|
||||
->setAllowedInputTypes(0, {ALL_FLOATS})
|
||||
->setAllowedInputTypes(1, {ALL_INTS})
|
||||
->setAllowedInputTypes(2,{ALL_FLOATS})
|
||||
->setAllowedInputTypes(1, {ALL_INDICES})
|
||||
->setAllowedInputTypes(2,{ALL_FLOATS, ALL_INTS})
|
||||
->setSameMode(false);
|
||||
}
|
||||
|
||||
|
|
|
@ -1723,7 +1723,7 @@ namespace nd4j {
|
|||
#endif
|
||||
|
||||
/**
|
||||
* image.non_max_suppression op.
|
||||
* image.non_max_suppression ops.
|
||||
* input:
|
||||
* 0 - boxes - 2D-tensor with shape (num_boxes, 4) by float type
|
||||
* 1 - scales - 1D-tensor with shape (num_boxes) by float type
|
||||
|
@ -1741,6 +1741,9 @@ namespace nd4j {
|
|||
#if NOT_EXCLUDED(OP_image_non_max_suppression)
|
||||
DECLARE_CUSTOM_OP(non_max_suppression, 2, 1, false, 0, 0);
|
||||
#endif
|
||||
#if NOT_EXCLUDED(OP_image_non_max_suppression_v3)
|
||||
DECLARE_CUSTOM_OP(non_max_suppression_v3, 2, 1, false, 0, 0);
|
||||
#endif
|
||||
|
||||
/*
|
||||
* image.non_max_suppression_overlaps op.
|
||||
|
|
|
@ -77,7 +77,13 @@ namespace helpers {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
//
|
||||
//const auto clamped = inputs.cwiseMin(nudged_max).cwiseMax(nudged_min);
|
||||
// const auto clamped_shifted = clamped - nudged_min;
|
||||
// outputs.device(d) = (clamped_shifted / nudged_scale_repl + 0.5f).floor() *
|
||||
// nudged_scale_repl +
|
||||
// nudged_min;
|
||||
//
|
||||
template <typename T>
|
||||
void fakeQuantWithMinMaxVars_(NDArray* input, NDArray* min, NDArray* max, int numBits, bool narrowed, NDArray* output) {
|
||||
int lowIntBound = narrowed ? 1 : 0;
|
||||
|
@ -95,7 +101,8 @@ namespace helpers {
|
|||
else if (val > nudgedMax)
|
||||
val = nudgedMax;
|
||||
// converse value with scale and shifted with nudged min
|
||||
return (nd4j::math::nd4j_floor<T,T>((val - nudgedMin)/scale + T(0.5)) * scale + nudgedMin);
|
||||
val -= nudgedMin;
|
||||
return (nd4j::math::nd4j_floor<T,T>(val / scale + T(0.5f)) * scale + nudgedMin);
|
||||
};
|
||||
|
||||
input->applyLambda<T>(fakeQuantizationWithMinMax, output);
|
||||
|
|
|
@ -19,7 +19,7 @@
|
|||
//
|
||||
|
||||
#include <ops/declarable/helpers/image_suppression.h>
|
||||
//#include <blas/NDArray.h>
|
||||
#include <NDArrayFactory.h>
|
||||
#include <algorithm>
|
||||
#include <numeric>
|
||||
#include <queue>
|
||||
|
@ -90,17 +90,61 @@ namespace helpers {
|
|||
}
|
||||
}
|
||||
}
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
// Return intersection-over-union overlap between boxes i and j
|
||||
template <typename T>
|
||||
static inline T similirityV3_(NDArray const& boxes, Nd4jLong i, Nd4jLong j) {
|
||||
const T zero = static_cast<T>(0.f);
|
||||
const T yminI = math::nd4j_min(boxes.t<T>(i, 0), boxes.t<T>(i, 2));
|
||||
const T xminI = math::nd4j_min(boxes.t<T>(i, 1), boxes.t<T>(i, 3));
|
||||
const T ymaxI = math::nd4j_max(boxes.t<T>(i, 0), boxes.t<T>(i, 2));
|
||||
const T xmaxI = math::nd4j_max(boxes.t<T>(i, 1), boxes.t<T>(i, 3));
|
||||
const T yminJ = math::nd4j_min(boxes.t<T>(j, 0), boxes.t<T>(j, 2));
|
||||
const T xminJ = math::nd4j_min(boxes.t<T>(j, 1), boxes.t<T>(j, 3));
|
||||
const T ymaxJ = math::nd4j_max(boxes.t<T>(j, 0), boxes.t<T>(j, 2));
|
||||
const T xmaxJ = math::nd4j_max(boxes.t<T>(j, 1), boxes.t<T>(j, 3));
|
||||
const T areaI = (ymaxI - yminI) * (xmaxI - xminI);
|
||||
const T areaJ = (ymaxJ - yminJ) * (xmaxJ - xminJ);
|
||||
if (areaI <= zero || areaJ <= zero) {
|
||||
return zero;
|
||||
}
|
||||
const T intersectionYmin = math::nd4j_max(yminI, yminJ);
|
||||
const T intersectionXmin = math::nd4j_max(xminI, xminJ);
|
||||
const T intersectionYmax = math::nd4j_min(ymaxI, ymaxJ);
|
||||
const T intersectionXmax = math::nd4j_min(xmaxI, xmaxJ);
|
||||
const T intersectionY = intersectionYmax - intersectionYmin;
|
||||
const T intersectionX = intersectionXmax - intersectionXmin;
|
||||
const T intersectionArea = math::nd4j_max(intersectionY, zero) * math::nd4j_max(intersectionX, zero);
|
||||
return intersectionArea / (areaI + areaJ - intersectionArea);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
static inline T similiratyOverlaps_(NDArray const& boxes, Nd4jLong i, Nd4jLong j) {
|
||||
return boxes.t<T>(i, j);
|
||||
}
|
||||
|
||||
typedef NDArray (*SimiliratyFunc)(NDArray const& boxes, Nd4jLong i, Nd4jLong j);
|
||||
|
||||
static NDArray similiratyOverlaps(NDArray const& boxes, Nd4jLong i, Nd4jLong j) {
|
||||
NDArray res(boxes.dataType(), boxes.getContext()); // = NDArrayFactory::create(0.);
|
||||
BUILD_SINGLE_SELECTOR(boxes.dataType(), res = similiratyOverlaps_, (boxes, i, j) , FLOAT_TYPES);
|
||||
return res;
|
||||
}
|
||||
|
||||
static NDArray similiratyV3(NDArray const& boxes, Nd4jLong i, Nd4jLong j) {
|
||||
NDArray res(boxes.dataType(), boxes.getContext()); // = NDArrayFactory::create(0.);
|
||||
BUILD_SINGLE_SELECTOR(boxes.dataType(), res = similirityV3_, (boxes, i, j) , FLOAT_TYPES);
|
||||
return res;
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
template <typename T, typename I>
|
||||
static Nd4jLong
|
||||
nonMaxSuppressionGeneric_(nd4j::LaunchContext* context, NDArray* boxes, NDArray* scores, int outputSize,
|
||||
double overlapThreshold, double scoreThreshold, NDArray* output) {
|
||||
float overlapThreshold, float scoreThreshold, NDArray* output, SimiliratyFunc f) {
|
||||
|
||||
// const int outputSize = maxSize->e<int>(0);
|
||||
auto numBoxes = boxes->sizeAt(0);
|
||||
//std::vector<T> scoresData(numBoxes);
|
||||
T* scoresData = scores->dataBuffer()->primaryAsT<T>();
|
||||
//std::copy_n(scores->getDataBuffer()->primaryAsT<T>(), numBoxes, scoresData.begin());
|
||||
|
||||
// Data structure for a selection candidate in NMS.
|
||||
struct Candidate {
|
||||
|
@ -113,9 +157,10 @@ namespace helpers {
|
|||
return ((bsI._score == bsJ._score) && (bsI._boxIndex > bsJ._boxIndex)) ||
|
||||
(bsI._score < bsJ._score);
|
||||
};
|
||||
|
||||
std::priority_queue<Candidate, std::deque<Candidate>, decltype(cmp)> candidatePriorityQueue(cmp);
|
||||
for (auto i = 0; i < scores->lengthOf(); ++i) {
|
||||
if (scoresData[i] > scoreThreshold) {
|
||||
if ((float)scoresData[i] > (float)scoreThreshold) {
|
||||
candidatePriorityQueue.emplace(Candidate({i, scoresData[i], 0}));
|
||||
}
|
||||
}
|
||||
|
@ -139,17 +184,18 @@ namespace helpers {
|
|||
// following loop.
|
||||
bool shouldHardSuppress = false;
|
||||
for (int j = static_cast<int>(selected.size()) - 1; j >= nextCandidate._suppressBeginIndex; --j) {
|
||||
similarity = boxes->t<T>(nextCandidate._boxIndex, selected[j]);
|
||||
auto similarityA = f(*boxes, nextCandidate._boxIndex, selected[j]); //boxes->t<T>(nextCandidate._boxIndex, selected[j]);
|
||||
similarity = similarityA.template t<T>(0);
|
||||
nextCandidate._score *= T(similarity <= overlapThreshold?1.0:0.); //suppressWeightFunc(similarity);
|
||||
|
||||
// First decide whether to perform hard suppression
|
||||
if (similarity >= static_cast<T>(overlapThreshold)) {
|
||||
if ((float)similarity >= static_cast<float>(overlapThreshold)) {
|
||||
shouldHardSuppress = true;
|
||||
break;
|
||||
}
|
||||
|
||||
// If next_candidate survives hard suppression, apply soft suppression
|
||||
if (nextCandidate._score <= scoreThreshold) break;
|
||||
if ((float)nextCandidate._score <= (float)scoreThreshold) break;
|
||||
}
|
||||
// If `nextCandidate._score` has not dropped below `scoreThreshold`
|
||||
// by this point, then we know that we went through all of the previous
|
||||
|
@ -169,7 +215,7 @@ namespace helpers {
|
|||
selected.push_back(nextCandidate._boxIndex);
|
||||
// selected_scores.push_back(nextCandidate._score);
|
||||
}
|
||||
if (nextCandidate._score > scoreThreshold) {
|
||||
if ((float)nextCandidate._score > (float)scoreThreshold) {
|
||||
// Soft suppression has occurred and current score is still greater than
|
||||
// score_threshold; add next_candidate back onto priority queue.
|
||||
candidatePriorityQueue.push(nextCandidate);
|
||||
|
@ -188,12 +234,19 @@ namespace helpers {
|
|||
Nd4jLong
|
||||
nonMaxSuppressionGeneric(nd4j::LaunchContext* context, NDArray* boxes, NDArray* scores, int maxSize,
|
||||
double overlapThreshold, double scoreThreshold, NDArray* output) {
|
||||
BUILD_DOUBLE_SELECTOR(boxes->dataType(), output == nullptr?DataType::INT32:output->dataType(), return nonMaxSuppressionGeneric_, (context, boxes, scores, maxSize, overlapThreshold, scoreThreshold, output), FLOAT_TYPES, INTEGER_TYPES);
|
||||
BUILD_DOUBLE_SELECTOR(boxes->dataType(), output == nullptr?DataType::INT32:output->dataType(), return nonMaxSuppressionGeneric_, (context, boxes, scores, maxSize, overlapThreshold, scoreThreshold, output, similiratyOverlaps), FLOAT_TYPES, INTEGER_TYPES);
|
||||
return 0;
|
||||
}
|
||||
|
||||
Nd4jLong
|
||||
nonMaxSuppressionV3(nd4j::LaunchContext* context, NDArray* boxes, NDArray* scores, int maxSize,
|
||||
double overlapThreshold, double scoreThreshold, NDArray* output) {
|
||||
BUILD_DOUBLE_SELECTOR(boxes->dataType(), output == nullptr?DataType::INT32:output->dataType(), return nonMaxSuppressionGeneric_, (context, boxes, scores, maxSize, overlapThreshold, scoreThreshold, output, similiratyV3), FLOAT_TYPES, INTEGER_TYPES);
|
||||
return 0;
|
||||
}
|
||||
|
||||
BUILD_DOUBLE_TEMPLATE(template Nd4jLong nonMaxSuppressionGeneric_, (nd4j::LaunchContext* context, NDArray* boxes, NDArray* scores, int maxSize,
|
||||
double overlapThreshold, double scoreThreshold, NDArray* output), FLOAT_TYPES, INTEGER_TYPES);
|
||||
float overlapThreshold, float scoreThreshold, NDArray* output, SimiliratyFunc similiratyFunc), FLOAT_TYPES, INTEGER_TYPES);
|
||||
|
||||
void
|
||||
nonMaxSuppression(nd4j::LaunchContext * context, NDArray* boxes, NDArray* scales, int maxSize,
|
||||
|
|
|
@ -100,7 +100,7 @@ namespace helpers {
|
|||
val = nudgedMax;
|
||||
}
|
||||
output[shape::getIndexOffset(b * channels + i, outputShape)] =
|
||||
(math::nd4j_floor<T, T>((val - nudgedMin) / scale + T(0.5)) * scale + nudgedMin);
|
||||
(math::nd4j_floor<T, T>((val - nudgedMin) / scale + T(0.5f)) * scale + nudgedMin);
|
||||
};
|
||||
}
|
||||
}
|
||||
|
|
|
@ -79,6 +79,48 @@ namespace helpers {
|
|||
return intersectionValue > threshold;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
static __device__ T similirityV3(T* boxes, Nd4jLong* boxesShape, int previousIndex, int nextIndex) {
|
||||
Nd4jLong previous0[] = {previousIndex, 0};
|
||||
Nd4jLong previous1[] = {previousIndex, 1};
|
||||
Nd4jLong previous2[] = {previousIndex, 2};
|
||||
Nd4jLong previous3[] = {previousIndex, 3};
|
||||
Nd4jLong next0[] = {nextIndex, 0};
|
||||
Nd4jLong next1[] = {nextIndex, 1};
|
||||
Nd4jLong next2[] = {nextIndex, 2};
|
||||
Nd4jLong next3[] = {nextIndex, 3};
|
||||
|
||||
// we have rectangle with given max values. Compute vexes of rectangle first
|
||||
|
||||
T minYPrev = nd4j::math::nd4j_min(boxes[shape::getOffset(boxesShape, previous0)], boxes[shape::getOffset(boxesShape, previous2)]);
|
||||
T minXPrev = nd4j::math::nd4j_min(boxes[shape::getOffset(boxesShape, previous1)], boxes[shape::getOffset(boxesShape, previous3)]);
|
||||
T maxYPrev = nd4j::math::nd4j_max(boxes[shape::getOffset(boxesShape, previous0)], boxes[shape::getOffset(boxesShape, previous2)]);
|
||||
T maxXPrev = nd4j::math::nd4j_max(boxes[shape::getOffset(boxesShape, previous1)], boxes[shape::getOffset(boxesShape, previous3)]);
|
||||
T minYNext = nd4j::math::nd4j_min(boxes[shape::getOffset(boxesShape, next0)], boxes[shape::getOffset(boxesShape, next2)]);
|
||||
T minXNext = nd4j::math::nd4j_min(boxes[shape::getOffset(boxesShape, next1)], boxes[shape::getOffset(boxesShape, next3)]);
|
||||
T maxYNext = nd4j::math::nd4j_max(boxes[shape::getOffset(boxesShape, next0)], boxes[shape::getOffset(boxesShape, next2)]);
|
||||
T maxXNext = nd4j::math::nd4j_max(boxes[shape::getOffset(boxesShape, next1)], boxes[shape::getOffset(boxesShape, next3)]);
|
||||
|
||||
// compute areas for comparation
|
||||
T areaPrev = (maxYPrev - minYPrev) * (maxXPrev - minXPrev);
|
||||
T areaNext = (maxYNext - minYNext) * (maxXNext - minXNext);
|
||||
|
||||
// of course, areas should be positive
|
||||
if (areaNext <= T(0.f) || areaPrev <= T(0.f)) return false;
|
||||
|
||||
// compute intersection of rectangles
|
||||
T minIntersectionY = nd4j::math::nd4j_max(minYPrev, minYNext);
|
||||
T minIntersectionX = nd4j::math::nd4j_max(minXPrev, minXNext);
|
||||
T maxIntersectionY = nd4j::math::nd4j_min(maxYPrev, maxYNext);
|
||||
T maxIntersectionX = nd4j::math::nd4j_min(maxXPrev, maxXNext);
|
||||
T intersectionArea =
|
||||
nd4j::math::nd4j_max(T(maxIntersectionY - minIntersectionY), T(0.0f)) *
|
||||
nd4j::math::nd4j_max(T(maxIntersectionX - minIntersectionX), T(0.0f));
|
||||
T intersectionValue = intersectionArea / (areaPrev + areaNext - intersectionArea);
|
||||
// final check
|
||||
return intersectionValue;
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
// shouldSelectKernel - compute status for all selected rectangles (boxes)
|
||||
//
|
||||
|
@ -200,24 +242,33 @@ namespace helpers {
|
|||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
template <typename T, typename I>
|
||||
static __device__ bool checkOverlapBoxes(T* boxes, Nd4jLong* shape, T* scores, I* indices, I* selectedIndices, I* startIndices, I selectedSize, I nextCandidateIndex, T overlapThreshold, T scoreThreshold) {
|
||||
static __device__ bool checkOverlapBoxes(T* boxes, Nd4jLong* shape, T* scores, I* indices, I* selectedIndices, I* startIndices, I selectedSize, I nextCandidateIndex, T overlapThreshold, T scoreThreshold, bool simple) {
|
||||
bool shouldHardSuppress = false;
|
||||
T& nextCandidateScore = scores[nextCandidateIndex];
|
||||
I selectedIndex = indices[nextCandidateIndex];
|
||||
I finish = startIndices[nextCandidateIndex];
|
||||
|
||||
for (int j = selectedSize; j > finish; --j) {
|
||||
T boxVal;
|
||||
if (simple) {
|
||||
Nd4jLong xPos[] = {selectedIndex, selectedIndices[j - 1]};
|
||||
auto xShift = shape::getOffset(shape, xPos, 0);
|
||||
nextCandidateScore *= (boxes[xShift] <= static_cast<T>(overlapThreshold)?T(1.):T(0.));//
|
||||
boxVal = boxes[xShift];
|
||||
}
|
||||
else {
|
||||
boxVal = similirityV3(boxes, shape, selectedIndex, selectedIndices[j - 1]);
|
||||
}
|
||||
if (boxVal > static_cast<T>(overlapThreshold))
|
||||
nextCandidateScore = static_cast<T>(0.f);
|
||||
|
||||
// First decide whether to perform hard suppression
|
||||
if (boxes[xShift] >= overlapThreshold) {
|
||||
if (boxVal >= overlapThreshold) {
|
||||
shouldHardSuppress = true;
|
||||
break;
|
||||
}
|
||||
|
||||
// If nextCandidate survives hard suppression, apply soft suppression
|
||||
if (nextCandidateScore <= scoreThreshold) break;
|
||||
if (nextCandidateScore <= static_cast<T>(scoreThreshold)) break;
|
||||
}
|
||||
|
||||
return shouldHardSuppress;
|
||||
|
@ -226,7 +277,7 @@ namespace helpers {
|
|||
template <typename T, typename I>
|
||||
static __global__ void
|
||||
suppressNonMaxOverlapKernel(T* boxes, Nd4jLong* boxesShape, T* scoresData, I* indices, I* startIndices, Nd4jLong length, I maxOutputLen,
|
||||
T overlapThreshold, T scoreThreshold, I* output, Nd4jLong* outputShape, I* outputLength) {
|
||||
T overlapThreshold, T scoreThreshold, I* output, Nd4jLong* outputShape, I* outputLength, bool simple) {
|
||||
|
||||
__shared__ I selectedSize;
|
||||
__shared__ I* tempOutput;
|
||||
|
@ -253,7 +304,7 @@ namespace helpers {
|
|||
}
|
||||
// check for overlaps
|
||||
bool shouldHardSuppress = checkOverlapBoxes(boxes, boxesShape, scoresData, indices, tempOutput, startIndices, selectedSize,
|
||||
nextCandidateIndex, overlapThreshold, scoreThreshold);//false;
|
||||
nextCandidateIndex, overlapThreshold, scoreThreshold, simple);//false;
|
||||
T nextCandidateScore = scoresData[nextCandidateIndex];
|
||||
|
||||
startIndices[nextCandidateIndex] = selectedSize;
|
||||
|
@ -285,7 +336,7 @@ namespace helpers {
|
|||
template <typename T, typename I>
|
||||
static Nd4jLong
|
||||
nonMaxSuppressionGeneric_(nd4j::LaunchContext* context, NDArray* boxes, NDArray* scores, int outputSize,
|
||||
double overlapThreshold, double scoreThreshold, NDArray* output) {
|
||||
double overlapThreshold, double scoreThreshold, NDArray* output, bool simple) {
|
||||
auto stream = context->getCudaStream();
|
||||
if (output)
|
||||
NDArray::prepareSpecialUse({output}, {boxes, scores});
|
||||
|
@ -318,13 +369,13 @@ namespace helpers {
|
|||
suppressNonMaxOverlapKernel<<<1, 1, 1024, *stream >>> (boxes->dataBuffer()->specialAsT<T>(),
|
||||
boxes->specialShapeInfo(), scoresData, indexBuf, startIndices, scores->lengthOf(), (I) outputSize,
|
||||
T(overlapThreshold), T(scoreThreshold), output->dataBuffer()->specialAsT<I>(), output->specialShapeInfo(),
|
||||
selectedSizeBuf.specialAsT<I>());
|
||||
selectedSizeBuf.specialAsT<I>(), simple);
|
||||
}
|
||||
else { // this case used on calculation of output shape. Output and output shape shoulde be nullptr.
|
||||
DataBuffer selectedSizeBuf(&selectedSize, sizeof(I), DataTypeUtils::fromT<I>());
|
||||
suppressNonMaxOverlapKernel<<<1, 1, 1024, *stream >>> (boxes->dataBuffer()->specialAsT<T>(),
|
||||
boxes->specialShapeInfo(), scoresData, indexBuf, startIndices, scores->lengthOf(), (I)outputSize,
|
||||
T(overlapThreshold), T(scoreThreshold), (I*)nullptr, (Nd4jLong*) nullptr, selectedSizeBuf.specialAsT<I>());
|
||||
T(overlapThreshold), T(scoreThreshold), (I*)nullptr, (Nd4jLong*) nullptr, selectedSizeBuf.specialAsT<I>(), simple);
|
||||
selectedSizeBuf.syncToPrimary(context, true);
|
||||
res = *selectedSizeBuf.primaryAsT<I>();
|
||||
}
|
||||
|
@ -344,7 +395,16 @@ namespace helpers {
|
|||
|
||||
Nd4jLong nonMaxSuppressionGeneric(nd4j::LaunchContext * context, NDArray* boxes, NDArray* scales, int maxSize, double threshold, double scoreThreshold, NDArray* output) {
|
||||
BUILD_DOUBLE_SELECTOR(boxes->dataType(), output ? output->dataType():DataType::INT32, return nonMaxSuppressionGeneric_,
|
||||
(context, boxes, scales, maxSize, threshold, scoreThreshold, output),
|
||||
(context, boxes, scales, maxSize, threshold, scoreThreshold, output, true),
|
||||
FLOAT_TYPES, INDEXING_TYPES);
|
||||
return boxes->sizeAt(0);
|
||||
}
|
||||
|
||||
Nd4jLong
|
||||
nonMaxSuppressionV3(nd4j::LaunchContext* context, NDArray* boxes, NDArray* scores, int maxSize,
|
||||
double overlapThreshold, double scoreThreshold, NDArray* output) {
|
||||
BUILD_DOUBLE_SELECTOR(boxes->dataType(), output ? output->dataType():DataType::INT32, return nonMaxSuppressionGeneric_,
|
||||
(context, boxes, scores, maxSize, overlapThreshold, scoreThreshold, output, false),
|
||||
FLOAT_TYPES, INDEXING_TYPES);
|
||||
return boxes->sizeAt(0);
|
||||
}
|
||||
|
|
|
@ -825,37 +825,30 @@ namespace helpers {
|
|||
NDArray::prepareSpecialUse({output}, {input});
|
||||
auto n2 = input->sizeAt(-1) * input->sizeAt(-2);
|
||||
auto stream = context->getCudaStream();
|
||||
std::unique_ptr<NDArray> tempOutput(input->dup());
|
||||
// auto inputs = tempOutput->allTensorsAlongDimension({input->rankOf() - 2, input->rankOf() - 1});
|
||||
// for (Nd4jLong e = 0; e < packX.numberOfTads(); e++) {
|
||||
// auto subArray = inputs->at(e);
|
||||
// cholesky(context, subArray, subArray, true);
|
||||
// }
|
||||
// delete inputs;
|
||||
cholesky(context, input, tempOutput.get(), false);
|
||||
tempOutput->syncToHost();
|
||||
tempOutput->printIndexedBuffer("Cholesky res!!!");
|
||||
auto outputBuf = reinterpret_cast<T*>(output->specialBuffer()); // + e * n2; // + e * n2;
|
||||
auto inputBuf = reinterpret_cast<T*>(tempOutput->specialBuffer());
|
||||
output->assign(0);
|
||||
output->syncToDevice();
|
||||
auto packX = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(tempOutput->getShapeInfo(),
|
||||
{input->rankOf() - 2,
|
||||
input->rankOf() - 1});
|
||||
logDetKernel<T> << < packX.numberOfTads(), n2, 128, *stream >> >
|
||||
(inputBuf, tempOutput->specialShapeInfo(), packX.numberOfTads(), packX.specialShapeInfo(), packX.specialOffsets(), outputBuf, output->specialShapeInfo());
|
||||
// }
|
||||
NDArray tempOutput(*input);
|
||||
|
||||
cholesky(context, input, &tempOutput, false);
|
||||
|
||||
auto outputBuf = output->dataBuffer()->specialAsT<T>(); //reinterpret_cast<T*>(output->specialBuffer()); // + e * n2; // + e * n2;
|
||||
auto inputBuf = tempOutput.dataBuffer()->specialAsT<T>(); //reinterpret_cast<T*>(tempOutput->specialBuffer());
|
||||
output->nullify();
|
||||
auto packX = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(tempOutput.getShapeInfo(),
|
||||
{tempOutput.rankOf() - 2,
|
||||
tempOutput.rankOf() - 1});
|
||||
logDetKernel<T> <<< 128, 512, 256, *stream >>>(inputBuf, tempOutput.specialShapeInfo(),
|
||||
packX.numberOfTads(), packX.specialShapeInfo(),
|
||||
packX.specialOffsets(), outputBuf, output->specialShapeInfo());
|
||||
output->tickWriteDevice();
|
||||
NDArray::registerSpecialUse({output}, {input});
|
||||
//delete tempOutput;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
int logdetFunctor(nd4j::LaunchContext *context, NDArray *input, NDArray *output) {
|
||||
BUILD_SINGLE_SELECTOR(output->dataType(), logdetFunctor_, (context, input, output), FLOAT_NATIVE);
|
||||
BUILD_SINGLE_SELECTOR(output->dataType(), return logdetFunctor_, (context, input, output), FLOAT_NATIVE);
|
||||
}
|
||||
|
||||
BUILD_SINGLE_TEMPLATE(template int logdetFunctor_,
|
||||
(nd4j::LaunchContext * context, NDArray * input, NDArray * output), FLOAT_NATIVE);
|
||||
// BUILD_SINGLE_TEMPLATE(template int logdetFunctor_,
|
||||
// (nd4j::LaunchContext * context, NDArray * input, NDArray * output), FLOAT_NATIVE);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -35,129 +35,86 @@ namespace helpers {
|
|||
// -------------------------------------------------------------------------------------------------------------- //
|
||||
|
||||
template <typename T, typename I>
|
||||
static __global__ void segmentProdLinearKernel(void* input, Nd4jLong* inputShape, int* starts, int* lengths, Nd4jLong numOfClasses, void* output, Nd4jLong* outputShape) {
|
||||
__shared__ T* val;
|
||||
__shared__ Nd4jLong xLen, zLen, segment, zIndex;
|
||||
static __global__ void segmentProdLinearKernel(void* input, Nd4jLong* inputShape, int* starts, int* lengths,
|
||||
Nd4jLong numOfClasses, void* output, Nd4jLong* outputShape) {
|
||||
|
||||
__shared__ Nd4jLong xLen, zLen;
|
||||
__shared__ T* x;
|
||||
__shared__ T* z;
|
||||
__shared__ int threadsPerSegment, start, finish;
|
||||
|
||||
if (threadIdx.x == 0) {
|
||||
threadsPerSegment = (gridDim.x + numOfClasses - 1) / numOfClasses;
|
||||
segment = blockIdx.x / threadsPerSegment;
|
||||
x = reinterpret_cast<T*>(input);
|
||||
z = reinterpret_cast<T*>(output);
|
||||
extern __shared__ unsigned char shmem[];
|
||||
val = reinterpret_cast<T*>(shmem);
|
||||
xLen = shape::length(inputShape);
|
||||
zLen = shape::length(outputShape);
|
||||
|
||||
if (segment < numOfClasses) {
|
||||
zIndex = shape::getIndexOffset(segment, outputShape);
|
||||
start = starts[segment];
|
||||
finish = start + lengths[segment];
|
||||
//val[segment] = ;
|
||||
z[zIndex] = x[shape::getIndexOffset(start, inputShape)];
|
||||
val[segment] = z[zIndex];
|
||||
}
|
||||
|
||||
}
|
||||
__syncthreads();
|
||||
// auto tid = threadIdx.x + blockIdx.x * blockDim.x;
|
||||
// auto step = blockDim.x * gridDim.x;
|
||||
|
||||
for (auto e = start + threadIdx.x + 1; e < finish; e += blockDim.x) {
|
||||
for(auto segment = blockIdx.x; segment < numOfClasses; segment += gridDim.x) {
|
||||
auto zIndex = shape::getIndexOffset(segment, outputShape);
|
||||
auto start = starts[segment];
|
||||
auto finish = start + lengths[segment];
|
||||
if (lengths[segment] == 0) {
|
||||
continue;
|
||||
}
|
||||
for (auto e = start + threadIdx.x; e < finish; e += blockDim.x) {
|
||||
auto xIndex = shape::getIndexOffset(e, inputShape);
|
||||
nd4j::math::atomics::nd4j_atomicMul(&val[segment], x[xIndex]);
|
||||
nd4j::math::atomics::nd4j_atomicMul(&z[segment], x[xIndex]);
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
if (threadIdx.x == 0) {
|
||||
z[zIndex] = val[segment];
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
// -------------------------------------------------------------------------------------------------------------- //
|
||||
template <typename T, typename I>
|
||||
static __global__ void unsortedSegmentProdLinearKernel(void* input, Nd4jLong* inputShape, void* indices, Nd4jLong* indicesShape, int* starts, int* lengths, Nd4jLong numOfClasses, void* output, Nd4jLong* outputShape) {
|
||||
__shared__ T* val;
|
||||
__shared__ Nd4jLong xLen, zLen, segment, zIndex;
|
||||
__shared__ T* x;
|
||||
__shared__ T* z;
|
||||
__shared__ I* y; //int threadsPerSegment, start, finish;
|
||||
static __global__ void unsortedSegmentProdLinearKernel(T* input, Nd4jLong* inputShape, I* indices, Nd4jLong* indicesShape, int* starts, int* lengths, Nd4jLong numOfClasses, T* output, Nd4jLong* outputShape) {
|
||||
__shared__ Nd4jLong xLen, zLen;
|
||||
|
||||
if (threadIdx.x == 0) {
|
||||
// threadsPerSegment = (gridDim.x + numOfClasses - 1) / numOfClasses;
|
||||
segment = blockIdx.x;// / threadsPerSegment;
|
||||
x = reinterpret_cast<T*>(input);
|
||||
z = reinterpret_cast<T*>(output);
|
||||
y = reinterpret_cast<I*>(indices);
|
||||
// extern __shared__ unsigned char shmem[];
|
||||
// val = reinterpret_cast<T*>(shmem);
|
||||
xLen = shape::length(inputShape);
|
||||
zLen = shape::length(outputShape);
|
||||
|
||||
// if (segment < numOfClasses) {
|
||||
zIndex = shape::getIndexOffset(segment, outputShape);
|
||||
//start = starts[segment];
|
||||
//finish = start + lengths[segment];
|
||||
if (lengths[segment] > 0)
|
||||
z[zIndex] = x[shape::getIndexOffset(starts[segment], inputShape)];
|
||||
else
|
||||
z[zIndex] = 0; //DataTypeUtils::max<T>();
|
||||
// val[segment] = z[zIndex];
|
||||
// }
|
||||
|
||||
}
|
||||
__syncthreads();
|
||||
if (lengths[segment] > 0)
|
||||
for (auto e = threadIdx.x; e < xLen; e += blockDim.x) {
|
||||
auto xIndex = shape::getIndexOffset(e, inputShape);
|
||||
auto yIndex = shape::getIndexOffset(e, indicesShape);
|
||||
if (y[yIndex] == segment && e != starts[segment]) {
|
||||
nd4j::math::atomics::nd4j_atomicMul(&z[zIndex], x[xIndex]);
|
||||
auto start = threadIdx.x + blockIdx.x * blockDim.x;
|
||||
auto step = blockDim.x * gridDim.x;
|
||||
for (auto idx = start; idx < xLen; idx += step) {
|
||||
auto xIndex = shape::getIndexOffset(idx, inputShape);
|
||||
auto yIndex = shape::getIndexOffset(idx, indicesShape);
|
||||
auto segment = indices[yIndex];
|
||||
auto zIndex = shape::getIndexOffset(segment, outputShape);
|
||||
if (lengths[segment] == 0) {
|
||||
continue;
|
||||
}
|
||||
nd4j::math::atomics::nd4j_atomicMul(&output[zIndex], input[xIndex]);
|
||||
}
|
||||
}
|
||||
// -------------------------------------------------------------------------------------------------------------- //
|
||||
// SegmentProd kernel
|
||||
template <typename T, typename I>
|
||||
static __global__ void segmentProdTadKernel(void* inputBuf, Nd4jLong* inputShape, Nd4jLong* inputTads, Nd4jLong* inputTadOffsets, I* indices, int* starts, int* lengths, Nd4jLong numOfClasses, void* outputBuf, Nd4jLong* outputShape, Nd4jLong* outputTads, Nd4jLong* outputTadOffsets) {
|
||||
__shared__ T* val;
|
||||
__shared__ Nd4jLong len, segment, zIndex, total;
|
||||
__shared__ T* z;
|
||||
__shared__ int threadsPerSegment, start, finish;
|
||||
static __global__ void segmentProdTadKernel(void* inputBuf, Nd4jLong* inputShape, Nd4jLong* inputTads,
|
||||
Nd4jLong* inputTadOffsets, I* indices, int* starts, int* lengths, Nd4jLong numOfClasses, void* outputBuf,
|
||||
Nd4jLong* outputShape, Nd4jLong* outputTads, Nd4jLong* outputTadOffsets) {
|
||||
|
||||
__shared__ Nd4jLong len, total;
|
||||
|
||||
if (threadIdx.x == 0) {
|
||||
segment = indices[blockIdx.x]; // / threadsPerSegment;
|
||||
z = reinterpret_cast<T*>(outputBuf) + outputTadOffsets[segment];
|
||||
len = shape::length(inputTads);
|
||||
start = starts[segment];
|
||||
finish = start + lengths[segment];
|
||||
total = shape::sizeAt(inputShape, 0);
|
||||
|
||||
len = shape::length(inputTads);
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
auto idx = blockIdx.x;
|
||||
if (blockIdx.x <= total) {
|
||||
for (auto idx = blockIdx.x; idx < total; idx += gridDim.x) {
|
||||
auto x = reinterpret_cast<T *>(inputBuf) + inputTadOffsets[idx];
|
||||
if (blockIdx.x == start) {
|
||||
auto segment = indices[idx]; // / threadsPerSegment;
|
||||
auto z = reinterpret_cast<T *>(outputBuf) + outputTadOffsets[segment];
|
||||
auto start = starts[segment];
|
||||
auto finish = start + lengths[segment];
|
||||
if (lengths[segment] == 0) continue;
|
||||
for (auto e = threadIdx.x; e < len; e += blockDim.x) {
|
||||
auto xIndex = shape::getIndexOffset(e, inputTads);
|
||||
auto zIndex = shape::getIndexOffset(e, outputTads);
|
||||
nd4j::math::atomics::nd4j_atomicMul(&z[zIndex], x[xIndex]);
|
||||
}
|
||||
}
|
||||
else {
|
||||
for (auto e = threadIdx.x; e < len; e += blockDim.x) {
|
||||
auto xIndex = shape::getIndexOffset(e, inputTads);
|
||||
auto zIndex = shape::getIndexOffset(e, outputTads);
|
||||
if (lengths[segment] > 0)
|
||||
nd4j::math::atomics::nd4j_atomicMul(&z[zIndex], x[xIndex]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
// -------------------------------------------------------------------------------------------------------------- //
|
||||
|
||||
|
@ -177,7 +134,7 @@ namespace helpers {
|
|||
int* lengths = reinterpret_cast<int*>(classesRangesLens.specialBuffer());
|
||||
|
||||
if (input->isVector()) {
|
||||
segmentProdLinearKernel<T,I><<<numClasses, input->lengthOf(), numClasses * 32 + 32, *stream>>>(input->specialBuffer(), input->specialShapeInfo(), begins, lengths, numClasses, output->specialBuffer(), output->specialShapeInfo());
|
||||
segmentProdLinearKernel<T,I><<<128, 256, 128, *stream>>>(input->specialBuffer(), input->specialShapeInfo(), begins, lengths, numClasses, output->specialBuffer(), output->specialShapeInfo());
|
||||
}
|
||||
else {
|
||||
std::vector<int> dimensions = ShapeUtils::evalDimsToExclude(input->rankOf(), {0});
|
||||
|
@ -187,7 +144,7 @@ namespace helpers {
|
|||
Nd4jLong* inputTadOffsets = packX.specialOffsets();
|
||||
Nd4jLong* outputTads = packZ.specialShapeInfo();
|
||||
Nd4jLong* outputTadOffsets = packZ.specialOffsets();
|
||||
segmentProdTadKernel<T,I><<<input->sizeAt(0), 512, 2048, *stream>>>(input->specialBuffer(), input->specialShapeInfo(), inputTads, inputTadOffsets, reinterpret_cast<I*>(indices->specialBuffer()), begins, lengths, numClasses, output->specialBuffer(), output->specialShapeInfo(), outputTads, outputTadOffsets);
|
||||
segmentProdTadKernel<T,I><<<128, 512, 2048, *stream>>>(input->specialBuffer(), input->specialShapeInfo(), inputTads, inputTadOffsets, reinterpret_cast<I*>(indices->specialBuffer()), begins, lengths, numClasses, output->specialBuffer(), output->specialShapeInfo(), outputTads, outputTadOffsets);
|
||||
}
|
||||
|
||||
}
|
||||
|
@ -214,12 +171,15 @@ namespace helpers {
|
|||
fillUpSegments(indices, numOfClasses, classesRangesBegs, classesRangesLens);
|
||||
int* begins = reinterpret_cast<int*>(classesRangesBegs.specialBuffer());
|
||||
int* lengths = reinterpret_cast<int*>(classesRangesLens.specialBuffer());
|
||||
output->assign(1);
|
||||
|
||||
if (input->isVector()) {
|
||||
unsortedSegmentProdLinearKernel<T,I><<<dims.x, dims.y, dims.z, *stream>>>(input->specialBuffer(), input->specialShapeInfo(), indices->specialBuffer(), indices->specialShapeInfo(), begins, lengths, numOfClasses, output->specialBuffer(), output->specialShapeInfo());
|
||||
unsortedSegmentProdLinearKernel<T,I><<<128, 256, 256, *stream>>>(
|
||||
input->dataBuffer()->specialAsT<T>(), input->specialShapeInfo(),
|
||||
indices->dataBuffer()->specialAsT<I>(), indices->specialShapeInfo(), begins, lengths, numOfClasses,
|
||||
output->dataBuffer()->specialAsT<T>(), output->specialShapeInfo());
|
||||
}
|
||||
else {
|
||||
output->assign(1);
|
||||
std::vector<int> dimensions = ShapeUtils::evalDimsToExclude(input->rankOf(), {0});
|
||||
auto packX = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(input->getShapeInfo(), dimensions);
|
||||
auto packZ = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(output->getShapeInfo(), dimensions);
|
||||
|
@ -228,7 +188,7 @@ namespace helpers {
|
|||
Nd4jLong* outputTads = packZ.specialShapeInfo();
|
||||
Nd4jLong* outputTadOffsets = packZ.specialOffsets();
|
||||
dims.x = input->sizeAt(0);
|
||||
segmentProdTadKernel<T,I><<<dims.x, dims.y, dims.z, *stream>>>(input->specialBuffer(), input->specialShapeInfo(), inputTads, inputTadOffsets, reinterpret_cast<I*>(indices->specialBuffer()), begins, lengths, numOfClasses, output->specialBuffer(), output->specialShapeInfo(), outputTads, outputTadOffsets);
|
||||
segmentProdTadKernel<T,I><<<128, 256, 256, *stream>>>(input->specialBuffer(), input->specialShapeInfo(), inputTads, inputTadOffsets, reinterpret_cast<I*>(indices->specialBuffer()), begins, lengths, numOfClasses, output->specialBuffer(), output->specialShapeInfo(), outputTads, outputTadOffsets);
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
@ -32,77 +32,48 @@ namespace ops {
|
|||
namespace helpers {
|
||||
// -------------------------------------------------------------------------------------------------------------- //
|
||||
template <typename T, typename I>
|
||||
static __global__ void unsortedSegmentSqrtNLinearKernel(void* input, Nd4jLong* inputShape, void* indices, Nd4jLong* indicesShape, int* starts, int* lengths, Nd4jLong numOfClasses, void* output, Nd4jLong* outputShape) {
|
||||
__shared__ T* val;
|
||||
__shared__ Nd4jLong xLen, zLen, segment, zIndex;
|
||||
__shared__ T* x;
|
||||
__shared__ T* z;
|
||||
__shared__ I* y; //int threadsPerSegment, start, finish;
|
||||
static __global__ void unsortedSegmentSqrtNLinearKernel(T* input, Nd4jLong* inputShape, I* indices, Nd4jLong* indicesShape, int* starts, int* lengths, Nd4jLong numOfClasses, T* output, Nd4jLong* outputShape) {
|
||||
__shared__ Nd4jLong xLen, zLen;
|
||||
|
||||
if (threadIdx.x == 0) {
|
||||
// threadsPerSegment = (gridDim.x + numOfClasses - 1) / numOfClasses;
|
||||
segment = blockIdx.x;// / threadsPerSegment;
|
||||
x = reinterpret_cast<T*>(input);
|
||||
z = reinterpret_cast<T*>(output);
|
||||
y = reinterpret_cast<I*>(indices);
|
||||
// extern __shared__ unsigned char shmem[];
|
||||
// val = reinterpret_cast<T*>(shmem);
|
||||
xLen = shape::length(inputShape);
|
||||
zLen = shape::length(outputShape);
|
||||
|
||||
// if (segment < numOfClasses) {
|
||||
zIndex = shape::getIndexOffset(segment, outputShape);
|
||||
//start = starts[segment];
|
||||
//finish = start + lengths[segment];
|
||||
if (lengths[segment] > 0)
|
||||
z[zIndex] = x[shape::getIndexOffset(starts[segment], inputShape)] / nd4j::math::nd4j_sqrt<int, T>(lengths[segment]);
|
||||
else
|
||||
z[zIndex] = 0; //DataTypeUtils::max<T>();
|
||||
// val[segment] = z[zIndex];
|
||||
// }
|
||||
|
||||
}
|
||||
__syncthreads();
|
||||
if (lengths[segment] > 0)
|
||||
for (auto e = threadIdx.x + 1; e < xLen; e += blockDim.x) {
|
||||
auto xIndex = shape::getIndexOffset(e, inputShape);
|
||||
auto yIndex = shape::getIndexOffset(e, indicesShape);
|
||||
if (y[yIndex] == segment && e != starts[segment]) {
|
||||
nd4j::math::atomics::nd4j_atomicAdd(&z[zIndex], x[xIndex] / nd4j::math::nd4j_sqrt<int, T>(lengths[segment]));
|
||||
}
|
||||
|
||||
auto start = threadIdx.x + blockIdx.x * blockDim.x;
|
||||
auto step = blockDim.x * gridDim.x;
|
||||
|
||||
for (auto idx = start; idx < xLen; idx += step) {
|
||||
auto yIndex = shape::getIndexOffset(idx, indicesShape);
|
||||
auto segment = indices[yIndex];
|
||||
auto zIndex = shape::getIndexOffset(segment, outputShape);
|
||||
if (lengths[segment] == 0) continue;
|
||||
auto xIndex = shape::getIndexOffset(idx, inputShape);
|
||||
|
||||
nd4j::math::atomics::nd4j_atomicAdd(&output[zIndex], input[xIndex] / nd4j::math::nd4j_sqrt<int, T>(lengths[segment]));
|
||||
}
|
||||
}
|
||||
// -------------------------------------------------------------------------------------------------------------- //
|
||||
// SegmentSqrtN kernel
|
||||
template <typename T, typename I>
|
||||
static __global__ void segmentSqrtNTadKernel(void* inputBuf, Nd4jLong* inputShape, Nd4jLong* inputTads, Nd4jLong* inputTadOffsets, I* indices, int* starts, int* lengths, Nd4jLong numOfClasses, void* outputBuf, Nd4jLong* outputShape, Nd4jLong* outputTads, Nd4jLong* outputTadOffsets) {
|
||||
__shared__ T* val;
|
||||
__shared__ Nd4jLong len, segment, zIndex, total;
|
||||
__shared__ T* z;
|
||||
__shared__ int threadsPerSegment, start, finish;
|
||||
static __global__ void segmentSqrtNTadKernel(T* inputBuf, Nd4jLong* inputShape, Nd4jLong* inputTads, Nd4jLong* inputTadOffsets, I* indices, int* starts, int* lengths, Nd4jLong numOfClasses, void* outputBuf, Nd4jLong* outputShape, Nd4jLong* outputTads, Nd4jLong* outputTadOffsets) {
|
||||
|
||||
__shared__ Nd4jLong len, total;
|
||||
|
||||
if (threadIdx.x == 0) {
|
||||
segment = indices[blockIdx.x]; // / threadsPerSegment;
|
||||
z = reinterpret_cast<T*>(outputBuf) + outputTadOffsets[segment];
|
||||
len = shape::length(inputTads);
|
||||
start = starts[segment];
|
||||
finish = start + lengths[segment];
|
||||
total = shape::sizeAt(inputShape, 0);
|
||||
|
||||
len = shape::length(inputTads);
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
auto idx = blockIdx.x;
|
||||
if (blockIdx.x <= total) {
|
||||
auto x = reinterpret_cast<T *>(inputBuf) + inputTadOffsets[idx];
|
||||
if (blockIdx.x == start) {
|
||||
for (auto e = threadIdx.x; e < len; e += blockDim.x) {
|
||||
auto xIndex = shape::getIndexOffset(e, inputTads);
|
||||
auto zIndex = shape::getIndexOffset(e, outputTads);
|
||||
z[zIndex] = x[xIndex] / nd4j::math::nd4j_sqrt<int, T>(lengths[segment]);
|
||||
}
|
||||
}
|
||||
else {
|
||||
for (auto idx = blockIdx.x; idx < total; idx += gridDim.x) {
|
||||
auto segment = indices[idx];
|
||||
auto x = inputBuf + inputTadOffsets[idx];
|
||||
auto z = reinterpret_cast<T *>(outputBuf) + outputTadOffsets[segment];
|
||||
auto start = starts[segment];
|
||||
auto finish = start + lengths[segment];
|
||||
|
||||
for (auto e = threadIdx.x; e < len; e += blockDim.x) {
|
||||
auto xIndex = shape::getIndexOffset(e, inputTads);
|
||||
auto zIndex = shape::getIndexOffset(e, outputTads);
|
||||
|
@ -110,7 +81,6 @@ namespace helpers {
|
|||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
// -------------------------------------------------------------------------------------------------------------- //
|
||||
template <typename T, typename I>
|
||||
static void unsortedSegmentSqrtNFunctor_(nd4j::LaunchContext* context, NDArray* input, NDArray* indices, Nd4jLong numOfClasses, NDArray* output) {
|
||||
|
@ -122,17 +92,21 @@ namespace helpers {
|
|||
// classes.applyTrueBroadcast(nd4j::BroadcastOpsTuple::Assign(), &row, &classes);
|
||||
classesRangesBegs.assign(indices->lengthOf());
|
||||
classesRangesLens.assign(0);
|
||||
dim3 dims(numOfClasses, indices->lengthOf(), numOfClasses * 32 + 32);
|
||||
// dim3 dims(numOfClasses, indices->lengthOf(), numOfClasses * 32 + 32);
|
||||
dim3 dims(128, 256, 256);
|
||||
// int* classesBuf = reinterpret_cast<int*>(classes.specialBuffer());
|
||||
fillUpSegments(indices, numOfClasses, classesRangesBegs, classesRangesLens);
|
||||
int* begins = reinterpret_cast<int*>(classesRangesBegs.specialBuffer());
|
||||
int* lengths = reinterpret_cast<int*>(classesRangesLens.specialBuffer());
|
||||
|
||||
output->nullify();
|
||||
if (input->isVector()) {
|
||||
unsortedSegmentSqrtNLinearKernel<T,I><<<dims.x, dims.y, dims.z, *stream>>>(input->specialBuffer(), input->specialShapeInfo(), indices->specialBuffer(), indices->specialShapeInfo(), begins, lengths, numOfClasses, output->specialBuffer(), output->specialShapeInfo());
|
||||
unsortedSegmentSqrtNLinearKernel<T,I><<<dims.x, dims.y, dims.z, *stream>>>(
|
||||
input->dataBuffer()->specialAsT<T>(), input->specialShapeInfo(),
|
||||
indices->dataBuffer()->specialAsT<I>(), indices->specialShapeInfo(), begins, lengths, numOfClasses,
|
||||
output->dataBuffer()->specialAsT<T>(), output->specialShapeInfo());
|
||||
}
|
||||
else {
|
||||
output->assign(0);
|
||||
output->nullify();
|
||||
std::vector<int> dimensions = ShapeUtils::evalDimsToExclude(input->rankOf(), {0});
|
||||
auto packX = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(input->getShapeInfo(), dimensions);
|
||||
auto packZ = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(output->getShapeInfo(), dimensions);
|
||||
|
@ -141,7 +115,9 @@ namespace helpers {
|
|||
Nd4jLong* outputTads = packZ.specialShapeInfo();
|
||||
Nd4jLong* outputTadOffsets = packZ.specialOffsets();
|
||||
dims.x = input->sizeAt(0);
|
||||
segmentSqrtNTadKernel<T,I><<<dims.x, dims.y, dims.z, *stream>>>(input->specialBuffer(), input->specialShapeInfo(), inputTads, inputTadOffsets, reinterpret_cast<I*>(indices->specialBuffer()), begins, lengths, numOfClasses, output->specialBuffer(), output->specialShapeInfo(), outputTads, outputTadOffsets);
|
||||
segmentSqrtNTadKernel<T,I><<<dims.x, dims.y, dims.z, *stream>>>(
|
||||
input->dataBuffer()->specialAsT<T>(), input->specialShapeInfo(), inputTads, inputTadOffsets, indices->dataBuffer()->specialAsT<I>(),
|
||||
begins, lengths, numOfClasses, output->specialBuffer(), output->specialShapeInfo(), outputTads, outputTadOffsets);
|
||||
}
|
||||
}
|
||||
// -------------------------------------------------------------------------------------------------------------- //
|
||||
|
|
|
@ -28,6 +28,8 @@ namespace helpers {
|
|||
|
||||
void nonMaxSuppression(nd4j::LaunchContext * context, NDArray* boxes, NDArray* scales, int maxSize,
|
||||
double overlapThreshold, double scoreThreshold, NDArray* output);
|
||||
Nd4jLong nonMaxSuppressionV3(nd4j::LaunchContext * context, NDArray* boxes, NDArray* scales, int maxSize,
|
||||
double overlapThreshold, double scoreThreshold, NDArray* output);
|
||||
Nd4jLong nonMaxSuppressionGeneric(nd4j::LaunchContext* context, NDArray* boxes, NDArray* scores, int maxSize,
|
||||
double overlapThreshold, double scoreThreshold, NDArray* output);
|
||||
|
||||
|
|
|
@ -1446,35 +1446,63 @@ inline __device__ unsigned char nd4j_atomicMul<unsigned char>(unsigned char* add
|
|||
return (uint8_t)old;
|
||||
}
|
||||
|
||||
template <>
|
||||
inline __device__ int16_t nd4j_atomicMul<int16_t>(int16_t* address, int16_t val) {
|
||||
template <typename T>
|
||||
static inline __device__ T internal_16bit_atomicMul(T* address, int16_t val) {
|
||||
size_t shift = ((size_t)address & 2);
|
||||
int *base_address = (int *)((char*)address - shift);
|
||||
int old = val, assumed;
|
||||
//printf("%u %x", *base_address);
|
||||
|
||||
union I16PAIR {
|
||||
struct {
|
||||
T H;
|
||||
T L;
|
||||
} B;
|
||||
int W;
|
||||
|
||||
__host__ __device__
|
||||
I16PAIR() {};
|
||||
|
||||
__host__ __device__
|
||||
~I16PAIR() {};
|
||||
};
|
||||
|
||||
I16PAIR pairNew, pairOld, pairAssumed;
|
||||
|
||||
pairOld.W = (int) val;
|
||||
if (reinterpret_cast<int*>(address) == base_address) {
|
||||
do {
|
||||
|
||||
assumed = old;
|
||||
old = atomicCAS(base_address, assumed, (old * val));
|
||||
} while (assumed != old);
|
||||
pairNew.B.L = pairOld.B.L;
|
||||
pairNew.B.H = pairOld.B.H * val;
|
||||
pairAssumed.W = pairOld.W;
|
||||
|
||||
return (int16_t)old;
|
||||
pairOld.W = atomicCAS(base_address, pairAssumed.W, pairNew.W);
|
||||
} while (pairAssumed.W != pairOld.W);
|
||||
|
||||
return (T) pairOld.B.H;
|
||||
} else {
|
||||
do {
|
||||
|
||||
pairNew.B.H = pairOld.B.H;
|
||||
pairNew.B.L = pairOld.B.L * val;
|
||||
pairAssumed.W = pairOld.W;
|
||||
pairOld.W = atomicCAS(base_address, pairAssumed.W, pairNew.W);
|
||||
|
||||
} while (pairAssumed.W != pairOld.W);
|
||||
|
||||
return (T) pairOld.B.L;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
|
||||
template <>
|
||||
inline __device__ int16_t nd4j_atomicMul<int16_t>(int16_t* address, int16_t val) {
|
||||
return internal_16bit_atomicMul<int16_t>(address, val);
|
||||
}
|
||||
|
||||
template <>
|
||||
inline __device__ uint16_t nd4j_atomicMul<uint16_t>(uint16_t* address, uint16_t val) {
|
||||
size_t shift = ((size_t)address & 2);
|
||||
unsigned int *base_address = (unsigned int *)((char*)address - shift);
|
||||
unsigned int old = val, assumed;
|
||||
//printf("%u %x", *base_address);
|
||||
do {
|
||||
|
||||
assumed = old;
|
||||
old = atomicCAS(base_address, assumed, (old * val));
|
||||
} while (assumed != old);
|
||||
|
||||
return (uint16_t)old;
|
||||
|
||||
return internal_16bit_atomicMul<uint16_t>(address, val);
|
||||
}
|
||||
|
||||
template <>
|
||||
|
@ -1547,105 +1575,27 @@ inline __device__ Nd4jLong nd4j_atomicMul<Nd4jLong>(Nd4jLong* address, Nd4jLong
|
|||
|
||||
template <>
|
||||
inline __device__ bfloat16 nd4j_atomicMul<bfloat16>(bfloat16* address, bfloat16 val) {
|
||||
auto address_as_ull = (int*) address;
|
||||
|
||||
long addr = (long)(address);
|
||||
bool misaligned = addr & 0x3;
|
||||
|
||||
if (misaligned)
|
||||
address_as_ull = (int *) (address - 1);
|
||||
|
||||
BPAIR old, assumed, fresh;
|
||||
|
||||
old.W = *address_as_ull;
|
||||
do {
|
||||
|
||||
if (!misaligned) {
|
||||
bfloat16 res = old.B.H * val;
|
||||
fresh.B.H = res;
|
||||
fresh.B.L = old.B.L;
|
||||
} else {
|
||||
bfloat16 res = old.B.L * val;
|
||||
fresh.B.L = res;
|
||||
fresh.B.H = old.B.H;
|
||||
}
|
||||
|
||||
assumed.W = old.W;
|
||||
old.W = atomicCAS(address_as_ull, assumed.W, fresh.W);
|
||||
} while (assumed.W != old.W);
|
||||
|
||||
if (!misaligned) return old.B.H;
|
||||
else return old.B.L;
|
||||
return internal_16bit_atomicMul<bfloat16>(address, val);
|
||||
}
|
||||
|
||||
template <>
|
||||
inline __device__ float16 nd4j_atomicMul<float16>(float16* address, float16 val) {
|
||||
auto address_as_ull = (int*) address;
|
||||
|
||||
long addr = (long)(address);
|
||||
bool misaligned = addr & 0x3;
|
||||
|
||||
if (misaligned)
|
||||
address_as_ull = (int *) (address - 1);
|
||||
|
||||
BPAIR old, assumed, fresh;
|
||||
|
||||
old.W = *address_as_ull;
|
||||
do {
|
||||
|
||||
if (!misaligned) {
|
||||
bfloat16 res = old.B.H * val;
|
||||
fresh.B.H = res;
|
||||
fresh.B.L = old.B.L;
|
||||
} else {
|
||||
bfloat16 res = old.B.L * val;
|
||||
fresh.B.L = res;
|
||||
fresh.B.H = old.B.H;
|
||||
}
|
||||
|
||||
assumed.W = old.W;
|
||||
old.W = atomicCAS(address_as_ull, assumed.W, fresh.W);
|
||||
} while (assumed.W != old.W);
|
||||
|
||||
if (!misaligned) return old.B.H;
|
||||
else return old.B.L;
|
||||
return internal_16bit_atomicMul<float16>(address, val);
|
||||
}
|
||||
|
||||
template <>
|
||||
inline __device__ float nd4j_atomicDiv<float>(float* address, float val) {
|
||||
int* address_as_ull =
|
||||
(int*)address;
|
||||
int old = *address_as_ull, assumed;
|
||||
do {
|
||||
assumed = old;
|
||||
old = atomicCAS(address_as_ull, assumed, __float_as_int(__int_as_float(assumed) / val ));
|
||||
} while (assumed != old);
|
||||
return __int_as_float(old);
|
||||
return nd4j_atomicMul<float>(address, (float) 1.f / val);
|
||||
}
|
||||
|
||||
template <>
|
||||
inline __device__ float16 nd4j_atomicDiv<float16>(float16* address, float16 val) {
|
||||
int* address_as_ull =
|
||||
(int*)address;
|
||||
int old = *address_as_ull, assumed;
|
||||
do {
|
||||
assumed = old;
|
||||
old = atomicCAS(address_as_ull, assumed, __float_as_int(val *
|
||||
__float_as_int(assumed)));
|
||||
} while (assumed != old);
|
||||
return __int_as_float(old);
|
||||
return nd4j_atomicMul<float16>(address, (float16) 1.f / val);
|
||||
}
|
||||
|
||||
template <>
|
||||
inline __device__ bfloat16 nd4j_atomicDiv<bfloat16>(bfloat16* address, bfloat16 val) {
|
||||
int* address_as_ull =
|
||||
(int*)address;
|
||||
int old = *address_as_ull, assumed;
|
||||
do {
|
||||
assumed = old;
|
||||
old = atomicCAS(address_as_ull, assumed, __float_as_int(val *
|
||||
__float_as_int(assumed)));
|
||||
} while (assumed != old);
|
||||
return __int_as_float(old);
|
||||
return nd4j_atomicMul<bfloat16>(address, (bfloat16) 1.f / val);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
|
|
@ -0,0 +1,86 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2019 Konduit K.K.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
//
|
||||
// @author raver119@gmail.com
|
||||
//
|
||||
|
||||
#include "testlayers.h"
|
||||
#include <ops/declarable/CustomOperations.h>
|
||||
#include <NDArray.h>
|
||||
#include <ops/ops.h>
|
||||
#include <GradCheck.h>
|
||||
#include <helpers/RandomLauncher.h>
|
||||
#include <exceptions/cuda_exception.h>
|
||||
|
||||
|
||||
using namespace nd4j;
|
||||
|
||||
|
||||
class AtomicTests : public testing::Test {
|
||||
public:
|
||||
AtomicTests() {
|
||||
//
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
static _CUDA_G void multiplyKernel(void *vbuffer, uint64_t length, void *vresult) {
|
||||
auto buffer = reinterpret_cast<T*>(vbuffer);
|
||||
auto result = reinterpret_cast<T*>(vresult);
|
||||
|
||||
auto tid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
|
||||
for (auto e = tid; e < length; e += gridDim.x * blockDim.x) {
|
||||
auto rem = e % 4;
|
||||
auto i = (e - rem) / 4;
|
||||
|
||||
nd4j::math::atomics::nd4j_atomicMul<T>(&result[i], buffer[e]);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
static void multiplyLauncher(void *vbuffer, uint64_t length, void *vresult) {
|
||||
multiplyKernel<T><<<256, 256, 1024, *nd4j::LaunchContext::defaultContext()->getCudaStream()>>>(vbuffer, length, vresult);
|
||||
auto err = cudaStreamSynchronize(*nd4j::LaunchContext::defaultContext()->getCudaStream());
|
||||
if (err != 0)
|
||||
nd4j::cuda_exception::build("multiply failed", err);
|
||||
}
|
||||
|
||||
static void multiplyHost(NDArray &input, NDArray &output) {
|
||||
BUILD_SINGLE_SELECTOR(input.dataType(), multiplyLauncher, (input.specialBuffer(), input.lengthOf(), output.specialBuffer()), NUMERIC_TYPES);
|
||||
}
|
||||
|
||||
TEST_F(AtomicTests, test_multiply) {
|
||||
std::vector<nd4j::DataType> dtypes = {nd4j::DataType::FLOAT32, nd4j::DataType::DOUBLE, nd4j::DataType::INT16};
|
||||
|
||||
for (auto t:dtypes) {
|
||||
nd4j_printf("Trying data type [%s]\n", DataTypeUtils::asString(t).c_str());
|
||||
NDArray input('c', {4, 25}, t);
|
||||
NDArray output('c', {input.lengthOf() / 4}, t);
|
||||
NDArray exp = output.ulike();
|
||||
|
||||
input.assign(2);
|
||||
output.assign(2);
|
||||
exp.assign(32);
|
||||
|
||||
multiplyHost(input, output);
|
||||
ASSERT_EQ(exp, output);
|
||||
}
|
||||
|
||||
|
||||
|
||||
}
|
|
@ -2330,6 +2330,76 @@ TEST_F(DeclarableOpsTests10, Image_NonMaxSuppressing_5) {
|
|||
delete results;
|
||||
}
|
||||
|
||||
TEST_F(DeclarableOpsTests10, Image_NonMaxSuppressing_6) {
|
||||
|
||||
NDArray boxes = NDArrayFactory::create<float16>('c', {3, 4}, {0.8115f, 0.4121f, 0.0771f, 0.4863f,
|
||||
0.7412f, 0.7607f, 0.1543f, 0.5479f,
|
||||
0.8223f, 0.2246f, 0.0049f, 0.6465f});
|
||||
NDArray scales = NDArrayFactory::create<float16>('c', {3}, {0.0029f, 0.8135f, 0.4873f}); //3, 0, 1, 2, 4, 5
|
||||
NDArray expected = NDArrayFactory::create<int>('c', {2}, {1,2});
|
||||
NDArray maxSize = NDArrayFactory::create(2);
|
||||
NDArray threshold = NDArrayFactory::create(0.5f);
|
||||
NDArray scoreThreshold = NDArrayFactory::create(-DataTypeUtils::infOrMax<float>());
|
||||
nd4j::ops::non_max_suppression_v3 op;
|
||||
auto results = op.execute({&boxes, &scales, &maxSize, &threshold, &scoreThreshold}, {}, {});
|
||||
|
||||
ASSERT_EQ(Status::OK(), results->status());
|
||||
|
||||
NDArray* result = results->at(0);
|
||||
// result->printBuffer("NonMaxSuppression OUtput6");
|
||||
// result->printShapeInfo("Ouput6 shape is");
|
||||
ASSERT_TRUE(expected.isSameShapeStrict(result));
|
||||
ASSERT_TRUE(expected.equalsTo(result));
|
||||
|
||||
delete results;
|
||||
}
|
||||
|
||||
TEST_F(DeclarableOpsTests10, Image_NonMaxSuppressing_06) {
|
||||
|
||||
NDArray boxes = NDArrayFactory::create<bfloat16>('c', {3, 4}, {0.8115f, 0.4121f, 0.0771f, 0.4863f,
|
||||
0.7412f, 0.7607f, 0.1543f, 0.5479f,
|
||||
0.8223f, 0.2246f, 0.0049f, 0.6465f});
|
||||
NDArray scales = NDArrayFactory::create<bfloat16>('c', {3}, {0.0029f, 0.8135f, 0.4873f}); //3, 0, 1, 2, 4, 5
|
||||
NDArray expected = NDArrayFactory::create<int>('c', {2}, {1,2});
|
||||
NDArray maxSize = NDArrayFactory::create(2);
|
||||
NDArray threshold = NDArrayFactory::create(0.5f);
|
||||
NDArray scoreThreshold = NDArrayFactory::create(-DataTypeUtils::infOrMax<float>());
|
||||
nd4j::ops::non_max_suppression_v3 op;
|
||||
auto results = op.execute({&boxes, &scales, &maxSize, &threshold, &scoreThreshold}, {}, {});
|
||||
|
||||
ASSERT_EQ(Status::OK(), results->status());
|
||||
|
||||
NDArray* result = results->at(0);
|
||||
// result->printBuffer("NonMaxSuppression OUtput06");
|
||||
// result->printShapeInfo("Ouput06 shape is");
|
||||
ASSERT_TRUE(expected.isSameShapeStrict(result));
|
||||
ASSERT_TRUE(expected.equalsTo(result));
|
||||
|
||||
delete results;
|
||||
}
|
||||
|
||||
TEST_F(DeclarableOpsTests10, Image_NonMaxSuppressing_7) {
|
||||
|
||||
NDArray boxes = NDArrayFactory::create<float>('c', {3, 4}, {0.7788f, 0.8012f, 0.7244f, 0.2329f,
|
||||
0.7271f, 0.1804f, 0.5056f, 0.8929f,
|
||||
0.5461f, 0.9234f, 0.0856f, 0.7938f});
|
||||
NDArray scales = NDArrayFactory::create<float>('c', {3}, {0.7717f, 0.9281f, 0.9846f}); //3, 0, 1, 2, 4, 5
|
||||
NDArray maxSize = NDArrayFactory::create(0);
|
||||
NDArray threshold = NDArrayFactory::create(0.5f);
|
||||
NDArray scoreThreshold = NDArrayFactory::create(0.5f);
|
||||
nd4j::ops::non_max_suppression_v3 op;
|
||||
auto results = op.execute({&boxes, &scales, &maxSize, &threshold, &scoreThreshold}, {}, {});
|
||||
|
||||
ASSERT_EQ(Status::OK(), results->status());
|
||||
|
||||
NDArray* result = results->at(0);
|
||||
// result->printBuffer("NonMaxSuppression OUtput7");
|
||||
// result->printShapeInfo("Ouput6 shape is");
|
||||
ASSERT_TRUE(result->isEmpty());
|
||||
|
||||
delete results;
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests10, Image_NonMaxSuppressingOverlap_1) {
|
||||
|
||||
|
@ -2720,22 +2790,22 @@ TEST_F(DeclarableOpsTests10, FakeQuantWithMinMaxVars_Test_4) {
|
|||
|
||||
NDArray x = NDArrayFactory::create<float>('c', {2,4,5,3});
|
||||
NDArray exp = NDArrayFactory::create<float>('c', {2,4,5,3},{
|
||||
1.0588236, 1.9607843, 3.019608, 4.0588236, 5.098039, 6.039216, 7.0588236, 8.039216, 9.058824,
|
||||
10.058824, 10.980392, 12.078432, 13.058824, 13.921569, 15.09804, 16.058825, 17.058825, 18.117647,
|
||||
19.058825, 20., 21.137257, 22.058825, 22.941177, 23.882355, 25.058825, 26.078432, 26.901962,
|
||||
28.058825, 29.019608, 29.92157, 31.058825, 31.960785, 32.941177, 34.058823, 35.09804, 35.960785,
|
||||
37.058823, 38.039215, 38.980392, 40.058823, 40.980392, 42.000004, 43.058826, 43.92157, 45.01961,
|
||||
45., 47.058823, 48.03922, 45., 50., 51.058826, 45., 50., 54.078434,
|
||||
45., 50., 57.09804, 45., 50., 60.11765, 45., 50., 62.862747,
|
||||
45., 50., 65.882355, 45., 50., 68.90196, 45., 50., 70.,
|
||||
45., 50., 70., 45., 50., 70., 45., 50., 70.,
|
||||
45., 50., 70., 45., 50., 70., 45., 50., 70.,
|
||||
45., 50., 70., 45., 50., 70., 45., 50., 70.,
|
||||
45., 50., 70., 45., 50., 70., 45., 50., 70.,
|
||||
45., 50., 70., 45., 50., 70., 45., 50., 70.,
|
||||
45., 50., 70.});
|
||||
NDArray min = NDArrayFactory::create<float>({20., 20., 20.});
|
||||
NDArray max = NDArrayFactory::create<float>({65., 70., 90.});
|
||||
1.0588236f, 1.9607843f, 3.019608f, 4.0588236f, 5.098039f, 6.039216f, 7.0588236f, 8.039216f, 9.058824f,
|
||||
10.058824f, 10.980392f, 12.078432f, 13.058824f, 13.921569f, 15.09804f, 16.058825f, 17.058825f, 18.117647f,
|
||||
19.058825f, 20.f, 21.137257f, 22.058825f, 22.941177f, 23.882355f, 25.058825f, 26.078432f, 26.901962f,
|
||||
28.058825f, 29.019608f, 29.92157f, 31.058825f, 31.960785f, 32.941177f, 34.058823f, 35.09804f, 35.960785f,
|
||||
37.058823f, 38.039215f, 38.980392f, 40.058823f, 40.980392f, 42.000004f, 43.058826f, 43.92157f, 45.01961f,
|
||||
45.f, 47.058823f, 48.03922f, 45.f, 50.f, 51.058826f, 45.f, 50.f, 54.078434f,
|
||||
45.f, 50.f, 57.09804f, 45.f, 50.f, 60.11765f, 45.f, 50.f, 62.862747f,
|
||||
45.f, 50.f, 65.882355f, 45.f, 50.f, 68.90196f, 45.f, 50.f, 70.f,
|
||||
45.f, 50.f, 70.f, 45.f, 50.f, 70.f, 45.f, 50.f, 70.f,
|
||||
45.f, 50.f, 70.f, 45.f, 50.f, 70.f, 45.f, 50.f, 70.f,
|
||||
45.f, 50.f, 70.f, 45.f, 50.f, 70.f, 45.f, 50.f, 70.f,
|
||||
45.f, 50.f, 70.f, 45.f, 50.f, 70.f, 45.f, 50.f, 70.f,
|
||||
45.f, 50.f, 70.f, 45.f, 50.f, 70.f, 45.f, 50.f, 70.f,
|
||||
45.f, 50.f, 70.f});
|
||||
NDArray min = NDArrayFactory::create<float>({20.f, 20.f, 20.f});
|
||||
NDArray max = NDArrayFactory::create<float>({65.f, 70.f, 90.f});
|
||||
x.linspace(1.);
|
||||
nd4j::ops::fake_quant_with_min_max_vars_per_channel op;
|
||||
auto results = op.execute({&x, &min, &max}, {}, {});
|
||||
|
@ -2756,36 +2826,36 @@ TEST_F(DeclarableOpsTests10, FakeQuantWithMinMaxVars_Test_4) {
|
|||
TEST_F(DeclarableOpsTests10, FakeQuantWithMinMaxVars_Test_5) {
|
||||
NDArray x = NDArrayFactory::create<float>('c', {2, 3, 5, 4});
|
||||
NDArray exp = NDArrayFactory::create<float>('c', {2, 3, 5, 4},{
|
||||
-19.92157 , -18.980392 , -18.039217 , -16.941177 ,
|
||||
-19.92157 , -18.980392 , -18.039217 , -16.941177 ,
|
||||
-19.92157 , -18.980392 , -18.039217 , -16.941177 ,
|
||||
-19.92157 , -18.980392 , -18.039217 , -16.941177 ,
|
||||
-19.92157 , -18.980392 , -18.039217 , -16.941177 ,
|
||||
-19.92157 , -18.980392 , -18.039217 , -16.941177 ,
|
||||
-19.92157 , -18.980392 , -18.039217 , -16.941177 ,
|
||||
-19.92157 , -18.980392 , -18.039217 , -16.941177 ,
|
||||
-19.92157 , -18.980392 , -18.039217 , -16.941177 ,
|
||||
-19.92157 , -18.980392 , -18.039217 , -16.941177 ,
|
||||
-19.92157 , -18.980392 , -18.039217 , -16.941177 ,
|
||||
-16. , -15.058824 , -13.960785 , -13.0196085 ,
|
||||
-11.92157 , -10.980392 , -10.039217 , -8.941177 ,
|
||||
-8.000001 , -7.0588236 , -5.960785 , -5.0196085 ,
|
||||
-3.9215698 , -2.9803925 , -2.039217 , -0.94117737,
|
||||
0. , 0.94117737, 2.039215 , 2.9803925 ,
|
||||
4.07843 , 5.0196075 , 5.960783 , 7.0588226 ,
|
||||
8. , 8.941177 , 10.039215 , 10.980392 ,
|
||||
12.07843 , 13.019608 , 13.960783 , 15.058823 ,
|
||||
16. , 16.941177 , 18.039217 , 18.980392 ,
|
||||
20.07843 , 21.019608 , 21.960783 , 23.058823 ,
|
||||
20.07843 , 21.019608 , 21.960783 , 23.058823 ,
|
||||
20.07843 , 21.019608 , 21.960783 , 23.058823 ,
|
||||
20.07843 , 21.019608 , 21.960783 , 23.058823 ,
|
||||
20.07843 , 21.019608 , 21.960783 , 23.058823 ,
|
||||
20.07843 , 21.019608 , 21.960783 , 23.058823 ,
|
||||
20.07843 , 21.019608 , 21.960783 , 23.058823 ,
|
||||
20.07843 , 21.019608 , 21.960783 , 23.058823 ,
|
||||
20.07843 , 21.019608 , 21.960783 , 23.058823 ,
|
||||
20.07843 , 21.019608 , 21.960783 , 23.058823
|
||||
-19.92157f, -18.980392f, -18.039217f, -16.941177f,
|
||||
-19.92157f, -18.980392f, -18.039217f, -16.941177f,
|
||||
-19.92157f, -18.980392f, -18.039217f, -16.941177f,
|
||||
-19.92157f, -18.980392f, -18.039217f, -16.941177f,
|
||||
-19.92157f, -18.980392f, -18.039217f, -16.941177f,
|
||||
-19.92157f, -18.980392f, -18.039217f, -16.941177f,
|
||||
-19.92157f, -18.980392f, -18.039217f, -16.941177f,
|
||||
-19.92157f, -18.980392f, -18.039217f, -16.941177f,
|
||||
-19.92157f, -18.980392f, -18.039217f, -16.941177f,
|
||||
-19.92157f, -18.980392f, -18.039217f, -16.941177f,
|
||||
-19.92157f, -18.980392f, -18.039217f, -16.941177f,
|
||||
-16.f, -15.058824f, -13.960785f, -13.0196085f,
|
||||
-11.92157f, -10.980392f, -10.039217f, -8.941177f,
|
||||
-8.000001f, -7.0588236f, -5.960785f, -5.0196085f,
|
||||
-3.9215698f, -2.9803925f, -2.039217f, -0.94117737f,
|
||||
0.f, 0.94117737f, 2.039215f, 2.9803925f,
|
||||
4.07843f, 5.0196075f, 5.960783f, 7.0588226f,
|
||||
8.f, 8.941177f, 10.039215f, 10.980392f,
|
||||
12.07843f, 13.019608f, 13.960783f, 15.058823f,
|
||||
16.f, 16.941177f, 18.039217f, 18.980392f,
|
||||
20.07843f, 21.019608f, 21.960783f, 23.058823f,
|
||||
20.07843f, 21.019608f, 21.960783f, 23.058823f,
|
||||
20.07843f, 21.019608f, 21.960783f, 23.058823f,
|
||||
20.07843f, 21.019608f, 21.960783f, 23.058823f,
|
||||
20.07843f, 21.019608f, 21.960783f, 23.058823f,
|
||||
20.07843f, 21.019608f, 21.960783f, 23.058823f,
|
||||
20.07843f, 21.019608f, 21.960783f, 23.058823f,
|
||||
20.07843f, 21.019608f, 21.960783f, 23.058823f,
|
||||
20.07843f, 21.019608f, 21.960783f, 23.058823f,
|
||||
20.07843f, 21.019608f, 21.960783f, 23.058823f
|
||||
});
|
||||
NDArray min = NDArrayFactory::create<float>({-20., -19., -18., -17});
|
||||
NDArray max = NDArrayFactory::create<float>({20., 21., 22., 23});
|
||||
|
@ -2833,9 +2903,10 @@ TEST_F(DeclarableOpsTests10, FakeQuantWithMinMaxVars_Test_6) {
|
|||
// 0.0877f, 0.5966f, 0.6600f, 0.3513f, 0.1604f
|
||||
// });
|
||||
|
||||
NDArray exp = NDArrayFactory::create<float>('c', {3,5}, {0.77700233, 0.596913, 0.72314, 0.23104, 0.50982356,
|
||||
0.17930824, 0.50528157, 0.86846, 0.34995764, 0.50982356,
|
||||
0.08735529, 0.596913, 0.6574, 0.34995764, 0.15974471});
|
||||
NDArray exp = NDArrayFactory::create<float>('c', {3,5}, {
|
||||
0.77700233f, 0.596913f, 0.72314f, 0.23104f, 0.50982356f,
|
||||
0.17930824f, 0.50528157f, 0.86846f, 0.34995764f, 0.50982356f,
|
||||
0.08735529f, 0.596913f, 0.6574f, 0.34995764f, 0.15974471f});
|
||||
NDArray min = NDArrayFactory::create<float>('c', {5}, {-0.2283f, -0.0719f, -0.0154f, -0.5162f, -0.3567f});
|
||||
NDArray max = NDArrayFactory::create<float>('c', {5}, {0.9441f, 0.5957f, 0.8669f, 0.3502f, 0.5100f});
|
||||
// x.linspace(-60.);
|
||||
|
@ -2856,45 +2927,74 @@ TEST_F(DeclarableOpsTests10, FakeQuantWithMinMaxVars_Test_6) {
|
|||
delete results;
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////
|
||||
//TEST_F(DeclarableOpsTests10, FakeQuantWithMinMaxVars_Test_7) {
|
||||
//
|
||||
// NDArray x = NDArrayFactory::create<double>('c', {100});
|
||||
// NDArray exp = NDArrayFactory::create<double>('c', {100}, {
|
||||
// 0.f, 0.f, 0.f , 0.f , 0.06666667f, 0.06666667f ,
|
||||
// 0.06666667, 0.06666667, 0.06666667, 0.06666667, 0.06666667, 0.13333334 ,
|
||||
// 0.13333334, 0.13333334, 0.13333334, 0.13333334, 0.13333334, 0.20000002 ,
|
||||
// 0.20000002, 0.20000002, 0.20000002, 0.20000002, 0.20000002, 0.20000002 ,
|
||||
// 0.26666668, 0.26666668, 0.26666668, 0.26666668, 0.26666668, 0.26666668 ,
|
||||
// 0.26666668, 0.33333334, 0.33333334, 0.33333334, 0.33333334, 0.33333334 ,
|
||||
// 0.33333334, 0.40000004, 0.40000004, 0.40000004, 0.40000004, 0.40000004 ,
|
||||
// 0.40000004, 0.40000004, 0.4666667 , 0.4666667 , 0.4666667 , 0.4666667 ,
|
||||
// 0.4666667 , 0.4666667 , 0.4666667 , 0.53333336, 0.53333336, 0.53333336 ,
|
||||
// 0.53333336, 0.53333336, 0.53333336, 0.6 , 0.6 , 0.6 ,
|
||||
// 0.6 , 0.6 , 0.6 , 0.6 , 0.6666667 , 0.6666667 ,
|
||||
// 0.6666667 , 0.6666667 , 0.6666667 , 0.6666667 , 0.6666667 , 0.73333335 ,
|
||||
// 0.73333335, 0.73333335, 0.73333335, 0.73333335, 0.73333335, 0.8000001 ,
|
||||
// 0.8000001 , 0.8000001 , 0.8000001 , 0.8000001 , 0.8000001 , 0.8000001 ,
|
||||
// 0.86666673, 0.86666673, 0.86666673, 0.86666673, 0.86666673, 0.86666673 ,
|
||||
// 0.86666673, 0.9333334 , 0.9333334 , 0.9333334 , 0.9333334 , 0.9333334 ,
|
||||
// 0.9333334 , 1., 1., 1.,
|
||||
// });
|
||||
// NDArray min = NDArrayFactory::create<float>('c', {1},{0.0f});
|
||||
// NDArray max = NDArrayFactory::create<float>('c', {1}, {1.f});
|
||||
// x.linspace(0., 0.01);
|
||||
// nd4j::ops::fake_quant_with_min_max_vars op;
|
||||
// auto results = op.execute({&x, &min, &max}, {}, {});
|
||||
//
|
||||
// ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||
//
|
||||
// auto result = results->at(0);
|
||||
//////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests10, FakeQuantWithMinMaxVars_Test_7) {
|
||||
|
||||
NDArray x = NDArrayFactory::create<float>('c', {100});
|
||||
NDArray exp = NDArrayFactory::create<float>('c', {100}, {
|
||||
0.f, 0.01176471f, 0.01960784f, 0.03137255f, 0.03921569f,
|
||||
0.0509804f, 0.05882353f, 0.07058824f, 0.07843138f, 0.09019608f,
|
||||
0.09803922f, 0.10980393f, 0.12156864f, 0.12941177f, 0.14117648f,
|
||||
0.14901961f, 0.16078432f, 0.16862746f, 0.18039216f, 0.18823531f,
|
||||
0.20000002f, 0.21176472f, 0.21960786f, 0.23137257f, 0.2392157f,
|
||||
0.2509804f, 0.25882354f, 0.27058825f, 0.2784314f, 0.2901961f,
|
||||
0.3019608f, 0.30980393f, 0.32156864f, 0.32941177f, 0.34117648f,
|
||||
0.34901962f, 0.36078432f, 0.36862746f, 0.3803922f, 0.38823533f,
|
||||
0.40000004f, 0.41176474f, 0.41960788f, 0.43137258f, 0.43921572f,
|
||||
0.45098042f, 0.45882356f, 0.47058827f, 0.4784314f, 0.4901961f,
|
||||
0.49803925f, 0.50980395f, 0.52156866f, 0.5294118f, 0.5411765f,
|
||||
0.54901963f, 0.56078434f, 0.5686275f, 0.5803922f, 0.5882353f,
|
||||
0.6f, 0.6117647f, 0.61960787f, 0.6313726f, 0.6392157f,
|
||||
0.6509804f, 0.65882355f, 0.67058825f, 0.6784314f, 0.6901961f,
|
||||
0.69803923f, 0.70980394f, 0.72156864f, 0.7294118f, 0.7411765f,
|
||||
0.7490196f, 0.7607844f, 0.7686275f, 0.7803922f, 0.78823537f,
|
||||
0.8000001f, 0.8117648f, 0.8196079f, 0.8313726f, 0.83921576f,
|
||||
0.85098046f, 0.8588236f, 0.8705883f, 0.87843144f, 0.89019614f,
|
||||
0.8980393f, 0.909804f, 0.9215687f, 0.9294118f, 0.94117653f,
|
||||
0.9490197f, 0.9607844f, 0.9686275f, 0.9803922f, 0.98823535f
|
||||
});
|
||||
NDArray min = NDArrayFactory::create<float>('c', {1},{0.0f});
|
||||
NDArray max = NDArrayFactory::create<float>('c', {1}, {1.f});
|
||||
x.linspace(0., 0.01);
|
||||
nd4j::ops::fake_quant_with_min_max_vars op;
|
||||
auto results = op.execute({&x, &min, &max}, {}, {});
|
||||
|
||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||
|
||||
auto result = results->at(0);
|
||||
// result->printBuffer("Quantized7");
|
||||
// exp.printBuffer("Expected 7");
|
||||
// ASSERT_TRUE(exp.isSameShapeStrict(result));
|
||||
// ASSERT_TRUE(exp.equalsTo(result));
|
||||
//
|
||||
// delete results;
|
||||
//}
|
||||
ASSERT_TRUE(exp.isSameShapeStrict(result));
|
||||
ASSERT_TRUE(exp.equalsTo(result));
|
||||
|
||||
delete results;
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests10, FakeQuantWithMinMaxVars_Test_8) {
|
||||
|
||||
NDArray x = NDArrayFactory::create<float>('c', {10});
|
||||
NDArray exp = NDArrayFactory::create<float>('c', {10}, {
|
||||
0.f, 0.09803922f, 0.20000002f, 0.3019608f, 0.40000004f, 0.49803925f,
|
||||
0.6f, 0.69803923f, 0.8000001f, 0.8980393f
|
||||
});
|
||||
NDArray min = NDArrayFactory::create<float>('c', {1},{0.0f});
|
||||
NDArray max = NDArrayFactory::create<float>('c', {1}, {1.f});
|
||||
x.linspace(0., 0.1);
|
||||
nd4j::ops::fake_quant_with_min_max_vars op;
|
||||
auto results = op.execute({&x, &min, &max}, {}, {});
|
||||
|
||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||
|
||||
auto result = results->at(0);
|
||||
// x.printBuffer("SourInput8");
|
||||
// result->printBuffer("Quantized8");
|
||||
// exp.printBuffer("Expected 8");
|
||||
ASSERT_TRUE(exp.isSameShapeStrict(result));
|
||||
ASSERT_TRUE(exp.equalsTo(result));
|
||||
|
||||
delete results;
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests10, batchnorm_test1) {
|
||||
|
|
|
@ -1523,14 +1523,12 @@ TEST_F(DeclarableOpsTests6, LogDet_1) {
|
|||
auto x = NDArrayFactory::create<double>('c', {2, 3, 3}, {4,12,-16,12,37,-43,-16,-43,98, 4,1.2,-1.6,1.2,3.7,-4.3,-1.6,-4.3,9.8});
|
||||
auto exp = NDArrayFactory::create<double>({ 3.5835189, 4.159008});
|
||||
|
||||
//x.printIndexedBuffer("Input");
|
||||
nd4j::ops::logdet op;
|
||||
auto result = op.execute({&x}, {}, {});
|
||||
|
||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||
|
||||
auto z = result->at(0);
|
||||
|
||||
ASSERT_TRUE(exp.isSameShape(z));
|
||||
ASSERT_TRUE(exp.equalsTo(z));
|
||||
|
||||
|
|
|
@ -2231,6 +2231,27 @@ TEST_F(DeclarableOpsTests7, TestSegmentProd_05) {
|
|||
|
||||
auto result = op.execute({&x, &idx}, {}, {});
|
||||
ASSERT_EQ(result->status(), Status::OK());
|
||||
auto res = result->at(0);
|
||||
// res->printIndexedBuffer("Segment prod 05");
|
||||
ASSERT_TRUE(exp.equalsTo(result->at(0)));
|
||||
|
||||
delete result;
|
||||
}
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests7, TestSegmentProd_05_1) {
|
||||
auto x = NDArrayFactory::create<int>({1,2,3,4,5,6,7,8 });
|
||||
|
||||
// ----------------------------------------------------------------
|
||||
|
||||
auto idx = NDArrayFactory::create<int>({0,0,1,2,2,2,3,3});
|
||||
auto exp = NDArrayFactory::create<int>({ 2, 3, 120, 56});
|
||||
|
||||
nd4j::ops::segment_prod op;
|
||||
|
||||
auto result = op.execute({&x, &idx}, {}, {});
|
||||
ASSERT_EQ(result->status(), Status::OK());
|
||||
auto res = result->at(0);
|
||||
// res->printIndexedBuffer("Segment prod 05_1");
|
||||
ASSERT_TRUE(exp.equalsTo(result->at(0)));
|
||||
|
||||
delete result;
|
||||
|
@ -2270,6 +2291,23 @@ TEST_F(DeclarableOpsTests7, TestSegmentProd_07) {
|
|||
delete result;
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests7, TestSegmentProd_08) {
|
||||
auto x = NDArrayFactory::create<int>({'\x1','\x2','\x3','\x4','\x5','\x6','\x7','\x8', '\x9', '\xA' });
|
||||
|
||||
// ----------------------------------------------------------------
|
||||
|
||||
auto idx = NDArrayFactory::create<int>({0,0,2,2,2,2,3,3,3,3});
|
||||
auto exp = NDArrayFactory::create<int>({ 2, 1,360, 5040});
|
||||
nd4j::ops::segment_prod op;
|
||||
|
||||
auto result = op.execute({&x, &idx}, {}, {});
|
||||
ASSERT_EQ(result->status(), Status::OK());
|
||||
ASSERT_TRUE(exp.equalsTo(result->at(0)));
|
||||
|
||||
delete result;
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests7, TestUnsortedSegmentProd_1) {
|
||||
auto x = NDArrayFactory::create<double>({1.8, 2.5,4., 9., 2.1, 2.4,3.,9., 2.1, 2.1,0.7, 0.1, 3., 4.2, 2.2, 1.});
|
||||
|
@ -2341,6 +2379,22 @@ TEST_F(DeclarableOpsTests7, TestUnsortedSegmentProd_12) {
|
|||
|
||||
delete result;
|
||||
}
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests7, TestUnsortedSegmentProd_08) {
|
||||
auto x = NDArrayFactory::create<int>({'\x1','\x2','\x3','\x4','\x5','\x6','\x7','\x8', '\x9', '\xA' });
|
||||
|
||||
// ----------------------------------------------------------------
|
||||
|
||||
auto idx = NDArrayFactory::create<int>({0,0,2,2,2,2,3,3,3,3});
|
||||
auto exp = NDArrayFactory::create<int>({ 2, 1,360, 5040});
|
||||
nd4j::ops::unsorted_segment_prod op;
|
||||
|
||||
auto result = op.execute({&x, &idx}, {}, {4});
|
||||
ASSERT_EQ(result->status(), Status::OK());
|
||||
ASSERT_TRUE(exp.equalsTo(result->at(0)));
|
||||
|
||||
delete result;
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests7, TestUnsortedSegmentProd_3) {
|
||||
|
@ -2401,6 +2455,41 @@ TEST_F(DeclarableOpsTests7, TestUnsortedSegmentProd_4) {
|
|||
delete result;
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests7, TestUnsortedSegmentProd_5) {
|
||||
auto x = NDArrayFactory::create<double>('c', {8, 15});
|
||||
|
||||
// ----------------------------------------------------------------
|
||||
|
||||
auto idx = NDArrayFactory::create<int>({3, 1, 2, 1, 2, 3, 2, 1});
|
||||
auto exp = NDArrayFactory::create<double>('c', {4, 15}, {
|
||||
1., 1., 1., 1., 1.,
|
||||
1., 1., 1., 1., 1.,
|
||||
1., 1., 1., 1., 1.,
|
||||
78016., 85493., 93312., 101479., 110000.,
|
||||
118881., 128128., 137747., 147744., 158125.,
|
||||
168896., 180063., 191632., 203609., 216000.,
|
||||
172081., 182528., 193347., 204544., 216125.,
|
||||
228096., 240463., 253232., 266409., 280000.,
|
||||
294011., 308448., 323317., 338624., 354375.,
|
||||
76., 154., 234., 316., 400.,
|
||||
486., 574., 664., 756., 850.,
|
||||
946., 1044., 1144., 1246., 1350.});
|
||||
x.linspace(1.);
|
||||
|
||||
nd4j::ops::unsorted_segment_prod op;
|
||||
|
||||
auto result = op.execute({&x, &idx}, {}, {4});
|
||||
ASSERT_EQ(result->status(), Status::OK());
|
||||
//result->at(0)->printIndexedBuffer("Output");
|
||||
// result->at(0)->printShapeInfo("Out Shape");
|
||||
//exp.printIndexedBuffer("Expect");
|
||||
// exp.printShapeInfo("Exp Shape");
|
||||
ASSERT_TRUE(exp.equalsTo(result->at(0)));
|
||||
|
||||
delete result;
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests7, TestUnsortedSegmentProdBP_4) {
|
||||
auto x = NDArrayFactory::create<double>('c', {8}, {
|
||||
|
|
|
@ -39,7 +39,7 @@ TEST_F(CudaWorkspaceTests, Basic_Tests_1) {
|
|||
ctx.setWorkspace(&workspace);
|
||||
auto array = NDArrayFactory::create<float>('c', {5, 5}, &ctx);
|
||||
|
||||
ASSERT_EQ(100, workspace.getCurrentOffset());
|
||||
ASSERT_EQ(108, workspace.getCurrentOffset());
|
||||
ASSERT_EQ(0, workspace.getCurrentSecondaryOffset());
|
||||
|
||||
array.e<int>(0);
|
||||
|
@ -55,6 +55,6 @@ TEST_F(CudaWorkspaceTests, Basic_Tests_2) {
|
|||
ctx.setWorkspace(&workspace);
|
||||
auto array = NDArrayFactory::create<float>('c', {5, 5}, {0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f}, &ctx);
|
||||
|
||||
ASSERT_EQ(100, workspace.getCurrentOffset());
|
||||
ASSERT_EQ(108, workspace.getCurrentOffset());
|
||||
ASSERT_EQ(0, workspace.getCurrentSecondaryOffset());
|
||||
}
|
|
@ -86,6 +86,7 @@ public class ImportClassMapping {
|
|||
org.nd4j.linalg.api.ops.impl.image.CropAndResize.class,
|
||||
org.nd4j.linalg.api.ops.impl.image.ExtractImagePatches.class,
|
||||
org.nd4j.linalg.api.ops.impl.image.NonMaxSuppression.class,
|
||||
org.nd4j.linalg.api.ops.impl.image.NonMaxSuppressionV3.class,
|
||||
org.nd4j.linalg.api.ops.impl.image.ResizeBilinear.class,
|
||||
org.nd4j.linalg.api.ops.impl.image.ResizeNearestNeighbor.class,
|
||||
org.nd4j.linalg.api.ops.impl.indexaccum.FirstIndex.class,
|
||||
|
|
|
@ -53,7 +53,7 @@ public class NonMaxSuppression extends DynamicCustomOp {
|
|||
|
||||
@Override
|
||||
public String[] tensorflowNames() {
|
||||
return new String[]{"NonMaxSuppression", "NonMaxSuppressionV2", "NonMaxSuppressionV3"};
|
||||
return new String[]{"NonMaxSuppression", "NonMaxSuppressionV2"};
|
||||
}
|
||||
|
||||
@Override
|
||||
|
|
|
@ -0,0 +1,79 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
package org.nd4j.linalg.api.ops.impl.image;
|
||||
|
||||
import lombok.NonNull;
|
||||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
import org.nd4j.imports.NoOpNameFoundException;
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||
import org.nd4j.linalg.api.ops.Op;
|
||||
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
|
||||
/**
|
||||
* Non max suppression
|
||||
*
|
||||
* @author raver119@gmail.com
|
||||
*/
|
||||
public class NonMaxSuppressionV3 extends DynamicCustomOp {
|
||||
|
||||
public NonMaxSuppressionV3() {}
|
||||
|
||||
public NonMaxSuppressionV3(SameDiff sameDiff, @NonNull SDVariable boxes, @NonNull SDVariable scores, @NonNull SDVariable maxOutSize,
|
||||
@NonNull SDVariable iouThreshold, @NonNull SDVariable scoreThreshold) {
|
||||
super(null, sameDiff, new SDVariable[]{boxes, scores, maxOutSize, iouThreshold, scoreThreshold}, false);
|
||||
}
|
||||
|
||||
@Override
|
||||
public String onnxName() {
|
||||
throw new NoOpNameFoundException("No onnx name found for shape " + opName());
|
||||
}
|
||||
|
||||
@Override
|
||||
public String tensorflowName() {
|
||||
return "NonMaxSuppressionV3";
|
||||
}
|
||||
|
||||
@Override
|
||||
public String[] tensorflowNames() {
|
||||
return new String[]{"NonMaxSuppressionV3","NonMaxSuppressionV4"};
|
||||
}
|
||||
|
||||
@Override
|
||||
public String opName() {
|
||||
return "non_max_suppression_v3";
|
||||
}
|
||||
|
||||
@Override
|
||||
public Op.Type opType() {
|
||||
return Op.Type.CUSTOM;
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<SDVariable> doDiff(List<SDVariable> i_v) {
|
||||
return Collections.singletonList(sameDiff.zerosLike(arg()));
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<DataType> calculateOutputDataTypes(List<DataType> inputDataTypes){
|
||||
//Always 1D integer tensor (indices)
|
||||
return Collections.singletonList(DataType.INT);
|
||||
}
|
||||
}
|
|
@ -20827,7 +20827,7 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD();
|
|||
// #endif
|
||||
|
||||
/**
|
||||
* image.non_max_suppression op.
|
||||
* image.non_max_suppression ops.
|
||||
* input:
|
||||
* 0 - boxes - 2D-tensor with shape (num_boxes, 4) by float type
|
||||
* 1 - scales - 1D-tensor with shape (num_boxes) by float type
|
||||
|
@ -20859,6 +20859,23 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD();
|
|||
public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block);
|
||||
}
|
||||
// #endif
|
||||
// #if NOT_EXCLUDED(OP_image_non_max_suppression_v3)
|
||||
@Namespace("nd4j::ops") public static class non_max_suppression_v3 extends DeclarableCustomOp {
|
||||
static { Loader.load(); }
|
||||
/** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */
|
||||
public non_max_suppression_v3(Pointer p) { super(p); }
|
||||
/** Native array allocator. Access with {@link Pointer#position(long)}. */
|
||||
public non_max_suppression_v3(long size) { super((Pointer)null); allocateArray(size); }
|
||||
private native void allocateArray(long size);
|
||||
@Override public non_max_suppression_v3 position(long position) {
|
||||
return (non_max_suppression_v3)super.position(position);
|
||||
}
|
||||
|
||||
public non_max_suppression_v3() { super((Pointer)null); allocate(); }
|
||||
private native void allocate();
|
||||
public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block);
|
||||
}
|
||||
// #endif
|
||||
|
||||
/*
|
||||
* image.non_max_suppression_overlaps op.
|
||||
|
|
Loading…
Reference in New Issue