Shugeo_release_fixes3 (#81)

* Implementation for non_max_suppression_v3 was added. Initial version

* Added check for overcome threshold.

* Added definition for V3 method.

* java remapping for NonMaxSuppressionV3

Signed-off-by: raver119 <raver119@gmail.com>

* Fixed proporly processing of an empty output and test.

* Refactored op to less threshold data to float.

* Implemented cuda-based helper for non_max_suppression_v3 op.

* Fixed fake_quant_with_min_max_vars op.

* Fixed tests with float numbers.

* - assert now stops execution
- sortByKey/sortByValue now have input validation

Signed-off-by: raver119 <raver119@gmail.com>

* missing var

Signed-off-by: raver119 <raver119@gmail.com>

* Fixed proper processing for zero max_size inputs.

* Refactored kernel callers.

* Fixed return statement for logdet op helper.

* Refactored unsorted segment SqrtN op.

* get back 8 tail bytes on CUDA

Signed-off-by: raver119 <raver119@gmail.com>

* Refactored segment prod ops and helpers for cuda and tests.

* Additional test.

* CudaWorkspace tests updated for 8 tail bytes

Signed-off-by: raver119 <raver119@gmail.com>

* special atomic test

Signed-off-by: raver119 <raver119@gmail.com>

* atomicMul/atomicDiv fix for 16bit values

Signed-off-by: raver119 <raver119@gmail.com>

* Eliminated waste prints.
master
shugeo 2019-11-28 20:08:51 +02:00 committed by raver119
parent abd2017a0a
commit 009007120b
24 changed files with 945 additions and 440 deletions

View File

@ -2202,10 +2202,17 @@ void sortByKey(Nd4jPointer *extraPointers,
auto stream = reinterpret_cast<cudaStream_t *>(extraPointers[1]); auto stream = reinterpret_cast<cudaStream_t *>(extraPointers[1]);
auto xLength = shape::length(xShapeInfo); auto xLength = shape::length(xShapeInfo);
auto yLength = shape::length(yShapeInfo);
auto xEWS = shape::elementWiseStride(xShapeInfo); auto xEWS = shape::elementWiseStride(xShapeInfo);
auto xType = nd4j::ArrayOptions::dataType(xShapeInfo); auto xType = nd4j::ArrayOptions::dataType(xShapeInfo);
auto yType = nd4j::ArrayOptions::dataType(yShapeInfo); auto yType = nd4j::ArrayOptions::dataType(yShapeInfo);
if (shape::isEmpty(xShapeInfo) || shape::isEmpty(yShapeInfo))
return;
if (xLength != yLength)
throw std::runtime_error("sortByKey: keys and values must have the same size");
// check if xLength is a power of 2, and use bitonic sort, if that's the case // check if xLength is a power of 2, and use bitonic sort, if that's the case
if ((xLength != 0) && ((xLength & (xLength - 1)) == 0) && (xLength <= 1024 * 1024 * 10)) { if ((xLength != 0) && ((xLength & (xLength - 1)) == 0) && (xLength <= 1024 * 1024 * 10)) {
@ -2269,10 +2276,17 @@ void sortByValue(Nd4jPointer *extraPointers,
auto stream = reinterpret_cast<cudaStream_t *>(extraPointers[1]); auto stream = reinterpret_cast<cudaStream_t *>(extraPointers[1]);
auto xLength = shape::length(xShapeInfo); auto xLength = shape::length(xShapeInfo);
auto yLength = shape::length(yShapeInfo);
auto xEWS = shape::elementWiseStride(xShapeInfo); auto xEWS = shape::elementWiseStride(xShapeInfo);
auto xType = nd4j::ArrayOptions::dataType(yShapeInfo); auto xType = nd4j::ArrayOptions::dataType(yShapeInfo);
auto yType = nd4j::ArrayOptions::dataType(xShapeInfo); auto yType = nd4j::ArrayOptions::dataType(xShapeInfo);
if (shape::isEmpty(xShapeInfo) || shape::isEmpty(yShapeInfo))
return;
if (xLength != yLength)
throw std::runtime_error("sortByValue: keys and values must have the same size");
// check if xLength is a power of 2, and use bitonic sort, if that's the case // check if xLength is a power of 2, and use bitonic sort, if that's the case
if ((xLength != 0) && ((xLength & (xLength - 1)) == 0) && (xLength <= 1024 * 1024 * 10)) { if ((xLength != 0) && ((xLength & (xLength - 1)) == 0) && (xLength <= 1024 * 1024 * 10)) {

View File

@ -1461,12 +1461,14 @@
#ifdef _RELEASE #ifdef _RELEASE
#define ALLOCATE_SPECIAL(VARIABLE, WORKSPACE, LENGTH, TT) if (WORKSPACE == nullptr) {auto erc_##VARIABLE = cudaMalloc(reinterpret_cast<void**>(&VARIABLE), LENGTH * sizeof(TT)); if (erc_##VARIABLE != 0) {throw cuda_exception::build("[DEVICE] allocation failed", erc_##VARIABLE);} else { }; } else {VARIABLE = reinterpret_cast<TT *>(WORKSPACE->allocateBytes(nd4j::memory::MemoryType::DEVICE, LENGTH * sizeof(TT))); } // we intentionally add 8 tail bytes here to avoid problems with atomic operations
#define ALLOCATE_SPECIAL(VARIABLE, WORKSPACE, LENGTH, TT) if (WORKSPACE == nullptr) {auto erc_##VARIABLE = cudaMalloc(reinterpret_cast<void**>(&VARIABLE), LENGTH * sizeof(TT) + 8); if (erc_##VARIABLE != 0) {throw cuda_exception::build("[DEVICE] allocation failed", erc_##VARIABLE);} else { }; } else {VARIABLE = reinterpret_cast<TT *>(WORKSPACE->allocateBytes(nd4j::memory::MemoryType::DEVICE, LENGTH * sizeof(TT) + 8)); }
#define RELEASE_SPECIAL(VARIABLE, WORKSPACE) if (VARIABLE != nullptr) {if (WORKSPACE == nullptr) { auto erc_##VARIABLE = cudaFree(reinterpret_cast<void *>(VARIABLE)); if (erc_##VARIABLE != 0) {throw cuda_exception::build("[DEVICE] deallocation failed", erc_##VARIABLE);}; }; }; #define RELEASE_SPECIAL(VARIABLE, WORKSPACE) if (VARIABLE != nullptr) {if (WORKSPACE == nullptr) { auto erc_##VARIABLE = cudaFree(reinterpret_cast<void *>(VARIABLE)); if (erc_##VARIABLE != 0) {throw cuda_exception::build("[DEVICE] deallocation failed", erc_##VARIABLE);}; }; };
#else #else
#define ALLOCATE_SPECIAL(VARIABLE, WORKSPACE, LENGTH, TT) if (WORKSPACE == nullptr) {auto erc_##VARIABLE = cudaMalloc(reinterpret_cast<void**>(&VARIABLE), LENGTH * sizeof(TT)); if (erc_##VARIABLE != 0) {throw cuda_exception::build("[DEVICE] allocation failed", erc_##VARIABLE);} else { nd4j::memory::MemoryTracker::getInstance()->countIn(nd4j::memory::MemoryType::DEVICE, VARIABLE, LENGTH * sizeof(TT)); }; } else {VARIABLE = reinterpret_cast<TT *>(WORKSPACE->allocateBytes(nd4j::memory::MemoryType::DEVICE, LENGTH * sizeof(TT))); } // we intentionally add 8 tail bytes here to avoid problems with atomic operations
#define ALLOCATE_SPECIAL(VARIABLE, WORKSPACE, LENGTH, TT) if (WORKSPACE == nullptr) {auto erc_##VARIABLE = cudaMalloc(reinterpret_cast<void**>(&VARIABLE), LENGTH * sizeof(TT) + 8); if (erc_##VARIABLE != 0) {throw cuda_exception::build("[DEVICE] allocation failed", erc_##VARIABLE);} else { nd4j::memory::MemoryTracker::getInstance()->countIn(nd4j::memory::MemoryType::DEVICE, VARIABLE, LENGTH * sizeof(TT)); }; } else {VARIABLE = reinterpret_cast<TT *>(WORKSPACE->allocateBytes(nd4j::memory::MemoryType::DEVICE, LENGTH * sizeof(TT) + 8)); }
#define RELEASE_SPECIAL(VARIABLE, WORKSPACE) if (VARIABLE != nullptr) {if (WORKSPACE == nullptr) { nd4j::memory::MemoryTracker::getInstance()->countOut(VARIABLE); auto erc_##VARIABLE = cudaFree(reinterpret_cast<void *>(VARIABLE)); if (erc_##VARIABLE != 0) {throw cuda_exception::build("[DEVICE] deallocation failed", erc_##VARIABLE);}; }; }; #define RELEASE_SPECIAL(VARIABLE, WORKSPACE) if (VARIABLE != nullptr) {if (WORKSPACE == nullptr) { nd4j::memory::MemoryTracker::getInstance()->countOut(VARIABLE); auto erc_##VARIABLE = cudaFree(reinterpret_cast<void *>(VARIABLE)); if (erc_##VARIABLE != 0) {throw cuda_exception::build("[DEVICE] deallocation failed", erc_##VARIABLE);}; }; };
#endif #endif

View File

@ -29,8 +29,8 @@ namespace nd4j {
OP_IMPL(Assert, 1, 1, false) { OP_IMPL(Assert, 1, 1, false) {
auto x = INPUT_VARIABLE(0); auto x = INPUT_VARIABLE(0);
if (x->e<float>(0) == 0.0f) { if (!x->e<bool>(0)) {
nd4j_printf("Assertion failed for node [%i]\n", block.getNodeId()); REQUIRE_TRUE(false, 0, "Assertion failed for node [%i]\n", block.getNodeId());
} }
return Status::OK(); return Status::OK();

View File

@ -21,10 +21,10 @@
#include <ops/declarable/CustomOperations.h> #include <ops/declarable/CustomOperations.h>
#include <ops/declarable/helpers/image_suppression.h> #include <ops/declarable/helpers/image_suppression.h>
#if NOT_EXCLUDED(OP_image_non_max_suppression)
namespace nd4j { namespace nd4j {
namespace ops { namespace ops {
#if NOT_EXCLUDED(OP_image_non_max_suppression)
CUSTOM_OP_IMPL(non_max_suppression, 2, 1, false, 0, 0) { CUSTOM_OP_IMPL(non_max_suppression, 2, 1, false, 0, 0) {
auto boxes = INPUT_VARIABLE(0); auto boxes = INPUT_VARIABLE(0);
auto scales = INPUT_VARIABLE(1); auto scales = INPUT_VARIABLE(1);
@ -56,11 +56,24 @@ namespace nd4j {
if (boxes->isEmpty() || scales->isEmpty()) if (boxes->isEmpty() || scales->isEmpty())
return Status::OK(); return Status::OK();
REQUIRE_TRUE(boxes->rankOf() == 2, 0, "image.non_max_suppression: The rank of boxes array should be 2, but %i is given", boxes->rankOf()); if (output->isEmpty())
REQUIRE_TRUE(boxes->sizeAt(1) == 4, 0, "image.non_max_suppression: The last dimension of boxes array should be 4, but %i is given", boxes->sizeAt(1)); return Status::OK();
REQUIRE_TRUE(scales->rankOf() == 1 && scales->lengthOf() == boxes->sizeAt(0), 0, "image.non_max_suppression: The rank of scales array should be 1, but %i is given", boxes->rankOf());
helpers::nonMaxSuppression(block.launchContext(), boxes, scales, maxOutputSize, overlayThreshold, scoreThreshold, output); REQUIRE_TRUE(boxes->rankOf() == 2, 0, "image.non_max_suppression: The rank of boxes array should be 2, "
"but %i is given", boxes->rankOf());
REQUIRE_TRUE(boxes->sizeAt(1) == 4, 0, "image.non_max_suppression: The last dimension of boxes array "
"should be 4, but %i is given", boxes->sizeAt(1));
REQUIRE_TRUE(scales->rankOf() == 1 && scales->lengthOf() == boxes->sizeAt(0), 0,
"image.non_max_suppression: The rank of scales array should be 1, but %i is given", boxes->rankOf());
REQUIRE_TRUE(overlayThreshold >= 0. && overlayThreshold <= 1., 0, "image.non_max_suppressio: The overlay "
"threashold should be in [0, 1], but "
"%lf is given.", overlayThreshold);
REQUIRE_TRUE(boxes->dataType() == scales->dataType(), 0,
"image.non_max_suppression: Boxes and scores inputs should have the same data type, but %s and %s "
"were given.", DataTypeUtils::asString(boxes->dataType()).c_str(),
DataTypeUtils::asString(scales->dataType()).c_str());
helpers::nonMaxSuppression(block.launchContext(), boxes, scales, maxOutputSize, overlayThreshold,
scoreThreshold, output);
return Status::OK(); return Status::OK();
} }
@ -77,20 +90,22 @@ namespace nd4j {
else else
REQUIRE_TRUE(false, 0, "image.non_max_suppression: Max output size argument cannot be retrieved."); REQUIRE_TRUE(false, 0, "image.non_max_suppression: Max output size argument cannot be retrieved.");
auto actualIndicesCount = shape::sizeAt(in, 0); if (maxOutputSize > 0) {
if (block.getTArguments()->size() > 1 || block.width() > 4) { auto actualIndicesCount = shape::sizeAt(in, 0);
auto scoreThreshold = block.getTArguments()->size() > 1?T_ARG(1):INPUT_VARIABLE(4)->e<double>(0); if (block.getTArguments()->size() > 1 || block.width() > 4) {
auto scales = INPUT_VARIABLE(1); auto scoreThreshold =
scales->syncToHost(); block.getTArguments()->size() > 1 ? T_ARG(1) : INPUT_VARIABLE(4)->e<double>(0);
for (auto e = 0; e < scales->lengthOf(); e++) { auto scales = INPUT_VARIABLE(1);
if (scales->e<float>(e) < (float)scoreThreshold) { scales->syncToHost();
actualIndicesCount--; for (auto e = 0; e < scales->lengthOf(); e++) {
if (scales->e<float>(e) < (float) scoreThreshold) {
actualIndicesCount--;
}
} }
} }
if (actualIndicesCount < maxOutputSize)
maxOutputSize = actualIndicesCount;
} }
if (actualIndicesCount < maxOutputSize)
maxOutputSize = actualIndicesCount;
outputShape = ConstantShapeHelper::getInstance()->vectorShapeInfo(maxOutputSize, DataType::INT32); outputShape = ConstantShapeHelper::getInstance()->vectorShapeInfo(maxOutputSize, DataType::INT32);
return SHAPELIST(outputShape); return SHAPELIST(outputShape);
@ -100,7 +115,107 @@ namespace nd4j {
->setAllowedInputTypes(nd4j::DataType::ANY) ->setAllowedInputTypes(nd4j::DataType::ANY)
->setAllowedOutputTypes({ALL_INDICES}); ->setAllowedOutputTypes({ALL_INDICES});
} }
#endif
#if NOT_EXCLUDED(OP_image_non_max_suppression_v3)
DECLARE_TYPES(non_max_suppression_v3) {
getOpDescriptor()
->setAllowedInputTypes(nd4j::DataType::ANY)
->setAllowedOutputTypes({ALL_INDICES});
}
CUSTOM_OP_IMPL(non_max_suppression_v3, 2, 1, false, 0, 0) {
auto boxes = INPUT_VARIABLE(0);
auto scales = INPUT_VARIABLE(1);
auto output = OUTPUT_VARIABLE(0);
int maxOutputSize; // = INT_ARG(0);
if (block.width() > 2)
maxOutputSize = INPUT_VARIABLE(2)->e<int>(0);
else if (block.getIArguments()->size() == 1)
maxOutputSize = INT_ARG(0);
else
REQUIRE_TRUE(false, 0, "image.non_max_suppression: Max output size argument cannot be retrieved.");
double overlayThreshold = 0.5;
double scoreThreshold = - DataTypeUtils::infOrMax<float>();
if (block.width() > 3) {
overlayThreshold = INPUT_VARIABLE(3)->e<double>(0);
}
else if (block.getTArguments()->size() > 0) {
overlayThreshold = T_ARG(0);
}
if (block.width() > 4) {
scoreThreshold = INPUT_VARIABLE(4)->e<double>(0);
}
else if (block.getTArguments()->size() > 1) {
scoreThreshold = T_ARG(1);
}
if (boxes->isEmpty() || scales->isEmpty())
return Status::OK();
if (output->isEmpty())
return Status::OK();
REQUIRE_TRUE(boxes->rankOf() == 2, 0, "image.non_max_suppression: The rank of boxes array should be 2, but "
"%i is given", boxes->rankOf());
REQUIRE_TRUE(boxes->sizeAt(1) == 4, 0, "image.non_max_suppression: The last dimension of boxes array should "
"be 4, but %i is given", boxes->sizeAt(1));
REQUIRE_TRUE(scales->rankOf() == 1 && scales->lengthOf() == boxes->sizeAt(0), 0,
"image.non_max_suppression: The rank of scales array should be 1, but %i is given", boxes->rankOf());
REQUIRE_TRUE(overlayThreshold >= 0. && overlayThreshold <= 1., 0,
"image.non_max_suppression_v3: The overlay threashold should be in [0, 1], but %lf given.",
overlayThreshold);
REQUIRE_TRUE(boxes->dataType() == scales->dataType(), 0,
"image.non_max_suppression_v3: Boxes and scores inputs should have the same data type, but %s and %s "
"were given.", DataTypeUtils::asString(boxes->dataType()).c_str(),
DataTypeUtils::asString(scales->dataType()).c_str());
helpers::nonMaxSuppressionV3(block.launchContext(), boxes, scales, maxOutputSize, overlayThreshold,
scoreThreshold, output);
return Status::OK();
}
DECLARE_SHAPE_FN(non_max_suppression_v3) {
auto in = inputShape->at(0);
int outRank = shape::rank(in);
Nd4jLong *outputShape = nullptr;
int maxOutputSize;
if (block.width() > 2)
maxOutputSize = INPUT_VARIABLE(2)->e<int>(0);
else if (block.getIArguments()->size() == 1)
maxOutputSize = INT_ARG(0);
else
REQUIRE_TRUE(false, 0, "image.non_max_suppression: Max output size argument cannot be retrieved.");
auto boxes = INPUT_VARIABLE(0);
auto scales = INPUT_VARIABLE(1);
double overlayThreshold = 0.5;
double scoreThreshold = - DataTypeUtils::infOrMax<float>();
if (block.width() > 3) {
overlayThreshold = INPUT_VARIABLE(3)->e<double>(0);
}
else if (block.getTArguments()->size() > 0) {
overlayThreshold = T_ARG(0);
}
if (block.width() > 4) {
scoreThreshold = INPUT_VARIABLE(4)->e<double>(0);
}
else if (block.getTArguments()->size() > 1) {
scoreThreshold = T_ARG(1);
}
auto len = maxOutputSize;
if (len > 0)
len = helpers::nonMaxSuppressionV3(block.launchContext(), boxes, scales, maxOutputSize, overlayThreshold, scoreThreshold, nullptr);
outputShape = ConstantShapeHelper::getInstance()->vectorShapeInfo(len, DataType::INT32);
return SHAPELIST(outputShape);
}
#endif
} }
} }
#endif

View File

@ -61,9 +61,9 @@ namespace nd4j {
} }
DECLARE_TYPES(unsorted_segment_prod) { DECLARE_TYPES(unsorted_segment_prod) {
getOpDescriptor() getOpDescriptor()
->setAllowedOutputTypes({ALL_FLOATS}) ->setAllowedOutputTypes({ALL_FLOATS, ALL_INTS})
->setAllowedInputTypes(0, {ALL_FLOATS}) ->setAllowedInputTypes(0, {ALL_FLOATS, ALL_INTS})
->setAllowedInputTypes(1, {ALL_INTS}) ->setAllowedInputTypes(1, {ALL_INDICES})
->setSameMode(false); ->setSameMode(false);
} }
@ -88,10 +88,10 @@ namespace nd4j {
DECLARE_TYPES(unsorted_segment_prod_bp) { DECLARE_TYPES(unsorted_segment_prod_bp) {
getOpDescriptor() getOpDescriptor()
->setAllowedOutputTypes(0, {ALL_FLOATS}) ->setAllowedOutputTypes(0, {ALL_FLOATS})
->setAllowedOutputTypes(1, {ALL_INTS}) ->setAllowedOutputTypes(1, {ALL_INDICES})
->setAllowedInputTypes(0, {ALL_FLOATS}) ->setAllowedInputTypes(0, {ALL_FLOATS})
->setAllowedInputTypes(1, {ALL_INTS}) ->setAllowedInputTypes(1, {ALL_INDICES})
->setAllowedInputTypes(2,{ALL_FLOATS}) ->setAllowedInputTypes(2,{ALL_FLOATS, ALL_INTS})
->setSameMode(false); ->setSameMode(false);
} }

View File

@ -1723,7 +1723,7 @@ namespace nd4j {
#endif #endif
/** /**
* image.non_max_suppression op. * image.non_max_suppression ops.
* input: * input:
* 0 - boxes - 2D-tensor with shape (num_boxes, 4) by float type * 0 - boxes - 2D-tensor with shape (num_boxes, 4) by float type
* 1 - scales - 1D-tensor with shape (num_boxes) by float type * 1 - scales - 1D-tensor with shape (num_boxes) by float type
@ -1741,6 +1741,9 @@ namespace nd4j {
#if NOT_EXCLUDED(OP_image_non_max_suppression) #if NOT_EXCLUDED(OP_image_non_max_suppression)
DECLARE_CUSTOM_OP(non_max_suppression, 2, 1, false, 0, 0); DECLARE_CUSTOM_OP(non_max_suppression, 2, 1, false, 0, 0);
#endif #endif
#if NOT_EXCLUDED(OP_image_non_max_suppression_v3)
DECLARE_CUSTOM_OP(non_max_suppression_v3, 2, 1, false, 0, 0);
#endif
/* /*
* image.non_max_suppression_overlaps op. * image.non_max_suppression_overlaps op.

View File

@ -77,7 +77,13 @@ namespace helpers {
} }
} }
} }
//
//const auto clamped = inputs.cwiseMin(nudged_max).cwiseMax(nudged_min);
// const auto clamped_shifted = clamped - nudged_min;
// outputs.device(d) = (clamped_shifted / nudged_scale_repl + 0.5f).floor() *
// nudged_scale_repl +
// nudged_min;
//
template <typename T> template <typename T>
void fakeQuantWithMinMaxVars_(NDArray* input, NDArray* min, NDArray* max, int numBits, bool narrowed, NDArray* output) { void fakeQuantWithMinMaxVars_(NDArray* input, NDArray* min, NDArray* max, int numBits, bool narrowed, NDArray* output) {
int lowIntBound = narrowed ? 1 : 0; int lowIntBound = narrowed ? 1 : 0;
@ -95,7 +101,8 @@ namespace helpers {
else if (val > nudgedMax) else if (val > nudgedMax)
val = nudgedMax; val = nudgedMax;
// converse value with scale and shifted with nudged min // converse value with scale and shifted with nudged min
return (nd4j::math::nd4j_floor<T,T>((val - nudgedMin)/scale + T(0.5)) * scale + nudgedMin); val -= nudgedMin;
return (nd4j::math::nd4j_floor<T,T>(val / scale + T(0.5f)) * scale + nudgedMin);
}; };
input->applyLambda<T>(fakeQuantizationWithMinMax, output); input->applyLambda<T>(fakeQuantizationWithMinMax, output);

View File

@ -19,7 +19,7 @@
// //
#include <ops/declarable/helpers/image_suppression.h> #include <ops/declarable/helpers/image_suppression.h>
//#include <blas/NDArray.h> #include <NDArrayFactory.h>
#include <algorithm> #include <algorithm>
#include <numeric> #include <numeric>
#include <queue> #include <queue>
@ -90,17 +90,61 @@ namespace helpers {
} }
} }
} }
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
// Return intersection-over-union overlap between boxes i and j
template <typename T>
static inline T similirityV3_(NDArray const& boxes, Nd4jLong i, Nd4jLong j) {
const T zero = static_cast<T>(0.f);
const T yminI = math::nd4j_min(boxes.t<T>(i, 0), boxes.t<T>(i, 2));
const T xminI = math::nd4j_min(boxes.t<T>(i, 1), boxes.t<T>(i, 3));
const T ymaxI = math::nd4j_max(boxes.t<T>(i, 0), boxes.t<T>(i, 2));
const T xmaxI = math::nd4j_max(boxes.t<T>(i, 1), boxes.t<T>(i, 3));
const T yminJ = math::nd4j_min(boxes.t<T>(j, 0), boxes.t<T>(j, 2));
const T xminJ = math::nd4j_min(boxes.t<T>(j, 1), boxes.t<T>(j, 3));
const T ymaxJ = math::nd4j_max(boxes.t<T>(j, 0), boxes.t<T>(j, 2));
const T xmaxJ = math::nd4j_max(boxes.t<T>(j, 1), boxes.t<T>(j, 3));
const T areaI = (ymaxI - yminI) * (xmaxI - xminI);
const T areaJ = (ymaxJ - yminJ) * (xmaxJ - xminJ);
if (areaI <= zero || areaJ <= zero) {
return zero;
}
const T intersectionYmin = math::nd4j_max(yminI, yminJ);
const T intersectionXmin = math::nd4j_max(xminI, xminJ);
const T intersectionYmax = math::nd4j_min(ymaxI, ymaxJ);
const T intersectionXmax = math::nd4j_min(xmaxI, xmaxJ);
const T intersectionY = intersectionYmax - intersectionYmin;
const T intersectionX = intersectionXmax - intersectionXmin;
const T intersectionArea = math::nd4j_max(intersectionY, zero) * math::nd4j_max(intersectionX, zero);
return intersectionArea / (areaI + areaJ - intersectionArea);
}
template <typename T>
static inline T similiratyOverlaps_(NDArray const& boxes, Nd4jLong i, Nd4jLong j) {
return boxes.t<T>(i, j);
}
typedef NDArray (*SimiliratyFunc)(NDArray const& boxes, Nd4jLong i, Nd4jLong j);
static NDArray similiratyOverlaps(NDArray const& boxes, Nd4jLong i, Nd4jLong j) {
NDArray res(boxes.dataType(), boxes.getContext()); // = NDArrayFactory::create(0.);
BUILD_SINGLE_SELECTOR(boxes.dataType(), res = similiratyOverlaps_, (boxes, i, j) , FLOAT_TYPES);
return res;
}
static NDArray similiratyV3(NDArray const& boxes, Nd4jLong i, Nd4jLong j) {
NDArray res(boxes.dataType(), boxes.getContext()); // = NDArrayFactory::create(0.);
BUILD_SINGLE_SELECTOR(boxes.dataType(), res = similirityV3_, (boxes, i, j) , FLOAT_TYPES);
return res;
}
//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
template <typename T, typename I> template <typename T, typename I>
static Nd4jLong static Nd4jLong
nonMaxSuppressionGeneric_(nd4j::LaunchContext* context, NDArray* boxes, NDArray* scores, int outputSize, nonMaxSuppressionGeneric_(nd4j::LaunchContext* context, NDArray* boxes, NDArray* scores, int outputSize,
double overlapThreshold, double scoreThreshold, NDArray* output) { float overlapThreshold, float scoreThreshold, NDArray* output, SimiliratyFunc f) {
// const int outputSize = maxSize->e<int>(0);
auto numBoxes = boxes->sizeAt(0); auto numBoxes = boxes->sizeAt(0);
//std::vector<T> scoresData(numBoxes);
T* scoresData = scores->dataBuffer()->primaryAsT<T>(); T* scoresData = scores->dataBuffer()->primaryAsT<T>();
//std::copy_n(scores->getDataBuffer()->primaryAsT<T>(), numBoxes, scoresData.begin());
// Data structure for a selection candidate in NMS. // Data structure for a selection candidate in NMS.
struct Candidate { struct Candidate {
@ -113,9 +157,10 @@ namespace helpers {
return ((bsI._score == bsJ._score) && (bsI._boxIndex > bsJ._boxIndex)) || return ((bsI._score == bsJ._score) && (bsI._boxIndex > bsJ._boxIndex)) ||
(bsI._score < bsJ._score); (bsI._score < bsJ._score);
}; };
std::priority_queue<Candidate, std::deque<Candidate>, decltype(cmp)> candidatePriorityQueue(cmp); std::priority_queue<Candidate, std::deque<Candidate>, decltype(cmp)> candidatePriorityQueue(cmp);
for (auto i = 0; i < scores->lengthOf(); ++i) { for (auto i = 0; i < scores->lengthOf(); ++i) {
if (scoresData[i] > scoreThreshold) { if ((float)scoresData[i] > (float)scoreThreshold) {
candidatePriorityQueue.emplace(Candidate({i, scoresData[i], 0})); candidatePriorityQueue.emplace(Candidate({i, scoresData[i], 0}));
} }
} }
@ -139,17 +184,18 @@ namespace helpers {
// following loop. // following loop.
bool shouldHardSuppress = false; bool shouldHardSuppress = false;
for (int j = static_cast<int>(selected.size()) - 1; j >= nextCandidate._suppressBeginIndex; --j) { for (int j = static_cast<int>(selected.size()) - 1; j >= nextCandidate._suppressBeginIndex; --j) {
similarity = boxes->t<T>(nextCandidate._boxIndex, selected[j]); auto similarityA = f(*boxes, nextCandidate._boxIndex, selected[j]); //boxes->t<T>(nextCandidate._boxIndex, selected[j]);
similarity = similarityA.template t<T>(0);
nextCandidate._score *= T(similarity <= overlapThreshold?1.0:0.); //suppressWeightFunc(similarity); nextCandidate._score *= T(similarity <= overlapThreshold?1.0:0.); //suppressWeightFunc(similarity);
// First decide whether to perform hard suppression // First decide whether to perform hard suppression
if (similarity >= static_cast<T>(overlapThreshold)) { if ((float)similarity >= static_cast<float>(overlapThreshold)) {
shouldHardSuppress = true; shouldHardSuppress = true;
break; break;
} }
// If next_candidate survives hard suppression, apply soft suppression // If next_candidate survives hard suppression, apply soft suppression
if (nextCandidate._score <= scoreThreshold) break; if ((float)nextCandidate._score <= (float)scoreThreshold) break;
} }
// If `nextCandidate._score` has not dropped below `scoreThreshold` // If `nextCandidate._score` has not dropped below `scoreThreshold`
// by this point, then we know that we went through all of the previous // by this point, then we know that we went through all of the previous
@ -169,7 +215,7 @@ namespace helpers {
selected.push_back(nextCandidate._boxIndex); selected.push_back(nextCandidate._boxIndex);
// selected_scores.push_back(nextCandidate._score); // selected_scores.push_back(nextCandidate._score);
} }
if (nextCandidate._score > scoreThreshold) { if ((float)nextCandidate._score > (float)scoreThreshold) {
// Soft suppression has occurred and current score is still greater than // Soft suppression has occurred and current score is still greater than
// score_threshold; add next_candidate back onto priority queue. // score_threshold; add next_candidate back onto priority queue.
candidatePriorityQueue.push(nextCandidate); candidatePriorityQueue.push(nextCandidate);
@ -188,12 +234,19 @@ namespace helpers {
Nd4jLong Nd4jLong
nonMaxSuppressionGeneric(nd4j::LaunchContext* context, NDArray* boxes, NDArray* scores, int maxSize, nonMaxSuppressionGeneric(nd4j::LaunchContext* context, NDArray* boxes, NDArray* scores, int maxSize,
double overlapThreshold, double scoreThreshold, NDArray* output) { double overlapThreshold, double scoreThreshold, NDArray* output) {
BUILD_DOUBLE_SELECTOR(boxes->dataType(), output == nullptr?DataType::INT32:output->dataType(), return nonMaxSuppressionGeneric_, (context, boxes, scores, maxSize, overlapThreshold, scoreThreshold, output), FLOAT_TYPES, INTEGER_TYPES); BUILD_DOUBLE_SELECTOR(boxes->dataType(), output == nullptr?DataType::INT32:output->dataType(), return nonMaxSuppressionGeneric_, (context, boxes, scores, maxSize, overlapThreshold, scoreThreshold, output, similiratyOverlaps), FLOAT_TYPES, INTEGER_TYPES);
return 0;
}
Nd4jLong
nonMaxSuppressionV3(nd4j::LaunchContext* context, NDArray* boxes, NDArray* scores, int maxSize,
double overlapThreshold, double scoreThreshold, NDArray* output) {
BUILD_DOUBLE_SELECTOR(boxes->dataType(), output == nullptr?DataType::INT32:output->dataType(), return nonMaxSuppressionGeneric_, (context, boxes, scores, maxSize, overlapThreshold, scoreThreshold, output, similiratyV3), FLOAT_TYPES, INTEGER_TYPES);
return 0; return 0;
} }
BUILD_DOUBLE_TEMPLATE(template Nd4jLong nonMaxSuppressionGeneric_, (nd4j::LaunchContext* context, NDArray* boxes, NDArray* scores, int maxSize, BUILD_DOUBLE_TEMPLATE(template Nd4jLong nonMaxSuppressionGeneric_, (nd4j::LaunchContext* context, NDArray* boxes, NDArray* scores, int maxSize,
double overlapThreshold, double scoreThreshold, NDArray* output), FLOAT_TYPES, INTEGER_TYPES); float overlapThreshold, float scoreThreshold, NDArray* output, SimiliratyFunc similiratyFunc), FLOAT_TYPES, INTEGER_TYPES);
void void
nonMaxSuppression(nd4j::LaunchContext * context, NDArray* boxes, NDArray* scales, int maxSize, nonMaxSuppression(nd4j::LaunchContext * context, NDArray* boxes, NDArray* scales, int maxSize,

View File

@ -100,7 +100,7 @@ namespace helpers {
val = nudgedMax; val = nudgedMax;
} }
output[shape::getIndexOffset(b * channels + i, outputShape)] = output[shape::getIndexOffset(b * channels + i, outputShape)] =
(math::nd4j_floor<T, T>((val - nudgedMin) / scale + T(0.5)) * scale + nudgedMin); (math::nd4j_floor<T, T>((val - nudgedMin) / scale + T(0.5f)) * scale + nudgedMin);
}; };
} }
} }

View File

@ -79,7 +79,49 @@ namespace helpers {
return intersectionValue > threshold; return intersectionValue > threshold;
} }
//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// template <typename T>
static __device__ T similirityV3(T* boxes, Nd4jLong* boxesShape, int previousIndex, int nextIndex) {
Nd4jLong previous0[] = {previousIndex, 0};
Nd4jLong previous1[] = {previousIndex, 1};
Nd4jLong previous2[] = {previousIndex, 2};
Nd4jLong previous3[] = {previousIndex, 3};
Nd4jLong next0[] = {nextIndex, 0};
Nd4jLong next1[] = {nextIndex, 1};
Nd4jLong next2[] = {nextIndex, 2};
Nd4jLong next3[] = {nextIndex, 3};
// we have rectangle with given max values. Compute vexes of rectangle first
T minYPrev = nd4j::math::nd4j_min(boxes[shape::getOffset(boxesShape, previous0)], boxes[shape::getOffset(boxesShape, previous2)]);
T minXPrev = nd4j::math::nd4j_min(boxes[shape::getOffset(boxesShape, previous1)], boxes[shape::getOffset(boxesShape, previous3)]);
T maxYPrev = nd4j::math::nd4j_max(boxes[shape::getOffset(boxesShape, previous0)], boxes[shape::getOffset(boxesShape, previous2)]);
T maxXPrev = nd4j::math::nd4j_max(boxes[shape::getOffset(boxesShape, previous1)], boxes[shape::getOffset(boxesShape, previous3)]);
T minYNext = nd4j::math::nd4j_min(boxes[shape::getOffset(boxesShape, next0)], boxes[shape::getOffset(boxesShape, next2)]);
T minXNext = nd4j::math::nd4j_min(boxes[shape::getOffset(boxesShape, next1)], boxes[shape::getOffset(boxesShape, next3)]);
T maxYNext = nd4j::math::nd4j_max(boxes[shape::getOffset(boxesShape, next0)], boxes[shape::getOffset(boxesShape, next2)]);
T maxXNext = nd4j::math::nd4j_max(boxes[shape::getOffset(boxesShape, next1)], boxes[shape::getOffset(boxesShape, next3)]);
// compute areas for comparation
T areaPrev = (maxYPrev - minYPrev) * (maxXPrev - minXPrev);
T areaNext = (maxYNext - minYNext) * (maxXNext - minXNext);
// of course, areas should be positive
if (areaNext <= T(0.f) || areaPrev <= T(0.f)) return false;
// compute intersection of rectangles
T minIntersectionY = nd4j::math::nd4j_max(minYPrev, minYNext);
T minIntersectionX = nd4j::math::nd4j_max(minXPrev, minXNext);
T maxIntersectionY = nd4j::math::nd4j_min(maxYPrev, maxYNext);
T maxIntersectionX = nd4j::math::nd4j_min(maxXPrev, maxXNext);
T intersectionArea =
nd4j::math::nd4j_max(T(maxIntersectionY - minIntersectionY), T(0.0f)) *
nd4j::math::nd4j_max(T(maxIntersectionX - minIntersectionX), T(0.0f));
T intersectionValue = intersectionArea / (areaPrev + areaNext - intersectionArea);
// final check
return intersectionValue;
}
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
// shouldSelectKernel - compute status for all selected rectangles (boxes) // shouldSelectKernel - compute status for all selected rectangles (boxes)
// //
// we compute boolean flag as shared uint32 and return it on final only for the first thread // we compute boolean flag as shared uint32 and return it on final only for the first thread
@ -139,7 +181,7 @@ namespace helpers {
} }
} }
//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
// nonMaxSuppressionV2 algorithm - given from TF NonMaxSuppressionV2 implementation // nonMaxSuppressionV2 algorithm - given from TF NonMaxSuppressionV2 implementation
// //
template <typename T, typename I> template <typename T, typename I>
@ -200,24 +242,33 @@ namespace helpers {
//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
template <typename T, typename I> template <typename T, typename I>
static __device__ bool checkOverlapBoxes(T* boxes, Nd4jLong* shape, T* scores, I* indices, I* selectedIndices, I* startIndices, I selectedSize, I nextCandidateIndex, T overlapThreshold, T scoreThreshold) { static __device__ bool checkOverlapBoxes(T* boxes, Nd4jLong* shape, T* scores, I* indices, I* selectedIndices, I* startIndices, I selectedSize, I nextCandidateIndex, T overlapThreshold, T scoreThreshold, bool simple) {
bool shouldHardSuppress = false; bool shouldHardSuppress = false;
T& nextCandidateScore = scores[nextCandidateIndex]; T& nextCandidateScore = scores[nextCandidateIndex];
I selectedIndex = indices[nextCandidateIndex]; I selectedIndex = indices[nextCandidateIndex];
I finish = startIndices[nextCandidateIndex]; I finish = startIndices[nextCandidateIndex];
for (int j = selectedSize; j > finish; --j) { for (int j = selectedSize; j > finish; --j) {
Nd4jLong xPos[] = {selectedIndex, selectedIndices[j - 1]}; T boxVal;
auto xShift = shape::getOffset(shape, xPos, 0); if (simple) {
nextCandidateScore *= (boxes[xShift] <= static_cast<T>(overlapThreshold)?T(1.):T(0.));// Nd4jLong xPos[] = {selectedIndex, selectedIndices[j - 1]};
auto xShift = shape::getOffset(shape, xPos, 0);
boxVal = boxes[xShift];
}
else {
boxVal = similirityV3(boxes, shape, selectedIndex, selectedIndices[j - 1]);
}
if (boxVal > static_cast<T>(overlapThreshold))
nextCandidateScore = static_cast<T>(0.f);
// First decide whether to perform hard suppression // First decide whether to perform hard suppression
if (boxes[xShift] >= overlapThreshold) { if (boxVal >= overlapThreshold) {
shouldHardSuppress = true; shouldHardSuppress = true;
break; break;
} }
// If nextCandidate survives hard suppression, apply soft suppression // If nextCandidate survives hard suppression, apply soft suppression
if (nextCandidateScore <= scoreThreshold) break; if (nextCandidateScore <= static_cast<T>(scoreThreshold)) break;
} }
return shouldHardSuppress; return shouldHardSuppress;
@ -226,7 +277,7 @@ namespace helpers {
template <typename T, typename I> template <typename T, typename I>
static __global__ void static __global__ void
suppressNonMaxOverlapKernel(T* boxes, Nd4jLong* boxesShape, T* scoresData, I* indices, I* startIndices, Nd4jLong length, I maxOutputLen, suppressNonMaxOverlapKernel(T* boxes, Nd4jLong* boxesShape, T* scoresData, I* indices, I* startIndices, Nd4jLong length, I maxOutputLen,
T overlapThreshold, T scoreThreshold, I* output, Nd4jLong* outputShape, I* outputLength) { T overlapThreshold, T scoreThreshold, I* output, Nd4jLong* outputShape, I* outputLength, bool simple) {
__shared__ I selectedSize; __shared__ I selectedSize;
__shared__ I* tempOutput; __shared__ I* tempOutput;
@ -253,7 +304,7 @@ namespace helpers {
} }
// check for overlaps // check for overlaps
bool shouldHardSuppress = checkOverlapBoxes(boxes, boxesShape, scoresData, indices, tempOutput, startIndices, selectedSize, bool shouldHardSuppress = checkOverlapBoxes(boxes, boxesShape, scoresData, indices, tempOutput, startIndices, selectedSize,
nextCandidateIndex, overlapThreshold, scoreThreshold);//false; nextCandidateIndex, overlapThreshold, scoreThreshold, simple);//false;
T nextCandidateScore = scoresData[nextCandidateIndex]; T nextCandidateScore = scoresData[nextCandidateIndex];
startIndices[nextCandidateIndex] = selectedSize; startIndices[nextCandidateIndex] = selectedSize;
@ -285,7 +336,7 @@ namespace helpers {
template <typename T, typename I> template <typename T, typename I>
static Nd4jLong static Nd4jLong
nonMaxSuppressionGeneric_(nd4j::LaunchContext* context, NDArray* boxes, NDArray* scores, int outputSize, nonMaxSuppressionGeneric_(nd4j::LaunchContext* context, NDArray* boxes, NDArray* scores, int outputSize,
double overlapThreshold, double scoreThreshold, NDArray* output) { double overlapThreshold, double scoreThreshold, NDArray* output, bool simple) {
auto stream = context->getCudaStream(); auto stream = context->getCudaStream();
if (output) if (output)
NDArray::prepareSpecialUse({output}, {boxes, scores}); NDArray::prepareSpecialUse({output}, {boxes, scores});
@ -315,16 +366,16 @@ namespace helpers {
Nd4jLong res = 0; Nd4jLong res = 0;
if (output) { // this part used when output shape already calculated to fill up values on output if (output) { // this part used when output shape already calculated to fill up values on output
DataBuffer selectedSizeBuf(&selectedSize, sizeof(I), DataTypeUtils::fromT<I>()); DataBuffer selectedSizeBuf(&selectedSize, sizeof(I), DataTypeUtils::fromT<I>());
suppressNonMaxOverlapKernel <<<1, 1, 1024, *stream >>> (boxes->dataBuffer()->specialAsT<T>(), suppressNonMaxOverlapKernel<<<1, 1, 1024, *stream >>> (boxes->dataBuffer()->specialAsT<T>(),
boxes->specialShapeInfo(), scoresData, indexBuf, startIndices, scores->lengthOf(), (I) outputSize, boxes->specialShapeInfo(), scoresData, indexBuf, startIndices, scores->lengthOf(), (I) outputSize,
T(overlapThreshold), T(scoreThreshold), output->dataBuffer()->specialAsT<I>(), output->specialShapeInfo(), T(overlapThreshold), T(scoreThreshold), output->dataBuffer()->specialAsT<I>(), output->specialShapeInfo(),
selectedSizeBuf.specialAsT<I>()); selectedSizeBuf.specialAsT<I>(), simple);
} }
else { // this case used on calculation of output shape. Output and output shape shoulde be nullptr. else { // this case used on calculation of output shape. Output and output shape shoulde be nullptr.
DataBuffer selectedSizeBuf(&selectedSize, sizeof(I), DataTypeUtils::fromT<I>()); DataBuffer selectedSizeBuf(&selectedSize, sizeof(I), DataTypeUtils::fromT<I>());
suppressNonMaxOverlapKernel <<<1, 1, 1024, *stream >>> (boxes->dataBuffer()->specialAsT<T>(), suppressNonMaxOverlapKernel<<<1, 1, 1024, *stream >>> (boxes->dataBuffer()->specialAsT<T>(),
boxes->specialShapeInfo(), scoresData, indexBuf, startIndices, scores->lengthOf(), (I)outputSize, boxes->specialShapeInfo(), scoresData, indexBuf, startIndices, scores->lengthOf(), (I)outputSize,
T(overlapThreshold), T(scoreThreshold), (I*)nullptr, (Nd4jLong*) nullptr, selectedSizeBuf.specialAsT<I>()); T(overlapThreshold), T(scoreThreshold), (I*)nullptr, (Nd4jLong*) nullptr, selectedSizeBuf.specialAsT<I>(), simple);
selectedSizeBuf.syncToPrimary(context, true); selectedSizeBuf.syncToPrimary(context, true);
res = *selectedSizeBuf.primaryAsT<I>(); res = *selectedSizeBuf.primaryAsT<I>();
} }
@ -344,7 +395,16 @@ namespace helpers {
Nd4jLong nonMaxSuppressionGeneric(nd4j::LaunchContext * context, NDArray* boxes, NDArray* scales, int maxSize, double threshold, double scoreThreshold, NDArray* output) { Nd4jLong nonMaxSuppressionGeneric(nd4j::LaunchContext * context, NDArray* boxes, NDArray* scales, int maxSize, double threshold, double scoreThreshold, NDArray* output) {
BUILD_DOUBLE_SELECTOR(boxes->dataType(), output ? output->dataType():DataType::INT32, return nonMaxSuppressionGeneric_, BUILD_DOUBLE_SELECTOR(boxes->dataType(), output ? output->dataType():DataType::INT32, return nonMaxSuppressionGeneric_,
(context, boxes, scales, maxSize, threshold, scoreThreshold, output), (context, boxes, scales, maxSize, threshold, scoreThreshold, output, true),
FLOAT_TYPES, INDEXING_TYPES);
return boxes->sizeAt(0);
}
Nd4jLong
nonMaxSuppressionV3(nd4j::LaunchContext* context, NDArray* boxes, NDArray* scores, int maxSize,
double overlapThreshold, double scoreThreshold, NDArray* output) {
BUILD_DOUBLE_SELECTOR(boxes->dataType(), output ? output->dataType():DataType::INT32, return nonMaxSuppressionGeneric_,
(context, boxes, scores, maxSize, overlapThreshold, scoreThreshold, output, false),
FLOAT_TYPES, INDEXING_TYPES); FLOAT_TYPES, INDEXING_TYPES);
return boxes->sizeAt(0); return boxes->sizeAt(0);
} }

View File

@ -825,37 +825,30 @@ namespace helpers {
NDArray::prepareSpecialUse({output}, {input}); NDArray::prepareSpecialUse({output}, {input});
auto n2 = input->sizeAt(-1) * input->sizeAt(-2); auto n2 = input->sizeAt(-1) * input->sizeAt(-2);
auto stream = context->getCudaStream(); auto stream = context->getCudaStream();
std::unique_ptr<NDArray> tempOutput(input->dup()); NDArray tempOutput(*input);
// auto inputs = tempOutput->allTensorsAlongDimension({input->rankOf() - 2, input->rankOf() - 1});
// for (Nd4jLong e = 0; e < packX.numberOfTads(); e++) { cholesky(context, input, &tempOutput, false);
// auto subArray = inputs->at(e);
// cholesky(context, subArray, subArray, true); auto outputBuf = output->dataBuffer()->specialAsT<T>(); //reinterpret_cast<T*>(output->specialBuffer()); // + e * n2; // + e * n2;
// } auto inputBuf = tempOutput.dataBuffer()->specialAsT<T>(); //reinterpret_cast<T*>(tempOutput->specialBuffer());
// delete inputs; output->nullify();
cholesky(context, input, tempOutput.get(), false); auto packX = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(tempOutput.getShapeInfo(),
tempOutput->syncToHost(); {tempOutput.rankOf() - 2,
tempOutput->printIndexedBuffer("Cholesky res!!!"); tempOutput.rankOf() - 1});
auto outputBuf = reinterpret_cast<T*>(output->specialBuffer()); // + e * n2; // + e * n2; logDetKernel<T> <<< 128, 512, 256, *stream >>>(inputBuf, tempOutput.specialShapeInfo(),
auto inputBuf = reinterpret_cast<T*>(tempOutput->specialBuffer()); packX.numberOfTads(), packX.specialShapeInfo(),
output->assign(0); packX.specialOffsets(), outputBuf, output->specialShapeInfo());
output->syncToDevice(); output->tickWriteDevice();
auto packX = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(tempOutput->getShapeInfo(),
{input->rankOf() - 2,
input->rankOf() - 1});
logDetKernel<T> << < packX.numberOfTads(), n2, 128, *stream >> >
(inputBuf, tempOutput->specialShapeInfo(), packX.numberOfTads(), packX.specialShapeInfo(), packX.specialOffsets(), outputBuf, output->specialShapeInfo());
// }
NDArray::registerSpecialUse({output}, {input}); NDArray::registerSpecialUse({output}, {input});
//delete tempOutput;
return Status::OK(); return Status::OK();
} }
int logdetFunctor(nd4j::LaunchContext *context, NDArray *input, NDArray *output) { int logdetFunctor(nd4j::LaunchContext *context, NDArray *input, NDArray *output) {
BUILD_SINGLE_SELECTOR(output->dataType(), logdetFunctor_, (context, input, output), FLOAT_NATIVE); BUILD_SINGLE_SELECTOR(output->dataType(), return logdetFunctor_, (context, input, output), FLOAT_NATIVE);
} }
BUILD_SINGLE_TEMPLATE(template int logdetFunctor_, // BUILD_SINGLE_TEMPLATE(template int logdetFunctor_,
(nd4j::LaunchContext * context, NDArray * input, NDArray * output), FLOAT_NATIVE); // (nd4j::LaunchContext * context, NDArray * input, NDArray * output), FLOAT_NATIVE);
} }
} }
} }

View File

@ -35,127 +35,84 @@ namespace helpers {
// -------------------------------------------------------------------------------------------------------------- // // -------------------------------------------------------------------------------------------------------------- //
template <typename T, typename I> template <typename T, typename I>
static __global__ void segmentProdLinearKernel(void* input, Nd4jLong* inputShape, int* starts, int* lengths, Nd4jLong numOfClasses, void* output, Nd4jLong* outputShape) { static __global__ void segmentProdLinearKernel(void* input, Nd4jLong* inputShape, int* starts, int* lengths,
__shared__ T* val; Nd4jLong numOfClasses, void* output, Nd4jLong* outputShape) {
__shared__ Nd4jLong xLen, zLen, segment, zIndex;
__shared__ Nd4jLong xLen, zLen;
__shared__ T* x; __shared__ T* x;
__shared__ T* z; __shared__ T* z;
__shared__ int threadsPerSegment, start, finish;
if (threadIdx.x == 0) { if (threadIdx.x == 0) {
threadsPerSegment = (gridDim.x + numOfClasses - 1) / numOfClasses;
segment = blockIdx.x / threadsPerSegment;
x = reinterpret_cast<T*>(input); x = reinterpret_cast<T*>(input);
z = reinterpret_cast<T*>(output); z = reinterpret_cast<T*>(output);
extern __shared__ unsigned char shmem[];
val = reinterpret_cast<T*>(shmem);
xLen = shape::length(inputShape); xLen = shape::length(inputShape);
zLen = shape::length(outputShape); zLen = shape::length(outputShape);
}
__syncthreads();
if (segment < numOfClasses) { for(auto segment = blockIdx.x; segment < numOfClasses; segment += gridDim.x) {
zIndex = shape::getIndexOffset(segment, outputShape); auto zIndex = shape::getIndexOffset(segment, outputShape);
start = starts[segment]; auto start = starts[segment];
finish = start + lengths[segment]; auto finish = start + lengths[segment];
//val[segment] = ; if (lengths[segment] == 0) {
z[zIndex] = x[shape::getIndexOffset(start, inputShape)]; continue;
val[segment] = z[zIndex]; }
for (auto e = start + threadIdx.x; e < finish; e += blockDim.x) {
auto xIndex = shape::getIndexOffset(e, inputShape);
nd4j::math::atomics::nd4j_atomicMul(&z[segment], x[xIndex]);
} }
} }
__syncthreads();
// auto tid = threadIdx.x + blockIdx.x * blockDim.x;
// auto step = blockDim.x * gridDim.x;
for (auto e = start + threadIdx.x + 1; e < finish; e += blockDim.x) {
auto xIndex = shape::getIndexOffset(e, inputShape);
nd4j::math::atomics::nd4j_atomicMul(&val[segment], x[xIndex]);
}
__syncthreads();
if (threadIdx.x == 0) {
z[zIndex] = val[segment];
}
__syncthreads();
} }
// -------------------------------------------------------------------------------------------------------------- // // -------------------------------------------------------------------------------------------------------------- //
template <typename T, typename I> template <typename T, typename I>
static __global__ void unsortedSegmentProdLinearKernel(void* input, Nd4jLong* inputShape, void* indices, Nd4jLong* indicesShape, int* starts, int* lengths, Nd4jLong numOfClasses, void* output, Nd4jLong* outputShape) { static __global__ void unsortedSegmentProdLinearKernel(T* input, Nd4jLong* inputShape, I* indices, Nd4jLong* indicesShape, int* starts, int* lengths, Nd4jLong numOfClasses, T* output, Nd4jLong* outputShape) {
__shared__ T* val; __shared__ Nd4jLong xLen, zLen;
__shared__ Nd4jLong xLen, zLen, segment, zIndex;
__shared__ T* x;
__shared__ T* z;
__shared__ I* y; //int threadsPerSegment, start, finish;
if (threadIdx.x == 0) { if (threadIdx.x == 0) {
// threadsPerSegment = (gridDim.x + numOfClasses - 1) / numOfClasses;
segment = blockIdx.x;// / threadsPerSegment;
x = reinterpret_cast<T*>(input);
z = reinterpret_cast<T*>(output);
y = reinterpret_cast<I*>(indices);
// extern __shared__ unsigned char shmem[];
// val = reinterpret_cast<T*>(shmem);
xLen = shape::length(inputShape); xLen = shape::length(inputShape);
zLen = shape::length(outputShape); zLen = shape::length(outputShape);
// if (segment < numOfClasses) {
zIndex = shape::getIndexOffset(segment, outputShape);
//start = starts[segment];
//finish = start + lengths[segment];
if (lengths[segment] > 0)
z[zIndex] = x[shape::getIndexOffset(starts[segment], inputShape)];
else
z[zIndex] = 0; //DataTypeUtils::max<T>();
// val[segment] = z[zIndex];
// }
} }
__syncthreads(); __syncthreads();
if (lengths[segment] > 0) auto start = threadIdx.x + blockIdx.x * blockDim.x;
for (auto e = threadIdx.x; e < xLen; e += blockDim.x) { auto step = blockDim.x * gridDim.x;
auto xIndex = shape::getIndexOffset(e, inputShape); for (auto idx = start; idx < xLen; idx += step) {
auto yIndex = shape::getIndexOffset(e, indicesShape); auto xIndex = shape::getIndexOffset(idx, inputShape);
if (y[yIndex] == segment && e != starts[segment]) { auto yIndex = shape::getIndexOffset(idx, indicesShape);
nd4j::math::atomics::nd4j_atomicMul(&z[zIndex], x[xIndex]); auto segment = indices[yIndex];
} auto zIndex = shape::getIndexOffset(segment, outputShape);
if (lengths[segment] == 0) {
continue;
} }
nd4j::math::atomics::nd4j_atomicMul(&output[zIndex], input[xIndex]);
}
} }
// -------------------------------------------------------------------------------------------------------------- // // -------------------------------------------------------------------------------------------------------------- //
// SegmentProd kernel // SegmentProd kernel
template <typename T, typename I> template <typename T, typename I>
static __global__ void segmentProdTadKernel(void* inputBuf, Nd4jLong* inputShape, Nd4jLong* inputTads, Nd4jLong* inputTadOffsets, I* indices, int* starts, int* lengths, Nd4jLong numOfClasses, void* outputBuf, Nd4jLong* outputShape, Nd4jLong* outputTads, Nd4jLong* outputTadOffsets) { static __global__ void segmentProdTadKernel(void* inputBuf, Nd4jLong* inputShape, Nd4jLong* inputTads,
__shared__ T* val; Nd4jLong* inputTadOffsets, I* indices, int* starts, int* lengths, Nd4jLong numOfClasses, void* outputBuf,
__shared__ Nd4jLong len, segment, zIndex, total; Nd4jLong* outputShape, Nd4jLong* outputTads, Nd4jLong* outputTadOffsets) {
__shared__ T* z;
__shared__ int threadsPerSegment, start, finish; __shared__ Nd4jLong len, total;
if (threadIdx.x == 0) { if (threadIdx.x == 0) {
segment = indices[blockIdx.x]; // / threadsPerSegment;
z = reinterpret_cast<T*>(outputBuf) + outputTadOffsets[segment];
len = shape::length(inputTads);
start = starts[segment];
finish = start + lengths[segment];
total = shape::sizeAt(inputShape, 0); total = shape::sizeAt(inputShape, 0);
len = shape::length(inputTads);
} }
__syncthreads(); __syncthreads();
auto idx = blockIdx.x; for (auto idx = blockIdx.x; idx < total; idx += gridDim.x) {
if (blockIdx.x <= total) {
auto x = reinterpret_cast<T *>(inputBuf) + inputTadOffsets[idx]; auto x = reinterpret_cast<T *>(inputBuf) + inputTadOffsets[idx];
if (blockIdx.x == start) { auto segment = indices[idx]; // / threadsPerSegment;
for (auto e = threadIdx.x; e < len; e += blockDim.x) { auto z = reinterpret_cast<T *>(outputBuf) + outputTadOffsets[segment];
auto xIndex = shape::getIndexOffset(e, inputTads); auto start = starts[segment];
auto zIndex = shape::getIndexOffset(e, outputTads); auto finish = start + lengths[segment];
nd4j::math::atomics::nd4j_atomicMul(&z[zIndex], x[xIndex]); if (lengths[segment] == 0) continue;
} for (auto e = threadIdx.x; e < len; e += blockDim.x) {
} auto xIndex = shape::getIndexOffset(e, inputTads);
else { auto zIndex = shape::getIndexOffset(e, outputTads);
for (auto e = threadIdx.x; e < len; e += blockDim.x) { nd4j::math::atomics::nd4j_atomicMul(&z[zIndex], x[xIndex]);
auto xIndex = shape::getIndexOffset(e, inputTads);
auto zIndex = shape::getIndexOffset(e, outputTads);
if (lengths[segment] > 0)
nd4j::math::atomics::nd4j_atomicMul(&z[zIndex], x[xIndex]);
}
} }
} }
} }
@ -177,7 +134,7 @@ namespace helpers {
int* lengths = reinterpret_cast<int*>(classesRangesLens.specialBuffer()); int* lengths = reinterpret_cast<int*>(classesRangesLens.specialBuffer());
if (input->isVector()) { if (input->isVector()) {
segmentProdLinearKernel<T,I><<<numClasses, input->lengthOf(), numClasses * 32 + 32, *stream>>>(input->specialBuffer(), input->specialShapeInfo(), begins, lengths, numClasses, output->specialBuffer(), output->specialShapeInfo()); segmentProdLinearKernel<T,I><<<128, 256, 128, *stream>>>(input->specialBuffer(), input->specialShapeInfo(), begins, lengths, numClasses, output->specialBuffer(), output->specialShapeInfo());
} }
else { else {
std::vector<int> dimensions = ShapeUtils::evalDimsToExclude(input->rankOf(), {0}); std::vector<int> dimensions = ShapeUtils::evalDimsToExclude(input->rankOf(), {0});
@ -187,7 +144,7 @@ namespace helpers {
Nd4jLong* inputTadOffsets = packX.specialOffsets(); Nd4jLong* inputTadOffsets = packX.specialOffsets();
Nd4jLong* outputTads = packZ.specialShapeInfo(); Nd4jLong* outputTads = packZ.specialShapeInfo();
Nd4jLong* outputTadOffsets = packZ.specialOffsets(); Nd4jLong* outputTadOffsets = packZ.specialOffsets();
segmentProdTadKernel<T,I><<<input->sizeAt(0), 512, 2048, *stream>>>(input->specialBuffer(), input->specialShapeInfo(), inputTads, inputTadOffsets, reinterpret_cast<I*>(indices->specialBuffer()), begins, lengths, numClasses, output->specialBuffer(), output->specialShapeInfo(), outputTads, outputTadOffsets); segmentProdTadKernel<T,I><<<128, 512, 2048, *stream>>>(input->specialBuffer(), input->specialShapeInfo(), inputTads, inputTadOffsets, reinterpret_cast<I*>(indices->specialBuffer()), begins, lengths, numClasses, output->specialBuffer(), output->specialShapeInfo(), outputTads, outputTadOffsets);
} }
} }
@ -214,12 +171,15 @@ namespace helpers {
fillUpSegments(indices, numOfClasses, classesRangesBegs, classesRangesLens); fillUpSegments(indices, numOfClasses, classesRangesBegs, classesRangesLens);
int* begins = reinterpret_cast<int*>(classesRangesBegs.specialBuffer()); int* begins = reinterpret_cast<int*>(classesRangesBegs.specialBuffer());
int* lengths = reinterpret_cast<int*>(classesRangesLens.specialBuffer()); int* lengths = reinterpret_cast<int*>(classesRangesLens.specialBuffer());
output->assign(1);
if (input->isVector()) { if (input->isVector()) {
unsortedSegmentProdLinearKernel<T,I><<<dims.x, dims.y, dims.z, *stream>>>(input->specialBuffer(), input->specialShapeInfo(), indices->specialBuffer(), indices->specialShapeInfo(), begins, lengths, numOfClasses, output->specialBuffer(), output->specialShapeInfo()); unsortedSegmentProdLinearKernel<T,I><<<128, 256, 256, *stream>>>(
input->dataBuffer()->specialAsT<T>(), input->specialShapeInfo(),
indices->dataBuffer()->specialAsT<I>(), indices->specialShapeInfo(), begins, lengths, numOfClasses,
output->dataBuffer()->specialAsT<T>(), output->specialShapeInfo());
} }
else { else {
output->assign(1);
std::vector<int> dimensions = ShapeUtils::evalDimsToExclude(input->rankOf(), {0}); std::vector<int> dimensions = ShapeUtils::evalDimsToExclude(input->rankOf(), {0});
auto packX = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(input->getShapeInfo(), dimensions); auto packX = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(input->getShapeInfo(), dimensions);
auto packZ = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(output->getShapeInfo(), dimensions); auto packZ = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(output->getShapeInfo(), dimensions);
@ -228,7 +188,7 @@ namespace helpers {
Nd4jLong* outputTads = packZ.specialShapeInfo(); Nd4jLong* outputTads = packZ.specialShapeInfo();
Nd4jLong* outputTadOffsets = packZ.specialOffsets(); Nd4jLong* outputTadOffsets = packZ.specialOffsets();
dims.x = input->sizeAt(0); dims.x = input->sizeAt(0);
segmentProdTadKernel<T,I><<<dims.x, dims.y, dims.z, *stream>>>(input->specialBuffer(), input->specialShapeInfo(), inputTads, inputTadOffsets, reinterpret_cast<I*>(indices->specialBuffer()), begins, lengths, numOfClasses, output->specialBuffer(), output->specialShapeInfo(), outputTads, outputTadOffsets); segmentProdTadKernel<T,I><<<128, 256, 256, *stream>>>(input->specialBuffer(), input->specialShapeInfo(), inputTads, inputTadOffsets, reinterpret_cast<I*>(indices->specialBuffer()), begins, lengths, numOfClasses, output->specialBuffer(), output->specialShapeInfo(), outputTads, outputTadOffsets);
} }
} }

View File

@ -32,82 +32,52 @@ namespace ops {
namespace helpers { namespace helpers {
// -------------------------------------------------------------------------------------------------------------- // // -------------------------------------------------------------------------------------------------------------- //
template <typename T, typename I> template <typename T, typename I>
static __global__ void unsortedSegmentSqrtNLinearKernel(void* input, Nd4jLong* inputShape, void* indices, Nd4jLong* indicesShape, int* starts, int* lengths, Nd4jLong numOfClasses, void* output, Nd4jLong* outputShape) { static __global__ void unsortedSegmentSqrtNLinearKernel(T* input, Nd4jLong* inputShape, I* indices, Nd4jLong* indicesShape, int* starts, int* lengths, Nd4jLong numOfClasses, T* output, Nd4jLong* outputShape) {
__shared__ T* val; __shared__ Nd4jLong xLen, zLen;
__shared__ Nd4jLong xLen, zLen, segment, zIndex;
__shared__ T* x;
__shared__ T* z;
__shared__ I* y; //int threadsPerSegment, start, finish;
if (threadIdx.x == 0) { if (threadIdx.x == 0) {
// threadsPerSegment = (gridDim.x + numOfClasses - 1) / numOfClasses;
segment = blockIdx.x;// / threadsPerSegment;
x = reinterpret_cast<T*>(input);
z = reinterpret_cast<T*>(output);
y = reinterpret_cast<I*>(indices);
// extern __shared__ unsigned char shmem[];
// val = reinterpret_cast<T*>(shmem);
xLen = shape::length(inputShape); xLen = shape::length(inputShape);
zLen = shape::length(outputShape); zLen = shape::length(outputShape);
// if (segment < numOfClasses) {
zIndex = shape::getIndexOffset(segment, outputShape);
//start = starts[segment];
//finish = start + lengths[segment];
if (lengths[segment] > 0)
z[zIndex] = x[shape::getIndexOffset(starts[segment], inputShape)] / nd4j::math::nd4j_sqrt<int, T>(lengths[segment]);
else
z[zIndex] = 0; //DataTypeUtils::max<T>();
// val[segment] = z[zIndex];
// }
} }
__syncthreads(); __syncthreads();
if (lengths[segment] > 0)
for (auto e = threadIdx.x + 1; e < xLen; e += blockDim.x) { auto start = threadIdx.x + blockIdx.x * blockDim.x;
auto xIndex = shape::getIndexOffset(e, inputShape); auto step = blockDim.x * gridDim.x;
auto yIndex = shape::getIndexOffset(e, indicesShape);
if (y[yIndex] == segment && e != starts[segment]) { for (auto idx = start; idx < xLen; idx += step) {
nd4j::math::atomics::nd4j_atomicAdd(&z[zIndex], x[xIndex] / nd4j::math::nd4j_sqrt<int, T>(lengths[segment])); auto yIndex = shape::getIndexOffset(idx, indicesShape);
} auto segment = indices[yIndex];
} auto zIndex = shape::getIndexOffset(segment, outputShape);
if (lengths[segment] == 0) continue;
auto xIndex = shape::getIndexOffset(idx, inputShape);
nd4j::math::atomics::nd4j_atomicAdd(&output[zIndex], input[xIndex] / nd4j::math::nd4j_sqrt<int, T>(lengths[segment]));
}
} }
// -------------------------------------------------------------------------------------------------------------- // // -------------------------------------------------------------------------------------------------------------- //
// SegmentSqrtN kernel // SegmentSqrtN kernel
template <typename T, typename I> template <typename T, typename I>
static __global__ void segmentSqrtNTadKernel(void* inputBuf, Nd4jLong* inputShape, Nd4jLong* inputTads, Nd4jLong* inputTadOffsets, I* indices, int* starts, int* lengths, Nd4jLong numOfClasses, void* outputBuf, Nd4jLong* outputShape, Nd4jLong* outputTads, Nd4jLong* outputTadOffsets) { static __global__ void segmentSqrtNTadKernel(T* inputBuf, Nd4jLong* inputShape, Nd4jLong* inputTads, Nd4jLong* inputTadOffsets, I* indices, int* starts, int* lengths, Nd4jLong numOfClasses, void* outputBuf, Nd4jLong* outputShape, Nd4jLong* outputTads, Nd4jLong* outputTadOffsets) {
__shared__ T* val;
__shared__ Nd4jLong len, segment, zIndex, total; __shared__ Nd4jLong len, total;
__shared__ T* z;
__shared__ int threadsPerSegment, start, finish;
if (threadIdx.x == 0) { if (threadIdx.x == 0) {
segment = indices[blockIdx.x]; // / threadsPerSegment;
z = reinterpret_cast<T*>(outputBuf) + outputTadOffsets[segment];
len = shape::length(inputTads);
start = starts[segment];
finish = start + lengths[segment];
total = shape::sizeAt(inputShape, 0); total = shape::sizeAt(inputShape, 0);
len = shape::length(inputTads);
} }
__syncthreads(); __syncthreads();
auto idx = blockIdx.x; for (auto idx = blockIdx.x; idx < total; idx += gridDim.x) {
if (blockIdx.x <= total) { auto segment = indices[idx];
auto x = reinterpret_cast<T *>(inputBuf) + inputTadOffsets[idx]; auto x = inputBuf + inputTadOffsets[idx];
if (blockIdx.x == start) { auto z = reinterpret_cast<T *>(outputBuf) + outputTadOffsets[segment];
for (auto e = threadIdx.x; e < len; e += blockDim.x) { auto start = starts[segment];
auto xIndex = shape::getIndexOffset(e, inputTads); auto finish = start + lengths[segment];
auto zIndex = shape::getIndexOffset(e, outputTads);
z[zIndex] = x[xIndex] / nd4j::math::nd4j_sqrt<int, T>(lengths[segment]); for (auto e = threadIdx.x; e < len; e += blockDim.x) {
} auto xIndex = shape::getIndexOffset(e, inputTads);
} auto zIndex = shape::getIndexOffset(e, outputTads);
else { nd4j::math::atomics::nd4j_atomicAdd(&z[zIndex], x[xIndex] / nd4j::math::nd4j_sqrt<int, T>(lengths[segment]));
for (auto e = threadIdx.x; e < len; e += blockDim.x) {
auto xIndex = shape::getIndexOffset(e, inputTads);
auto zIndex = shape::getIndexOffset(e, outputTads);
nd4j::math::atomics::nd4j_atomicAdd(&z[zIndex], x[xIndex] / nd4j::math::nd4j_sqrt<int, T>(lengths[segment]));
}
} }
} }
} }
@ -122,17 +92,21 @@ namespace helpers {
// classes.applyTrueBroadcast(nd4j::BroadcastOpsTuple::Assign(), &row, &classes); // classes.applyTrueBroadcast(nd4j::BroadcastOpsTuple::Assign(), &row, &classes);
classesRangesBegs.assign(indices->lengthOf()); classesRangesBegs.assign(indices->lengthOf());
classesRangesLens.assign(0); classesRangesLens.assign(0);
dim3 dims(numOfClasses, indices->lengthOf(), numOfClasses * 32 + 32); // dim3 dims(numOfClasses, indices->lengthOf(), numOfClasses * 32 + 32);
dim3 dims(128, 256, 256);
// int* classesBuf = reinterpret_cast<int*>(classes.specialBuffer()); // int* classesBuf = reinterpret_cast<int*>(classes.specialBuffer());
fillUpSegments(indices, numOfClasses, classesRangesBegs, classesRangesLens); fillUpSegments(indices, numOfClasses, classesRangesBegs, classesRangesLens);
int* begins = reinterpret_cast<int*>(classesRangesBegs.specialBuffer()); int* begins = reinterpret_cast<int*>(classesRangesBegs.specialBuffer());
int* lengths = reinterpret_cast<int*>(classesRangesLens.specialBuffer()); int* lengths = reinterpret_cast<int*>(classesRangesLens.specialBuffer());
output->nullify();
if (input->isVector()) { if (input->isVector()) {
unsortedSegmentSqrtNLinearKernel<T,I><<<dims.x, dims.y, dims.z, *stream>>>(input->specialBuffer(), input->specialShapeInfo(), indices->specialBuffer(), indices->specialShapeInfo(), begins, lengths, numOfClasses, output->specialBuffer(), output->specialShapeInfo()); unsortedSegmentSqrtNLinearKernel<T,I><<<dims.x, dims.y, dims.z, *stream>>>(
input->dataBuffer()->specialAsT<T>(), input->specialShapeInfo(),
indices->dataBuffer()->specialAsT<I>(), indices->specialShapeInfo(), begins, lengths, numOfClasses,
output->dataBuffer()->specialAsT<T>(), output->specialShapeInfo());
} }
else { else {
output->assign(0); output->nullify();
std::vector<int> dimensions = ShapeUtils::evalDimsToExclude(input->rankOf(), {0}); std::vector<int> dimensions = ShapeUtils::evalDimsToExclude(input->rankOf(), {0});
auto packX = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(input->getShapeInfo(), dimensions); auto packX = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(input->getShapeInfo(), dimensions);
auto packZ = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(output->getShapeInfo(), dimensions); auto packZ = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(output->getShapeInfo(), dimensions);
@ -141,7 +115,9 @@ namespace helpers {
Nd4jLong* outputTads = packZ.specialShapeInfo(); Nd4jLong* outputTads = packZ.specialShapeInfo();
Nd4jLong* outputTadOffsets = packZ.specialOffsets(); Nd4jLong* outputTadOffsets = packZ.specialOffsets();
dims.x = input->sizeAt(0); dims.x = input->sizeAt(0);
segmentSqrtNTadKernel<T,I><<<dims.x, dims.y, dims.z, *stream>>>(input->specialBuffer(), input->specialShapeInfo(), inputTads, inputTadOffsets, reinterpret_cast<I*>(indices->specialBuffer()), begins, lengths, numOfClasses, output->specialBuffer(), output->specialShapeInfo(), outputTads, outputTadOffsets); segmentSqrtNTadKernel<T,I><<<dims.x, dims.y, dims.z, *stream>>>(
input->dataBuffer()->specialAsT<T>(), input->specialShapeInfo(), inputTads, inputTadOffsets, indices->dataBuffer()->specialAsT<I>(),
begins, lengths, numOfClasses, output->specialBuffer(), output->specialShapeInfo(), outputTads, outputTadOffsets);
} }
} }
// -------------------------------------------------------------------------------------------------------------- // // -------------------------------------------------------------------------------------------------------------- //

View File

@ -28,6 +28,8 @@ namespace helpers {
void nonMaxSuppression(nd4j::LaunchContext * context, NDArray* boxes, NDArray* scales, int maxSize, void nonMaxSuppression(nd4j::LaunchContext * context, NDArray* boxes, NDArray* scales, int maxSize,
double overlapThreshold, double scoreThreshold, NDArray* output); double overlapThreshold, double scoreThreshold, NDArray* output);
Nd4jLong nonMaxSuppressionV3(nd4j::LaunchContext * context, NDArray* boxes, NDArray* scales, int maxSize,
double overlapThreshold, double scoreThreshold, NDArray* output);
Nd4jLong nonMaxSuppressionGeneric(nd4j::LaunchContext* context, NDArray* boxes, NDArray* scores, int maxSize, Nd4jLong nonMaxSuppressionGeneric(nd4j::LaunchContext* context, NDArray* boxes, NDArray* scores, int maxSize,
double overlapThreshold, double scoreThreshold, NDArray* output); double overlapThreshold, double scoreThreshold, NDArray* output);

View File

@ -1446,35 +1446,63 @@ inline __device__ unsigned char nd4j_atomicMul<unsigned char>(unsigned char* add
return (uint8_t)old; return (uint8_t)old;
} }
template <> template <typename T>
inline __device__ int16_t nd4j_atomicMul<int16_t>(int16_t* address, int16_t val) { static inline __device__ T internal_16bit_atomicMul(T* address, int16_t val) {
size_t shift = ((size_t)address & 2); size_t shift = ((size_t)address & 2);
int *base_address = (int *)((char*)address - shift); int *base_address = (int *)((char*)address - shift);
int old = val, assumed;
//printf("%u %x", *base_address);
do {
assumed = old; union I16PAIR {
old = atomicCAS(base_address, assumed, (old * val)); struct {
} while (assumed != old); T H;
T L;
} B;
int W;
return (int16_t)old; __host__ __device__
I16PAIR() {};
__host__ __device__
~I16PAIR() {};
};
I16PAIR pairNew, pairOld, pairAssumed;
pairOld.W = (int) val;
if (reinterpret_cast<int*>(address) == base_address) {
do {
pairNew.B.L = pairOld.B.L;
pairNew.B.H = pairOld.B.H * val;
pairAssumed.W = pairOld.W;
pairOld.W = atomicCAS(base_address, pairAssumed.W, pairNew.W);
} while (pairAssumed.W != pairOld.W);
return (T) pairOld.B.H;
} else {
do {
pairNew.B.H = pairOld.B.H;
pairNew.B.L = pairOld.B.L * val;
pairAssumed.W = pairOld.W;
pairOld.W = atomicCAS(base_address, pairAssumed.W, pairNew.W);
} while (pairAssumed.W != pairOld.W);
return (T) pairOld.B.L;
}
}
template <>
inline __device__ int16_t nd4j_atomicMul<int16_t>(int16_t* address, int16_t val) {
return internal_16bit_atomicMul<int16_t>(address, val);
} }
template <> template <>
inline __device__ uint16_t nd4j_atomicMul<uint16_t>(uint16_t* address, uint16_t val) { inline __device__ uint16_t nd4j_atomicMul<uint16_t>(uint16_t* address, uint16_t val) {
size_t shift = ((size_t)address & 2); return internal_16bit_atomicMul<uint16_t>(address, val);
unsigned int *base_address = (unsigned int *)((char*)address - shift);
unsigned int old = val, assumed;
//printf("%u %x", *base_address);
do {
assumed = old;
old = atomicCAS(base_address, assumed, (old * val));
} while (assumed != old);
return (uint16_t)old;
} }
template <> template <>
@ -1547,105 +1575,27 @@ inline __device__ Nd4jLong nd4j_atomicMul<Nd4jLong>(Nd4jLong* address, Nd4jLong
template <> template <>
inline __device__ bfloat16 nd4j_atomicMul<bfloat16>(bfloat16* address, bfloat16 val) { inline __device__ bfloat16 nd4j_atomicMul<bfloat16>(bfloat16* address, bfloat16 val) {
auto address_as_ull = (int*) address; return internal_16bit_atomicMul<bfloat16>(address, val);
long addr = (long)(address);
bool misaligned = addr & 0x3;
if (misaligned)
address_as_ull = (int *) (address - 1);
BPAIR old, assumed, fresh;
old.W = *address_as_ull;
do {
if (!misaligned) {
bfloat16 res = old.B.H * val;
fresh.B.H = res;
fresh.B.L = old.B.L;
} else {
bfloat16 res = old.B.L * val;
fresh.B.L = res;
fresh.B.H = old.B.H;
}
assumed.W = old.W;
old.W = atomicCAS(address_as_ull, assumed.W, fresh.W);
} while (assumed.W != old.W);
if (!misaligned) return old.B.H;
else return old.B.L;
} }
template <> template <>
inline __device__ float16 nd4j_atomicMul<float16>(float16* address, float16 val) { inline __device__ float16 nd4j_atomicMul<float16>(float16* address, float16 val) {
auto address_as_ull = (int*) address; return internal_16bit_atomicMul<float16>(address, val);
long addr = (long)(address);
bool misaligned = addr & 0x3;
if (misaligned)
address_as_ull = (int *) (address - 1);
BPAIR old, assumed, fresh;
old.W = *address_as_ull;
do {
if (!misaligned) {
bfloat16 res = old.B.H * val;
fresh.B.H = res;
fresh.B.L = old.B.L;
} else {
bfloat16 res = old.B.L * val;
fresh.B.L = res;
fresh.B.H = old.B.H;
}
assumed.W = old.W;
old.W = atomicCAS(address_as_ull, assumed.W, fresh.W);
} while (assumed.W != old.W);
if (!misaligned) return old.B.H;
else return old.B.L;
} }
template <> template <>
inline __device__ float nd4j_atomicDiv<float>(float* address, float val) { inline __device__ float nd4j_atomicDiv<float>(float* address, float val) {
int* address_as_ull = return nd4j_atomicMul<float>(address, (float) 1.f / val);
(int*)address;
int old = *address_as_ull, assumed;
do {
assumed = old;
old = atomicCAS(address_as_ull, assumed, __float_as_int(__int_as_float(assumed) / val ));
} while (assumed != old);
return __int_as_float(old);
} }
template <> template <>
inline __device__ float16 nd4j_atomicDiv<float16>(float16* address, float16 val) { inline __device__ float16 nd4j_atomicDiv<float16>(float16* address, float16 val) {
int* address_as_ull = return nd4j_atomicMul<float16>(address, (float16) 1.f / val);
(int*)address;
int old = *address_as_ull, assumed;
do {
assumed = old;
old = atomicCAS(address_as_ull, assumed, __float_as_int(val *
__float_as_int(assumed)));
} while (assumed != old);
return __int_as_float(old);
} }
template <> template <>
inline __device__ bfloat16 nd4j_atomicDiv<bfloat16>(bfloat16* address, bfloat16 val) { inline __device__ bfloat16 nd4j_atomicDiv<bfloat16>(bfloat16* address, bfloat16 val) {
int* address_as_ull = return nd4j_atomicMul<bfloat16>(address, (bfloat16) 1.f / val);
(int*)address;
int old = *address_as_ull, assumed;
do {
assumed = old;
old = atomicCAS(address_as_ull, assumed, __float_as_int(val *
__float_as_int(assumed)));
} while (assumed != old);
return __int_as_float(old);
} }
} }
#endif #endif

View File

@ -0,0 +1,86 @@
/*******************************************************************************
* Copyright (c) 2019 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// @author raver119@gmail.com
//
#include "testlayers.h"
#include <ops/declarable/CustomOperations.h>
#include <NDArray.h>
#include <ops/ops.h>
#include <GradCheck.h>
#include <helpers/RandomLauncher.h>
#include <exceptions/cuda_exception.h>
using namespace nd4j;
class AtomicTests : public testing::Test {
public:
AtomicTests() {
//
}
};
template <typename T>
static _CUDA_G void multiplyKernel(void *vbuffer, uint64_t length, void *vresult) {
auto buffer = reinterpret_cast<T*>(vbuffer);
auto result = reinterpret_cast<T*>(vresult);
auto tid = blockIdx.x * blockDim.x + threadIdx.x;
for (auto e = tid; e < length; e += gridDim.x * blockDim.x) {
auto rem = e % 4;
auto i = (e - rem) / 4;
nd4j::math::atomics::nd4j_atomicMul<T>(&result[i], buffer[e]);
}
}
template <typename T>
static void multiplyLauncher(void *vbuffer, uint64_t length, void *vresult) {
multiplyKernel<T><<<256, 256, 1024, *nd4j::LaunchContext::defaultContext()->getCudaStream()>>>(vbuffer, length, vresult);
auto err = cudaStreamSynchronize(*nd4j::LaunchContext::defaultContext()->getCudaStream());
if (err != 0)
nd4j::cuda_exception::build("multiply failed", err);
}
static void multiplyHost(NDArray &input, NDArray &output) {
BUILD_SINGLE_SELECTOR(input.dataType(), multiplyLauncher, (input.specialBuffer(), input.lengthOf(), output.specialBuffer()), NUMERIC_TYPES);
}
TEST_F(AtomicTests, test_multiply) {
std::vector<nd4j::DataType> dtypes = {nd4j::DataType::FLOAT32, nd4j::DataType::DOUBLE, nd4j::DataType::INT16};
for (auto t:dtypes) {
nd4j_printf("Trying data type [%s]\n", DataTypeUtils::asString(t).c_str());
NDArray input('c', {4, 25}, t);
NDArray output('c', {input.lengthOf() / 4}, t);
NDArray exp = output.ulike();
input.assign(2);
output.assign(2);
exp.assign(32);
multiplyHost(input, output);
ASSERT_EQ(exp, output);
}
}

View File

@ -2330,6 +2330,76 @@ TEST_F(DeclarableOpsTests10, Image_NonMaxSuppressing_5) {
delete results; delete results;
} }
TEST_F(DeclarableOpsTests10, Image_NonMaxSuppressing_6) {
NDArray boxes = NDArrayFactory::create<float16>('c', {3, 4}, {0.8115f, 0.4121f, 0.0771f, 0.4863f,
0.7412f, 0.7607f, 0.1543f, 0.5479f,
0.8223f, 0.2246f, 0.0049f, 0.6465f});
NDArray scales = NDArrayFactory::create<float16>('c', {3}, {0.0029f, 0.8135f, 0.4873f}); //3, 0, 1, 2, 4, 5
NDArray expected = NDArrayFactory::create<int>('c', {2}, {1,2});
NDArray maxSize = NDArrayFactory::create(2);
NDArray threshold = NDArrayFactory::create(0.5f);
NDArray scoreThreshold = NDArrayFactory::create(-DataTypeUtils::infOrMax<float>());
nd4j::ops::non_max_suppression_v3 op;
auto results = op.execute({&boxes, &scales, &maxSize, &threshold, &scoreThreshold}, {}, {});
ASSERT_EQ(Status::OK(), results->status());
NDArray* result = results->at(0);
// result->printBuffer("NonMaxSuppression OUtput6");
// result->printShapeInfo("Ouput6 shape is");
ASSERT_TRUE(expected.isSameShapeStrict(result));
ASSERT_TRUE(expected.equalsTo(result));
delete results;
}
TEST_F(DeclarableOpsTests10, Image_NonMaxSuppressing_06) {
NDArray boxes = NDArrayFactory::create<bfloat16>('c', {3, 4}, {0.8115f, 0.4121f, 0.0771f, 0.4863f,
0.7412f, 0.7607f, 0.1543f, 0.5479f,
0.8223f, 0.2246f, 0.0049f, 0.6465f});
NDArray scales = NDArrayFactory::create<bfloat16>('c', {3}, {0.0029f, 0.8135f, 0.4873f}); //3, 0, 1, 2, 4, 5
NDArray expected = NDArrayFactory::create<int>('c', {2}, {1,2});
NDArray maxSize = NDArrayFactory::create(2);
NDArray threshold = NDArrayFactory::create(0.5f);
NDArray scoreThreshold = NDArrayFactory::create(-DataTypeUtils::infOrMax<float>());
nd4j::ops::non_max_suppression_v3 op;
auto results = op.execute({&boxes, &scales, &maxSize, &threshold, &scoreThreshold}, {}, {});
ASSERT_EQ(Status::OK(), results->status());
NDArray* result = results->at(0);
// result->printBuffer("NonMaxSuppression OUtput06");
// result->printShapeInfo("Ouput06 shape is");
ASSERT_TRUE(expected.isSameShapeStrict(result));
ASSERT_TRUE(expected.equalsTo(result));
delete results;
}
TEST_F(DeclarableOpsTests10, Image_NonMaxSuppressing_7) {
NDArray boxes = NDArrayFactory::create<float>('c', {3, 4}, {0.7788f, 0.8012f, 0.7244f, 0.2329f,
0.7271f, 0.1804f, 0.5056f, 0.8929f,
0.5461f, 0.9234f, 0.0856f, 0.7938f});
NDArray scales = NDArrayFactory::create<float>('c', {3}, {0.7717f, 0.9281f, 0.9846f}); //3, 0, 1, 2, 4, 5
NDArray maxSize = NDArrayFactory::create(0);
NDArray threshold = NDArrayFactory::create(0.5f);
NDArray scoreThreshold = NDArrayFactory::create(0.5f);
nd4j::ops::non_max_suppression_v3 op;
auto results = op.execute({&boxes, &scales, &maxSize, &threshold, &scoreThreshold}, {}, {});
ASSERT_EQ(Status::OK(), results->status());
NDArray* result = results->at(0);
// result->printBuffer("NonMaxSuppression OUtput7");
// result->printShapeInfo("Ouput6 shape is");
ASSERT_TRUE(result->isEmpty());
delete results;
}
//////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests10, Image_NonMaxSuppressingOverlap_1) { TEST_F(DeclarableOpsTests10, Image_NonMaxSuppressingOverlap_1) {
@ -2720,22 +2790,22 @@ TEST_F(DeclarableOpsTests10, FakeQuantWithMinMaxVars_Test_4) {
NDArray x = NDArrayFactory::create<float>('c', {2,4,5,3}); NDArray x = NDArrayFactory::create<float>('c', {2,4,5,3});
NDArray exp = NDArrayFactory::create<float>('c', {2,4,5,3},{ NDArray exp = NDArrayFactory::create<float>('c', {2,4,5,3},{
1.0588236, 1.9607843, 3.019608, 4.0588236, 5.098039, 6.039216, 7.0588236, 8.039216, 9.058824, 1.0588236f, 1.9607843f, 3.019608f, 4.0588236f, 5.098039f, 6.039216f, 7.0588236f, 8.039216f, 9.058824f,
10.058824, 10.980392, 12.078432, 13.058824, 13.921569, 15.09804, 16.058825, 17.058825, 18.117647, 10.058824f, 10.980392f, 12.078432f, 13.058824f, 13.921569f, 15.09804f, 16.058825f, 17.058825f, 18.117647f,
19.058825, 20., 21.137257, 22.058825, 22.941177, 23.882355, 25.058825, 26.078432, 26.901962, 19.058825f, 20.f, 21.137257f, 22.058825f, 22.941177f, 23.882355f, 25.058825f, 26.078432f, 26.901962f,
28.058825, 29.019608, 29.92157, 31.058825, 31.960785, 32.941177, 34.058823, 35.09804, 35.960785, 28.058825f, 29.019608f, 29.92157f, 31.058825f, 31.960785f, 32.941177f, 34.058823f, 35.09804f, 35.960785f,
37.058823, 38.039215, 38.980392, 40.058823, 40.980392, 42.000004, 43.058826, 43.92157, 45.01961, 37.058823f, 38.039215f, 38.980392f, 40.058823f, 40.980392f, 42.000004f, 43.058826f, 43.92157f, 45.01961f,
45., 47.058823, 48.03922, 45., 50., 51.058826, 45., 50., 54.078434, 45.f, 47.058823f, 48.03922f, 45.f, 50.f, 51.058826f, 45.f, 50.f, 54.078434f,
45., 50., 57.09804, 45., 50., 60.11765, 45., 50., 62.862747, 45.f, 50.f, 57.09804f, 45.f, 50.f, 60.11765f, 45.f, 50.f, 62.862747f,
45., 50., 65.882355, 45., 50., 68.90196, 45., 50., 70., 45.f, 50.f, 65.882355f, 45.f, 50.f, 68.90196f, 45.f, 50.f, 70.f,
45., 50., 70., 45., 50., 70., 45., 50., 70., 45.f, 50.f, 70.f, 45.f, 50.f, 70.f, 45.f, 50.f, 70.f,
45., 50., 70., 45., 50., 70., 45., 50., 70., 45.f, 50.f, 70.f, 45.f, 50.f, 70.f, 45.f, 50.f, 70.f,
45., 50., 70., 45., 50., 70., 45., 50., 70., 45.f, 50.f, 70.f, 45.f, 50.f, 70.f, 45.f, 50.f, 70.f,
45., 50., 70., 45., 50., 70., 45., 50., 70., 45.f, 50.f, 70.f, 45.f, 50.f, 70.f, 45.f, 50.f, 70.f,
45., 50., 70., 45., 50., 70., 45., 50., 70., 45.f, 50.f, 70.f, 45.f, 50.f, 70.f, 45.f, 50.f, 70.f,
45., 50., 70.}); 45.f, 50.f, 70.f});
NDArray min = NDArrayFactory::create<float>({20., 20., 20.}); NDArray min = NDArrayFactory::create<float>({20.f, 20.f, 20.f});
NDArray max = NDArrayFactory::create<float>({65., 70., 90.}); NDArray max = NDArrayFactory::create<float>({65.f, 70.f, 90.f});
x.linspace(1.); x.linspace(1.);
nd4j::ops::fake_quant_with_min_max_vars_per_channel op; nd4j::ops::fake_quant_with_min_max_vars_per_channel op;
auto results = op.execute({&x, &min, &max}, {}, {}); auto results = op.execute({&x, &min, &max}, {}, {});
@ -2756,36 +2826,36 @@ TEST_F(DeclarableOpsTests10, FakeQuantWithMinMaxVars_Test_4) {
TEST_F(DeclarableOpsTests10, FakeQuantWithMinMaxVars_Test_5) { TEST_F(DeclarableOpsTests10, FakeQuantWithMinMaxVars_Test_5) {
NDArray x = NDArrayFactory::create<float>('c', {2, 3, 5, 4}); NDArray x = NDArrayFactory::create<float>('c', {2, 3, 5, 4});
NDArray exp = NDArrayFactory::create<float>('c', {2, 3, 5, 4},{ NDArray exp = NDArrayFactory::create<float>('c', {2, 3, 5, 4},{
-19.92157 , -18.980392 , -18.039217 , -16.941177 , -19.92157f, -18.980392f, -18.039217f, -16.941177f,
-19.92157 , -18.980392 , -18.039217 , -16.941177 , -19.92157f, -18.980392f, -18.039217f, -16.941177f,
-19.92157 , -18.980392 , -18.039217 , -16.941177 , -19.92157f, -18.980392f, -18.039217f, -16.941177f,
-19.92157 , -18.980392 , -18.039217 , -16.941177 , -19.92157f, -18.980392f, -18.039217f, -16.941177f,
-19.92157 , -18.980392 , -18.039217 , -16.941177 , -19.92157f, -18.980392f, -18.039217f, -16.941177f,
-19.92157 , -18.980392 , -18.039217 , -16.941177 , -19.92157f, -18.980392f, -18.039217f, -16.941177f,
-19.92157 , -18.980392 , -18.039217 , -16.941177 , -19.92157f, -18.980392f, -18.039217f, -16.941177f,
-19.92157 , -18.980392 , -18.039217 , -16.941177 , -19.92157f, -18.980392f, -18.039217f, -16.941177f,
-19.92157 , -18.980392 , -18.039217 , -16.941177 , -19.92157f, -18.980392f, -18.039217f, -16.941177f,
-19.92157 , -18.980392 , -18.039217 , -16.941177 , -19.92157f, -18.980392f, -18.039217f, -16.941177f,
-19.92157 , -18.980392 , -18.039217 , -16.941177 , -19.92157f, -18.980392f, -18.039217f, -16.941177f,
-16. , -15.058824 , -13.960785 , -13.0196085 , -16.f, -15.058824f, -13.960785f, -13.0196085f,
-11.92157 , -10.980392 , -10.039217 , -8.941177 , -11.92157f, -10.980392f, -10.039217f, -8.941177f,
-8.000001 , -7.0588236 , -5.960785 , -5.0196085 , -8.000001f, -7.0588236f, -5.960785f, -5.0196085f,
-3.9215698 , -2.9803925 , -2.039217 , -0.94117737, -3.9215698f, -2.9803925f, -2.039217f, -0.94117737f,
0. , 0.94117737, 2.039215 , 2.9803925 , 0.f, 0.94117737f, 2.039215f, 2.9803925f,
4.07843 , 5.0196075 , 5.960783 , 7.0588226 , 4.07843f, 5.0196075f, 5.960783f, 7.0588226f,
8. , 8.941177 , 10.039215 , 10.980392 , 8.f, 8.941177f, 10.039215f, 10.980392f,
12.07843 , 13.019608 , 13.960783 , 15.058823 , 12.07843f, 13.019608f, 13.960783f, 15.058823f,
16. , 16.941177 , 18.039217 , 18.980392 , 16.f, 16.941177f, 18.039217f, 18.980392f,
20.07843 , 21.019608 , 21.960783 , 23.058823 , 20.07843f, 21.019608f, 21.960783f, 23.058823f,
20.07843 , 21.019608 , 21.960783 , 23.058823 , 20.07843f, 21.019608f, 21.960783f, 23.058823f,
20.07843 , 21.019608 , 21.960783 , 23.058823 , 20.07843f, 21.019608f, 21.960783f, 23.058823f,
20.07843 , 21.019608 , 21.960783 , 23.058823 , 20.07843f, 21.019608f, 21.960783f, 23.058823f,
20.07843 , 21.019608 , 21.960783 , 23.058823 , 20.07843f, 21.019608f, 21.960783f, 23.058823f,
20.07843 , 21.019608 , 21.960783 , 23.058823 , 20.07843f, 21.019608f, 21.960783f, 23.058823f,
20.07843 , 21.019608 , 21.960783 , 23.058823 , 20.07843f, 21.019608f, 21.960783f, 23.058823f,
20.07843 , 21.019608 , 21.960783 , 23.058823 , 20.07843f, 21.019608f, 21.960783f, 23.058823f,
20.07843 , 21.019608 , 21.960783 , 23.058823 , 20.07843f, 21.019608f, 21.960783f, 23.058823f,
20.07843 , 21.019608 , 21.960783 , 23.058823 20.07843f, 21.019608f, 21.960783f, 23.058823f
}); });
NDArray min = NDArrayFactory::create<float>({-20., -19., -18., -17}); NDArray min = NDArrayFactory::create<float>({-20., -19., -18., -17});
NDArray max = NDArrayFactory::create<float>({20., 21., 22., 23}); NDArray max = NDArrayFactory::create<float>({20., 21., 22., 23});
@ -2833,9 +2903,10 @@ TEST_F(DeclarableOpsTests10, FakeQuantWithMinMaxVars_Test_6) {
// 0.0877f, 0.5966f, 0.6600f, 0.3513f, 0.1604f // 0.0877f, 0.5966f, 0.6600f, 0.3513f, 0.1604f
// }); // });
NDArray exp = NDArrayFactory::create<float>('c', {3,5}, {0.77700233, 0.596913, 0.72314, 0.23104, 0.50982356, NDArray exp = NDArrayFactory::create<float>('c', {3,5}, {
0.17930824, 0.50528157, 0.86846, 0.34995764, 0.50982356, 0.77700233f, 0.596913f, 0.72314f, 0.23104f, 0.50982356f,
0.08735529, 0.596913, 0.6574, 0.34995764, 0.15974471}); 0.17930824f, 0.50528157f, 0.86846f, 0.34995764f, 0.50982356f,
0.08735529f, 0.596913f, 0.6574f, 0.34995764f, 0.15974471f});
NDArray min = NDArrayFactory::create<float>('c', {5}, {-0.2283f, -0.0719f, -0.0154f, -0.5162f, -0.3567f}); NDArray min = NDArrayFactory::create<float>('c', {5}, {-0.2283f, -0.0719f, -0.0154f, -0.5162f, -0.3567f});
NDArray max = NDArrayFactory::create<float>('c', {5}, {0.9441f, 0.5957f, 0.8669f, 0.3502f, 0.5100f}); NDArray max = NDArrayFactory::create<float>('c', {5}, {0.9441f, 0.5957f, 0.8669f, 0.3502f, 0.5100f});
// x.linspace(-60.); // x.linspace(-60.);
@ -2856,45 +2927,74 @@ TEST_F(DeclarableOpsTests10, FakeQuantWithMinMaxVars_Test_6) {
delete results; delete results;
} }
//////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////
//TEST_F(DeclarableOpsTests10, FakeQuantWithMinMaxVars_Test_7) { TEST_F(DeclarableOpsTests10, FakeQuantWithMinMaxVars_Test_7) {
//
// NDArray x = NDArrayFactory::create<double>('c', {100}); NDArray x = NDArrayFactory::create<float>('c', {100});
// NDArray exp = NDArrayFactory::create<double>('c', {100}, { NDArray exp = NDArrayFactory::create<float>('c', {100}, {
// 0.f, 0.f, 0.f , 0.f , 0.06666667f, 0.06666667f , 0.f, 0.01176471f, 0.01960784f, 0.03137255f, 0.03921569f,
// 0.06666667, 0.06666667, 0.06666667, 0.06666667, 0.06666667, 0.13333334 , 0.0509804f, 0.05882353f, 0.07058824f, 0.07843138f, 0.09019608f,
// 0.13333334, 0.13333334, 0.13333334, 0.13333334, 0.13333334, 0.20000002 , 0.09803922f, 0.10980393f, 0.12156864f, 0.12941177f, 0.14117648f,
// 0.20000002, 0.20000002, 0.20000002, 0.20000002, 0.20000002, 0.20000002 , 0.14901961f, 0.16078432f, 0.16862746f, 0.18039216f, 0.18823531f,
// 0.26666668, 0.26666668, 0.26666668, 0.26666668, 0.26666668, 0.26666668 , 0.20000002f, 0.21176472f, 0.21960786f, 0.23137257f, 0.2392157f,
// 0.26666668, 0.33333334, 0.33333334, 0.33333334, 0.33333334, 0.33333334 , 0.2509804f, 0.25882354f, 0.27058825f, 0.2784314f, 0.2901961f,
// 0.33333334, 0.40000004, 0.40000004, 0.40000004, 0.40000004, 0.40000004 , 0.3019608f, 0.30980393f, 0.32156864f, 0.32941177f, 0.34117648f,
// 0.40000004, 0.40000004, 0.4666667 , 0.4666667 , 0.4666667 , 0.4666667 , 0.34901962f, 0.36078432f, 0.36862746f, 0.3803922f, 0.38823533f,
// 0.4666667 , 0.4666667 , 0.4666667 , 0.53333336, 0.53333336, 0.53333336 , 0.40000004f, 0.41176474f, 0.41960788f, 0.43137258f, 0.43921572f,
// 0.53333336, 0.53333336, 0.53333336, 0.6 , 0.6 , 0.6 , 0.45098042f, 0.45882356f, 0.47058827f, 0.4784314f, 0.4901961f,
// 0.6 , 0.6 , 0.6 , 0.6 , 0.6666667 , 0.6666667 , 0.49803925f, 0.50980395f, 0.52156866f, 0.5294118f, 0.5411765f,
// 0.6666667 , 0.6666667 , 0.6666667 , 0.6666667 , 0.6666667 , 0.73333335 , 0.54901963f, 0.56078434f, 0.5686275f, 0.5803922f, 0.5882353f,
// 0.73333335, 0.73333335, 0.73333335, 0.73333335, 0.73333335, 0.8000001 , 0.6f, 0.6117647f, 0.61960787f, 0.6313726f, 0.6392157f,
// 0.8000001 , 0.8000001 , 0.8000001 , 0.8000001 , 0.8000001 , 0.8000001 , 0.6509804f, 0.65882355f, 0.67058825f, 0.6784314f, 0.6901961f,
// 0.86666673, 0.86666673, 0.86666673, 0.86666673, 0.86666673, 0.86666673 , 0.69803923f, 0.70980394f, 0.72156864f, 0.7294118f, 0.7411765f,
// 0.86666673, 0.9333334 , 0.9333334 , 0.9333334 , 0.9333334 , 0.9333334 , 0.7490196f, 0.7607844f, 0.7686275f, 0.7803922f, 0.78823537f,
// 0.9333334 , 1., 1., 1., 0.8000001f, 0.8117648f, 0.8196079f, 0.8313726f, 0.83921576f,
// }); 0.85098046f, 0.8588236f, 0.8705883f, 0.87843144f, 0.89019614f,
// NDArray min = NDArrayFactory::create<float>('c', {1},{0.0f}); 0.8980393f, 0.909804f, 0.9215687f, 0.9294118f, 0.94117653f,
// NDArray max = NDArrayFactory::create<float>('c', {1}, {1.f}); 0.9490197f, 0.9607844f, 0.9686275f, 0.9803922f, 0.98823535f
// x.linspace(0., 0.01); });
// nd4j::ops::fake_quant_with_min_max_vars op; NDArray min = NDArrayFactory::create<float>('c', {1},{0.0f});
// auto results = op.execute({&x, &min, &max}, {}, {}); NDArray max = NDArrayFactory::create<float>('c', {1}, {1.f});
// x.linspace(0., 0.01);
// ASSERT_EQ(ND4J_STATUS_OK, results->status()); nd4j::ops::fake_quant_with_min_max_vars op;
// auto results = op.execute({&x, &min, &max}, {}, {});
// auto result = results->at(0);
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto result = results->at(0);
// result->printBuffer("Quantized7"); // result->printBuffer("Quantized7");
// exp.printBuffer("Expected 7"); // exp.printBuffer("Expected 7");
// ASSERT_TRUE(exp.isSameShapeStrict(result)); ASSERT_TRUE(exp.isSameShapeStrict(result));
// ASSERT_TRUE(exp.equalsTo(result)); ASSERT_TRUE(exp.equalsTo(result));
//
// delete results; delete results;
//} }
//////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests10, FakeQuantWithMinMaxVars_Test_8) {
NDArray x = NDArrayFactory::create<float>('c', {10});
NDArray exp = NDArrayFactory::create<float>('c', {10}, {
0.f, 0.09803922f, 0.20000002f, 0.3019608f, 0.40000004f, 0.49803925f,
0.6f, 0.69803923f, 0.8000001f, 0.8980393f
});
NDArray min = NDArrayFactory::create<float>('c', {1},{0.0f});
NDArray max = NDArrayFactory::create<float>('c', {1}, {1.f});
x.linspace(0., 0.1);
nd4j::ops::fake_quant_with_min_max_vars op;
auto results = op.execute({&x, &min, &max}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto result = results->at(0);
// x.printBuffer("SourInput8");
// result->printBuffer("Quantized8");
// exp.printBuffer("Expected 8");
ASSERT_TRUE(exp.isSameShapeStrict(result));
ASSERT_TRUE(exp.equalsTo(result));
delete results;
}
//////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests10, batchnorm_test1) { TEST_F(DeclarableOpsTests10, batchnorm_test1) {

View File

@ -1523,14 +1523,12 @@ TEST_F(DeclarableOpsTests6, LogDet_1) {
auto x = NDArrayFactory::create<double>('c', {2, 3, 3}, {4,12,-16,12,37,-43,-16,-43,98, 4,1.2,-1.6,1.2,3.7,-4.3,-1.6,-4.3,9.8}); auto x = NDArrayFactory::create<double>('c', {2, 3, 3}, {4,12,-16,12,37,-43,-16,-43,98, 4,1.2,-1.6,1.2,3.7,-4.3,-1.6,-4.3,9.8});
auto exp = NDArrayFactory::create<double>({ 3.5835189, 4.159008}); auto exp = NDArrayFactory::create<double>({ 3.5835189, 4.159008});
//x.printIndexedBuffer("Input");
nd4j::ops::logdet op; nd4j::ops::logdet op;
auto result = op.execute({&x}, {}, {}); auto result = op.execute({&x}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_EQ(ND4J_STATUS_OK, result->status());
auto z = result->at(0); auto z = result->at(0);
ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z)); ASSERT_TRUE(exp.equalsTo(z));

View File

@ -1813,8 +1813,8 @@ TEST_F(DeclarableOpsTests7, TestUnsortedSegmentSqrtN_5) {
//////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests7, TestSegmentSum_1) { TEST_F(DeclarableOpsTests7, TestSegmentSum_1) {
auto x = NDArrayFactory::create<double>({1.8, 2.5,4., 9., 2.1, 2.4,3.,9., 2.1, 2.1,0.7, 0.1, 3., 4.2, 2.2, 1. }); auto x = NDArrayFactory::create<double>({1.8, 2.5, 4., 9., 2.1, 2.4, 3., 9., 2.1, 2.1, 0.7, 0.1, 3., 4.2, 2.2, 1. });
auto idx = NDArrayFactory::create<int>({0, 0, 1, 1, 1, 1, 2, 3, 3, 3, 4, 4, 4, 4, 4, 4}); auto idx = NDArrayFactory::create<int>({ 0, 0, 1, 1, 1, 1, 2, 3, 3, 3, 4, 4, 4, 4, 4, 4});
auto exp = NDArrayFactory::create<double>({4.3, 17.5, 3., 13.2, 11.2}); auto exp = NDArrayFactory::create<double>({4.3, 17.5, 3., 13.2, 11.2});
nd4j::ops::segment_sum op; nd4j::ops::segment_sum op;
@ -2231,6 +2231,27 @@ TEST_F(DeclarableOpsTests7, TestSegmentProd_05) {
auto result = op.execute({&x, &idx}, {}, {}); auto result = op.execute({&x, &idx}, {}, {});
ASSERT_EQ(result->status(), Status::OK()); ASSERT_EQ(result->status(), Status::OK());
auto res = result->at(0);
// res->printIndexedBuffer("Segment prod 05");
ASSERT_TRUE(exp.equalsTo(result->at(0)));
delete result;
}
////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests7, TestSegmentProd_05_1) {
auto x = NDArrayFactory::create<int>({1,2,3,4,5,6,7,8 });
// ----------------------------------------------------------------
auto idx = NDArrayFactory::create<int>({0,0,1,2,2,2,3,3});
auto exp = NDArrayFactory::create<int>({ 2, 3, 120, 56});
nd4j::ops::segment_prod op;
auto result = op.execute({&x, &idx}, {}, {});
ASSERT_EQ(result->status(), Status::OK());
auto res = result->at(0);
// res->printIndexedBuffer("Segment prod 05_1");
ASSERT_TRUE(exp.equalsTo(result->at(0))); ASSERT_TRUE(exp.equalsTo(result->at(0)));
delete result; delete result;
@ -2270,6 +2291,23 @@ TEST_F(DeclarableOpsTests7, TestSegmentProd_07) {
delete result; delete result;
} }
////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests7, TestSegmentProd_08) {
auto x = NDArrayFactory::create<int>({'\x1','\x2','\x3','\x4','\x5','\x6','\x7','\x8', '\x9', '\xA' });
// ----------------------------------------------------------------
auto idx = NDArrayFactory::create<int>({0,0,2,2,2,2,3,3,3,3});
auto exp = NDArrayFactory::create<int>({ 2, 1,360, 5040});
nd4j::ops::segment_prod op;
auto result = op.execute({&x, &idx}, {}, {});
ASSERT_EQ(result->status(), Status::OK());
ASSERT_TRUE(exp.equalsTo(result->at(0)));
delete result;
}
//////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests7, TestUnsortedSegmentProd_1) { TEST_F(DeclarableOpsTests7, TestUnsortedSegmentProd_1) {
auto x = NDArrayFactory::create<double>({1.8, 2.5,4., 9., 2.1, 2.4,3.,9., 2.1, 2.1,0.7, 0.1, 3., 4.2, 2.2, 1.}); auto x = NDArrayFactory::create<double>({1.8, 2.5,4., 9., 2.1, 2.4,3.,9., 2.1, 2.1,0.7, 0.1, 3., 4.2, 2.2, 1.});
@ -2341,6 +2379,22 @@ TEST_F(DeclarableOpsTests7, TestUnsortedSegmentProd_12) {
delete result; delete result;
} }
////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests7, TestUnsortedSegmentProd_08) {
auto x = NDArrayFactory::create<int>({'\x1','\x2','\x3','\x4','\x5','\x6','\x7','\x8', '\x9', '\xA' });
// ----------------------------------------------------------------
auto idx = NDArrayFactory::create<int>({0,0,2,2,2,2,3,3,3,3});
auto exp = NDArrayFactory::create<int>({ 2, 1,360, 5040});
nd4j::ops::unsorted_segment_prod op;
auto result = op.execute({&x, &idx}, {}, {4});
ASSERT_EQ(result->status(), Status::OK());
ASSERT_TRUE(exp.equalsTo(result->at(0)));
delete result;
}
//////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests7, TestUnsortedSegmentProd_3) { TEST_F(DeclarableOpsTests7, TestUnsortedSegmentProd_3) {
@ -2401,6 +2455,41 @@ TEST_F(DeclarableOpsTests7, TestUnsortedSegmentProd_4) {
delete result; delete result;
} }
////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests7, TestUnsortedSegmentProd_5) {
auto x = NDArrayFactory::create<double>('c', {8, 15});
// ----------------------------------------------------------------
auto idx = NDArrayFactory::create<int>({3, 1, 2, 1, 2, 3, 2, 1});
auto exp = NDArrayFactory::create<double>('c', {4, 15}, {
1., 1., 1., 1., 1.,
1., 1., 1., 1., 1.,
1., 1., 1., 1., 1.,
78016., 85493., 93312., 101479., 110000.,
118881., 128128., 137747., 147744., 158125.,
168896., 180063., 191632., 203609., 216000.,
172081., 182528., 193347., 204544., 216125.,
228096., 240463., 253232., 266409., 280000.,
294011., 308448., 323317., 338624., 354375.,
76., 154., 234., 316., 400.,
486., 574., 664., 756., 850.,
946., 1044., 1144., 1246., 1350.});
x.linspace(1.);
nd4j::ops::unsorted_segment_prod op;
auto result = op.execute({&x, &idx}, {}, {4});
ASSERT_EQ(result->status(), Status::OK());
//result->at(0)->printIndexedBuffer("Output");
// result->at(0)->printShapeInfo("Out Shape");
//exp.printIndexedBuffer("Expect");
// exp.printShapeInfo("Exp Shape");
ASSERT_TRUE(exp.equalsTo(result->at(0)));
delete result;
}
//////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests7, TestUnsortedSegmentProdBP_4) { TEST_F(DeclarableOpsTests7, TestUnsortedSegmentProdBP_4) {
auto x = NDArrayFactory::create<double>('c', {8}, { auto x = NDArrayFactory::create<double>('c', {8}, {

View File

@ -39,7 +39,7 @@ TEST_F(CudaWorkspaceTests, Basic_Tests_1) {
ctx.setWorkspace(&workspace); ctx.setWorkspace(&workspace);
auto array = NDArrayFactory::create<float>('c', {5, 5}, &ctx); auto array = NDArrayFactory::create<float>('c', {5, 5}, &ctx);
ASSERT_EQ(100, workspace.getCurrentOffset()); ASSERT_EQ(108, workspace.getCurrentOffset());
ASSERT_EQ(0, workspace.getCurrentSecondaryOffset()); ASSERT_EQ(0, workspace.getCurrentSecondaryOffset());
array.e<int>(0); array.e<int>(0);
@ -55,6 +55,6 @@ TEST_F(CudaWorkspaceTests, Basic_Tests_2) {
ctx.setWorkspace(&workspace); ctx.setWorkspace(&workspace);
auto array = NDArrayFactory::create<float>('c', {5, 5}, {0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f}, &ctx); auto array = NDArrayFactory::create<float>('c', {5, 5}, {0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f}, &ctx);
ASSERT_EQ(100, workspace.getCurrentOffset()); ASSERT_EQ(108, workspace.getCurrentOffset());
ASSERT_EQ(0, workspace.getCurrentSecondaryOffset()); ASSERT_EQ(0, workspace.getCurrentSecondaryOffset());
} }

View File

@ -86,6 +86,7 @@ public class ImportClassMapping {
org.nd4j.linalg.api.ops.impl.image.CropAndResize.class, org.nd4j.linalg.api.ops.impl.image.CropAndResize.class,
org.nd4j.linalg.api.ops.impl.image.ExtractImagePatches.class, org.nd4j.linalg.api.ops.impl.image.ExtractImagePatches.class,
org.nd4j.linalg.api.ops.impl.image.NonMaxSuppression.class, org.nd4j.linalg.api.ops.impl.image.NonMaxSuppression.class,
org.nd4j.linalg.api.ops.impl.image.NonMaxSuppressionV3.class,
org.nd4j.linalg.api.ops.impl.image.ResizeBilinear.class, org.nd4j.linalg.api.ops.impl.image.ResizeBilinear.class,
org.nd4j.linalg.api.ops.impl.image.ResizeNearestNeighbor.class, org.nd4j.linalg.api.ops.impl.image.ResizeNearestNeighbor.class,
org.nd4j.linalg.api.ops.impl.indexaccum.FirstIndex.class, org.nd4j.linalg.api.ops.impl.indexaccum.FirstIndex.class,

View File

@ -53,7 +53,7 @@ public class NonMaxSuppression extends DynamicCustomOp {
@Override @Override
public String[] tensorflowNames() { public String[] tensorflowNames() {
return new String[]{"NonMaxSuppression", "NonMaxSuppressionV2", "NonMaxSuppressionV3"}; return new String[]{"NonMaxSuppression", "NonMaxSuppressionV2"};
} }
@Override @Override

View File

@ -0,0 +1,79 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.nd4j.linalg.api.ops.impl.image;
import lombok.NonNull;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.imports.NoOpNameFoundException;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
import org.nd4j.linalg.api.ops.Op;
import java.util.Collections;
import java.util.List;
/**
* Non max suppression
*
* @author raver119@gmail.com
*/
public class NonMaxSuppressionV3 extends DynamicCustomOp {
public NonMaxSuppressionV3() {}
public NonMaxSuppressionV3(SameDiff sameDiff, @NonNull SDVariable boxes, @NonNull SDVariable scores, @NonNull SDVariable maxOutSize,
@NonNull SDVariable iouThreshold, @NonNull SDVariable scoreThreshold) {
super(null, sameDiff, new SDVariable[]{boxes, scores, maxOutSize, iouThreshold, scoreThreshold}, false);
}
@Override
public String onnxName() {
throw new NoOpNameFoundException("No onnx name found for shape " + opName());
}
@Override
public String tensorflowName() {
return "NonMaxSuppressionV3";
}
@Override
public String[] tensorflowNames() {
return new String[]{"NonMaxSuppressionV3","NonMaxSuppressionV4"};
}
@Override
public String opName() {
return "non_max_suppression_v3";
}
@Override
public Op.Type opType() {
return Op.Type.CUSTOM;
}
@Override
public List<SDVariable> doDiff(List<SDVariable> i_v) {
return Collections.singletonList(sameDiff.zerosLike(arg()));
}
@Override
public List<DataType> calculateOutputDataTypes(List<DataType> inputDataTypes){
//Always 1D integer tensor (indices)
return Collections.singletonList(DataType.INT);
}
}

View File

@ -20827,7 +20827,7 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD();
// #endif // #endif
/** /**
* image.non_max_suppression op. * image.non_max_suppression ops.
* input: * input:
* 0 - boxes - 2D-tensor with shape (num_boxes, 4) by float type * 0 - boxes - 2D-tensor with shape (num_boxes, 4) by float type
* 1 - scales - 1D-tensor with shape (num_boxes) by float type * 1 - scales - 1D-tensor with shape (num_boxes) by float type
@ -20859,6 +20859,23 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD();
public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block);
} }
// #endif // #endif
// #if NOT_EXCLUDED(OP_image_non_max_suppression_v3)
@Namespace("nd4j::ops") public static class non_max_suppression_v3 extends DeclarableCustomOp {
static { Loader.load(); }
/** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */
public non_max_suppression_v3(Pointer p) { super(p); }
/** Native array allocator. Access with {@link Pointer#position(long)}. */
public non_max_suppression_v3(long size) { super((Pointer)null); allocateArray(size); }
private native void allocateArray(long size);
@Override public non_max_suppression_v3 position(long position) {
return (non_max_suppression_v3)super.position(position);
}
public non_max_suppression_v3() { super((Pointer)null); allocate(); }
private native void allocate();
public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block);
}
// #endif
/* /*
* image.non_max_suppression_overlaps op. * image.non_max_suppression_overlaps op.