Shugeo release fix2 (#70)
* Corrected input checking and tests for bitcast op. * Fixed an issue with non_max_suppression form generation and processing with score threshold given. * Fixed bilinear resize kernel and tests. * push for Serhii Signed-off-by: raver119 <raver119@gmail.com> * Added test for nearest_neighbor resize with int input. * Added data type check for input/output match. * Eliminate error in macros. * Improved output message for type checking. * Fixed input/output types for op. * Eliminated waste logging. * Refactored resize_bilinear helper for multithreading for cpu platform. * Cosmetic changes only. * Fixed error for string substitution. * Skip test for cbow_batch with cuda. * fix for resizeNearestNeighbor output dtype Signed-off-by: raver119 <raver119@gmail.com> * Refactored non_max_suppression helper. * Refactored shape generation and input handling. * Added additional test.master
parent
289d9dc141
commit
4187190609
|
@ -30,6 +30,17 @@ namespace nd4j {
|
||||||
auto input = INPUT_VARIABLE(0);
|
auto input = INPUT_VARIABLE(0);
|
||||||
auto output = OUTPUT_VARIABLE(0);
|
auto output = OUTPUT_VARIABLE(0);
|
||||||
// when empty - nothing to do
|
// when empty - nothing to do
|
||||||
|
DataType newType = DataTypeUtils::fromInt(INT_ARG(0));
|
||||||
|
DataType oldType = input->dataType();
|
||||||
|
// correct output shape to conform with output data type
|
||||||
|
auto inputSize = DataTypeUtils::sizeOf(oldType);
|
||||||
|
auto outputSize = DataTypeUtils::sizeOf(newType);
|
||||||
|
auto lastSize = outputSize / inputSize;
|
||||||
|
if (inputSize < outputSize) {
|
||||||
|
REQUIRE_TRUE(input->sizeAt(-1) == lastSize, 0,
|
||||||
|
"BITCAST: %llu > %llu. So last dimension should be %i, but %i given.", inputSize,
|
||||||
|
outputSize, lastSize, input->sizeAt(-1));
|
||||||
|
}
|
||||||
if(input->isEmpty()){
|
if(input->isEmpty()){
|
||||||
REQUIRE_TRUE(output->isEmpty(), 0, "BITCAST: If input is empty, output array must also be empty.");
|
REQUIRE_TRUE(output->isEmpty(), 0, "BITCAST: If input is empty, output array must also be empty.");
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
|
@ -70,7 +81,7 @@ namespace nd4j {
|
||||||
auto outputShape = ConstantShapeHelper::getInstance()->createShapeInfo(newType, shape::order(inShape), shapeOf);
|
auto outputShape = ConstantShapeHelper::getInstance()->createShapeInfo(newType, shape::order(inShape), shapeOf);
|
||||||
return SHAPELIST(outputShape);
|
return SHAPELIST(outputShape);
|
||||||
}
|
}
|
||||||
REQUIRE_TRUE(shape::sizeAt(inShape, -1) == outputSize / inputSize, 0, "BITCAST: %ull > %ull. So last dimension should be %ull, but %i given.", inputSize, outputSize, outputSize / inputSize, shape::sizeAt(inShape, -1));
|
REQUIRE_TRUE(shape::sizeAt(inShape, -1) == outputSize / inputSize, 0, "BITCAST: %llu > %llu. So last dimension should be %i, but %i given.", inputSize, outputSize, outputSize / inputSize, shape::sizeAt(inShape, -1));
|
||||||
std::vector<Nd4jLong> shapeOf(inputRank - 1);
|
std::vector<Nd4jLong> shapeOf(inputRank - 1);
|
||||||
|
|
||||||
for (auto i = 0; i < shapeOf.size(); ++i) {
|
for (auto i = 0; i < shapeOf.size(); ++i) {
|
||||||
|
|
|
@ -37,6 +37,22 @@ namespace nd4j {
|
||||||
else
|
else
|
||||||
REQUIRE_TRUE(false, 0, "image.non_max_suppression: Max output size argument cannot be retrieved.");
|
REQUIRE_TRUE(false, 0, "image.non_max_suppression: Max output size argument cannot be retrieved.");
|
||||||
|
|
||||||
|
double overlayThreshold = 0.5;
|
||||||
|
double scoreThreshold = - DataTypeUtils::infOrMax<float>();
|
||||||
|
|
||||||
|
if (block.width() > 3) {
|
||||||
|
overlayThreshold = INPUT_VARIABLE(3)->e<double>(0);
|
||||||
|
}
|
||||||
|
else if (block.getTArguments()->size() > 0) {
|
||||||
|
overlayThreshold = T_ARG(0);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (block.width() > 4) {
|
||||||
|
scoreThreshold = INPUT_VARIABLE(4)->e<double>(0);
|
||||||
|
}
|
||||||
|
else if (block.getTArguments()->size() > 1) {
|
||||||
|
scoreThreshold = T_ARG(1);
|
||||||
|
}
|
||||||
if (boxes->isEmpty() || scales->isEmpty())
|
if (boxes->isEmpty() || scales->isEmpty())
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
|
|
||||||
|
@ -44,15 +60,6 @@ namespace nd4j {
|
||||||
REQUIRE_TRUE(boxes->sizeAt(1) == 4, 0, "image.non_max_suppression: The last dimension of boxes array should be 4, but %i is given", boxes->sizeAt(1));
|
REQUIRE_TRUE(boxes->sizeAt(1) == 4, 0, "image.non_max_suppression: The last dimension of boxes array should be 4, but %i is given", boxes->sizeAt(1));
|
||||||
REQUIRE_TRUE(scales->rankOf() == 1 && scales->lengthOf() == boxes->sizeAt(0), 0, "image.non_max_suppression: The rank of scales array should be 1, but %i is given", boxes->rankOf());
|
REQUIRE_TRUE(scales->rankOf() == 1 && scales->lengthOf() == boxes->sizeAt(0), 0, "image.non_max_suppression: The rank of scales array should be 1, but %i is given", boxes->rankOf());
|
||||||
|
|
||||||
if (scales->lengthOf() < maxOutputSize)
|
|
||||||
maxOutputSize = scales->lengthOf();
|
|
||||||
double overlayThreshold = 0.5;
|
|
||||||
double scoreThreshold = - DataTypeUtils::infOrMax<float>();
|
|
||||||
if (block.getTArguments()->size() > 0)
|
|
||||||
overlayThreshold = T_ARG(0);
|
|
||||||
if (block.getTArguments()->size() > 1)
|
|
||||||
scoreThreshold = T_ARG(1);
|
|
||||||
|
|
||||||
helpers::nonMaxSuppression(block.launchContext(), boxes, scales, maxOutputSize, overlayThreshold, scoreThreshold, output);
|
helpers::nonMaxSuppression(block.launchContext(), boxes, scales, maxOutputSize, overlayThreshold, scoreThreshold, output);
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
@ -70,10 +77,19 @@ namespace nd4j {
|
||||||
else
|
else
|
||||||
REQUIRE_TRUE(false, 0, "image.non_max_suppression: Max output size argument cannot be retrieved.");
|
REQUIRE_TRUE(false, 0, "image.non_max_suppression: Max output size argument cannot be retrieved.");
|
||||||
|
|
||||||
|
auto actualIndicesCount = shape::sizeAt(in, 0);
|
||||||
Nd4jLong boxSize = shape::sizeAt(in, 0);
|
if (block.getTArguments()->size() > 1 || block.width() > 4) {
|
||||||
if (boxSize < maxOutputSize)
|
auto scoreThreshold = block.getTArguments()->size() > 1?T_ARG(1):INPUT_VARIABLE(4)->e<double>(0);
|
||||||
maxOutputSize = boxSize;
|
auto scales = INPUT_VARIABLE(1);
|
||||||
|
scales->syncToHost();
|
||||||
|
for (auto e = 0; e < scales->lengthOf(); e++) {
|
||||||
|
if (scales->e<float>(e) < (float)scoreThreshold) {
|
||||||
|
actualIndicesCount--;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (actualIndicesCount < maxOutputSize)
|
||||||
|
maxOutputSize = actualIndicesCount;
|
||||||
|
|
||||||
outputShape = ConstantShapeHelper::getInstance()->vectorShapeInfo(maxOutputSize, DataType::INT32);
|
outputShape = ConstantShapeHelper::getInstance()->vectorShapeInfo(maxOutputSize, DataType::INT32);
|
||||||
|
|
||||||
|
|
|
@ -54,7 +54,9 @@ namespace nd4j {
|
||||||
auto inRank = image->rankOf();
|
auto inRank = image->rankOf();
|
||||||
REQUIRE_TRUE(inRank == 4 || inRank == 3, 0, "resize_nearest_neighbor: Input should be 4D tensor, but rank %i occured");
|
REQUIRE_TRUE(inRank == 4 || inRank == 3, 0, "resize_nearest_neighbor: Input should be 4D tensor, but rank %i occured");
|
||||||
REQUIRE_TRUE(inRank == output->rankOf(), 0, "resize_nearest_neighbor: Input and output ranks should be equals, but %i and %i occured.", inRank, output->rankOf());
|
REQUIRE_TRUE(inRank == output->rankOf(), 0, "resize_nearest_neighbor: Input and output ranks should be equals, but %i and %i occured.", inRank, output->rankOf());
|
||||||
|
REQUIRE_TRUE(image->dataType() == output->dataType(), 0, "resize_nearest_neighbor: Input and output types should be the same, but `%s' occured instead.", DataTypeUtils::asString(output->dataType()).c_str());
|
||||||
auto source = inRank == 4?*image:image->reshape(image->ordering(), {1, image->sizeAt(0), image->sizeAt(1), image->sizeAt(2)});
|
auto source = inRank == 4?*image:image->reshape(image->ordering(), {1, image->sizeAt(0), image->sizeAt(1), image->sizeAt(2)});
|
||||||
|
|
||||||
auto target = inRank == 4?*output:output->reshape(output->ordering(), {1, output->sizeAt(0), output->sizeAt(1), output->sizeAt(2)});
|
auto target = inRank == 4?*output:output->reshape(output->ordering(), {1, output->sizeAt(0), output->sizeAt(1), output->sizeAt(2)});
|
||||||
|
|
||||||
return helpers::resizeNeighborFunctor(block.launchContext(), inRank==4?image:&source, width, height, center, inRank == 4?output:&target);
|
return helpers::resizeNeighborFunctor(block.launchContext(), inRank==4?image:&source, width, height, center, inRank == 4?output:&target);
|
||||||
|
@ -105,8 +107,8 @@ namespace nd4j {
|
||||||
}
|
}
|
||||||
DECLARE_TYPES(resize_nearest_neighbor) {
|
DECLARE_TYPES(resize_nearest_neighbor) {
|
||||||
getOpDescriptor()
|
getOpDescriptor()
|
||||||
->setAllowedInputTypes(nd4j::DataType::ANY)
|
->setAllowedInputTypes({ALL_INTS, ALL_FLOATS})
|
||||||
->setAllowedOutputTypes({ALL_FLOATS});
|
->setAllowedOutputTypes({ALL_INTS, ALL_FLOATS});
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -26,7 +26,6 @@
|
||||||
namespace nd4j {
|
namespace nd4j {
|
||||||
namespace ops {
|
namespace ops {
|
||||||
|
|
||||||
|
|
||||||
CUSTOM_OP_IMPL(broadcast_to, 2, 1, false, 0, 0) {
|
CUSTOM_OP_IMPL(broadcast_to, 2, 1, false, 0, 0) {
|
||||||
|
|
||||||
auto input = INPUT_VARIABLE(0);
|
auto input = INPUT_VARIABLE(0);
|
||||||
|
|
|
@ -41,11 +41,11 @@ namespace ops {
|
||||||
namespace helpers {
|
namespace helpers {
|
||||||
|
|
||||||
struct BilinearInterpolationData {
|
struct BilinearInterpolationData {
|
||||||
Nd4jLong bottomIndex; // Lower source index used in the interpolation
|
Nd4jLong _bottomIndex; // Lower source index used in the interpolation
|
||||||
Nd4jLong topIndex; // Upper source index used in the interpolation
|
Nd4jLong _topIndex; // Upper source index used in the interpolation
|
||||||
// 1-D linear iterpolation scale (see:
|
// 1D linear iterpolation scale (see:
|
||||||
// https://en.wikipedia.org/wiki/Bilinear_interpolation)
|
// https://en.wikipedia.org/wiki/Bilinear_interpolation)
|
||||||
double interpolarValue;
|
double _interpolarValue;
|
||||||
};
|
};
|
||||||
// calculateResizeScale determines the float scaling factor.
|
// calculateResizeScale determines the float scaling factor.
|
||||||
inline float calculateResizeScale(Nd4jLong inSize, Nd4jLong outSize,
|
inline float calculateResizeScale(Nd4jLong inSize, Nd4jLong outSize,
|
||||||
|
@ -137,16 +137,16 @@ namespace helpers {
|
||||||
Nd4jLong inSize,
|
Nd4jLong inSize,
|
||||||
double scale,
|
double scale,
|
||||||
BilinearInterpolationData *interpolationData) {
|
BilinearInterpolationData *interpolationData) {
|
||||||
interpolationData[outSize].bottomIndex = 0;
|
interpolationData[outSize]._bottomIndex = 0;
|
||||||
interpolationData[outSize].topIndex = 0;
|
interpolationData[outSize]._topIndex = 0;
|
||||||
|
|
||||||
auto func = PRAGMA_THREADS_FOR {
|
auto func = PRAGMA_THREADS_FOR {
|
||||||
for (auto k = start; k < stop; k++) {
|
for (auto k = start; k < stop; k++) {
|
||||||
auto i = (outSize - k - 1);
|
auto i = (outSize - k - 1);
|
||||||
double in = i * scale;
|
double in = i * scale;
|
||||||
interpolationData[i].bottomIndex = static_cast<Nd4jLong>(in);
|
interpolationData[i]._bottomIndex = static_cast<Nd4jLong>(in);
|
||||||
interpolationData[i].topIndex = nd4j::math::nd4j_min(interpolationData[i].bottomIndex + 1, inSize - 1);
|
interpolationData[i]._topIndex = nd4j::math::nd4j_min(interpolationData[i]._bottomIndex + 1, inSize - 1);
|
||||||
interpolationData[i].interpolarValue = in - interpolationData[i].bottomIndex;
|
interpolationData[i]._interpolarValue = in - interpolationData[i]._bottomIndex;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
samediff::Threads::parallel_for(func, 0, outSize);
|
samediff::Threads::parallel_for(func, 0, outSize);
|
||||||
|
@ -159,8 +159,8 @@ namespace helpers {
|
||||||
static void
|
static void
|
||||||
resizeImage(NDArray const *images, Nd4jLong batchSize, Nd4jLong inHeight, Nd4jLong inWidth, Nd4jLong outHeight,
|
resizeImage(NDArray const *images, Nd4jLong batchSize, Nd4jLong inHeight, Nd4jLong inWidth, Nd4jLong outHeight,
|
||||||
Nd4jLong outWidth, Nd4jLong channels,
|
Nd4jLong outWidth, Nd4jLong channels,
|
||||||
std::vector<BilinearInterpolationData> const &xs,
|
std::vector<BilinearInterpolationData> const& xs,
|
||||||
std::vector<BilinearInterpolationData> const &ys,
|
std::vector<BilinearInterpolationData> const& ys,
|
||||||
NDArray *output);
|
NDArray *output);
|
||||||
|
|
||||||
template<typename T>
|
template<typename T>
|
||||||
|
@ -175,10 +175,10 @@ namespace helpers {
|
||||||
Nd4jLong inBatchNumValues = inHeight * inRowSize;
|
Nd4jLong inBatchNumValues = inHeight * inRowSize;
|
||||||
Nd4jLong outRowSize = outWidth * channels;
|
Nd4jLong outRowSize = outWidth * channels;
|
||||||
|
|
||||||
T const *pInput = images->getDataBuffer()->primaryAsT<T>(); // this works only with 'c' direction
|
T const *pInputBuf = images->getDataBuffer()->primaryAsT<T>(); // this works only with 'c' direction
|
||||||
BilinearInterpolationData const *xs_ = xs.data();
|
BilinearInterpolationData const* xsPtr = xs.data();
|
||||||
|
|
||||||
T* pOutput = output->dataBuffer()->primaryAsT<T>();
|
T* pOutputBuf = output->dataBuffer()->primaryAsT<T>();
|
||||||
auto computeBilinear = [](double topLeft, double topRight,
|
auto computeBilinear = [](double topLeft, double topRight,
|
||||||
double bottomLeft, double bottomRight,
|
double bottomLeft, double bottomRight,
|
||||||
double xVal, double yVal) {
|
double xVal, double yVal) {
|
||||||
|
@ -187,32 +187,31 @@ namespace helpers {
|
||||||
return top + (bottom - top) * yVal;
|
return top + (bottom - top) * yVal;
|
||||||
};
|
};
|
||||||
|
|
||||||
auto func = PRAGMA_THREADS_FOR {
|
auto func = PRAGMA_THREADS_FOR {
|
||||||
for (auto b = start; b < stop; ++b) {
|
for (auto batch = start; batch < stop; ++batch) {
|
||||||
for (auto y = 0; y < outHeight; ++y) {
|
auto pInput = pInputBuf + batch * inBatchNumValues;
|
||||||
const T *ys_input_lower_ptr = pInput + ys[y].bottomIndex * inRowSize;
|
for (auto y = 0; y < outHeight; ++y) {
|
||||||
const T *ys_input_upper_ptr = pInput + ys[y].topIndex * inRowSize;
|
auto pOutput = pOutputBuf + (batch * outHeight + y) * outRowSize;
|
||||||
double yVal = ys[y].interpolarValue;
|
const T* ysInputLowerPtr = pInput + ys[y]._bottomIndex * inRowSize;
|
||||||
for (auto x = 0; x < outWidth; ++x) {
|
const T* ysInputUpperPtr = pInput + ys[y]._topIndex * inRowSize;
|
||||||
auto xsBottom = xs_[x].bottomIndex;
|
double yVal = ys[y]._interpolarValue;
|
||||||
auto xsTop = xs_[x].topIndex;
|
for (auto x = 0; x < outWidth; ++x) {
|
||||||
auto xVal = xs_[x].interpolarValue;
|
auto xsBottom = xsPtr[x]._bottomIndex;
|
||||||
for (auto c = 0; c < channels; ++c) {
|
auto xsTop = xsPtr[x]._topIndex;
|
||||||
double topLeft(ys_input_lower_ptr[xsBottom + c]);
|
auto xVal = xsPtr[x]._interpolarValue;
|
||||||
double topRight(ys_input_lower_ptr[xsTop + c]);
|
for (auto c = 0; c < channels; ++c) {
|
||||||
double bottomLeft(ys_input_upper_ptr[xsBottom + c]);
|
double topLeft(ysInputLowerPtr[xsBottom + c]);
|
||||||
double bottomRight(ys_input_upper_ptr[xsTop + c]);
|
double topRight(ysInputLowerPtr[xsTop + c]);
|
||||||
pOutput[x * channels + c] =
|
double bottomLeft(ysInputUpperPtr[xsBottom + c]);
|
||||||
computeBilinear(topLeft, topRight, bottomLeft, bottomRight,
|
double bottomRight(ysInputUpperPtr[xsTop + c]);
|
||||||
xVal, yVal);
|
pOutput[x * channels + c] = computeBilinear(topLeft, topRight, bottomLeft, bottomRight,
|
||||||
|
xVal, yVal);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
pOutput += outRowSize;
|
|
||||||
}
|
}
|
||||||
pInput += inBatchNumValues;
|
};
|
||||||
}
|
samediff::Threads::parallel_tad(func, 0, batchSize);
|
||||||
};
|
|
||||||
samediff::Threads::parallel_tad(func, 0, batchSize);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
template<typename T>
|
template<typename T>
|
||||||
|
@ -257,8 +256,8 @@ namespace helpers {
|
||||||
// Scale x interpolation weights to avoid a multiplication during iteration.
|
// Scale x interpolation weights to avoid a multiplication during iteration.
|
||||||
auto func = PRAGMA_THREADS_FOR {
|
auto func = PRAGMA_THREADS_FOR {
|
||||||
for (auto i = start; i < stop; i += increment) {
|
for (auto i = start; i < stop; i += increment) {
|
||||||
xs[i].bottomIndex *= channels;
|
xs[i]._bottomIndex *= channels;
|
||||||
xs[i].topIndex *= channels;
|
xs[i]._topIndex *= channels;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
samediff::Threads::parallel_for(func, 0, xsSize);
|
samediff::Threads::parallel_for(func, 0, xsSize);
|
||||||
|
|
|
@ -33,23 +33,27 @@ namespace helpers {
|
||||||
double scoreThreshold, NDArray* output) {
|
double scoreThreshold, NDArray* output) {
|
||||||
std::vector<int> indices(scales->lengthOf());
|
std::vector<int> indices(scales->lengthOf());
|
||||||
std::iota(indices.begin(), indices.end(), 0);
|
std::iota(indices.begin(), indices.end(), 0);
|
||||||
|
auto actualIndicesCount = indices.size();
|
||||||
for (auto e = 0; e < scales->lengthOf(); e++) {
|
for (auto e = 0; e < scales->lengthOf(); e++) {
|
||||||
if (scales->e<double>(e) < scoreThreshold) indices[e] = -1;
|
if (scales->e<float>(e) < (float)scoreThreshold) {
|
||||||
|
indices[e] = -1;
|
||||||
|
actualIndicesCount--;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
std::sort(indices.begin(), indices.end(), [scales](int i, int j) {return scales->e<T>(i) > scales->e<T>(j);});
|
std::sort(indices.begin(), indices.end(), [scales](int i, int j) {return i >= 0 && j >=0?scales->e<T>(i) > scales->e<T>(j):(i > j);});
|
||||||
|
|
||||||
// std::vector<int> selected(output->lengthOf());
|
// std::vector<int> selected(output->lengthOf());
|
||||||
std::vector<int> selectedIndices(output->lengthOf(), 0);
|
std::vector<int> selectedIndices(output->lengthOf(), 0);
|
||||||
auto needToSuppressWithThreshold = [] (NDArray& boxes, int previousIndex, int nextIndex, T threshold) -> bool {
|
auto needToSuppressWithThreshold = [] (NDArray& boxes, int previousIndex, int nextIndex, T threshold) -> bool {
|
||||||
if (previousIndex < 0 || nextIndex < 0) return true;
|
if (previousIndex < 0 || nextIndex < 0) return true;
|
||||||
T minYPrev = nd4j::math::nd4j_min(boxes.e<T>(previousIndex, 0), boxes.e<T>(previousIndex, 2));
|
T minYPrev = nd4j::math::nd4j_min(boxes.t<T>(previousIndex, 0), boxes.t<T>(previousIndex, 2));
|
||||||
T minXPrev = nd4j::math::nd4j_min(boxes.e<T>(previousIndex, 1), boxes.e<T>(previousIndex, 3));
|
T minXPrev = nd4j::math::nd4j_min(boxes.t<T>(previousIndex, 1), boxes.t<T>(previousIndex, 3));
|
||||||
T maxYPrev = nd4j::math::nd4j_max(boxes.e<T>(previousIndex, 0), boxes.e<T>(previousIndex, 2));
|
T maxYPrev = nd4j::math::nd4j_max(boxes.t<T>(previousIndex, 0), boxes.t<T>(previousIndex, 2));
|
||||||
T maxXPrev = nd4j::math::nd4j_max(boxes.e<T>(previousIndex, 1), boxes.e<T>(previousIndex, 3));
|
T maxXPrev = nd4j::math::nd4j_max(boxes.t<T>(previousIndex, 1), boxes.t<T>(previousIndex, 3));
|
||||||
T minYNext = nd4j::math::nd4j_min(boxes.e<T>(nextIndex, 0), boxes.e<T>(nextIndex, 2));
|
T minYNext = nd4j::math::nd4j_min(boxes.t<T>(nextIndex, 0), boxes.t<T>(nextIndex, 2));
|
||||||
T minXNext = nd4j::math::nd4j_min(boxes.e<T>(nextIndex, 1), boxes.e<T>(nextIndex, 3));
|
T minXNext = nd4j::math::nd4j_min(boxes.t<T>(nextIndex, 1), boxes.t<T>(nextIndex, 3));
|
||||||
T maxYNext = nd4j::math::nd4j_max(boxes.e<T>(nextIndex, 0), boxes.e<T>(nextIndex, 2));
|
T maxYNext = nd4j::math::nd4j_max(boxes.t<T>(nextIndex, 0), boxes.t<T>(nextIndex, 2));
|
||||||
T maxXNext = nd4j::math::nd4j_max(boxes.e<T>(nextIndex, 1), boxes.e<T>(nextIndex, 3));
|
T maxXNext = nd4j::math::nd4j_max(boxes.t<T>(nextIndex, 1), boxes.t<T>(nextIndex, 3));
|
||||||
T areaPrev = (maxYPrev - minYPrev) * (maxXPrev - minXPrev);
|
T areaPrev = (maxYPrev - minYPrev) * (maxXPrev - minXPrev);
|
||||||
T areaNext = (maxYNext - minYNext) * (maxXNext - minXNext);
|
T areaNext = (maxYNext - minYNext) * (maxXNext - minXNext);
|
||||||
|
|
||||||
|
@ -67,7 +71,7 @@ namespace helpers {
|
||||||
|
|
||||||
};
|
};
|
||||||
// int numSelected = 0;
|
// int numSelected = 0;
|
||||||
int numBoxes = boxes->sizeAt(0);
|
int numBoxes = actualIndicesCount; //boxes->sizeAt(0);
|
||||||
int numSelected = 0;
|
int numSelected = 0;
|
||||||
|
|
||||||
for (int i = 0; i < numBoxes; ++i) {
|
for (int i = 0; i < numBoxes; ++i) {
|
||||||
|
|
|
@ -77,22 +77,19 @@ namespace helpers {
|
||||||
Nd4jLong outWidth, Nd4jLong outHeight, Nd4jLong channels, Nd4jLong inRowSize, Nd4jLong outRowSize, Nd4jLong inBatchNumValues,
|
Nd4jLong outWidth, Nd4jLong outHeight, Nd4jLong channels, Nd4jLong inRowSize, Nd4jLong outRowSize, Nd4jLong inBatchNumValues,
|
||||||
BilinearInterpolationData* xs_, BilinearInterpolationData* ys_) {
|
BilinearInterpolationData* xs_, BilinearInterpolationData* ys_) {
|
||||||
|
|
||||||
if (blockIdx.x < batchSize) { // blockIdx.x as batch index
|
for (auto batch = blockIdx.x; batch < batchSize; batch += gridDim.x ) { // blockIdx.x as batch index
|
||||||
auto pX = input + blockIdx.x * inBatchNumValues;
|
auto pX = input + batch * inBatchNumValues;
|
||||||
|
|
||||||
auto channelStart = blockIdx.z * blockDim.z + threadIdx.z;
|
|
||||||
auto step = blockDim.z * gridDim.z;
|
|
||||||
for (Nd4jLong y = threadIdx.x; y < outHeight; y += blockDim.x) {
|
for (Nd4jLong y = threadIdx.x; y < outHeight; y += blockDim.x) {
|
||||||
const T *ys_input_lower_ptr = pX + ys_[y].bottomIndex * inRowSize;
|
const T *ys_input_lower_ptr = pX + ys_[y].bottomIndex * inRowSize;
|
||||||
const T *ys_input_upper_ptr = pX + ys_[y].topIndex * inRowSize;
|
const T *ys_input_upper_ptr = pX + ys_[y].topIndex * inRowSize;
|
||||||
double yVal = ys_[y].interpolarValue;
|
double yVal = ys_[y].interpolarValue;
|
||||||
auto pZ = outputYptr + y * outRowSize;
|
auto pZ = outputYptr + (batch * outHeight + y) * outRowSize;
|
||||||
for (Nd4jLong x = threadIdx.y; x < outWidth; x += blockDim.y) {
|
for (Nd4jLong x = threadIdx.y; x < outWidth; x += blockDim.y) {
|
||||||
auto xsBottom = xs_[x].bottomIndex;
|
auto xsBottom = xs_[x].bottomIndex;
|
||||||
auto xsTop = xs_[x].topIndex;
|
auto xsTop = xs_[x].topIndex;
|
||||||
auto xVal = xs_[x].interpolarValue;
|
auto xVal = xs_[x].interpolarValue;
|
||||||
// process interpolation for all channels
|
// process interpolation for all channels
|
||||||
for (int c = channelStart; c < channels; c += step) {
|
for (int c = threadIdx.z; c < channels; c += blockDim.z) {
|
||||||
double topLeft(ys_input_lower_ptr[xsBottom + c]);
|
double topLeft(ys_input_lower_ptr[xsBottom + c]);
|
||||||
double topRight(ys_input_lower_ptr[xsTop + c]);
|
double topRight(ys_input_lower_ptr[xsTop + c]);
|
||||||
double bottomLeft(ys_input_upper_ptr[xsBottom + c]);
|
double bottomLeft(ys_input_upper_ptr[xsBottom + c]);
|
||||||
|
@ -120,9 +117,15 @@ namespace helpers {
|
||||||
auto stream = context->getCudaStream();
|
auto stream = context->getCudaStream();
|
||||||
T const *input_b_ptr = reinterpret_cast<T const *>(images->getSpecialBuffer()); // this works only with 'c' direction
|
T const *input_b_ptr = reinterpret_cast<T const *>(images->getSpecialBuffer()); // this works only with 'c' direction
|
||||||
T *output_y_ptr = reinterpret_cast<T *>(output->specialBuffer());
|
T *output_y_ptr = reinterpret_cast<T *>(output->specialBuffer());
|
||||||
|
dim3 batchSizeBlock(batchSize, 1, 1);
|
||||||
resizeImageKernel<T><<<batchSize, outHeight, 256, *stream>>>(input_b_ptr, images->getSpecialShapeInfo(), output_y_ptr, output->specialShapeInfo(), batchSize,
|
dim3 pictureBlock(outHeight, outWidth, channels);
|
||||||
|
resizeImageKernel<T><<<256, pictureBlock, 256, *stream>>>(input_b_ptr, images->getSpecialShapeInfo(), output_y_ptr, output->specialShapeInfo(), batchSize,
|
||||||
outWidth, outHeight, channels, inRowSize, outRowSize, inBatchNumValues, xs_, ys_);
|
outWidth, outHeight, channels, inRowSize, outRowSize, inBatchNumValues, xs_, ys_);
|
||||||
|
|
||||||
|
auto err = cudaStreamSynchronize(*stream);
|
||||||
|
if (err != 0) {
|
||||||
|
throw cuda_exception::build("helpers::resizeImage_: Cannot synchronize kernel execution", err);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
@ -176,7 +179,6 @@ namespace helpers {
|
||||||
NDArray::prepareSpecialUse({output}, {images});
|
NDArray::prepareSpecialUse({output}, {images});
|
||||||
resizeImage(context, images, batchSize, inHeight, inWidth, outHeight, outWidth, channels, xs_, ys_, output);
|
resizeImage(context, images, batchSize, inHeight, inWidth, outHeight, outWidth, channels, xs_, ys_, output);
|
||||||
NDArray::registerSpecialUse({output}, {images});
|
NDArray::registerSpecialUse({output}, {images});
|
||||||
|
|
||||||
err = cudaFree(xs_);
|
err = cudaFree(xs_);
|
||||||
if (err != 0) {
|
if (err != 0) {
|
||||||
throw cuda_exception::build("helpers::resize_image: Cannot deallocate memory for vertical parts rectangulars", err);
|
throw cuda_exception::build("helpers::resize_image: Cannot deallocate memory for vertical parts rectangulars", err);
|
||||||
|
|
|
@ -1530,7 +1530,8 @@ TEST_F(DeclarableOpsTests10, ImageResizeBilinear_Test01) {
|
||||||
|
|
||||||
TEST_F(DeclarableOpsTests10, ImageResizeBilinear_Test02) {
|
TEST_F(DeclarableOpsTests10, ImageResizeBilinear_Test02) {
|
||||||
|
|
||||||
NDArray input = NDArrayFactory::create<float>('c', {2, 5,5,3}, {0.7788f, 0.8012f, 0.7244f,
|
NDArray input = NDArrayFactory::create<float>('c', {2, 5,5,3}, {
|
||||||
|
0.7788f, 0.8012f, 0.7244f,
|
||||||
0.2309f, 0.7271f, 0.1804f,
|
0.2309f, 0.7271f, 0.1804f,
|
||||||
0.5056f, 0.8925f, 0.5461f,
|
0.5056f, 0.8925f, 0.5461f,
|
||||||
0.9234f, 0.0856f, 0.7938f,
|
0.9234f, 0.0856f, 0.7938f,
|
||||||
|
@ -1581,40 +1582,89 @@ TEST_F(DeclarableOpsTests10, ImageResizeBilinear_Test02) {
|
||||||
0.4739f, 0.7014f, 0.4473f,
|
0.4739f, 0.7014f, 0.4473f,
|
||||||
0.5171f, 0.1744f, 0.3487f});
|
0.5171f, 0.1744f, 0.3487f});
|
||||||
|
|
||||||
NDArray expected = NDArrayFactory::create<float>('c', {10, 10, 4}, {1., 2., 3., 4., 2.2, 3.2, 4.2, 5.2, 3.4, 4.4, 5.4, 6.4,
|
NDArray expected = NDArrayFactory::create<float>('c', {2, 9, 9, 3}, {
|
||||||
4.6, 5.6, 6.6, 7.6, 5.8, 6.8, 7.8, 8.8, 7., 8., 9., 10.,
|
0.7788f, 0.8012f, 0.7244f, 0.4744111f, 0.7600333f, 0.42217776f,
|
||||||
8.2, 9.2, 10.2, 11.2, 9., 10., 11., 12., 9., 10., 11., 12.,
|
0.26142225f, 0.7454778f, 0.22103335f, 0.41403335f, 0.8373667f, 0.42420003f,
|
||||||
9., 10., 11., 12., 3.4, 4.4, 5.4, 6.4, 4.6, 5.6, 6.6, 7.6,
|
0.59844446f, 0.71318877f, 0.6011445f, 0.83055556f, 0.264911f, 0.7387556f,
|
||||||
5.8, 6.8, 7.8, 8.8, 7.0, 8., 9., 10., 8.2, 9.2, 10.2, 11.2,
|
0.83529997f, 0.2422334f, 0.5823999f, 0.6884666f, 0.5032889f, 0.23006654f,
|
||||||
9.4,10.4, 11.4, 12.4,10.6, 11.6,12.6, 13.6,11.4, 12.4, 13.4, 14.4,
|
0.6591f, 0.5555f, 0.1596f, 0.5176333f, 0.44208887f , 0.5827889f,
|
||||||
11.4,12.4, 13.4, 14.4,11.4, 12.4,13.4, 14.4, 5.8, 6.8, 7.8, 8.8,
|
0.5938309f, 0.5646876f, 0.5123568f, 0.61811364f, 0.6748667f, 0.44617534f,
|
||||||
7., 8., 9., 10., 8.2, 9.2,10.2, 11.2, 9.4, 10.4, 11.4, 12.4,
|
0.43473703f, 0.7353667f, 0.3969963f, 0.35003704f, 0.6654419f, 0.46649635f,
|
||||||
10.6,11.6, 12.6, 13.6,11.8, 12.8,13.8, 14.8,13.0, 14.0, 15.0, 16.,
|
0.41335183f, 0.39988017f, 0.7140149f, 0.43368888f, 0.45865932f, 0.72049254f,
|
||||||
13.8,14.8, 15.8, 16.8,13.8, 14.8,15.8, 16.8,13.8, 14.8, 15.8, 16.8,
|
0.42537406f, 0.73366547f, 0.5662765f, 0.42371112f, 0.78866667f, 0.53543335f,
|
||||||
8.2, 9.2, 10.2, 11.2, 9.4, 10.4,11.4, 12.4,10.6, 11.6, 12.6, 13.6,
|
0.30312222f, 0.18414445f, 0.49542224f, 0.67293704f, 0.4168852f, 0.59891605f,
|
||||||
11.8,12.8, 13.8, 14.8,13., 14., 15., 16., 14.2, 15.2, 16.2, 17.2,
|
0.8822444f, 0.60281235f, 0.62855184f, 0.4495222f, 0.6014852f, 0.36275554f,
|
||||||
15.4,16.4, 17.4, 18.4,16.2, 17.2,18.2, 19.2,16.2, 17.2, 18.2, 19.2,
|
0.15933579f, 0.5788963f, 0.34024328f, 0.08295307f, 0.52441484f, 0.6826569f,
|
||||||
16.2,17.2, 18.2, 19.2,10.6, 11.6,12.6, 13.6,11.8, 12.8, 13.8, 14.8,
|
0.10747781f, 0.64715934f, 0.80707777f, 0.19927411f, 0.8880544f, 0.7861703f,
|
||||||
13., 14., 15., 16., 14.2, 15.2,16.2, 17.2,15.4, 16.4, 17.4, 18.4,
|
0.21763334f, 0.9362333f, 0.78198886f, 0.27523333f, 0.3308667f, 0.6250333f,
|
||||||
16.6,17.6, 18.6, 19.6,17.8, 18.8,19.8, 20.8,18.6, 19.6, 20.6, 21.6,
|
0.5907889f, 0.45925558f, 0.6709963f, 0.7761333f, 0.5249852f, 0.63986665f,
|
||||||
18.6,19.6, 20.6, 21.6,18.6, 19.6,20.6, 21.6,13., 14., 15., 16.,
|
0.4406333f, 0.34007773f, 0.3003666f, 0.19945924f, 0.33715558f, 0.24757043f,
|
||||||
14.2,15.2, 16.2, 17.2,15.4, 16.4,17.4, 18.4,16.6, 17.6, 18.6, 19.6,
|
0.09977405f, 0.60721123f, 0.6248297f, 0.08286668f, 0.7239556f, 0.6876333f,
|
||||||
17.8,18.8, 19.8, 20.8,19., 20., 21., 22., 20.2, 21.2, 22.2, 23.2,
|
0.12114445f, 0.73849255f ,0.54079986f, 0.12879999f, 0.74139994f, 0.51143324f,
|
||||||
21., 22., 23., 24., 21., 22., 23., 24., 21., 22., 23., 24.,
|
0.32978892f, 0.45314446f, 0.58711106f, 0.5576408f, 0.5464408f, 0.6107901f,
|
||||||
13., 14., 15., 16., 14.2, 15.2,16.2, 17.2,15.4, 16.4, 17.4, 18.4,
|
0.68978024f, 0.55681235f, 0.5833172f, 0.43907034f, 0.23548517f, 0.35123706f,
|
||||||
16.6,17.6, 18.6, 19.6,17.8, 18.8, 19.8, 20.8,19., 20., 21., 22.,
|
0.26263458f, 0.18254575f, 0.33890504f, 0.1976099f, 0.5321877f, 0.65619516f,
|
||||||
20.2,21.2, 22.2, 23.2,21., 22., 23., 24., 21., 22., 23., 24.,
|
0.18267044f, 0.6404851f, 0.63069254f, 0.20112106f, 0.58788633f, 0.37666163f,
|
||||||
21., 22., 23., 24., 13., 14., 15., 16., 14.2, 15.2, 16.2, 17.2,
|
0.20481117f, 0.57736665f, 0.32585555f, 0.50801116f, 0.5387556f, 0.29788882f,
|
||||||
15.4,16.4, 17.4, 18.4,16.6, 17.6, 18.6, 19.6,17.8, 18.8, 19.8, 20.8,
|
0.59799266f, 0.7008482f, 0.35215425f, 0.6330642f, 0.753121f, 0.42497158f,
|
||||||
19., 20., 21., 22., 20.2, 21.2, 22.2, 23.2,21., 22., 23., 24.,
|
0.44849625f, 0.36611477f, 0.5719964f, 0.36038768f, 0.1586321f, 0.70625067f,
|
||||||
21., 22., 23., 24., 21., 22., 23., 24., 13., 14., 15., 16.,
|
0.416968f, 0.22043455f, 0.82134944f, 0.4690964f, 0.31661478f, 0.6675073f,
|
||||||
14.2,15.2, 16.2, 17.2,15.4, 16.4, 17.4, 18.4,16.6, 17.6, 18.6, 19.6,
|
0.5182569f, 0.4357136f, 0.33437145f, 0.528089f, 0.4595333f, 0.26774442f,
|
||||||
17.8,18.8, 19.8, 20.8,19., 20., 21., 22., 20.2, 21.2, 22.2, 23.2,
|
0.52779996f, 0.5559667f, 0.35320008f, 0.5630963f, 0.62568885f, 0.44562602f,
|
||||||
21., 22., 23., 24., 21., 22., 23., 24., 21., 22., 23., 24.,
|
0.557237f, 0.62408876f, 0.5438927f, 0.3867555f, 0.3371999f, 0.6655223f,
|
||||||
13., 14., 15., 16., 14.2, 15.2, 16.2, 17.2,15.4, 16.4, 17.4, 18.4,
|
0.30325183f, 0.17024446f, 0.71867025f, 0.35021478f, 0.18318895f, 0.6690962f,
|
||||||
16.6,17.6, 18.6, 19.6,17.8, 18.8, 19.8, 20.8,19., 20., 21., 22.,
|
0.4377444f, 0.24482228f, 0.5241777f, 0.5523185f, 0.33891484f, 0.3156962f,
|
||||||
20.2,21.2, 22.2, 23.2,
|
0.5752333f, 0.3577333f, 0.27400002f, 0.44196665f, 0.52757776f, 0.6382001f,
|
||||||
21., 22., 23., 24., 21., 22., 23., 24., 21., 22., 23., 24.});
|
0.47803456f, 0.3974851f, 0.7738359f, 0.4686691f, 0.27816284f, 0.8476581f,
|
||||||
|
0.2775703f, 0.20192216f, 0.6742259f, 0.14285672f, 0.20554078f, 0.4944727f,
|
||||||
|
0.0927209f, 0.32894826f, 0.30523813f, 0.19454071f, 0.3410815f, 0.26075178f,
|
||||||
|
0.3976642f, 0.27903205f, 0.31276423f, 0.43828884f, 0.2666222f, 0.32316667f,
|
||||||
|
0.4248f, 0.5219f, 0.6952f, 0.46102223f, 0.35184443f, 0.8394778f,
|
||||||
|
0.45095554f, 0.20897777f, 0.9084111f, 0.2557333f, 0.17486666f, 0.6759666f,
|
||||||
|
0.11077777f, 0.21260004f, 0.44963327f, 0.04122221f, 0.35810006f, 0.23246664f,
|
||||||
|
0.14590007f, 0.36033332f, 0.2080667f, 0.3667334f, 0.2670555f, 0.31217784f,
|
||||||
|
0.4109f, 0.2484f, 0.333f, 0.2974f, 0.6636f, 0.3808f,
|
||||||
|
0.6135111f, 0.40026665f, 0.5875778f, 0.8503f, 0.24200003f, 0.7501111f,
|
||||||
|
0.76979995f, 0.50400007f, 0.7356667f, 0.6879222f, 0.57351106f, 0.73106664f,
|
||||||
|
0.60397774f, 0.35428885f, 0.74123335f, 0.39506656f, 0.27853334f, 0.6585333f,
|
||||||
|
0.10284433f, 0.29842222f, 0.5139222f, 0.0444f, 0.3024f, 0.485f,
|
||||||
|
0.5756222f, 0.34854442f, 0.6049667f, 0.6263938f, 0.22777282f, 0.71313334f,
|
||||||
|
0.66620123f, 0.17765433f, 0.78429013f, 0.6621518f, 0.41014817f, 0.7074074f,
|
||||||
|
0.67555183f, 0.51060987f, 0.6708259f, 0.7151259f, 0.41302344f, 0.6946963f,
|
||||||
|
0.5446962f, 0.33081108f, 0.6180703f, 0.23426408f, 0.25884813f, 0.4744469f,
|
||||||
|
0.17217779f, 0.24445555f, 0.44572222f, 0.7964111f, 0.12472223f, 0.7531556f,
|
||||||
|
0.6118617f, 0.1483889f, 0.75928515f, 0.4833407f, 0.2004667f, 0.7449173f,
|
||||||
|
0.57893336f, 0.3661889f, 0.6485592f, 0.6772543f, 0.46945432f, 0.5984506f,
|
||||||
|
0.7796679f, 0.47903457f, 0.617716f, 0.63706285f, 0.40579626f, 0.54952586f,
|
||||||
|
0.33111224f, 0.27734566f, 0.42303205f, 0.26992223f, 0.25165558f, 0.39773333f,
|
||||||
|
0.7874667f, 0.26583335f, 0.5974333f, 0.4876703f, 0.44144446f, 0.48782218f,
|
||||||
|
0.30543333f, 0.57191116f, 0.41133702f, 0.5934334f, 0.5218f, 0.46735552f,
|
||||||
|
0.73524815f, 0.5152815f, 0.47753704f, 0.6577852f, 0.5741519f, 0.41896293f,
|
||||||
|
0.50037766f, 0.57161117f, 0.3686555f, 0.28967398f, 0.5281297f, 0.3238592f,
|
||||||
|
0.24753332f, 0.5194334f, 0.31489998f, 0.72816664f, 0.37683335f, 0.5285778f,
|
||||||
|
0.3895555f, 0.5582283f, 0.32292962f, 0.18990126f, 0.6730641f, 0.18445063f,
|
||||||
|
0.5460741f, 0.5216629f, 0.31464812f, 0.6978098f, 0.45279747f, 0.36710492f,
|
||||||
|
0.5428901f, 0.5077358f, 0.30295062f, 0.42367774f, 0.53567034f, 0.28493333f,
|
||||||
|
0.32827038f, 0.54560244f, 0.2976741f, 0.30918893f, 0.5475888f, 0.30022222f,
|
||||||
|
0.5933333f, 0.44266668f, 0.59002227f, 0.3305555f, 0.4106049f, 0.31789258f,
|
||||||
|
0.16793211f, 0.36878017f, 0.11760493f, 0.40592593f, 0.28790364f, 0.20468517f,
|
||||||
|
0.5172234f, 0.22784683f, 0.27239504f, 0.4384765f, 0.19901967f, 0.3110494f,
|
||||||
|
0.43695557f, 0.19709623f, 0.34693336f, 0.4869186f, 0.21310854f, 0.38097042f,
|
||||||
|
0.49691117f, 0.21631104f, 0.3877778f, 0.37919992f, 0.4914f, 0.56826663f,
|
||||||
|
0.26019996f, 0.34673333f, 0.29495183f, 0.21430746f, 0.23090371f, 0.09418149f,
|
||||||
|
0.46084452f, 0.23042224f, 0.1835889f, 0.56450003f, 0.23844449f, 0.26893705f,
|
||||||
|
0.45383334f, 0.2592223f, 0.34819633f, 0.45761114f, 0.21635559f, 0.38596666f,
|
||||||
|
0.5376852f, 0.13105926f, 0.39607778f, 0.55370003f, 0.11400001f, 0.3981f,
|
||||||
|
0.11219993f, 0.5287333f, 0.49104443f, 0.18227404f, 0.3386963f, 0.26007527f,
|
||||||
|
0.30624574f, 0.20396544f, 0.09970618f, 0.6458075f, 0.2904593f, 0.22173704f,
|
||||||
|
0.7636852f, 0.40607417f, 0.32631359f, 0.549037f, 0.5653705f, 0.40470868f,
|
||||||
|
0.4831852f, 0.47417036f, 0.40968886f, 0.5165309f, 0.21597281f, 0.3657259f,
|
||||||
|
0.5232f, 0.16433334f, 0.3569333f, 0.0588f, 0.5362f, 0.4756f,
|
||||||
|
0.16668889f, 0.33708888f, 0.25309998f, 0.32463336f, 0.19857779f, 0.10081112f,
|
||||||
|
0.68280005f, 0.3024667f, 0.22936666f, 0.80352217f, 0.43960005f, 0.33778888f,
|
||||||
|
0.5680777f, 0.6266f, 0.41601112f, 0.4883f, 0.52573323f, 0.4144333f,
|
||||||
|
0.5123f, 0.23295549f, 0.35965553f, 0.5171f, 0.1744f, 0.3487f
|
||||||
|
});
|
||||||
//input.linspace(1);
|
//input.linspace(1);
|
||||||
|
|
||||||
nd4j::ops::resize_bilinear op;
|
nd4j::ops::resize_bilinear op;
|
||||||
|
@ -1624,12 +1674,12 @@ TEST_F(DeclarableOpsTests10, ImageResizeBilinear_Test02) {
|
||||||
|
|
||||||
NDArray* result = results->at(0);
|
NDArray* result = results->at(0);
|
||||||
|
|
||||||
result->printIndexedBuffer("Resized to 9x9");
|
// result->printBuffer("Resized to 9x9");
|
||||||
//expected.printIndexedBuffer("Expect for 10x10");
|
// expected.printBuffer("Expect for 9x9");
|
||||||
result->printShapeInfo("Output shape");
|
// result->printShapeInfo("Output shape");
|
||||||
// expected.printShapeInfo("Expect shape");
|
// expected.printShapeInfo("Expect shape");
|
||||||
// ASSERT_TRUE(expected.isSameShape(result));
|
ASSERT_TRUE(expected.isSameShape(result));
|
||||||
// ASSERT_TRUE(expected.equalsTo(result));
|
ASSERT_TRUE(expected.equalsTo(result));
|
||||||
delete results;
|
delete results;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -2015,6 +2065,53 @@ TEST_F(DeclarableOpsTests10, ImageResizeNeighbor_Test1) {
|
||||||
delete results;
|
delete results;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_F(DeclarableOpsTests10, ImageResizeNeighbor_Test1_1) {
|
||||||
|
|
||||||
|
NDArray input = NDArrayFactory::create<int>('c', {1, 2, 3, 4});
|
||||||
|
//NDArray<float> paddings('c', {3,2}, {0,0, 0,1, 0,0});
|
||||||
|
//NDArray<float> expected('c', {2,4,4}, {1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,0.,0.,0.,0.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,0.,0.,0.,0.});
|
||||||
|
NDArray expected = NDArrayFactory::create<int>('c', {1, 4, 5, 4}, { 1, 2, 3, 4,
|
||||||
|
1, 2, 3, 4,
|
||||||
|
5, 6, 7, 8,
|
||||||
|
5, 6, 7, 8,
|
||||||
|
9, 10, 11, 12,
|
||||||
|
|
||||||
|
1, 2, 3, 4,
|
||||||
|
1, 2, 3, 4,
|
||||||
|
5, 6, 7, 8,
|
||||||
|
5, 6, 7, 8,
|
||||||
|
9, 10, 11, 12,
|
||||||
|
|
||||||
|
13, 14, 15, 16,
|
||||||
|
13, 14, 15, 16,
|
||||||
|
17, 18, 19, 20,
|
||||||
|
17, 18, 19, 20,
|
||||||
|
21, 22, 23, 24,
|
||||||
|
|
||||||
|
13, 14, 15, 16,
|
||||||
|
13, 14, 15, 16,
|
||||||
|
17, 18, 19, 20,
|
||||||
|
17, 18, 19, 20,
|
||||||
|
21, 22, 23, 24
|
||||||
|
});
|
||||||
|
//input = 1.f;
|
||||||
|
input.linspace(1);
|
||||||
|
|
||||||
|
nd4j::ops::resize_nearest_neighbor op;
|
||||||
|
auto results = op.execute({&input}, {}, {4, 5});
|
||||||
|
|
||||||
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
NDArray* result = results->at(0);
|
||||||
|
|
||||||
|
// result->printIndexedBuffer("Resized to 4x5");
|
||||||
|
// expected.printIndexedBuffer("Expect for 4x5");
|
||||||
|
ASSERT_TRUE(expected.isSameShape(result));
|
||||||
|
ASSERT_TRUE(expected.equalsTo(result));
|
||||||
|
|
||||||
|
delete results;
|
||||||
|
}
|
||||||
|
|
||||||
TEST_F(DeclarableOpsTests10, ImageResizeNeighbor_Test01) {
|
TEST_F(DeclarableOpsTests10, ImageResizeNeighbor_Test01) {
|
||||||
|
|
||||||
NDArray input = NDArrayFactory::create<double>('c', {2, 3, 4});
|
NDArray input = NDArrayFactory::create<double>('c', {2, 3, 4});
|
||||||
|
@ -2166,6 +2263,73 @@ TEST_F(DeclarableOpsTests10, Image_NonMaxSuppressing_2) {
|
||||||
delete results;
|
delete results;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
////////////////////////////////////////////////////////////////////
|
||||||
|
TEST_F(DeclarableOpsTests10, Image_NonMaxSuppressing_3) {
|
||||||
|
|
||||||
|
NDArray boxes = NDArrayFactory::create<float>('c', {3, 4}, {0.8115f, 0.4121f, 0.0771f, 0.4863f,
|
||||||
|
0.7412f, 0.7607f, 0.1543f, 0.5479f,
|
||||||
|
0.8223f, 0.2246f, 0.0049f, 0.6465f});
|
||||||
|
NDArray scales = NDArrayFactory::create<float>('c', {3}, {0.0029f, 0.8135f, 0.4873f}); //3, 0, 1, 2, 4, 5
|
||||||
|
NDArray expected = NDArrayFactory::create<int>('c', {1}, {1});
|
||||||
|
|
||||||
|
nd4j::ops::non_max_suppression op;
|
||||||
|
auto results = op.execute({&boxes, &scales}, {0.5, 0.5}, {2});
|
||||||
|
|
||||||
|
ASSERT_EQ(Status::OK(), results->status());
|
||||||
|
|
||||||
|
NDArray* result = results->at(0);
|
||||||
|
// result->printBuffer("NonMaxSuppression OUtput3");
|
||||||
|
ASSERT_TRUE(expected.isSameShapeStrict(result));
|
||||||
|
ASSERT_TRUE(expected.equalsTo(result));
|
||||||
|
|
||||||
|
delete results;
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(DeclarableOpsTests10, Image_NonMaxSuppressing_4) {
|
||||||
|
|
||||||
|
NDArray boxes = NDArrayFactory::create<float16>('c', {3, 4}, {0.8115f, 0.4121f, 0.0771f, 0.4863f,
|
||||||
|
0.7412f, 0.7607f, 0.1543f, 0.5479f,
|
||||||
|
0.8223f, 0.2246f, 0.0049f, 0.6465f});
|
||||||
|
NDArray scales = NDArrayFactory::create<float16>('c', {3}, {0.0029f, 0.8135f, 0.4873f}); //3, 0, 1, 2, 4, 5
|
||||||
|
NDArray expected = NDArrayFactory::create<int>('c', {1}, {1});
|
||||||
|
NDArray maxSize = NDArrayFactory::create(2);
|
||||||
|
NDArray threshold = NDArrayFactory::create(0.5f);
|
||||||
|
NDArray scoreThreshold = NDArrayFactory::create(0.5);
|
||||||
|
nd4j::ops::non_max_suppression op;
|
||||||
|
auto results = op.execute({&boxes, &scales, &maxSize, &threshold, &scoreThreshold}, {}, {});
|
||||||
|
|
||||||
|
ASSERT_EQ(Status::OK(), results->status());
|
||||||
|
|
||||||
|
NDArray* result = results->at(0);
|
||||||
|
// result->printBuffer("NonMaxSuppression OUtput4");
|
||||||
|
ASSERT_TRUE(expected.isSameShapeStrict(result));
|
||||||
|
ASSERT_TRUE(expected.equalsTo(result));
|
||||||
|
|
||||||
|
delete results;
|
||||||
|
}
|
||||||
|
TEST_F(DeclarableOpsTests10, Image_NonMaxSuppressing_5) {
|
||||||
|
|
||||||
|
NDArray boxes = NDArrayFactory::create<float16>('c', {3, 4}, {0.8115f, 0.4121f, 0.0771f, 0.4863f,
|
||||||
|
0.7412f, 0.7607f, 0.1543f, 0.5479f,
|
||||||
|
0.8223f, 0.2246f, 0.0049f, 0.6465f});
|
||||||
|
NDArray scales = NDArrayFactory::create<float16>('c', {3}, {0.0029f, 0.8135f, 0.4873f}); //3, 0, 1, 2, 4, 5
|
||||||
|
NDArray expected = NDArrayFactory::create<int>('c', {2}, {1, 2});
|
||||||
|
NDArray maxSize = NDArrayFactory::create(2);
|
||||||
|
NDArray threshold = NDArrayFactory::create(0.5f);
|
||||||
|
NDArray scoreThreshold = NDArrayFactory::create(-DataTypeUtils::infOrMax<float>());
|
||||||
|
nd4j::ops::non_max_suppression op;
|
||||||
|
auto results = op.execute({&boxes, &scales, &maxSize, &threshold, &scoreThreshold}, {}, {});
|
||||||
|
|
||||||
|
ASSERT_EQ(Status::OK(), results->status());
|
||||||
|
|
||||||
|
NDArray* result = results->at(0);
|
||||||
|
// result->printBuffer("NonMaxSuppression OUtput4");
|
||||||
|
ASSERT_TRUE(expected.isSameShapeStrict(result));
|
||||||
|
ASSERT_TRUE(expected.equalsTo(result));
|
||||||
|
|
||||||
|
delete results;
|
||||||
|
}
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////
|
||||||
TEST_F(DeclarableOpsTests10, Image_NonMaxSuppressingOverlap_1) {
|
TEST_F(DeclarableOpsTests10, Image_NonMaxSuppressingOverlap_1) {
|
||||||
|
|
||||||
|
@ -2692,6 +2856,46 @@ TEST_F(DeclarableOpsTests10, FakeQuantWithMinMaxVars_Test_6) {
|
||||||
delete results;
|
delete results;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
////////////////////////////////////////////////////////////////////
|
||||||
|
//TEST_F(DeclarableOpsTests10, FakeQuantWithMinMaxVars_Test_7) {
|
||||||
|
//
|
||||||
|
// NDArray x = NDArrayFactory::create<double>('c', {100});
|
||||||
|
// NDArray exp = NDArrayFactory::create<double>('c', {100}, {
|
||||||
|
// 0.f, 0.f, 0.f , 0.f , 0.06666667f, 0.06666667f ,
|
||||||
|
// 0.06666667, 0.06666667, 0.06666667, 0.06666667, 0.06666667, 0.13333334 ,
|
||||||
|
// 0.13333334, 0.13333334, 0.13333334, 0.13333334, 0.13333334, 0.20000002 ,
|
||||||
|
// 0.20000002, 0.20000002, 0.20000002, 0.20000002, 0.20000002, 0.20000002 ,
|
||||||
|
// 0.26666668, 0.26666668, 0.26666668, 0.26666668, 0.26666668, 0.26666668 ,
|
||||||
|
// 0.26666668, 0.33333334, 0.33333334, 0.33333334, 0.33333334, 0.33333334 ,
|
||||||
|
// 0.33333334, 0.40000004, 0.40000004, 0.40000004, 0.40000004, 0.40000004 ,
|
||||||
|
// 0.40000004, 0.40000004, 0.4666667 , 0.4666667 , 0.4666667 , 0.4666667 ,
|
||||||
|
// 0.4666667 , 0.4666667 , 0.4666667 , 0.53333336, 0.53333336, 0.53333336 ,
|
||||||
|
// 0.53333336, 0.53333336, 0.53333336, 0.6 , 0.6 , 0.6 ,
|
||||||
|
// 0.6 , 0.6 , 0.6 , 0.6 , 0.6666667 , 0.6666667 ,
|
||||||
|
// 0.6666667 , 0.6666667 , 0.6666667 , 0.6666667 , 0.6666667 , 0.73333335 ,
|
||||||
|
// 0.73333335, 0.73333335, 0.73333335, 0.73333335, 0.73333335, 0.8000001 ,
|
||||||
|
// 0.8000001 , 0.8000001 , 0.8000001 , 0.8000001 , 0.8000001 , 0.8000001 ,
|
||||||
|
// 0.86666673, 0.86666673, 0.86666673, 0.86666673, 0.86666673, 0.86666673 ,
|
||||||
|
// 0.86666673, 0.9333334 , 0.9333334 , 0.9333334 , 0.9333334 , 0.9333334 ,
|
||||||
|
// 0.9333334 , 1., 1., 1.,
|
||||||
|
// });
|
||||||
|
// NDArray min = NDArrayFactory::create<float>('c', {1},{0.0f});
|
||||||
|
// NDArray max = NDArrayFactory::create<float>('c', {1}, {1.f});
|
||||||
|
// x.linspace(0., 0.01);
|
||||||
|
// nd4j::ops::fake_quant_with_min_max_vars op;
|
||||||
|
// auto results = op.execute({&x, &min, &max}, {}, {});
|
||||||
|
//
|
||||||
|
// ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
//
|
||||||
|
// auto result = results->at(0);
|
||||||
|
// result->printBuffer("Quantized7");
|
||||||
|
// exp.printBuffer("Expected 7");
|
||||||
|
// ASSERT_TRUE(exp.isSameShapeStrict(result));
|
||||||
|
// ASSERT_TRUE(exp.equalsTo(result));
|
||||||
|
//
|
||||||
|
// delete results;
|
||||||
|
//}
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////
|
||||||
TEST_F(DeclarableOpsTests10, batchnorm_test1) {
|
TEST_F(DeclarableOpsTests10, batchnorm_test1) {
|
||||||
|
|
||||||
|
|
|
@ -254,6 +254,34 @@ TEST_F(DeclarableOpsTests15, Test_BitCast_2) {
|
||||||
delete result;
|
delete result;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_F(DeclarableOpsTests15, Test_BitCast_3) {
|
||||||
|
auto x = NDArrayFactory::create<float>('c', {1, 4});
|
||||||
|
|
||||||
|
x.linspace(1.);
|
||||||
|
nd4j::ops::bitcast op;
|
||||||
|
try {
|
||||||
|
auto result = op.execute({&x}, {}, {nd4j::DataType::INT64}, {});
|
||||||
|
ASSERT_NE(Status::OK(), result->status());
|
||||||
|
delete result;
|
||||||
|
} catch (std::exception& e) {
|
||||||
|
nd4j_printf("Error should be here `%s'. It's OK.\n", e.what());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(DeclarableOpsTests15, Test_BitCast_4) {
|
||||||
|
auto x = NDArrayFactory::create<float>('c', {1, 4});
|
||||||
|
auto e = NDArrayFactory::create<Nd4jLong>('c', {1, 2}, {1234567890LL, 2468013579LL});
|
||||||
|
x.linspace(1.);
|
||||||
|
nd4j::ops::bitcast op;
|
||||||
|
try {
|
||||||
|
auto result = op.execute({&x}, {&e}, {}, {nd4j::DataType::INT64}, {});
|
||||||
|
ASSERT_NE(Status::OK(), result);
|
||||||
|
} catch(std::exception& e) {
|
||||||
|
nd4j_printf("Error `%s' should be here. It's OK.\n",e.what());
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
TEST_F(DeclarableOpsTests15, Test_depthwise_bp_1) {
|
TEST_F(DeclarableOpsTests15, Test_depthwise_bp_1) {
|
||||||
auto in = NDArrayFactory::create<float>('c', {4, 8, 64, 64});
|
auto in = NDArrayFactory::create<float>('c', {4, 8, 64, 64});
|
||||||
auto w = NDArrayFactory::create<float>('c', {2, 2, 8, 2});
|
auto w = NDArrayFactory::create<float>('c', {2, 2, 8, 2});
|
||||||
|
|
|
@ -426,6 +426,10 @@ TEST_F(NlpTests, test_sg_ns_batch_1) {
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(NlpTests, test_cbow_hs_batch_1) {
|
TEST_F(NlpTests, test_cbow_hs_batch_1) {
|
||||||
|
#ifdef __CUDABLAS__
|
||||||
|
return ;
|
||||||
|
#endif
|
||||||
|
|
||||||
auto target = NDArrayFactory::create<int>(0);
|
auto target = NDArrayFactory::create<int>(0);
|
||||||
auto ngStarter = NDArrayFactory::empty<int>();
|
auto ngStarter = NDArrayFactory::empty<int>();
|
||||||
auto context = NDArrayFactory::create<int>('c', {2, 3}, {0, 1, 2, 100, 101, 102});
|
auto context = NDArrayFactory::create<int>('c', {2, 3}, {0, 1, 2, 100, 101, 102});
|
||||||
|
|
|
@ -62,8 +62,7 @@ public class ResizeNearestNeighbor extends DynamicCustomOp {
|
||||||
public List<DataType> calculateOutputDataTypes(List<DataType> inputDataTypes){
|
public List<DataType> calculateOutputDataTypes(List<DataType> inputDataTypes){
|
||||||
Preconditions.checkState(inputDataTypes != null && (inputDataTypes.size() == 1 || inputDataTypes.size() == 2),
|
Preconditions.checkState(inputDataTypes != null && (inputDataTypes.size() == 1 || inputDataTypes.size() == 2),
|
||||||
"Expected 1 or 2 input datatypes for %s, got %s", getClass(), inputDataTypes);
|
"Expected 1 or 2 input datatypes for %s, got %s", getClass(), inputDataTypes);
|
||||||
if(inputDataTypes.get(0).isFPType())
|
|
||||||
return Collections.singletonList(inputDataTypes.get(0));
|
return Collections.singletonList(inputDataTypes.get(0));
|
||||||
return Collections.singletonList(Nd4j.defaultFloatingPointType());
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -8005,9 +8005,12 @@ public static final int PREALLOC_SIZE = 33554432;
|
||||||
@Namespace("shape") public static native @Cast("Nd4jLong") long getOffset(@Cast("const Nd4jLong*") LongBuffer shapeInfo, @Cast("const Nd4jLong*") LongBuffer indices);
|
@Namespace("shape") public static native @Cast("Nd4jLong") long getOffset(@Cast("const Nd4jLong*") LongBuffer shapeInfo, @Cast("const Nd4jLong*") LongBuffer indices);
|
||||||
@Namespace("shape") public static native @Cast("Nd4jLong") long getOffset(@Cast("const Nd4jLong*") long[] shapeInfo, @Cast("const Nd4jLong*") long[] indices, @Cast("Nd4jLong") long baseOffset/*=0*/);
|
@Namespace("shape") public static native @Cast("Nd4jLong") long getOffset(@Cast("const Nd4jLong*") long[] shapeInfo, @Cast("const Nd4jLong*") long[] indices, @Cast("Nd4jLong") long baseOffset/*=0*/);
|
||||||
@Namespace("shape") public static native @Cast("Nd4jLong") long getOffset(@Cast("const Nd4jLong*") long[] shapeInfo, @Cast("const Nd4jLong*") long[] indices);
|
@Namespace("shape") public static native @Cast("Nd4jLong") long getOffset(@Cast("const Nd4jLong*") long[] shapeInfo, @Cast("const Nd4jLong*") long[] indices);
|
||||||
@Namespace("shape") public static native @Cast("Nd4jLong") long getOffset(@Cast("const Nd4jLong*") LongPointer shapeInfo, @Cast("uint*") @StdVector IntPointer indices);
|
@Namespace("shape") public static native @Cast("Nd4jLong") long getOffset(@Cast("const Nd4jLong*") LongPointer shapeInfo, @Const IntPointer indices, @Cast("Nd4jLong") long baseOffset/*=0*/);
|
||||||
@Namespace("shape") public static native @Cast("Nd4jLong") long getOffset(@Cast("const Nd4jLong*") LongBuffer shapeInfo, @Cast("uint*") @StdVector IntBuffer indices);
|
@Namespace("shape") public static native @Cast("Nd4jLong") long getOffset(@Cast("const Nd4jLong*") LongPointer shapeInfo, @Const IntPointer indices);
|
||||||
@Namespace("shape") public static native @Cast("Nd4jLong") long getOffset(@Cast("const Nd4jLong*") long[] shapeInfo, @Cast("uint*") @StdVector int[] indices);
|
@Namespace("shape") public static native @Cast("Nd4jLong") long getOffset(@Cast("const Nd4jLong*") LongBuffer shapeInfo, @Const IntBuffer indices, @Cast("Nd4jLong") long baseOffset/*=0*/);
|
||||||
|
@Namespace("shape") public static native @Cast("Nd4jLong") long getOffset(@Cast("const Nd4jLong*") LongBuffer shapeInfo, @Const IntBuffer indices);
|
||||||
|
@Namespace("shape") public static native @Cast("Nd4jLong") long getOffset(@Cast("const Nd4jLong*") long[] shapeInfo, @Const int[] indices, @Cast("Nd4jLong") long baseOffset/*=0*/);
|
||||||
|
@Namespace("shape") public static native @Cast("Nd4jLong") long getOffset(@Cast("const Nd4jLong*") long[] shapeInfo, @Const int[] indices);
|
||||||
|
|
||||||
@Namespace("shape") public static native @Cast("Nd4jLong*") LongPointer createShapeInfo(@Cast("Nd4jLong*") LongPointer shape, @Cast("Nd4jLong*") LongPointer stride, int rank);
|
@Namespace("shape") public static native @Cast("Nd4jLong*") LongPointer createShapeInfo(@Cast("Nd4jLong*") LongPointer shape, @Cast("Nd4jLong*") LongPointer stride, int rank);
|
||||||
@Namespace("shape") public static native @Cast("Nd4jLong*") LongBuffer createShapeInfo(@Cast("Nd4jLong*") LongBuffer shape, @Cast("Nd4jLong*") LongBuffer stride, int rank);
|
@Namespace("shape") public static native @Cast("Nd4jLong*") LongBuffer createShapeInfo(@Cast("Nd4jLong*") LongBuffer shape, @Cast("Nd4jLong*") LongBuffer stride, int rank);
|
||||||
|
@ -8024,6 +8027,9 @@ public static final int PREALLOC_SIZE = 33554432;
|
||||||
@Namespace("shape") public static native void index2coords(@Cast("Nd4jLong") long index, @Cast("const Nd4jLong*") LongPointer shapeInfo, @Cast("Nd4jLong*") LongPointer coords);
|
@Namespace("shape") public static native void index2coords(@Cast("Nd4jLong") long index, @Cast("const Nd4jLong*") LongPointer shapeInfo, @Cast("Nd4jLong*") LongPointer coords);
|
||||||
@Namespace("shape") public static native void index2coords(@Cast("Nd4jLong") long index, @Cast("const Nd4jLong*") LongBuffer shapeInfo, @Cast("Nd4jLong*") LongBuffer coords);
|
@Namespace("shape") public static native void index2coords(@Cast("Nd4jLong") long index, @Cast("const Nd4jLong*") LongBuffer shapeInfo, @Cast("Nd4jLong*") LongBuffer coords);
|
||||||
@Namespace("shape") public static native void index2coords(@Cast("Nd4jLong") long index, @Cast("const Nd4jLong*") long[] shapeInfo, @Cast("Nd4jLong*") long[] coords);
|
@Namespace("shape") public static native void index2coords(@Cast("Nd4jLong") long index, @Cast("const Nd4jLong*") long[] shapeInfo, @Cast("Nd4jLong*") long[] coords);
|
||||||
|
@Namespace("shape") public static native void index2coords(@Cast("Nd4jLong") long index, @Cast("const Nd4jLong*") LongPointer shapeInfo, IntPointer coords);
|
||||||
|
@Namespace("shape") public static native void index2coords(@Cast("Nd4jLong") long index, @Cast("const Nd4jLong*") LongBuffer shapeInfo, IntBuffer coords);
|
||||||
|
@Namespace("shape") public static native void index2coords(@Cast("Nd4jLong") long index, @Cast("const Nd4jLong*") long[] shapeInfo, int[] coords);
|
||||||
@Namespace("shape") public static native void index2coords(@Cast("Nd4jLong") long index, int rank, @Cast("const Nd4jLong*") LongPointer shape, @Cast("Nd4jLong*") LongPointer coords);
|
@Namespace("shape") public static native void index2coords(@Cast("Nd4jLong") long index, int rank, @Cast("const Nd4jLong*") LongPointer shape, @Cast("Nd4jLong*") LongPointer coords);
|
||||||
@Namespace("shape") public static native void index2coords(@Cast("Nd4jLong") long index, int rank, @Cast("const Nd4jLong*") LongBuffer shape, @Cast("Nd4jLong*") LongBuffer coords);
|
@Namespace("shape") public static native void index2coords(@Cast("Nd4jLong") long index, int rank, @Cast("const Nd4jLong*") LongBuffer shape, @Cast("Nd4jLong*") LongBuffer coords);
|
||||||
@Namespace("shape") public static native void index2coords(@Cast("Nd4jLong") long index, int rank, @Cast("const Nd4jLong*") long[] shape, @Cast("Nd4jLong*") long[] coords);
|
@Namespace("shape") public static native void index2coords(@Cast("Nd4jLong") long index, int rank, @Cast("const Nd4jLong*") long[] shape, @Cast("Nd4jLong*") long[] coords);
|
||||||
|
@ -8043,6 +8049,9 @@ public static final int PREALLOC_SIZE = 33554432;
|
||||||
@Namespace("shape") public static native @Cast("Nd4jLong") long coords2index(@Cast("const Nd4jLong*") LongPointer shapeInfo, @Cast("const Nd4jLong*") LongPointer coords);
|
@Namespace("shape") public static native @Cast("Nd4jLong") long coords2index(@Cast("const Nd4jLong*") LongPointer shapeInfo, @Cast("const Nd4jLong*") LongPointer coords);
|
||||||
@Namespace("shape") public static native @Cast("Nd4jLong") long coords2index(@Cast("const Nd4jLong*") LongBuffer shapeInfo, @Cast("const Nd4jLong*") LongBuffer coords);
|
@Namespace("shape") public static native @Cast("Nd4jLong") long coords2index(@Cast("const Nd4jLong*") LongBuffer shapeInfo, @Cast("const Nd4jLong*") LongBuffer coords);
|
||||||
@Namespace("shape") public static native @Cast("Nd4jLong") long coords2index(@Cast("const Nd4jLong*") long[] shapeInfo, @Cast("const Nd4jLong*") long[] coords);
|
@Namespace("shape") public static native @Cast("Nd4jLong") long coords2index(@Cast("const Nd4jLong*") long[] shapeInfo, @Cast("const Nd4jLong*") long[] coords);
|
||||||
|
@Namespace("shape") public static native @Cast("Nd4jLong") long coords2index(@Cast("const Nd4jLong*") LongPointer shapeInfo, @Const IntPointer coords);
|
||||||
|
@Namespace("shape") public static native @Cast("Nd4jLong") long coords2index(@Cast("const Nd4jLong*") LongBuffer shapeInfo, @Const IntBuffer coords);
|
||||||
|
@Namespace("shape") public static native @Cast("Nd4jLong") long coords2index(@Cast("const Nd4jLong*") long[] shapeInfo, @Const int[] coords);
|
||||||
@Namespace("shape") public static native @Cast("Nd4jLong") long coords2index(int rank, @Cast("const Nd4jLong*") LongPointer shape, @Cast("const Nd4jLong*") LongPointer coords);
|
@Namespace("shape") public static native @Cast("Nd4jLong") long coords2index(int rank, @Cast("const Nd4jLong*") LongPointer shape, @Cast("const Nd4jLong*") LongPointer coords);
|
||||||
@Namespace("shape") public static native @Cast("Nd4jLong") long coords2index(int rank, @Cast("const Nd4jLong*") LongBuffer shape, @Cast("const Nd4jLong*") LongBuffer coords);
|
@Namespace("shape") public static native @Cast("Nd4jLong") long coords2index(int rank, @Cast("const Nd4jLong*") LongBuffer shape, @Cast("const Nd4jLong*") LongBuffer coords);
|
||||||
@Namespace("shape") public static native @Cast("Nd4jLong") long coords2index(int rank, @Cast("const Nd4jLong*") long[] shape, @Cast("const Nd4jLong*") long[] coords);
|
@Namespace("shape") public static native @Cast("Nd4jLong") long coords2index(int rank, @Cast("const Nd4jLong*") long[] shape, @Cast("const Nd4jLong*") long[] coords);
|
||||||
|
@ -8354,6 +8363,10 @@ public static final int PREALLOC_SIZE = 33554432;
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
//////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
//////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
|
||||||
// //////////////////////////////////////////////////////////////////////
|
// //////////////////////////////////////////////////////////////////////
|
||||||
// INLINEDEF _CUDA_HD Nd4jLong getIndexOffset(Nd4jLong index, const Nd4jLong *shapeInfo, Nd4jLong arrLen) {
|
// INLINEDEF _CUDA_HD Nd4jLong getIndexOffset(Nd4jLong index, const Nd4jLong *shapeInfo, Nd4jLong arrLen) {
|
||||||
|
@ -8778,7 +8791,7 @@ public static final int PREALLOC_SIZE = 33554432;
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
//////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -9110,6 +9123,10 @@ public static final int PREALLOC_SIZE = 33554432;
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
//////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
//////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -6,6 +6,9 @@ import java.nio.*;
|
||||||
import org.bytedeco.javacpp.*;
|
import org.bytedeco.javacpp.*;
|
||||||
import org.bytedeco.javacpp.annotation.*;
|
import org.bytedeco.javacpp.annotation.*;
|
||||||
|
|
||||||
|
import static org.bytedeco.openblas.global.openblas_nolapack.*;
|
||||||
|
import static org.bytedeco.openblas.global.openblas.*;
|
||||||
|
|
||||||
public class Nd4jCpu extends org.nd4j.nativeblas.Nd4jCpuHelper {
|
public class Nd4jCpu extends org.nd4j.nativeblas.Nd4jCpuHelper {
|
||||||
static { Loader.load(); }
|
static { Loader.load(); }
|
||||||
|
|
||||||
|
@ -8005,9 +8008,12 @@ public static final int PREALLOC_SIZE = 33554432;
|
||||||
@Namespace("shape") public static native @Cast("Nd4jLong") long getOffset(@Cast("const Nd4jLong*") LongBuffer shapeInfo, @Cast("const Nd4jLong*") LongBuffer indices);
|
@Namespace("shape") public static native @Cast("Nd4jLong") long getOffset(@Cast("const Nd4jLong*") LongBuffer shapeInfo, @Cast("const Nd4jLong*") LongBuffer indices);
|
||||||
@Namespace("shape") public static native @Cast("Nd4jLong") long getOffset(@Cast("const Nd4jLong*") long[] shapeInfo, @Cast("const Nd4jLong*") long[] indices, @Cast("Nd4jLong") long baseOffset/*=0*/);
|
@Namespace("shape") public static native @Cast("Nd4jLong") long getOffset(@Cast("const Nd4jLong*") long[] shapeInfo, @Cast("const Nd4jLong*") long[] indices, @Cast("Nd4jLong") long baseOffset/*=0*/);
|
||||||
@Namespace("shape") public static native @Cast("Nd4jLong") long getOffset(@Cast("const Nd4jLong*") long[] shapeInfo, @Cast("const Nd4jLong*") long[] indices);
|
@Namespace("shape") public static native @Cast("Nd4jLong") long getOffset(@Cast("const Nd4jLong*") long[] shapeInfo, @Cast("const Nd4jLong*") long[] indices);
|
||||||
@Namespace("shape") public static native @Cast("Nd4jLong") long getOffset(@Cast("const Nd4jLong*") LongPointer shapeInfo, @Cast("uint*") @StdVector IntPointer indices);
|
@Namespace("shape") public static native @Cast("Nd4jLong") long getOffset(@Cast("const Nd4jLong*") LongPointer shapeInfo, @Const IntPointer indices, @Cast("Nd4jLong") long baseOffset/*=0*/);
|
||||||
@Namespace("shape") public static native @Cast("Nd4jLong") long getOffset(@Cast("const Nd4jLong*") LongBuffer shapeInfo, @Cast("uint*") @StdVector IntBuffer indices);
|
@Namespace("shape") public static native @Cast("Nd4jLong") long getOffset(@Cast("const Nd4jLong*") LongPointer shapeInfo, @Const IntPointer indices);
|
||||||
@Namespace("shape") public static native @Cast("Nd4jLong") long getOffset(@Cast("const Nd4jLong*") long[] shapeInfo, @Cast("uint*") @StdVector int[] indices);
|
@Namespace("shape") public static native @Cast("Nd4jLong") long getOffset(@Cast("const Nd4jLong*") LongBuffer shapeInfo, @Const IntBuffer indices, @Cast("Nd4jLong") long baseOffset/*=0*/);
|
||||||
|
@Namespace("shape") public static native @Cast("Nd4jLong") long getOffset(@Cast("const Nd4jLong*") LongBuffer shapeInfo, @Const IntBuffer indices);
|
||||||
|
@Namespace("shape") public static native @Cast("Nd4jLong") long getOffset(@Cast("const Nd4jLong*") long[] shapeInfo, @Const int[] indices, @Cast("Nd4jLong") long baseOffset/*=0*/);
|
||||||
|
@Namespace("shape") public static native @Cast("Nd4jLong") long getOffset(@Cast("const Nd4jLong*") long[] shapeInfo, @Const int[] indices);
|
||||||
|
|
||||||
@Namespace("shape") public static native @Cast("Nd4jLong*") LongPointer createShapeInfo(@Cast("Nd4jLong*") LongPointer shape, @Cast("Nd4jLong*") LongPointer stride, int rank);
|
@Namespace("shape") public static native @Cast("Nd4jLong*") LongPointer createShapeInfo(@Cast("Nd4jLong*") LongPointer shape, @Cast("Nd4jLong*") LongPointer stride, int rank);
|
||||||
@Namespace("shape") public static native @Cast("Nd4jLong*") LongBuffer createShapeInfo(@Cast("Nd4jLong*") LongBuffer shape, @Cast("Nd4jLong*") LongBuffer stride, int rank);
|
@Namespace("shape") public static native @Cast("Nd4jLong*") LongBuffer createShapeInfo(@Cast("Nd4jLong*") LongBuffer shape, @Cast("Nd4jLong*") LongBuffer stride, int rank);
|
||||||
|
@ -8024,6 +8030,9 @@ public static final int PREALLOC_SIZE = 33554432;
|
||||||
@Namespace("shape") public static native void index2coords(@Cast("Nd4jLong") long index, @Cast("const Nd4jLong*") LongPointer shapeInfo, @Cast("Nd4jLong*") LongPointer coords);
|
@Namespace("shape") public static native void index2coords(@Cast("Nd4jLong") long index, @Cast("const Nd4jLong*") LongPointer shapeInfo, @Cast("Nd4jLong*") LongPointer coords);
|
||||||
@Namespace("shape") public static native void index2coords(@Cast("Nd4jLong") long index, @Cast("const Nd4jLong*") LongBuffer shapeInfo, @Cast("Nd4jLong*") LongBuffer coords);
|
@Namespace("shape") public static native void index2coords(@Cast("Nd4jLong") long index, @Cast("const Nd4jLong*") LongBuffer shapeInfo, @Cast("Nd4jLong*") LongBuffer coords);
|
||||||
@Namespace("shape") public static native void index2coords(@Cast("Nd4jLong") long index, @Cast("const Nd4jLong*") long[] shapeInfo, @Cast("Nd4jLong*") long[] coords);
|
@Namespace("shape") public static native void index2coords(@Cast("Nd4jLong") long index, @Cast("const Nd4jLong*") long[] shapeInfo, @Cast("Nd4jLong*") long[] coords);
|
||||||
|
@Namespace("shape") public static native void index2coords(@Cast("Nd4jLong") long index, @Cast("const Nd4jLong*") LongPointer shapeInfo, IntPointer coords);
|
||||||
|
@Namespace("shape") public static native void index2coords(@Cast("Nd4jLong") long index, @Cast("const Nd4jLong*") LongBuffer shapeInfo, IntBuffer coords);
|
||||||
|
@Namespace("shape") public static native void index2coords(@Cast("Nd4jLong") long index, @Cast("const Nd4jLong*") long[] shapeInfo, int[] coords);
|
||||||
@Namespace("shape") public static native void index2coords(@Cast("Nd4jLong") long index, int rank, @Cast("const Nd4jLong*") LongPointer shape, @Cast("Nd4jLong*") LongPointer coords);
|
@Namespace("shape") public static native void index2coords(@Cast("Nd4jLong") long index, int rank, @Cast("const Nd4jLong*") LongPointer shape, @Cast("Nd4jLong*") LongPointer coords);
|
||||||
@Namespace("shape") public static native void index2coords(@Cast("Nd4jLong") long index, int rank, @Cast("const Nd4jLong*") LongBuffer shape, @Cast("Nd4jLong*") LongBuffer coords);
|
@Namespace("shape") public static native void index2coords(@Cast("Nd4jLong") long index, int rank, @Cast("const Nd4jLong*") LongBuffer shape, @Cast("Nd4jLong*") LongBuffer coords);
|
||||||
@Namespace("shape") public static native void index2coords(@Cast("Nd4jLong") long index, int rank, @Cast("const Nd4jLong*") long[] shape, @Cast("Nd4jLong*") long[] coords);
|
@Namespace("shape") public static native void index2coords(@Cast("Nd4jLong") long index, int rank, @Cast("const Nd4jLong*") long[] shape, @Cast("Nd4jLong*") long[] coords);
|
||||||
|
@ -8043,6 +8052,9 @@ public static final int PREALLOC_SIZE = 33554432;
|
||||||
@Namespace("shape") public static native @Cast("Nd4jLong") long coords2index(@Cast("const Nd4jLong*") LongPointer shapeInfo, @Cast("const Nd4jLong*") LongPointer coords);
|
@Namespace("shape") public static native @Cast("Nd4jLong") long coords2index(@Cast("const Nd4jLong*") LongPointer shapeInfo, @Cast("const Nd4jLong*") LongPointer coords);
|
||||||
@Namespace("shape") public static native @Cast("Nd4jLong") long coords2index(@Cast("const Nd4jLong*") LongBuffer shapeInfo, @Cast("const Nd4jLong*") LongBuffer coords);
|
@Namespace("shape") public static native @Cast("Nd4jLong") long coords2index(@Cast("const Nd4jLong*") LongBuffer shapeInfo, @Cast("const Nd4jLong*") LongBuffer coords);
|
||||||
@Namespace("shape") public static native @Cast("Nd4jLong") long coords2index(@Cast("const Nd4jLong*") long[] shapeInfo, @Cast("const Nd4jLong*") long[] coords);
|
@Namespace("shape") public static native @Cast("Nd4jLong") long coords2index(@Cast("const Nd4jLong*") long[] shapeInfo, @Cast("const Nd4jLong*") long[] coords);
|
||||||
|
@Namespace("shape") public static native @Cast("Nd4jLong") long coords2index(@Cast("const Nd4jLong*") LongPointer shapeInfo, @Const IntPointer coords);
|
||||||
|
@Namespace("shape") public static native @Cast("Nd4jLong") long coords2index(@Cast("const Nd4jLong*") LongBuffer shapeInfo, @Const IntBuffer coords);
|
||||||
|
@Namespace("shape") public static native @Cast("Nd4jLong") long coords2index(@Cast("const Nd4jLong*") long[] shapeInfo, @Const int[] coords);
|
||||||
@Namespace("shape") public static native @Cast("Nd4jLong") long coords2index(int rank, @Cast("const Nd4jLong*") LongPointer shape, @Cast("const Nd4jLong*") LongPointer coords);
|
@Namespace("shape") public static native @Cast("Nd4jLong") long coords2index(int rank, @Cast("const Nd4jLong*") LongPointer shape, @Cast("const Nd4jLong*") LongPointer coords);
|
||||||
@Namespace("shape") public static native @Cast("Nd4jLong") long coords2index(int rank, @Cast("const Nd4jLong*") LongBuffer shape, @Cast("const Nd4jLong*") LongBuffer coords);
|
@Namespace("shape") public static native @Cast("Nd4jLong") long coords2index(int rank, @Cast("const Nd4jLong*") LongBuffer shape, @Cast("const Nd4jLong*") LongBuffer coords);
|
||||||
@Namespace("shape") public static native @Cast("Nd4jLong") long coords2index(int rank, @Cast("const Nd4jLong*") long[] shape, @Cast("const Nd4jLong*") long[] coords);
|
@Namespace("shape") public static native @Cast("Nd4jLong") long coords2index(int rank, @Cast("const Nd4jLong*") long[] shape, @Cast("const Nd4jLong*") long[] coords);
|
||||||
|
@ -8354,6 +8366,10 @@ public static final int PREALLOC_SIZE = 33554432;
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
//////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
//////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
|
||||||
// //////////////////////////////////////////////////////////////////////
|
// //////////////////////////////////////////////////////////////////////
|
||||||
// INLINEDEF _CUDA_HD Nd4jLong getIndexOffset(Nd4jLong index, const Nd4jLong *shapeInfo, Nd4jLong arrLen) {
|
// INLINEDEF _CUDA_HD Nd4jLong getIndexOffset(Nd4jLong index, const Nd4jLong *shapeInfo, Nd4jLong arrLen) {
|
||||||
|
@ -8778,7 +8794,7 @@ public static final int PREALLOC_SIZE = 33554432;
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
//////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -9110,6 +9126,10 @@ public static final int PREALLOC_SIZE = 33554432;
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
//////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
//////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue