Shugeo resize fix5 (#102)
* Refactored resize images ops to use TF-like bool args as input. * Refactored helpers for cpu implementation of resize_bilinear and resize_nearest_neighbor ops. * Refactored cuda implementation for image.resize_bilinear and image.resize_nearest_neighbor ops helpers. * Refactored nearest_neighbor resize op. * Added a pair of tests for special case of resize_bilinear algorithm. * Fixed issue with resize_bilinear op. * Refactored cpu implementation for helpers with resize_nearest_neighbor op. * Final fixed for resize ops to conform TF v.1.5 * Refactored cuda helpers for resize_neares_neighbor op. * Fixed resize_bilinear to accept proper data. * Fixed issue with non-float input for resize_bilinear op. * Refactored cuda helper for resize_bilinear to proper process non-float inputs. * Added tests for resize_bilinear to int inputs. * Fixed ResizeBilinear wrapper * Tests fixed * Fixed float and bool constant to avoid overflow for some kind of compilers. * Corrected float constants with float data type. * Added f suffix for float constants. * Corrected float constant to avoid overflow with initializing lists. * Corrected float initializing list with float input. * Corrected bool constant with initalizing list. * Corrected float and bool values with initializing lists. * Fixed wrong constant. * Fixed issue with 1x1 input picture for resize. * ResizeBilinear default values on import fix Signed-off-by: raver119 <raver119@gmail.com>master
parent
6a3c046ffd
commit
e09a785232
|
@ -1256,6 +1256,9 @@ namespace nd4j {
|
||||||
FORCEINLINE T& t(const Nd4jLong i, const Nd4jLong j);
|
FORCEINLINE T& t(const Nd4jLong i, const Nd4jLong j);
|
||||||
template<typename T>
|
template<typename T>
|
||||||
FORCEINLINE T& t(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k);
|
FORCEINLINE T& t(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k);
|
||||||
|
template<typename T>
|
||||||
|
FORCEINLINE T& t(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const Nd4jLong w);
|
||||||
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* returns array element with given index
|
* returns array element with given index
|
||||||
|
@ -1268,6 +1271,8 @@ namespace nd4j {
|
||||||
FORCEINLINE T t(const Nd4jLong i, const Nd4jLong j) const;
|
FORCEINLINE T t(const Nd4jLong i, const Nd4jLong j) const;
|
||||||
template<typename T>
|
template<typename T>
|
||||||
FORCEINLINE T t(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k) const;
|
FORCEINLINE T t(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k) const;
|
||||||
|
template<typename T>
|
||||||
|
FORCEINLINE T t(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const Nd4jLong w) const;
|
||||||
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -1711,7 +1716,7 @@ namespace nd4j {
|
||||||
if (isEmpty())
|
if (isEmpty())
|
||||||
return false;
|
return false;
|
||||||
|
|
||||||
return shape::isMatrix(this->_shapeInfo);
|
return 0 != shape::isMatrix(this->_shapeInfo);
|
||||||
}
|
}
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////////
|
||||||
|
@ -1751,7 +1756,7 @@ namespace nd4j {
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////////
|
||||||
bool NDArray::isScalar() const {
|
bool NDArray::isScalar() const {
|
||||||
return shape::isScalar(this->_shapeInfo);
|
return 0 != shape::isScalar(this->_shapeInfo);
|
||||||
}
|
}
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////////
|
||||||
|
@ -2082,7 +2087,7 @@ template <typename T>
|
||||||
T& NDArray::t(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k) {
|
T& NDArray::t(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k) {
|
||||||
|
|
||||||
if (rankOf() != 3 || i >= sizeAt(0) || j >= sizeAt(1) || k >= sizeAt(2))
|
if (rankOf() != 3 || i >= sizeAt(0) || j >= sizeAt(1) || k >= sizeAt(2))
|
||||||
throw std::invalid_argument("NDArray::t(i,j,k): one of input indexes is out of array length or rank!=2 !");
|
throw std::invalid_argument("NDArray::t(i,j,k): one of input indexes is out of array length or rank!=3!");
|
||||||
if (DataTypeUtils::fromT<T>() != _dataType)
|
if (DataTypeUtils::fromT<T>() != _dataType)
|
||||||
throw std::invalid_argument("NDArray::t(i,j,k): type of array is not equal to template type T!");
|
throw std::invalid_argument("NDArray::t(i,j,k): type of array is not equal to template type T!");
|
||||||
|
|
||||||
|
@ -2095,6 +2100,23 @@ T& NDArray::t(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k) {
|
||||||
return *(reinterpret_cast<T*>(bufferWithOffset(offset)));
|
return *(reinterpret_cast<T*>(bufferWithOffset(offset)));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
T& NDArray::t(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const Nd4jLong w) {
|
||||||
|
|
||||||
|
if (rankOf() != 4 || i >= sizeAt(0) || j >= sizeAt(1) || k >= sizeAt(2), w >= sizeAt(3))
|
||||||
|
throw std::invalid_argument("NDArray::t(i,j,k,w): one of input indexes is out of array length or rank!=4 !");
|
||||||
|
if (DataTypeUtils::fromT<T>() != _dataType)
|
||||||
|
throw std::invalid_argument("NDArray::t(i,j,k,w): type of array is not equal to template type T!");
|
||||||
|
|
||||||
|
if(!isActualOnHostSide())
|
||||||
|
syncToHost();
|
||||||
|
|
||||||
|
Nd4jLong coords[4] = {i, j, k, w};
|
||||||
|
auto offset = shape::getOffset(getShapeInfo(), coords);
|
||||||
|
tickWriteHost();
|
||||||
|
return *(reinterpret_cast<T*>(bufferWithOffset(offset)));
|
||||||
|
}
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////
|
||||||
template <typename T>
|
template <typename T>
|
||||||
T NDArray::t(const Nd4jLong i) const {
|
T NDArray::t(const Nd4jLong i) const {
|
||||||
|
@ -2133,7 +2155,7 @@ T NDArray::t(const Nd4jLong i, const Nd4jLong j) const {
|
||||||
T NDArray::t(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k) const {
|
T NDArray::t(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k) const {
|
||||||
|
|
||||||
if (rankOf() != 3 || i >= sizeAt(0) || j >= sizeAt(1) || k >= sizeAt(2))
|
if (rankOf() != 3 || i >= sizeAt(0) || j >= sizeAt(1) || k >= sizeAt(2))
|
||||||
throw std::invalid_argument("NDArray::t(i,j,k): one of input indexes is out of array length or rank!=2 !");
|
throw std::invalid_argument("NDArray::t(i,j,k): one of input indexes is out of array length or rank!=3!");
|
||||||
if (DataTypeUtils::fromT<T>() != _dataType)
|
if (DataTypeUtils::fromT<T>() != _dataType)
|
||||||
throw std::invalid_argument("NDArray::t(i,j,k): type of array is not equal to template type T!");
|
throw std::invalid_argument("NDArray::t(i,j,k): type of array is not equal to template type T!");
|
||||||
|
|
||||||
|
@ -2146,6 +2168,23 @@ T NDArray::t(const Nd4jLong i, const Nd4jLong j) const {
|
||||||
return *(reinterpret_cast<T*>(bufferWithOffset(offset)));
|
return *(reinterpret_cast<T*>(bufferWithOffset(offset)));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
T NDArray::t(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const Nd4jLong w) const {
|
||||||
|
|
||||||
|
if (rankOf() != 4 || i >= sizeAt(0) || j >= sizeAt(1) || k >= sizeAt(2) || w >= sizeAt(3))
|
||||||
|
throw std::invalid_argument("NDArray::t(i,j,k,w): one of input indexes is out of array length or rank!=4!");
|
||||||
|
if (DataTypeUtils::fromT<T>() != _dataType)
|
||||||
|
throw std::invalid_argument("NDArray::t(i,j,k,w): type of array is not equal to template type T!");
|
||||||
|
|
||||||
|
if(!isActualOnHostSide())
|
||||||
|
syncToHost();
|
||||||
|
|
||||||
|
Nd4jLong coords[4] = {i, j, k, w};
|
||||||
|
auto offset = shape::getOffset(getShapeInfo(), coords);
|
||||||
|
tickReadHost();
|
||||||
|
return *(reinterpret_cast<T*>(bufferWithOffset(offset)));
|
||||||
|
}
|
||||||
|
|
||||||
#ifndef __JAVACPP_HACK__
|
#ifndef __JAVACPP_HACK__
|
||||||
////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////
|
||||||
std::shared_ptr<DataBuffer> NDArray::getDataBuffer() const {
|
std::shared_ptr<DataBuffer> NDArray::getDataBuffer() const {
|
||||||
|
|
|
@ -35,6 +35,8 @@ namespace nd4j {
|
||||||
int width;
|
int width;
|
||||||
int height;
|
int height;
|
||||||
auto inRank = image->rankOf();
|
auto inRank = image->rankOf();
|
||||||
|
if (output->isEmpty()) return Status::OK();
|
||||||
|
|
||||||
REQUIRE_TRUE(inRank == 3 || inRank == 4, 0, "resize_bicubic: Source tensor should have rank 4, but %i given.", inRank);
|
REQUIRE_TRUE(inRank == 3 || inRank == 4, 0, "resize_bicubic: Source tensor should have rank 4, but %i given.", inRank);
|
||||||
REQUIRE_TRUE(output->rankOf() == inRank, 0, "resize_bicubic: Source tensor and output should have the same rank, but %i and %i given.", inRank, output->rankOf());
|
REQUIRE_TRUE(output->rankOf() == inRank, 0, "resize_bicubic: Source tensor and output should have the same rank, but %i and %i given.", inRank, output->rankOf());
|
||||||
REQUIRE_TRUE(size->rankOf() == 1, size->lengthOf() == 2, 0, "resize_bicubic: Resize params is a pair of values, not %i.", size->lengthOf());
|
REQUIRE_TRUE(size->rankOf() == 1, size->lengthOf() == 2, 0, "resize_bicubic: Resize params is a pair of values, not %i.", size->lengthOf());
|
||||||
|
@ -57,7 +59,7 @@ namespace nd4j {
|
||||||
if (block.numB()> 1)
|
if (block.numB()> 1)
|
||||||
halfPixelAlign = block.getBArguments()->at(1);
|
halfPixelAlign = block.getBArguments()->at(1);
|
||||||
}
|
}
|
||||||
REQUIRE_TRUE(halfPixelAlign == false || halfPixelAlign == true && alignCorners == false, 0, "resize_bicubic: half pixel align can be used only with non-aligned corners");
|
REQUIRE_TRUE(!halfPixelAlign || (halfPixelAlign && !alignCorners), 0, "resize_bicubic: `half_pixel_centers' should be false or true only when `align_corners' is false");
|
||||||
|
|
||||||
auto source = inRank == 4?image->reshape(image->ordering(), {image->sizeAt(0), image->sizeAt(1), image->sizeAt(2), image->sizeAt(3)}):image->reshape(image->ordering(), {1, image->sizeAt(0), image->sizeAt(1), image->sizeAt(2)});
|
auto source = inRank == 4?image->reshape(image->ordering(), {image->sizeAt(0), image->sizeAt(1), image->sizeAt(2), image->sizeAt(3)}):image->reshape(image->ordering(), {1, image->sizeAt(0), image->sizeAt(1), image->sizeAt(2)});
|
||||||
auto target = inRank == 4?output->reshape(output->ordering(), {output->sizeAt(0), output->sizeAt(1), output->sizeAt(2), output->sizeAt(3)}):output->reshape(output->ordering(), {1, output->sizeAt(0), output->sizeAt(1), output->sizeAt(2)});
|
auto target = inRank == 4?output->reshape(output->ordering(), {output->sizeAt(0), output->sizeAt(1), output->sizeAt(2), output->sizeAt(3)}):output->reshape(output->ordering(), {1, output->sizeAt(0), output->sizeAt(1), output->sizeAt(2)});
|
||||||
|
|
|
@ -32,8 +32,10 @@ namespace nd4j {
|
||||||
NDArray* output = OUTPUT_VARIABLE(0);
|
NDArray* output = OUTPUT_VARIABLE(0);
|
||||||
int width;
|
int width;
|
||||||
int height;
|
int height;
|
||||||
bool center = false; // - default value
|
bool alignCorners = false; // - default value
|
||||||
auto inRank = image->rankOf();
|
auto inRank = image->rankOf();
|
||||||
|
if (output->isEmpty()) return Status::OK();
|
||||||
|
|
||||||
REQUIRE_TRUE( inRank == 4 || inRank == 3, 0, "resize_bilinear: input image should be 4D "
|
REQUIRE_TRUE( inRank == 4 || inRank == 3, 0, "resize_bilinear: input image should be 4D "
|
||||||
"tensor, but input has rank %i",
|
"tensor, but input has rank %i",
|
||||||
image->rankOf());
|
image->rankOf());
|
||||||
|
@ -46,21 +48,25 @@ namespace nd4j {
|
||||||
auto newImageSize = INPUT_VARIABLE(1);
|
auto newImageSize = INPUT_VARIABLE(1);
|
||||||
REQUIRE_TRUE(newImageSize->lengthOf() == 2, 0, "resize_bilinear: Resize params is a pair of values, not %i.", newImageSize->lengthOf());
|
REQUIRE_TRUE(newImageSize->lengthOf() == 2, 0, "resize_bilinear: Resize params is a pair of values, not %i.", newImageSize->lengthOf());
|
||||||
REQUIRE_TRUE(block.numI() <= 1, 0, "resize_bilinear: Resize params already given by the second param. Int params are expensive.");
|
REQUIRE_TRUE(block.numI() <= 1, 0, "resize_bilinear: Resize params already given by the second param. Int params are expensive.");
|
||||||
width = newImageSize->e<int>(0);
|
height = newImageSize->e<int>(0);
|
||||||
height = newImageSize->e<int>(1);
|
width = newImageSize->e<int>(1);
|
||||||
if (block.numI() == 1) {
|
|
||||||
center = 0 != INT_ARG(0);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
else {
|
else {
|
||||||
REQUIRE_TRUE(block.numI() <= 3, 0, "resize_bilinear: Neither resize width nor height are provided.");
|
REQUIRE_TRUE(block.numI() > 1, 0, "resize_bilinear: Neither resize width nor height are provided.");
|
||||||
width = INT_ARG(0);
|
height = INT_ARG(0);
|
||||||
height = INT_ARG(1);
|
width = INT_ARG(1);
|
||||||
if (block.numI() == 3)
|
|
||||||
center = 0 != INT_ARG(2);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return helpers::resizeBilinearFunctor(block.launchContext(), inRank==4?image:&source, width, height, center, inRank == 4?output:&target);
|
if (block.numB() > 0)
|
||||||
|
alignCorners = B_ARG(0);
|
||||||
|
bool halfPixelCenter = false;
|
||||||
|
|
||||||
|
if (block.numB() > 1)
|
||||||
|
halfPixelCenter = B_ARG(1);
|
||||||
|
|
||||||
|
REQUIRE_TRUE(!halfPixelCenter || (halfPixelCenter && !alignCorners), 0, "resize_bilinear: `half_pixel_centers' should be false or true only when `align_corners' is false");
|
||||||
|
|
||||||
|
return helpers::resizeBilinearFunctor(block.launchContext(), inRank==4?image:&source, width, height, alignCorners, halfPixelCenter, inRank == 4 ? output : &target);
|
||||||
}
|
}
|
||||||
|
|
||||||
DECLARE_SHAPE_FN(resize_bilinear) {
|
DECLARE_SHAPE_FN(resize_bilinear) {
|
||||||
|
@ -83,7 +89,7 @@ namespace nd4j {
|
||||||
height = newImageSize->e<int>(1);
|
height = newImageSize->e<int>(1);
|
||||||
}
|
}
|
||||||
else {
|
else {
|
||||||
REQUIRE_TRUE(block.numI() <= 3, 0, "resize_bilinear: Neither resize width nor height are provided.");
|
REQUIRE_TRUE(block.numI() == 2, 0, "resize_bilinear: Neither resize width nor height are provided.");
|
||||||
width = INT_ARG(0);
|
width = INT_ARG(0);
|
||||||
height = INT_ARG(1);
|
height = INT_ARG(1);
|
||||||
}
|
}
|
||||||
|
@ -101,7 +107,12 @@ namespace nd4j {
|
||||||
outputShape[2] = height;
|
outputShape[2] = height;
|
||||||
outputShape[3] = in[3];
|
outputShape[3] = in[3];
|
||||||
}
|
}
|
||||||
|
if (DataTypeUtils::isR(ArrayOptions::dataType(in))) {
|
||||||
ShapeUtils::updateStridesAndType(outputShape, in, shape::order(in));
|
ShapeUtils::updateStridesAndType(outputShape, in, shape::order(in));
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
ShapeUtils::updateStridesAndType(outputShape, DataType::FLOAT32, shape::order(in));
|
||||||
|
}
|
||||||
|
|
||||||
shapeList->push_back(CONSTANT(outputShape));
|
shapeList->push_back(CONSTANT(outputShape));
|
||||||
return shapeList;
|
return shapeList;
|
||||||
|
|
|
@ -31,35 +31,40 @@ namespace nd4j {
|
||||||
|
|
||||||
auto image = INPUT_VARIABLE(0);
|
auto image = INPUT_VARIABLE(0);
|
||||||
auto output = OUTPUT_VARIABLE(0);
|
auto output = OUTPUT_VARIABLE(0);
|
||||||
|
auto inRank = image->rankOf();
|
||||||
int width;
|
int width;
|
||||||
int height;
|
int height;
|
||||||
bool center = false; // - default value
|
bool alignCorners = false; // - default value
|
||||||
|
if (output->isEmpty()) return Status::OK();
|
||||||
if (block.width() > 1) {
|
if (block.width() > 1) {
|
||||||
auto newImageSize = INPUT_VARIABLE(1);
|
auto newImageSize = INPUT_VARIABLE(1);
|
||||||
REQUIRE_TRUE(newImageSize->lengthOf() == 2, 0, "resize_nearest_neighbor: Resize params is a pair of values, not %i.", newImageSize->lengthOf());
|
REQUIRE_TRUE(newImageSize->lengthOf() == 2, 0, "resize_nearest_neighbor: Resize params is a pair of values, not %i.", newImageSize->lengthOf());
|
||||||
REQUIRE_TRUE(block.numI() <= 1, 0, "resize_nearest_neighbor: Resize params already given by the second param. Int params are expensive.");
|
REQUIRE_TRUE(block.numI() <= 1, 0, "resize_nearest_neighbor: Resize params already given by the second param. Int params are expensive.");
|
||||||
width = newImageSize->e<int>(0);
|
height = newImageSize->e<int>(0);
|
||||||
height = newImageSize->e<int>(1);
|
width = newImageSize->e<int>(1);
|
||||||
if (block.numI() == 1) {
|
|
||||||
center = 0 != INT_ARG(0);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
else {
|
else {
|
||||||
REQUIRE_TRUE(block.numI() <= 3, 0, "resize_nearest_neighbor: Neither resize width nor height are provided.");
|
REQUIRE_TRUE(block.numI() == 2, 0, "resize_nearest_neighbor: Neither resize width nor height are provided.");
|
||||||
width = INT_ARG(0);
|
height = INT_ARG(0);
|
||||||
height = INT_ARG(1);
|
width = INT_ARG(1);
|
||||||
if (block.numI() == 3)
|
|
||||||
center = 0 != INT_ARG(2);
|
|
||||||
}
|
}
|
||||||
auto inRank = image->rankOf();
|
if (block.numB() > 0)
|
||||||
|
alignCorners = B_ARG(0);
|
||||||
|
bool halfPixelCenter = false;
|
||||||
|
|
||||||
|
if (block.numB() > 1)
|
||||||
|
halfPixelCenter = B_ARG(1);
|
||||||
|
REQUIRE_TRUE(width <= (1 << 24) || height <= (1 << 24), 0, "resize_nearest_neighbour: the image resize should be limited to 2^24 pixels both for height and width, but %d and %d were given.", height, width);
|
||||||
REQUIRE_TRUE(inRank == 4 || inRank == 3, 0, "resize_nearest_neighbor: Input should be 4D tensor, but rank %i occured");
|
REQUIRE_TRUE(inRank == 4 || inRank == 3, 0, "resize_nearest_neighbor: Input should be 4D tensor, but rank %i occured");
|
||||||
REQUIRE_TRUE(inRank == output->rankOf(), 0, "resize_nearest_neighbor: Input and output ranks should be equals, but %i and %i occured.", inRank, output->rankOf());
|
REQUIRE_TRUE(inRank == output->rankOf(), 0, "resize_nearest_neighbor: Input and output ranks should be equals, but %i and %i occured.", inRank, output->rankOf());
|
||||||
REQUIRE_TRUE(image->dataType() == output->dataType(), 0, "resize_nearest_neighbor: Input and output types should be the same, but `%s' occured instead.", DataTypeUtils::asString(output->dataType()).c_str());
|
REQUIRE_TRUE(image->dataType() == output->dataType(), 0, "resize_nearest_neighbor: Input and output types should be the same, but `%s' occured instead.", DataTypeUtils::asString(output->dataType()).c_str());
|
||||||
auto source = inRank == 4?*image:image->reshape(image->ordering(), {1, image->sizeAt(0), image->sizeAt(1), image->sizeAt(2)});
|
REQUIRE_TRUE(!halfPixelCenter || (halfPixelCenter && !alignCorners), 0, "resize_nearest_neighbor: `half_pixel_centers' should be false or true only when `align_corners' is false");
|
||||||
|
REQUIRE_TRUE(((alignCorners && height > 2) || (height > 0)) && ((alignCorners && width > 1) || (width > 0)), 0, "resize_nearest_neighbor: Wrong input or output size to resize (width = %d, height = %d)", width, height);
|
||||||
|
|
||||||
|
auto source = inRank == 4?*image:image->reshape(image->ordering(), {1, image->sizeAt(0), image->sizeAt(1), image->sizeAt(2)});
|
||||||
auto target = inRank == 4?*output:output->reshape(output->ordering(), {1, output->sizeAt(0), output->sizeAt(1), output->sizeAt(2)});
|
auto target = inRank == 4?*output:output->reshape(output->ordering(), {1, output->sizeAt(0), output->sizeAt(1), output->sizeAt(2)});
|
||||||
|
|
||||||
return helpers::resizeNeighborFunctor(block.launchContext(), inRank==4?image:&source, width, height, center, inRank == 4?output:&target);
|
return helpers::resizeNeighborFunctor(block.launchContext(), inRank==4?image:&source, width, height, alignCorners, halfPixelCenter, inRank == 4 ? output : &target);
|
||||||
}
|
}
|
||||||
|
|
||||||
DECLARE_SHAPE_FN(resize_nearest_neighbor) {
|
DECLARE_SHAPE_FN(resize_nearest_neighbor) {
|
||||||
|
|
|
@ -120,6 +120,27 @@ namespace helpers {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// Half pixel scaler scales assuming that the pixel centers are at 0.5, i.e. the
|
||||||
|
// floating point coordinates of the top,left pixel is 0.5,0.5.
|
||||||
|
struct HalfPixelScalerNN {
|
||||||
|
HalfPixelScalerNN(){};
|
||||||
|
inline float operator()(const int x, const float scale) const {
|
||||||
|
// Note that we subtract 0.5 from the return value, as the existing bilinear
|
||||||
|
// sampling code etc assumes pixels are in the old coordinate system.
|
||||||
|
return (static_cast<float>(x) + 0.5f) * scale;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
// Older incorrect scaling method that causes all resizes to have a slight
|
||||||
|
// translation leading to inconsistent results. For example, a flip then a
|
||||||
|
// resize gives different results then a resize then a flip.
|
||||||
|
struct LegacyScaler {
|
||||||
|
LegacyScaler(){};
|
||||||
|
inline float operator()(const int x, const float scale) const {
|
||||||
|
return static_cast<float>(x) * scale;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
struct WeightsAndIndices {
|
struct WeightsAndIndices {
|
||||||
float _weight0;
|
float _weight0;
|
||||||
float _weight1;
|
float _weight1;
|
||||||
|
@ -133,7 +154,8 @@ namespace helpers {
|
||||||
int _advance; // advance value.
|
int _advance; // advance value.
|
||||||
};
|
};
|
||||||
|
|
||||||
inline void computeInterpolationWeights(Nd4jLong outSize,
|
template <class Scaler>
|
||||||
|
inline void computeInterpolationWeights(const Scaler scaler, Nd4jLong outSize,
|
||||||
Nd4jLong inSize,
|
Nd4jLong inSize,
|
||||||
double scale,
|
double scale,
|
||||||
BilinearInterpolationData *interpolationData) {
|
BilinearInterpolationData *interpolationData) {
|
||||||
|
@ -143,10 +165,12 @@ namespace helpers {
|
||||||
auto func = PRAGMA_THREADS_FOR {
|
auto func = PRAGMA_THREADS_FOR {
|
||||||
for (auto k = start; k < stop; k++) {
|
for (auto k = start; k < stop; k++) {
|
||||||
auto i = (outSize - k - 1);
|
auto i = (outSize - k - 1);
|
||||||
double in = i * scale;
|
double const in = scaler(i, scale);
|
||||||
interpolationData[i]._bottomIndex = static_cast<Nd4jLong>(in);
|
double const in_f = nd4j::math::nd4j_floor<double, double>(in);
|
||||||
interpolationData[i]._topIndex = nd4j::math::nd4j_min(interpolationData[i]._bottomIndex + 1, inSize - 1);
|
double const in_c = nd4j::math::nd4j_ceil<double, double>(in);
|
||||||
interpolationData[i]._interpolarValue = in - interpolationData[i]._bottomIndex;
|
interpolationData[i]._bottomIndex = nd4j::math::nd4j_max(static_cast<Nd4jLong>(in_f), (Nd4jLong)0LL);//static_cast<Nd4jLong>(in);
|
||||||
|
interpolationData[i]._topIndex = nd4j::math::nd4j_min(static_cast<Nd4jLong>(in_c), inSize - 1);
|
||||||
|
interpolationData[i]._interpolarValue = in - in_f;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
samediff::Threads::parallel_for(func, 0, outSize);
|
samediff::Threads::parallel_for(func, 0, outSize);
|
||||||
|
@ -156,29 +180,29 @@ namespace helpers {
|
||||||
* Computes the bilinear interpolation from the appropriate 4 float points
|
* Computes the bilinear interpolation from the appropriate 4 float points
|
||||||
* and the linear interpolation weights.
|
* and the linear interpolation weights.
|
||||||
*/
|
*/
|
||||||
static void
|
// static void
|
||||||
resizeImage(NDArray const *images, Nd4jLong batchSize, Nd4jLong inHeight, Nd4jLong inWidth, Nd4jLong outHeight,
|
// resizeImage(NDArray const *images, Nd4jLong batchSize, Nd4jLong inHeight, Nd4jLong inWidth, Nd4jLong outHeight,
|
||||||
Nd4jLong outWidth, Nd4jLong channels,
|
// Nd4jLong outWidth, Nd4jLong channels,
|
||||||
std::vector<BilinearInterpolationData> const& xs,
|
// std::vector<BilinearInterpolationData> const& xs,
|
||||||
std::vector<BilinearInterpolationData> const& ys,
|
// std::vector<BilinearInterpolationData> const& ys,
|
||||||
NDArray *output);
|
// NDArray *output);
|
||||||
|
|
||||||
template<typename T>
|
template<typename T, typename Z>
|
||||||
static void
|
static void
|
||||||
resizeImage_(NDArray const *images, Nd4jLong batchSize, Nd4jLong inHeight, Nd4jLong inWidth, Nd4jLong outHeight,
|
resizeImage_(T const* pInputBuf, Nd4jLong batchSize, Nd4jLong inHeight, Nd4jLong inWidth, Nd4jLong outHeight,
|
||||||
Nd4jLong outWidth, Nd4jLong channels,
|
Nd4jLong outWidth, Nd4jLong channels,
|
||||||
std::vector<BilinearInterpolationData> const &xs,
|
std::vector<BilinearInterpolationData> const &xs,
|
||||||
std::vector<BilinearInterpolationData> const &ys,
|
std::vector<BilinearInterpolationData> const &ys,
|
||||||
NDArray *output) {
|
Z* pOutputBuf) {
|
||||||
|
|
||||||
Nd4jLong inRowSize = inWidth * channels;
|
Nd4jLong inRowSize = inWidth * channels;
|
||||||
Nd4jLong inBatchNumValues = inHeight * inRowSize;
|
Nd4jLong inBatchNumValues = inHeight * inRowSize;
|
||||||
Nd4jLong outRowSize = outWidth * channels;
|
Nd4jLong outRowSize = outWidth * channels;
|
||||||
|
|
||||||
T const *pInputBuf = images->getDataBuffer()->primaryAsT<T>(); // this works only with 'c' direction
|
// T const *pInputBuf = images->getDataBuffer()->primaryAsT<T>(); // this works only with 'c' direction
|
||||||
BilinearInterpolationData const* xsPtr = xs.data();
|
BilinearInterpolationData const* xsPtr = xs.data();
|
||||||
|
|
||||||
T* pOutputBuf = output->dataBuffer()->primaryAsT<T>();
|
// T* pOutputBuf = output->dataBuffer()->primaryAsT<T>();
|
||||||
auto computeBilinear = [](double topLeft, double topRight,
|
auto computeBilinear = [](double topLeft, double topRight,
|
||||||
double bottomLeft, double bottomRight,
|
double bottomLeft, double bottomRight,
|
||||||
double xVal, double yVal) {
|
double xVal, double yVal) {
|
||||||
|
@ -214,8 +238,12 @@ namespace helpers {
|
||||||
samediff::Threads::parallel_tad(func, 0, batchSize);
|
samediff::Threads::parallel_tad(func, 0, batchSize);
|
||||||
}
|
}
|
||||||
|
|
||||||
template<typename T>
|
template<typename X, typename Z>
|
||||||
static int resizeBilinearFunctor_(NDArray const *images, int width, int height, bool center, NDArray *output) {
|
static int resizeBilinearFunctor_(NDArray const *images, int const width, int const height, bool const alignCorners,
|
||||||
|
bool const halfPixelCenter, NDArray *output) {
|
||||||
|
ImageResizerState st(alignCorners, halfPixelCenter);
|
||||||
|
st.validateAndCalculateOutputSize(images, width, height);
|
||||||
|
|
||||||
const Nd4jLong batchSize = images->sizeAt(0);
|
const Nd4jLong batchSize = images->sizeAt(0);
|
||||||
const Nd4jLong inHeight = images->sizeAt(1);
|
const Nd4jLong inHeight = images->sizeAt(1);
|
||||||
const Nd4jLong inWidth = images->sizeAt(2);
|
const Nd4jLong inWidth = images->sizeAt(2);
|
||||||
|
@ -230,28 +258,20 @@ namespace helpers {
|
||||||
return ND4J_STATUS_OK;
|
return ND4J_STATUS_OK;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Special case for TF compatibility
|
|
||||||
if((center && inHeight < 2) || (center && inWidth < 2)){
|
|
||||||
center = false;
|
|
||||||
}
|
|
||||||
|
|
||||||
if ((center && inHeight < 2) || (inHeight < 1) || (outHeight < 1) || (center && outHeight < 2) ||
|
|
||||||
(center && inWidth < 2) || (inWidth < 1) || (outWidth < 1) || (center && outWidth < 2)) {
|
|
||||||
// wrong input data
|
|
||||||
nd4j_printf("image.resize_bilinear: Wrong input or output size to resize\n", "");
|
|
||||||
return ND4J_STATUS_BAD_ARGUMENTS;
|
|
||||||
}
|
|
||||||
float heightScale = center ? (inHeight - 1.f) / double(outHeight - 1.f) : (inHeight / float(outHeight));
|
|
||||||
float widthScale = center ? (inWidth - 1.f) / double(outWidth - 1.f) : (inWidth / float(outWidth));
|
|
||||||
|
|
||||||
std::vector<BilinearInterpolationData> ys(outHeight + 1);
|
std::vector<BilinearInterpolationData> ys(outHeight + 1);
|
||||||
std::vector<BilinearInterpolationData> xs(outWidth + 1);
|
std::vector<BilinearInterpolationData> xs(outWidth + 1);
|
||||||
|
if (halfPixelCenter) {
|
||||||
// Compute the cached interpolation weights on the x and y dimensions.
|
computeInterpolationWeights(HalfPixelScaler(), outHeight, inHeight, st.heightScale,
|
||||||
computeInterpolationWeights(outHeight, inHeight, heightScale,
|
|
||||||
ys.data());
|
ys.data());
|
||||||
computeInterpolationWeights(outWidth, inWidth, widthScale, xs.data());
|
computeInterpolationWeights(HalfPixelScaler(), outWidth, inWidth, st.widthScale, xs.data());
|
||||||
|
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
// Compute the cached interpolation weights on the x and y dimensions.
|
||||||
|
computeInterpolationWeights(LegacyScaler(), outHeight, inHeight, st.heightScale,
|
||||||
|
ys.data());
|
||||||
|
computeInterpolationWeights(LegacyScaler(), outWidth, inWidth, st.widthScale, xs.data());
|
||||||
|
}
|
||||||
int xsSize = xs.size();
|
int xsSize = xs.size();
|
||||||
// Scale x interpolation weights to avoid a multiplication during iteration.
|
// Scale x interpolation weights to avoid a multiplication during iteration.
|
||||||
auto func = PRAGMA_THREADS_FOR {
|
auto func = PRAGMA_THREADS_FOR {
|
||||||
|
@ -262,71 +282,84 @@ namespace helpers {
|
||||||
};
|
};
|
||||||
samediff::Threads::parallel_for(func, 0, xsSize);
|
samediff::Threads::parallel_for(func, 0, xsSize);
|
||||||
|
|
||||||
resizeImage(images, batchSize, inHeight, inWidth, outHeight, outWidth, channels, xs, ys, output);
|
resizeImage_<X,Z>(images->getDataBuffer()->primaryAsT<X>(), batchSize, inHeight, inWidth, outHeight, outWidth, channels, xs, ys, output->dataBuffer()->primaryAsT<Z>());
|
||||||
return ND4J_STATUS_OK;
|
return ND4J_STATUS_OK;
|
||||||
}
|
}
|
||||||
|
|
||||||
template<typename T>
|
template <class Scaler, typename T>
|
||||||
int resizeNeighborFunctor_(NDArray const *images, int width, int height, bool center, NDArray *output) {
|
void resizeNeighbor(ImageResizerState const& st, NDArray const *images, bool const alignCorners, bool const halfPixelCenter, NDArray *output) {
|
||||||
const Nd4jLong batchSize = images->sizeAt(0);
|
const Nd4jLong batchSize = st.batchSize;
|
||||||
const Nd4jLong inHeight = images->sizeAt(1);
|
const Nd4jLong inHeight = st.inHeight;
|
||||||
const Nd4jLong inWidth = images->sizeAt(2);
|
const Nd4jLong inWidth = st.inWidth;
|
||||||
const Nd4jLong channels = images->sizeAt(3);
|
const Nd4jLong channels = st.channels;
|
||||||
|
|
||||||
const Nd4jLong outHeight = output->sizeAt(1);
|
const Nd4jLong outHeight = st.outHeight;
|
||||||
const Nd4jLong outWidth = output->sizeAt(2);
|
const Nd4jLong outWidth = st.outWidth;
|
||||||
|
Scaler scaler;
|
||||||
// Handle no-op resizes efficiently.
|
|
||||||
if (outHeight == inHeight && outWidth == inWidth) {
|
|
||||||
output->assign(images);
|
|
||||||
return Status::OK();
|
|
||||||
}
|
|
||||||
|
|
||||||
if ((center && inHeight < 2) || (inHeight < 1) || (outHeight < 1) || (center && outHeight < 2) ||
|
|
||||||
(center && inWidth < 2) || (inWidth < 1) || (outWidth < 1) || (center && outWidth < 2)) {
|
|
||||||
// wrong input data
|
|
||||||
nd4j_printf("image.resize_nearest_neighbor: Wrong input or output size to resize\n", "");
|
|
||||||
return ND4J_STATUS_BAD_ARGUMENTS;
|
|
||||||
}
|
|
||||||
double heightScale = center ? (inHeight - 1.) / double(outHeight - 1.0) : (inHeight / double(outHeight));
|
|
||||||
double widthScale = center ? (inWidth - 1.) / double(outWidth - 1.0) : (inWidth / double(outWidth));
|
|
||||||
|
|
||||||
auto func = PRAGMA_THREADS_FOR_2D {
|
auto func = PRAGMA_THREADS_FOR_2D {
|
||||||
for (auto b = start_x; b < stop_x; b += inc_x) {
|
for (auto b = start_x; b < stop_x; b += inc_x) {
|
||||||
for (auto y = start_y; y < stop_y; y += inc_y) {
|
for (auto y = start_y; y < stop_y; y += inc_y) {
|
||||||
Nd4jLong inY = nd4j::math::nd4j_min((center) ? static_cast<Nd4jLong>(nd4j::math::p_round<float>(y * heightScale)) : static_cast<Nd4jLong>(nd4j::math::p_floor<float>(y * heightScale)), inHeight - 1);
|
auto posY = alignCorners ? static_cast<Nd4jLong>(nd4j::math::p_round<float>(scaler(y, st.heightScale))) : static_cast<Nd4jLong>(nd4j::math::p_floor<float>(scaler(y, st.heightScale)));
|
||||||
|
Nd4jLong inY = nd4j::math::nd4j_min(posY, inHeight - 1);
|
||||||
|
if (halfPixelCenter) {
|
||||||
|
inY = nd4j::math::nd4j_max(0LL, inY);
|
||||||
|
}
|
||||||
for (auto x = 0; x < outWidth; ++x) {
|
for (auto x = 0; x < outWidth; ++x) {
|
||||||
Nd4jLong inX = nd4j::math::nd4j_min((center) ? static_cast<Nd4jLong>(nd4j::math::p_round<float>(x * widthScale)) : static_cast<Nd4jLong>(nd4j::math::p_floor<float>(x * widthScale)),inWidth - 1);
|
auto posX = alignCorners ? static_cast<Nd4jLong>(nd4j::math::p_round<float>(scaler(x, st.widthScale))) : static_cast<Nd4jLong>(nd4j::math::p_floor<float>(scaler(x, st.widthScale)));
|
||||||
|
Nd4jLong inX = nd4j::math::nd4j_min(posX,inWidth - 1);
|
||||||
|
if (halfPixelCenter) {
|
||||||
|
inX = nd4j::math::nd4j_max(0LL, inX);
|
||||||
|
}
|
||||||
|
// copy pixel over all channels
|
||||||
for (auto e = 0; e < channels; e++)
|
for (auto e = 0; e < channels; e++)
|
||||||
output->p(b, y, x, e, images->e<T>(b, inY, inX, e));
|
output->t<T>(b, y, x, e) = images->t<T>(b, inY, inX, e);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
samediff::Threads::parallel_for(func, 0, batchSize, 1, 0, outHeight, 1);
|
samediff::Threads::parallel_for(func, 0, batchSize, 1, 0, outHeight, 1);
|
||||||
|
}
|
||||||
|
|
||||||
|
template<typename T>
|
||||||
|
int resizeNeighborFunctor_(NDArray const *images, int const width, int const height, bool const alignCorners, bool const halfPixelCenter, NDArray *output) {
|
||||||
|
ImageResizerState st(alignCorners, halfPixelCenter);
|
||||||
|
st.validateAndCalculateOutputSize(images, width, height);
|
||||||
|
|
||||||
|
// Handle no-op resizes efficiently.
|
||||||
|
if (output->sizeAt(1) == images->sizeAt(1) && output->sizeAt(2) == images->sizeAt(2)) {
|
||||||
|
output->assign(images);
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
if (halfPixelCenter)
|
||||||
|
resizeNeighbor<HalfPixelScalerNN, T>(st, images, alignCorners, true, output);
|
||||||
|
else
|
||||||
|
resizeNeighbor<LegacyScaler, T>(st, images, alignCorners, false, output);
|
||||||
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
void resizeImage(NDArray const *images, Nd4jLong batchSize, Nd4jLong inHeight, Nd4jLong inWidth, Nd4jLong outHeight,
|
// void resizeImage(NDArray const *images, Nd4jLong batchSize, Nd4jLong inHeight, Nd4jLong inWidth, Nd4jLong outHeight,
|
||||||
Nd4jLong outWidth, Nd4jLong channels,
|
// Nd4jLong outWidth, Nd4jLong channels,
|
||||||
std::vector<BilinearInterpolationData> const &xs,
|
// std::vector<BilinearInterpolationData> const &xs,
|
||||||
std::vector<BilinearInterpolationData> const &ys,
|
// std::vector<BilinearInterpolationData> const &ys,
|
||||||
NDArray *output) {
|
// NDArray *output) {
|
||||||
BUILD_SINGLE_SELECTOR(images->dataType(), resizeImage_,
|
// BUILD_DOUBLE_SELECTOR(images->dataType(), output->dataType(), resizeImage_,
|
||||||
(images, batchSize, inHeight, inWidth, outHeight, outWidth, channels, xs, ys, output),
|
// (images, batchSize, inHeight, inWidth, outHeight, outWidth, channels, xs, ys, output),
|
||||||
LIBND4J_TYPES);
|
// NUMERIC_TYPES, FLOAT_TYPES);
|
||||||
|
// }
|
||||||
|
|
||||||
|
int resizeBilinearFunctor(nd4j::LaunchContext * context, NDArray const *images, int const width, int const height,
|
||||||
|
bool const alignCorners, bool const halfPixelCenter, NDArray *output) {
|
||||||
|
BUILD_DOUBLE_SELECTOR(images->dataType(), output->dataType(), return resizeBilinearFunctor_,
|
||||||
|
(images, width, height, alignCorners, halfPixelCenter, output), NUMERIC_TYPES, FLOAT_TYPES);
|
||||||
}
|
}
|
||||||
|
|
||||||
int resizeBilinearFunctor(nd4j::LaunchContext * context, NDArray const *images, int width, int height, bool center, NDArray *output) {
|
int resizeNeighborFunctor(nd4j::LaunchContext * context, NDArray const *images, int const width, int const height,
|
||||||
BUILD_SINGLE_SELECTOR(images->dataType(), return resizeBilinearFunctor_,
|
bool const alignCorners, bool const halfPixelCenter, NDArray *output) {
|
||||||
(images, width, height, center, output), LIBND4J_TYPES);
|
|
||||||
}
|
|
||||||
|
|
||||||
int resizeNeighborFunctor(nd4j::LaunchContext * context, NDArray const *images, int width, int height, bool center, NDArray *output) {
|
|
||||||
BUILD_SINGLE_SELECTOR(images->dataType(), return resizeNeighborFunctor_,
|
BUILD_SINGLE_SELECTOR(images->dataType(), return resizeNeighborFunctor_,
|
||||||
(images, width, height, center, output), LIBND4J_TYPES);
|
(images, width, height, alignCorners, halfPixelCenter, output), LIBND4J_TYPES);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -586,16 +619,6 @@ namespace helpers {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Older incorrect scaling method that causes all resizes to have a slight
|
|
||||||
// translation leading to inconsistent results. For example, a flip then a
|
|
||||||
// resize gives different results then a resize then a flip.
|
|
||||||
struct LegacyScaler {
|
|
||||||
LegacyScaler(){};
|
|
||||||
inline float operator()(const int x, const float scale) const {
|
|
||||||
return static_cast<float>(x) * scale;
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
static void computeXWeightsAndIndices(const ImageResizerState& resizer_state,
|
static void computeXWeightsAndIndices(const ImageResizerState& resizer_state,
|
||||||
const bool half_pixel_centers,
|
const bool half_pixel_centers,
|
||||||
std::vector<WeightsAndIndices>* x_wais) {
|
std::vector<WeightsAndIndices>* x_wais) {
|
||||||
|
@ -847,7 +870,7 @@ namespace helpers {
|
||||||
// simplified bicubic resize without antialiasing
|
// simplified bicubic resize without antialiasing
|
||||||
//
|
//
|
||||||
template <typename T>
|
template <typename T>
|
||||||
int resizeBicubicFunctorA_(nd4j::LaunchContext * context, NDArray const* image, int width, int height,
|
int resizeBicubicFunctorA_(nd4j::LaunchContext * context, NDArray const* image, int const width, int const height,
|
||||||
bool const alignCorners, bool const halfPixelAlign, NDArray* output) {
|
bool const alignCorners, bool const halfPixelAlign, NDArray* output) {
|
||||||
ImageResizerState st(alignCorners, halfPixelAlign); // align_corners, half_pixel_align
|
ImageResizerState st(alignCorners, halfPixelAlign); // align_corners, half_pixel_align
|
||||||
int res = st.validateAndCreateOutput(image, width, height);
|
int res = st.validateAndCreateOutput(image, width, height);
|
||||||
|
@ -856,17 +879,17 @@ namespace helpers {
|
||||||
|
|
||||||
return res;
|
return res;
|
||||||
}
|
}
|
||||||
int resizeBicubicFunctorA(nd4j::LaunchContext * context, NDArray const* image, int width, int height,
|
int resizeBicubicFunctorA(nd4j::LaunchContext * context, NDArray const* image, int const width, int const height,
|
||||||
bool const alignCorners, bool const halfPixelAlign, NDArray* output) {
|
bool const alignCorners, bool const halfPixelAlign, NDArray* output) {
|
||||||
BUILD_SINGLE_SELECTOR(image->dataType(), return resizeBicubicFunctorA_, (context,
|
BUILD_SINGLE_SELECTOR(image->dataType(), return resizeBicubicFunctorA_, (context,
|
||||||
image, width, height, alignCorners, halfPixelAlign, output), NUMERIC_TYPES);
|
image, width, height, alignCorners, halfPixelAlign, output), NUMERIC_TYPES);
|
||||||
}
|
}
|
||||||
// ------------------------------------------------------------------------------------------------------------------ //
|
// ------------------------------------------------------------------------------------------------------------------ //
|
||||||
int resizeFunctor(nd4j::LaunchContext * context, NDArray const* image, int width, int height,
|
int resizeFunctor(nd4j::LaunchContext * context, NDArray const* image, int const width, int const height,
|
||||||
ImageResizeMethods method, bool preserveAspectRatio, bool antialias, NDArray* output) {
|
ImageResizeMethods method, bool preserveAspectRatio, bool antialias, NDArray* output) {
|
||||||
switch (method) {
|
switch (method) {
|
||||||
case kResizeBilinear: return resizeBilinearFunctor(context, image, width, height, false, output); break;
|
case kResizeBilinear: return resizeBilinearFunctor(context, image, width, height, false, false, output); break;
|
||||||
case kResizeNearest: return resizeNeighborFunctor(context, image, width, height, false, output); break;
|
case kResizeNearest: return resizeNeighborFunctor(context, image, width, height, false, false, output); break;
|
||||||
case kResizeBicubic: return resizeBicubicFunctor(context, image, width, height, preserveAspectRatio, antialias, output); break;
|
case kResizeBicubic: return resizeBicubicFunctor(context, image, width, height, preserveAspectRatio, antialias, output); break;
|
||||||
case kResizeLanczos5:
|
case kResizeLanczos5:
|
||||||
case kResizeGaussian:
|
case kResizeGaussian:
|
||||||
|
|
|
@ -13,6 +13,20 @@
|
||||||
*
|
*
|
||||||
* SPDX-License-Identifier: Apache-2.0
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
******************************************************************************/
|
******************************************************************************/
|
||||||
|
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
//
|
//
|
||||||
// @author sgazeos@gmail.com
|
// @author sgazeos@gmail.com
|
||||||
|
@ -32,6 +46,38 @@ namespace helpers {
|
||||||
// https://en.wikipedia.org/wiki/Bilinear_interpolation)
|
// https://en.wikipedia.org/wiki/Bilinear_interpolation)
|
||||||
double interpolarValue;
|
double interpolarValue;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// Older incorrect scaling method that causes all resizes to have a slight
|
||||||
|
// translation leading to inconsistent results. For example, a flip then a
|
||||||
|
// resize gives different results then a resize then a flip.
|
||||||
|
struct LegacyScaler {
|
||||||
|
_CUDA_HD LegacyScaler(){};
|
||||||
|
inline _CUDA_HD float operator()(const int x, const float scale) const {
|
||||||
|
return static_cast<float>(x) * scale;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
// Half pixel scaler scales assuming that the pixel centers are at 0.5, i.e. the
|
||||||
|
// floating point coordinates of the top,left pixel is 0.5,0.5.
|
||||||
|
struct HalfPixelScaler {
|
||||||
|
_CUDA_HD HalfPixelScaler(){};
|
||||||
|
inline _CUDA_HD float operator()(const int x, const float scale) const {
|
||||||
|
// Note that we subtract 0.5 from the return value, as the existing bilinear
|
||||||
|
// sampling code etc assumes pixels are in the old coordinate system.
|
||||||
|
return (static_cast<float>(x) + 0.5f) * scale - 0.5f;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
|
||||||
|
// Utility functions
|
||||||
|
// calculateResizeScale determines the float scaling factor.
|
||||||
|
inline float calculateResizeScale(Nd4jLong inSize, Nd4jLong outSize,
|
||||||
|
bool alignCorners) {
|
||||||
|
return (alignCorners && outSize > 1)
|
||||||
|
? (inSize - 1) / static_cast<float>(outSize - 1)
|
||||||
|
: inSize / static_cast<float>(outSize);
|
||||||
|
}
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
// computeInterpolationWeights kernel
|
// computeInterpolationWeights kernel
|
||||||
// outSize - output length
|
// outSize - output length
|
||||||
|
@ -39,6 +85,7 @@ namespace helpers {
|
||||||
// scale - input scale
|
// scale - input scale
|
||||||
// interporationData - result
|
// interporationData - result
|
||||||
//
|
//
|
||||||
|
template <class Scaler>
|
||||||
static __global__ void computeInterpolationWeights(Nd4jLong outSize,
|
static __global__ void computeInterpolationWeights(Nd4jLong outSize,
|
||||||
Nd4jLong inSize,
|
Nd4jLong inSize,
|
||||||
double scale,
|
double scale,
|
||||||
|
@ -48,12 +95,18 @@ namespace helpers {
|
||||||
interpolationData[outSize].topIndex = 0;
|
interpolationData[outSize].topIndex = 0;
|
||||||
auto tid = blockIdx.x * blockDim.x + threadIdx.x;
|
auto tid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||||
auto step = blockDim.x * gridDim.x;
|
auto step = blockDim.x * gridDim.x;
|
||||||
|
Scaler scaler;
|
||||||
for (Nd4jLong i = outSize - tid; i >= 0; i -= step) {
|
for (Nd4jLong i = outSize - tid; i >= 0; i -= step) {
|
||||||
double in = i * scale;
|
double in = scaler(i, scale);
|
||||||
interpolationData[i].bottomIndex = static_cast<Nd4jLong>(in);
|
// interpolationData[i].bottomIndex = static_cast<Nd4jLong>(in);
|
||||||
interpolationData[i].topIndex = nd4j::math::nd4j_min(interpolationData[i].bottomIndex + 1, inSize - 1);
|
// interpolationData[i].topIndex = nd4j::math::nd4j_min(interpolationData[i].bottomIndex + 1, inSize - 1);
|
||||||
interpolationData[i].interpolarValue = in - interpolationData[i].bottomIndex;
|
// interpolationData[i].interpolarValue = in - interpolationData[i].bottomIndex;
|
||||||
|
double const in_f = nd4j::math::p_floor<double>(in);
|
||||||
|
double const in_c = nd4j::math::p_ceil<double>(in);
|
||||||
|
interpolationData[i].bottomIndex = nd4j::math::nd4j_max(static_cast<Nd4jLong>(in_f), (Nd4jLong)0LL);//static_cast<Nd4jLong>(in);
|
||||||
|
interpolationData[i].topIndex = nd4j::math::nd4j_min(static_cast<Nd4jLong>(in_c), inSize - 1);
|
||||||
|
interpolationData[i].interpolarValue = in - in_f;
|
||||||
|
|
||||||
if (channels) {
|
if (channels) {
|
||||||
math::atomics::nd4j_atomicMul(&interpolationData[i].bottomIndex, channels);
|
math::atomics::nd4j_atomicMul(&interpolationData[i].bottomIndex, channels);
|
||||||
math::atomics::nd4j_atomicMul(&interpolationData[i].topIndex, channels);
|
math::atomics::nd4j_atomicMul(&interpolationData[i].topIndex, channels);
|
||||||
|
@ -72,31 +125,33 @@ namespace helpers {
|
||||||
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
// resize image with bilinear interpolation algorithm kernel
|
// resize image with bilinear interpolation algorithm kernel
|
||||||
//
|
//
|
||||||
template <typename T>
|
template <typename T, typename Z>
|
||||||
static __global__ void resizeImageKernel(T const* input, Nd4jLong const* inputShape, T* outputYptr, Nd4jLong* outputShape, Nd4jLong batchSize,
|
static __global__ void resizeImageKernel(T const* input, Nd4jLong const* inputShape, Z* outputYptr,
|
||||||
Nd4jLong outWidth, Nd4jLong outHeight, Nd4jLong channels, Nd4jLong inRowSize, Nd4jLong outRowSize, Nd4jLong inBatchNumValues,
|
Nd4jLong* outputShape, Nd4jLong batchSize, Nd4jLong outWidth, Nd4jLong outHeight, Nd4jLong channels,
|
||||||
|
Nd4jLong inRowSize, Nd4jLong outRowSize, Nd4jLong inBatchNumValues,
|
||||||
BilinearInterpolationData* xs_, BilinearInterpolationData* ys_) {
|
BilinearInterpolationData* xs_, BilinearInterpolationData* ys_) {
|
||||||
|
|
||||||
for (auto batch = blockIdx.x; batch < batchSize; batch += gridDim.x ) { // blockIdx.x as batch index
|
for (auto batch = blockIdx.x; batch < batchSize; batch += gridDim.x ) { // blockIdx.x as batch index
|
||||||
auto pX = input + batch * inBatchNumValues;
|
auto pX = input + batch * inBatchNumValues;
|
||||||
for (Nd4jLong y = threadIdx.x; y < outHeight; y += blockDim.x) {
|
for (Nd4jLong y = threadIdx.x; y < outHeight; y += blockDim.x) {
|
||||||
const T *ys_input_lower_ptr = pX + ys_[y].bottomIndex * inRowSize;
|
const T* ys_input_lower_ptr = pX + ys_[y].bottomIndex * inRowSize;
|
||||||
const T *ys_input_upper_ptr = pX + ys_[y].topIndex * inRowSize;
|
const T* ys_input_upper_ptr = pX + ys_[y].topIndex * inRowSize;
|
||||||
double yVal = ys_[y].interpolarValue;
|
double yVal = ys_[y].interpolarValue;
|
||||||
auto pZ = outputYptr + (batch * outHeight + y) * outRowSize;
|
auto pZ = outputYptr + (batch * outHeight + y) * outRowSize;
|
||||||
for (Nd4jLong x = threadIdx.y; x < outWidth; x += blockDim.y) {
|
for (Nd4jLong x = 0; x < outWidth; x++) {
|
||||||
auto xsBottom = xs_[x].bottomIndex;
|
auto xsBottom = xs_[x].bottomIndex;
|
||||||
auto xsTop = xs_[x].topIndex;
|
auto xsTop = xs_[x].topIndex;
|
||||||
auto xVal = xs_[x].interpolarValue;
|
auto xVal = xs_[x].interpolarValue;
|
||||||
// process interpolation for all channels
|
// process interpolation for all channels
|
||||||
for (int c = threadIdx.z; c < channels; c += blockDim.z) {
|
for (int c = 0; c < channels; c++) {
|
||||||
double topLeft(ys_input_lower_ptr[xsBottom + c]);
|
Z topLeft(ys_input_lower_ptr[xsBottom + c]);
|
||||||
double topRight(ys_input_lower_ptr[xsTop + c]);
|
Z topRight(ys_input_lower_ptr[xsTop + c]);
|
||||||
double bottomLeft(ys_input_upper_ptr[xsBottom + c]);
|
Z bottomLeft(ys_input_upper_ptr[xsBottom + c]);
|
||||||
double bottomRight(ys_input_upper_ptr[xsTop + c]);
|
Z bottomRight(ys_input_upper_ptr[xsTop + c]);
|
||||||
double top = topLeft + (topRight - topLeft) * xVal;
|
Z top = topLeft + (topRight - topLeft) * xVal;
|
||||||
double bottom = bottomLeft + (bottomRight - bottomLeft) * xVal;
|
Z bottom = bottomLeft + (bottomRight - bottomLeft) * xVal;
|
||||||
pZ[x * channels + c] = T(top + (bottom - top) * yVal);
|
Z resVal = Z(top + (bottom - top) * yVal);
|
||||||
|
pZ[x * channels + c] = resVal;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -105,7 +160,7 @@ namespace helpers {
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
// resize image with
|
// resize image with
|
||||||
template <typename T>
|
template <typename T, typename F>
|
||||||
static void resizeImage_(nd4j::LaunchContext* context, NDArray const* images, Nd4jLong batchSize, Nd4jLong inHeight, Nd4jLong inWidth, Nd4jLong outHeight,
|
static void resizeImage_(nd4j::LaunchContext* context, NDArray const* images, Nd4jLong batchSize, Nd4jLong inHeight, Nd4jLong inWidth, Nd4jLong outHeight,
|
||||||
Nd4jLong outWidth, Nd4jLong channels,
|
Nd4jLong outWidth, Nd4jLong channels,
|
||||||
BilinearInterpolationData* xs_,
|
BilinearInterpolationData* xs_,
|
||||||
|
@ -115,12 +170,13 @@ namespace helpers {
|
||||||
Nd4jLong inBatchNumValues = inHeight * inRowSize;
|
Nd4jLong inBatchNumValues = inHeight * inRowSize;
|
||||||
Nd4jLong outRowSize = outWidth * channels;
|
Nd4jLong outRowSize = outWidth * channels;
|
||||||
auto stream = context->getCudaStream();
|
auto stream = context->getCudaStream();
|
||||||
T const *input_b_ptr = reinterpret_cast<T const *>(images->getSpecialBuffer()); // this works only with 'c' direction
|
T const* pInput = images->getDataBuffer()->specialAsT<T>(); //reinterpret_cast<T const *>(images->getSpecialBuffer()); // this works only with 'c' direction
|
||||||
T *output_y_ptr = reinterpret_cast<T *>(output->specialBuffer());
|
F* pOutput = output->dataBuffer()->specialAsT<F>();//reinterpret_cast<F *>(output->specialBuffer());
|
||||||
dim3 batchSizeBlock(batchSize, 1, 1);
|
dim3 batchSizeBlock(batchSize, 1, 1);
|
||||||
dim3 pictureBlock(outHeight, outWidth, channels);
|
dim3 pictureBlock(outHeight, outWidth, channels);
|
||||||
resizeImageKernel<T><<<256, pictureBlock, 256, *stream>>>(input_b_ptr, images->getSpecialShapeInfo(), output_y_ptr, output->specialShapeInfo(), batchSize,
|
resizeImageKernel<T,F><<<256, 256, 256, *stream>>>(pInput, images->getSpecialShapeInfo(), pOutput,
|
||||||
outWidth, outHeight, channels, inRowSize, outRowSize, inBatchNumValues, xs_, ys_);
|
output->specialShapeInfo(), batchSize, outWidth, outHeight, channels, inRowSize, outRowSize,
|
||||||
|
inBatchNumValues, xs_, ys_);
|
||||||
|
|
||||||
auto err = cudaStreamSynchronize(*stream);
|
auto err = cudaStreamSynchronize(*stream);
|
||||||
if (err != 0) {
|
if (err != 0) {
|
||||||
|
@ -129,8 +185,9 @@ namespace helpers {
|
||||||
}
|
}
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
template <typename T>
|
template <typename T, typename F>
|
||||||
static int resizeBilinearFunctor_(nd4j::LaunchContext* context, NDArray const* images, int width, int height, bool center, NDArray* output) {
|
static int resizeBilinearFunctor_(nd4j::LaunchContext* context, NDArray const* images, int const width,
|
||||||
|
int const height, bool const alignCorners, bool const halfPixelCenter, NDArray* output) {
|
||||||
const Nd4jLong batchSize = images->sizeAt(0);
|
const Nd4jLong batchSize = images->sizeAt(0);
|
||||||
const Nd4jLong inHeight = images->sizeAt(1);
|
const Nd4jLong inHeight = images->sizeAt(1);
|
||||||
const Nd4jLong inWidth = images->sizeAt(2);
|
const Nd4jLong inWidth = images->sizeAt(2);
|
||||||
|
@ -145,19 +202,8 @@ namespace helpers {
|
||||||
return ND4J_STATUS_OK;
|
return ND4J_STATUS_OK;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Special case for TF compatibility
|
float heightScale = calculateResizeScale(inHeight, outHeight, alignCorners);
|
||||||
if((center && inHeight < 2) || (center && inWidth < 2)){
|
float widthScale = calculateResizeScale(inWidth, outWidth, alignCorners);
|
||||||
center = false;
|
|
||||||
}
|
|
||||||
|
|
||||||
if ((center && inHeight < 2) || (inHeight < 1) || (outHeight < 1) || (center && outHeight < 2) ||
|
|
||||||
(center && inWidth < 2) || (inWidth < 1) || (outWidth < 1) || (center && outWidth < 2)) {
|
|
||||||
// wrong input data
|
|
||||||
nd4j_printf("image.resize_bilinear: Wrong input or output size to resize\n", "");
|
|
||||||
return ND4J_STATUS_BAD_ARGUMENTS;
|
|
||||||
}
|
|
||||||
float heightScale = center ? (inHeight - 1.f) / double(outHeight - 1.f) : (inHeight / float(outHeight));
|
|
||||||
float widthScale = center ? (inWidth - 1.f) / double(outWidth - 1.f) : (inWidth / float(outWidth));
|
|
||||||
|
|
||||||
BilinearInterpolationData* xs_;// = xs.data();
|
BilinearInterpolationData* xs_;// = xs.data();
|
||||||
BilinearInterpolationData* ys_;// = xs.data();
|
BilinearInterpolationData* ys_;// = xs.data();
|
||||||
|
@ -173,12 +219,24 @@ namespace helpers {
|
||||||
}
|
}
|
||||||
auto stream = context->getCudaStream();
|
auto stream = context->getCudaStream();
|
||||||
// Compute the cached interpolation weights on the x and y dimensions.
|
// Compute the cached interpolation weights on the x and y dimensions.
|
||||||
computeInterpolationWeights<<<256, 512, 512, *stream>>>(outHeight, inHeight, heightScale, 0, ys_);
|
if (halfPixelCenter) {
|
||||||
computeInterpolationWeights<<<256, 512, 512, *stream>>>(outWidth, inWidth, widthScale, channels, xs_);
|
computeInterpolationWeights <
|
||||||
|
HalfPixelScaler ><<<256, 512, 512, *stream>>>(outHeight, inHeight, heightScale, 0, ys_);
|
||||||
|
computeInterpolationWeights <
|
||||||
|
HalfPixelScaler ><<<256, 512, 512, *stream>>>(outWidth, inWidth, widthScale, channels, xs_);
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
computeInterpolationWeights <
|
||||||
|
LegacyScaler ><<<256, 512, 512, *stream>>>(outHeight, inHeight, heightScale, 0, ys_);
|
||||||
|
computeInterpolationWeights <
|
||||||
|
LegacyScaler ><<<256, 512, 512, *stream>>>(outWidth, inWidth, widthScale, channels, xs_);
|
||||||
|
}
|
||||||
|
printf("Input is %dx%d, Output is %dx%d\n", inHeight, inWidth, outHeight, outWidth);
|
||||||
NDArray::prepareSpecialUse({output}, {images});
|
NDArray::prepareSpecialUse({output}, {images});
|
||||||
resizeImage(context, images, batchSize, inHeight, inWidth, outHeight, outWidth, channels, xs_, ys_, output);
|
resizeImage_<T,F>(context, images, batchSize, inHeight, inWidth, outHeight, outWidth, channels, xs_, ys_, output);
|
||||||
|
err = cudaStreamSynchronize(*stream);
|
||||||
NDArray::registerSpecialUse({output}, {images});
|
NDArray::registerSpecialUse({output}, {images});
|
||||||
|
|
||||||
err = cudaFree(xs_);
|
err = cudaFree(xs_);
|
||||||
if (err != 0) {
|
if (err != 0) {
|
||||||
throw cuda_exception::build("helpers::resize_image: Cannot deallocate memory for vertical parts rectangulars", err);
|
throw cuda_exception::build("helpers::resize_image: Cannot deallocate memory for vertical parts rectangulars", err);
|
||||||
|
@ -197,20 +255,28 @@ namespace helpers {
|
||||||
//
|
//
|
||||||
template <typename T>
|
template <typename T>
|
||||||
static __global__ void resizeNeighborKernel(T const* input, Nd4jLong* inputShape, T* output, Nd4jLong* outputShape,
|
static __global__ void resizeNeighborKernel(T const* input, Nd4jLong* inputShape, T* output, Nd4jLong* outputShape,
|
||||||
Nd4jLong batchSize, Nd4jLong inWidth, Nd4jLong inHeight, Nd4jLong outWidth, Nd4jLong outHeight, Nd4jLong channels, double widthScale, double heightScale, bool center) {
|
Nd4jLong batchSize, Nd4jLong inWidth, Nd4jLong inHeight, Nd4jLong outWidth, Nd4jLong outHeight, Nd4jLong channels, double widthScale, double heightScale, bool alignCorners, bool halfPixelCenters) {
|
||||||
|
|
||||||
//for (int b = blockIdx.x; b < batchSize; b += gridDim.x)
|
//for (int b = blockIdx.x; b < batchSize; b += gridDim.x)
|
||||||
if (blockIdx.x < batchSize)
|
if (blockIdx.x < batchSize)
|
||||||
{
|
{
|
||||||
auto b = blockIdx.x;
|
auto b = blockIdx.x;
|
||||||
for (int y = threadIdx.x; y < outHeight; y += blockDim.x) {
|
for (int y = threadIdx.x; y < outHeight; y += blockDim.x) {
|
||||||
Nd4jLong inY = nd4j::math::nd4j_min(
|
auto posY = alignCorners ? static_cast<Nd4jLong>(nd4j::math::p_round<float>(halfPixelCenters?((float)y + 0.5f) * heightScale:(float)y * heightScale)) : static_cast<Nd4jLong>(nd4j::math::p_floor<float>(
|
||||||
(center) ? static_cast<Nd4jLong>(nd4j::math::p_round<float>(y * heightScale)) : static_cast<Nd4jLong>(nd4j::math::p_floor<float>(
|
halfPixelCenters?((float)y + 0.5f) * heightScale:(float)y * heightScale));
|
||||||
y * heightScale)), inHeight - 1);
|
Nd4jLong inY = nd4j::math::nd4j_min(posY, inHeight - 1);
|
||||||
|
if (halfPixelCenters) {
|
||||||
|
inY = nd4j::math::nd4j_max(0LL, inY);
|
||||||
|
}
|
||||||
|
|
||||||
for (int x = threadIdx.y; x < outWidth; x += blockDim.y) {
|
for (int x = threadIdx.y; x < outWidth; x += blockDim.y) {
|
||||||
Nd4jLong inX = nd4j::math::nd4j_min(
|
auto posX = alignCorners ? static_cast<Nd4jLong>(nd4j::math::p_round<float>(halfPixelCenters?((float)x + 0.5f) * widthScale:(float)x * widthScale)) : static_cast<Nd4jLong>(nd4j::math::p_floor<float>(
|
||||||
(center) ? static_cast<Nd4jLong>(nd4j::math::p_round<float>(x * widthScale)) : static_cast<Nd4jLong>(nd4j::math::p_floor<float>(
|
halfPixelCenters?((float)x + 0.5f) * widthScale:(float)x * widthScale));
|
||||||
x * widthScale)), inWidth - 1);
|
Nd4jLong inX = nd4j::math::nd4j_min(posX, inWidth - 1);
|
||||||
|
if (halfPixelCenters) {
|
||||||
|
inX = nd4j::math::nd4j_max(0LL, inX);
|
||||||
|
}
|
||||||
|
|
||||||
auto start = blockIdx.z * blockDim.z + threadIdx.z;
|
auto start = blockIdx.z * blockDim.z + threadIdx.z;
|
||||||
auto step = blockDim.z * gridDim.z;
|
auto step = blockDim.z * gridDim.z;
|
||||||
|
|
||||||
|
@ -231,7 +297,8 @@ namespace helpers {
|
||||||
// resizeNeighborFunctor - main algorithm by nearest neighbor
|
// resizeNeighborFunctor - main algorithm by nearest neighbor
|
||||||
//
|
//
|
||||||
template <typename T>
|
template <typename T>
|
||||||
int resizeNeighborFunctor_(nd4j::LaunchContext* context, NDArray const* images, int width, int height, bool center, NDArray* output) {
|
int resizeNeighborFunctor_(nd4j::LaunchContext* context, NDArray const* images, int const width, int const height,
|
||||||
|
bool const alignCorners, bool const halfPixelCenters, NDArray* output) {
|
||||||
const Nd4jLong batchSize = images->sizeAt(0);
|
const Nd4jLong batchSize = images->sizeAt(0);
|
||||||
const Nd4jLong inHeight = images->sizeAt(1);
|
const Nd4jLong inHeight = images->sizeAt(1);
|
||||||
const Nd4jLong inWidth = images->sizeAt(2);
|
const Nd4jLong inWidth = images->sizeAt(2);
|
||||||
|
@ -246,25 +313,24 @@ namespace helpers {
|
||||||
return ND4J_STATUS_OK;
|
return ND4J_STATUS_OK;
|
||||||
}
|
}
|
||||||
|
|
||||||
if ((center && inHeight < 2) || (inHeight < 1) || (outHeight < 1) || (center && outHeight < 2) ||
|
// if ((alignCorners && inHeight < 2) || (inHeight < 1) || (outHeight < 1) || (alignCorners && outHeight < 2) ||
|
||||||
(center && inWidth < 2) || (inWidth < 1) || (outWidth < 1) || (center && outWidth < 2)) {
|
// (alignCorners && inWidth < 2) || (inWidth < 1) || (outWidth < 1) || (center && outWidth < 2)) {
|
||||||
// wrong input data
|
// // wrong input data
|
||||||
nd4j_printf("image.resize_nearest_neighbor: Wrong input or output size to resize\n", "");
|
// nd4j_printf("image.resize_nearest_neighbor: Wrong input or output size to resize\n", "");
|
||||||
return ND4J_STATUS_BAD_ARGUMENTS;
|
// return ND4J_STATUS_BAD_ARGUMENTS;
|
||||||
}
|
// }
|
||||||
double heightScale = center ? (inHeight - 1.) / double(outHeight - 1.0) : (inHeight / double(outHeight));
|
// float heightScale = alignCorners ? (inHeight - 1.f) / float(outHeight - 1.f) : (inHeight / float(outHeight));
|
||||||
double widthScale = center ? (inWidth - 1.) / double(outWidth - 1.0) : (inWidth / double(outWidth));
|
// float widthScale = alignCorners ? (inWidth - 1.f) / float(outWidth - 1.f) : (inWidth / float(outWidth));
|
||||||
auto imagesBuffer = reinterpret_cast<T const*>(images->getSpecialBuffer());
|
float heightScale = calculateResizeScale(inHeight, outHeight, alignCorners);
|
||||||
auto outputBuffer = reinterpret_cast<T*>(output->specialBuffer());
|
float widthScale = calculateResizeScale(inWidth, outWidth, alignCorners);
|
||||||
|
|
||||||
|
auto imagesBuffer = images->getDataBuffer()->specialAsT<T>();//reinterpret_cast<T const*>(images->getSpecialBuffer());
|
||||||
|
auto outputBuffer = output->dataBuffer()->specialAsT<T>();//reinterpret_cast<T*>(output->specialBuffer());
|
||||||
auto stream = context->getCudaStream();
|
auto stream = context->getCudaStream();
|
||||||
|
|
||||||
//T const* input, Nd4jLong const* inputShape, T* output, Nd4jLong* outputShape,
|
|
||||||
// Nd4jLong batchSize, Nd4jLong inWidth, Nd4jLong inHeight, Nd4jLong outWidth, Nd4jLong outHeight, Nd4jLong channels, double widthScale, double heightScale, bool center
|
|
||||||
//input, inputShape, output, outputShape,
|
|
||||||
// batchSize, inWidth, inHeight, outWidth, outHeight, channels, widthScale, heightScale, center
|
|
||||||
NDArray::prepareSpecialUse({output}, {images});
|
NDArray::prepareSpecialUse({output}, {images});
|
||||||
resizeNeighborKernel<T><<<batchSize, outHeight * outWidth, 512, *stream>>>(imagesBuffer, images->getSpecialShapeInfo(), outputBuffer, output->specialShapeInfo(),
|
resizeNeighborKernel<T><<<batchSize, outHeight * outWidth, 512, *stream>>>(imagesBuffer, images->getSpecialShapeInfo(), outputBuffer, output->specialShapeInfo(),
|
||||||
batchSize, inWidth, inHeight, outWidth, outHeight, channels, widthScale, heightScale, center);
|
batchSize, inWidth, inHeight, outWidth, outHeight, channels, widthScale, heightScale, alignCorners, halfPixelCenters);
|
||||||
NDArray::registerSpecialUse({output}, {images});
|
NDArray::registerSpecialUse({output}, {images});
|
||||||
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
|
@ -275,39 +341,38 @@ namespace helpers {
|
||||||
void resizeImage(nd4j::LaunchContext* context, NDArray const* images, Nd4jLong batchSize, Nd4jLong inHeight,
|
void resizeImage(nd4j::LaunchContext* context, NDArray const* images, Nd4jLong batchSize, Nd4jLong inHeight,
|
||||||
Nd4jLong inWidth, Nd4jLong outHeight, Nd4jLong outWidth, Nd4jLong channels, BilinearInterpolationData* xs_,
|
Nd4jLong inWidth, Nd4jLong outHeight, Nd4jLong outWidth, Nd4jLong channels, BilinearInterpolationData* xs_,
|
||||||
BilinearInterpolationData* ys_, NDArray* output) {
|
BilinearInterpolationData* ys_, NDArray* output) {
|
||||||
BUILD_SINGLE_SELECTOR(images->dataType(), resizeImage_, (context, images, batchSize, inHeight, inWidth, outHeight, outWidth, channels, xs_, ys_, output), LIBND4J_TYPES);
|
BUILD_DOUBLE_SELECTOR(images->dataType(), output->dataType(),
|
||||||
|
resizeImage_, (context, images, batchSize, inHeight, inWidth, outHeight, outWidth, channels,
|
||||||
|
xs_, ys_, output), NUMERIC_TYPES, FLOAT_TYPES);
|
||||||
}
|
}
|
||||||
|
|
||||||
BUILD_SINGLE_TEMPLATE(template void resizeImage_,(nd4j::LaunchContext* context, NDArray const* images,
|
BUILD_DOUBLE_TEMPLATE(template void resizeImage_,(nd4j::LaunchContext* context, NDArray const* images,
|
||||||
Nd4jLong batchSize, Nd4jLong inHeight, Nd4jLong inWidth, Nd4jLong outHeight, Nd4jLong outWidth,
|
Nd4jLong batchSize, Nd4jLong inHeight, Nd4jLong inWidth, Nd4jLong outHeight, Nd4jLong outWidth,
|
||||||
Nd4jLong channels, BilinearInterpolationData* xs_, BilinearInterpolationData* ys_, NDArray* output), LIBND4J_TYPES);
|
Nd4jLong channels, BilinearInterpolationData* xs_, BilinearInterpolationData* ys_, NDArray* output),
|
||||||
|
NUMERIC_TYPES, FLOAT_TYPES);
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
int resizeBilinearFunctor(nd4j::LaunchContext* context, NDArray const* images, int width, int height, bool center, NDArray* output) {
|
int resizeBilinearFunctor(nd4j::LaunchContext* context, NDArray const* images, int width, int height,
|
||||||
BUILD_SINGLE_SELECTOR(images->dataType(), return resizeBilinearFunctor_, (context, images, width, height, center, output), LIBND4J_TYPES);
|
bool const alignCorners, bool const halfPixelCenter, NDArray* output) {
|
||||||
|
BUILD_DOUBLE_SELECTOR(images->dataType(), output->dataType(), return resizeBilinearFunctor_, (context, images,
|
||||||
|
width, height, alignCorners, halfPixelCenter, output), NUMERIC_TYPES, FLOAT_TYPES);
|
||||||
}
|
}
|
||||||
BUILD_SINGLE_TEMPLATE(template int resizeBilinearFunctor_, (nd4j::LaunchContext* context, NDArray const* images, int width, int height, bool center, NDArray* output), LIBND4J_TYPES);
|
// BUILD_SINGLE_TEMPLATE(template int resizeBilinearFunctor_, (nd4j::LaunchContext* context,
|
||||||
|
// NDArray const* images, int const width, int const height, bool const alignCorners,
|
||||||
|
// bool const halfPixelCenter, NDArray* output), LIBND4J_TYPES);
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
int resizeNeighborFunctor(nd4j::LaunchContext* context, NDArray const* images, int width, int height, bool center, NDArray* output) {
|
int resizeNeighborFunctor(nd4j::LaunchContext* context, NDArray const* images, int const width, int const height,
|
||||||
BUILD_SINGLE_SELECTOR(images->dataType(), return resizeNeighborFunctor_, (context, images, width, height, center, output), LIBND4J_TYPES);
|
bool const alignCorners, bool const halfPixelCenter, NDArray* output) {
|
||||||
|
BUILD_SINGLE_SELECTOR(images->dataType(), return resizeNeighborFunctor_,
|
||||||
|
(context, images, width, height, alignCorners, halfPixelCenter, output), LIBND4J_TYPES);
|
||||||
}
|
}
|
||||||
BUILD_SINGLE_TEMPLATE(template int resizeNeighborFunctor_, (nd4j::LaunchContext* context, NDArray const* images,
|
// BUILD_SINGLE_TEMPLATE(template int resizeNeighborFunctor_, (nd4j::LaunchContext* context, NDArray const* images,
|
||||||
int width, int height, bool center, NDArray* output), LIBND4J_TYPES);
|
// int width, int height, bool const alignCorners, bool const halfPixelCenter, NDArray* output), LIBND4J_TYPES);
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
// Bicubic interpolation
|
// Bicubic interpolation
|
||||||
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
// Utility functions and classes
|
|
||||||
|
|
||||||
// calculateResizeScale determines the float scaling factor.
|
|
||||||
inline float calculateResizeScale(Nd4jLong inSize, Nd4jLong outSize,
|
|
||||||
bool alignCorners) {
|
|
||||||
return (alignCorners && outSize > 1)
|
|
||||||
? (inSize - 1) / static_cast<float>(outSize - 1)
|
|
||||||
: inSize / static_cast<float>(outSize);
|
|
||||||
}
|
|
||||||
|
|
||||||
struct ImageResizerState {
|
struct ImageResizerState {
|
||||||
explicit ImageResizerState(bool alignCorners, bool halfPixelCenters)
|
explicit ImageResizerState(bool alignCorners, bool halfPixelCenters)
|
||||||
: _alignCorners(alignCorners),
|
: _alignCorners(alignCorners),
|
||||||
|
@ -362,17 +427,6 @@ namespace helpers {
|
||||||
bool _halfPixelCenters;
|
bool _halfPixelCenters;
|
||||||
};
|
};
|
||||||
|
|
||||||
// Half pixel scaler scales assuming that the pixel centers are at 0.5, i.e. the
|
|
||||||
// floating point coordinates of the top,left pixel is 0.5,0.5.
|
|
||||||
struct HalfPixelScaler {
|
|
||||||
_CUDA_HD HalfPixelScaler(){};
|
|
||||||
inline _CUDA_HD float operator()(const int x, const float scale) const {
|
|
||||||
// Note that we subtract 0.5 from the return value, as the existing bilinear
|
|
||||||
// sampling code etc assumes pixels are in the old coordinate system.
|
|
||||||
return (static_cast<float>(x) + 0.5f) * scale - 0.5f;
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
struct WeightsAndIndices {
|
struct WeightsAndIndices {
|
||||||
float _weight0;
|
float _weight0;
|
||||||
float _weight1;
|
float _weight1;
|
||||||
|
@ -547,16 +601,6 @@ namespace helpers {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Older incorrect scaling method that causes all resizes to have a slight
|
|
||||||
// translation leading to inconsistent results. For example, a flip then a
|
|
||||||
// resize gives different results then a resize then a flip.
|
|
||||||
struct LegacyScaler {
|
|
||||||
_CUDA_HD LegacyScaler(){};
|
|
||||||
inline _CUDA_HD float operator()(const int x, const float scale) const {
|
|
||||||
return static_cast<float>(x) * scale;
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
static __global__ void accumulateChannelsKernel(WeightsAndIndices* pXWais, Nd4jLong outWidth, Nd4jLong channels) {
|
static __global__ void accumulateChannelsKernel(WeightsAndIndices* pXWais, Nd4jLong outWidth, Nd4jLong channels) {
|
||||||
auto start = blockIdx.x * blockDim.x + threadIdx.x;
|
auto start = blockIdx.x * blockDim.x + threadIdx.x;
|
||||||
auto step = blockDim.x * gridDim.x;
|
auto step = blockDim.x * gridDim.x;
|
||||||
|
@ -906,8 +950,8 @@ namespace helpers {
|
||||||
int resizeFunctor(nd4j::LaunchContext * context, NDArray const* image, int width, int height,
|
int resizeFunctor(nd4j::LaunchContext * context, NDArray const* image, int width, int height,
|
||||||
ImageResizeMethods method, bool preserveAspectRatio, bool antialias, NDArray* output) {
|
ImageResizeMethods method, bool preserveAspectRatio, bool antialias, NDArray* output) {
|
||||||
switch (method) {
|
switch (method) {
|
||||||
case kResizeBilinear: return resizeBilinearFunctor(context, image, width, height, false, output); break;
|
case kResizeBilinear: return resizeBilinearFunctor(context, image, width, height, false, false, output); break;
|
||||||
case kResizeNearest: return resizeNeighborFunctor(context, image, width, height, true, output); break;
|
case kResizeNearest: return resizeNeighborFunctor(context, image, width, height, false, false, output); break;
|
||||||
case kResizeBicubic: return resizeBicubicFunctor(context, image, width, height, preserveAspectRatio, antialias, output); break;
|
case kResizeBicubic: return resizeBicubicFunctor(context, image, width, height, preserveAspectRatio, antialias, output); break;
|
||||||
case kResizeLanczos5:
|
case kResizeLanczos5:
|
||||||
case kResizeGaussian:
|
case kResizeGaussian:
|
||||||
|
|
|
@ -37,15 +37,15 @@ namespace helpers {
|
||||||
kResizeArea
|
kResizeArea
|
||||||
};
|
};
|
||||||
|
|
||||||
int resizeBilinearFunctor(nd4j::LaunchContext * context, NDArray const* image, int width, int height, bool center,
|
int resizeBilinearFunctor(nd4j::LaunchContext * context, NDArray const* image, int const width, int const height,
|
||||||
NDArray* output);
|
bool const alignCorners, bool const halfPixelCenter, NDArray* output);
|
||||||
int resizeNeighborFunctor(nd4j::LaunchContext * context, NDArray const* image, int width, int height, bool center,
|
int resizeNeighborFunctor(nd4j::LaunchContext * context, NDArray const* image, int const width, int const height,
|
||||||
NDArray* output);
|
bool const alignCorners, bool const halfPixelCenter, NDArray* output);
|
||||||
int resizeBicubicFunctor(nd4j::LaunchContext * context, NDArray const* image, int width, int height,
|
int resizeBicubicFunctor(nd4j::LaunchContext * context, NDArray const* image, int const width, int const height,
|
||||||
bool preserveAspectRatio, bool antialias, NDArray* output);
|
bool preserveAspectRatio, bool antialias, NDArray* output);
|
||||||
int resizeBicubicFunctorA(nd4j::LaunchContext * context, NDArray const* image, int width, int height,
|
int resizeBicubicFunctorA(nd4j::LaunchContext * context, NDArray const* image, int const width, int const height,
|
||||||
bool const alignCorners, bool const halfPixelAlign, NDArray* output);
|
bool const alignCorners, bool const halfPixelAlign, NDArray* output);
|
||||||
int resizeFunctor(nd4j::LaunchContext * context, NDArray const* image, int width, int height,
|
int resizeFunctor(nd4j::LaunchContext * context, NDArray const* image, int const width, int const height,
|
||||||
ImageResizeMethods method, bool preserveAspectRatio, bool antialias, NDArray* output);
|
ImageResizeMethods method, bool preserveAspectRatio, bool antialias, NDArray* output);
|
||||||
|
|
||||||
void cropAndResizeFunctor(nd4j::LaunchContext * context, NDArray const* images, NDArray const* boxes,
|
void cropAndResizeFunctor(nd4j::LaunchContext * context, NDArray const* images, NDArray const* boxes,
|
||||||
|
|
File diff suppressed because one or more lines are too long
|
@ -212,12 +212,12 @@ TEST_F(ConvolutionTests2, Test_Dilation2D_Again_2) {
|
||||||
}
|
}
|
||||||
|
|
||||||
TYPED_TEST(TypedConvolutionTests2, sconv2d_bp_1) {
|
TYPED_TEST(TypedConvolutionTests2, sconv2d_bp_1) {
|
||||||
TypeParam _expGradWpB[] = {1603.7102981f, 10645.6278024f, 5975.4227995f, 17697.0903052f, 12133.6353024f, 26535.0528052f, 1779.221097f, 11795.5686029f, 6721.9835994f, 19904.0811062f, 13775.2461029f, 30123.0936062f, 1954.7318976f, 12945.5094033f, 7468.5443993f, 22111.071907f, 15416.8569033f, 33711.134407f, 2130.2426974f, 14095.4502038f, 8215.1051992f, 24318.0627081f, 17058.4677038f, 37299.1752081f, 2305.7534972f, 15245.3910042f, 8961.6659991f, 26525.0535091f, 18700.0785042f, 40887.2160091f, 2481.2642970f, 16395.3318047f, 9708.2267991f, 28732.0443100f, 20341.6893047f, 44475.2568100f, 2656.7750968f, 17545.2726051f, 10454.7875990f, 30939.0351110f, 21983.3001051f, 48063.2976110f, 2832.2858966f, 18695.2134056f, 11201.3483989f, 33146.0259119f, 23624.9109056f, 51651.3384119f, 3007.7966964f, 19845.1542060f, 11947.9091988f, 35353.0167129f, 25266.5217060f, 55239.3792129f, 3183.3074962f, 20995.095006f, 12694.4699987f, 37560.007513f, 26908.132506f, 58827.4200139};
|
TypeParam _expGradWpB[] = {1603.7102981f, 10645.6278024f, 5975.4227995f, 17697.0903052f, 12133.6353024f, 26535.0528052f, 1779.221097f, 11795.5686029f, 6721.9835994f, 19904.0811062f, 13775.2461029f, 30123.0936062f, 1954.7318976f, 12945.5094033f, 7468.5443993f, 22111.071907f, 15416.8569033f, 33711.134407f, 2130.2426974f, 14095.4502038f, 8215.1051992f, 24318.0627081f, 17058.4677038f, 37299.1752081f, 2305.7534972f, 15245.3910042f, 8961.6659991f, 26525.0535091f, 18700.0785042f, 40887.2160091f, 2481.2642970f, 16395.3318047f, 9708.2267991f, 28732.0443100f, 20341.6893047f, 44475.2568100f, 2656.7750968f, 17545.2726051f, 10454.7875990f, 30939.0351110f, 21983.3001051f, 48063.2976110f, 2832.2858966f, 18695.2134056f, 11201.3483989f, 33146.0259119f, 23624.9109056f, 51651.3384119f, 3007.7966964f, 19845.1542060f, 11947.9091988f, 35353.0167129f, 25266.5217060f, 55239.3792129f, 3183.3074962f, 20995.095006f, 12694.4699987f, 37560.007513f, 26908.132506f, 58827.4200139f};
|
||||||
Nd4jLong _expGradWpS[] {4, 10, 6, 1, 1, 6, 1, 1, 1, typeid(TypeParam) == typeid(float) ? 8192 : 16384, 1, 99};
|
Nd4jLong _expGradWpS[] {4, 10, 6, 1, 1, 6, 1, 1, 1, typeid(TypeParam) == typeid(float) ? 8192 : 16384, 1, 99};
|
||||||
NDArray expGWP(_expGradWpB, _expGradWpS);
|
NDArray expGWP(_expGradWpB, _expGradWpS);
|
||||||
expGWP.permutei({2,3,1,0});
|
expGWP.permutei({2,3,1,0});
|
||||||
|
|
||||||
TypeParam _expGradWdB[] = {2074.21032f, 2082.76104f, 2091.31176f, 2099.86248f, 2108.4132f, 2159.71752f, 2168.26824f, 2176.81896f, 2185.36968f, 2193.9204f, 2245.22472f, 2253.77544f, 2262.32616f, 2270.87688f, 2279.4276f, 2330.73192f, 2339.28264f, 2347.83336f, 2356.38408f, 2364.9348f, 2416.23912f, 2424.78984f, 2433.34056f, 2441.89128f, 2450.442f, 3112.99344f, 3122.06328f, 3131.13312f, 3140.20296f, 3149.2728f, 3203.69184f, 3212.76168f, 3221.83152f, 3230.90136f, 3239.9712f, 3294.39024f, 3303.46008f, 3312.52992f, 3321.59976f, 3330.6696f, 3385.08864f, 3394.15848f, 3403.22832f, 3412.29816f, 3421.368f, 3475.78704f, 3484.85688f, 3493.92672f, 3502.99656f, 3512.0664f, 4255.60056f, 4265.18952f, 4274.77848f, 4284.36744f, 4293.9564f, 4351.49016f, 4361.07912f, 4370.66808f, 4380.25704f, 4389.846f, 4447.37976f, 4456.96872f, 4466.55768f, 4476.14664f, 4485.7356f, 4543.26936f, 4552.85832f, 4562.44728f, 4572.03624f, 4581.6252f, 4639.15896f, 4648.74792f, 4658.33688f, 4667.92584f, 4677.5148f, 2140.10988f, 2148.92016f, 2157.73044f, 2166.54072f, 2175.351f, 2228.21268f, 2237.02296f, 2245.83324f, 2254.64352f, 2263.4538f, 2316.31548f, 2325.12576f, 2333.93604f, 2342.74632f, 2351.5566f, 2404.41828f, 2413.22856f, 2422.03884f, 2430.84912f, 2439.6594f, 2492.52108f, 2501.33136f, 2510.14164f, 2518.95192f, 2527.7622f, 3204.849f, 3214.1784f, 3223.5078f, 3232.8372f, 3242.1666f, 3298.143f, 3307.4724f, 3316.8018f, 3326.1312f, 3335.4606f, 3391.437f, 3400.7664f, 3410.0958f, 3419.4252f, 3428.7546f, 3484.731f, 3494.0604f, 3503.3898f, 3512.7192f, 3522.0486f, 3578.025f, 3587.3544f, 3596.6838f, 3606.0132f, 3615.3426f, 4373.41212f, 4383.26064f, 4393.10916f, 4402.95768f, 4412.8062f, 4471.89732f, 4481.74584f, 4491.59436f, 4501.44288f, 4511.2914f, 4570.38252f, 4580.23104f, 4590.07956f, 4599.92808f, 4609.7766f, 4668.86772f, 4678.71624f, 4688.56476f, 4698.41328f, 4708.2618f, 4767.35292f, 4777.20144f, 4787.04996f, 4796.89848f, 4806.747};
|
TypeParam _expGradWdB[] = {2074.21032f, 2082.76104f, 2091.31176f, 2099.86248f, 2108.4132f, 2159.71752f, 2168.26824f, 2176.81896f, 2185.36968f, 2193.9204f, 2245.22472f, 2253.77544f, 2262.32616f, 2270.87688f, 2279.4276f, 2330.73192f, 2339.28264f, 2347.83336f, 2356.38408f, 2364.9348f, 2416.23912f, 2424.78984f, 2433.34056f, 2441.89128f, 2450.442f, 3112.99344f, 3122.06328f, 3131.13312f, 3140.20296f, 3149.2728f, 3203.69184f, 3212.76168f, 3221.83152f, 3230.90136f, 3239.9712f, 3294.39024f, 3303.46008f, 3312.52992f, 3321.59976f, 3330.6696f, 3385.08864f, 3394.15848f, 3403.22832f, 3412.29816f, 3421.368f, 3475.78704f, 3484.85688f, 3493.92672f, 3502.99656f, 3512.0664f, 4255.60056f, 4265.18952f, 4274.77848f, 4284.36744f, 4293.9564f, 4351.49016f, 4361.07912f, 4370.66808f, 4380.25704f, 4389.846f, 4447.37976f, 4456.96872f, 4466.55768f, 4476.14664f, 4485.7356f, 4543.26936f, 4552.85832f, 4562.44728f, 4572.03624f, 4581.6252f, 4639.15896f, 4648.74792f, 4658.33688f, 4667.92584f, 4677.5148f, 2140.10988f, 2148.92016f, 2157.73044f, 2166.54072f, 2175.351f, 2228.21268f, 2237.02296f, 2245.83324f, 2254.64352f, 2263.4538f, 2316.31548f, 2325.12576f, 2333.93604f, 2342.74632f, 2351.5566f, 2404.41828f, 2413.22856f, 2422.03884f, 2430.84912f, 2439.6594f, 2492.52108f, 2501.33136f, 2510.14164f, 2518.95192f, 2527.7622f, 3204.849f, 3214.1784f, 3223.5078f, 3232.8372f, 3242.1666f, 3298.143f, 3307.4724f, 3316.8018f, 3326.1312f, 3335.4606f, 3391.437f, 3400.7664f, 3410.0958f, 3419.4252f, 3428.7546f, 3484.731f, 3494.0604f, 3503.3898f, 3512.7192f, 3522.0486f, 3578.025f, 3587.3544f, 3596.6838f, 3606.0132f, 3615.3426f, 4373.41212f, 4383.26064f, 4393.10916f, 4402.95768f, 4412.8062f, 4471.89732f, 4481.74584f, 4491.59436f, 4501.44288f, 4511.2914f, 4570.38252f, 4580.23104f, 4590.07956f, 4599.92808f, 4609.7766f, 4668.86772f, 4678.71624f, 4688.56476f, 4698.41328f, 4708.2618f, 4767.35292f, 4777.20144f, 4787.04996f, 4796.89848f, 4806.747f};
|
||||||
Nd4jLong _expGradWdS[] = {4, 2, 3, 5, 5, 75, 25, 5, 1, typeid(TypeParam) == typeid(float) ? 8192 : 16384, 1, 99};
|
Nd4jLong _expGradWdS[] = {4, 2, 3, 5, 5, 75, 25, 5, 1, typeid(TypeParam) == typeid(float) ? 8192 : 16384, 1, 99};
|
||||||
NDArray expGWD(_expGradWdB, _expGradWdS);
|
NDArray expGWD(_expGradWdB, _expGradWdS);
|
||||||
expGWD.permutei({2,3,1,0});
|
expGWD.permutei({2,3,1,0});
|
||||||
|
|
|
@ -1594,7 +1594,7 @@ TEST_F(DeclarableOpsTests1, TestGemv1) {
|
||||||
|
|
||||||
auto z = NDArrayFactory::create_<float>('f', {5, 1});
|
auto z = NDArrayFactory::create_<float>('f', {5, 1});
|
||||||
|
|
||||||
auto expBuffer = new float[5]{28.00,64.00,100.00,136.00,172.00};
|
auto expBuffer = new float[5]{28.00f,64.00f,100.00f,136.00f,172.00f};
|
||||||
auto exp = new NDArray(expBuffer, z->getShapeInfo());
|
auto exp = new NDArray(expBuffer, z->getShapeInfo());
|
||||||
|
|
||||||
nd4j::blas::GEMV<float, float, float>::op('f', x->rows(), x->columns(), 1.0f, x->getBuffer(), y->rows(), y->getBuffer(), 1, 0.0, z->getBuffer(), 1);
|
nd4j::blas::GEMV<float, float, float>::op('f', x->rows(), x->columns(), 1.0f, x->getBuffer(), y->rows(), y->getBuffer(), 1, 0.0, z->getBuffer(), 1);
|
||||||
|
@ -3606,7 +3606,9 @@ TEST_F(DeclarableOpsTests1, Reverse_11 ) {
|
||||||
|
|
||||||
|
|
||||||
auto input = NDArrayFactory::create<float>('c', {2,3,4});
|
auto input = NDArrayFactory::create<float>('c', {2,3,4});
|
||||||
auto expected = NDArrayFactory::create<float>('c', {2,3,4}, {24., 23., 22., 21., 20., 19., 18., 17., 16., 15., 14., 13., 12., 11., 10., 9., 8., 7., 6., 5., 4., 3., 2., 1.});
|
auto expected = NDArrayFactory::create<float>('c', {2,3,4}, {24.f, 23.f, 22.f, 21.f, 20.f, 19.f, 18.f, 17.f, 16.f,
|
||||||
|
15.f, 14.f, 13.f, 12.f, 11.f, 10.f, 9.f, 8.f, 7.f,
|
||||||
|
6.f, 5.f, 4.f, 3.f, 2.f, 1.f});
|
||||||
|
|
||||||
input.linspace(1);
|
input.linspace(1);
|
||||||
nd4j::ops::reverse op;
|
nd4j::ops::reverse op;
|
||||||
|
|
|
@ -121,10 +121,10 @@ TEST_F(DeclarableOpsTests10, Test_Or_1) {
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(DeclarableOpsTests10, Test_Not_1) {
|
TEST_F(DeclarableOpsTests10, Test_Not_1) {
|
||||||
auto x = NDArrayFactory::create<bool>('c', {4}, {1, 1, 0, 1});
|
auto x = NDArrayFactory::create<bool>('c', {4}, {true, true, false, true});
|
||||||
auto y = NDArrayFactory::create<bool>('c', {4}, {0, 0, 0, 1});
|
auto y = NDArrayFactory::create<bool>('c', {4}, {false, false, false, true});
|
||||||
// auto e = NDArrayFactory::create<bool>('c', {4}, {1, 1, 1, 0});
|
// auto e = NDArrayFactory::create<bool>('c', {4}, {1, 1, 1, 0});
|
||||||
auto e = NDArrayFactory::create<bool>('c', {4}, {0, 0, 1, 0});
|
auto e = NDArrayFactory::create<bool>('c', {4}, {false, false, true, false});
|
||||||
|
|
||||||
nd4j::ops::boolean_not op;
|
nd4j::ops::boolean_not op;
|
||||||
auto result = op.execute({&x, &y}, {}, {}, {}, false, nd4j::DataType::BOOL);
|
auto result = op.execute({&x, &y}, {}, {}, {}, false, nd4j::DataType::BOOL);
|
||||||
|
@ -245,7 +245,8 @@ TEST_F(DeclarableOpsTests10, WhereNP_SGO_Test_1) {
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
TEST_F(DeclarableOpsTests10, WhereNP_SGO_Test_2) {
|
TEST_F(DeclarableOpsTests10, WhereNP_SGO_Test_2) {
|
||||||
auto cond2d = NDArrayFactory::create<bool>('c', {3, 5}, {1, 1, 0, 0, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1});
|
auto cond2d = NDArrayFactory::create<bool>('c', {3, 5}, {true, true, false, false, true, true, true,
|
||||||
|
true, true, true, false, true, true, true, true});
|
||||||
// auto expIdx({0, 1, 0, 2, 0, 3, 4, 1, 4, 1});
|
// auto expIdx({0, 1, 0, 2, 0, 3, 4, 1, 4, 1});
|
||||||
auto exp1 = NDArrayFactory::create<Nd4jLong>({0, 0, 0, 1, 1, 1, 1, 1, 2, 2, 2, 2});
|
auto exp1 = NDArrayFactory::create<Nd4jLong>({0, 0, 0, 1, 1, 1, 1, 1, 2, 2, 2, 2});
|
||||||
auto exp2 = NDArrayFactory::create<Nd4jLong>({0, 1, 4, 0, 1, 2, 3, 4, 1, 2, 3, 4});
|
auto exp2 = NDArrayFactory::create<Nd4jLong>({0, 1, 4, 0, 1, 2, 3, 4, 1, 2, 3, 4});
|
||||||
|
@ -623,7 +624,7 @@ TEST_F(DeclarableOpsTests10, range_test11) {
|
||||||
//////////////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////////////
|
||||||
TEST_F(DeclarableOpsTests10, range_test12) {
|
TEST_F(DeclarableOpsTests10, range_test12) {
|
||||||
|
|
||||||
auto exp = NDArrayFactory::create<float>('c', {9}, {0.5, 1. , 1.5, 2. , 2.5, 3. , 3.5, 4. , 4.5});
|
auto exp = NDArrayFactory::create<float>('c', {9}, {0.5f, 1.f , 1.5f, 2.f , 2.5f, 3.f , 3.5f, 4.f , 4.5f});
|
||||||
|
|
||||||
nd4j::ops::range op;
|
nd4j::ops::range op;
|
||||||
auto result = op.execute({}, {0.5, 5, 0.5}, {}, {});
|
auto result = op.execute({}, {0.5, 5, 0.5}, {}, {});
|
||||||
|
@ -1416,7 +1417,7 @@ TEST_F(DeclarableOpsTests10, broadcast_to_test10) {
|
||||||
////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////
|
||||||
TEST_F(DeclarableOpsTests10, ImageResizeBilinear_Test1) {
|
TEST_F(DeclarableOpsTests10, ImageResizeBilinear_Test1) {
|
||||||
|
|
||||||
NDArray input = NDArrayFactory::create<double>('c', {1, 2,3,4});
|
NDArray input = NDArrayFactory::create<double>('c', {1, 2, 3, 4});
|
||||||
//NDArray<float> paddings('c', {3,2}, {0,0, 0,1, 0,0});
|
//NDArray<float> paddings('c', {3,2}, {0,0, 0,1, 0,0});
|
||||||
//NDArray<float> expected('c', {2,4,4}, {1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,0.,0.,0.,0.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,0.,0.,0.,0.});
|
//NDArray<float> expected('c', {2,4,4}, {1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,0.,0.,0.,0.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,0.,0.,0.,0.});
|
||||||
NDArray expected = NDArrayFactory::create<double>('c', {1, 10, 10, 4}, {1., 2., 3., 4., 2.2, 3.2, 4.2, 5.2, 3.4, 4.4, 5.4, 6.4,
|
NDArray expected = NDArrayFactory::create<double>('c', {1, 10, 10, 4}, {1., 2., 3., 4., 2.2, 3.2, 4.2, 5.2, 3.4, 4.4, 5.4, 6.4,
|
||||||
|
@ -1470,6 +1471,138 @@ TEST_F(DeclarableOpsTests10, ImageResizeBilinear_Test1) {
|
||||||
delete results;
|
delete results;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
////////////////////////////////////////////////////////////////////
|
||||||
|
TEST_F(DeclarableOpsTests10, ImageResizeBilinear_Test_11) {
|
||||||
|
|
||||||
|
NDArray input = NDArrayFactory::create<float>('c', {1, 1, 1, 256});
|
||||||
|
|
||||||
|
input.assign(0.8f); //linspace(1);
|
||||||
|
auto size = NDArrayFactory::create<int>({65,65});
|
||||||
|
auto ex = NDArrayFactory::create<float>('c', {1,65,65,256});
|
||||||
|
nd4j::ops::resize_bilinear op;
|
||||||
|
auto results = op.execute({&input, &size}, {}, {}, {false});
|
||||||
|
|
||||||
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
NDArray* result = results->at(0);
|
||||||
|
ASSERT_NE(*result, ex);
|
||||||
|
|
||||||
|
delete results;
|
||||||
|
}
|
||||||
|
|
||||||
|
////////////////////////////////////////////////////////////////////
|
||||||
|
TEST_F(DeclarableOpsTests10, ImageResizeBilinear_Test_12) {
|
||||||
|
|
||||||
|
NDArray input = NDArrayFactory::create<float>('c', {1, 1, 1, 256});
|
||||||
|
|
||||||
|
input.assign(0.8f); //linspace(1);
|
||||||
|
auto size = NDArrayFactory::create<int>({65,65});
|
||||||
|
auto ex = NDArrayFactory::create<float>('c', {1,65,65,256});
|
||||||
|
nd4j::ops::resize_bilinear op;
|
||||||
|
auto results = op.execute({&input, &size}, {}, {}, {true});
|
||||||
|
|
||||||
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
NDArray* result = results->at(0);
|
||||||
|
ASSERT_NE(*result, ex);
|
||||||
|
|
||||||
|
delete results;
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(DeclarableOpsTests10, ImageResizeBilinear_Test1_1) {
|
||||||
|
|
||||||
|
NDArray input = NDArrayFactory::create<double>('c', {1, 2, 3, 4});
|
||||||
|
//NDArray<float> paddings('c', {3,2}, {0,0, 0,1, 0,0});
|
||||||
|
//NDArray<float> expected('c', {2,4,4}, {1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,0.,0.,0.,0.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,0.,0.,0.,0.});
|
||||||
|
NDArray expected = NDArrayFactory::create<double>('c', {1, 4, 5, 4}, {
|
||||||
|
1., 2., 3., 4.,
|
||||||
|
2.6, 3.6, 4.6, 5.6,
|
||||||
|
5., 6., 7., 8.,
|
||||||
|
7.4, 8.4, 9.4, 10.4,
|
||||||
|
9., 10., 11., 12.,
|
||||||
|
|
||||||
|
4., 5., 6., 7.,
|
||||||
|
5.6, 6.6, 7.6, 8.6,
|
||||||
|
8., 9., 10., 11.,
|
||||||
|
10.4, 11.4, 12.4, 13.4,
|
||||||
|
12., 13., 14., 15.,
|
||||||
|
|
||||||
|
10., 11., 12., 13.,
|
||||||
|
11.6, 12.6, 13.6, 14.6,
|
||||||
|
14., 15., 16., 17.,
|
||||||
|
16.4, 17.4, 18.4, 19.4,
|
||||||
|
18., 19., 20., 21.,
|
||||||
|
|
||||||
|
13., 14., 15., 16.,
|
||||||
|
14.6, 15.6, 16.6, 17.6,
|
||||||
|
17., 18., 19., 20.,
|
||||||
|
19.4, 20.4, 21.4, 22.4,
|
||||||
|
21., 22., 23., 24.
|
||||||
|
});
|
||||||
|
//input = 1.f;
|
||||||
|
input.linspace(1);
|
||||||
|
|
||||||
|
nd4j::ops::resize_bilinear op;
|
||||||
|
auto results = op.execute({&input}, {}, {4, 5}, {false, true});
|
||||||
|
|
||||||
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
NDArray* result = results->at(0);
|
||||||
|
|
||||||
|
// result->printIndexedBuffer("Resized to 4x5 bilinear with half pixels");
|
||||||
|
//expected.printIndexedBuffer("Expect for 10x10");
|
||||||
|
ASSERT_TRUE(expected.isSameShape(result));
|
||||||
|
ASSERT_TRUE(expected.equalsTo(result));
|
||||||
|
delete results;
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(DeclarableOpsTests10, ImageResizeBilinear_Test1_2) {
|
||||||
|
|
||||||
|
NDArray input = NDArrayFactory::create<int>('c', {1, 2, 3, 4});
|
||||||
|
//NDArray<float> paddings('c', {3,2}, {0,0, 0,1, 0,0});
|
||||||
|
//NDArray<float> expected('c', {2,4,4}, {1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,0.,0.,0.,0.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,0.,0.,0.,0.});
|
||||||
|
NDArray expected = NDArrayFactory::create<float>('c', {1, 4, 5, 4}, {
|
||||||
|
1.f, 2.f, 3.f, 4.f,
|
||||||
|
2.6f, 3.6f, 4.6f, 5.6f,
|
||||||
|
5.f, 6.f, 7.f, 8.f,
|
||||||
|
7.4f, 8.4f, 9.4f, 10.4f,
|
||||||
|
9.f, 10.f, 11.f, 12.f,
|
||||||
|
|
||||||
|
4.f, 5.f, 6.f, 7.f,
|
||||||
|
5.6f, 6.6f, 7.6f, 8.6f,
|
||||||
|
8.f, 9.f, 10.f, 11.f,
|
||||||
|
10.4f, 11.4f, 12.4f, 13.4f,
|
||||||
|
12.f, 13.f, 14.f, 15.f,
|
||||||
|
|
||||||
|
10.f, 11.f, 12.f, 13.f,
|
||||||
|
11.6f, 12.6f, 13.6f, 14.6f,
|
||||||
|
14.f, 15.f, 16.f, 17.f,
|
||||||
|
16.4f, 17.4f, 18.4f, 19.4f,
|
||||||
|
18.f, 19.f, 20.f, 21.f,
|
||||||
|
|
||||||
|
13.f, 14.f, 15.f, 16.f,
|
||||||
|
14.6f, 15.6f, 16.6f, 17.6f,
|
||||||
|
17.f, 18.f, 19.f, 20.f,
|
||||||
|
19.4f, 20.4f, 21.4f, 22.4f,
|
||||||
|
21.f, 22.f, 23.f, 24.f
|
||||||
|
});
|
||||||
|
//input = 1.f;
|
||||||
|
input.linspace(1);
|
||||||
|
|
||||||
|
nd4j::ops::resize_bilinear op;
|
||||||
|
auto results = op.execute({&input}, {}, {4, 5}, {false, true});
|
||||||
|
|
||||||
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
NDArray* result = results->at(0);
|
||||||
|
|
||||||
|
// result->printBuffer("Resized to 4x5");
|
||||||
|
// expected.printBuffer("Expect for 4x5");
|
||||||
|
ASSERT_TRUE(expected.isSameShape(result));
|
||||||
|
ASSERT_TRUE(expected.equalsTo(result));
|
||||||
|
delete results;
|
||||||
|
}
|
||||||
|
|
||||||
TEST_F(DeclarableOpsTests10, ImageResizeBilinear_Test01) {
|
TEST_F(DeclarableOpsTests10, ImageResizeBilinear_Test01) {
|
||||||
|
|
||||||
NDArray input = NDArrayFactory::create<double>('c', {2,3,4});
|
NDArray input = NDArrayFactory::create<double>('c', {2,3,4});
|
||||||
|
@ -1857,7 +1990,7 @@ TEST_F(DeclarableOpsTests10, ImageResizeBilinear_Test3) {
|
||||||
input.linspace(1);
|
input.linspace(1);
|
||||||
|
|
||||||
nd4j::ops::resize_bilinear op;
|
nd4j::ops::resize_bilinear op;
|
||||||
auto results = op.execute({&input}, {}, {10, 10, 1});
|
auto results = op.execute({&input}, {}, {10, 10}, {true});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
@ -1986,7 +2119,7 @@ TEST_F(DeclarableOpsTests10, ImageResizeBilinear_Test4) {
|
||||||
input.linspace(1);
|
input.linspace(1);
|
||||||
|
|
||||||
nd4j::ops::resize_bilinear op;
|
nd4j::ops::resize_bilinear op;
|
||||||
auto results = op.execute({&input, &size}, {}, {1});
|
auto results = op.execute({&input, &size}, {}, {}, {true});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
@ -2023,7 +2156,56 @@ TEST_F(DeclarableOpsTests10, ImageResizeNeighbor_Test1) {
|
||||||
NDArray input = NDArrayFactory::create<double>('c', {1, 2, 3, 4});
|
NDArray input = NDArrayFactory::create<double>('c', {1, 2, 3, 4});
|
||||||
//NDArray<float> paddings('c', {3,2}, {0,0, 0,1, 0,0});
|
//NDArray<float> paddings('c', {3,2}, {0,0, 0,1, 0,0});
|
||||||
//NDArray<float> expected('c', {2,4,4}, {1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,0.,0.,0.,0.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,0.,0.,0.,0.});
|
//NDArray<float> expected('c', {2,4,4}, {1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,0.,0.,0.,0.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,0.,0.,0.,0.});
|
||||||
NDArray expected = NDArrayFactory::create<double>('c', {1, 4, 5, 4}, { 1, 2, 3, 4,
|
NDArray expected = NDArrayFactory::create<double>('c', {1, 4, 5, 4}, {
|
||||||
|
1, 2, 3, 4,
|
||||||
|
1, 2, 3, 4,
|
||||||
|
5, 6, 7, 8,
|
||||||
|
5, 6, 7, 8,
|
||||||
|
9, 10, 11, 12,
|
||||||
|
|
||||||
|
1, 2, 3, 4,
|
||||||
|
1, 2, 3, 4,
|
||||||
|
5, 6, 7, 8,
|
||||||
|
5, 6, 7, 8,
|
||||||
|
9, 10, 11, 12,
|
||||||
|
|
||||||
|
13, 14, 15, 16,
|
||||||
|
13, 14, 15, 16,
|
||||||
|
17, 18, 19, 20,
|
||||||
|
17, 18, 19, 20,
|
||||||
|
21, 22, 23, 24,
|
||||||
|
|
||||||
|
13, 14, 15, 16,
|
||||||
|
13, 14, 15, 16,
|
||||||
|
17, 18, 19, 20,
|
||||||
|
17, 18, 19, 20,
|
||||||
|
21, 22, 23, 24
|
||||||
|
});
|
||||||
|
//input = 1.f;
|
||||||
|
input.linspace(1);
|
||||||
|
|
||||||
|
nd4j::ops::resize_nearest_neighbor op;
|
||||||
|
auto results = op.execute({&input}, {}, {4, 5}, {false, false});
|
||||||
|
|
||||||
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
NDArray* result = results->at(0);
|
||||||
|
|
||||||
|
// result->printIndexedBuffer("Resized to 4x5");
|
||||||
|
// expected.printIndexedBuffer("Expect for 4x5");
|
||||||
|
ASSERT_TRUE(expected.isSameShape(result));
|
||||||
|
ASSERT_TRUE(expected.equalsTo(result));
|
||||||
|
|
||||||
|
delete results;
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(DeclarableOpsTests10, ImageResizeNeighbor_Test1_1) {
|
||||||
|
|
||||||
|
NDArray input = NDArrayFactory::create<int>('c', {1, 2, 3, 4});
|
||||||
|
//NDArray<float> paddings('c', {3,2}, {0,0, 0,1, 0,0});
|
||||||
|
//NDArray<float> expected('c', {2,4,4}, {1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,0.,0.,0.,0.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,0.,0.,0.,0.});
|
||||||
|
NDArray expected = NDArrayFactory::create<int>('c', {1, 4, 5, 4}, {
|
||||||
|
1, 2, 3, 4,
|
||||||
1, 2, 3, 4,
|
1, 2, 3, 4,
|
||||||
5, 6, 7, 8,
|
5, 6, 7, 8,
|
||||||
5, 6, 7, 8,
|
5, 6, 7, 8,
|
||||||
|
@ -2065,47 +2247,48 @@ TEST_F(DeclarableOpsTests10, ImageResizeNeighbor_Test1) {
|
||||||
delete results;
|
delete results;
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(DeclarableOpsTests10, ImageResizeNeighbor_Test1_1) {
|
TEST_F(DeclarableOpsTests10, ImageResizeNeighbor_Test1_1_1) {
|
||||||
|
|
||||||
NDArray input = NDArrayFactory::create<int>('c', {1, 2, 3, 4});
|
NDArray input = NDArrayFactory::create<float>('c', {1, 2, 3, 4});
|
||||||
//NDArray<float> paddings('c', {3,2}, {0,0, 0,1, 0,0});
|
//NDArray<float> paddings('c', {3,2}, {0,0, 0,1, 0,0});
|
||||||
//NDArray<float> expected('c', {2,4,4}, {1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,0.,0.,0.,0.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,0.,0.,0.,0.});
|
//NDArray<float> expected('c', {2,4,4}, {1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,0.,0.,0.,0.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,0.,0.,0.,0.});
|
||||||
NDArray expected = NDArrayFactory::create<int>('c', {1, 4, 5, 4}, { 1, 2, 3, 4,
|
NDArray expected = NDArrayFactory::create<float>('c', {1, 4, 5, 4}, {
|
||||||
1, 2, 3, 4,
|
1.f, 2.f, 3.f, 4.f,
|
||||||
5, 6, 7, 8,
|
1.f, 2.f, 3.f, 4.f,
|
||||||
5, 6, 7, 8,
|
5.f, 6.f, 7.f, 8.f,
|
||||||
9, 10, 11, 12,
|
9.f, 10.f, 11.f, 12.f,
|
||||||
|
9.f, 10.f, 11.f, 12.f,
|
||||||
|
|
||||||
1, 2, 3, 4,
|
1.f, 2.f, 3.f, 4.f,
|
||||||
1, 2, 3, 4,
|
1.f, 2.f, 3.f, 4.f,
|
||||||
5, 6, 7, 8,
|
5.f, 6.f, 7.f, 8.f,
|
||||||
5, 6, 7, 8,
|
9.f, 10.f, 11.f, 12.f,
|
||||||
9, 10, 11, 12,
|
9.f, 10.f, 11.f, 12.f,
|
||||||
|
|
||||||
13, 14, 15, 16,
|
13.f, 14.f, 15.f, 16.f,
|
||||||
13, 14, 15, 16,
|
13.f, 14.f, 15.f, 16.f,
|
||||||
17, 18, 19, 20,
|
17.f, 18.f, 19.f, 20.f,
|
||||||
17, 18, 19, 20,
|
21.f, 22.f, 23.f, 24.f,
|
||||||
21, 22, 23, 24,
|
21.f, 22.f, 23.f, 24.f,
|
||||||
|
|
||||||
13, 14, 15, 16,
|
13.f, 14.f, 15.f, 16.f,
|
||||||
13, 14, 15, 16,
|
13.f, 14.f, 15.f, 16.f,
|
||||||
17, 18, 19, 20,
|
17.f, 18.f, 19.f, 20.f,
|
||||||
17, 18, 19, 20,
|
21.f, 22.f, 23.f, 24.f,
|
||||||
21, 22, 23, 24
|
21.f, 22.f, 23.f, 24.f
|
||||||
});
|
});
|
||||||
//input = 1.f;
|
//input = 1.f;
|
||||||
input.linspace(1);
|
input.linspace(1);
|
||||||
|
|
||||||
nd4j::ops::resize_nearest_neighbor op;
|
nd4j::ops::resize_nearest_neighbor op;
|
||||||
auto results = op.execute({&input}, {}, {4, 5});
|
auto results = op.execute({&input}, {}, {4,5}, {false, true});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
NDArray* result = results->at(0);
|
NDArray* result = results->at(0);
|
||||||
|
|
||||||
// result->printIndexedBuffer("Resized to 4x5");
|
// result->printIndexedBuffer("Resized to 4x5");
|
||||||
// expected.printIndexedBuffer("Expect for 4x5");
|
// expected.printBuffer("Expect for 4x5");
|
||||||
ASSERT_TRUE(expected.isSameShape(result));
|
ASSERT_TRUE(expected.isSameShape(result));
|
||||||
ASSERT_TRUE(expected.equalsTo(result));
|
ASSERT_TRUE(expected.equalsTo(result));
|
||||||
|
|
||||||
|
@ -2533,7 +2716,7 @@ TEST_F(DeclarableOpsTests10, Image_CropAndResize_3) {
|
||||||
NDArray cropSize = NDArrayFactory::create<Nd4jLong>({3, 3});
|
NDArray cropSize = NDArrayFactory::create<Nd4jLong>({3, 3});
|
||||||
|
|
||||||
//NDArray<float> ('c', {6}, {0.9f, .75f, .6f, .95f, .5f, .3f});
|
//NDArray<float> ('c', {6}, {0.9f, .75f, .6f, .95f, .5f, .3f});
|
||||||
NDArray expected('c', {1,3,3,1}, {1, 1.5f, 2., 2.f, 2.5f, 3.f, 3.f, 3.5f, 4.f}, nd4j::DataType::FLOAT32);
|
NDArray expected('c', {1,3,3,1}, {1.f, 1.5f, 2., 2.f, 2.5f, 3.f, 3.f, 3.5f, 4.f}, nd4j::DataType::FLOAT32);
|
||||||
|
|
||||||
nd4j::ops::crop_and_resize op;
|
nd4j::ops::crop_and_resize op;
|
||||||
auto results = op.execute({&images, &boxes, &boxI, &cropSize}, {}, {0});
|
auto results = op.execute({&images, &boxes, &boxI, &cropSize}, {}, {0});
|
||||||
|
@ -2557,7 +2740,7 @@ TEST_F(DeclarableOpsTests10, Image_CropAndResize_4) {
|
||||||
NDArray cropSize = NDArrayFactory::create<int>({3, 3});
|
NDArray cropSize = NDArrayFactory::create<int>({3, 3});
|
||||||
|
|
||||||
//NDArray<float> ('c', {6}, {0.9f, .75f, .6f, .95f, .5f, .3f});
|
//NDArray<float> ('c', {6}, {0.9f, .75f, .6f, .95f, .5f, .3f});
|
||||||
NDArray expected('c', {1,3,3,1}, {1, 2.f, 2.f, 3.f, 4, 4.f, 3.f, 4.f, 4.f}, nd4j::DataType::FLOAT32);
|
NDArray expected('c', {1,3,3,1}, {1.f, 2.f, 2.f, 3.f, 4, 4.f, 3.f, 4.f, 4.f}, nd4j::DataType::FLOAT32);
|
||||||
|
|
||||||
nd4j::ops::crop_and_resize op;
|
nd4j::ops::crop_and_resize op;
|
||||||
auto results = op.execute({&images, &boxes, &boxI, &cropSize}, {}, {1});
|
auto results = op.execute({&images, &boxes, &boxI, &cropSize}, {}, {1});
|
||||||
|
@ -2726,7 +2909,7 @@ TEST_F(DeclarableOpsTests10, Image_DrawBoundingBoxes_3) {
|
||||||
TEST_F(DeclarableOpsTests10, FakeQuantWithMinMaxVars_Test_1) {
|
TEST_F(DeclarableOpsTests10, FakeQuantWithMinMaxVars_Test_1) {
|
||||||
|
|
||||||
NDArray x('c', {2,3}, {-63.80f, -63.75f, -63.70f, -63.5f, 0.0f, 0.1f}, nd4j::DataType::FLOAT32);
|
NDArray x('c', {2,3}, {-63.80f, -63.75f, -63.70f, -63.5f, 0.0f, 0.1f}, nd4j::DataType::FLOAT32);
|
||||||
NDArray exp('c', {2,3}, {-63.75, -63.75, -63.75, -63.5, 0., 0.}, nd4j::DataType::FLOAT32);
|
NDArray exp('c', {2,3}, {-63.75f, -63.75f, -63.75f, -63.5f, 0.f, 0.f}, nd4j::DataType::FLOAT32);
|
||||||
NDArray min('c', {}, {-63.65f}, nd4j::DataType::FLOAT32);
|
NDArray min('c', {}, {-63.65f}, nd4j::DataType::FLOAT32);
|
||||||
NDArray max('c', {}, {0.1f}, nd4j::DataType::FLOAT32);
|
NDArray max('c', {}, {0.1f}, nd4j::DataType::FLOAT32);
|
||||||
|
|
||||||
|
@ -2971,22 +3154,6 @@ TEST_F(DeclarableOpsTests10, FakeQuantWithMinMaxVars_Test_5) {
|
||||||
delete results;
|
delete results;
|
||||||
}
|
}
|
||||||
|
|
||||||
/* public void testFakeQuantAgainstTF_1() {
|
|
||||||
INDArray x = Nd4j.createFromArray(new float[]{ 0.7788f,0.8012f, 0.7244f, 0.2309f,0.7271f,
|
|
||||||
0.1804f, 0.5056f, 0.8925f, 0.5461f, 0.9234f,
|
|
||||||
0.0856f, 0.7938f, 0.6591f, 0.5555f, 0.1596f}).reshape(3,5);
|
|
||||||
INDArray min = Nd4j.createFromArray(new float[]{-0.2283f, -0.0719f, -0.0154f, -0.5162f, -0.3567f}).reshape(1,5);
|
|
||||||
INDArray max = Nd4j.createFromArray(new float[]{0.9441f, 0.5957f, 0.8669f, 0.3502f, 0.5100f}).reshape(1,5);
|
|
||||||
|
|
||||||
INDArray out = Nd4j.createUninitialized(x.shape());
|
|
||||||
val op = new FakeQuantWithMinMaxVarsPerChannel(x,min,max,out);
|
|
||||||
|
|
||||||
INDArray expected = Nd4j.createFromArray(new float[]{0.7801f, 0.5966f, 0.7260f, 0.2320f, 0.5084f,
|
|
||||||
0.1800f, 0.5046f, 0.8684f, 0.3513f, 0.5084f,
|
|
||||||
0.0877f, 0.5966f, 0.6600f, 0.3513f, 0.1604f}).reshape(3,5);
|
|
||||||
|
|
||||||
assertEquals(expected, out);
|
|
||||||
}*/
|
|
||||||
TEST_F(DeclarableOpsTests10, FakeQuantWithMinMaxVars_Test_6) {
|
TEST_F(DeclarableOpsTests10, FakeQuantWithMinMaxVars_Test_6) {
|
||||||
NDArray x = NDArrayFactory::create<float>('c', {3, 5}, {0.7788f,0.8012f, 0.7244f, 0.2309f,0.7271f,
|
NDArray x = NDArrayFactory::create<float>('c', {3, 5}, {0.7788f,0.8012f, 0.7244f, 0.2309f,0.7271f,
|
||||||
0.1804f, 0.5056f, 0.8925f, 0.5461f, 0.9234f,
|
0.1804f, 0.5056f, 0.8925f, 0.5461f, 0.9234f,
|
||||||
|
@ -3094,12 +3261,12 @@ TEST_F(DeclarableOpsTests10, FakeQuantWithMinMaxVars_Test_8) {
|
||||||
TEST_F(DeclarableOpsTests10, batchnorm_test1) {
|
TEST_F(DeclarableOpsTests10, batchnorm_test1) {
|
||||||
|
|
||||||
NDArray input ('c', {2,4}, nd4j::DataType::FLOAT32);
|
NDArray input ('c', {2,4}, nd4j::DataType::FLOAT32);
|
||||||
NDArray mean ('c', {4}, {1.05, 1.15, 1.2, 1.3}, nd4j::DataType::FLOAT32);
|
NDArray mean ('c', {4}, {1.05f, 1.15f, 1.2f, 1.3f}, nd4j::DataType::FLOAT32);
|
||||||
NDArray variance('c', {4}, {0.5, 0.7, 0.9, 1.1}, nd4j::DataType::FLOAT32);
|
NDArray variance('c', {4}, {0.5f, 0.7f, 0.9f, 1.1f}, nd4j::DataType::FLOAT32);
|
||||||
NDArray gamma ('c', {4}, {-1.2, 1.3, -1.4, 1.5}, nd4j::DataType::FLOAT32);
|
NDArray gamma ('c', {4}, {-1.2f, 1.3f, -1.4f, 1.5f}, nd4j::DataType::FLOAT32);
|
||||||
NDArray beta ('c', {4}, {10, 20, -10, -20}, nd4j::DataType::FLOAT32);
|
NDArray beta ('c', {4}, {10.f, 20.f, -10.f, -20.f}, nd4j::DataType::FLOAT32);
|
||||||
|
|
||||||
NDArray expected('c', {2,4}, {11.61218734, 18.52390321, -8.67185076, -21.28716864, 10.93337162, 19.14541765, -9.26213931, -20.71509369}, nd4j::DataType::FLOAT32);
|
NDArray expected('c', {2,4}, {11.61218734f, 18.52390321f, -8.67185076f, -21.28716864f, 10.93337162f, 19.14541765f, -9.26213931f, -20.71509369f}, nd4j::DataType::FLOAT32);
|
||||||
|
|
||||||
input.linspace(0.1, 0.1);
|
input.linspace(0.1, 0.1);
|
||||||
|
|
||||||
|
@ -3211,19 +3378,19 @@ TYPED_TEST(TypedDeclarableOpsTests10, batchnorm_test4) {
|
||||||
TEST_F(DeclarableOpsTests10, batchnorm_test5) {
|
TEST_F(DeclarableOpsTests10, batchnorm_test5) {
|
||||||
|
|
||||||
NDArray input ('c', {2,4,2,2}, nd4j::DataType::FLOAT32);
|
NDArray input ('c', {2,4,2,2}, nd4j::DataType::FLOAT32);
|
||||||
NDArray mean ('c', {4}, {1.05, 1.15, 1.2, 1.3}, nd4j::DataType::FLOAT32);
|
NDArray mean ('c', {4}, {1.05f, 1.15f, 1.2f, 1.3f}, nd4j::DataType::FLOAT32);
|
||||||
NDArray variance('c', {4}, {0.5, 0.7, 0.9, 1.1}, nd4j::DataType::FLOAT32);
|
NDArray variance('c', {4}, {0.5f, 0.7f, 0.9f, 1.1f}, nd4j::DataType::FLOAT32);
|
||||||
NDArray gamma ('c', {4}, {-1.2, 1.3, -1.4, 1.5}, nd4j::DataType::FLOAT32);
|
NDArray gamma ('c', {4}, {-1.2f, 1.3f, -1.4f, 1.5f}, nd4j::DataType::FLOAT32);
|
||||||
NDArray beta ('c', {4}, {10, 20, -10, -20}, nd4j::DataType::FLOAT32);
|
NDArray beta ('c', {4}, {10.f, 20.f, -10.f, -20.f}, nd4j::DataType::FLOAT32);
|
||||||
|
|
||||||
NDArray expected('c', {2,4,2,2}, {11.612187, 11.442483, 11.272779, 11.103076, 18.990039, 19.145418, 19.300796, 19.456175, -9.557284, -9.704856, -9.852428, -10., -20.,
|
NDArray expected('c', {2,4,2,2}, { 11.612187f, 11.442483f, 11.272779f, 11.103076f, 18.990039f, 19.145418f, 19.300796f, 19.456175f, -9.557284f, -9.704856f, -9.852428f, -10.f, -20.f,
|
||||||
-19.856981, -19.713963, -19.570944, 8.896924, 8.727221, 8.557517, 8.387813, 21.476097, 21.631475, 21.786854, 21.942233, -11.918438,
|
-19.856981f, -19.713963f, -19.570944f, 8.896924f, 8.727221f, 8.557517f, 8.387813f, 21.476097f, 21.631475f, 21.786854f, 21.942233f, -11.918438f,
|
||||||
-12.06601 , -12.213582, -12.361154, -17.7117, -17.568681, -17.425663, -17.282644}, nd4j::DataType::FLOAT32);
|
-12.06601f, -12.213582f, -12.361154f, -17.7117f, -17.568681f, -17.425663f, -17.282644f}, nd4j::DataType::FLOAT32);
|
||||||
input.linspace(0.1, 0.1);
|
input.linspace(0.1, 0.1);
|
||||||
|
|
||||||
nd4j::ops::batchnorm op;
|
nd4j::ops::batchnorm op;
|
||||||
|
|
||||||
auto results = op.execute({&input, &mean, &variance, &gamma, &beta}, {1e-5}, {1,1,1});
|
auto results = op.execute({&input, &mean, &variance, &gamma, &beta}, {1e-5}, {1, 1, 1});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
@ -3240,14 +3407,14 @@ TEST_F(DeclarableOpsTests10, batchnorm_test5) {
|
||||||
TEST_F(DeclarableOpsTests10, batchnorm_test6) {
|
TEST_F(DeclarableOpsTests10, batchnorm_test6) {
|
||||||
|
|
||||||
NDArray input ('c', {2,2,2,4}, nd4j::DataType::FLOAT32);
|
NDArray input ('c', {2,2,2,4}, nd4j::DataType::FLOAT32);
|
||||||
NDArray mean ('c', {4}, {1.05, 1.15, 1.2, 1.3}, nd4j::DataType::FLOAT32);
|
NDArray mean ('c', {4}, {1.05f, 1.15f, 1.2f, 1.3f}, nd4j::DataType::FLOAT32);
|
||||||
NDArray variance('c', {4}, {0.5, 0.7, 0.9, 1.1}, nd4j::DataType::FLOAT32);
|
NDArray variance('c', {4}, {0.5f, 0.7f, 0.9, 1.1f}, nd4j::DataType::FLOAT32);
|
||||||
NDArray gamma ('c', {4}, {-1.2, 1.3, -1.4, 1.5}, nd4j::DataType::FLOAT32);
|
NDArray gamma ('c', {4}, {-1.2f, 1.3f, -1.4f, 1.5f}, nd4j::DataType::FLOAT32);
|
||||||
NDArray beta ('c', {4}, {10, 20, -10, -20}, nd4j::DataType::FLOAT32);
|
NDArray beta ('c', {4}, {10.f, 20.f, -10.f, -20.f}, nd4j::DataType::FLOAT32);
|
||||||
|
|
||||||
NDArray expected('c', {2,2,2,4}, {11.612187, 18.523903, -8.671851, -21.287169, 10.933372, 19.145418, -9.262139, -20.715094, 10.254556, 19.766932, -9.852428, -20.143019, 9.57574 ,
|
NDArray expected('c', {2,2,2,4}, {11.612187f, 18.523903f, -8.671851f, -21.287169f, 10.933372f, 19.145418f, -9.262139f, -20.715094f, 10.254556f, 19.766932f, -9.852428f, -20.143019f, 9.57574f,
|
||||||
20.388447, -10.442716, -19.570944,8.896924, 21.009961, -11.033005, -18.998869, 8.218109, 21.631475, -11.623294, -18.426794, 7.539293, 22.25299 ,
|
20.388447f, -10.442716f, -19.570944f, 8.896924f, 21.009961f, -11.033005f, -18.998869f, 8.218109f, 21.631475f, -11.623294f, -18.426794f, 7.539293f, 22.25299f,
|
||||||
-12.213582, -17.854719, 6.860477, 22.874504, -12.803871, -17.282644}, nd4j::DataType::FLOAT32);
|
-12.213582f, -17.854719f, 6.860477f, 22.874504f, -12.803871f, -17.282644f}, nd4j::DataType::FLOAT32);
|
||||||
input.linspace(0.1, 0.1);
|
input.linspace(0.1, 0.1);
|
||||||
|
|
||||||
nd4j::ops::batchnorm op;
|
nd4j::ops::batchnorm op;
|
||||||
|
@ -3270,7 +3437,7 @@ TEST_F(DeclarableOpsTests10, bool_broadcast_test_1) {
|
||||||
NDArray arr1('c', {2,2,1}, {1, 2, 3, 4}, nd4j::DataType::INT32);
|
NDArray arr1('c', {2,2,1}, {1, 2, 3, 4}, nd4j::DataType::INT32);
|
||||||
NDArray arr2('c', { 2,2}, {0, 1, 0, 4}, nd4j::DataType::INT32);
|
NDArray arr2('c', { 2,2}, {0, 1, 0, 4}, nd4j::DataType::INT32);
|
||||||
|
|
||||||
NDArray expd('c', {2,2,2}, {0,1,0,0, 0,0,0,1}, nd4j::DataType::BOOL);
|
NDArray expd('c', {2,2,2}, {false, true, false, false, false, false, false, true}, nd4j::DataType::BOOL);
|
||||||
|
|
||||||
NDArray result('c', {2,2,2}, nd4j::DataType::BOOL);
|
NDArray result('c', {2,2,2}, nd4j::DataType::BOOL);
|
||||||
|
|
||||||
|
|
|
@ -1257,7 +1257,7 @@ TEST_F(DeclarableOpsTests12, inTopK_2) {
|
||||||
auto input = NDArrayFactory::create<double>('c', {4, 5});
|
auto input = NDArrayFactory::create<double>('c', {4, 5});
|
||||||
auto idx = NDArrayFactory::create<Nd4jLong>('c', {4});
|
auto idx = NDArrayFactory::create<Nd4jLong>('c', {4});
|
||||||
|
|
||||||
auto exp = NDArrayFactory::create<bool>({0, 0, 0, 1});
|
auto exp = NDArrayFactory::create<bool>({false, false, false, true});
|
||||||
|
|
||||||
int exclusive, reverse;
|
int exclusive, reverse;
|
||||||
input.linspace(1);
|
input.linspace(1);
|
||||||
|
@ -1318,7 +1318,7 @@ TEST_F(DeclarableOpsTests12, inTopK_4) {
|
||||||
TEST_F(DeclarableOpsTests12, inTopK_5) {
|
TEST_F(DeclarableOpsTests12, inTopK_5) {
|
||||||
auto x = NDArrayFactory::create<double>('f', {6, 4}, {11.0, 3.0, 14.0, 5.0, 6.0, 9.0, 3.5, 7.0, 21.0, 3.0, 14.0, 15.0, 6.0, 9.0, 3.5, 7.0, 11.0, 13.0, 14.0, 5.0, 16.0, 9.0, 13.5, 7.0} );
|
auto x = NDArrayFactory::create<double>('f', {6, 4}, {11.0, 3.0, 14.0, 5.0, 6.0, 9.0, 3.5, 7.0, 21.0, 3.0, 14.0, 15.0, 6.0, 9.0, 3.5, 7.0, 11.0, 13.0, 14.0, 5.0, 16.0, 9.0, 13.5, 7.0} );
|
||||||
auto y = NDArrayFactory::create<Nd4jLong>('f', {6}, {0, 0, 0, 0, 0, 0});
|
auto y = NDArrayFactory::create<Nd4jLong>('f', {6}, {0, 0, 0, 0, 0, 0});
|
||||||
auto expV = NDArrayFactory::create<bool>('f', {6}, {1, 0, 0, 0, 0, 0 });
|
auto expV = NDArrayFactory::create<bool>('f', {6}, {true, false, false, false, false, false });
|
||||||
|
|
||||||
nd4j::ops::in_top_k op;
|
nd4j::ops::in_top_k op;
|
||||||
auto result = op.execute({&x, &y}, {}, {2});
|
auto result = op.execute({&x, &y}, {}, {2});
|
||||||
|
|
|
@ -1167,12 +1167,12 @@ TEST_F(DeclarableOpsTests13, lstmLayer_3) {
|
||||||
std::initializer_list<Nd4jLong> iArgs = {dataFormat, directionMode, gateAct, cellAct, outAct};
|
std::initializer_list<Nd4jLong> iArgs = {dataFormat, directionMode, gateAct, cellAct, outAct};
|
||||||
std::initializer_list<bool> bArgs = {hasBiases, hasSeqLen, hasInitH, hasInitC, hasPH, retFullSeq, retLastH, retLastC};
|
std::initializer_list<bool> bArgs = {hasBiases, hasSeqLen, hasInitH, hasInitC, hasPH, retFullSeq, retLastH, retLastC};
|
||||||
|
|
||||||
NDArray expH('c', {sL, bS, nOut}, {0.493883, 0.493883, 0.493883, 0.510990, 0.510990, 0.510990, 0.534701, 0.534701, 0.534701, 0.549139,
|
NDArray expH('c', {sL, bS, nOut}, {0.493883f, 0.493883f, 0.493883f, 0.510990f, 0.510990f, 0.510990f, 0.534701f, 0.534701f, 0.534701f, 0.549139f,
|
||||||
0.549139, 0.549139, 0.571900, 0.571900, 0.571900, 0.583561, 0.583561, 0.583561, 0.605106, 0.605106,
|
0.549139f, 0.549139f, 0.571900f, 0.571900f, 0.571900f, 0.583561f, 0.583561f, 0.583561f, 0.605106f, 0.605106f,
|
||||||
0.605106, 0.614114, 0.614114, 0.614114, 0.635354, 0.635354, 0.635354, 0.642045, 0.642045, 0.642045}, nd4j::DataType::FLOAT32);
|
0.605106f, 0.614114f, 0.614114f, 0.614114f, 0.635354f, 0.635354f, 0.635354f, 0.642045f, 0.642045f, 0.642045f}, nd4j::DataType::FLOAT32);
|
||||||
|
|
||||||
NDArray expHL('c', {bS, nOut}, {0.493883, 0.493883, 0.493883, 0.510990, 0.510990, 0.510990}, nd4j::DataType::FLOAT32);
|
NDArray expHL('c', {bS, nOut}, {0.493883f, 0.493883f, 0.493883f, 0.510990f, 0.510990f, 0.510990f}, nd4j::DataType::FLOAT32);
|
||||||
NDArray expCL('c', {bS, nOut}, {1.061274, 1.061274, 1.061274, 1.115888, 1.115888, 1.115888}, nd4j::DataType::FLOAT32);
|
NDArray expCL('c', {bS, nOut}, {1.061274f, 1.061274f, 1.061274f, 1.115888f, 1.115888f, 1.115888f}, nd4j::DataType::FLOAT32);
|
||||||
|
|
||||||
nd4j::ops::lstmLayer op;
|
nd4j::ops::lstmLayer op;
|
||||||
auto results = op.execute({&x, &Wx, &Wr, &b, &hI, &cI}, tArgs, iArgs, bArgs);
|
auto results = op.execute({&x, &Wx, &Wr, &b, &hI, &cI}, tArgs, iArgs, bArgs);
|
||||||
|
@ -1230,12 +1230,12 @@ TEST_F(DeclarableOpsTests13, lstmLayer_4) {
|
||||||
NDArray cI('c', {2,bS, nOut}, nd4j::DataType::FLOAT32);
|
NDArray cI('c', {2,bS, nOut}, nd4j::DataType::FLOAT32);
|
||||||
|
|
||||||
x.linspace(0.5, 0.5);
|
x.linspace(0.5, 0.5);
|
||||||
Wx({0,1, 0,0, 0,0}) = 0.003;
|
Wx({0,1, 0,0, 0,0}) = 0.003f;
|
||||||
Wx({1,2, 0,0, 0,0}) = -0.003;
|
Wx({1,2, 0,0, 0,0}) = -0.003f;
|
||||||
Wr({0,1, 0,0, 0,0}) = 0.006;
|
Wr({0,1, 0,0, 0,0}) = 0.006f;
|
||||||
Wr({1,2, 0,0, 0,0}) = -0.006;
|
Wr({1,2, 0,0, 0,0}) = -0.006f;
|
||||||
b({0,1, 0,0}) = 0.5;
|
b({0,1, 0,0}) = 0.5f;
|
||||||
b({1,2, 0,0}) = -0.5;
|
b({1,2, 0,0}) = -0.5f;
|
||||||
hI({0,1, 0,0, 0,0}) = 1;
|
hI({0,1, 0,0, 0,0}) = 1;
|
||||||
hI({1,2, 0,0, 0,0}) = -1;
|
hI({1,2, 0,0, 0,0}) = -1;
|
||||||
cI({0,1, 0,0, 0,0}) = 2;
|
cI({0,1, 0,0, 0,0}) = 2;
|
||||||
|
@ -1245,18 +1245,19 @@ TEST_F(DeclarableOpsTests13, lstmLayer_4) {
|
||||||
std::initializer_list<Nd4jLong> iArgs = {dataFormat, directionMode, gateAct, cellAct, outAct};
|
std::initializer_list<Nd4jLong> iArgs = {dataFormat, directionMode, gateAct, cellAct, outAct};
|
||||||
std::initializer_list<bool> bArgs = {hasBiases, hasSeqLen, hasInitH, hasInitC, hasPH, retFullSeq, retLastH, retLastC};
|
std::initializer_list<bool> bArgs = {hasBiases, hasSeqLen, hasInitH, hasInitC, hasPH, retFullSeq, retLastH, retLastC};
|
||||||
|
|
||||||
NDArray expH('c', {sL, bS, 2*nOut}, {0.577661, 0.577661, 0.577661, -0.107642, -0.107642, -0.107642, 0.585289, 0.585289, 0.585289,
|
NDArray expH('c', {sL, bS, 2 * nOut}, {
|
||||||
-0.106937, -0.106937, -0.106937, 0.556517, 0.556517, 0.556517, -0.111647, -0.111647, -0.111647,
|
0.577661f, 0.577661f, 0.577661f, -0.107642f, -0.107642f, -0.107642f, 0.585289f, 0.585289f, 0.585289f,
|
||||||
0.567274, 0.567274, 0.567274, -0.110214, -0.110214, -0.110214, 0.547395, 0.547395, 0.547395,
|
-0.106937f, -0.106937f, -0.106937f, 0.556517f, 0.556517f, 0.556517f, -0.111647f, -0.111647f, -0.111647f,
|
||||||
-0.123305, -0.123305, -0.123305, 0.560640, 0.560640, 0.560640, -0.120862, -0.120862, -0.120862,
|
0.567274f, 0.567274f, 0.567274f, -0.110214f, -0.110214f, -0.110214f, 0.547395f, 0.547395f, 0.547395f,
|
||||||
0.550714, 0.550714, 0.550714, -0.156223, -0.156223, -0.156223, 0.565308, 0.565308, 0.565308,
|
-0.123305f, -0.123305f, -0.123305f, 0.560640f, 0.560640f, 0.560640f, -0.120862f, -0.120862f, -0.120862f,
|
||||||
-0.152313, -0.152313, -0.152313, 0.563741, 0.563741, 0.563741, -0.234128, -0.234128, -0.234128,
|
0.550714f, 0.550714f, 0.550714f, -0.156223f, -0.156223f, -0.156223f, 0.565308f, 0.565308f, 0.565308f,
|
||||||
0.578676, 0.578676, 0.578676, -0.228917, -0.228917, -0.228917}, nd4j::DataType::FLOAT32);
|
-0.152313f, -0.152313f, -0.152313f, 0.563741f, 0.563741f, 0.563741f, -0.234128f, -0.234128f, -0.234128f,
|
||||||
|
0.578676f, 0.578676f, 0.578676f, -0.228917f, -0.228917f, -0.228917f}, nd4j::DataType::FLOAT32);
|
||||||
|
|
||||||
NDArray expHL('c', {2,bS, nOut}, {0.563741, 0.563741, 0.563741, 0.578676, 0.578676, 0.578676, -0.107642,
|
NDArray expHL('c', {2,bS, nOut}, {0.563741f, 0.563741f, 0.563741f, 0.578676f, 0.578676f, 0.578676f, -0.107642f,
|
||||||
-0.107642, -0.107642, -0.106937, -0.106937, -0.106937}, nd4j::DataType::FLOAT32);
|
-0.107642f, -0.107642f, -0.106937f, -0.106937f, -0.106937f}, nd4j::DataType::FLOAT32);
|
||||||
NDArray expCL('c', {2,bS, nOut}, {1.217757, 1.217757, 1.217757, 1.272398, 1.272398, 1.272398, -0.295768,
|
NDArray expCL('c', {2,bS, nOut}, {1.217757f, 1.217757f, 1.217757f, 1.272398f, 1.272398f, 1.272398f, -0.295768f,
|
||||||
-0.295768, -0.295768, -0.298453, -0.298453, -0.298453}, nd4j::DataType::FLOAT32);
|
-0.295768f, -0.295768f, -0.298453f, -0.298453f, -0.298453f}, nd4j::DataType::FLOAT32);
|
||||||
|
|
||||||
nd4j::ops::lstmLayer op;
|
nd4j::ops::lstmLayer op;
|
||||||
auto results = op.execute({&x, &Wx, &Wr, &b, &hI, &cI}, tArgs, iArgs, bArgs);
|
auto results = op.execute({&x, &Wx, &Wr, &b, &hI, &cI}, tArgs, iArgs, bArgs);
|
||||||
|
@ -1328,16 +1329,17 @@ TEST_F(DeclarableOpsTests13, lstmLayer_5) {
|
||||||
std::initializer_list<Nd4jLong> iArgs = {dataFormat, directionMode, gateAct, cellAct, outAct};
|
std::initializer_list<Nd4jLong> iArgs = {dataFormat, directionMode, gateAct, cellAct, outAct};
|
||||||
std::initializer_list<bool> bArgs = {hasBiases, hasSeqLen, hasInitH, hasInitC, hasPH, retFullSeq, retLastH, retLastC};
|
std::initializer_list<bool> bArgs = {hasBiases, hasSeqLen, hasInitH, hasInitC, hasPH, retFullSeq, retLastH, retLastC};
|
||||||
|
|
||||||
NDArray expH('c', {bS, sL, 2*nOut}, {0.577661, 0.577661, 0.577661, -0.107659, -0.107659, -0.107659, 0.548099, 0.548099, 0.548099, -0.113406, -0.113406, -0.113406,
|
NDArray expH('c', {bS, sL, 2*nOut}, {
|
||||||
0.526881, 0.526881, 0.526881, -0.12883 , -0.12883 , -0.12883 , 0.515882, 0.515882, 0.515882, -0.16868 , -0.16868 , -0.16868 ,
|
0.577661f, 0.577661f, 0.577661f, -0.107659f, -0.107659f, -0.107659f, 0.548099f, 0.548099f, 0.548099f, -0.113406f, -0.113406f, -0.113406f,
|
||||||
0.51409 , 0.51409 , 0.51409 , -0.255185, -0.255185, -0.255185, 0.614599, 0.614599, 0.614599, -0.102739, -0.102739, -0.102739,
|
0.526881f, 0.526881f, 0.526881f, -0.12883f, -0.12883f, -0.12883f, 0.515882f, 0.515882f, 0.515882f, -0.16868f, -0.16868f, -0.16868f,
|
||||||
0.599572, 0.599572, 0.599572, -0.105802, -0.105802, -0.105802,0.591089, 0.591089, 0.591089, -0.116681, -0.116681, -0.116681,
|
0.51409f, 0.51409f, 0.51409f, -0.255185f, -0.255185f, -0.255185f, 0.614599f, 0.614599f, 0.614599f, -0.102739f, -0.102739f, -0.102739f,
|
||||||
0.588694, 0.588694, 0.588694, -0.149201, -0.149201, -0.149201,0.591492, 0.591492, 0.591492, -0.228917, -0.228917, -0.228917}, nd4j::DataType::FLOAT32);
|
0.599572f, 0.599572f, 0.599572f, -0.105802f, -0.105802f, -0.105802f, 0.591089f, 0.591089f, 0.591089f, -0.116681f, -0.116681f, -0.116681f,
|
||||||
|
0.588694f, 0.588694f, 0.588694f, -0.149201f, -0.149201f, -0.149201f, 0.591492f, 0.591492f, 0.591492f, -0.228917f, -0.228917f, -0.228917f}, nd4j::DataType::FLOAT32);
|
||||||
|
|
||||||
NDArray expHL('c', {2,bS, nOut}, {0.51409 , 0.51409 , 0.51409 , 0.591492, 0.591492, 0.591492,
|
NDArray expHL('c', {2,bS, nOut}, {0.51409f, 0.51409f, 0.51409f, 0.591492f, 0.591492f, 0.591492f,
|
||||||
-0.107659, -0.107659, -0.107659, -0.102739, -0.102739, -0.102739}, nd4j::DataType::FLOAT32);
|
-0.107659f, -0.107659f, -0.107659f, -0.102739f, -0.102739f, -0.102739f}, nd4j::DataType::FLOAT32);
|
||||||
NDArray expCL('c', {2,bS, nOut}, {1.07293 , 1.07293 , 1.07293,1.346609, 1.346609, 1.346609,
|
NDArray expCL('c', {2,bS, nOut}, {1.07293f , 1.07293f , 1.07293f, 1.346609f, 1.346609f, 1.346609f,
|
||||||
-0.295811, -0.295811, -0.295811,-0.305394, -0.305394, -0.305394}, nd4j::DataType::FLOAT32);
|
-0.295811f, -0.295811f, -0.295811f, -0.305394f, -0.305394f, -0.305394f}, nd4j::DataType::FLOAT32);
|
||||||
|
|
||||||
nd4j::ops::lstmLayer op;
|
nd4j::ops::lstmLayer op;
|
||||||
auto results = op.execute({&x, &Wx, &Wr, &b, &hI, &cI}, tArgs, iArgs, bArgs);
|
auto results = op.execute({&x, &Wx, &Wr, &b, &hI, &cI}, tArgs, iArgs, bArgs);
|
||||||
|
@ -1398,12 +1400,12 @@ TEST_F(DeclarableOpsTests13, lstmLayer_6) {
|
||||||
NDArray cI('c', {2,bS, nOut}, nd4j::DataType::FLOAT32);
|
NDArray cI('c', {2,bS, nOut}, nd4j::DataType::FLOAT32);
|
||||||
|
|
||||||
x.linspace(0.5, 0.5);
|
x.linspace(0.5, 0.5);
|
||||||
Wx({0,1, 0,0, 0,0}) = 0.003;
|
Wx({0,1, 0,0, 0,0}) = 0.003f;
|
||||||
Wx({1,2, 0,0, 0,0}) = -0.003;
|
Wx({1,2, 0,0, 0,0}) = -0.003f;
|
||||||
Wr({0,1, 0,0, 0,0}) = 0.006;
|
Wr({0,1, 0,0, 0,0}) = 0.006f;
|
||||||
Wr({1,2, 0,0, 0,0}) = -0.006;
|
Wr({1,2, 0,0, 0,0}) = -0.006f;
|
||||||
b({0,1, 0,0}) = 0.5;
|
b({0,1, 0,0}) = 0.5f;
|
||||||
b({1,2, 0,0}) = -0.5;
|
b({1,2, 0,0}) = -0.5f;
|
||||||
hI({0,1, 0,0, 0,0}) = 1;
|
hI({0,1, 0,0, 0,0}) = 1;
|
||||||
hI({1,2, 0,0, 0,0}) = -1;
|
hI({1,2, 0,0, 0,0}) = -1;
|
||||||
cI({0,1, 0,0, 0,0}) = 2;
|
cI({0,1, 0,0, 0,0}) = 2;
|
||||||
|
@ -1413,14 +1415,17 @@ TEST_F(DeclarableOpsTests13, lstmLayer_6) {
|
||||||
std::initializer_list<Nd4jLong> iArgs = {dataFormat, directionMode, gateAct, cellAct, outAct};
|
std::initializer_list<Nd4jLong> iArgs = {dataFormat, directionMode, gateAct, cellAct, outAct};
|
||||||
std::initializer_list<bool> bArgs = {hasBiases, hasSeqLen, hasInitH, hasInitC, hasPH, retFullSeq, retLastH, retLastC};
|
std::initializer_list<bool> bArgs = {hasBiases, hasSeqLen, hasInitH, hasInitC, hasPH, retFullSeq, retLastH, retLastC};
|
||||||
|
|
||||||
NDArray expH('c', {sL, bS, nOut}, {0.470019, 0.470019, 0.470019, 0.478352, 0.478352, 0.478352, 0.444871, 0.444871, 0.444871, 0.457060,
|
NDArray expH('c', {sL, bS, nOut}, {
|
||||||
0.457060, 0.457060, 0.424090, 0.424090, 0.424090, 0.439778, 0.439778, 0.439778, 0.394491, 0.394491,
|
0.470019f, 0.470019f, 0.470019f, 0.478352f, 0.478352f, 0.478352f, 0.444871f, 0.444871f, 0.444871f, 0.457060f,
|
||||||
0.394491, 0.412995, 0.412995, 0.412995, 0.329613, 0.329613, 0.329613, 0.349760, 0.349760, 0.349760}, nd4j::DataType::FLOAT32);
|
0.457060f, 0.457060f, 0.424090f, 0.424090f, 0.424090f, 0.439778f, 0.439778f, 0.439778f, 0.394491f, 0.394491f,
|
||||||
|
0.394491f, 0.412995f, 0.412995f, 0.412995f, 0.329613f, 0.329613f, 0.329613f, 0.349760f, 0.349760f, 0.349760f}, nd4j::DataType::FLOAT32);
|
||||||
|
|
||||||
NDArray expHL('c', {2,bS, nOut}, {0.563741, 0.563741, 0.563741, 0.578676, 0.578676, 0.578676, -0.107642,
|
NDArray expHL('c', {2,bS, nOut}, {0.563741f, 0.563741f, 0.563741f, 0.578676f, 0.578676f, 0.578676f,
|
||||||
-0.107642, -0.107642, -0.106937, -0.106937, -0.106937}, nd4j::DataType::FLOAT32);
|
-0.107642f, -0.107642f, -0.107642f, -0.106937f, -0.106937f, -0.106937f},
|
||||||
NDArray expCL('c', {2,bS, nOut}, {1.217757, 1.217757, 1.217757, 1.272398, 1.272398, 1.272398, -0.295768,
|
nd4j::DataType::FLOAT32);
|
||||||
-0.295768, -0.295768, -0.298453, -0.298453, -0.298453}, nd4j::DataType::FLOAT32);
|
NDArray expCL('c', {2,bS, nOut}, {1.217757f, 1.217757f, 1.217757f, 1.272398f, 1.272398f, 1.272398f,
|
||||||
|
-0.295768f, -0.295768f, -0.295768f, -0.298453f, -0.298453f, -0.298453f},
|
||||||
|
nd4j::DataType::FLOAT32);
|
||||||
|
|
||||||
nd4j::ops::lstmLayer op;
|
nd4j::ops::lstmLayer op;
|
||||||
auto results = op.execute({&x, &Wx, &Wr, &b, &hI, &cI}, tArgs, iArgs, bArgs);
|
auto results = op.execute({&x, &Wx, &Wr, &b, &hI, &cI}, tArgs, iArgs, bArgs);
|
||||||
|
@ -1568,12 +1573,13 @@ TEST_F(DeclarableOpsTests13, lstmLayer_8) {
|
||||||
std::initializer_list<Nd4jLong> iArgs = {dataFormat, directionMode, gateAct, cellAct, outAct};
|
std::initializer_list<Nd4jLong> iArgs = {dataFormat, directionMode, gateAct, cellAct, outAct};
|
||||||
std::initializer_list<bool> bArgs = {hasBiases, hasSeqLen, hasInitH, hasInitC, hasPH, retFullSeq, retLastH, retLastC};
|
std::initializer_list<bool> bArgs = {hasBiases, hasSeqLen, hasInitH, hasInitC, hasPH, retFullSeq, retLastH, retLastC};
|
||||||
|
|
||||||
NDArray expH('c', {sL, bS, nOut}, {0.436221, 0.436221, 0.436221,0.450573, 0.450573, 0.450573,0.463602, 0.463602, 0.463602, 0.474674, 0.474674, 0.474674,
|
NDArray expH('c', {sL, bS, nOut}, {
|
||||||
0.484039, 0.484039, 0.484039,0.490679, 0.490679, 0.490679, 0.494871, 0.494871, 0.494871, 0.499028, 0.499028, 0.499028,
|
0.436221f, 0.436221f, 0.436221f, 0.450573f, 0.450573f, 0.450573f, 0.463602f, 0.463602f, 0.463602f, 0.474674f, 0.474674f, 0.474674f,
|
||||||
0.504649, 0.504649, 0.504649, 0.508719, 0.508719, 0.508719}, nd4j::DataType::FLOAT32);
|
0.484039f, 0.484039f, 0.484039f, 0.490679f, 0.490679f, 0.490679f, 0.494871f, 0.494871f, 0.494871f, 0.499028f, 0.499028f, 0.499028f,
|
||||||
|
0.504649f, 0.504649f, 0.504649f, 0.508719f, 0.508719f, 0.508719f}, nd4j::DataType::FLOAT32);
|
||||||
|
|
||||||
NDArray expHL('c', {bS, nOut}, {0.436221, 0.436221, 0.436221, 0.450573, 0.450573, 0.450573}, nd4j::DataType::FLOAT32);
|
NDArray expHL('c', {bS, nOut}, {0.436221f, 0.436221f, 0.436221f, 0.450573f, 0.450573f, 0.450573f}, nd4j::DataType::FLOAT32);
|
||||||
NDArray expCL('c', {bS, nOut}, {0.879804, 0.879804, 0.879804,0.914666, 0.914666, 0.914666}, nd4j::DataType::FLOAT32);
|
NDArray expCL('c', {bS, nOut}, {0.879804f, 0.879804f, 0.879804f, 0.914666f, 0.914666f, 0.914666f}, nd4j::DataType::FLOAT32);
|
||||||
|
|
||||||
nd4j::ops::lstmLayer op;
|
nd4j::ops::lstmLayer op;
|
||||||
auto results = op.execute({&x, &Wx, &Wr, &b, &hI, &cI, &Wp}, tArgs, iArgs, bArgs);
|
auto results = op.execute({&x, &Wx, &Wr, &b, &hI, &cI, &Wp}, tArgs, iArgs, bArgs);
|
||||||
|
@ -1650,16 +1656,17 @@ TEST_F(DeclarableOpsTests13, lstmLayer_9) {
|
||||||
std::initializer_list<Nd4jLong> iArgs = {dataFormat, directionMode, gateAct, cellAct, outAct};
|
std::initializer_list<Nd4jLong> iArgs = {dataFormat, directionMode, gateAct, cellAct, outAct};
|
||||||
std::initializer_list<bool> bArgs = {hasBiases, hasSeqLen, hasInitH, hasInitC, hasPH, retFullSeq, retLastH, retLastC};
|
std::initializer_list<bool> bArgs = {hasBiases, hasSeqLen, hasInitH, hasInitC, hasPH, retFullSeq, retLastH, retLastC};
|
||||||
|
|
||||||
NDArray expH('c', {sL, bS, 2*nOut}, { 0.55533 , 0.55533 , 0.55533 , -0.104502, -0.104502, -0.104502, 0.562925, 0.562925, 0.562925, -0.103843, -0.103843, -0.103843,
|
NDArray expH('c', {sL, bS, 2*nOut}, {
|
||||||
0.531795, 0.531795, 0.531795, -0.107456, -0.107456, -0.107456,0.542556, 0.542556, 0.542556, -0.106139, -0.106139, -0.106139,
|
0.55533f, 0.55533f, 0.55533f, -0.104502f, -0.104502f, -0.104502f, 0.562925f, 0.562925f, 0.562925f, -0.103843f, -0.103843f, -0.103843f,
|
||||||
0.521466, 0.521466, 0.521466, -0.11681 , -0.11681 , -0.11681 , 0.534638, 0.534638, 0.534638, -0.11458 , -0.11458 , -0.11458 ,
|
0.531795f, 0.531795f, 0.531795f, -0.107456f, -0.107456f, -0.107456f, 0.542556f, 0.542556f, 0.542556f, -0.106139f, -0.106139f, -0.106139f,
|
||||||
0.524805, 0.524805, 0.524805, -0.145177, -0.145177, -0.145177,0.539187, 0.539187, 0.539187, -0.14157 , -0.14157 , -0.14157 ,
|
0.521466f, 0.521466f, 0.521466f, -0.11681f, -0.11681f, -0.11681f, 0.534638f, 0.534638f, 0.534638f, -0.11458f, -0.11458f, -0.11458f,
|
||||||
0.538309, 0.538309, 0.538309, -0.218056, -0.218056, -0.218056,0.552923, 0.552923, 0.552923, -0.213068, -0.213068, -0.213068}, nd4j::DataType::FLOAT32);
|
0.524805f, 0.524805f, 0.524805f, -0.145177f, -0.145177f, -0.145177f, 0.539187f, 0.539187f, 0.539187f, -0.14157f, -0.14157f, -0.14157f,
|
||||||
|
0.538309f, 0.538309f, 0.538309f, -0.218056f, -0.218056f, -0.218056f, 0.552923f, 0.552923f, 0.552923f, -0.213068f, -0.213068f, -0.213068f}, nd4j::DataType::FLOAT32);
|
||||||
|
|
||||||
NDArray expHL('c', {2,bS, nOut}, {0.538309, 0.538309, 0.538309, 0.552923, 0.552923, 0.552923, -0.104502, -0.104502, -0.104502,
|
NDArray expHL('c', {2,bS, nOut}, {0.538309f, 0.538309f, 0.538309f, 0.552923f, 0.552923f, 0.552923f, -0.104502f, -0.104502f, -0.104502f,
|
||||||
-0.103843, -0.103843, -0.103843}, nd4j::DataType::FLOAT32);
|
-0.103843f, -0.103843f, -0.103843f}, nd4j::DataType::FLOAT32);
|
||||||
NDArray expCL('c', {2,bS, nOut}, {1.147089, 1.147089, 1.147089, 1.197228, 1.197228, 1.197228, -0.289425, -0.289425, -0.289425,
|
NDArray expCL('c', {2,bS, nOut}, {1.147089f, 1.147089f, 1.147089f, 1.197228f, 1.197228f, 1.197228f, -0.289425f, -0.289425f, -0.289425f,
|
||||||
-0.292174, -0.292174, -0.292174}, nd4j::DataType::FLOAT32);
|
-0.292174f, -0.292174f, -0.292174f}, nd4j::DataType::FLOAT32);
|
||||||
|
|
||||||
nd4j::ops::lstmLayer op;
|
nd4j::ops::lstmLayer op;
|
||||||
auto results = op.execute({&x, &Wx, &Wr, &b, &hI, &cI, &Wp}, tArgs, iArgs, bArgs);
|
auto results = op.execute({&x, &Wx, &Wr, &b, &hI, &cI, &Wp}, tArgs, iArgs, bArgs);
|
||||||
|
@ -1731,14 +1738,20 @@ TEST_F(DeclarableOpsTests13, lstmLayer_10) {
|
||||||
std::initializer_list<Nd4jLong> iArgs = {dataFormat, directionMode, gateAct, cellAct, outAct};
|
std::initializer_list<Nd4jLong> iArgs = {dataFormat, directionMode, gateAct, cellAct, outAct};
|
||||||
std::initializer_list<bool> bArgs = {hasBiases, hasSeqLen, hasInitH, hasInitC, hasPH, retFullSeq, retLastH, retLastC};
|
std::initializer_list<bool> bArgs = {hasBiases, hasSeqLen, hasInitH, hasInitC, hasPH, retFullSeq, retLastH, retLastC};
|
||||||
|
|
||||||
NDArray expH('c', {sL, bS, nOut}, {0., 0., 0., 0.562925, 0.562925, 0.562925, 0.570404, 0.570404, 0.570404, 0.57777 , 0.57777 , 0.57777 , 0.585023, 0.585023, 0.585023,
|
NDArray expH('c', {sL, bS, nOut}, {
|
||||||
0., 0., 0., 0., 0., 0., 0.576568, 0.576568, 0.576568, 0.586163, 0.586163, 0.586163, 0.595462, 0.595462, 0.595462, 0., 0., 0., 0., 0.,
|
0.f, 0.f, 0.f, 0.562925f, 0.562925f, 0.562925f, 0.570404f, 0.570404f, 0.570404f, 0.57777f,
|
||||||
0., 0., 0., 0., 0.611224, 0.611224, 0.611224, 0.621298, 0.621298, 0.621298, 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
|
0.57777f, 0.57777f, 0.585023f, 0.585023f, 0.585023f, 0.f, 0.f, 0.f, 0.f, 0.f,
|
||||||
0.655858, 0.655858, 0.655858, 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.692315, 0.692315, 0.692315, 0., 0., 0., 0., 0., 0.,
|
0.f, 0.576568f, 0.576568f, 0.576568f, 0.586163f, 0.586163f, 0.586163f, 0.595462f, 0.595462f, 0.595462f,
|
||||||
0., 0., 0., 0., 0., 0., 0., 0., 0.}, nd4j::DataType::FLOAT32);
|
0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.611224f,
|
||||||
|
0.611224f, 0.611224f, 0.621298f, 0.621298f, 0.621298f, 0.f, 0.f, 0.f, 0.f, 0.f,
|
||||||
|
0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.655858f, 0.655858f, 0.655858f,
|
||||||
|
0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f,
|
||||||
|
0.f, 0.f, 0.692315f, 0.692315f, 0.692315f, 0.f, 0.f, 0.f, 0.f, 0.f,
|
||||||
|
0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f},
|
||||||
|
nd4j::DataType::FLOAT32);
|
||||||
|
|
||||||
NDArray expHL('c', {bS, nOut}, {0., 0., 0., 0.562925, 0.562925, 0.562925, 0.576568, 0.576568, 0.576568, 0.611224, 0.611224, 0.611224, 0.692315, 0.692315, 0.692315}, nd4j::DataType::FLOAT32);
|
NDArray expHL('c', {bS, nOut}, {0.f, 0.f, 0.f, 0.562925f, 0.562925f, 0.562925f, 0.576568f, 0.576568f, 0.576568f, 0.611224f, 0.611224f, 0.611224f, 0.692315f, 0.692315f, 0.692315f}, nd4j::DataType::FLOAT32);
|
||||||
NDArray expCL('c', {bS, nOut}, {0., 0., 0., 1.534275, 1.534275, 1.534275, 1.40183, 1.40183, 1.40183, 1.449675, 1.449675, 1.449675, 1.767702, 1.767702, 1.767702}, nd4j::DataType::FLOAT32);
|
NDArray expCL('c', {bS, nOut}, {0.f, 0.f, 0.f, 1.534275f, 1.534275f, 1.534275f, 1.40183f, 1.40183f, 1.40183f, 1.449675f, 1.449675f, 1.449675f, 1.767702f, 1.767702f, 1.767702f}, nd4j::DataType::FLOAT32);
|
||||||
|
|
||||||
nd4j::ops::lstmLayer op;
|
nd4j::ops::lstmLayer op;
|
||||||
auto results = op.execute({&x, &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp}, tArgs, iArgs, bArgs);
|
auto results = op.execute({&x, &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp}, tArgs, iArgs, bArgs);
|
||||||
|
@ -1799,25 +1812,26 @@ TEST_F(DeclarableOpsTests13, lstmLayer_11) {
|
||||||
NDArray Wp('c', {3*nOut}, nd4j::DataType::FLOAT32);
|
NDArray Wp('c', {3*nOut}, nd4j::DataType::FLOAT32);
|
||||||
|
|
||||||
x.linspace(0.5, 0.5);
|
x.linspace(0.5, 0.5);
|
||||||
Wx = 0.003;
|
Wx = 0.003f;
|
||||||
Wr = 0.006;
|
Wr = 0.006f;
|
||||||
b = 0.5;
|
b = 0.5f;
|
||||||
hI = 1.;
|
hI = 1.f;
|
||||||
cI = 2.;
|
cI = 2.f;
|
||||||
Wp = -0.05;
|
Wp = -0.05f;
|
||||||
|
|
||||||
std::initializer_list<double> tArgs = {cellClip};
|
std::initializer_list<double> tArgs = {cellClip};
|
||||||
std::initializer_list<Nd4jLong> iArgs = {dataFormat, directionMode, gateAct, cellAct, outAct};
|
std::initializer_list<Nd4jLong> iArgs = {dataFormat, directionMode, gateAct, cellAct, outAct};
|
||||||
std::initializer_list<bool> bArgs = {hasBiases, hasSeqLen, hasInitH, hasInitC, hasPH, retFullSeq, retLastH, retLastC};
|
std::initializer_list<bool> bArgs = {hasBiases, hasSeqLen, hasInitH, hasInitC, hasPH, retFullSeq, retLastH, retLastC};
|
||||||
|
|
||||||
NDArray expH('c', {sL, bS, nOut}, {0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.61209,
|
NDArray expH('c', {sL, bS, nOut}, {
|
||||||
0.61209, 0.61209,0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.652042, 0.652042, 0.652042, 0., 0., 0., 0., 0.,
|
0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.61209f,
|
||||||
0., 0., 0., 0., 0.677708, 0.677708, 0.677708, 0.684177, 0.684177, 0.684177, 0., 0., 0.,0., 0., 0.,0.699627, 0.699627,
|
0.61209f, 0.61209f,0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.652042f, 0.652042f, 0.652042f, 0.f, 0.f, 0.f, 0.f, 0.f,
|
||||||
0.699627,0.705371, 0.705371, 0.705371,0.710989, 0.710989, 0.710989, 0., 0., 0., 0.719014, 0.719014, 0.719014, 0.724087,
|
0.f, 0.f, 0.f, 0.f, 0.677708f, 0.677708f, 0.677708f, 0.684177f, 0.684177f, 0.684177f, 0.f, 0.f, 0.f,0.f, 0.f, 0.f, 0.699627f, 0.699627f,
|
||||||
0.724087, 0.724087, 0.729084, 0.729084, 0.729084, 0.734004, 0.734004, 0.734004 }, nd4j::DataType::FLOAT32);
|
0.699627f, 0.705371f, 0.705371f, 0.705371f, 0.710989f, 0.710989f, 0.710989f, 0., 0., 0., 0.719014, 0.719014, 0.719014, 0.724087,
|
||||||
|
0.724087f, 0.724087f, 0.729084f, 0.729084f, 0.729084f, 0.734004f, 0.734004f, 0.734004f }, nd4j::DataType::FLOAT32);
|
||||||
|
|
||||||
NDArray expHL('c', {bS, nOut}, {0., 0., 0., 0.719014, 0.719014, 0.719014, 0.699627, 0.699627, 0.699627, 0.677708, 0.677708, 0.677708, 0.61209, 0.61209, 0.61209}, nd4j::DataType::FLOAT32);
|
NDArray expHL('c', {bS, nOut}, {0.f, 0.f, 0.f, 0.719014f, 0.719014f, 0.719014f, 0.699627f, 0.699627f, 0.699627f, 0.677708f, 0.677708f, 0.677708f, 0.61209f, 0.61209f, 0.61209f}, nd4j::DataType::FLOAT32);
|
||||||
NDArray expCL('c', {bS, nOut}, {0., 0., 0., 2.092814, 2.092814, 2.092814, 2.08832, 2.08832, 2.08832, 2.009851, 2.009851, 2.009851, 1.646034, 1.646034, 1.646034}, nd4j::DataType::FLOAT32);
|
NDArray expCL('c', {bS, nOut}, {0.f, 0.f, 0.f, 2.092814f, 2.092814f, 2.092814f, 2.08832f, 2.08832f, 2.08832f, 2.009851f, 2.009851f, 2.009851f, 1.646034f, 1.646034f, 1.646034f}, nd4j::DataType::FLOAT32);
|
||||||
|
|
||||||
nd4j::ops::lstmLayer op;
|
nd4j::ops::lstmLayer op;
|
||||||
auto results = op.execute({&x, &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp}, tArgs, iArgs, bArgs);
|
auto results = op.execute({&x, &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp}, tArgs, iArgs, bArgs);
|
||||||
|
@ -1878,18 +1892,18 @@ TEST_F(DeclarableOpsTests13, lstmLayer_12) {
|
||||||
NDArray Wp('c', {2,3*nOut}, nd4j::DataType::FLOAT32);
|
NDArray Wp('c', {2,3*nOut}, nd4j::DataType::FLOAT32);
|
||||||
|
|
||||||
x.linspace(0.5, 0.5);
|
x.linspace(0.5, 0.5);
|
||||||
Wx({0,1, 0,0, 0,0}) = 0.003;
|
Wx({0,1, 0,0, 0,0}) = 0.003f;
|
||||||
Wx({1,2, 0,0, 0,0}) = -0.003;
|
Wx({1,2, 0,0, 0,0}) = -0.003f;
|
||||||
Wr({0,1, 0,0, 0,0}) = 0.006;
|
Wr({0,1, 0,0, 0,0}) = 0.006f;
|
||||||
Wr({1,2, 0,0, 0,0}) = -0.006;
|
Wr({1,2, 0,0, 0,0}) = -0.006f;
|
||||||
b({0,1, 0,0}) = 0.5;
|
b({0,1, 0,0}) = 0.5f;
|
||||||
b({1,2, 0,0}) = -0.5;
|
b({1,2, 0,0}) = -0.5f;
|
||||||
hI({0,1, 0,0, 0,0}) = 1;
|
hI({0,1, 0,0, 0,0}) = 1;
|
||||||
hI({1,2, 0,0, 0,0}) = -1;
|
hI({1,2, 0,0, 0,0}) = -1;
|
||||||
cI({0,1, 0,0, 0,0}) = 2;
|
cI({0,1, 0,0, 0,0}) = 2;
|
||||||
cI({1,2, 0,0, 0,0}) = -2;
|
cI({1,2, 0,0, 0,0}) = -2;
|
||||||
Wp({0,1, 0,0}) = -0.05;
|
Wp({0,1, 0,0}) = -0.05f;
|
||||||
Wp({1,2, 0,0}) = 0.05;
|
Wp({1,2, 0,0}) = 0.05f;
|
||||||
|
|
||||||
std::initializer_list<double> tArgs = {cellClip};
|
std::initializer_list<double> tArgs = {cellClip};
|
||||||
std::initializer_list<Nd4jLong> iArgs = {dataFormat, directionMode, gateAct, cellAct, outAct};
|
std::initializer_list<Nd4jLong> iArgs = {dataFormat, directionMode, gateAct, cellAct, outAct};
|
||||||
|
@ -1905,10 +1919,10 @@ TEST_F(DeclarableOpsTests13, lstmLayer_12) {
|
||||||
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.692315, 0.692315, 0.692315, -0.143704, -0.143704, -0.143704, 0., 0., 0., 0., 0., 0.,
|
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.692315, 0.692315, 0.692315, -0.143704, -0.143704, -0.143704, 0., 0., 0., 0., 0., 0.,
|
||||||
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.}, nd4j::DataType::FLOAT32);
|
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.}, nd4j::DataType::FLOAT32);
|
||||||
|
|
||||||
NDArray expHL('c', {2,bS, nOut}, {0., 0., 0., 0.562925, 0.562925, 0.562925, 0.576568, 0.576568, 0.576568, 0.611224, 0.611224, 0.611224, 0.692315, 0.692315, 0.692315,
|
NDArray expHL('c', {2,bS, nOut}, {0.f, 0.f, 0.f, 0.562925f, 0.562925f, 0.562925f, 0.576568f, 0.576568f, 0.576568f, 0.611224f, 0.611224f, 0.611224f, 0.692315f, 0.692315f, 0.692315f,
|
||||||
0., 0., 0., -0.25361 , -0.25361 , -0.25361 , -0.157103, -0.157103, -0.157103,-0.116502, -0.116502, -0.116502, -0.100025, -0.100025, -0.100025}, nd4j::DataType::FLOAT32);
|
0.f, 0.f, 0.f, -0.25361f, -0.25361f, -0.25361f, -0.157103f, -0.157103f, -0.157103f, -0.116502f, -0.116502f, -0.116502f, -0.100025f, -0.100025f, -0.100025f}, nd4j::DataType::FLOAT32);
|
||||||
NDArray expCL('c', {2,bS, nOut}, {0., 0., 0.,1.534275, 1.534275, 1.534275,1.40183 , 1.40183 , 1.40183 ,1.449675, 1.449675, 1.449675,1.767702, 1.767702, 1.767702,
|
NDArray expCL('c', {2,bS, nOut}, {0.f, 0.f, 0.f, 1.534275f, 1.534275f, 1.534275f, 1.40183f, 1.40183f, 1.40183f, 1.449675f, 1.449675f, 1.449675f, 1.767702f, 1.767702f, 1.767702f,
|
||||||
0., 0., 0.,-0.86636 , -0.86636 , -0.86636 ,-0.470245, -0.470245, -0.470245,-0.341856, -0.341856, -0.341856,-0.294986, -0.294986, -0.294986}, nd4j::DataType::FLOAT32);
|
0.f, 0.f, 0.f, -0.86636f, -0.86636f, -0.86636f, -0.470245f, -0.470245f, -0.470245f, -0.341856f, -0.341856f, -0.341856f, -0.294986f, -0.294986f, -0.294986f}, nd4j::DataType::FLOAT32);
|
||||||
|
|
||||||
nd4j::ops::lstmLayer op;
|
nd4j::ops::lstmLayer op;
|
||||||
auto results = op.execute({&x, &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp}, tArgs, iArgs, bArgs);
|
auto results = op.execute({&x, &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp}, tArgs, iArgs, bArgs);
|
||||||
|
|
|
@ -148,8 +148,8 @@ TEST_F(DeclarableOpsTests15, Test_standarize_1) {
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(DeclarableOpsTests15, Test_standarize_bp_1) {
|
TEST_F(DeclarableOpsTests15, Test_standarize_bp_1) {
|
||||||
auto x = NDArrayFactory::create<float>('c', {5}, {1., 1., 1., 1., 1.});
|
auto x = NDArrayFactory::create<float>('c', {5}, {1.f, 1.f, 1.f, 1.f, 1.f});
|
||||||
auto eps = NDArrayFactory::create<float>('c', {5}, {0., 0., 0., 0., 0.});
|
auto eps = NDArrayFactory::create<float>('c', {5}, {0.f, 0.f, 0.f, 0.f, 0.f});
|
||||||
|
|
||||||
nd4j::ops::standardize_bp op;
|
nd4j::ops::standardize_bp op;
|
||||||
auto result = op.execute({&x, &eps}, {}, {0}, {});
|
auto result = op.execute({&x, &eps}, {}, {0}, {});
|
||||||
|
|
|
@ -1591,7 +1591,7 @@ TEST_F(DeclarableOpsTests2, cosine_distance_loss_test6) {
|
||||||
auto *result = results->at(0);
|
auto *result = results->at(0);
|
||||||
|
|
||||||
ASSERT_TRUE(result->isScalar());
|
ASSERT_TRUE(result->isScalar());
|
||||||
ASSERT_TRUE(result->e<float>(0) == -71.);
|
ASSERT_TRUE(result->e<float>(0) == -71.f);
|
||||||
|
|
||||||
delete results;
|
delete results;
|
||||||
|
|
||||||
|
@ -1616,7 +1616,7 @@ TEST_F(DeclarableOpsTests2, cosine_distance_loss_test7) {
|
||||||
auto *result = results->at(0);
|
auto *result = results->at(0);
|
||||||
|
|
||||||
ASSERT_TRUE(result->isScalar());
|
ASSERT_TRUE(result->isScalar());
|
||||||
ASSERT_TRUE(result->e<float>(0) == -69.);
|
ASSERT_TRUE(result->e<float>(0) == -69.f);
|
||||||
|
|
||||||
delete results;
|
delete results;
|
||||||
|
|
||||||
|
@ -1630,8 +1630,8 @@ TEST_F(DeclarableOpsTests2, cosine_distance_loss_test8) {
|
||||||
auto weights = NDArrayFactory::create<float>('c', {2,3,1});
|
auto weights = NDArrayFactory::create<float>('c', {2,3,1});
|
||||||
|
|
||||||
labels.linspace(1);
|
labels.linspace(1);
|
||||||
weights.assign(0.5);
|
weights.assign(0.5f);
|
||||||
predictions.assign(0.5);
|
predictions.assign(0.5f);
|
||||||
|
|
||||||
nd4j::ops::cosine_distance_loss op;
|
nd4j::ops::cosine_distance_loss op;
|
||||||
auto results = op.execute({&predictions, &weights, &labels}, {}, {2,2});
|
auto results = op.execute({&predictions, &weights, &labels}, {}, {2,2});
|
||||||
|
@ -1641,7 +1641,7 @@ TEST_F(DeclarableOpsTests2, cosine_distance_loss_test8) {
|
||||||
auto *result = results->at(0);
|
auto *result = results->at(0);
|
||||||
|
|
||||||
ASSERT_TRUE(result->isScalar());
|
ASSERT_TRUE(result->isScalar());
|
||||||
ASSERT_TRUE(result->e<float>(0) == -24.);
|
ASSERT_TRUE(result->e<float>(0) == -24.f);
|
||||||
|
|
||||||
delete results;
|
delete results;
|
||||||
|
|
||||||
|
@ -1655,8 +1655,8 @@ TEST_F(DeclarableOpsTests2, cosine_distance_loss_test9) {
|
||||||
auto weights = NDArrayFactory::create<float>('c', {1,1});
|
auto weights = NDArrayFactory::create<float>('c', {1,1});
|
||||||
|
|
||||||
labels.linspace(1);
|
labels.linspace(1);
|
||||||
weights.assign(0.5);
|
weights.assign(0.5f);
|
||||||
predictions.assign(0.5);
|
predictions.assign(0.5f);
|
||||||
|
|
||||||
nd4j::ops::cosine_distance_loss op;
|
nd4j::ops::cosine_distance_loss op;
|
||||||
auto results = op.execute({&predictions, &weights, &labels}, {}, {2,2});
|
auto results = op.execute({&predictions, &weights, &labels}, {}, {2,2});
|
||||||
|
@ -1680,10 +1680,10 @@ TEST_F(DeclarableOpsTests2, cosine_distance_loss_test10) {
|
||||||
auto weights = NDArrayFactory::create<float>('c', {2,3,1});
|
auto weights = NDArrayFactory::create<float>('c', {2,3,1});
|
||||||
|
|
||||||
labels.linspace(1);
|
labels.linspace(1);
|
||||||
weights.assign(0.5);
|
weights.assign(0.5f);
|
||||||
predictions.assign(0.5);
|
predictions.assign(0.5f);
|
||||||
weights.p(0, 0.);
|
weights.p(0, 0.f);
|
||||||
weights.p(1, 0.);
|
weights.p(1, 0.f);
|
||||||
|
|
||||||
nd4j::ops::cosine_distance_loss op;
|
nd4j::ops::cosine_distance_loss op;
|
||||||
auto results = op.execute({&predictions, &weights, &labels}, {}, {2,2});
|
auto results = op.execute({&predictions, &weights, &labels}, {}, {2,2});
|
||||||
|
|
|
@ -1707,7 +1707,7 @@ TEST_F(DeclarableOpsTests3, betainc_test8) {
|
||||||
b.linspace(10.);
|
b.linspace(10.);
|
||||||
x.assign(1.);
|
x.assign(1.);
|
||||||
|
|
||||||
auto expected= NDArrayFactory::create<float>('c', {3,3}, {1.f, 1.f, 1.,1.,1.,1.,1.,1.,1.});
|
auto expected= NDArrayFactory::create<float>('c', {3,3}, {1.f, 1.f, 1.f,1.f,1.f,1.f,1.f,1.f,1.f});
|
||||||
|
|
||||||
nd4j::ops::betainc op;
|
nd4j::ops::betainc op;
|
||||||
auto results = op.execute({&a, &b, &x}, {}, {});
|
auto results = op.execute({&a, &b, &x}, {}, {});
|
||||||
|
@ -2292,9 +2292,9 @@ TEST_F(DeclarableOpsTests3, svd_test3) {
|
||||||
}
|
}
|
||||||
else {
|
else {
|
||||||
for(uint i = 0; i < expU.lengthOf(); ++i)
|
for(uint i = 0; i < expU.lengthOf(); ++i)
|
||||||
ASSERT_NEAR(nd4j::math::nd4j_abs(expU.e<float>(i)), nd4j::math::nd4j_abs(u->e<float>(i)), 1e-5);
|
ASSERT_NEAR(nd4j::math::nd4j_abs(expU.e<float>(i)), nd4j::math::nd4j_abs(u->e<float>(i)), 1e-5f);
|
||||||
for(uint i = 0; i < expV.lengthOf(); ++i)
|
for(uint i = 0; i < expV.lengthOf(); ++i)
|
||||||
ASSERT_NEAR(nd4j::math::nd4j_abs(expV.e<float>(i)), nd4j::math::nd4j_abs(v->e<float>(i)), 1e-5);
|
ASSERT_NEAR(nd4j::math::nd4j_abs(expV.e<float>(i)), nd4j::math::nd4j_abs(v->e<float>(i)), 1e-5f);
|
||||||
}
|
}
|
||||||
|
|
||||||
delete results;
|
delete results;
|
||||||
|
@ -2329,9 +2329,9 @@ TEST_F(DeclarableOpsTests3, svd_test4) {
|
||||||
}
|
}
|
||||||
else {
|
else {
|
||||||
for(uint i = 0; i < expU.lengthOf(); ++i)
|
for(uint i = 0; i < expU.lengthOf(); ++i)
|
||||||
ASSERT_NEAR(nd4j::math::nd4j_abs(expU.e<float>(i)), nd4j::math::nd4j_abs(u->e<float>(i)), 1e-5);
|
ASSERT_NEAR(nd4j::math::nd4j_abs(expU.e<float>(i)), nd4j::math::nd4j_abs(u->e<float>(i)), 1e-5f);
|
||||||
for(uint i = 0; i < expV.lengthOf(); ++i)
|
for(uint i = 0; i < expV.lengthOf(); ++i)
|
||||||
ASSERT_NEAR(nd4j::math::nd4j_abs(expV.e<float>(i)), nd4j::math::nd4j_abs(v->e<float>(i)), 1e-5);
|
ASSERT_NEAR(nd4j::math::nd4j_abs(expV.e<float>(i)), nd4j::math::nd4j_abs(v->e<float>(i)), 1e-5f);
|
||||||
}
|
}
|
||||||
|
|
||||||
delete results;
|
delete results;
|
||||||
|
@ -2366,9 +2366,9 @@ TEST_F(DeclarableOpsTests3, svd_test5) {
|
||||||
}
|
}
|
||||||
else {
|
else {
|
||||||
for(uint i = 0; i < expU.lengthOf(); ++i)
|
for(uint i = 0; i < expU.lengthOf(); ++i)
|
||||||
ASSERT_NEAR(nd4j::math::nd4j_abs(expU.e<float>(i)), nd4j::math::nd4j_abs(u->e<float>(i)), 1e-5);
|
ASSERT_NEAR(nd4j::math::nd4j_abs(expU.e<float>(i)), nd4j::math::nd4j_abs(u->e<float>(i)), 1e-5f);
|
||||||
for(uint i = 0; i < expV.lengthOf(); ++i)
|
for(uint i = 0; i < expV.lengthOf(); ++i)
|
||||||
ASSERT_NEAR(nd4j::math::nd4j_abs(expV.e<float>(i)), nd4j::math::nd4j_abs(v->e<float>(i)), 1e-5);
|
ASSERT_NEAR(nd4j::math::nd4j_abs(expV.e<float>(i)), nd4j::math::nd4j_abs(v->e<float>(i)), 1e-5f);
|
||||||
}
|
}
|
||||||
|
|
||||||
delete results;
|
delete results;
|
||||||
|
@ -2421,9 +2421,9 @@ TEST_F(DeclarableOpsTests3, svd_test6) {
|
||||||
}
|
}
|
||||||
else {
|
else {
|
||||||
for(uint i = 0; i < expU.lengthOf(); ++i)
|
for(uint i = 0; i < expU.lengthOf(); ++i)
|
||||||
ASSERT_NEAR(nd4j::math::nd4j_abs(expU.e<float>(i)), nd4j::math::nd4j_abs(u->e<float>(i)), 1e-5);
|
ASSERT_NEAR(nd4j::math::nd4j_abs(expU.e<float>(i)), nd4j::math::nd4j_abs(u->e<float>(i)), 1e-5f);
|
||||||
for(uint i = 0; i < expV.lengthOf(); ++i)
|
for(uint i = 0; i < expV.lengthOf(); ++i)
|
||||||
ASSERT_NEAR(nd4j::math::nd4j_abs(expV.e<float>(i)), nd4j::math::nd4j_abs(v->e<float>(i)), 1e-5);
|
ASSERT_NEAR(nd4j::math::nd4j_abs(expV.e<float>(i)), nd4j::math::nd4j_abs(v->e<float>(i)), 1e-5f);
|
||||||
}
|
}
|
||||||
|
|
||||||
delete results;
|
delete results;
|
||||||
|
|
|
@ -4084,7 +4084,7 @@ TEST_F(DeclarableOpsTests7, Softsign_1) {
|
||||||
TEST_F(DeclarableOpsTests7, Softsign_BP_1) {
|
TEST_F(DeclarableOpsTests7, Softsign_BP_1) {
|
||||||
|
|
||||||
NDArray x = NDArrayFactory::create<double >('c', {5, 2}, {1,2,3,4,5,7,9,10, 10, 11});
|
NDArray x = NDArrayFactory::create<double >('c', {5, 2}, {1,2,3,4,5,7,9,10, 10, 11});
|
||||||
// NDArray e = NDArrayFactory::create<float>('c', {5, 2}, {1.3132616, 2.126928, 3.0485873, 4.01815, 5.0067153, 7.0009117, 9.000123, 10.000046, 10.000046, 11.000016});
|
// NDArray e = NDArrayFactory::create<float>('c', {5, 2}, {1.3132616f, 2.126928f, 3.0485873f, 4.01815f, 5.0067153f, 7.0009117f, 9.000123f, 10.000046f, 10.000046f, 11.000016f});
|
||||||
NDArray eps = NDArrayFactory::create<double>('c', {5, 2}, {1,2,3,4,5,6,7,8, 9, 10});
|
NDArray eps = NDArrayFactory::create<double>('c', {5, 2}, {1,2,3,4,5,6,7,8, 9, 10});
|
||||||
nd4j::ops::softsign ffOP;
|
nd4j::ops::softsign ffOP;
|
||||||
nd4j::ops::softsign_bp bpOp;
|
nd4j::ops::softsign_bp bpOp;
|
||||||
|
|
|
@ -661,9 +661,9 @@ TEST_F(JavaInteropTests, Test_Greater_1) {
|
||||||
auto x = NDArrayFactory::create<float>('c', {2, 2}, {1, 2, 1, 2});
|
auto x = NDArrayFactory::create<float>('c', {2, 2}, {1, 2, 1, 2});
|
||||||
auto y = NDArrayFactory::create<float>('c', {2, 2}, {1, 2, 0, 0});
|
auto y = NDArrayFactory::create<float>('c', {2, 2}, {1, 2, 0, 0});
|
||||||
// auto o = NDArrayFactory::create<float>('c', {2, 2}, {3, 3, 3, 3});
|
// auto o = NDArrayFactory::create<float>('c', {2, 2}, {3, 3, 3, 3});
|
||||||
auto o = NDArrayFactory::create<bool>('c', {2, 2}, {1, 1, 1, 1});
|
auto o = NDArrayFactory::create<bool>('c', {2, 2}, {true, true, true, true});
|
||||||
|
|
||||||
auto exp = NDArrayFactory::create<bool>('c', {2, 2}, {0, 0, 1, 1});
|
auto exp = NDArrayFactory::create<bool>('c', {2, 2}, {false, false, true, true});
|
||||||
|
|
||||||
NDArray::prepareSpecialUse({&o}, {&x, &y});
|
NDArray::prepareSpecialUse({&o}, {&x, &y});
|
||||||
|
|
||||||
|
@ -685,9 +685,9 @@ TEST_F(JavaInteropTests, Test_Greater_1) {
|
||||||
TEST_F(JavaInteropTests, Test_Greater_2) {
|
TEST_F(JavaInteropTests, Test_Greater_2) {
|
||||||
auto x = NDArrayFactory::create<float>('c', {2, 2}, {1.f, 2.f, 1.f, 2.f});
|
auto x = NDArrayFactory::create<float>('c', {2, 2}, {1.f, 2.f, 1.f, 2.f});
|
||||||
auto y = NDArrayFactory::create<float>('c', {2, 2}, {1.f, 2.f, 0.f, 0.f});
|
auto y = NDArrayFactory::create<float>('c', {2, 2}, {1.f, 2.f, 0.f, 0.f});
|
||||||
auto o = NDArrayFactory::create<bool>('c', {2, 2}, {1, 1, 1, 1});
|
auto o = NDArrayFactory::create<bool>('c', {2, 2}, {true, true, true, true});
|
||||||
|
|
||||||
auto exp = NDArrayFactory::create<bool>('c', {2, 2}, {0, 0, 1, 1});
|
auto exp = NDArrayFactory::create<bool>('c', {2, 2}, {false, false, true, true});
|
||||||
|
|
||||||
nd4j::ops::greater op;
|
nd4j::ops::greater op;
|
||||||
|
|
||||||
|
|
|
@ -1163,10 +1163,10 @@ TEST_F(NDArrayCudaBasicsTests, applyReduce3_1) {
|
||||||
NDArray k('c', {2,3}, {-2,3,-4,5,-2,3}, nd4j::DataType::INT32);
|
NDArray k('c', {2,3}, {-2,3,-4,5,-2,3}, nd4j::DataType::INT32);
|
||||||
NDArray k2('c', {3,2}, {-2,3,-4,5,-2,3}, nd4j::DataType::INT32);
|
NDArray k2('c', {3,2}, {-2,3,-4,5,-2,3}, nd4j::DataType::INT32);
|
||||||
|
|
||||||
NDArray exp1('c', {3}, {4., 20., 36.}, nd4j::DataType::FLOAT32);
|
NDArray exp1('c', {3}, {4.f, 20.f, 36.f}, nd4j::DataType::FLOAT32);
|
||||||
NDArray exp2('c', {2,3}, {-10., -2., 6.,14., 22., 30.}, nd4j::DataType::FLOAT32);
|
NDArray exp2('c', {2,3}, {-10.f, -2.f, 6.f,14.f, 22.f, 30.f}, nd4j::DataType::FLOAT32);
|
||||||
NDArray exp3('c', {4}, {38., 41., 44., 47.}, nd4j::DataType::FLOAT32);
|
NDArray exp3('c', {4}, {38.f, 41.f, 44.f, 47.f}, nd4j::DataType::FLOAT32);
|
||||||
NDArray exp4('c', {4}, {114., 117., 120., 123.}, nd4j::DataType::FLOAT32);
|
NDArray exp4('c', {4}, {114.f, 117.f, 120.f, 123.f}, nd4j::DataType::FLOAT32);
|
||||||
|
|
||||||
|
|
||||||
NDArray* z = x.applyReduce3(nd4j::reduce3::Dot, &y, {0,2});
|
NDArray* z = x.applyReduce3(nd4j::reduce3::Dot, &y, {0,2});
|
||||||
|
@ -1271,8 +1271,10 @@ TEST_F(NDArrayCudaBasicsTests, applyAllReduce3_1) {
|
||||||
NDArray x3('c', {3,2}, {1.5,1.5,1.5,1.5,1.5,1.5}, nd4j::DataType::DOUBLE);
|
NDArray x3('c', {3,2}, {1.5,1.5,1.5,1.5,1.5,1.5}, nd4j::DataType::DOUBLE);
|
||||||
NDArray x4('c', {3,2}, {1,2,3,4,5,6}, nd4j::DataType::DOUBLE);
|
NDArray x4('c', {3,2}, {1,2,3,4,5,6}, nd4j::DataType::DOUBLE);
|
||||||
|
|
||||||
NDArray exp1('c', {3,2}, {-88., -124., 6., -2., 22., 14.}, nd4j::DataType::FLOAT32);
|
NDArray exp1('c', {3,2}, {-88.f, -124.f, 6.f, -2.f, 22.f, 14.f}, nd4j::DataType::FLOAT32);
|
||||||
NDArray exp2('c', {6,4}, {-36., -44., -52., -60.,-42., -52., -62., -72.,2., 0., -2., -4.,6., 4., 2., 0.,10., 8., 6., 4.,14., 12., 10., 8.}, nd4j::DataType::FLOAT32);
|
NDArray exp2('c', {6,4}, {-36.f, -44.f, -52.f, -60.f,-42.f, -52.f, -62.f, -72.f, 2.f, 0.f, -2.f,
|
||||||
|
-4.f, 6.f, 4.f, 2.f, 0.f, 10.f, 8.f, 6.f, 4.f, 14.f, 12.f, 10.f, 8.f},
|
||||||
|
nd4j::DataType::FLOAT32);
|
||||||
NDArray exp3('c', {1,1}, {31.5}, nd4j::DataType::DOUBLE);
|
NDArray exp3('c', {1,1}, {31.5}, nd4j::DataType::DOUBLE);
|
||||||
NDArray exp4('c', {3,3}, {4.5, 10.5, 16.5,4.5, 10.5, 16.5,4.5, 10.5, 16.5}, nd4j::DataType::DOUBLE);
|
NDArray exp4('c', {3,3}, {4.5, 10.5, 16.5,4.5, 10.5, 16.5,4.5, 10.5, 16.5}, nd4j::DataType::DOUBLE);
|
||||||
|
|
||||||
|
@ -1400,10 +1402,10 @@ TEST_F(NDArrayCudaBasicsTests, reduceAlongDimension_float_test1) {
|
||||||
NDArray z5('c', {2}, {100,100}, nd4j::DataType::FLOAT32);
|
NDArray z5('c', {2}, {100,100}, nd4j::DataType::FLOAT32);
|
||||||
|
|
||||||
NDArray exp1('c', {}, {2.166667}, nd4j::DataType::DOUBLE);
|
NDArray exp1('c', {}, {2.166667}, nd4j::DataType::DOUBLE);
|
||||||
NDArray exp2('c', {2,2}, {3,4,1,0.666667}, nd4j::DataType::FLOAT32);
|
NDArray exp2('c', {2,2}, {3.f,4.f,1.f,0.666667f}, nd4j::DataType::FLOAT32);
|
||||||
NDArray exp3('c', {3}, {4.5,1,1}, nd4j::DataType::DOUBLE);
|
NDArray exp3('c', {3}, {4.5,1,1}, nd4j::DataType::DOUBLE);
|
||||||
NDArray exp4('c', {3,2}, {4,5,1,1,1,1}, nd4j::DataType::FLOAT32);
|
NDArray exp4('c', {3,2}, {4,5,1,1,1,1}, nd4j::DataType::FLOAT32);
|
||||||
NDArray exp5('c', {2}, {3.5,0.833333}, nd4j::DataType::FLOAT32);
|
NDArray exp5('c', {2}, {3.5f,0.833333f}, nd4j::DataType::FLOAT32);
|
||||||
|
|
||||||
x.reduceAlongDimension(nd4j::reduce::Mean, &z1, {0,1,2});
|
x.reduceAlongDimension(nd4j::reduce::Mean, &z1, {0,1,2});
|
||||||
ASSERT_TRUE(z1.equalsTo(&exp1));
|
ASSERT_TRUE(z1.equalsTo(&exp1));
|
||||||
|
@ -1503,7 +1505,7 @@ TEST_F(NDArrayCudaBasicsTests, EqualityTest1) {
|
||||||
////////////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
TEST_F(NDArrayCudaBasicsTests, reduceAlongDimension_same_test1) {
|
TEST_F(NDArrayCudaBasicsTests, reduceAlongDimension_same_test1) {
|
||||||
|
|
||||||
NDArray x('c', {2,3,2}, {1.5,2,3,4,5,6,7.5,8,-1,-2,-3.5,-4,}, nd4j::DataType::FLOAT32);
|
NDArray x('c', {2,3,2}, {1.5f,2.f,3.f,4.f,5.f,6.f,7.5f,8.f,-1.f,-2.f,-3.5f,-4.f}, nd4j::DataType::FLOAT32);
|
||||||
|
|
||||||
NDArray z1('c', {}, {100}, nd4j::DataType::FLOAT32);
|
NDArray z1('c', {}, {100}, nd4j::DataType::FLOAT32);
|
||||||
NDArray z2('c', {2,2}, {100,100,100,100}, nd4j::DataType::FLOAT32);
|
NDArray z2('c', {2,2}, {100,100,100,100}, nd4j::DataType::FLOAT32);
|
||||||
|
@ -1511,11 +1513,11 @@ TEST_F(NDArrayCudaBasicsTests, reduceAlongDimension_same_test1) {
|
||||||
NDArray z4('c', {3,2}, {100,100,100,100,100,100}, nd4j::DataType::FLOAT32);
|
NDArray z4('c', {3,2}, {100,100,100,100,100,100}, nd4j::DataType::FLOAT32);
|
||||||
NDArray z5('c', {2}, {100,100}, nd4j::DataType::FLOAT32);
|
NDArray z5('c', {2}, {100,100}, nd4j::DataType::FLOAT32);
|
||||||
|
|
||||||
NDArray exp1('c', {}, {26.5}, nd4j::DataType::FLOAT32);
|
NDArray exp1('c', {}, {26.5f}, nd4j::DataType::FLOAT32);
|
||||||
NDArray exp2('c', {2,2}, {9.5,12,3,2}, nd4j::DataType::FLOAT32);
|
NDArray exp2('c', {2,2}, {9.5f,12.f,3.f,2.f}, nd4j::DataType::FLOAT32);
|
||||||
NDArray exp3('c', {3}, {19,4,3.5}, nd4j::DataType::FLOAT32);
|
NDArray exp3('c', {3}, {19.f,4.f,3.5f}, nd4j::DataType::FLOAT32);
|
||||||
NDArray exp4('c', {3,2}, {9,10,2,2,1.5,2}, nd4j::DataType::FLOAT32);
|
NDArray exp4('c', {3,2}, {9.f,10.f,2.f,2.f,1.5f,2.f}, nd4j::DataType::FLOAT32);
|
||||||
NDArray exp5('c', {2}, {21.5,5}, nd4j::DataType::FLOAT32);
|
NDArray exp5('c', {2}, {21.5f,5.f}, nd4j::DataType::FLOAT32);
|
||||||
|
|
||||||
x.reduceAlongDimension(nd4j::reduce::Sum, &z1, {0,1,2});
|
x.reduceAlongDimension(nd4j::reduce::Sum, &z1, {0,1,2});
|
||||||
ASSERT_TRUE(z1.equalsTo(&exp1));
|
ASSERT_TRUE(z1.equalsTo(&exp1));
|
||||||
|
@ -1575,17 +1577,17 @@ TEST_F(NDArrayCudaBasicsTests, reduceAlongDimension_bool_test1) {
|
||||||
|
|
||||||
NDArray x('c', {2,3,2}, {0.5,2,3,-4,5,6,-7.5,8,-1,-0.5,-3.5,4}, nd4j::DataType::DOUBLE);
|
NDArray x('c', {2,3,2}, {0.5,2,3,-4,5,6,-7.5,8,-1,-0.5,-3.5,4}, nd4j::DataType::DOUBLE);
|
||||||
|
|
||||||
NDArray z1('c', {}, {100}, nd4j::DataType::BOOL);
|
NDArray z1('c', {}, {true}, nd4j::DataType::BOOL);
|
||||||
NDArray z2('c', {2,2}, {100,100,100,100}, nd4j::DataType::BOOL);
|
NDArray z2('c', {2,2}, {true,true,true,true}, nd4j::DataType::BOOL);
|
||||||
NDArray z3('c', {3}, {100,100,100}, nd4j::DataType::BOOL);
|
NDArray z3('c', {3}, {true,true,true}, nd4j::DataType::BOOL);
|
||||||
NDArray z4('c', {3,2}, {100,100,100,100,100,100}, nd4j::DataType::BOOL);
|
NDArray z4('c', {3,2}, {true,true,true,true,true,true}, nd4j::DataType::BOOL);
|
||||||
NDArray z5('c', {2}, {100,100}, nd4j::DataType::BOOL);
|
NDArray z5('c', {2}, {true,true}, nd4j::DataType::BOOL);
|
||||||
|
|
||||||
NDArray exp1('c', {}, {1}, nd4j::DataType::BOOL);
|
NDArray exp1('c', {}, {true}, nd4j::DataType::BOOL);
|
||||||
NDArray exp2('c', {2,2}, {1,1,0,1}, nd4j::DataType::BOOL);
|
NDArray exp2('c', {2,2}, {true,true,false,true}, nd4j::DataType::BOOL);
|
||||||
NDArray exp3('c', {3}, {1,1,1}, nd4j::DataType::BOOL);
|
NDArray exp3('c', {3}, {true,true,true}, nd4j::DataType::BOOL);
|
||||||
NDArray exp4('c', {3,2}, {1,1,1,0,1,1}, nd4j::DataType::BOOL);
|
NDArray exp4('c', {3,2}, {true,true,true,false,true,true}, nd4j::DataType::BOOL);
|
||||||
NDArray exp5('c', {2}, {1,1}, nd4j::DataType::BOOL);
|
NDArray exp5('c', {2}, {true,true}, nd4j::DataType::BOOL);
|
||||||
|
|
||||||
x.reduceAlongDimension(nd4j::reduce::IsPositive, &z1, {0,1,2});
|
x.reduceAlongDimension(nd4j::reduce::IsPositive, &z1, {0,1,2});
|
||||||
ASSERT_TRUE(z1.equalsTo(&exp1));
|
ASSERT_TRUE(z1.equalsTo(&exp1));
|
||||||
|
@ -1643,7 +1645,7 @@ TEST_F(NDArrayCudaBasicsTests, reduceAlongDimension_bool_test2) {
|
||||||
////////////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
TEST_F(NDArrayCudaBasicsTests, reduceAlongDimension_long_test1) {
|
TEST_F(NDArrayCudaBasicsTests, reduceAlongDimension_long_test1) {
|
||||||
|
|
||||||
NDArray x('c', {2,3,2}, {0.5,2,3,-0,5,6,-7.5,0,-1,-0.5,-3.5,4}, nd4j::DataType::FLOAT32);
|
NDArray x('c', {2,3,2}, {0.5f,2.f,3.f,-0.f,5.f,6.f,-7.5f,0.f,-1.f,-0.5f,-3.5f,4.f}, nd4j::DataType::FLOAT32);
|
||||||
|
|
||||||
NDArray z1('c', {}, {100}, nd4j::DataType::INT64);
|
NDArray z1('c', {}, {100}, nd4j::DataType::INT64);
|
||||||
NDArray z2('c', {2,2}, {100,100,100,100}, nd4j::DataType::INT64);
|
NDArray z2('c', {2,2}, {100,100,100,100}, nd4j::DataType::INT64);
|
||||||
|
@ -1912,7 +1914,7 @@ TEST_F(NDArrayCudaBasicsTests, Tile_Test_2_3)
|
||||||
TEST_F(NDArrayCudaBasicsTests, Operator_Plus_Test_2)
|
TEST_F(NDArrayCudaBasicsTests, Operator_Plus_Test_2)
|
||||||
{
|
{
|
||||||
double expBuff[] = {2., 3, 3., 4., 4., 5, 5., 6., 6., 7, 7., 8.};
|
double expBuff[] = {2., 3, 3., 4., 4., 5, 5., 6., 6., 7, 7., 8.};
|
||||||
NDArray a('c', {4,4}, {1.,2,3,4,5,6,7,8,9,2,3,2,1,0,4,7.}, nd4j::DataType::FLOAT32);
|
NDArray a('c', {4,4}, {1,2,3,4,5,6,7,8,9,2,3,2,1,0,4,7}, nd4j::DataType::FLOAT32);
|
||||||
auto x = NDArrayFactory::create<double>('c', {3, 2, 1});
|
auto x = NDArrayFactory::create<double>('c', {3, 2, 1});
|
||||||
auto y = NDArrayFactory::create<double>('c', {1, 2});
|
auto y = NDArrayFactory::create<double>('c', {1, 2});
|
||||||
auto expected = NDArrayFactory::create<double>(expBuff, 'c', {3, 2, 2});
|
auto expected = NDArrayFactory::create<double>(expBuff, 'c', {3, 2, 2});
|
||||||
|
@ -1928,7 +1930,7 @@ TEST_F(NDArrayCudaBasicsTests, Operator_Plus_Test_2)
|
||||||
//////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////
|
||||||
TEST_F(NDArrayCudaBasicsTests, assign_2)
|
TEST_F(NDArrayCudaBasicsTests, assign_2)
|
||||||
{
|
{
|
||||||
NDArray x('c', {4}, {1.5,2.5,3.5,4.5}, nd4j::DataType::FLOAT32);
|
NDArray x('c', {4}, {1.5f,2.5f,3.5f,4.5f}, nd4j::DataType::FLOAT32);
|
||||||
NDArray y('c', {4}, nd4j::DataType::INT32);
|
NDArray y('c', {4}, nd4j::DataType::INT32);
|
||||||
NDArray expected('c', {4}, {1,2,3,4}, nd4j::DataType::INT32);
|
NDArray expected('c', {4}, {1,2,3,4}, nd4j::DataType::INT32);
|
||||||
|
|
||||||
|
@ -1945,30 +1947,30 @@ TEST_F(NDArrayCudaBasicsTests, subarray_1)
|
||||||
NDArray y('f', {2,3,4}, {1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24}, nd4j::DataType::FLOAT32);
|
NDArray y('f', {2,3,4}, {1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24}, nd4j::DataType::FLOAT32);
|
||||||
|
|
||||||
Nd4jLong shapeExpX0[] = {1, 2, 12, 8192, 1, 99};
|
Nd4jLong shapeExpX0[] = {1, 2, 12, 8192, 1, 99};
|
||||||
float buffExpX0[] = {1.000000, 13.000000};
|
float buffExpX0[] = {1.f, 13.f};
|
||||||
Nd4jLong shapeExpX1[] = {1, 2, 12, 8192, 1, 99};
|
Nd4jLong shapeExpX1[] = {1, 2, 12, 8192, 1, 99};
|
||||||
float buffExpX1[] = {2.000000, 14.000000};
|
float buffExpX1[] = {2.f, 14.f};
|
||||||
Nd4jLong shapeExpX2[] = {3, 2, 1, 1, 12, 4, 1, 8192, 1, 99};
|
Nd4jLong shapeExpX2[] = {3, 2, 1, 1, 12, 4, 1, 8192, 1, 99};
|
||||||
float buffExpX2[] = {1.000000, 13.000000};
|
float buffExpX2[] = {1.f, 13.f};
|
||||||
Nd4jLong shapeExpX3[] = {2, 2, 4, 12, 1, 8192, 1, 99};
|
Nd4jLong shapeExpX3[] = {2, 2, 4, 12, 1, 8192, 1, 99};
|
||||||
float buffExpX3[] = {9.000000, 10.000000, 11.000000, 12.000000, 21.000000, 22.000000, 23.000000, 24.000000};
|
float buffExpX3[] = {9.f, 10.f, 11.f, 12.f, 21.f, 22.f, 23.f, 24.f};
|
||||||
Nd4jLong shapeExpX4[] = {3, 2, 1, 4, 12, 4, 1, 8192, 1, 99};
|
Nd4jLong shapeExpX4[] = {3, 2, 1, 4, 12, 4, 1, 8192, 1, 99};
|
||||||
float buffExpX4[] = {9.000000, 10.000000, 11.000000, 12.000000, 21.000000, 22.000000, 23.000000, 24.000000};
|
float buffExpX4[] = {9.f, 10.f, 11.f, 12.f, 21.f, 22.f, 23.f, 24.f};
|
||||||
Nd4jLong shapeExpX5[] = {2, 2, 3, 12, 4, 8192, 1, 99};
|
Nd4jLong shapeExpX5[] = {2, 2, 3, 12, 4, 8192, 1, 99};
|
||||||
float buffExpX5[] = {4.000000, 8.000000, 12.000000, 16.000000, 20.000000, 24.000000};
|
float buffExpX5[] = {4.f, 8.f, 12.f, 16.f, 20.f, 24.f};
|
||||||
|
|
||||||
Nd4jLong shapeExpY0[] = {1, 2, 1, 8192, 1, 99};
|
Nd4jLong shapeExpY0[] = {1, 2, 1, 8192, 1, 99};
|
||||||
float buffExpY0[] = {1.000000, 2.000000};
|
float buffExpY0[] = {1.f, 2.f};
|
||||||
Nd4jLong shapeExpY1[] = {1, 2, 1, 8192, 1, 99};
|
Nd4jLong shapeExpY1[] = {1, 2, 1, 8192, 1, 99};
|
||||||
float buffExpY1[] = {7.000000, 8.000000};
|
float buffExpY1[] = {7.f, 8.f};
|
||||||
Nd4jLong shapeExpY2[] = {3, 2, 1, 1, 1, 2, 6, 8192, 1, 102};
|
Nd4jLong shapeExpY2[] = {3, 2, 1, 1, 1, 2, 6, 8192, 1, 102};
|
||||||
float buffExpY2[] = {1.000000, 2.000000};
|
float buffExpY2[] = {1.f, 2.f};
|
||||||
Nd4jLong shapeExpY3[] = {2, 2, 4, 1, 6, 8192, 1, 99};
|
Nd4jLong shapeExpY3[] = {2, 2, 4, 1, 6, 8192, 1, 99};
|
||||||
float buffExpY3[] = {5.000000, 11.000000, 17.000000, 23.000000, 6.000000, 12.000000, 18.000000, 24.000000};
|
float buffExpY3[] = {5.f, 11.f, 17.f, 23.f, 6.f, 12.f, 18.f, 24.f};
|
||||||
Nd4jLong shapeExpY4[] = {3, 2, 1, 4, 1, 2, 6, 8192, 1, 102};
|
Nd4jLong shapeExpY4[] = {3, 2, 1, 4, 1, 2, 6, 8192, 1, 102};
|
||||||
float buffExpY4[] = {5.000000, 11.000000, 17.000000, 23.000000, 6.000000, 12.000000, 18.000000, 24.000000};
|
float buffExpY4[] = {5.f, 11.f, 17.f, 23.f, 6.f, 12.f, 18.f, 24.f};
|
||||||
Nd4jLong shapeExpY5[] = {2, 2, 3, 1, 2, 8192, 1, 99};
|
Nd4jLong shapeExpY5[] = {2, 2, 3, 1, 2, 8192, 1, 99};
|
||||||
float buffExpY5[] = {19.000000, 21.000000, 23.000000, 20.000000, 22.000000, 24.000000};
|
float buffExpY5[] = {19.f, 21.f, 23.f, 20.f, 22.f, 24.f};
|
||||||
|
|
||||||
|
|
||||||
NDArray x0 = x(0, {1,2});
|
NDArray x0 = x(0, {1,2});
|
||||||
|
@ -2121,7 +2123,7 @@ TEST_F(NDArrayCudaBasicsTests, Test_diagonal_1) {
|
||||||
TEST_F(NDArrayCudaBasicsTests, Test_PermuteEquality_02) {
|
TEST_F(NDArrayCudaBasicsTests, Test_PermuteEquality_02) {
|
||||||
auto x = NDArrayFactory::linspace<float>(1.f, 60.f, 60); //('c', {1, 60});
|
auto x = NDArrayFactory::linspace<float>(1.f, 60.f, 60); //('c', {1, 60});
|
||||||
//x.linspace(1);
|
//x.linspace(1);
|
||||||
auto exp = NDArrayFactory::create<float>('c', {3, 4, 5}, {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0, 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0, 25.0, 26.0, 27.0, 28.0, 29.0, 30.0, 31.0, 32.0, 33.0, 34.0, 35.0, 36.0, 37.0, 38.0, 39.0, 40.0, 41.0, 42.0, 43.0, 44.0, 45.0, 46.0, 47.0, 48.0, 49.0, 50.0, 51.0, 52.0, 53.0, 54.0, 55.0, 56.0, 57.0, 58.0, 59.0, 60.0});
|
auto exp = NDArrayFactory::create<float>('c', {3, 4, 5}, {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f, 10.0f, 11.0f, 12.0f, 13.0f, 14.0f, 15.0f, 16.0f, 17.0f, 18.0f, 19.0f, 20.0f, 21.0f, 22.0f, 23.0f, 24.0f, 25.0f, 26.0f, 27.0f, 28.0f, 29.0f, 30.0f, 31.0f, 32.0f, 33.0f, 34.0f, 35.0f, 36.0f, 37.0f, 38.0f, 39.0f, 40.0f, 41.0f, 42.0f, 43.0f, 44.0f, 45.0f, 46.0f, 47.0f, 48.0f, 49.0f, 50.0f, 51.0f, 52.0f, 53.0f, 54.0f, 55.0f, 56.0f, 57.0f, 58.0f, 59.0f, 60.0});
|
||||||
x->reshapei('c', {3, 4, 5});
|
x->reshapei('c', {3, 4, 5});
|
||||||
|
|
||||||
x->permutei({0, 1, 2});
|
x->permutei({0, 1, 2});
|
||||||
|
@ -2138,7 +2140,7 @@ TEST_F(NDArrayCudaBasicsTests, Test_PermuteEquality_02) {
|
||||||
TEST_F(NDArrayCudaBasicsTests, Test_PermuteEquality_0) {
|
TEST_F(NDArrayCudaBasicsTests, Test_PermuteEquality_0) {
|
||||||
auto x = NDArrayFactory::create<float>('c', {1, 60});
|
auto x = NDArrayFactory::create<float>('c', {1, 60});
|
||||||
x.linspace(1);
|
x.linspace(1);
|
||||||
auto exp = NDArrayFactory::create<float>('c', {3, 4, 5}, {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0, 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0, 25.0, 26.0, 27.0, 28.0, 29.0, 30.0, 31.0, 32.0, 33.0, 34.0, 35.0, 36.0, 37.0, 38.0, 39.0, 40.0, 41.0, 42.0, 43.0, 44.0, 45.0, 46.0, 47.0, 48.0, 49.0, 50.0, 51.0, 52.0, 53.0, 54.0, 55.0, 56.0, 57.0, 58.0, 59.0, 60.0});
|
auto exp = NDArrayFactory::create<float>('c', {3, 4, 5}, {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f, 10.0f, 11.0f, 12.0f, 13.0f, 14.0f, 15.0f, 16.0f, 17.0f, 18.0f, 19.0f, 20.0f, 21.0f, 22.0f, 23.0f, 24.0f, 25.0f, 26.0f, 27.0f, 28.0f, 29.0f, 30.0f, 31.0f, 32.0f, 33.0f, 34.0f, 35.0f, 36.0f, 37.0f, 38.0f, 39.0f, 40.0f, 41.0f, 42.0f, 43.0f, 44.0f, 45.0f, 46.0f, 47.0f, 48.0f, 49.0f, 50.0f, 51.0f, 52.0f, 53.0f, 54.0f, 55.0f, 56.0f, 57.0f, 58.0f, 59.0f, 60.0});
|
||||||
x.reshapei('c', {3, 4, 5});
|
x.reshapei('c', {3, 4, 5});
|
||||||
|
|
||||||
x.permutei({0, 1, 2});
|
x.permutei({0, 1, 2});
|
||||||
|
@ -2153,7 +2155,7 @@ TEST_F(NDArrayCudaBasicsTests, Test_PermuteEquality_0) {
|
||||||
TEST_F(NDArrayCudaBasicsTests, Test_PermuteEquality_1) {
|
TEST_F(NDArrayCudaBasicsTests, Test_PermuteEquality_1) {
|
||||||
auto x = NDArrayFactory::create<float>('c', {1, 60});
|
auto x = NDArrayFactory::create<float>('c', {1, 60});
|
||||||
x.linspace(1);
|
x.linspace(1);
|
||||||
auto exp = NDArrayFactory::create<float>('c', {3, 4, 5}, {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0, 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0, 25.0, 26.0, 27.0, 28.0, 29.0, 30.0, 31.0, 32.0, 33.0, 34.0, 35.0, 36.0, 37.0, 38.0, 39.0, 40.0, 41.0, 42.0, 43.0, 44.0, 45.0, 46.0, 47.0, 48.0, 49.0, 50.0, 51.0, 52.0, 53.0, 54.0, 55.0, 56.0, 57.0, 58.0, 59.0, 60.0});
|
auto exp = NDArrayFactory::create<float>('c', {3, 4, 5}, {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f, 10.0f, 11.0f, 12.0f, 13.0f, 14.0f, 15.0f, 16.0f, 17.0f, 18.0f, 19.0f, 20.0f, 21.0f, 22.0f, 23.0f, 24.0f, 25.0f, 26.0f, 27.0f, 28.0f, 29.0f, 30.0f, 31.0f, 32.0f, 33.0f, 34.0f, 35.0f, 36.0f, 37.0f, 38.0f, 39.0f, 40.0f, 41.0f, 42.0f, 43.0f, 44.0f, 45.0f, 46.0f, 47.0f, 48.0f, 49.0f, 50.0f, 51.0f, 52.0f, 53.0f, 54.0f, 55.0f, 56.0f, 57.0f, 58.0f, 59.0f, 60.0});
|
||||||
x.reshapei('c', {3, 4, 5});
|
x.reshapei('c', {3, 4, 5});
|
||||||
|
|
||||||
x.permutei({0, 1, 2});
|
x.permutei({0, 1, 2});
|
||||||
|
@ -2170,7 +2172,7 @@ TEST_F(NDArrayCudaBasicsTests, Test_PermuteEquality_2) {
|
||||||
auto xx = NDArrayFactory::linspace<float>(1.f, 60.f, 60); //('c', {1, 60});
|
auto xx = NDArrayFactory::linspace<float>(1.f, 60.f, 60); //('c', {1, 60});
|
||||||
// auto x = *xx;
|
// auto x = *xx;
|
||||||
//x.linspace(1);
|
//x.linspace(1);
|
||||||
// auto exp = NDArrayFactory::create<float>('c', {3, 4, 5}, {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0, 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0, 25.0, 26.0, 27.0, 28.0, 29.0, 30.0, 31.0, 32.0, 33.0, 34.0, 35.0, 36.0, 37.0, 38.0, 39.0, 40.0, 41.0, 42.0, 43.0, 44.0, 45.0, 46.0, 47.0, 48.0, 49.0, 50.0, 51.0, 52.0, 53.0, 54.0, 55.0, 56.0, 57.0, 58.0, 59.0, 60.0});
|
// auto exp = NDArrayFactory::create<float>('c', {3, 4, 5}, {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f, 10.0f, 11.0f, 12.0f, 13.0f, 14.0f, 15.0f, 16.0f, 17.0f, 18.0f, 19.0f, 20.0f, 21.0f, 22.0f, 23.0f, 24.0f, 25.0f, 26.0f, 27.0f, 28.0f, 29.0f, 30.0f, 31.0f, 32.0f, 33.0f, 34.0f, 35.0f, 36.0f, 37.0f, 38.0f, 39.0f, 40.0f, 41.0f, 42.0f, 43.0f, 44.0f, 45.0f, 46.0f, 47.0f, 48.0f, 49.0f, 50.0f, 51.0f, 52.0f, 53.0f, 54.0f, 55.0f, 56.0f, 57.0f, 58.0f, 59.0f, 60.0});
|
||||||
// x.reshapei('c', {3, 4, 5});
|
// x.reshapei('c', {3, 4, 5});
|
||||||
|
|
||||||
// x.permutei({0, 1, 2});
|
// x.permutei({0, 1, 2});
|
||||||
|
@ -2188,7 +2190,7 @@ TEST_F(NDArrayCudaBasicsTests, Test_PermuteEquality_3) {
|
||||||
//x.linspace(1);
|
//x.linspace(1);
|
||||||
for (int l = 0; l < x.lengthOf(); l++)
|
for (int l = 0; l < x.lengthOf(); l++)
|
||||||
x.p(l, float(l + 1.f));
|
x.p(l, float(l + 1.f));
|
||||||
auto exp = NDArrayFactory::create<float>('c', {3, 4, 5}, {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0, 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0, 25.0, 26.0, 27.0, 28.0, 29.0, 30.0, 31.0, 32.0, 33.0, 34.0, 35.0, 36.0, 37.0, 38.0, 39.0, 40.0, 41.0, 42.0, 43.0, 44.0, 45.0, 46.0, 47.0, 48.0, 49.0, 50.0, 51.0, 52.0, 53.0, 54.0, 55.0, 56.0, 57.0, 58.0, 59.0, 60.0});
|
auto exp = NDArrayFactory::create<float>('c', {3, 4, 5}, {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f, 10.0f, 11.0f, 12.0f, 13.0f, 14.0f, 15.0f, 16.0f, 17.0f, 18.0f, 19.0f, 20.0f, 21.0f, 22.0f, 23.0f, 24.0f, 25.0f, 26.0f, 27.0f, 28.0f, 29.0f, 30.0f, 31.0f, 32.0f, 33.0f, 34.0f, 35.0f, 36.0f, 37.0f, 38.0f, 39.0f, 40.0f, 41.0f, 42.0f, 43.0f, 44.0f, 45.0f, 46.0f, 47.0f, 48.0f, 49.0f, 50.0f, 51.0f, 52.0f, 53.0f, 54.0f, 55.0f, 56.0f, 57.0f, 58.0f, 59.0f, 60.0});
|
||||||
x.reshapei('c', {3, 4, 5});
|
x.reshapei('c', {3, 4, 5});
|
||||||
|
|
||||||
x.permutei({0, 1, 2});
|
x.permutei({0, 1, 2});
|
||||||
|
|
File diff suppressed because one or more lines are too long
|
@ -19,6 +19,7 @@ package org.nd4j.linalg.api.ops.impl.image;
|
||||||
import lombok.NoArgsConstructor;
|
import lombok.NoArgsConstructor;
|
||||||
import lombok.NonNull;
|
import lombok.NonNull;
|
||||||
import lombok.NoArgsConstructor;
|
import lombok.NoArgsConstructor;
|
||||||
|
import lombok.val;
|
||||||
import org.nd4j.autodiff.samediff.SDVariable;
|
import org.nd4j.autodiff.samediff.SDVariable;
|
||||||
import org.nd4j.autodiff.samediff.SameDiff;
|
import org.nd4j.autodiff.samediff.SameDiff;
|
||||||
import org.nd4j.base.Preconditions;
|
import org.nd4j.base.Preconditions;
|
||||||
|
@ -43,20 +44,25 @@ import java.util.Map;
|
||||||
@NoArgsConstructor
|
@NoArgsConstructor
|
||||||
public class ResizeBilinear extends DynamicCustomOp {
|
public class ResizeBilinear extends DynamicCustomOp {
|
||||||
protected boolean alignCorners = false;
|
protected boolean alignCorners = false;
|
||||||
|
protected boolean halfPixelCenters = false;
|
||||||
protected Integer height = null;
|
protected Integer height = null;
|
||||||
protected Integer width = null;
|
protected Integer width = null;
|
||||||
|
|
||||||
public ResizeBilinear(@NonNull SameDiff sd, @NonNull SDVariable input, int height, int width, boolean alignCorners){
|
public ResizeBilinear(@NonNull SameDiff sd, @NonNull SDVariable input, int height, int width,
|
||||||
|
boolean alignCorners, boolean halfPixelCenters){
|
||||||
super(sd, input);
|
super(sd, input);
|
||||||
this.alignCorners = alignCorners;
|
this.alignCorners = alignCorners;
|
||||||
this.height = height;
|
this.height = height;
|
||||||
this.width = width;
|
this.width = width;
|
||||||
|
this.halfPixelCenters = halfPixelCenters;
|
||||||
addArgs();
|
addArgs();
|
||||||
}
|
}
|
||||||
|
|
||||||
public ResizeBilinear(@NonNull INDArray x, INDArray z, int height, int width, boolean alignCorners){
|
public ResizeBilinear(@NonNull INDArray x, INDArray z, int height, int width,
|
||||||
|
boolean alignCorners, boolean halfPixelCenters) {
|
||||||
super(new INDArray[]{x}, new INDArray[]{z});
|
super(new INDArray[]{x}, new INDArray[]{z});
|
||||||
this.alignCorners = alignCorners;
|
this.alignCorners = alignCorners;
|
||||||
|
this.halfPixelCenters = halfPixelCenters;
|
||||||
this.height = height;
|
this.height = height;
|
||||||
this.width = width;
|
this.width = width;
|
||||||
addArgs();
|
addArgs();
|
||||||
|
@ -76,7 +82,12 @@ public class ResizeBilinear extends DynamicCustomOp {
|
||||||
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
|
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
|
||||||
TFGraphMapper.initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph);
|
TFGraphMapper.initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph);
|
||||||
|
|
||||||
this.alignCorners = attributesForNode.get("align_corners").getB();
|
val attrC = attributesForNode.get("align_corners");
|
||||||
|
val attrH = attributesForNode.get("half_pixel_centers");
|
||||||
|
|
||||||
|
this.alignCorners = attrC != null ? attrC.getB() : false;
|
||||||
|
this.halfPixelCenters = attrH != null ? attrH.getB() : false;
|
||||||
|
|
||||||
addArgs();
|
addArgs();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -87,8 +98,7 @@ public class ResizeBilinear extends DynamicCustomOp {
|
||||||
iArguments.add(Long.valueOf(height));
|
iArguments.add(Long.valueOf(height));
|
||||||
iArguments.add(Long.valueOf(width));
|
iArguments.add(Long.valueOf(width));
|
||||||
}
|
}
|
||||||
iArguments.add(alignCorners ? 1L : 0L);
|
addBArgument(alignCorners, halfPixelCenters);
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
|
|
@ -4584,6 +4584,7 @@ public native @Cast("bool") boolean isOptimalRequirementsMet();
|
||||||
* returns reference on array element with given index
|
* returns reference on array element with given index
|
||||||
*/
|
*/
|
||||||
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* returns array element with given index
|
* returns array element with given index
|
||||||
* i - element index in array
|
* i - element index in array
|
||||||
|
@ -5171,6 +5172,8 @@ NDArray& NDArray::operator()(const Nd4jLong* idx) {
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
|
||||||
|
@ -5179,6 +5182,8 @@ NDArray& NDArray::operator()(const Nd4jLong* idx) {
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
// #ifndef __JAVACPP_HACK__
|
// #ifndef __JAVACPP_HACK__
|
||||||
// #endif
|
// #endif
|
||||||
|
|
||||||
|
|
|
@ -4587,6 +4587,7 @@ public native @Cast("bool") boolean isOptimalRequirementsMet();
|
||||||
* returns reference on array element with given index
|
* returns reference on array element with given index
|
||||||
*/
|
*/
|
||||||
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* returns array element with given index
|
* returns array element with given index
|
||||||
* i - element index in array
|
* i - element index in array
|
||||||
|
@ -5174,6 +5175,8 @@ NDArray& NDArray::operator()(const Nd4jLong* idx) {
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
|
||||||
|
@ -5182,6 +5185,8 @@ NDArray& NDArray::operator()(const Nd4jLong* idx) {
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
// #ifndef __JAVACPP_HACK__
|
// #ifndef __JAVACPP_HACK__
|
||||||
// #endif
|
// #endif
|
||||||
|
|
||||||
|
@ -18280,7 +18285,6 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD();
|
||||||
/**
|
/**
|
||||||
* This op calculates polygamma function psi^(n)(x). Implementation is based on serial representation written in
|
* This op calculates polygamma function psi^(n)(x). Implementation is based on serial representation written in
|
||||||
* terms of the Hurwitz zeta function: polygamma = (-1)^{n+1} * n! * zeta(n+1, x).
|
* terms of the Hurwitz zeta function: polygamma = (-1)^{n+1} * n! * zeta(n+1, x).
|
||||||
* Currently the case n = 0 is not supported.
|
|
||||||
*
|
*
|
||||||
* Input arrays:
|
* Input arrays:
|
||||||
* 0: n - define derivative order (n+1), type integer (however currently is implemented as float casted to integer)
|
* 0: n - define derivative order (n+1), type integer (however currently is implemented as float casted to integer)
|
||||||
|
@ -18309,6 +18313,34 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD();
|
||||||
}
|
}
|
||||||
// #endif
|
// #endif
|
||||||
|
|
||||||
|
/**
|
||||||
|
* This op calculates digamma function psi(x) = derivative of log(Gamma(x))
|
||||||
|
*
|
||||||
|
* Input arrays:
|
||||||
|
* 0: x - abscissa points where to evaluate the digamma function, type float
|
||||||
|
*
|
||||||
|
* Output array:
|
||||||
|
* 0: values of digamma function at corresponding x, type float
|
||||||
|
*
|
||||||
|
*/
|
||||||
|
// #if NOT_EXCLUDED(OP_digamma)
|
||||||
|
@Namespace("nd4j::ops") public static class digamma extends DeclarableOp {
|
||||||
|
static { Loader.load(); }
|
||||||
|
/** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */
|
||||||
|
public digamma(Pointer p) { super(p); }
|
||||||
|
/** Native array allocator. Access with {@link Pointer#position(long)}. */
|
||||||
|
public digamma(long size) { super((Pointer)null); allocateArray(size); }
|
||||||
|
private native void allocateArray(long size);
|
||||||
|
@Override public digamma position(long position) {
|
||||||
|
return (digamma)super.position(position);
|
||||||
|
}
|
||||||
|
|
||||||
|
public digamma() { super((Pointer)null); allocate(); }
|
||||||
|
private native void allocate();
|
||||||
|
public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block);
|
||||||
|
}
|
||||||
|
// #endif
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* This operation takes shape as first argument, and returns new NDArray filled with specific scalar value.
|
* This operation takes shape as first argument, and returns new NDArray filled with specific scalar value.
|
||||||
* Input arrays:
|
* Input arrays:
|
||||||
|
@ -18398,9 +18430,10 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD();
|
||||||
* This operation adjusts image hue by delta
|
* This operation adjusts image hue by delta
|
||||||
* Input arrays:
|
* Input arrays:
|
||||||
* 0 - input array with rank >= 3, must have at least one dimension equal 3, that is dimension containing channels.
|
* 0 - input array with rank >= 3, must have at least one dimension equal 3, that is dimension containing channels.
|
||||||
|
* 1 - optional argument, input scalar-array containing delta
|
||||||
*
|
*
|
||||||
* T arguments:
|
* T arguments:
|
||||||
* 0 - delta value
|
* 0 - optional argument, delta value
|
||||||
*
|
*
|
||||||
* Int arguments:
|
* Int arguments:
|
||||||
* 0 - optional argument, corresponds to dimension with 3 channels
|
* 0 - optional argument, corresponds to dimension with 3 channels
|
||||||
|
@ -18427,9 +18460,10 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD();
|
||||||
* This operation adjusts image saturation by delta
|
* This operation adjusts image saturation by delta
|
||||||
* Input arrays:
|
* Input arrays:
|
||||||
* 0 - input array with rank >= 3, must have at least one dimension equal 3, that is dimension containing channels.
|
* 0 - input array with rank >= 3, must have at least one dimension equal 3, that is dimension containing channels.
|
||||||
|
* 1 - optional argument, input scalar-array containing saturation factor
|
||||||
*
|
*
|
||||||
* T arguments:
|
* T arguments:
|
||||||
* 0 - saturation factor
|
* 0 - optional argument, saturation factor
|
||||||
*
|
*
|
||||||
* Int arguments:
|
* Int arguments:
|
||||||
* 0 - optional argument, corresponds to dimension with 3 channels
|
* 0 - optional argument, corresponds to dimension with 3 channels
|
||||||
|
@ -18456,9 +18490,10 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD();
|
||||||
* This operation adjusts image contrast by given factor ( z = (x - mean) * factor + mean )
|
* This operation adjusts image contrast by given factor ( z = (x - mean) * factor + mean )
|
||||||
* Input arrays:
|
* Input arrays:
|
||||||
* 0 - input array with rank >= 3, must have last one dimension equal 3, that is dimension containing channels.
|
* 0 - input array with rank >= 3, must have last one dimension equal 3, that is dimension containing channels.
|
||||||
|
* 1 - optional argument, input scalar-array containing saturation contrast factor
|
||||||
*
|
*
|
||||||
* T arguments:
|
* T arguments:
|
||||||
* 0 - contrast factor
|
* 0 - optional argument, contrast factor
|
||||||
*
|
*
|
||||||
*/
|
*/
|
||||||
// #if NOT_EXCLUDED(OP_adjust_contrast)
|
// #if NOT_EXCLUDED(OP_adjust_contrast)
|
||||||
|
|
|
@ -117,9 +117,6 @@ public class TFGraphTestAllSameDiff { //Note: Can't extend BaseNd4jTest here a
|
||||||
// 2019/11/15 - failure https://github.com/eclipse/deeplearning4j/issues/8402
|
// 2019/11/15 - failure https://github.com/eclipse/deeplearning4j/issues/8402
|
||||||
"fake_quant/min_max_args_per_channel.*",
|
"fake_quant/min_max_args_per_channel.*",
|
||||||
|
|
||||||
// 2019/11/15 - failure https://github.com/eclipse/deeplearning4j/issues/8403
|
|
||||||
"resize_bilinear/int32.*",
|
|
||||||
|
|
||||||
// Suggesting TF 1.15 bug
|
// Suggesting TF 1.15 bug
|
||||||
"non_max_suppression_v2/float16.*",
|
"non_max_suppression_v2/float16.*",
|
||||||
|
|
||||||
|
|
|
@ -972,7 +972,7 @@ public class CustomOpsTests extends BaseNd4jTest {
|
||||||
INDArray x = Nd4j.rand(1, 2,3,4);
|
INDArray x = Nd4j.rand(1, 2,3,4);
|
||||||
INDArray z = Nd4j.createUninitialized(x.shape());
|
INDArray z = Nd4j.createUninitialized(x.shape());
|
||||||
boolean align = false;
|
boolean align = false;
|
||||||
val op = new ResizeBilinear(x, z, 10, 10, align);
|
val op = new ResizeBilinear(x, z, 10, 10, align, false);
|
||||||
Nd4j.exec(op);
|
Nd4j.exec(op);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1174,6 +1174,7 @@ public class CustomOpsTests extends BaseNd4jTest {
|
||||||
assertEquals(expected, x);
|
assertEquals(expected, x);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Ignore("AS failed 2019/12/04")
|
||||||
@Test
|
@Test
|
||||||
public void testPolygamma() {
|
public void testPolygamma() {
|
||||||
INDArray n = Nd4j.linspace(DataType.FLOAT, 1.0, 1.0, 9).reshape(3,3);
|
INDArray n = Nd4j.linspace(DataType.FLOAT, 1.0, 1.0, 9).reshape(3,3);
|
||||||
|
|
Loading…
Reference in New Issue