Shugeo release fix2 (#70)

* Corrected input checking and tests for bitcast op.

* Fixed an issue with non_max_suppression form generation and processing with score threshold given.

* Fixed bilinear resize kernel and tests.

* push for Serhii

Signed-off-by: raver119 <raver119@gmail.com>

* Added test for nearest_neighbor resize with int input.

* Added data type check for input/output match.

* Eliminate error in macros.

* Improved output message for type checking.

* Fixed input/output types for op.

* Eliminated waste logging.

* Refactored resize_bilinear helper for multithreading for cpu platform.

* Cosmetic changes only.

* Fixed error for string substitution.

* Skip test for cbow_batch with cuda.

* fix for resizeNearestNeighbor output dtype

Signed-off-by: raver119 <raver119@gmail.com>

* Refactored non_max_suppression helper.

* Refactored shape generation and input handling.

* Added additional test.
master
shugeo 2019-11-22 21:42:44 +02:00 committed by raver119
parent 289d9dc141
commit 4187190609
13 changed files with 432 additions and 127 deletions

View File

@ -30,6 +30,17 @@ namespace nd4j {
auto input = INPUT_VARIABLE(0); auto input = INPUT_VARIABLE(0);
auto output = OUTPUT_VARIABLE(0); auto output = OUTPUT_VARIABLE(0);
// when empty - nothing to do // when empty - nothing to do
DataType newType = DataTypeUtils::fromInt(INT_ARG(0));
DataType oldType = input->dataType();
// correct output shape to conform with output data type
auto inputSize = DataTypeUtils::sizeOf(oldType);
auto outputSize = DataTypeUtils::sizeOf(newType);
auto lastSize = outputSize / inputSize;
if (inputSize < outputSize) {
REQUIRE_TRUE(input->sizeAt(-1) == lastSize, 0,
"BITCAST: %llu > %llu. So last dimension should be %i, but %i given.", inputSize,
outputSize, lastSize, input->sizeAt(-1));
}
if(input->isEmpty()){ if(input->isEmpty()){
REQUIRE_TRUE(output->isEmpty(), 0, "BITCAST: If input is empty, output array must also be empty."); REQUIRE_TRUE(output->isEmpty(), 0, "BITCAST: If input is empty, output array must also be empty.");
return Status::OK(); return Status::OK();
@ -70,7 +81,7 @@ namespace nd4j {
auto outputShape = ConstantShapeHelper::getInstance()->createShapeInfo(newType, shape::order(inShape), shapeOf); auto outputShape = ConstantShapeHelper::getInstance()->createShapeInfo(newType, shape::order(inShape), shapeOf);
return SHAPELIST(outputShape); return SHAPELIST(outputShape);
} }
REQUIRE_TRUE(shape::sizeAt(inShape, -1) == outputSize / inputSize, 0, "BITCAST: %ull > %ull. So last dimension should be %ull, but %i given.", inputSize, outputSize, outputSize / inputSize, shape::sizeAt(inShape, -1)); REQUIRE_TRUE(shape::sizeAt(inShape, -1) == outputSize / inputSize, 0, "BITCAST: %llu > %llu. So last dimension should be %i, but %i given.", inputSize, outputSize, outputSize / inputSize, shape::sizeAt(inShape, -1));
std::vector<Nd4jLong> shapeOf(inputRank - 1); std::vector<Nd4jLong> shapeOf(inputRank - 1);
for (auto i = 0; i < shapeOf.size(); ++i) { for (auto i = 0; i < shapeOf.size(); ++i) {

View File

@ -37,6 +37,22 @@ namespace nd4j {
else else
REQUIRE_TRUE(false, 0, "image.non_max_suppression: Max output size argument cannot be retrieved."); REQUIRE_TRUE(false, 0, "image.non_max_suppression: Max output size argument cannot be retrieved.");
double overlayThreshold = 0.5;
double scoreThreshold = - DataTypeUtils::infOrMax<float>();
if (block.width() > 3) {
overlayThreshold = INPUT_VARIABLE(3)->e<double>(0);
}
else if (block.getTArguments()->size() > 0) {
overlayThreshold = T_ARG(0);
}
if (block.width() > 4) {
scoreThreshold = INPUT_VARIABLE(4)->e<double>(0);
}
else if (block.getTArguments()->size() > 1) {
scoreThreshold = T_ARG(1);
}
if (boxes->isEmpty() || scales->isEmpty()) if (boxes->isEmpty() || scales->isEmpty())
return Status::OK(); return Status::OK();
@ -44,15 +60,6 @@ namespace nd4j {
REQUIRE_TRUE(boxes->sizeAt(1) == 4, 0, "image.non_max_suppression: The last dimension of boxes array should be 4, but %i is given", boxes->sizeAt(1)); REQUIRE_TRUE(boxes->sizeAt(1) == 4, 0, "image.non_max_suppression: The last dimension of boxes array should be 4, but %i is given", boxes->sizeAt(1));
REQUIRE_TRUE(scales->rankOf() == 1 && scales->lengthOf() == boxes->sizeAt(0), 0, "image.non_max_suppression: The rank of scales array should be 1, but %i is given", boxes->rankOf()); REQUIRE_TRUE(scales->rankOf() == 1 && scales->lengthOf() == boxes->sizeAt(0), 0, "image.non_max_suppression: The rank of scales array should be 1, but %i is given", boxes->rankOf());
if (scales->lengthOf() < maxOutputSize)
maxOutputSize = scales->lengthOf();
double overlayThreshold = 0.5;
double scoreThreshold = - DataTypeUtils::infOrMax<float>();
if (block.getTArguments()->size() > 0)
overlayThreshold = T_ARG(0);
if (block.getTArguments()->size() > 1)
scoreThreshold = T_ARG(1);
helpers::nonMaxSuppression(block.launchContext(), boxes, scales, maxOutputSize, overlayThreshold, scoreThreshold, output); helpers::nonMaxSuppression(block.launchContext(), boxes, scales, maxOutputSize, overlayThreshold, scoreThreshold, output);
return Status::OK(); return Status::OK();
} }
@ -70,10 +77,19 @@ namespace nd4j {
else else
REQUIRE_TRUE(false, 0, "image.non_max_suppression: Max output size argument cannot be retrieved."); REQUIRE_TRUE(false, 0, "image.non_max_suppression: Max output size argument cannot be retrieved.");
auto actualIndicesCount = shape::sizeAt(in, 0);
Nd4jLong boxSize = shape::sizeAt(in, 0); if (block.getTArguments()->size() > 1 || block.width() > 4) {
if (boxSize < maxOutputSize) auto scoreThreshold = block.getTArguments()->size() > 1?T_ARG(1):INPUT_VARIABLE(4)->e<double>(0);
maxOutputSize = boxSize; auto scales = INPUT_VARIABLE(1);
scales->syncToHost();
for (auto e = 0; e < scales->lengthOf(); e++) {
if (scales->e<float>(e) < (float)scoreThreshold) {
actualIndicesCount--;
}
}
}
if (actualIndicesCount < maxOutputSize)
maxOutputSize = actualIndicesCount;
outputShape = ConstantShapeHelper::getInstance()->vectorShapeInfo(maxOutputSize, DataType::INT32); outputShape = ConstantShapeHelper::getInstance()->vectorShapeInfo(maxOutputSize, DataType::INT32);

View File

@ -54,7 +54,9 @@ namespace nd4j {
auto inRank = image->rankOf(); auto inRank = image->rankOf();
REQUIRE_TRUE(inRank == 4 || inRank == 3, 0, "resize_nearest_neighbor: Input should be 4D tensor, but rank %i occured"); REQUIRE_TRUE(inRank == 4 || inRank == 3, 0, "resize_nearest_neighbor: Input should be 4D tensor, but rank %i occured");
REQUIRE_TRUE(inRank == output->rankOf(), 0, "resize_nearest_neighbor: Input and output ranks should be equals, but %i and %i occured.", inRank, output->rankOf()); REQUIRE_TRUE(inRank == output->rankOf(), 0, "resize_nearest_neighbor: Input and output ranks should be equals, but %i and %i occured.", inRank, output->rankOf());
REQUIRE_TRUE(image->dataType() == output->dataType(), 0, "resize_nearest_neighbor: Input and output types should be the same, but `%s' occured instead.", DataTypeUtils::asString(output->dataType()).c_str());
auto source = inRank == 4?*image:image->reshape(image->ordering(), {1, image->sizeAt(0), image->sizeAt(1), image->sizeAt(2)}); auto source = inRank == 4?*image:image->reshape(image->ordering(), {1, image->sizeAt(0), image->sizeAt(1), image->sizeAt(2)});
auto target = inRank == 4?*output:output->reshape(output->ordering(), {1, output->sizeAt(0), output->sizeAt(1), output->sizeAt(2)}); auto target = inRank == 4?*output:output->reshape(output->ordering(), {1, output->sizeAt(0), output->sizeAt(1), output->sizeAt(2)});
return helpers::resizeNeighborFunctor(block.launchContext(), inRank==4?image:&source, width, height, center, inRank == 4?output:&target); return helpers::resizeNeighborFunctor(block.launchContext(), inRank==4?image:&source, width, height, center, inRank == 4?output:&target);
@ -105,8 +107,8 @@ namespace nd4j {
} }
DECLARE_TYPES(resize_nearest_neighbor) { DECLARE_TYPES(resize_nearest_neighbor) {
getOpDescriptor() getOpDescriptor()
->setAllowedInputTypes(nd4j::DataType::ANY) ->setAllowedInputTypes({ALL_INTS, ALL_FLOATS})
->setAllowedOutputTypes({ALL_FLOATS}); ->setAllowedOutputTypes({ALL_INTS, ALL_FLOATS});
} }
} }

View File

@ -26,7 +26,6 @@
namespace nd4j { namespace nd4j {
namespace ops { namespace ops {
CUSTOM_OP_IMPL(broadcast_to, 2, 1, false, 0, 0) { CUSTOM_OP_IMPL(broadcast_to, 2, 1, false, 0, 0) {
auto input = INPUT_VARIABLE(0); auto input = INPUT_VARIABLE(0);

View File

@ -41,11 +41,11 @@ namespace ops {
namespace helpers { namespace helpers {
struct BilinearInterpolationData { struct BilinearInterpolationData {
Nd4jLong bottomIndex; // Lower source index used in the interpolation Nd4jLong _bottomIndex; // Lower source index used in the interpolation
Nd4jLong topIndex; // Upper source index used in the interpolation Nd4jLong _topIndex; // Upper source index used in the interpolation
// 1-D linear iterpolation scale (see: // 1D linear iterpolation scale (see:
// https://en.wikipedia.org/wiki/Bilinear_interpolation) // https://en.wikipedia.org/wiki/Bilinear_interpolation)
double interpolarValue; double _interpolarValue;
}; };
// calculateResizeScale determines the float scaling factor. // calculateResizeScale determines the float scaling factor.
inline float calculateResizeScale(Nd4jLong inSize, Nd4jLong outSize, inline float calculateResizeScale(Nd4jLong inSize, Nd4jLong outSize,
@ -137,16 +137,16 @@ namespace helpers {
Nd4jLong inSize, Nd4jLong inSize,
double scale, double scale,
BilinearInterpolationData *interpolationData) { BilinearInterpolationData *interpolationData) {
interpolationData[outSize].bottomIndex = 0; interpolationData[outSize]._bottomIndex = 0;
interpolationData[outSize].topIndex = 0; interpolationData[outSize]._topIndex = 0;
auto func = PRAGMA_THREADS_FOR { auto func = PRAGMA_THREADS_FOR {
for (auto k = start; k < stop; k++) { for (auto k = start; k < stop; k++) {
auto i = (outSize - k - 1); auto i = (outSize - k - 1);
double in = i * scale; double in = i * scale;
interpolationData[i].bottomIndex = static_cast<Nd4jLong>(in); interpolationData[i]._bottomIndex = static_cast<Nd4jLong>(in);
interpolationData[i].topIndex = nd4j::math::nd4j_min(interpolationData[i].bottomIndex + 1, inSize - 1); interpolationData[i]._topIndex = nd4j::math::nd4j_min(interpolationData[i]._bottomIndex + 1, inSize - 1);
interpolationData[i].interpolarValue = in - interpolationData[i].bottomIndex; interpolationData[i]._interpolarValue = in - interpolationData[i]._bottomIndex;
} }
}; };
samediff::Threads::parallel_for(func, 0, outSize); samediff::Threads::parallel_for(func, 0, outSize);
@ -159,8 +159,8 @@ namespace helpers {
static void static void
resizeImage(NDArray const *images, Nd4jLong batchSize, Nd4jLong inHeight, Nd4jLong inWidth, Nd4jLong outHeight, resizeImage(NDArray const *images, Nd4jLong batchSize, Nd4jLong inHeight, Nd4jLong inWidth, Nd4jLong outHeight,
Nd4jLong outWidth, Nd4jLong channels, Nd4jLong outWidth, Nd4jLong channels,
std::vector<BilinearInterpolationData> const &xs, std::vector<BilinearInterpolationData> const& xs,
std::vector<BilinearInterpolationData> const &ys, std::vector<BilinearInterpolationData> const& ys,
NDArray *output); NDArray *output);
template<typename T> template<typename T>
@ -175,10 +175,10 @@ namespace helpers {
Nd4jLong inBatchNumValues = inHeight * inRowSize; Nd4jLong inBatchNumValues = inHeight * inRowSize;
Nd4jLong outRowSize = outWidth * channels; Nd4jLong outRowSize = outWidth * channels;
T const *pInput = images->getDataBuffer()->primaryAsT<T>(); // this works only with 'c' direction T const *pInputBuf = images->getDataBuffer()->primaryAsT<T>(); // this works only with 'c' direction
BilinearInterpolationData const *xs_ = xs.data(); BilinearInterpolationData const* xsPtr = xs.data();
T* pOutput = output->dataBuffer()->primaryAsT<T>(); T* pOutputBuf = output->dataBuffer()->primaryAsT<T>();
auto computeBilinear = [](double topLeft, double topRight, auto computeBilinear = [](double topLeft, double topRight,
double bottomLeft, double bottomRight, double bottomLeft, double bottomRight,
double xVal, double yVal) { double xVal, double yVal) {
@ -188,28 +188,27 @@ namespace helpers {
}; };
auto func = PRAGMA_THREADS_FOR { auto func = PRAGMA_THREADS_FOR {
for (auto b = start; b < stop; ++b) { for (auto batch = start; batch < stop; ++batch) {
auto pInput = pInputBuf + batch * inBatchNumValues;
for (auto y = 0; y < outHeight; ++y) { for (auto y = 0; y < outHeight; ++y) {
const T *ys_input_lower_ptr = pInput + ys[y].bottomIndex * inRowSize; auto pOutput = pOutputBuf + (batch * outHeight + y) * outRowSize;
const T *ys_input_upper_ptr = pInput + ys[y].topIndex * inRowSize; const T* ysInputLowerPtr = pInput + ys[y]._bottomIndex * inRowSize;
double yVal = ys[y].interpolarValue; const T* ysInputUpperPtr = pInput + ys[y]._topIndex * inRowSize;
double yVal = ys[y]._interpolarValue;
for (auto x = 0; x < outWidth; ++x) { for (auto x = 0; x < outWidth; ++x) {
auto xsBottom = xs_[x].bottomIndex; auto xsBottom = xsPtr[x]._bottomIndex;
auto xsTop = xs_[x].topIndex; auto xsTop = xsPtr[x]._topIndex;
auto xVal = xs_[x].interpolarValue; auto xVal = xsPtr[x]._interpolarValue;
for (auto c = 0; c < channels; ++c) { for (auto c = 0; c < channels; ++c) {
double topLeft(ys_input_lower_ptr[xsBottom + c]); double topLeft(ysInputLowerPtr[xsBottom + c]);
double topRight(ys_input_lower_ptr[xsTop + c]); double topRight(ysInputLowerPtr[xsTop + c]);
double bottomLeft(ys_input_upper_ptr[xsBottom + c]); double bottomLeft(ysInputUpperPtr[xsBottom + c]);
double bottomRight(ys_input_upper_ptr[xsTop + c]); double bottomRight(ysInputUpperPtr[xsTop + c]);
pOutput[x * channels + c] = pOutput[x * channels + c] = computeBilinear(topLeft, topRight, bottomLeft, bottomRight,
computeBilinear(topLeft, topRight, bottomLeft, bottomRight,
xVal, yVal); xVal, yVal);
} }
} }
pOutput += outRowSize;
} }
pInput += inBatchNumValues;
} }
}; };
samediff::Threads::parallel_tad(func, 0, batchSize); samediff::Threads::parallel_tad(func, 0, batchSize);
@ -257,8 +256,8 @@ namespace helpers {
// Scale x interpolation weights to avoid a multiplication during iteration. // Scale x interpolation weights to avoid a multiplication during iteration.
auto func = PRAGMA_THREADS_FOR { auto func = PRAGMA_THREADS_FOR {
for (auto i = start; i < stop; i += increment) { for (auto i = start; i < stop; i += increment) {
xs[i].bottomIndex *= channels; xs[i]._bottomIndex *= channels;
xs[i].topIndex *= channels; xs[i]._topIndex *= channels;
} }
}; };
samediff::Threads::parallel_for(func, 0, xsSize); samediff::Threads::parallel_for(func, 0, xsSize);

View File

@ -33,23 +33,27 @@ namespace helpers {
double scoreThreshold, NDArray* output) { double scoreThreshold, NDArray* output) {
std::vector<int> indices(scales->lengthOf()); std::vector<int> indices(scales->lengthOf());
std::iota(indices.begin(), indices.end(), 0); std::iota(indices.begin(), indices.end(), 0);
auto actualIndicesCount = indices.size();
for (auto e = 0; e < scales->lengthOf(); e++) { for (auto e = 0; e < scales->lengthOf(); e++) {
if (scales->e<double>(e) < scoreThreshold) indices[e] = -1; if (scales->e<float>(e) < (float)scoreThreshold) {
indices[e] = -1;
actualIndicesCount--;
} }
std::sort(indices.begin(), indices.end(), [scales](int i, int j) {return scales->e<T>(i) > scales->e<T>(j);}); }
std::sort(indices.begin(), indices.end(), [scales](int i, int j) {return i >= 0 && j >=0?scales->e<T>(i) > scales->e<T>(j):(i > j);});
// std::vector<int> selected(output->lengthOf()); // std::vector<int> selected(output->lengthOf());
std::vector<int> selectedIndices(output->lengthOf(), 0); std::vector<int> selectedIndices(output->lengthOf(), 0);
auto needToSuppressWithThreshold = [] (NDArray& boxes, int previousIndex, int nextIndex, T threshold) -> bool { auto needToSuppressWithThreshold = [] (NDArray& boxes, int previousIndex, int nextIndex, T threshold) -> bool {
if (previousIndex < 0 || nextIndex < 0) return true; if (previousIndex < 0 || nextIndex < 0) return true;
T minYPrev = nd4j::math::nd4j_min(boxes.e<T>(previousIndex, 0), boxes.e<T>(previousIndex, 2)); T minYPrev = nd4j::math::nd4j_min(boxes.t<T>(previousIndex, 0), boxes.t<T>(previousIndex, 2));
T minXPrev = nd4j::math::nd4j_min(boxes.e<T>(previousIndex, 1), boxes.e<T>(previousIndex, 3)); T minXPrev = nd4j::math::nd4j_min(boxes.t<T>(previousIndex, 1), boxes.t<T>(previousIndex, 3));
T maxYPrev = nd4j::math::nd4j_max(boxes.e<T>(previousIndex, 0), boxes.e<T>(previousIndex, 2)); T maxYPrev = nd4j::math::nd4j_max(boxes.t<T>(previousIndex, 0), boxes.t<T>(previousIndex, 2));
T maxXPrev = nd4j::math::nd4j_max(boxes.e<T>(previousIndex, 1), boxes.e<T>(previousIndex, 3)); T maxXPrev = nd4j::math::nd4j_max(boxes.t<T>(previousIndex, 1), boxes.t<T>(previousIndex, 3));
T minYNext = nd4j::math::nd4j_min(boxes.e<T>(nextIndex, 0), boxes.e<T>(nextIndex, 2)); T minYNext = nd4j::math::nd4j_min(boxes.t<T>(nextIndex, 0), boxes.t<T>(nextIndex, 2));
T minXNext = nd4j::math::nd4j_min(boxes.e<T>(nextIndex, 1), boxes.e<T>(nextIndex, 3)); T minXNext = nd4j::math::nd4j_min(boxes.t<T>(nextIndex, 1), boxes.t<T>(nextIndex, 3));
T maxYNext = nd4j::math::nd4j_max(boxes.e<T>(nextIndex, 0), boxes.e<T>(nextIndex, 2)); T maxYNext = nd4j::math::nd4j_max(boxes.t<T>(nextIndex, 0), boxes.t<T>(nextIndex, 2));
T maxXNext = nd4j::math::nd4j_max(boxes.e<T>(nextIndex, 1), boxes.e<T>(nextIndex, 3)); T maxXNext = nd4j::math::nd4j_max(boxes.t<T>(nextIndex, 1), boxes.t<T>(nextIndex, 3));
T areaPrev = (maxYPrev - minYPrev) * (maxXPrev - minXPrev); T areaPrev = (maxYPrev - minYPrev) * (maxXPrev - minXPrev);
T areaNext = (maxYNext - minYNext) * (maxXNext - minXNext); T areaNext = (maxYNext - minYNext) * (maxXNext - minXNext);
@ -67,7 +71,7 @@ namespace helpers {
}; };
// int numSelected = 0; // int numSelected = 0;
int numBoxes = boxes->sizeAt(0); int numBoxes = actualIndicesCount; //boxes->sizeAt(0);
int numSelected = 0; int numSelected = 0;
for (int i = 0; i < numBoxes; ++i) { for (int i = 0; i < numBoxes; ++i) {

View File

@ -77,22 +77,19 @@ namespace helpers {
Nd4jLong outWidth, Nd4jLong outHeight, Nd4jLong channels, Nd4jLong inRowSize, Nd4jLong outRowSize, Nd4jLong inBatchNumValues, Nd4jLong outWidth, Nd4jLong outHeight, Nd4jLong channels, Nd4jLong inRowSize, Nd4jLong outRowSize, Nd4jLong inBatchNumValues,
BilinearInterpolationData* xs_, BilinearInterpolationData* ys_) { BilinearInterpolationData* xs_, BilinearInterpolationData* ys_) {
if (blockIdx.x < batchSize) { // blockIdx.x as batch index for (auto batch = blockIdx.x; batch < batchSize; batch += gridDim.x ) { // blockIdx.x as batch index
auto pX = input + blockIdx.x * inBatchNumValues; auto pX = input + batch * inBatchNumValues;
auto channelStart = blockIdx.z * blockDim.z + threadIdx.z;
auto step = blockDim.z * gridDim.z;
for (Nd4jLong y = threadIdx.x; y < outHeight; y += blockDim.x) { for (Nd4jLong y = threadIdx.x; y < outHeight; y += blockDim.x) {
const T *ys_input_lower_ptr = pX + ys_[y].bottomIndex * inRowSize; const T *ys_input_lower_ptr = pX + ys_[y].bottomIndex * inRowSize;
const T *ys_input_upper_ptr = pX + ys_[y].topIndex * inRowSize; const T *ys_input_upper_ptr = pX + ys_[y].topIndex * inRowSize;
double yVal = ys_[y].interpolarValue; double yVal = ys_[y].interpolarValue;
auto pZ = outputYptr + y * outRowSize; auto pZ = outputYptr + (batch * outHeight + y) * outRowSize;
for (Nd4jLong x = threadIdx.y; x < outWidth; x += blockDim.y) { for (Nd4jLong x = threadIdx.y; x < outWidth; x += blockDim.y) {
auto xsBottom = xs_[x].bottomIndex; auto xsBottom = xs_[x].bottomIndex;
auto xsTop = xs_[x].topIndex; auto xsTop = xs_[x].topIndex;
auto xVal = xs_[x].interpolarValue; auto xVal = xs_[x].interpolarValue;
// process interpolation for all channels // process interpolation for all channels
for (int c = channelStart; c < channels; c += step) { for (int c = threadIdx.z; c < channels; c += blockDim.z) {
double topLeft(ys_input_lower_ptr[xsBottom + c]); double topLeft(ys_input_lower_ptr[xsBottom + c]);
double topRight(ys_input_lower_ptr[xsTop + c]); double topRight(ys_input_lower_ptr[xsTop + c]);
double bottomLeft(ys_input_upper_ptr[xsBottom + c]); double bottomLeft(ys_input_upper_ptr[xsBottom + c]);
@ -120,9 +117,15 @@ namespace helpers {
auto stream = context->getCudaStream(); auto stream = context->getCudaStream();
T const *input_b_ptr = reinterpret_cast<T const *>(images->getSpecialBuffer()); // this works only with 'c' direction T const *input_b_ptr = reinterpret_cast<T const *>(images->getSpecialBuffer()); // this works only with 'c' direction
T *output_y_ptr = reinterpret_cast<T *>(output->specialBuffer()); T *output_y_ptr = reinterpret_cast<T *>(output->specialBuffer());
dim3 batchSizeBlock(batchSize, 1, 1);
resizeImageKernel<T><<<batchSize, outHeight, 256, *stream>>>(input_b_ptr, images->getSpecialShapeInfo(), output_y_ptr, output->specialShapeInfo(), batchSize, dim3 pictureBlock(outHeight, outWidth, channels);
resizeImageKernel<T><<<256, pictureBlock, 256, *stream>>>(input_b_ptr, images->getSpecialShapeInfo(), output_y_ptr, output->specialShapeInfo(), batchSize,
outWidth, outHeight, channels, inRowSize, outRowSize, inBatchNumValues, xs_, ys_); outWidth, outHeight, channels, inRowSize, outRowSize, inBatchNumValues, xs_, ys_);
auto err = cudaStreamSynchronize(*stream);
if (err != 0) {
throw cuda_exception::build("helpers::resizeImage_: Cannot synchronize kernel execution", err);
}
} }
//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
@ -176,7 +179,6 @@ namespace helpers {
NDArray::prepareSpecialUse({output}, {images}); NDArray::prepareSpecialUse({output}, {images});
resizeImage(context, images, batchSize, inHeight, inWidth, outHeight, outWidth, channels, xs_, ys_, output); resizeImage(context, images, batchSize, inHeight, inWidth, outHeight, outWidth, channels, xs_, ys_, output);
NDArray::registerSpecialUse({output}, {images}); NDArray::registerSpecialUse({output}, {images});
err = cudaFree(xs_); err = cudaFree(xs_);
if (err != 0) { if (err != 0) {
throw cuda_exception::build("helpers::resize_image: Cannot deallocate memory for vertical parts rectangulars", err); throw cuda_exception::build("helpers::resize_image: Cannot deallocate memory for vertical parts rectangulars", err);

View File

@ -1530,7 +1530,8 @@ TEST_F(DeclarableOpsTests10, ImageResizeBilinear_Test01) {
TEST_F(DeclarableOpsTests10, ImageResizeBilinear_Test02) { TEST_F(DeclarableOpsTests10, ImageResizeBilinear_Test02) {
NDArray input = NDArrayFactory::create<float>('c', {2, 5,5,3}, {0.7788f, 0.8012f, 0.7244f, NDArray input = NDArrayFactory::create<float>('c', {2, 5,5,3}, {
0.7788f, 0.8012f, 0.7244f,
0.2309f, 0.7271f, 0.1804f, 0.2309f, 0.7271f, 0.1804f,
0.5056f, 0.8925f, 0.5461f, 0.5056f, 0.8925f, 0.5461f,
0.9234f, 0.0856f, 0.7938f, 0.9234f, 0.0856f, 0.7938f,
@ -1581,40 +1582,89 @@ TEST_F(DeclarableOpsTests10, ImageResizeBilinear_Test02) {
0.4739f, 0.7014f, 0.4473f, 0.4739f, 0.7014f, 0.4473f,
0.5171f, 0.1744f, 0.3487f}); 0.5171f, 0.1744f, 0.3487f});
NDArray expected = NDArrayFactory::create<float>('c', {10, 10, 4}, {1., 2., 3., 4., 2.2, 3.2, 4.2, 5.2, 3.4, 4.4, 5.4, 6.4, NDArray expected = NDArrayFactory::create<float>('c', {2, 9, 9, 3}, {
4.6, 5.6, 6.6, 7.6, 5.8, 6.8, 7.8, 8.8, 7., 8., 9., 10., 0.7788f, 0.8012f, 0.7244f, 0.4744111f, 0.7600333f, 0.42217776f,
8.2, 9.2, 10.2, 11.2, 9., 10., 11., 12., 9., 10., 11., 12., 0.26142225f, 0.7454778f, 0.22103335f, 0.41403335f, 0.8373667f, 0.42420003f,
9., 10., 11., 12., 3.4, 4.4, 5.4, 6.4, 4.6, 5.6, 6.6, 7.6, 0.59844446f, 0.71318877f, 0.6011445f, 0.83055556f, 0.264911f, 0.7387556f,
5.8, 6.8, 7.8, 8.8, 7.0, 8., 9., 10., 8.2, 9.2, 10.2, 11.2, 0.83529997f, 0.2422334f, 0.5823999f, 0.6884666f, 0.5032889f, 0.23006654f,
9.4,10.4, 11.4, 12.4,10.6, 11.6,12.6, 13.6,11.4, 12.4, 13.4, 14.4, 0.6591f, 0.5555f, 0.1596f, 0.5176333f, 0.44208887f , 0.5827889f,
11.4,12.4, 13.4, 14.4,11.4, 12.4,13.4, 14.4, 5.8, 6.8, 7.8, 8.8, 0.5938309f, 0.5646876f, 0.5123568f, 0.61811364f, 0.6748667f, 0.44617534f,
7., 8., 9., 10., 8.2, 9.2,10.2, 11.2, 9.4, 10.4, 11.4, 12.4, 0.43473703f, 0.7353667f, 0.3969963f, 0.35003704f, 0.6654419f, 0.46649635f,
10.6,11.6, 12.6, 13.6,11.8, 12.8,13.8, 14.8,13.0, 14.0, 15.0, 16., 0.41335183f, 0.39988017f, 0.7140149f, 0.43368888f, 0.45865932f, 0.72049254f,
13.8,14.8, 15.8, 16.8,13.8, 14.8,15.8, 16.8,13.8, 14.8, 15.8, 16.8, 0.42537406f, 0.73366547f, 0.5662765f, 0.42371112f, 0.78866667f, 0.53543335f,
8.2, 9.2, 10.2, 11.2, 9.4, 10.4,11.4, 12.4,10.6, 11.6, 12.6, 13.6, 0.30312222f, 0.18414445f, 0.49542224f, 0.67293704f, 0.4168852f, 0.59891605f,
11.8,12.8, 13.8, 14.8,13., 14., 15., 16., 14.2, 15.2, 16.2, 17.2, 0.8822444f, 0.60281235f, 0.62855184f, 0.4495222f, 0.6014852f, 0.36275554f,
15.4,16.4, 17.4, 18.4,16.2, 17.2,18.2, 19.2,16.2, 17.2, 18.2, 19.2, 0.15933579f, 0.5788963f, 0.34024328f, 0.08295307f, 0.52441484f, 0.6826569f,
16.2,17.2, 18.2, 19.2,10.6, 11.6,12.6, 13.6,11.8, 12.8, 13.8, 14.8, 0.10747781f, 0.64715934f, 0.80707777f, 0.19927411f, 0.8880544f, 0.7861703f,
13., 14., 15., 16., 14.2, 15.2,16.2, 17.2,15.4, 16.4, 17.4, 18.4, 0.21763334f, 0.9362333f, 0.78198886f, 0.27523333f, 0.3308667f, 0.6250333f,
16.6,17.6, 18.6, 19.6,17.8, 18.8,19.8, 20.8,18.6, 19.6, 20.6, 21.6, 0.5907889f, 0.45925558f, 0.6709963f, 0.7761333f, 0.5249852f, 0.63986665f,
18.6,19.6, 20.6, 21.6,18.6, 19.6,20.6, 21.6,13., 14., 15., 16., 0.4406333f, 0.34007773f, 0.3003666f, 0.19945924f, 0.33715558f, 0.24757043f,
14.2,15.2, 16.2, 17.2,15.4, 16.4,17.4, 18.4,16.6, 17.6, 18.6, 19.6, 0.09977405f, 0.60721123f, 0.6248297f, 0.08286668f, 0.7239556f, 0.6876333f,
17.8,18.8, 19.8, 20.8,19., 20., 21., 22., 20.2, 21.2, 22.2, 23.2, 0.12114445f, 0.73849255f ,0.54079986f, 0.12879999f, 0.74139994f, 0.51143324f,
21., 22., 23., 24., 21., 22., 23., 24., 21., 22., 23., 24., 0.32978892f, 0.45314446f, 0.58711106f, 0.5576408f, 0.5464408f, 0.6107901f,
13., 14., 15., 16., 14.2, 15.2,16.2, 17.2,15.4, 16.4, 17.4, 18.4, 0.68978024f, 0.55681235f, 0.5833172f, 0.43907034f, 0.23548517f, 0.35123706f,
16.6,17.6, 18.6, 19.6,17.8, 18.8, 19.8, 20.8,19., 20., 21., 22., 0.26263458f, 0.18254575f, 0.33890504f, 0.1976099f, 0.5321877f, 0.65619516f,
20.2,21.2, 22.2, 23.2,21., 22., 23., 24., 21., 22., 23., 24., 0.18267044f, 0.6404851f, 0.63069254f, 0.20112106f, 0.58788633f, 0.37666163f,
21., 22., 23., 24., 13., 14., 15., 16., 14.2, 15.2, 16.2, 17.2, 0.20481117f, 0.57736665f, 0.32585555f, 0.50801116f, 0.5387556f, 0.29788882f,
15.4,16.4, 17.4, 18.4,16.6, 17.6, 18.6, 19.6,17.8, 18.8, 19.8, 20.8, 0.59799266f, 0.7008482f, 0.35215425f, 0.6330642f, 0.753121f, 0.42497158f,
19., 20., 21., 22., 20.2, 21.2, 22.2, 23.2,21., 22., 23., 24., 0.44849625f, 0.36611477f, 0.5719964f, 0.36038768f, 0.1586321f, 0.70625067f,
21., 22., 23., 24., 21., 22., 23., 24., 13., 14., 15., 16., 0.416968f, 0.22043455f, 0.82134944f, 0.4690964f, 0.31661478f, 0.6675073f,
14.2,15.2, 16.2, 17.2,15.4, 16.4, 17.4, 18.4,16.6, 17.6, 18.6, 19.6, 0.5182569f, 0.4357136f, 0.33437145f, 0.528089f, 0.4595333f, 0.26774442f,
17.8,18.8, 19.8, 20.8,19., 20., 21., 22., 20.2, 21.2, 22.2, 23.2, 0.52779996f, 0.5559667f, 0.35320008f, 0.5630963f, 0.62568885f, 0.44562602f,
21., 22., 23., 24., 21., 22., 23., 24., 21., 22., 23., 24., 0.557237f, 0.62408876f, 0.5438927f, 0.3867555f, 0.3371999f, 0.6655223f,
13., 14., 15., 16., 14.2, 15.2, 16.2, 17.2,15.4, 16.4, 17.4, 18.4, 0.30325183f, 0.17024446f, 0.71867025f, 0.35021478f, 0.18318895f, 0.6690962f,
16.6,17.6, 18.6, 19.6,17.8, 18.8, 19.8, 20.8,19., 20., 21., 22., 0.4377444f, 0.24482228f, 0.5241777f, 0.5523185f, 0.33891484f, 0.3156962f,
20.2,21.2, 22.2, 23.2, 0.5752333f, 0.3577333f, 0.27400002f, 0.44196665f, 0.52757776f, 0.6382001f,
21., 22., 23., 24., 21., 22., 23., 24., 21., 22., 23., 24.}); 0.47803456f, 0.3974851f, 0.7738359f, 0.4686691f, 0.27816284f, 0.8476581f,
0.2775703f, 0.20192216f, 0.6742259f, 0.14285672f, 0.20554078f, 0.4944727f,
0.0927209f, 0.32894826f, 0.30523813f, 0.19454071f, 0.3410815f, 0.26075178f,
0.3976642f, 0.27903205f, 0.31276423f, 0.43828884f, 0.2666222f, 0.32316667f,
0.4248f, 0.5219f, 0.6952f, 0.46102223f, 0.35184443f, 0.8394778f,
0.45095554f, 0.20897777f, 0.9084111f, 0.2557333f, 0.17486666f, 0.6759666f,
0.11077777f, 0.21260004f, 0.44963327f, 0.04122221f, 0.35810006f, 0.23246664f,
0.14590007f, 0.36033332f, 0.2080667f, 0.3667334f, 0.2670555f, 0.31217784f,
0.4109f, 0.2484f, 0.333f, 0.2974f, 0.6636f, 0.3808f,
0.6135111f, 0.40026665f, 0.5875778f, 0.8503f, 0.24200003f, 0.7501111f,
0.76979995f, 0.50400007f, 0.7356667f, 0.6879222f, 0.57351106f, 0.73106664f,
0.60397774f, 0.35428885f, 0.74123335f, 0.39506656f, 0.27853334f, 0.6585333f,
0.10284433f, 0.29842222f, 0.5139222f, 0.0444f, 0.3024f, 0.485f,
0.5756222f, 0.34854442f, 0.6049667f, 0.6263938f, 0.22777282f, 0.71313334f,
0.66620123f, 0.17765433f, 0.78429013f, 0.6621518f, 0.41014817f, 0.7074074f,
0.67555183f, 0.51060987f, 0.6708259f, 0.7151259f, 0.41302344f, 0.6946963f,
0.5446962f, 0.33081108f, 0.6180703f, 0.23426408f, 0.25884813f, 0.4744469f,
0.17217779f, 0.24445555f, 0.44572222f, 0.7964111f, 0.12472223f, 0.7531556f,
0.6118617f, 0.1483889f, 0.75928515f, 0.4833407f, 0.2004667f, 0.7449173f,
0.57893336f, 0.3661889f, 0.6485592f, 0.6772543f, 0.46945432f, 0.5984506f,
0.7796679f, 0.47903457f, 0.617716f, 0.63706285f, 0.40579626f, 0.54952586f,
0.33111224f, 0.27734566f, 0.42303205f, 0.26992223f, 0.25165558f, 0.39773333f,
0.7874667f, 0.26583335f, 0.5974333f, 0.4876703f, 0.44144446f, 0.48782218f,
0.30543333f, 0.57191116f, 0.41133702f, 0.5934334f, 0.5218f, 0.46735552f,
0.73524815f, 0.5152815f, 0.47753704f, 0.6577852f, 0.5741519f, 0.41896293f,
0.50037766f, 0.57161117f, 0.3686555f, 0.28967398f, 0.5281297f, 0.3238592f,
0.24753332f, 0.5194334f, 0.31489998f, 0.72816664f, 0.37683335f, 0.5285778f,
0.3895555f, 0.5582283f, 0.32292962f, 0.18990126f, 0.6730641f, 0.18445063f,
0.5460741f, 0.5216629f, 0.31464812f, 0.6978098f, 0.45279747f, 0.36710492f,
0.5428901f, 0.5077358f, 0.30295062f, 0.42367774f, 0.53567034f, 0.28493333f,
0.32827038f, 0.54560244f, 0.2976741f, 0.30918893f, 0.5475888f, 0.30022222f,
0.5933333f, 0.44266668f, 0.59002227f, 0.3305555f, 0.4106049f, 0.31789258f,
0.16793211f, 0.36878017f, 0.11760493f, 0.40592593f, 0.28790364f, 0.20468517f,
0.5172234f, 0.22784683f, 0.27239504f, 0.4384765f, 0.19901967f, 0.3110494f,
0.43695557f, 0.19709623f, 0.34693336f, 0.4869186f, 0.21310854f, 0.38097042f,
0.49691117f, 0.21631104f, 0.3877778f, 0.37919992f, 0.4914f, 0.56826663f,
0.26019996f, 0.34673333f, 0.29495183f, 0.21430746f, 0.23090371f, 0.09418149f,
0.46084452f, 0.23042224f, 0.1835889f, 0.56450003f, 0.23844449f, 0.26893705f,
0.45383334f, 0.2592223f, 0.34819633f, 0.45761114f, 0.21635559f, 0.38596666f,
0.5376852f, 0.13105926f, 0.39607778f, 0.55370003f, 0.11400001f, 0.3981f,
0.11219993f, 0.5287333f, 0.49104443f, 0.18227404f, 0.3386963f, 0.26007527f,
0.30624574f, 0.20396544f, 0.09970618f, 0.6458075f, 0.2904593f, 0.22173704f,
0.7636852f, 0.40607417f, 0.32631359f, 0.549037f, 0.5653705f, 0.40470868f,
0.4831852f, 0.47417036f, 0.40968886f, 0.5165309f, 0.21597281f, 0.3657259f,
0.5232f, 0.16433334f, 0.3569333f, 0.0588f, 0.5362f, 0.4756f,
0.16668889f, 0.33708888f, 0.25309998f, 0.32463336f, 0.19857779f, 0.10081112f,
0.68280005f, 0.3024667f, 0.22936666f, 0.80352217f, 0.43960005f, 0.33778888f,
0.5680777f, 0.6266f, 0.41601112f, 0.4883f, 0.52573323f, 0.4144333f,
0.5123f, 0.23295549f, 0.35965553f, 0.5171f, 0.1744f, 0.3487f
});
//input.linspace(1); //input.linspace(1);
nd4j::ops::resize_bilinear op; nd4j::ops::resize_bilinear op;
@ -1624,12 +1674,12 @@ TEST_F(DeclarableOpsTests10, ImageResizeBilinear_Test02) {
NDArray* result = results->at(0); NDArray* result = results->at(0);
result->printIndexedBuffer("Resized to 9x9"); // result->printBuffer("Resized to 9x9");
//expected.printIndexedBuffer("Expect for 10x10"); // expected.printBuffer("Expect for 9x9");
result->printShapeInfo("Output shape"); // result->printShapeInfo("Output shape");
// expected.printShapeInfo("Expect shape"); // expected.printShapeInfo("Expect shape");
// ASSERT_TRUE(expected.isSameShape(result)); ASSERT_TRUE(expected.isSameShape(result));
// ASSERT_TRUE(expected.equalsTo(result)); ASSERT_TRUE(expected.equalsTo(result));
delete results; delete results;
} }
@ -2015,6 +2065,53 @@ TEST_F(DeclarableOpsTests10, ImageResizeNeighbor_Test1) {
delete results; delete results;
} }
TEST_F(DeclarableOpsTests10, ImageResizeNeighbor_Test1_1) {
NDArray input = NDArrayFactory::create<int>('c', {1, 2, 3, 4});
//NDArray<float> paddings('c', {3,2}, {0,0, 0,1, 0,0});
//NDArray<float> expected('c', {2,4,4}, {1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,0.,0.,0.,0.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,0.,0.,0.,0.});
NDArray expected = NDArrayFactory::create<int>('c', {1, 4, 5, 4}, { 1, 2, 3, 4,
1, 2, 3, 4,
5, 6, 7, 8,
5, 6, 7, 8,
9, 10, 11, 12,
1, 2, 3, 4,
1, 2, 3, 4,
5, 6, 7, 8,
5, 6, 7, 8,
9, 10, 11, 12,
13, 14, 15, 16,
13, 14, 15, 16,
17, 18, 19, 20,
17, 18, 19, 20,
21, 22, 23, 24,
13, 14, 15, 16,
13, 14, 15, 16,
17, 18, 19, 20,
17, 18, 19, 20,
21, 22, 23, 24
});
//input = 1.f;
input.linspace(1);
nd4j::ops::resize_nearest_neighbor op;
auto results = op.execute({&input}, {}, {4, 5});
ASSERT_EQ(ND4J_STATUS_OK, results->status());
NDArray* result = results->at(0);
// result->printIndexedBuffer("Resized to 4x5");
// expected.printIndexedBuffer("Expect for 4x5");
ASSERT_TRUE(expected.isSameShape(result));
ASSERT_TRUE(expected.equalsTo(result));
delete results;
}
TEST_F(DeclarableOpsTests10, ImageResizeNeighbor_Test01) { TEST_F(DeclarableOpsTests10, ImageResizeNeighbor_Test01) {
NDArray input = NDArrayFactory::create<double>('c', {2, 3, 4}); NDArray input = NDArrayFactory::create<double>('c', {2, 3, 4});
@ -2166,6 +2263,73 @@ TEST_F(DeclarableOpsTests10, Image_NonMaxSuppressing_2) {
delete results; delete results;
} }
////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests10, Image_NonMaxSuppressing_3) {
NDArray boxes = NDArrayFactory::create<float>('c', {3, 4}, {0.8115f, 0.4121f, 0.0771f, 0.4863f,
0.7412f, 0.7607f, 0.1543f, 0.5479f,
0.8223f, 0.2246f, 0.0049f, 0.6465f});
NDArray scales = NDArrayFactory::create<float>('c', {3}, {0.0029f, 0.8135f, 0.4873f}); //3, 0, 1, 2, 4, 5
NDArray expected = NDArrayFactory::create<int>('c', {1}, {1});
nd4j::ops::non_max_suppression op;
auto results = op.execute({&boxes, &scales}, {0.5, 0.5}, {2});
ASSERT_EQ(Status::OK(), results->status());
NDArray* result = results->at(0);
// result->printBuffer("NonMaxSuppression OUtput3");
ASSERT_TRUE(expected.isSameShapeStrict(result));
ASSERT_TRUE(expected.equalsTo(result));
delete results;
}
TEST_F(DeclarableOpsTests10, Image_NonMaxSuppressing_4) {
NDArray boxes = NDArrayFactory::create<float16>('c', {3, 4}, {0.8115f, 0.4121f, 0.0771f, 0.4863f,
0.7412f, 0.7607f, 0.1543f, 0.5479f,
0.8223f, 0.2246f, 0.0049f, 0.6465f});
NDArray scales = NDArrayFactory::create<float16>('c', {3}, {0.0029f, 0.8135f, 0.4873f}); //3, 0, 1, 2, 4, 5
NDArray expected = NDArrayFactory::create<int>('c', {1}, {1});
NDArray maxSize = NDArrayFactory::create(2);
NDArray threshold = NDArrayFactory::create(0.5f);
NDArray scoreThreshold = NDArrayFactory::create(0.5);
nd4j::ops::non_max_suppression op;
auto results = op.execute({&boxes, &scales, &maxSize, &threshold, &scoreThreshold}, {}, {});
ASSERT_EQ(Status::OK(), results->status());
NDArray* result = results->at(0);
// result->printBuffer("NonMaxSuppression OUtput4");
ASSERT_TRUE(expected.isSameShapeStrict(result));
ASSERT_TRUE(expected.equalsTo(result));
delete results;
}
TEST_F(DeclarableOpsTests10, Image_NonMaxSuppressing_5) {
NDArray boxes = NDArrayFactory::create<float16>('c', {3, 4}, {0.8115f, 0.4121f, 0.0771f, 0.4863f,
0.7412f, 0.7607f, 0.1543f, 0.5479f,
0.8223f, 0.2246f, 0.0049f, 0.6465f});
NDArray scales = NDArrayFactory::create<float16>('c', {3}, {0.0029f, 0.8135f, 0.4873f}); //3, 0, 1, 2, 4, 5
NDArray expected = NDArrayFactory::create<int>('c', {2}, {1, 2});
NDArray maxSize = NDArrayFactory::create(2);
NDArray threshold = NDArrayFactory::create(0.5f);
NDArray scoreThreshold = NDArrayFactory::create(-DataTypeUtils::infOrMax<float>());
nd4j::ops::non_max_suppression op;
auto results = op.execute({&boxes, &scales, &maxSize, &threshold, &scoreThreshold}, {}, {});
ASSERT_EQ(Status::OK(), results->status());
NDArray* result = results->at(0);
// result->printBuffer("NonMaxSuppression OUtput4");
ASSERT_TRUE(expected.isSameShapeStrict(result));
ASSERT_TRUE(expected.equalsTo(result));
delete results;
}
//////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests10, Image_NonMaxSuppressingOverlap_1) { TEST_F(DeclarableOpsTests10, Image_NonMaxSuppressingOverlap_1) {
@ -2692,6 +2856,46 @@ TEST_F(DeclarableOpsTests10, FakeQuantWithMinMaxVars_Test_6) {
delete results; delete results;
} }
////////////////////////////////////////////////////////////////////
//TEST_F(DeclarableOpsTests10, FakeQuantWithMinMaxVars_Test_7) {
//
// NDArray x = NDArrayFactory::create<double>('c', {100});
// NDArray exp = NDArrayFactory::create<double>('c', {100}, {
// 0.f, 0.f, 0.f , 0.f , 0.06666667f, 0.06666667f ,
// 0.06666667, 0.06666667, 0.06666667, 0.06666667, 0.06666667, 0.13333334 ,
// 0.13333334, 0.13333334, 0.13333334, 0.13333334, 0.13333334, 0.20000002 ,
// 0.20000002, 0.20000002, 0.20000002, 0.20000002, 0.20000002, 0.20000002 ,
// 0.26666668, 0.26666668, 0.26666668, 0.26666668, 0.26666668, 0.26666668 ,
// 0.26666668, 0.33333334, 0.33333334, 0.33333334, 0.33333334, 0.33333334 ,
// 0.33333334, 0.40000004, 0.40000004, 0.40000004, 0.40000004, 0.40000004 ,
// 0.40000004, 0.40000004, 0.4666667 , 0.4666667 , 0.4666667 , 0.4666667 ,
// 0.4666667 , 0.4666667 , 0.4666667 , 0.53333336, 0.53333336, 0.53333336 ,
// 0.53333336, 0.53333336, 0.53333336, 0.6 , 0.6 , 0.6 ,
// 0.6 , 0.6 , 0.6 , 0.6 , 0.6666667 , 0.6666667 ,
// 0.6666667 , 0.6666667 , 0.6666667 , 0.6666667 , 0.6666667 , 0.73333335 ,
// 0.73333335, 0.73333335, 0.73333335, 0.73333335, 0.73333335, 0.8000001 ,
// 0.8000001 , 0.8000001 , 0.8000001 , 0.8000001 , 0.8000001 , 0.8000001 ,
// 0.86666673, 0.86666673, 0.86666673, 0.86666673, 0.86666673, 0.86666673 ,
// 0.86666673, 0.9333334 , 0.9333334 , 0.9333334 , 0.9333334 , 0.9333334 ,
// 0.9333334 , 1., 1., 1.,
// });
// NDArray min = NDArrayFactory::create<float>('c', {1},{0.0f});
// NDArray max = NDArrayFactory::create<float>('c', {1}, {1.f});
// x.linspace(0., 0.01);
// nd4j::ops::fake_quant_with_min_max_vars op;
// auto results = op.execute({&x, &min, &max}, {}, {});
//
// ASSERT_EQ(ND4J_STATUS_OK, results->status());
//
// auto result = results->at(0);
// result->printBuffer("Quantized7");
// exp.printBuffer("Expected 7");
// ASSERT_TRUE(exp.isSameShapeStrict(result));
// ASSERT_TRUE(exp.equalsTo(result));
//
// delete results;
//}
//////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests10, batchnorm_test1) { TEST_F(DeclarableOpsTests10, batchnorm_test1) {

View File

@ -254,6 +254,34 @@ TEST_F(DeclarableOpsTests15, Test_BitCast_2) {
delete result; delete result;
} }
TEST_F(DeclarableOpsTests15, Test_BitCast_3) {
auto x = NDArrayFactory::create<float>('c', {1, 4});
x.linspace(1.);
nd4j::ops::bitcast op;
try {
auto result = op.execute({&x}, {}, {nd4j::DataType::INT64}, {});
ASSERT_NE(Status::OK(), result->status());
delete result;
} catch (std::exception& e) {
nd4j_printf("Error should be here `%s'. It's OK.\n", e.what());
}
}
TEST_F(DeclarableOpsTests15, Test_BitCast_4) {
auto x = NDArrayFactory::create<float>('c', {1, 4});
auto e = NDArrayFactory::create<Nd4jLong>('c', {1, 2}, {1234567890LL, 2468013579LL});
x.linspace(1.);
nd4j::ops::bitcast op;
try {
auto result = op.execute({&x}, {&e}, {}, {nd4j::DataType::INT64}, {});
ASSERT_NE(Status::OK(), result);
} catch(std::exception& e) {
nd4j_printf("Error `%s' should be here. It's OK.\n",e.what());
}
}
TEST_F(DeclarableOpsTests15, Test_depthwise_bp_1) { TEST_F(DeclarableOpsTests15, Test_depthwise_bp_1) {
auto in = NDArrayFactory::create<float>('c', {4, 8, 64, 64}); auto in = NDArrayFactory::create<float>('c', {4, 8, 64, 64});
auto w = NDArrayFactory::create<float>('c', {2, 2, 8, 2}); auto w = NDArrayFactory::create<float>('c', {2, 2, 8, 2});

View File

@ -426,6 +426,10 @@ TEST_F(NlpTests, test_sg_ns_batch_1) {
} }
TEST_F(NlpTests, test_cbow_hs_batch_1) { TEST_F(NlpTests, test_cbow_hs_batch_1) {
#ifdef __CUDABLAS__
return ;
#endif
auto target = NDArrayFactory::create<int>(0); auto target = NDArrayFactory::create<int>(0);
auto ngStarter = NDArrayFactory::empty<int>(); auto ngStarter = NDArrayFactory::empty<int>();
auto context = NDArrayFactory::create<int>('c', {2, 3}, {0, 1, 2, 100, 101, 102}); auto context = NDArrayFactory::create<int>('c', {2, 3}, {0, 1, 2, 100, 101, 102});

View File

@ -62,8 +62,7 @@ public class ResizeNearestNeighbor extends DynamicCustomOp {
public List<DataType> calculateOutputDataTypes(List<DataType> inputDataTypes){ public List<DataType> calculateOutputDataTypes(List<DataType> inputDataTypes){
Preconditions.checkState(inputDataTypes != null && (inputDataTypes.size() == 1 || inputDataTypes.size() == 2), Preconditions.checkState(inputDataTypes != null && (inputDataTypes.size() == 1 || inputDataTypes.size() == 2),
"Expected 1 or 2 input datatypes for %s, got %s", getClass(), inputDataTypes); "Expected 1 or 2 input datatypes for %s, got %s", getClass(), inputDataTypes);
if(inputDataTypes.get(0).isFPType())
return Collections.singletonList(inputDataTypes.get(0)); return Collections.singletonList(inputDataTypes.get(0));
return Collections.singletonList(Nd4j.defaultFloatingPointType());
} }
} }

View File

@ -8005,9 +8005,12 @@ public static final int PREALLOC_SIZE = 33554432;
@Namespace("shape") public static native @Cast("Nd4jLong") long getOffset(@Cast("const Nd4jLong*") LongBuffer shapeInfo, @Cast("const Nd4jLong*") LongBuffer indices); @Namespace("shape") public static native @Cast("Nd4jLong") long getOffset(@Cast("const Nd4jLong*") LongBuffer shapeInfo, @Cast("const Nd4jLong*") LongBuffer indices);
@Namespace("shape") public static native @Cast("Nd4jLong") long getOffset(@Cast("const Nd4jLong*") long[] shapeInfo, @Cast("const Nd4jLong*") long[] indices, @Cast("Nd4jLong") long baseOffset/*=0*/); @Namespace("shape") public static native @Cast("Nd4jLong") long getOffset(@Cast("const Nd4jLong*") long[] shapeInfo, @Cast("const Nd4jLong*") long[] indices, @Cast("Nd4jLong") long baseOffset/*=0*/);
@Namespace("shape") public static native @Cast("Nd4jLong") long getOffset(@Cast("const Nd4jLong*") long[] shapeInfo, @Cast("const Nd4jLong*") long[] indices); @Namespace("shape") public static native @Cast("Nd4jLong") long getOffset(@Cast("const Nd4jLong*") long[] shapeInfo, @Cast("const Nd4jLong*") long[] indices);
@Namespace("shape") public static native @Cast("Nd4jLong") long getOffset(@Cast("const Nd4jLong*") LongPointer shapeInfo, @Cast("uint*") @StdVector IntPointer indices); @Namespace("shape") public static native @Cast("Nd4jLong") long getOffset(@Cast("const Nd4jLong*") LongPointer shapeInfo, @Const IntPointer indices, @Cast("Nd4jLong") long baseOffset/*=0*/);
@Namespace("shape") public static native @Cast("Nd4jLong") long getOffset(@Cast("const Nd4jLong*") LongBuffer shapeInfo, @Cast("uint*") @StdVector IntBuffer indices); @Namespace("shape") public static native @Cast("Nd4jLong") long getOffset(@Cast("const Nd4jLong*") LongPointer shapeInfo, @Const IntPointer indices);
@Namespace("shape") public static native @Cast("Nd4jLong") long getOffset(@Cast("const Nd4jLong*") long[] shapeInfo, @Cast("uint*") @StdVector int[] indices); @Namespace("shape") public static native @Cast("Nd4jLong") long getOffset(@Cast("const Nd4jLong*") LongBuffer shapeInfo, @Const IntBuffer indices, @Cast("Nd4jLong") long baseOffset/*=0*/);
@Namespace("shape") public static native @Cast("Nd4jLong") long getOffset(@Cast("const Nd4jLong*") LongBuffer shapeInfo, @Const IntBuffer indices);
@Namespace("shape") public static native @Cast("Nd4jLong") long getOffset(@Cast("const Nd4jLong*") long[] shapeInfo, @Const int[] indices, @Cast("Nd4jLong") long baseOffset/*=0*/);
@Namespace("shape") public static native @Cast("Nd4jLong") long getOffset(@Cast("const Nd4jLong*") long[] shapeInfo, @Const int[] indices);
@Namespace("shape") public static native @Cast("Nd4jLong*") LongPointer createShapeInfo(@Cast("Nd4jLong*") LongPointer shape, @Cast("Nd4jLong*") LongPointer stride, int rank); @Namespace("shape") public static native @Cast("Nd4jLong*") LongPointer createShapeInfo(@Cast("Nd4jLong*") LongPointer shape, @Cast("Nd4jLong*") LongPointer stride, int rank);
@Namespace("shape") public static native @Cast("Nd4jLong*") LongBuffer createShapeInfo(@Cast("Nd4jLong*") LongBuffer shape, @Cast("Nd4jLong*") LongBuffer stride, int rank); @Namespace("shape") public static native @Cast("Nd4jLong*") LongBuffer createShapeInfo(@Cast("Nd4jLong*") LongBuffer shape, @Cast("Nd4jLong*") LongBuffer stride, int rank);
@ -8024,6 +8027,9 @@ public static final int PREALLOC_SIZE = 33554432;
@Namespace("shape") public static native void index2coords(@Cast("Nd4jLong") long index, @Cast("const Nd4jLong*") LongPointer shapeInfo, @Cast("Nd4jLong*") LongPointer coords); @Namespace("shape") public static native void index2coords(@Cast("Nd4jLong") long index, @Cast("const Nd4jLong*") LongPointer shapeInfo, @Cast("Nd4jLong*") LongPointer coords);
@Namespace("shape") public static native void index2coords(@Cast("Nd4jLong") long index, @Cast("const Nd4jLong*") LongBuffer shapeInfo, @Cast("Nd4jLong*") LongBuffer coords); @Namespace("shape") public static native void index2coords(@Cast("Nd4jLong") long index, @Cast("const Nd4jLong*") LongBuffer shapeInfo, @Cast("Nd4jLong*") LongBuffer coords);
@Namespace("shape") public static native void index2coords(@Cast("Nd4jLong") long index, @Cast("const Nd4jLong*") long[] shapeInfo, @Cast("Nd4jLong*") long[] coords); @Namespace("shape") public static native void index2coords(@Cast("Nd4jLong") long index, @Cast("const Nd4jLong*") long[] shapeInfo, @Cast("Nd4jLong*") long[] coords);
@Namespace("shape") public static native void index2coords(@Cast("Nd4jLong") long index, @Cast("const Nd4jLong*") LongPointer shapeInfo, IntPointer coords);
@Namespace("shape") public static native void index2coords(@Cast("Nd4jLong") long index, @Cast("const Nd4jLong*") LongBuffer shapeInfo, IntBuffer coords);
@Namespace("shape") public static native void index2coords(@Cast("Nd4jLong") long index, @Cast("const Nd4jLong*") long[] shapeInfo, int[] coords);
@Namespace("shape") public static native void index2coords(@Cast("Nd4jLong") long index, int rank, @Cast("const Nd4jLong*") LongPointer shape, @Cast("Nd4jLong*") LongPointer coords); @Namespace("shape") public static native void index2coords(@Cast("Nd4jLong") long index, int rank, @Cast("const Nd4jLong*") LongPointer shape, @Cast("Nd4jLong*") LongPointer coords);
@Namespace("shape") public static native void index2coords(@Cast("Nd4jLong") long index, int rank, @Cast("const Nd4jLong*") LongBuffer shape, @Cast("Nd4jLong*") LongBuffer coords); @Namespace("shape") public static native void index2coords(@Cast("Nd4jLong") long index, int rank, @Cast("const Nd4jLong*") LongBuffer shape, @Cast("Nd4jLong*") LongBuffer coords);
@Namespace("shape") public static native void index2coords(@Cast("Nd4jLong") long index, int rank, @Cast("const Nd4jLong*") long[] shape, @Cast("Nd4jLong*") long[] coords); @Namespace("shape") public static native void index2coords(@Cast("Nd4jLong") long index, int rank, @Cast("const Nd4jLong*") long[] shape, @Cast("Nd4jLong*") long[] coords);
@ -8043,6 +8049,9 @@ public static final int PREALLOC_SIZE = 33554432;
@Namespace("shape") public static native @Cast("Nd4jLong") long coords2index(@Cast("const Nd4jLong*") LongPointer shapeInfo, @Cast("const Nd4jLong*") LongPointer coords); @Namespace("shape") public static native @Cast("Nd4jLong") long coords2index(@Cast("const Nd4jLong*") LongPointer shapeInfo, @Cast("const Nd4jLong*") LongPointer coords);
@Namespace("shape") public static native @Cast("Nd4jLong") long coords2index(@Cast("const Nd4jLong*") LongBuffer shapeInfo, @Cast("const Nd4jLong*") LongBuffer coords); @Namespace("shape") public static native @Cast("Nd4jLong") long coords2index(@Cast("const Nd4jLong*") LongBuffer shapeInfo, @Cast("const Nd4jLong*") LongBuffer coords);
@Namespace("shape") public static native @Cast("Nd4jLong") long coords2index(@Cast("const Nd4jLong*") long[] shapeInfo, @Cast("const Nd4jLong*") long[] coords); @Namespace("shape") public static native @Cast("Nd4jLong") long coords2index(@Cast("const Nd4jLong*") long[] shapeInfo, @Cast("const Nd4jLong*") long[] coords);
@Namespace("shape") public static native @Cast("Nd4jLong") long coords2index(@Cast("const Nd4jLong*") LongPointer shapeInfo, @Const IntPointer coords);
@Namespace("shape") public static native @Cast("Nd4jLong") long coords2index(@Cast("const Nd4jLong*") LongBuffer shapeInfo, @Const IntBuffer coords);
@Namespace("shape") public static native @Cast("Nd4jLong") long coords2index(@Cast("const Nd4jLong*") long[] shapeInfo, @Const int[] coords);
@Namespace("shape") public static native @Cast("Nd4jLong") long coords2index(int rank, @Cast("const Nd4jLong*") LongPointer shape, @Cast("const Nd4jLong*") LongPointer coords); @Namespace("shape") public static native @Cast("Nd4jLong") long coords2index(int rank, @Cast("const Nd4jLong*") LongPointer shape, @Cast("const Nd4jLong*") LongPointer coords);
@Namespace("shape") public static native @Cast("Nd4jLong") long coords2index(int rank, @Cast("const Nd4jLong*") LongBuffer shape, @Cast("const Nd4jLong*") LongBuffer coords); @Namespace("shape") public static native @Cast("Nd4jLong") long coords2index(int rank, @Cast("const Nd4jLong*") LongBuffer shape, @Cast("const Nd4jLong*") LongBuffer coords);
@Namespace("shape") public static native @Cast("Nd4jLong") long coords2index(int rank, @Cast("const Nd4jLong*") long[] shape, @Cast("const Nd4jLong*") long[] coords); @Namespace("shape") public static native @Cast("Nd4jLong") long coords2index(int rank, @Cast("const Nd4jLong*") long[] shape, @Cast("const Nd4jLong*") long[] coords);
@ -8354,6 +8363,10 @@ public static final int PREALLOC_SIZE = 33554432;
////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////
//////////////////////////////////////////////////////////////////////
//////////////////////////////////////////////////////////////////////
// ////////////////////////////////////////////////////////////////////// // //////////////////////////////////////////////////////////////////////
// INLINEDEF _CUDA_HD Nd4jLong getIndexOffset(Nd4jLong index, const Nd4jLong *shapeInfo, Nd4jLong arrLen) { // INLINEDEF _CUDA_HD Nd4jLong getIndexOffset(Nd4jLong index, const Nd4jLong *shapeInfo, Nd4jLong arrLen) {
@ -8778,7 +8791,7 @@ public static final int PREALLOC_SIZE = 33554432;
////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////
//////////////////////////////////////////////////////////////////////////
/** /**
@ -9110,6 +9123,10 @@ public static final int PREALLOC_SIZE = 33554432;
////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////
//////////////////////////////////////////////////////////////////////
//////////////////////////////////////////////////////////////////////

View File

@ -6,6 +6,9 @@ import java.nio.*;
import org.bytedeco.javacpp.*; import org.bytedeco.javacpp.*;
import org.bytedeco.javacpp.annotation.*; import org.bytedeco.javacpp.annotation.*;
import static org.bytedeco.openblas.global.openblas_nolapack.*;
import static org.bytedeco.openblas.global.openblas.*;
public class Nd4jCpu extends org.nd4j.nativeblas.Nd4jCpuHelper { public class Nd4jCpu extends org.nd4j.nativeblas.Nd4jCpuHelper {
static { Loader.load(); } static { Loader.load(); }
@ -8005,9 +8008,12 @@ public static final int PREALLOC_SIZE = 33554432;
@Namespace("shape") public static native @Cast("Nd4jLong") long getOffset(@Cast("const Nd4jLong*") LongBuffer shapeInfo, @Cast("const Nd4jLong*") LongBuffer indices); @Namespace("shape") public static native @Cast("Nd4jLong") long getOffset(@Cast("const Nd4jLong*") LongBuffer shapeInfo, @Cast("const Nd4jLong*") LongBuffer indices);
@Namespace("shape") public static native @Cast("Nd4jLong") long getOffset(@Cast("const Nd4jLong*") long[] shapeInfo, @Cast("const Nd4jLong*") long[] indices, @Cast("Nd4jLong") long baseOffset/*=0*/); @Namespace("shape") public static native @Cast("Nd4jLong") long getOffset(@Cast("const Nd4jLong*") long[] shapeInfo, @Cast("const Nd4jLong*") long[] indices, @Cast("Nd4jLong") long baseOffset/*=0*/);
@Namespace("shape") public static native @Cast("Nd4jLong") long getOffset(@Cast("const Nd4jLong*") long[] shapeInfo, @Cast("const Nd4jLong*") long[] indices); @Namespace("shape") public static native @Cast("Nd4jLong") long getOffset(@Cast("const Nd4jLong*") long[] shapeInfo, @Cast("const Nd4jLong*") long[] indices);
@Namespace("shape") public static native @Cast("Nd4jLong") long getOffset(@Cast("const Nd4jLong*") LongPointer shapeInfo, @Cast("uint*") @StdVector IntPointer indices); @Namespace("shape") public static native @Cast("Nd4jLong") long getOffset(@Cast("const Nd4jLong*") LongPointer shapeInfo, @Const IntPointer indices, @Cast("Nd4jLong") long baseOffset/*=0*/);
@Namespace("shape") public static native @Cast("Nd4jLong") long getOffset(@Cast("const Nd4jLong*") LongBuffer shapeInfo, @Cast("uint*") @StdVector IntBuffer indices); @Namespace("shape") public static native @Cast("Nd4jLong") long getOffset(@Cast("const Nd4jLong*") LongPointer shapeInfo, @Const IntPointer indices);
@Namespace("shape") public static native @Cast("Nd4jLong") long getOffset(@Cast("const Nd4jLong*") long[] shapeInfo, @Cast("uint*") @StdVector int[] indices); @Namespace("shape") public static native @Cast("Nd4jLong") long getOffset(@Cast("const Nd4jLong*") LongBuffer shapeInfo, @Const IntBuffer indices, @Cast("Nd4jLong") long baseOffset/*=0*/);
@Namespace("shape") public static native @Cast("Nd4jLong") long getOffset(@Cast("const Nd4jLong*") LongBuffer shapeInfo, @Const IntBuffer indices);
@Namespace("shape") public static native @Cast("Nd4jLong") long getOffset(@Cast("const Nd4jLong*") long[] shapeInfo, @Const int[] indices, @Cast("Nd4jLong") long baseOffset/*=0*/);
@Namespace("shape") public static native @Cast("Nd4jLong") long getOffset(@Cast("const Nd4jLong*") long[] shapeInfo, @Const int[] indices);
@Namespace("shape") public static native @Cast("Nd4jLong*") LongPointer createShapeInfo(@Cast("Nd4jLong*") LongPointer shape, @Cast("Nd4jLong*") LongPointer stride, int rank); @Namespace("shape") public static native @Cast("Nd4jLong*") LongPointer createShapeInfo(@Cast("Nd4jLong*") LongPointer shape, @Cast("Nd4jLong*") LongPointer stride, int rank);
@Namespace("shape") public static native @Cast("Nd4jLong*") LongBuffer createShapeInfo(@Cast("Nd4jLong*") LongBuffer shape, @Cast("Nd4jLong*") LongBuffer stride, int rank); @Namespace("shape") public static native @Cast("Nd4jLong*") LongBuffer createShapeInfo(@Cast("Nd4jLong*") LongBuffer shape, @Cast("Nd4jLong*") LongBuffer stride, int rank);
@ -8024,6 +8030,9 @@ public static final int PREALLOC_SIZE = 33554432;
@Namespace("shape") public static native void index2coords(@Cast("Nd4jLong") long index, @Cast("const Nd4jLong*") LongPointer shapeInfo, @Cast("Nd4jLong*") LongPointer coords); @Namespace("shape") public static native void index2coords(@Cast("Nd4jLong") long index, @Cast("const Nd4jLong*") LongPointer shapeInfo, @Cast("Nd4jLong*") LongPointer coords);
@Namespace("shape") public static native void index2coords(@Cast("Nd4jLong") long index, @Cast("const Nd4jLong*") LongBuffer shapeInfo, @Cast("Nd4jLong*") LongBuffer coords); @Namespace("shape") public static native void index2coords(@Cast("Nd4jLong") long index, @Cast("const Nd4jLong*") LongBuffer shapeInfo, @Cast("Nd4jLong*") LongBuffer coords);
@Namespace("shape") public static native void index2coords(@Cast("Nd4jLong") long index, @Cast("const Nd4jLong*") long[] shapeInfo, @Cast("Nd4jLong*") long[] coords); @Namespace("shape") public static native void index2coords(@Cast("Nd4jLong") long index, @Cast("const Nd4jLong*") long[] shapeInfo, @Cast("Nd4jLong*") long[] coords);
@Namespace("shape") public static native void index2coords(@Cast("Nd4jLong") long index, @Cast("const Nd4jLong*") LongPointer shapeInfo, IntPointer coords);
@Namespace("shape") public static native void index2coords(@Cast("Nd4jLong") long index, @Cast("const Nd4jLong*") LongBuffer shapeInfo, IntBuffer coords);
@Namespace("shape") public static native void index2coords(@Cast("Nd4jLong") long index, @Cast("const Nd4jLong*") long[] shapeInfo, int[] coords);
@Namespace("shape") public static native void index2coords(@Cast("Nd4jLong") long index, int rank, @Cast("const Nd4jLong*") LongPointer shape, @Cast("Nd4jLong*") LongPointer coords); @Namespace("shape") public static native void index2coords(@Cast("Nd4jLong") long index, int rank, @Cast("const Nd4jLong*") LongPointer shape, @Cast("Nd4jLong*") LongPointer coords);
@Namespace("shape") public static native void index2coords(@Cast("Nd4jLong") long index, int rank, @Cast("const Nd4jLong*") LongBuffer shape, @Cast("Nd4jLong*") LongBuffer coords); @Namespace("shape") public static native void index2coords(@Cast("Nd4jLong") long index, int rank, @Cast("const Nd4jLong*") LongBuffer shape, @Cast("Nd4jLong*") LongBuffer coords);
@Namespace("shape") public static native void index2coords(@Cast("Nd4jLong") long index, int rank, @Cast("const Nd4jLong*") long[] shape, @Cast("Nd4jLong*") long[] coords); @Namespace("shape") public static native void index2coords(@Cast("Nd4jLong") long index, int rank, @Cast("const Nd4jLong*") long[] shape, @Cast("Nd4jLong*") long[] coords);
@ -8043,6 +8052,9 @@ public static final int PREALLOC_SIZE = 33554432;
@Namespace("shape") public static native @Cast("Nd4jLong") long coords2index(@Cast("const Nd4jLong*") LongPointer shapeInfo, @Cast("const Nd4jLong*") LongPointer coords); @Namespace("shape") public static native @Cast("Nd4jLong") long coords2index(@Cast("const Nd4jLong*") LongPointer shapeInfo, @Cast("const Nd4jLong*") LongPointer coords);
@Namespace("shape") public static native @Cast("Nd4jLong") long coords2index(@Cast("const Nd4jLong*") LongBuffer shapeInfo, @Cast("const Nd4jLong*") LongBuffer coords); @Namespace("shape") public static native @Cast("Nd4jLong") long coords2index(@Cast("const Nd4jLong*") LongBuffer shapeInfo, @Cast("const Nd4jLong*") LongBuffer coords);
@Namespace("shape") public static native @Cast("Nd4jLong") long coords2index(@Cast("const Nd4jLong*") long[] shapeInfo, @Cast("const Nd4jLong*") long[] coords); @Namespace("shape") public static native @Cast("Nd4jLong") long coords2index(@Cast("const Nd4jLong*") long[] shapeInfo, @Cast("const Nd4jLong*") long[] coords);
@Namespace("shape") public static native @Cast("Nd4jLong") long coords2index(@Cast("const Nd4jLong*") LongPointer shapeInfo, @Const IntPointer coords);
@Namespace("shape") public static native @Cast("Nd4jLong") long coords2index(@Cast("const Nd4jLong*") LongBuffer shapeInfo, @Const IntBuffer coords);
@Namespace("shape") public static native @Cast("Nd4jLong") long coords2index(@Cast("const Nd4jLong*") long[] shapeInfo, @Const int[] coords);
@Namespace("shape") public static native @Cast("Nd4jLong") long coords2index(int rank, @Cast("const Nd4jLong*") LongPointer shape, @Cast("const Nd4jLong*") LongPointer coords); @Namespace("shape") public static native @Cast("Nd4jLong") long coords2index(int rank, @Cast("const Nd4jLong*") LongPointer shape, @Cast("const Nd4jLong*") LongPointer coords);
@Namespace("shape") public static native @Cast("Nd4jLong") long coords2index(int rank, @Cast("const Nd4jLong*") LongBuffer shape, @Cast("const Nd4jLong*") LongBuffer coords); @Namespace("shape") public static native @Cast("Nd4jLong") long coords2index(int rank, @Cast("const Nd4jLong*") LongBuffer shape, @Cast("const Nd4jLong*") LongBuffer coords);
@Namespace("shape") public static native @Cast("Nd4jLong") long coords2index(int rank, @Cast("const Nd4jLong*") long[] shape, @Cast("const Nd4jLong*") long[] coords); @Namespace("shape") public static native @Cast("Nd4jLong") long coords2index(int rank, @Cast("const Nd4jLong*") long[] shape, @Cast("const Nd4jLong*") long[] coords);
@ -8354,6 +8366,10 @@ public static final int PREALLOC_SIZE = 33554432;
////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////
//////////////////////////////////////////////////////////////////////
//////////////////////////////////////////////////////////////////////
// ////////////////////////////////////////////////////////////////////// // //////////////////////////////////////////////////////////////////////
// INLINEDEF _CUDA_HD Nd4jLong getIndexOffset(Nd4jLong index, const Nd4jLong *shapeInfo, Nd4jLong arrLen) { // INLINEDEF _CUDA_HD Nd4jLong getIndexOffset(Nd4jLong index, const Nd4jLong *shapeInfo, Nd4jLong arrLen) {
@ -8778,7 +8794,7 @@ public static final int PREALLOC_SIZE = 33554432;
////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////
//////////////////////////////////////////////////////////////////////////
/** /**
@ -9110,6 +9126,10 @@ public static final int PREALLOC_SIZE = 33554432;
////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////
//////////////////////////////////////////////////////////////////////
//////////////////////////////////////////////////////////////////////