Shugeo release fix2 (#70)

* Corrected input checking and tests for bitcast op.

* Fixed an issue with non_max_suppression form generation and processing with score threshold given.

* Fixed bilinear resize kernel and tests.

* push for Serhii

Signed-off-by: raver119 <raver119@gmail.com>

* Added test for nearest_neighbor resize with int input.

* Added data type check for input/output match.

* Eliminate error in macros.

* Improved output message for type checking.

* Fixed input/output types for op.

* Eliminated waste logging.

* Refactored resize_bilinear helper for multithreading for cpu platform.

* Cosmetic changes only.

* Fixed error for string substitution.

* Skip test for cbow_batch with cuda.

* fix for resizeNearestNeighbor output dtype

Signed-off-by: raver119 <raver119@gmail.com>

* Refactored non_max_suppression helper.

* Refactored shape generation and input handling.

* Added additional test.
master
shugeo 2019-11-22 21:42:44 +02:00 committed by raver119
parent 289d9dc141
commit 4187190609
13 changed files with 432 additions and 127 deletions

View File

@ -30,6 +30,17 @@ namespace nd4j {
auto input = INPUT_VARIABLE(0);
auto output = OUTPUT_VARIABLE(0);
// when empty - nothing to do
DataType newType = DataTypeUtils::fromInt(INT_ARG(0));
DataType oldType = input->dataType();
// correct output shape to conform with output data type
auto inputSize = DataTypeUtils::sizeOf(oldType);
auto outputSize = DataTypeUtils::sizeOf(newType);
auto lastSize = outputSize / inputSize;
if (inputSize < outputSize) {
REQUIRE_TRUE(input->sizeAt(-1) == lastSize, 0,
"BITCAST: %llu > %llu. So last dimension should be %i, but %i given.", inputSize,
outputSize, lastSize, input->sizeAt(-1));
}
if(input->isEmpty()){
REQUIRE_TRUE(output->isEmpty(), 0, "BITCAST: If input is empty, output array must also be empty.");
return Status::OK();
@ -70,7 +81,7 @@ namespace nd4j {
auto outputShape = ConstantShapeHelper::getInstance()->createShapeInfo(newType, shape::order(inShape), shapeOf);
return SHAPELIST(outputShape);
}
REQUIRE_TRUE(shape::sizeAt(inShape, -1) == outputSize / inputSize, 0, "BITCAST: %ull > %ull. So last dimension should be %ull, but %i given.", inputSize, outputSize, outputSize / inputSize, shape::sizeAt(inShape, -1));
REQUIRE_TRUE(shape::sizeAt(inShape, -1) == outputSize / inputSize, 0, "BITCAST: %llu > %llu. So last dimension should be %i, but %i given.", inputSize, outputSize, outputSize / inputSize, shape::sizeAt(inShape, -1));
std::vector<Nd4jLong> shapeOf(inputRank - 1);
for (auto i = 0; i < shapeOf.size(); ++i) {

View File

@ -37,6 +37,22 @@ namespace nd4j {
else
REQUIRE_TRUE(false, 0, "image.non_max_suppression: Max output size argument cannot be retrieved.");
double overlayThreshold = 0.5;
double scoreThreshold = - DataTypeUtils::infOrMax<float>();
if (block.width() > 3) {
overlayThreshold = INPUT_VARIABLE(3)->e<double>(0);
}
else if (block.getTArguments()->size() > 0) {
overlayThreshold = T_ARG(0);
}
if (block.width() > 4) {
scoreThreshold = INPUT_VARIABLE(4)->e<double>(0);
}
else if (block.getTArguments()->size() > 1) {
scoreThreshold = T_ARG(1);
}
if (boxes->isEmpty() || scales->isEmpty())
return Status::OK();
@ -44,15 +60,6 @@ namespace nd4j {
REQUIRE_TRUE(boxes->sizeAt(1) == 4, 0, "image.non_max_suppression: The last dimension of boxes array should be 4, but %i is given", boxes->sizeAt(1));
REQUIRE_TRUE(scales->rankOf() == 1 && scales->lengthOf() == boxes->sizeAt(0), 0, "image.non_max_suppression: The rank of scales array should be 1, but %i is given", boxes->rankOf());
if (scales->lengthOf() < maxOutputSize)
maxOutputSize = scales->lengthOf();
double overlayThreshold = 0.5;
double scoreThreshold = - DataTypeUtils::infOrMax<float>();
if (block.getTArguments()->size() > 0)
overlayThreshold = T_ARG(0);
if (block.getTArguments()->size() > 1)
scoreThreshold = T_ARG(1);
helpers::nonMaxSuppression(block.launchContext(), boxes, scales, maxOutputSize, overlayThreshold, scoreThreshold, output);
return Status::OK();
}
@ -70,10 +77,19 @@ namespace nd4j {
else
REQUIRE_TRUE(false, 0, "image.non_max_suppression: Max output size argument cannot be retrieved.");
Nd4jLong boxSize = shape::sizeAt(in, 0);
if (boxSize < maxOutputSize)
maxOutputSize = boxSize;
auto actualIndicesCount = shape::sizeAt(in, 0);
if (block.getTArguments()->size() > 1 || block.width() > 4) {
auto scoreThreshold = block.getTArguments()->size() > 1?T_ARG(1):INPUT_VARIABLE(4)->e<double>(0);
auto scales = INPUT_VARIABLE(1);
scales->syncToHost();
for (auto e = 0; e < scales->lengthOf(); e++) {
if (scales->e<float>(e) < (float)scoreThreshold) {
actualIndicesCount--;
}
}
}
if (actualIndicesCount < maxOutputSize)
maxOutputSize = actualIndicesCount;
outputShape = ConstantShapeHelper::getInstance()->vectorShapeInfo(maxOutputSize, DataType::INT32);

View File

@ -54,7 +54,9 @@ namespace nd4j {
auto inRank = image->rankOf();
REQUIRE_TRUE(inRank == 4 || inRank == 3, 0, "resize_nearest_neighbor: Input should be 4D tensor, but rank %i occured");
REQUIRE_TRUE(inRank == output->rankOf(), 0, "resize_nearest_neighbor: Input and output ranks should be equals, but %i and %i occured.", inRank, output->rankOf());
REQUIRE_TRUE(image->dataType() == output->dataType(), 0, "resize_nearest_neighbor: Input and output types should be the same, but `%s' occured instead.", DataTypeUtils::asString(output->dataType()).c_str());
auto source = inRank == 4?*image:image->reshape(image->ordering(), {1, image->sizeAt(0), image->sizeAt(1), image->sizeAt(2)});
auto target = inRank == 4?*output:output->reshape(output->ordering(), {1, output->sizeAt(0), output->sizeAt(1), output->sizeAt(2)});
return helpers::resizeNeighborFunctor(block.launchContext(), inRank==4?image:&source, width, height, center, inRank == 4?output:&target);
@ -105,8 +107,8 @@ namespace nd4j {
}
DECLARE_TYPES(resize_nearest_neighbor) {
getOpDescriptor()
->setAllowedInputTypes(nd4j::DataType::ANY)
->setAllowedOutputTypes({ALL_FLOATS});
->setAllowedInputTypes({ALL_INTS, ALL_FLOATS})
->setAllowedOutputTypes({ALL_INTS, ALL_FLOATS});
}
}

View File

@ -26,7 +26,6 @@
namespace nd4j {
namespace ops {
CUSTOM_OP_IMPL(broadcast_to, 2, 1, false, 0, 0) {
auto input = INPUT_VARIABLE(0);

View File

@ -41,11 +41,11 @@ namespace ops {
namespace helpers {
struct BilinearInterpolationData {
Nd4jLong bottomIndex; // Lower source index used in the interpolation
Nd4jLong topIndex; // Upper source index used in the interpolation
// 1-D linear iterpolation scale (see:
Nd4jLong _bottomIndex; // Lower source index used in the interpolation
Nd4jLong _topIndex; // Upper source index used in the interpolation
// 1D linear iterpolation scale (see:
// https://en.wikipedia.org/wiki/Bilinear_interpolation)
double interpolarValue;
double _interpolarValue;
};
// calculateResizeScale determines the float scaling factor.
inline float calculateResizeScale(Nd4jLong inSize, Nd4jLong outSize,
@ -137,16 +137,16 @@ namespace helpers {
Nd4jLong inSize,
double scale,
BilinearInterpolationData *interpolationData) {
interpolationData[outSize].bottomIndex = 0;
interpolationData[outSize].topIndex = 0;
interpolationData[outSize]._bottomIndex = 0;
interpolationData[outSize]._topIndex = 0;
auto func = PRAGMA_THREADS_FOR {
for (auto k = start; k < stop; k++) {
auto i = (outSize - k - 1);
double in = i * scale;
interpolationData[i].bottomIndex = static_cast<Nd4jLong>(in);
interpolationData[i].topIndex = nd4j::math::nd4j_min(interpolationData[i].bottomIndex + 1, inSize - 1);
interpolationData[i].interpolarValue = in - interpolationData[i].bottomIndex;
interpolationData[i]._bottomIndex = static_cast<Nd4jLong>(in);
interpolationData[i]._topIndex = nd4j::math::nd4j_min(interpolationData[i]._bottomIndex + 1, inSize - 1);
interpolationData[i]._interpolarValue = in - interpolationData[i]._bottomIndex;
}
};
samediff::Threads::parallel_for(func, 0, outSize);
@ -175,10 +175,10 @@ namespace helpers {
Nd4jLong inBatchNumValues = inHeight * inRowSize;
Nd4jLong outRowSize = outWidth * channels;
T const *pInput = images->getDataBuffer()->primaryAsT<T>(); // this works only with 'c' direction
BilinearInterpolationData const *xs_ = xs.data();
T const *pInputBuf = images->getDataBuffer()->primaryAsT<T>(); // this works only with 'c' direction
BilinearInterpolationData const* xsPtr = xs.data();
T* pOutput = output->dataBuffer()->primaryAsT<T>();
T* pOutputBuf = output->dataBuffer()->primaryAsT<T>();
auto computeBilinear = [](double topLeft, double topRight,
double bottomLeft, double bottomRight,
double xVal, double yVal) {
@ -188,28 +188,27 @@ namespace helpers {
};
auto func = PRAGMA_THREADS_FOR {
for (auto b = start; b < stop; ++b) {
for (auto batch = start; batch < stop; ++batch) {
auto pInput = pInputBuf + batch * inBatchNumValues;
for (auto y = 0; y < outHeight; ++y) {
const T *ys_input_lower_ptr = pInput + ys[y].bottomIndex * inRowSize;
const T *ys_input_upper_ptr = pInput + ys[y].topIndex * inRowSize;
double yVal = ys[y].interpolarValue;
auto pOutput = pOutputBuf + (batch * outHeight + y) * outRowSize;
const T* ysInputLowerPtr = pInput + ys[y]._bottomIndex * inRowSize;
const T* ysInputUpperPtr = pInput + ys[y]._topIndex * inRowSize;
double yVal = ys[y]._interpolarValue;
for (auto x = 0; x < outWidth; ++x) {
auto xsBottom = xs_[x].bottomIndex;
auto xsTop = xs_[x].topIndex;
auto xVal = xs_[x].interpolarValue;
auto xsBottom = xsPtr[x]._bottomIndex;
auto xsTop = xsPtr[x]._topIndex;
auto xVal = xsPtr[x]._interpolarValue;
for (auto c = 0; c < channels; ++c) {
double topLeft(ys_input_lower_ptr[xsBottom + c]);
double topRight(ys_input_lower_ptr[xsTop + c]);
double bottomLeft(ys_input_upper_ptr[xsBottom + c]);
double bottomRight(ys_input_upper_ptr[xsTop + c]);
pOutput[x * channels + c] =
computeBilinear(topLeft, topRight, bottomLeft, bottomRight,
double topLeft(ysInputLowerPtr[xsBottom + c]);
double topRight(ysInputLowerPtr[xsTop + c]);
double bottomLeft(ysInputUpperPtr[xsBottom + c]);
double bottomRight(ysInputUpperPtr[xsTop + c]);
pOutput[x * channels + c] = computeBilinear(topLeft, topRight, bottomLeft, bottomRight,
xVal, yVal);
}
}
pOutput += outRowSize;
}
pInput += inBatchNumValues;
}
};
samediff::Threads::parallel_tad(func, 0, batchSize);
@ -257,8 +256,8 @@ namespace helpers {
// Scale x interpolation weights to avoid a multiplication during iteration.
auto func = PRAGMA_THREADS_FOR {
for (auto i = start; i < stop; i += increment) {
xs[i].bottomIndex *= channels;
xs[i].topIndex *= channels;
xs[i]._bottomIndex *= channels;
xs[i]._topIndex *= channels;
}
};
samediff::Threads::parallel_for(func, 0, xsSize);

View File

@ -33,23 +33,27 @@ namespace helpers {
double scoreThreshold, NDArray* output) {
std::vector<int> indices(scales->lengthOf());
std::iota(indices.begin(), indices.end(), 0);
auto actualIndicesCount = indices.size();
for (auto e = 0; e < scales->lengthOf(); e++) {
if (scales->e<double>(e) < scoreThreshold) indices[e] = -1;
if (scales->e<float>(e) < (float)scoreThreshold) {
indices[e] = -1;
actualIndicesCount--;
}
std::sort(indices.begin(), indices.end(), [scales](int i, int j) {return scales->e<T>(i) > scales->e<T>(j);});
}
std::sort(indices.begin(), indices.end(), [scales](int i, int j) {return i >= 0 && j >=0?scales->e<T>(i) > scales->e<T>(j):(i > j);});
// std::vector<int> selected(output->lengthOf());
std::vector<int> selectedIndices(output->lengthOf(), 0);
auto needToSuppressWithThreshold = [] (NDArray& boxes, int previousIndex, int nextIndex, T threshold) -> bool {
if (previousIndex < 0 || nextIndex < 0) return true;
T minYPrev = nd4j::math::nd4j_min(boxes.e<T>(previousIndex, 0), boxes.e<T>(previousIndex, 2));
T minXPrev = nd4j::math::nd4j_min(boxes.e<T>(previousIndex, 1), boxes.e<T>(previousIndex, 3));
T maxYPrev = nd4j::math::nd4j_max(boxes.e<T>(previousIndex, 0), boxes.e<T>(previousIndex, 2));
T maxXPrev = nd4j::math::nd4j_max(boxes.e<T>(previousIndex, 1), boxes.e<T>(previousIndex, 3));
T minYNext = nd4j::math::nd4j_min(boxes.e<T>(nextIndex, 0), boxes.e<T>(nextIndex, 2));
T minXNext = nd4j::math::nd4j_min(boxes.e<T>(nextIndex, 1), boxes.e<T>(nextIndex, 3));
T maxYNext = nd4j::math::nd4j_max(boxes.e<T>(nextIndex, 0), boxes.e<T>(nextIndex, 2));
T maxXNext = nd4j::math::nd4j_max(boxes.e<T>(nextIndex, 1), boxes.e<T>(nextIndex, 3));
T minYPrev = nd4j::math::nd4j_min(boxes.t<T>(previousIndex, 0), boxes.t<T>(previousIndex, 2));
T minXPrev = nd4j::math::nd4j_min(boxes.t<T>(previousIndex, 1), boxes.t<T>(previousIndex, 3));
T maxYPrev = nd4j::math::nd4j_max(boxes.t<T>(previousIndex, 0), boxes.t<T>(previousIndex, 2));
T maxXPrev = nd4j::math::nd4j_max(boxes.t<T>(previousIndex, 1), boxes.t<T>(previousIndex, 3));
T minYNext = nd4j::math::nd4j_min(boxes.t<T>(nextIndex, 0), boxes.t<T>(nextIndex, 2));
T minXNext = nd4j::math::nd4j_min(boxes.t<T>(nextIndex, 1), boxes.t<T>(nextIndex, 3));
T maxYNext = nd4j::math::nd4j_max(boxes.t<T>(nextIndex, 0), boxes.t<T>(nextIndex, 2));
T maxXNext = nd4j::math::nd4j_max(boxes.t<T>(nextIndex, 1), boxes.t<T>(nextIndex, 3));
T areaPrev = (maxYPrev - minYPrev) * (maxXPrev - minXPrev);
T areaNext = (maxYNext - minYNext) * (maxXNext - minXNext);
@ -67,7 +71,7 @@ namespace helpers {
};
// int numSelected = 0;
int numBoxes = boxes->sizeAt(0);
int numBoxes = actualIndicesCount; //boxes->sizeAt(0);
int numSelected = 0;
for (int i = 0; i < numBoxes; ++i) {

View File

@ -77,22 +77,19 @@ namespace helpers {
Nd4jLong outWidth, Nd4jLong outHeight, Nd4jLong channels, Nd4jLong inRowSize, Nd4jLong outRowSize, Nd4jLong inBatchNumValues,
BilinearInterpolationData* xs_, BilinearInterpolationData* ys_) {
if (blockIdx.x < batchSize) { // blockIdx.x as batch index
auto pX = input + blockIdx.x * inBatchNumValues;
auto channelStart = blockIdx.z * blockDim.z + threadIdx.z;
auto step = blockDim.z * gridDim.z;
for (auto batch = blockIdx.x; batch < batchSize; batch += gridDim.x ) { // blockIdx.x as batch index
auto pX = input + batch * inBatchNumValues;
for (Nd4jLong y = threadIdx.x; y < outHeight; y += blockDim.x) {
const T *ys_input_lower_ptr = pX + ys_[y].bottomIndex * inRowSize;
const T *ys_input_upper_ptr = pX + ys_[y].topIndex * inRowSize;
double yVal = ys_[y].interpolarValue;
auto pZ = outputYptr + y * outRowSize;
auto pZ = outputYptr + (batch * outHeight + y) * outRowSize;
for (Nd4jLong x = threadIdx.y; x < outWidth; x += blockDim.y) {
auto xsBottom = xs_[x].bottomIndex;
auto xsTop = xs_[x].topIndex;
auto xVal = xs_[x].interpolarValue;
// process interpolation for all channels
for (int c = channelStart; c < channels; c += step) {
for (int c = threadIdx.z; c < channels; c += blockDim.z) {
double topLeft(ys_input_lower_ptr[xsBottom + c]);
double topRight(ys_input_lower_ptr[xsTop + c]);
double bottomLeft(ys_input_upper_ptr[xsBottom + c]);
@ -120,9 +117,15 @@ namespace helpers {
auto stream = context->getCudaStream();
T const *input_b_ptr = reinterpret_cast<T const *>(images->getSpecialBuffer()); // this works only with 'c' direction
T *output_y_ptr = reinterpret_cast<T *>(output->specialBuffer());
resizeImageKernel<T><<<batchSize, outHeight, 256, *stream>>>(input_b_ptr, images->getSpecialShapeInfo(), output_y_ptr, output->specialShapeInfo(), batchSize,
dim3 batchSizeBlock(batchSize, 1, 1);
dim3 pictureBlock(outHeight, outWidth, channels);
resizeImageKernel<T><<<256, pictureBlock, 256, *stream>>>(input_b_ptr, images->getSpecialShapeInfo(), output_y_ptr, output->specialShapeInfo(), batchSize,
outWidth, outHeight, channels, inRowSize, outRowSize, inBatchNumValues, xs_, ys_);
auto err = cudaStreamSynchronize(*stream);
if (err != 0) {
throw cuda_exception::build("helpers::resizeImage_: Cannot synchronize kernel execution", err);
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
@ -176,7 +179,6 @@ namespace helpers {
NDArray::prepareSpecialUse({output}, {images});
resizeImage(context, images, batchSize, inHeight, inWidth, outHeight, outWidth, channels, xs_, ys_, output);
NDArray::registerSpecialUse({output}, {images});
err = cudaFree(xs_);
if (err != 0) {
throw cuda_exception::build("helpers::resize_image: Cannot deallocate memory for vertical parts rectangulars", err);

View File

@ -1530,7 +1530,8 @@ TEST_F(DeclarableOpsTests10, ImageResizeBilinear_Test01) {
TEST_F(DeclarableOpsTests10, ImageResizeBilinear_Test02) {
NDArray input = NDArrayFactory::create<float>('c', {2, 5,5,3}, {0.7788f, 0.8012f, 0.7244f,
NDArray input = NDArrayFactory::create<float>('c', {2, 5,5,3}, {
0.7788f, 0.8012f, 0.7244f,
0.2309f, 0.7271f, 0.1804f,
0.5056f, 0.8925f, 0.5461f,
0.9234f, 0.0856f, 0.7938f,
@ -1581,40 +1582,89 @@ TEST_F(DeclarableOpsTests10, ImageResizeBilinear_Test02) {
0.4739f, 0.7014f, 0.4473f,
0.5171f, 0.1744f, 0.3487f});
NDArray expected = NDArrayFactory::create<float>('c', {10, 10, 4}, {1., 2., 3., 4., 2.2, 3.2, 4.2, 5.2, 3.4, 4.4, 5.4, 6.4,
4.6, 5.6, 6.6, 7.6, 5.8, 6.8, 7.8, 8.8, 7., 8., 9., 10.,
8.2, 9.2, 10.2, 11.2, 9., 10., 11., 12., 9., 10., 11., 12.,
9., 10., 11., 12., 3.4, 4.4, 5.4, 6.4, 4.6, 5.6, 6.6, 7.6,
5.8, 6.8, 7.8, 8.8, 7.0, 8., 9., 10., 8.2, 9.2, 10.2, 11.2,
9.4,10.4, 11.4, 12.4,10.6, 11.6,12.6, 13.6,11.4, 12.4, 13.4, 14.4,
11.4,12.4, 13.4, 14.4,11.4, 12.4,13.4, 14.4, 5.8, 6.8, 7.8, 8.8,
7., 8., 9., 10., 8.2, 9.2,10.2, 11.2, 9.4, 10.4, 11.4, 12.4,
10.6,11.6, 12.6, 13.6,11.8, 12.8,13.8, 14.8,13.0, 14.0, 15.0, 16.,
13.8,14.8, 15.8, 16.8,13.8, 14.8,15.8, 16.8,13.8, 14.8, 15.8, 16.8,
8.2, 9.2, 10.2, 11.2, 9.4, 10.4,11.4, 12.4,10.6, 11.6, 12.6, 13.6,
11.8,12.8, 13.8, 14.8,13., 14., 15., 16., 14.2, 15.2, 16.2, 17.2,
15.4,16.4, 17.4, 18.4,16.2, 17.2,18.2, 19.2,16.2, 17.2, 18.2, 19.2,
16.2,17.2, 18.2, 19.2,10.6, 11.6,12.6, 13.6,11.8, 12.8, 13.8, 14.8,
13., 14., 15., 16., 14.2, 15.2,16.2, 17.2,15.4, 16.4, 17.4, 18.4,
16.6,17.6, 18.6, 19.6,17.8, 18.8,19.8, 20.8,18.6, 19.6, 20.6, 21.6,
18.6,19.6, 20.6, 21.6,18.6, 19.6,20.6, 21.6,13., 14., 15., 16.,
14.2,15.2, 16.2, 17.2,15.4, 16.4,17.4, 18.4,16.6, 17.6, 18.6, 19.6,
17.8,18.8, 19.8, 20.8,19., 20., 21., 22., 20.2, 21.2, 22.2, 23.2,
21., 22., 23., 24., 21., 22., 23., 24., 21., 22., 23., 24.,
13., 14., 15., 16., 14.2, 15.2,16.2, 17.2,15.4, 16.4, 17.4, 18.4,
16.6,17.6, 18.6, 19.6,17.8, 18.8, 19.8, 20.8,19., 20., 21., 22.,
20.2,21.2, 22.2, 23.2,21., 22., 23., 24., 21., 22., 23., 24.,
21., 22., 23., 24., 13., 14., 15., 16., 14.2, 15.2, 16.2, 17.2,
15.4,16.4, 17.4, 18.4,16.6, 17.6, 18.6, 19.6,17.8, 18.8, 19.8, 20.8,
19., 20., 21., 22., 20.2, 21.2, 22.2, 23.2,21., 22., 23., 24.,
21., 22., 23., 24., 21., 22., 23., 24., 13., 14., 15., 16.,
14.2,15.2, 16.2, 17.2,15.4, 16.4, 17.4, 18.4,16.6, 17.6, 18.6, 19.6,
17.8,18.8, 19.8, 20.8,19., 20., 21., 22., 20.2, 21.2, 22.2, 23.2,
21., 22., 23., 24., 21., 22., 23., 24., 21., 22., 23., 24.,
13., 14., 15., 16., 14.2, 15.2, 16.2, 17.2,15.4, 16.4, 17.4, 18.4,
16.6,17.6, 18.6, 19.6,17.8, 18.8, 19.8, 20.8,19., 20., 21., 22.,
20.2,21.2, 22.2, 23.2,
21., 22., 23., 24., 21., 22., 23., 24., 21., 22., 23., 24.});
NDArray expected = NDArrayFactory::create<float>('c', {2, 9, 9, 3}, {
0.7788f, 0.8012f, 0.7244f, 0.4744111f, 0.7600333f, 0.42217776f,
0.26142225f, 0.7454778f, 0.22103335f, 0.41403335f, 0.8373667f, 0.42420003f,
0.59844446f, 0.71318877f, 0.6011445f, 0.83055556f, 0.264911f, 0.7387556f,
0.83529997f, 0.2422334f, 0.5823999f, 0.6884666f, 0.5032889f, 0.23006654f,
0.6591f, 0.5555f, 0.1596f, 0.5176333f, 0.44208887f , 0.5827889f,
0.5938309f, 0.5646876f, 0.5123568f, 0.61811364f, 0.6748667f, 0.44617534f,
0.43473703f, 0.7353667f, 0.3969963f, 0.35003704f, 0.6654419f, 0.46649635f,
0.41335183f, 0.39988017f, 0.7140149f, 0.43368888f, 0.45865932f, 0.72049254f,
0.42537406f, 0.73366547f, 0.5662765f, 0.42371112f, 0.78866667f, 0.53543335f,
0.30312222f, 0.18414445f, 0.49542224f, 0.67293704f, 0.4168852f, 0.59891605f,
0.8822444f, 0.60281235f, 0.62855184f, 0.4495222f, 0.6014852f, 0.36275554f,
0.15933579f, 0.5788963f, 0.34024328f, 0.08295307f, 0.52441484f, 0.6826569f,
0.10747781f, 0.64715934f, 0.80707777f, 0.19927411f, 0.8880544f, 0.7861703f,
0.21763334f, 0.9362333f, 0.78198886f, 0.27523333f, 0.3308667f, 0.6250333f,
0.5907889f, 0.45925558f, 0.6709963f, 0.7761333f, 0.5249852f, 0.63986665f,
0.4406333f, 0.34007773f, 0.3003666f, 0.19945924f, 0.33715558f, 0.24757043f,
0.09977405f, 0.60721123f, 0.6248297f, 0.08286668f, 0.7239556f, 0.6876333f,
0.12114445f, 0.73849255f ,0.54079986f, 0.12879999f, 0.74139994f, 0.51143324f,
0.32978892f, 0.45314446f, 0.58711106f, 0.5576408f, 0.5464408f, 0.6107901f,
0.68978024f, 0.55681235f, 0.5833172f, 0.43907034f, 0.23548517f, 0.35123706f,
0.26263458f, 0.18254575f, 0.33890504f, 0.1976099f, 0.5321877f, 0.65619516f,
0.18267044f, 0.6404851f, 0.63069254f, 0.20112106f, 0.58788633f, 0.37666163f,
0.20481117f, 0.57736665f, 0.32585555f, 0.50801116f, 0.5387556f, 0.29788882f,
0.59799266f, 0.7008482f, 0.35215425f, 0.6330642f, 0.753121f, 0.42497158f,
0.44849625f, 0.36611477f, 0.5719964f, 0.36038768f, 0.1586321f, 0.70625067f,
0.416968f, 0.22043455f, 0.82134944f, 0.4690964f, 0.31661478f, 0.6675073f,
0.5182569f, 0.4357136f, 0.33437145f, 0.528089f, 0.4595333f, 0.26774442f,
0.52779996f, 0.5559667f, 0.35320008f, 0.5630963f, 0.62568885f, 0.44562602f,
0.557237f, 0.62408876f, 0.5438927f, 0.3867555f, 0.3371999f, 0.6655223f,
0.30325183f, 0.17024446f, 0.71867025f, 0.35021478f, 0.18318895f, 0.6690962f,
0.4377444f, 0.24482228f, 0.5241777f, 0.5523185f, 0.33891484f, 0.3156962f,
0.5752333f, 0.3577333f, 0.27400002f, 0.44196665f, 0.52757776f, 0.6382001f,
0.47803456f, 0.3974851f, 0.7738359f, 0.4686691f, 0.27816284f, 0.8476581f,
0.2775703f, 0.20192216f, 0.6742259f, 0.14285672f, 0.20554078f, 0.4944727f,
0.0927209f, 0.32894826f, 0.30523813f, 0.19454071f, 0.3410815f, 0.26075178f,
0.3976642f, 0.27903205f, 0.31276423f, 0.43828884f, 0.2666222f, 0.32316667f,
0.4248f, 0.5219f, 0.6952f, 0.46102223f, 0.35184443f, 0.8394778f,
0.45095554f, 0.20897777f, 0.9084111f, 0.2557333f, 0.17486666f, 0.6759666f,
0.11077777f, 0.21260004f, 0.44963327f, 0.04122221f, 0.35810006f, 0.23246664f,
0.14590007f, 0.36033332f, 0.2080667f, 0.3667334f, 0.2670555f, 0.31217784f,
0.4109f, 0.2484f, 0.333f, 0.2974f, 0.6636f, 0.3808f,
0.6135111f, 0.40026665f, 0.5875778f, 0.8503f, 0.24200003f, 0.7501111f,
0.76979995f, 0.50400007f, 0.7356667f, 0.6879222f, 0.57351106f, 0.73106664f,
0.60397774f, 0.35428885f, 0.74123335f, 0.39506656f, 0.27853334f, 0.6585333f,
0.10284433f, 0.29842222f, 0.5139222f, 0.0444f, 0.3024f, 0.485f,
0.5756222f, 0.34854442f, 0.6049667f, 0.6263938f, 0.22777282f, 0.71313334f,
0.66620123f, 0.17765433f, 0.78429013f, 0.6621518f, 0.41014817f, 0.7074074f,
0.67555183f, 0.51060987f, 0.6708259f, 0.7151259f, 0.41302344f, 0.6946963f,
0.5446962f, 0.33081108f, 0.6180703f, 0.23426408f, 0.25884813f, 0.4744469f,
0.17217779f, 0.24445555f, 0.44572222f, 0.7964111f, 0.12472223f, 0.7531556f,
0.6118617f, 0.1483889f, 0.75928515f, 0.4833407f, 0.2004667f, 0.7449173f,
0.57893336f, 0.3661889f, 0.6485592f, 0.6772543f, 0.46945432f, 0.5984506f,
0.7796679f, 0.47903457f, 0.617716f, 0.63706285f, 0.40579626f, 0.54952586f,
0.33111224f, 0.27734566f, 0.42303205f, 0.26992223f, 0.25165558f, 0.39773333f,
0.7874667f, 0.26583335f, 0.5974333f, 0.4876703f, 0.44144446f, 0.48782218f,
0.30543333f, 0.57191116f, 0.41133702f, 0.5934334f, 0.5218f, 0.46735552f,
0.73524815f, 0.5152815f, 0.47753704f, 0.6577852f, 0.5741519f, 0.41896293f,
0.50037766f, 0.57161117f, 0.3686555f, 0.28967398f, 0.5281297f, 0.3238592f,
0.24753332f, 0.5194334f, 0.31489998f, 0.72816664f, 0.37683335f, 0.5285778f,
0.3895555f, 0.5582283f, 0.32292962f, 0.18990126f, 0.6730641f, 0.18445063f,
0.5460741f, 0.5216629f, 0.31464812f, 0.6978098f, 0.45279747f, 0.36710492f,
0.5428901f, 0.5077358f, 0.30295062f, 0.42367774f, 0.53567034f, 0.28493333f,
0.32827038f, 0.54560244f, 0.2976741f, 0.30918893f, 0.5475888f, 0.30022222f,
0.5933333f, 0.44266668f, 0.59002227f, 0.3305555f, 0.4106049f, 0.31789258f,
0.16793211f, 0.36878017f, 0.11760493f, 0.40592593f, 0.28790364f, 0.20468517f,
0.5172234f, 0.22784683f, 0.27239504f, 0.4384765f, 0.19901967f, 0.3110494f,
0.43695557f, 0.19709623f, 0.34693336f, 0.4869186f, 0.21310854f, 0.38097042f,
0.49691117f, 0.21631104f, 0.3877778f, 0.37919992f, 0.4914f, 0.56826663f,
0.26019996f, 0.34673333f, 0.29495183f, 0.21430746f, 0.23090371f, 0.09418149f,
0.46084452f, 0.23042224f, 0.1835889f, 0.56450003f, 0.23844449f, 0.26893705f,
0.45383334f, 0.2592223f, 0.34819633f, 0.45761114f, 0.21635559f, 0.38596666f,
0.5376852f, 0.13105926f, 0.39607778f, 0.55370003f, 0.11400001f, 0.3981f,
0.11219993f, 0.5287333f, 0.49104443f, 0.18227404f, 0.3386963f, 0.26007527f,
0.30624574f, 0.20396544f, 0.09970618f, 0.6458075f, 0.2904593f, 0.22173704f,
0.7636852f, 0.40607417f, 0.32631359f, 0.549037f, 0.5653705f, 0.40470868f,
0.4831852f, 0.47417036f, 0.40968886f, 0.5165309f, 0.21597281f, 0.3657259f,
0.5232f, 0.16433334f, 0.3569333f, 0.0588f, 0.5362f, 0.4756f,
0.16668889f, 0.33708888f, 0.25309998f, 0.32463336f, 0.19857779f, 0.10081112f,
0.68280005f, 0.3024667f, 0.22936666f, 0.80352217f, 0.43960005f, 0.33778888f,
0.5680777f, 0.6266f, 0.41601112f, 0.4883f, 0.52573323f, 0.4144333f,
0.5123f, 0.23295549f, 0.35965553f, 0.5171f, 0.1744f, 0.3487f
});
//input.linspace(1);
nd4j::ops::resize_bilinear op;
@ -1624,12 +1674,12 @@ TEST_F(DeclarableOpsTests10, ImageResizeBilinear_Test02) {
NDArray* result = results->at(0);
result->printIndexedBuffer("Resized to 9x9");
//expected.printIndexedBuffer("Expect for 10x10");
result->printShapeInfo("Output shape");
// result->printBuffer("Resized to 9x9");
// expected.printBuffer("Expect for 9x9");
// result->printShapeInfo("Output shape");
// expected.printShapeInfo("Expect shape");
// ASSERT_TRUE(expected.isSameShape(result));
// ASSERT_TRUE(expected.equalsTo(result));
ASSERT_TRUE(expected.isSameShape(result));
ASSERT_TRUE(expected.equalsTo(result));
delete results;
}
@ -2015,6 +2065,53 @@ TEST_F(DeclarableOpsTests10, ImageResizeNeighbor_Test1) {
delete results;
}
TEST_F(DeclarableOpsTests10, ImageResizeNeighbor_Test1_1) {
NDArray input = NDArrayFactory::create<int>('c', {1, 2, 3, 4});
//NDArray<float> paddings('c', {3,2}, {0,0, 0,1, 0,0});
//NDArray<float> expected('c', {2,4,4}, {1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,0.,0.,0.,0.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,0.,0.,0.,0.});
NDArray expected = NDArrayFactory::create<int>('c', {1, 4, 5, 4}, { 1, 2, 3, 4,
1, 2, 3, 4,
5, 6, 7, 8,
5, 6, 7, 8,
9, 10, 11, 12,
1, 2, 3, 4,
1, 2, 3, 4,
5, 6, 7, 8,
5, 6, 7, 8,
9, 10, 11, 12,
13, 14, 15, 16,
13, 14, 15, 16,
17, 18, 19, 20,
17, 18, 19, 20,
21, 22, 23, 24,
13, 14, 15, 16,
13, 14, 15, 16,
17, 18, 19, 20,
17, 18, 19, 20,
21, 22, 23, 24
});
//input = 1.f;
input.linspace(1);
nd4j::ops::resize_nearest_neighbor op;
auto results = op.execute({&input}, {}, {4, 5});
ASSERT_EQ(ND4J_STATUS_OK, results->status());
NDArray* result = results->at(0);
// result->printIndexedBuffer("Resized to 4x5");
// expected.printIndexedBuffer("Expect for 4x5");
ASSERT_TRUE(expected.isSameShape(result));
ASSERT_TRUE(expected.equalsTo(result));
delete results;
}
TEST_F(DeclarableOpsTests10, ImageResizeNeighbor_Test01) {
NDArray input = NDArrayFactory::create<double>('c', {2, 3, 4});
@ -2166,6 +2263,73 @@ TEST_F(DeclarableOpsTests10, Image_NonMaxSuppressing_2) {
delete results;
}
////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests10, Image_NonMaxSuppressing_3) {
NDArray boxes = NDArrayFactory::create<float>('c', {3, 4}, {0.8115f, 0.4121f, 0.0771f, 0.4863f,
0.7412f, 0.7607f, 0.1543f, 0.5479f,
0.8223f, 0.2246f, 0.0049f, 0.6465f});
NDArray scales = NDArrayFactory::create<float>('c', {3}, {0.0029f, 0.8135f, 0.4873f}); //3, 0, 1, 2, 4, 5
NDArray expected = NDArrayFactory::create<int>('c', {1}, {1});
nd4j::ops::non_max_suppression op;
auto results = op.execute({&boxes, &scales}, {0.5, 0.5}, {2});
ASSERT_EQ(Status::OK(), results->status());
NDArray* result = results->at(0);
// result->printBuffer("NonMaxSuppression OUtput3");
ASSERT_TRUE(expected.isSameShapeStrict(result));
ASSERT_TRUE(expected.equalsTo(result));
delete results;
}
TEST_F(DeclarableOpsTests10, Image_NonMaxSuppressing_4) {
NDArray boxes = NDArrayFactory::create<float16>('c', {3, 4}, {0.8115f, 0.4121f, 0.0771f, 0.4863f,
0.7412f, 0.7607f, 0.1543f, 0.5479f,
0.8223f, 0.2246f, 0.0049f, 0.6465f});
NDArray scales = NDArrayFactory::create<float16>('c', {3}, {0.0029f, 0.8135f, 0.4873f}); //3, 0, 1, 2, 4, 5
NDArray expected = NDArrayFactory::create<int>('c', {1}, {1});
NDArray maxSize = NDArrayFactory::create(2);
NDArray threshold = NDArrayFactory::create(0.5f);
NDArray scoreThreshold = NDArrayFactory::create(0.5);
nd4j::ops::non_max_suppression op;
auto results = op.execute({&boxes, &scales, &maxSize, &threshold, &scoreThreshold}, {}, {});
ASSERT_EQ(Status::OK(), results->status());
NDArray* result = results->at(0);
// result->printBuffer("NonMaxSuppression OUtput4");
ASSERT_TRUE(expected.isSameShapeStrict(result));
ASSERT_TRUE(expected.equalsTo(result));
delete results;
}
TEST_F(DeclarableOpsTests10, Image_NonMaxSuppressing_5) {
NDArray boxes = NDArrayFactory::create<float16>('c', {3, 4}, {0.8115f, 0.4121f, 0.0771f, 0.4863f,
0.7412f, 0.7607f, 0.1543f, 0.5479f,
0.8223f, 0.2246f, 0.0049f, 0.6465f});
NDArray scales = NDArrayFactory::create<float16>('c', {3}, {0.0029f, 0.8135f, 0.4873f}); //3, 0, 1, 2, 4, 5
NDArray expected = NDArrayFactory::create<int>('c', {2}, {1, 2});
NDArray maxSize = NDArrayFactory::create(2);
NDArray threshold = NDArrayFactory::create(0.5f);
NDArray scoreThreshold = NDArrayFactory::create(-DataTypeUtils::infOrMax<float>());
nd4j::ops::non_max_suppression op;
auto results = op.execute({&boxes, &scales, &maxSize, &threshold, &scoreThreshold}, {}, {});
ASSERT_EQ(Status::OK(), results->status());
NDArray* result = results->at(0);
// result->printBuffer("NonMaxSuppression OUtput4");
ASSERT_TRUE(expected.isSameShapeStrict(result));
ASSERT_TRUE(expected.equalsTo(result));
delete results;
}
////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests10, Image_NonMaxSuppressingOverlap_1) {
@ -2692,6 +2856,46 @@ TEST_F(DeclarableOpsTests10, FakeQuantWithMinMaxVars_Test_6) {
delete results;
}
////////////////////////////////////////////////////////////////////
//TEST_F(DeclarableOpsTests10, FakeQuantWithMinMaxVars_Test_7) {
//
// NDArray x = NDArrayFactory::create<double>('c', {100});
// NDArray exp = NDArrayFactory::create<double>('c', {100}, {
// 0.f, 0.f, 0.f , 0.f , 0.06666667f, 0.06666667f ,
// 0.06666667, 0.06666667, 0.06666667, 0.06666667, 0.06666667, 0.13333334 ,
// 0.13333334, 0.13333334, 0.13333334, 0.13333334, 0.13333334, 0.20000002 ,
// 0.20000002, 0.20000002, 0.20000002, 0.20000002, 0.20000002, 0.20000002 ,
// 0.26666668, 0.26666668, 0.26666668, 0.26666668, 0.26666668, 0.26666668 ,
// 0.26666668, 0.33333334, 0.33333334, 0.33333334, 0.33333334, 0.33333334 ,
// 0.33333334, 0.40000004, 0.40000004, 0.40000004, 0.40000004, 0.40000004 ,
// 0.40000004, 0.40000004, 0.4666667 , 0.4666667 , 0.4666667 , 0.4666667 ,
// 0.4666667 , 0.4666667 , 0.4666667 , 0.53333336, 0.53333336, 0.53333336 ,
// 0.53333336, 0.53333336, 0.53333336, 0.6 , 0.6 , 0.6 ,
// 0.6 , 0.6 , 0.6 , 0.6 , 0.6666667 , 0.6666667 ,
// 0.6666667 , 0.6666667 , 0.6666667 , 0.6666667 , 0.6666667 , 0.73333335 ,
// 0.73333335, 0.73333335, 0.73333335, 0.73333335, 0.73333335, 0.8000001 ,
// 0.8000001 , 0.8000001 , 0.8000001 , 0.8000001 , 0.8000001 , 0.8000001 ,
// 0.86666673, 0.86666673, 0.86666673, 0.86666673, 0.86666673, 0.86666673 ,
// 0.86666673, 0.9333334 , 0.9333334 , 0.9333334 , 0.9333334 , 0.9333334 ,
// 0.9333334 , 1., 1., 1.,
// });
// NDArray min = NDArrayFactory::create<float>('c', {1},{0.0f});
// NDArray max = NDArrayFactory::create<float>('c', {1}, {1.f});
// x.linspace(0., 0.01);
// nd4j::ops::fake_quant_with_min_max_vars op;
// auto results = op.execute({&x, &min, &max}, {}, {});
//
// ASSERT_EQ(ND4J_STATUS_OK, results->status());
//
// auto result = results->at(0);
// result->printBuffer("Quantized7");
// exp.printBuffer("Expected 7");
// ASSERT_TRUE(exp.isSameShapeStrict(result));
// ASSERT_TRUE(exp.equalsTo(result));
//
// delete results;
//}
////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests10, batchnorm_test1) {

View File

@ -254,6 +254,34 @@ TEST_F(DeclarableOpsTests15, Test_BitCast_2) {
delete result;
}
TEST_F(DeclarableOpsTests15, Test_BitCast_3) {
auto x = NDArrayFactory::create<float>('c', {1, 4});
x.linspace(1.);
nd4j::ops::bitcast op;
try {
auto result = op.execute({&x}, {}, {nd4j::DataType::INT64}, {});
ASSERT_NE(Status::OK(), result->status());
delete result;
} catch (std::exception& e) {
nd4j_printf("Error should be here `%s'. It's OK.\n", e.what());
}
}
TEST_F(DeclarableOpsTests15, Test_BitCast_4) {
auto x = NDArrayFactory::create<float>('c', {1, 4});
auto e = NDArrayFactory::create<Nd4jLong>('c', {1, 2}, {1234567890LL, 2468013579LL});
x.linspace(1.);
nd4j::ops::bitcast op;
try {
auto result = op.execute({&x}, {&e}, {}, {nd4j::DataType::INT64}, {});
ASSERT_NE(Status::OK(), result);
} catch(std::exception& e) {
nd4j_printf("Error `%s' should be here. It's OK.\n",e.what());
}
}
TEST_F(DeclarableOpsTests15, Test_depthwise_bp_1) {
auto in = NDArrayFactory::create<float>('c', {4, 8, 64, 64});
auto w = NDArrayFactory::create<float>('c', {2, 2, 8, 2});

View File

@ -426,6 +426,10 @@ TEST_F(NlpTests, test_sg_ns_batch_1) {
}
TEST_F(NlpTests, test_cbow_hs_batch_1) {
#ifdef __CUDABLAS__
return ;
#endif
auto target = NDArrayFactory::create<int>(0);
auto ngStarter = NDArrayFactory::empty<int>();
auto context = NDArrayFactory::create<int>('c', {2, 3}, {0, 1, 2, 100, 101, 102});

View File

@ -62,8 +62,7 @@ public class ResizeNearestNeighbor extends DynamicCustomOp {
public List<DataType> calculateOutputDataTypes(List<DataType> inputDataTypes){
Preconditions.checkState(inputDataTypes != null && (inputDataTypes.size() == 1 || inputDataTypes.size() == 2),
"Expected 1 or 2 input datatypes for %s, got %s", getClass(), inputDataTypes);
if(inputDataTypes.get(0).isFPType())
return Collections.singletonList(inputDataTypes.get(0));
return Collections.singletonList(Nd4j.defaultFloatingPointType());
}
}

View File

@ -8005,9 +8005,12 @@ public static final int PREALLOC_SIZE = 33554432;
@Namespace("shape") public static native @Cast("Nd4jLong") long getOffset(@Cast("const Nd4jLong*") LongBuffer shapeInfo, @Cast("const Nd4jLong*") LongBuffer indices);
@Namespace("shape") public static native @Cast("Nd4jLong") long getOffset(@Cast("const Nd4jLong*") long[] shapeInfo, @Cast("const Nd4jLong*") long[] indices, @Cast("Nd4jLong") long baseOffset/*=0*/);
@Namespace("shape") public static native @Cast("Nd4jLong") long getOffset(@Cast("const Nd4jLong*") long[] shapeInfo, @Cast("const Nd4jLong*") long[] indices);
@Namespace("shape") public static native @Cast("Nd4jLong") long getOffset(@Cast("const Nd4jLong*") LongPointer shapeInfo, @Cast("uint*") @StdVector IntPointer indices);
@Namespace("shape") public static native @Cast("Nd4jLong") long getOffset(@Cast("const Nd4jLong*") LongBuffer shapeInfo, @Cast("uint*") @StdVector IntBuffer indices);
@Namespace("shape") public static native @Cast("Nd4jLong") long getOffset(@Cast("const Nd4jLong*") long[] shapeInfo, @Cast("uint*") @StdVector int[] indices);
@Namespace("shape") public static native @Cast("Nd4jLong") long getOffset(@Cast("const Nd4jLong*") LongPointer shapeInfo, @Const IntPointer indices, @Cast("Nd4jLong") long baseOffset/*=0*/);
@Namespace("shape") public static native @Cast("Nd4jLong") long getOffset(@Cast("const Nd4jLong*") LongPointer shapeInfo, @Const IntPointer indices);
@Namespace("shape") public static native @Cast("Nd4jLong") long getOffset(@Cast("const Nd4jLong*") LongBuffer shapeInfo, @Const IntBuffer indices, @Cast("Nd4jLong") long baseOffset/*=0*/);
@Namespace("shape") public static native @Cast("Nd4jLong") long getOffset(@Cast("const Nd4jLong*") LongBuffer shapeInfo, @Const IntBuffer indices);
@Namespace("shape") public static native @Cast("Nd4jLong") long getOffset(@Cast("const Nd4jLong*") long[] shapeInfo, @Const int[] indices, @Cast("Nd4jLong") long baseOffset/*=0*/);
@Namespace("shape") public static native @Cast("Nd4jLong") long getOffset(@Cast("const Nd4jLong*") long[] shapeInfo, @Const int[] indices);
@Namespace("shape") public static native @Cast("Nd4jLong*") LongPointer createShapeInfo(@Cast("Nd4jLong*") LongPointer shape, @Cast("Nd4jLong*") LongPointer stride, int rank);
@Namespace("shape") public static native @Cast("Nd4jLong*") LongBuffer createShapeInfo(@Cast("Nd4jLong*") LongBuffer shape, @Cast("Nd4jLong*") LongBuffer stride, int rank);
@ -8024,6 +8027,9 @@ public static final int PREALLOC_SIZE = 33554432;
@Namespace("shape") public static native void index2coords(@Cast("Nd4jLong") long index, @Cast("const Nd4jLong*") LongPointer shapeInfo, @Cast("Nd4jLong*") LongPointer coords);
@Namespace("shape") public static native void index2coords(@Cast("Nd4jLong") long index, @Cast("const Nd4jLong*") LongBuffer shapeInfo, @Cast("Nd4jLong*") LongBuffer coords);
@Namespace("shape") public static native void index2coords(@Cast("Nd4jLong") long index, @Cast("const Nd4jLong*") long[] shapeInfo, @Cast("Nd4jLong*") long[] coords);
@Namespace("shape") public static native void index2coords(@Cast("Nd4jLong") long index, @Cast("const Nd4jLong*") LongPointer shapeInfo, IntPointer coords);
@Namespace("shape") public static native void index2coords(@Cast("Nd4jLong") long index, @Cast("const Nd4jLong*") LongBuffer shapeInfo, IntBuffer coords);
@Namespace("shape") public static native void index2coords(@Cast("Nd4jLong") long index, @Cast("const Nd4jLong*") long[] shapeInfo, int[] coords);
@Namespace("shape") public static native void index2coords(@Cast("Nd4jLong") long index, int rank, @Cast("const Nd4jLong*") LongPointer shape, @Cast("Nd4jLong*") LongPointer coords);
@Namespace("shape") public static native void index2coords(@Cast("Nd4jLong") long index, int rank, @Cast("const Nd4jLong*") LongBuffer shape, @Cast("Nd4jLong*") LongBuffer coords);
@Namespace("shape") public static native void index2coords(@Cast("Nd4jLong") long index, int rank, @Cast("const Nd4jLong*") long[] shape, @Cast("Nd4jLong*") long[] coords);
@ -8043,6 +8049,9 @@ public static final int PREALLOC_SIZE = 33554432;
@Namespace("shape") public static native @Cast("Nd4jLong") long coords2index(@Cast("const Nd4jLong*") LongPointer shapeInfo, @Cast("const Nd4jLong*") LongPointer coords);
@Namespace("shape") public static native @Cast("Nd4jLong") long coords2index(@Cast("const Nd4jLong*") LongBuffer shapeInfo, @Cast("const Nd4jLong*") LongBuffer coords);
@Namespace("shape") public static native @Cast("Nd4jLong") long coords2index(@Cast("const Nd4jLong*") long[] shapeInfo, @Cast("const Nd4jLong*") long[] coords);
@Namespace("shape") public static native @Cast("Nd4jLong") long coords2index(@Cast("const Nd4jLong*") LongPointer shapeInfo, @Const IntPointer coords);
@Namespace("shape") public static native @Cast("Nd4jLong") long coords2index(@Cast("const Nd4jLong*") LongBuffer shapeInfo, @Const IntBuffer coords);
@Namespace("shape") public static native @Cast("Nd4jLong") long coords2index(@Cast("const Nd4jLong*") long[] shapeInfo, @Const int[] coords);
@Namespace("shape") public static native @Cast("Nd4jLong") long coords2index(int rank, @Cast("const Nd4jLong*") LongPointer shape, @Cast("const Nd4jLong*") LongPointer coords);
@Namespace("shape") public static native @Cast("Nd4jLong") long coords2index(int rank, @Cast("const Nd4jLong*") LongBuffer shape, @Cast("const Nd4jLong*") LongBuffer coords);
@Namespace("shape") public static native @Cast("Nd4jLong") long coords2index(int rank, @Cast("const Nd4jLong*") long[] shape, @Cast("const Nd4jLong*") long[] coords);
@ -8354,6 +8363,10 @@ public static final int PREALLOC_SIZE = 33554432;
//////////////////////////////////////////////////////////////////////
//////////////////////////////////////////////////////////////////////
//////////////////////////////////////////////////////////////////////
// //////////////////////////////////////////////////////////////////////
// INLINEDEF _CUDA_HD Nd4jLong getIndexOffset(Nd4jLong index, const Nd4jLong *shapeInfo, Nd4jLong arrLen) {
@ -8778,7 +8791,7 @@ public static final int PREALLOC_SIZE = 33554432;
//////////////////////////////////////////////////////////////////////////
//////////////////////////////////////////////////////////////////////////
/**
@ -9110,6 +9123,10 @@ public static final int PREALLOC_SIZE = 33554432;
//////////////////////////////////////////////////////////////////////
//////////////////////////////////////////////////////////////////////
//////////////////////////////////////////////////////////////////////

View File

@ -6,6 +6,9 @@ import java.nio.*;
import org.bytedeco.javacpp.*;
import org.bytedeco.javacpp.annotation.*;
import static org.bytedeco.openblas.global.openblas_nolapack.*;
import static org.bytedeco.openblas.global.openblas.*;
public class Nd4jCpu extends org.nd4j.nativeblas.Nd4jCpuHelper {
static { Loader.load(); }
@ -8005,9 +8008,12 @@ public static final int PREALLOC_SIZE = 33554432;
@Namespace("shape") public static native @Cast("Nd4jLong") long getOffset(@Cast("const Nd4jLong*") LongBuffer shapeInfo, @Cast("const Nd4jLong*") LongBuffer indices);
@Namespace("shape") public static native @Cast("Nd4jLong") long getOffset(@Cast("const Nd4jLong*") long[] shapeInfo, @Cast("const Nd4jLong*") long[] indices, @Cast("Nd4jLong") long baseOffset/*=0*/);
@Namespace("shape") public static native @Cast("Nd4jLong") long getOffset(@Cast("const Nd4jLong*") long[] shapeInfo, @Cast("const Nd4jLong*") long[] indices);
@Namespace("shape") public static native @Cast("Nd4jLong") long getOffset(@Cast("const Nd4jLong*") LongPointer shapeInfo, @Cast("uint*") @StdVector IntPointer indices);
@Namespace("shape") public static native @Cast("Nd4jLong") long getOffset(@Cast("const Nd4jLong*") LongBuffer shapeInfo, @Cast("uint*") @StdVector IntBuffer indices);
@Namespace("shape") public static native @Cast("Nd4jLong") long getOffset(@Cast("const Nd4jLong*") long[] shapeInfo, @Cast("uint*") @StdVector int[] indices);
@Namespace("shape") public static native @Cast("Nd4jLong") long getOffset(@Cast("const Nd4jLong*") LongPointer shapeInfo, @Const IntPointer indices, @Cast("Nd4jLong") long baseOffset/*=0*/);
@Namespace("shape") public static native @Cast("Nd4jLong") long getOffset(@Cast("const Nd4jLong*") LongPointer shapeInfo, @Const IntPointer indices);
@Namespace("shape") public static native @Cast("Nd4jLong") long getOffset(@Cast("const Nd4jLong*") LongBuffer shapeInfo, @Const IntBuffer indices, @Cast("Nd4jLong") long baseOffset/*=0*/);
@Namespace("shape") public static native @Cast("Nd4jLong") long getOffset(@Cast("const Nd4jLong*") LongBuffer shapeInfo, @Const IntBuffer indices);
@Namespace("shape") public static native @Cast("Nd4jLong") long getOffset(@Cast("const Nd4jLong*") long[] shapeInfo, @Const int[] indices, @Cast("Nd4jLong") long baseOffset/*=0*/);
@Namespace("shape") public static native @Cast("Nd4jLong") long getOffset(@Cast("const Nd4jLong*") long[] shapeInfo, @Const int[] indices);
@Namespace("shape") public static native @Cast("Nd4jLong*") LongPointer createShapeInfo(@Cast("Nd4jLong*") LongPointer shape, @Cast("Nd4jLong*") LongPointer stride, int rank);
@Namespace("shape") public static native @Cast("Nd4jLong*") LongBuffer createShapeInfo(@Cast("Nd4jLong*") LongBuffer shape, @Cast("Nd4jLong*") LongBuffer stride, int rank);
@ -8024,6 +8030,9 @@ public static final int PREALLOC_SIZE = 33554432;
@Namespace("shape") public static native void index2coords(@Cast("Nd4jLong") long index, @Cast("const Nd4jLong*") LongPointer shapeInfo, @Cast("Nd4jLong*") LongPointer coords);
@Namespace("shape") public static native void index2coords(@Cast("Nd4jLong") long index, @Cast("const Nd4jLong*") LongBuffer shapeInfo, @Cast("Nd4jLong*") LongBuffer coords);
@Namespace("shape") public static native void index2coords(@Cast("Nd4jLong") long index, @Cast("const Nd4jLong*") long[] shapeInfo, @Cast("Nd4jLong*") long[] coords);
@Namespace("shape") public static native void index2coords(@Cast("Nd4jLong") long index, @Cast("const Nd4jLong*") LongPointer shapeInfo, IntPointer coords);
@Namespace("shape") public static native void index2coords(@Cast("Nd4jLong") long index, @Cast("const Nd4jLong*") LongBuffer shapeInfo, IntBuffer coords);
@Namespace("shape") public static native void index2coords(@Cast("Nd4jLong") long index, @Cast("const Nd4jLong*") long[] shapeInfo, int[] coords);
@Namespace("shape") public static native void index2coords(@Cast("Nd4jLong") long index, int rank, @Cast("const Nd4jLong*") LongPointer shape, @Cast("Nd4jLong*") LongPointer coords);
@Namespace("shape") public static native void index2coords(@Cast("Nd4jLong") long index, int rank, @Cast("const Nd4jLong*") LongBuffer shape, @Cast("Nd4jLong*") LongBuffer coords);
@Namespace("shape") public static native void index2coords(@Cast("Nd4jLong") long index, int rank, @Cast("const Nd4jLong*") long[] shape, @Cast("Nd4jLong*") long[] coords);
@ -8043,6 +8052,9 @@ public static final int PREALLOC_SIZE = 33554432;
@Namespace("shape") public static native @Cast("Nd4jLong") long coords2index(@Cast("const Nd4jLong*") LongPointer shapeInfo, @Cast("const Nd4jLong*") LongPointer coords);
@Namespace("shape") public static native @Cast("Nd4jLong") long coords2index(@Cast("const Nd4jLong*") LongBuffer shapeInfo, @Cast("const Nd4jLong*") LongBuffer coords);
@Namespace("shape") public static native @Cast("Nd4jLong") long coords2index(@Cast("const Nd4jLong*") long[] shapeInfo, @Cast("const Nd4jLong*") long[] coords);
@Namespace("shape") public static native @Cast("Nd4jLong") long coords2index(@Cast("const Nd4jLong*") LongPointer shapeInfo, @Const IntPointer coords);
@Namespace("shape") public static native @Cast("Nd4jLong") long coords2index(@Cast("const Nd4jLong*") LongBuffer shapeInfo, @Const IntBuffer coords);
@Namespace("shape") public static native @Cast("Nd4jLong") long coords2index(@Cast("const Nd4jLong*") long[] shapeInfo, @Const int[] coords);
@Namespace("shape") public static native @Cast("Nd4jLong") long coords2index(int rank, @Cast("const Nd4jLong*") LongPointer shape, @Cast("const Nd4jLong*") LongPointer coords);
@Namespace("shape") public static native @Cast("Nd4jLong") long coords2index(int rank, @Cast("const Nd4jLong*") LongBuffer shape, @Cast("const Nd4jLong*") LongBuffer coords);
@Namespace("shape") public static native @Cast("Nd4jLong") long coords2index(int rank, @Cast("const Nd4jLong*") long[] shape, @Cast("const Nd4jLong*") long[] coords);
@ -8354,6 +8366,10 @@ public static final int PREALLOC_SIZE = 33554432;
//////////////////////////////////////////////////////////////////////
//////////////////////////////////////////////////////////////////////
//////////////////////////////////////////////////////////////////////
// //////////////////////////////////////////////////////////////////////
// INLINEDEF _CUDA_HD Nd4jLong getIndexOffset(Nd4jLong index, const Nd4jLong *shapeInfo, Nd4jLong arrLen) {
@ -8778,7 +8794,7 @@ public static final int PREALLOC_SIZE = 33554432;
//////////////////////////////////////////////////////////////////////////
//////////////////////////////////////////////////////////////////////////
/**
@ -9110,6 +9126,10 @@ public static final int PREALLOC_SIZE = 33554432;
//////////////////////////////////////////////////////////////////////
//////////////////////////////////////////////////////////////////////
//////////////////////////////////////////////////////////////////////