correct output empty shapes deducing in split op (#311)
* - correct output empty shapes deducing in split op Signed-off-by: Yurii <iuriish@yahoo.com> * java test fixed Signed-off-by: raver119 <raver119@gmail.com> * - split broadcast::exec function on individual functions corresponding to switch arg Signed-off-by: Yurii <iuriish@yahoo.com> * - split broadcast::exec _int and _bool function on individual functions corresponding to switch arg Signed-off-by: Yurii <iuriish@yahoo.com> Co-authored-by: raver119 <raver119@gmail.com>master
parent
41bde8f885
commit
e42b4e96c3
|
@ -572,6 +572,272 @@ template <typename X, typename Y, typename Z>
|
||||||
DISPATCH_BY_OPNUM_TTT(exec, PARAMS(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo), BROADCAST_OPS);
|
DISPATCH_BY_OPNUM_TTT(exec, PARAMS(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo), BROADCAST_OPS);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
////////////////////////////////////////////////////////////////////////
|
||||||
|
template <typename X, typename Y, typename Z, typename OpType>
|
||||||
|
static void execRank1(const X *x, const Nd4jLong *xShapeInfo, const Y *y, const Nd4jLong *yShapeInfo, Z* z, const Nd4jLong *zShapeInfo) {
|
||||||
|
|
||||||
|
uint zAxis0 = shape::sizeAt(zShapeInfo, 0);
|
||||||
|
Nd4jLong xStrd0 = shape::strideAt(xShapeInfo, 0);
|
||||||
|
Nd4jLong yStrd0 = shape::strideAt(yShapeInfo, 0);
|
||||||
|
Nd4jLong zStrd0 = shape::strideAt(zShapeInfo, 0);
|
||||||
|
|
||||||
|
auto func = PRAGMA_THREADS_FOR{
|
||||||
|
|
||||||
|
if(zStrd0 == 1 && xStrd0 == 1 && yStrd0 == 0) {
|
||||||
|
for (auto i0 = start; i0 < stop; ++i0)
|
||||||
|
z[i0] = OpType::op(x[i0], *y);
|
||||||
|
}
|
||||||
|
else if(zStrd0 == 1 && xStrd0 == 0 && yStrd0 == 1) {
|
||||||
|
for (auto i0 = start; i0 < stop; ++i0)
|
||||||
|
z[i0] = OpType::op(*x, y[i0]);
|
||||||
|
}
|
||||||
|
else if(zStrd0 == 1 && xStrd0 == 1 && yStrd0 == 1) {
|
||||||
|
for (auto i0 = start; i0 < stop; ++i0)
|
||||||
|
z[i0] = OpType::op(x[i0], y[i0]);
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
for (auto i0 = start; i0 < stop; ++i0)
|
||||||
|
z[i0 * zStrd0] = OpType::op(x[i0 * xStrd0], y[i0 * yStrd0]);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
samediff::Threads::parallel_tad(func, 0, zAxis0);
|
||||||
|
}
|
||||||
|
|
||||||
|
////////////////////////////////////////////////////////////////////////
|
||||||
|
template <typename X, typename Y, typename Z, typename OpType>
|
||||||
|
static void execRank2(const X *x, const Nd4jLong *xShapeInfo, const Y *y, const Nd4jLong *yShapeInfo, Z* z, const Nd4jLong *zShapeInfo) {
|
||||||
|
|
||||||
|
uint zAxis0 = shape::sizeAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 1);
|
||||||
|
Nd4jLong xStrd0 = shape::strideAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 1);
|
||||||
|
Nd4jLong yStrd0 = shape::strideAt(yShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 1);
|
||||||
|
Nd4jLong zStrd0 = shape::strideAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 1);
|
||||||
|
|
||||||
|
uint zAxis1 = shape::sizeAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 1 : 0);
|
||||||
|
Nd4jLong xStrd1 = shape::strideAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? 1 : 0);
|
||||||
|
Nd4jLong yStrd1 = shape::strideAt(yShapeInfo, shape::order(zShapeInfo) == 'c' ? 1 : 0);
|
||||||
|
Nd4jLong zStrd1 = shape::strideAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 1 : 0);
|
||||||
|
|
||||||
|
auto func = PRAGMA_THREADS_FOR{
|
||||||
|
|
||||||
|
for (auto i0 = start; i0 < stop; ++i0) {
|
||||||
|
|
||||||
|
auto x0 = x + i0 * xStrd0;
|
||||||
|
auto y0 = y + i0 * yStrd0;
|
||||||
|
auto z0 = z + i0 * zStrd0;
|
||||||
|
|
||||||
|
if(zStrd1 == 1 && xStrd1 == 1 && yStrd1 == 0)
|
||||||
|
for (uint i1 = 0; i1 < zAxis1; ++i1)
|
||||||
|
z0[i1] = OpType::op(x0[i1], *y0);
|
||||||
|
else if(zStrd1 == 1 && xStrd1 == 0 && yStrd1 == 1)
|
||||||
|
for (uint i1 = 0; i1 < zAxis1; ++i1)
|
||||||
|
z0[i1] = OpType::op(*x0, y0[i1]);
|
||||||
|
else if(zStrd1 == 1 && xStrd1 == 1 && yStrd1 == 1)
|
||||||
|
for (uint i1 = 0; i1 < zAxis1; ++i1)
|
||||||
|
z0[i1] = OpType::op(x0[i1], y0[i1]);
|
||||||
|
else
|
||||||
|
for (uint i1 = 0; i1 < zAxis1; ++i1)
|
||||||
|
z0[i1 * zStrd1] = OpType::op(x0[i1 * xStrd1], y0[i1 * yStrd1]);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
samediff::Threads::parallel_tad(func, 0, zAxis0);
|
||||||
|
}
|
||||||
|
|
||||||
|
////////////////////////////////////////////////////////////////////////
|
||||||
|
template <typename X, typename Y, typename Z, typename OpType>
|
||||||
|
static void execRank3(const X *x, const Nd4jLong *xShapeInfo, const Y *y, const Nd4jLong *yShapeInfo, Z* z, const Nd4jLong *zShapeInfo) {
|
||||||
|
|
||||||
|
uint zAxis0 = shape::sizeAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 2);
|
||||||
|
Nd4jLong xStrd0 = shape::strideAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 2);
|
||||||
|
Nd4jLong yStrd0 = shape::strideAt(yShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 2);
|
||||||
|
Nd4jLong zStrd0 = shape::strideAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 2);
|
||||||
|
|
||||||
|
uint zAxis1 = shape::sizeAt(zShapeInfo, 1);
|
||||||
|
Nd4jLong xStrd1 = shape::strideAt(xShapeInfo, 1);
|
||||||
|
Nd4jLong yStrd1 = shape::strideAt(yShapeInfo, 1);
|
||||||
|
Nd4jLong zStrd1 = shape::strideAt(zShapeInfo, 1);
|
||||||
|
|
||||||
|
uint zAxis2 = shape::sizeAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 2 : 0);
|
||||||
|
Nd4jLong xStrd2 = shape::strideAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? 2 : 0);
|
||||||
|
Nd4jLong yStrd2 = shape::strideAt(yShapeInfo, shape::order(zShapeInfo) == 'c' ? 2 : 0);
|
||||||
|
Nd4jLong zStrd2 = shape::strideAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 2 : 0);
|
||||||
|
|
||||||
|
auto func = PRAGMA_THREADS_FOR_2D {
|
||||||
|
|
||||||
|
for (auto i0 = start_x; i0 < stop_x; ++i0) {
|
||||||
|
for (auto i1 = start_y; i1 < stop_y; ++i1) {
|
||||||
|
|
||||||
|
auto x1 = x + i0 * xStrd0 + i1 * xStrd1;
|
||||||
|
auto y1 = y + i0 * yStrd0 + i1 * yStrd1;
|
||||||
|
auto z1 = z + i0 * zStrd0 + i1 * zStrd1;
|
||||||
|
|
||||||
|
if(zStrd2 == 1 && xStrd2 == 1 && yStrd2 == 0)
|
||||||
|
for (uint i2 = 0; i2 < zAxis2; ++i2)
|
||||||
|
z1[i2] = OpType::op(x1[i2], *y1);
|
||||||
|
else if(zStrd2 == 1 && xStrd2 == 0 && yStrd2 == 1)
|
||||||
|
for (uint i2 = 0; i2 < zAxis2; ++i2)
|
||||||
|
z1[i2] = OpType::op(*x1, y1[i2]);
|
||||||
|
else if(zStrd2 == 1 && xStrd2 == 1 && yStrd2 == 1)
|
||||||
|
for (uint i2 = 0; i2 < zAxis2; ++i2)
|
||||||
|
z1[i2] = OpType::op(x1[i2], y1[i2]);
|
||||||
|
else
|
||||||
|
for (uint i2 = 0; i2 < zAxis2; ++i2)
|
||||||
|
z1[i2 * zStrd2] = OpType::op(x1[i2 * xStrd2], y1[i2 * yStrd2]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
samediff::Threads::parallel_for(func, 0,zAxis0,1, 0,zAxis1,1);
|
||||||
|
}
|
||||||
|
|
||||||
|
////////////////////////////////////////////////////////////////////////
|
||||||
|
template <typename X, typename Y, typename Z, typename OpType>
|
||||||
|
static void execRank4(const X *x, const Nd4jLong *xShapeInfo, const Y *y, const Nd4jLong *yShapeInfo, Z* z, const Nd4jLong *zShapeInfo) {
|
||||||
|
|
||||||
|
uint zAxis0 = shape::sizeAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 3);
|
||||||
|
Nd4jLong xStrd0 = shape::strideAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 3);
|
||||||
|
Nd4jLong yStrd0 = shape::strideAt(yShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 3);
|
||||||
|
Nd4jLong zStrd0 = shape::strideAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 3);
|
||||||
|
|
||||||
|
uint zAxis1 = shape::sizeAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 1 : 2);
|
||||||
|
Nd4jLong xStrd1 = shape::strideAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? 1 : 2);
|
||||||
|
Nd4jLong yStrd1 = shape::strideAt(yShapeInfo, shape::order(zShapeInfo) == 'c' ? 1 : 2);
|
||||||
|
Nd4jLong zStrd1 = shape::strideAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 1 : 2);
|
||||||
|
|
||||||
|
uint zAxis2 = shape::sizeAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 2 : 1);
|
||||||
|
Nd4jLong xStrd2 = shape::strideAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? 2 : 1);
|
||||||
|
Nd4jLong yStrd2 = shape::strideAt(yShapeInfo, shape::order(zShapeInfo) == 'c' ? 2 : 1);
|
||||||
|
Nd4jLong zStrd2 = shape::strideAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 2 : 1);
|
||||||
|
|
||||||
|
uint zAxis3 = shape::sizeAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 3 : 0);
|
||||||
|
Nd4jLong xStrd3 = shape::strideAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? 3 : 0);
|
||||||
|
Nd4jLong yStrd3 = shape::strideAt(yShapeInfo, shape::order(zShapeInfo) == 'c' ? 3 : 0);
|
||||||
|
Nd4jLong zStrd3 = shape::strideAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 3 : 0);
|
||||||
|
|
||||||
|
auto func = PRAGMA_THREADS_FOR_3D {
|
||||||
|
|
||||||
|
for (auto i0 = start_x; i0 < stop_x; ++i0) {
|
||||||
|
for (auto i1 = start_y; i1 < stop_y; ++i1) {
|
||||||
|
for (auto i2 = start_z; i2 < stop_z; ++i2) {
|
||||||
|
|
||||||
|
auto x2 = x + i0 * xStrd0 + i1 * xStrd1 + i2 * xStrd2;
|
||||||
|
auto y2 = y + i0 * yStrd0 + i1 * yStrd1 + i2 * yStrd2;
|
||||||
|
auto z2 = z + i0 * zStrd0 + i1 * zStrd1 + i2 * zStrd2;
|
||||||
|
|
||||||
|
if(zStrd3 == 1 && xStrd3 == 1 && yStrd3 == 0)
|
||||||
|
for (uint i3 = 0; i3 < zAxis3; ++i3)
|
||||||
|
z2[i3] = OpType::op(x2[i3], *y2);
|
||||||
|
else if(zStrd3 == 1 && xStrd3 == 0 && yStrd3 == 1)
|
||||||
|
for (uint i3 = 0; i3 < zAxis3; ++i3)
|
||||||
|
z2[i3] = OpType::op(*x2, y2[i3]);
|
||||||
|
else if(zStrd3 == 1 && xStrd3 == 1 && yStrd3 == 1)
|
||||||
|
for (uint i3 = 0; i3 < zAxis3; ++i3)
|
||||||
|
z2[i3] = OpType::op(x2[i3], y2[i3]);
|
||||||
|
else
|
||||||
|
for (uint i3 = 0; i3 < zAxis3; ++i3)
|
||||||
|
z2[i3 * zStrd3] = OpType::op(x2[i3 * xStrd3], y2[i3 * yStrd3]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
samediff::Threads::parallel_for(func, 0,zAxis0,1, 0,zAxis1,1, 0,zAxis2,1);
|
||||||
|
}
|
||||||
|
|
||||||
|
////////////////////////////////////////////////////////////////////////
|
||||||
|
template <typename X, typename Y, typename Z, typename OpType>
|
||||||
|
static void execRank5(const X *x, const Nd4jLong *xShapeInfo, const Y *y, const Nd4jLong *yShapeInfo, Z* z, const Nd4jLong *zShapeInfo) {
|
||||||
|
|
||||||
|
uint zAxis0 = shape::sizeAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 4);
|
||||||
|
Nd4jLong xStrd0 = shape::strideAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 4);
|
||||||
|
Nd4jLong yStrd0 = shape::strideAt(yShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 4);
|
||||||
|
Nd4jLong zStrd0 = shape::strideAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 4);
|
||||||
|
|
||||||
|
uint zAxis1 = shape::sizeAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 1 : 3);
|
||||||
|
Nd4jLong xStrd1 = shape::strideAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? 1 : 3);
|
||||||
|
Nd4jLong yStrd1 = shape::strideAt(yShapeInfo, shape::order(zShapeInfo) == 'c' ? 1 : 3);
|
||||||
|
Nd4jLong zStrd1 = shape::strideAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 1 : 3);
|
||||||
|
|
||||||
|
uint zAxis2 = shape::sizeAt(zShapeInfo, 2);
|
||||||
|
Nd4jLong xStrd2 = shape::strideAt(xShapeInfo, 2);
|
||||||
|
Nd4jLong yStrd2 = shape::strideAt(yShapeInfo, 2);
|
||||||
|
Nd4jLong zStrd2 = shape::strideAt(zShapeInfo, 2);
|
||||||
|
|
||||||
|
uint zAxis3 = shape::sizeAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 3 : 1);
|
||||||
|
Nd4jLong xStrd3 = shape::strideAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? 3 : 1);
|
||||||
|
Nd4jLong yStrd3 = shape::strideAt(yShapeInfo, shape::order(zShapeInfo) == 'c' ? 3 : 1);
|
||||||
|
Nd4jLong zStrd3 = shape::strideAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 3 : 1);
|
||||||
|
|
||||||
|
uint zAxis4 = shape::sizeAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 4 : 0);
|
||||||
|
Nd4jLong xStrd4 = shape::strideAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? 4 : 0);
|
||||||
|
Nd4jLong yStrd4 = shape::strideAt(yShapeInfo, shape::order(zShapeInfo) == 'c' ? 4 : 0);
|
||||||
|
Nd4jLong zStrd4 = shape::strideAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 4 : 0);
|
||||||
|
|
||||||
|
auto func = PRAGMA_THREADS_FOR_3D {
|
||||||
|
|
||||||
|
for (auto i0 = start_x; i0 < stop_x; ++i0) {
|
||||||
|
for (auto i1 = start_y; i1 < stop_y; ++i1) {
|
||||||
|
for (auto i2 = start_z; i2 < stop_z; ++i2) {
|
||||||
|
for (uint i3 = 0; i3 < zAxis3; ++i3) {
|
||||||
|
|
||||||
|
auto x3 = x + i0 * xStrd0 + i1 * xStrd1 + i2 * xStrd2 + i3 * xStrd3;
|
||||||
|
auto y3 = y + i0 * yStrd0 + i1 * yStrd1 + i2 * yStrd2 + i3 * yStrd3;
|
||||||
|
auto z3 = z + i0 * zStrd0 + i1 * zStrd1 + i2 * zStrd2 + i3 * zStrd3;
|
||||||
|
|
||||||
|
if(zStrd4 == 1 && xStrd4 == 1 && yStrd4 == 0)
|
||||||
|
for (uint i4 = 0; i4 < zAxis4; ++i4)
|
||||||
|
z3[i4] = OpType::op(x3[i4], *y3);
|
||||||
|
else if(zStrd4 == 1 && xStrd4 == 0 && yStrd4 == 1)
|
||||||
|
for (uint i4 = 0; i4 < zAxis4; ++i4)
|
||||||
|
z3[i4] = OpType::op(*x3, y3[i4]);
|
||||||
|
else if(zStrd4 == 1 && xStrd4 == 1 && yStrd4 == 1)
|
||||||
|
for (uint i4 = 0; i4 < zAxis4; ++i4)
|
||||||
|
z3[i4] = OpType::op(x3[i4], y3[i4]);
|
||||||
|
else
|
||||||
|
for (uint i4 = 0; i4 < zAxis4; ++i4)
|
||||||
|
z3[i4 * zStrd4] = OpType::op(x3[i4 * xStrd4], y3[i4 * yStrd4]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
samediff::Threads::parallel_for(func, 0,zAxis0,1, 0,zAxis1,1, 0,zAxis2,1);
|
||||||
|
}
|
||||||
|
|
||||||
|
////////////////////////////////////////////////////////////////////////
|
||||||
|
template <typename X, typename Y, typename Z, typename OpType>
|
||||||
|
static void execDefault(const X *x, const Nd4jLong *xShapeInfo, const Y *y, const Nd4jLong *yShapeInfo, Z* z, const Nd4jLong *zShapeInfo) {
|
||||||
|
|
||||||
|
const bool xzSameOffsets = shape::haveSameShapeAndStrides(xShapeInfo, zShapeInfo);
|
||||||
|
const bool yzSameOffsets = shape::haveSameShapeAndStrides(yShapeInfo, zShapeInfo);
|
||||||
|
|
||||||
|
const int rank = shape::rank(zShapeInfo); // xRank = yRank = zRank
|
||||||
|
|
||||||
|
auto func = PRAGMA_THREADS_FOR{
|
||||||
|
|
||||||
|
int xCoords[MAX_RANK], yCoords[MAX_RANK], zCoords[MAX_RANK];
|
||||||
|
|
||||||
|
for (auto i = start; i < stop; ++i) {
|
||||||
|
|
||||||
|
shape::index2coordsCPU(start, i, zShapeInfo, zCoords);
|
||||||
|
|
||||||
|
for (uint j = 0; j < rank; ++j) {
|
||||||
|
xCoords[j] = shape::sizeAt(xShapeInfo, j) == 1 ? 0 : zCoords[j];
|
||||||
|
yCoords[j] = shape::sizeAt(yShapeInfo, j) == 1 ? 0 : zCoords[j];
|
||||||
|
}
|
||||||
|
|
||||||
|
const auto zOffset = shape::getOffset(zShapeInfo, zCoords);
|
||||||
|
const auto xOffset = xzSameOffsets ? zOffset : shape::getOffset(xShapeInfo, xCoords);
|
||||||
|
const auto yOffset = yzSameOffsets ? zOffset : shape::getOffset(yShapeInfo, yCoords);
|
||||||
|
|
||||||
|
z[zOffset] = OpType::op(x[xOffset], y[yOffset]);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
samediff::Threads::parallel_for(func, 0, shape::length(zShapeInfo));
|
||||||
|
}
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////
|
||||||
template <typename X, typename Y, typename Z>
|
template <typename X, typename Y, typename Z>
|
||||||
template<typename OpType>
|
template<typename OpType>
|
||||||
|
@ -582,220 +848,26 @@ void Broadcast<X, Y, Z>::exec(const void *vx, const Nd4jLong *xShapeInfo, const
|
||||||
Z* z = reinterpret_cast<Z*>(vz);
|
Z* z = reinterpret_cast<Z*>(vz);
|
||||||
|
|
||||||
const int rank = shape::rank(zShapeInfo); // xRank = yRank = zRank
|
const int rank = shape::rank(zShapeInfo); // xRank = yRank = zRank
|
||||||
const char zOrder = shape::order(zShapeInfo);
|
|
||||||
|
|
||||||
uint xAxis0 = shape::sizeAt(xShapeInfo, zOrder == 'c' ? 0 : rank-1);
|
|
||||||
uint xAxis1 = shape::sizeAt(xShapeInfo, zOrder == 'c' ? 1 : rank-2);
|
|
||||||
uint xAxis2 = rank > 2 ? shape::sizeAt(xShapeInfo, zOrder == 'c' ? 2 : rank - 3) : 0;
|
|
||||||
uint xAxis3 = rank > 3 ? shape::sizeAt(xShapeInfo, zOrder == 'c' ? 3 : rank - 4) : 0;
|
|
||||||
uint xAxis4 = rank > 4 ? shape::sizeAt(xShapeInfo, zOrder == 'c' ? 4 : rank - 5) : 0;
|
|
||||||
Nd4jLong xStrd0 = shape::strideAt(xShapeInfo, zOrder == 'c' ? 0 : rank-1);
|
|
||||||
Nd4jLong xStrd1 = shape::strideAt(xShapeInfo, zOrder == 'c' ? 1 : rank-2);
|
|
||||||
Nd4jLong xStrd2 = rank > 2 ? shape::strideAt(xShapeInfo, zOrder == 'c' ? 2 : rank - 3) : 0;
|
|
||||||
Nd4jLong xStrd3 = rank > 3 ? shape::strideAt(xShapeInfo, zOrder == 'c' ? 3 : rank - 4) : 0;
|
|
||||||
Nd4jLong xStrd4 = rank > 4 ? shape::strideAt(xShapeInfo, zOrder == 'c' ? 4 : rank - 5) : 0;
|
|
||||||
|
|
||||||
uint yAxis0 = shape::sizeAt(yShapeInfo, zOrder == 'c' ? 0 : rank-1);
|
|
||||||
uint yAxis1 = shape::sizeAt(yShapeInfo, zOrder == 'c' ? 1 : rank-2);
|
|
||||||
uint yAxis2 = rank > 2 ? shape::sizeAt(yShapeInfo, zOrder == 'c' ? 2 : rank - 3) : 0;
|
|
||||||
uint yAxis3 = rank > 3 ? shape::sizeAt(yShapeInfo, zOrder == 'c' ? 3 : rank - 4) : 0;
|
|
||||||
uint yAxis4 = rank > 4 ? shape::sizeAt(yShapeInfo, zOrder == 'c' ? 4 : rank - 5) : 0;
|
|
||||||
Nd4jLong yStrd0 = shape::strideAt(yShapeInfo, zOrder == 'c' ? 0 : rank-1);
|
|
||||||
Nd4jLong yStrd1 = shape::strideAt(yShapeInfo, zOrder == 'c' ? 1 : rank-2);
|
|
||||||
Nd4jLong yStrd2 = rank > 2 ? shape::strideAt(yShapeInfo, zOrder == 'c' ? 2 : rank - 3) : 0;
|
|
||||||
Nd4jLong yStrd3 = rank > 3 ? shape::strideAt(yShapeInfo, zOrder == 'c' ? 3 : rank - 4) : 0;
|
|
||||||
Nd4jLong yStrd4 = rank > 4 ? shape::strideAt(yShapeInfo, zOrder == 'c' ? 4 : rank - 5) : 0;
|
|
||||||
|
|
||||||
uint zAxis0 = shape::sizeAt(zShapeInfo, zOrder == 'c' ? 0 : rank-1);
|
|
||||||
uint zAxis1 = shape::sizeAt(zShapeInfo, zOrder == 'c' ? 1 : rank-2);
|
|
||||||
uint zAxis2 = rank > 2 ? shape::sizeAt(zShapeInfo, zOrder == 'c' ? 2 : rank - 3) : 0;
|
|
||||||
uint zAxis3 = rank > 3 ? shape::sizeAt(zShapeInfo, zOrder == 'c' ? 3 : rank - 4) : 0;
|
|
||||||
uint zAxis4 = rank > 4 ? shape::sizeAt(zShapeInfo, zOrder == 'c' ? 4 : rank - 5) : 0;
|
|
||||||
Nd4jLong zStrd0 = shape::strideAt(zShapeInfo, zOrder == 'c' ? 0 : rank-1);
|
|
||||||
Nd4jLong zStrd1 = shape::strideAt(zShapeInfo, zOrder == 'c' ? 1 : rank-2);
|
|
||||||
Nd4jLong zStrd2 = rank > 2 ? shape::strideAt(zShapeInfo, zOrder == 'c' ? 2 : rank - 3) : 0;
|
|
||||||
Nd4jLong zStrd3 = rank > 3 ? shape::strideAt(zShapeInfo, zOrder == 'c' ? 3 : rank - 4) : 0;
|
|
||||||
Nd4jLong zStrd4 = rank > 4 ? shape::strideAt(zShapeInfo, zOrder == 'c' ? 4 : rank - 5) : 0;
|
|
||||||
|
|
||||||
switch (rank) {
|
switch (rank) {
|
||||||
|
|
||||||
case 1: {
|
case 1:
|
||||||
|
execRank1<X,Y,Z, OpType>(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo);
|
||||||
auto func = PRAGMA_THREADS_FOR{
|
break;
|
||||||
|
case 2:
|
||||||
if(zStrd0 == 1 && xStrd0 == 1 && yStrd0 == 0) {
|
execRank2<X,Y,Z, OpType>(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo);
|
||||||
for (auto i0 = start; i0 < stop; ++i0)
|
break;
|
||||||
z[i0] = OpType::op(x[i0], *y);
|
case 3:
|
||||||
}
|
execRank3<X,Y,Z, OpType>(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo);
|
||||||
else if(zStrd0 == 1 && xStrd0 == 0 && yStrd0 == 1) {
|
break;
|
||||||
for (auto i0 = start; i0 < stop; ++i0)
|
case 4:
|
||||||
z[i0] = OpType::op(*x, y[i0]);
|
execRank4<X,Y,Z, OpType>(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo);
|
||||||
}
|
break;
|
||||||
else if(zStrd0 == 1 && xStrd0 == 1 && yStrd0 == 1) {
|
case 5:
|
||||||
for (auto i0 = start; i0 < stop; ++i0)
|
execRank5<X,Y,Z, OpType>(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo);
|
||||||
z[i0] = OpType::op(x[i0], y[i0]);
|
break;
|
||||||
}
|
default:
|
||||||
else {
|
execDefault<X,Y,Z, OpType>(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo);
|
||||||
for (auto i0 = start; i0 < stop; ++i0)
|
|
||||||
z[i0 * zStrd0] = OpType::op(x[i0 * xStrd0], y[i0 * yStrd0]);
|
|
||||||
}
|
|
||||||
};
|
|
||||||
samediff::Threads::parallel_tad(func, 0, zAxis0);
|
|
||||||
}
|
|
||||||
break;
|
|
||||||
|
|
||||||
case 2: {
|
|
||||||
|
|
||||||
auto func = PRAGMA_THREADS_FOR{
|
|
||||||
|
|
||||||
for (auto i0 = start; i0 < stop; ++i0) {
|
|
||||||
|
|
||||||
auto x0 = x + i0 * xStrd0;
|
|
||||||
auto y0 = y + i0 * yStrd0;
|
|
||||||
auto z0 = z + i0 * zStrd0;
|
|
||||||
|
|
||||||
if(zStrd1 == 1 && xStrd1 == 1 && yStrd1 == 0)
|
|
||||||
for (uint i1 = 0; i1 < zAxis1; ++i1)
|
|
||||||
z0[i1] = OpType::op(x0[i1], *y0);
|
|
||||||
else if(zStrd1 == 1 && xStrd1 == 0 && yStrd1 == 1)
|
|
||||||
for (uint i1 = 0; i1 < zAxis1; ++i1)
|
|
||||||
z0[i1] = OpType::op(*x0, y0[i1]);
|
|
||||||
else if(zStrd1 == 1 && xStrd1 == 1 && yStrd1 == 1)
|
|
||||||
for (uint i1 = 0; i1 < zAxis1; ++i1)
|
|
||||||
z0[i1] = OpType::op(x0[i1], y0[i1]);
|
|
||||||
else
|
|
||||||
for (uint i1 = 0; i1 < zAxis1; ++i1)
|
|
||||||
z0[i1 * zStrd1] = OpType::op(x0[i1 * xStrd1], y0[i1 * yStrd1]);
|
|
||||||
}
|
|
||||||
};
|
|
||||||
samediff::Threads::parallel_tad(func, 0, zAxis0);
|
|
||||||
}
|
|
||||||
break;
|
|
||||||
|
|
||||||
case 3: {
|
|
||||||
|
|
||||||
auto func = PRAGMA_THREADS_FOR_2D {
|
|
||||||
|
|
||||||
for (auto i0 = start_x; i0 < stop_x; ++i0) {
|
|
||||||
for (auto i1 = start_y; i1 < stop_y; ++i1) {
|
|
||||||
|
|
||||||
auto x1 = x + i0 * xStrd0 + i1 * xStrd1;
|
|
||||||
auto y1 = y + i0 * yStrd0 + i1 * yStrd1;
|
|
||||||
auto z1 = z + i0 * zStrd0 + i1 * zStrd1;
|
|
||||||
|
|
||||||
if(zStrd2 == 1 && xStrd2 == 1 && yStrd2 == 0)
|
|
||||||
for (uint i2 = 0; i2 < zAxis2; ++i2)
|
|
||||||
z1[i2] = OpType::op(x1[i2], *y1);
|
|
||||||
else if(zStrd2 == 1 && xStrd2 == 0 && yStrd2 == 1)
|
|
||||||
for (uint i2 = 0; i2 < zAxis2; ++i2)
|
|
||||||
z1[i2] = OpType::op(*x1, y1[i2]);
|
|
||||||
else if(zStrd2 == 1 && xStrd2 == 1 && yStrd2 == 1)
|
|
||||||
for (uint i2 = 0; i2 < zAxis2; ++i2)
|
|
||||||
z1[i2] = OpType::op(x1[i2], y1[i2]);
|
|
||||||
else
|
|
||||||
for (uint i2 = 0; i2 < zAxis2; ++i2)
|
|
||||||
z1[i2 * zStrd2] = OpType::op(x1[i2 * xStrd2], y1[i2 * yStrd2]);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
};
|
|
||||||
samediff::Threads::parallel_for(func, 0,zAxis0,1, 0,zAxis1,1);
|
|
||||||
}
|
|
||||||
break;
|
|
||||||
|
|
||||||
case 4: {
|
|
||||||
|
|
||||||
auto func = PRAGMA_THREADS_FOR_3D {
|
|
||||||
|
|
||||||
for (auto i0 = start_x; i0 < stop_x; ++i0) {
|
|
||||||
for (auto i1 = start_y; i1 < stop_y; ++i1) {
|
|
||||||
for (auto i2 = start_z; i2 < stop_z; ++i2) {
|
|
||||||
|
|
||||||
auto x2 = x + i0 * xStrd0 + i1 * xStrd1 + i2 * xStrd2;
|
|
||||||
auto y2 = y + i0 * yStrd0 + i1 * yStrd1 + i2 * yStrd2;
|
|
||||||
auto z2 = z + i0 * zStrd0 + i1 * zStrd1 + i2 * zStrd2;
|
|
||||||
|
|
||||||
if(zStrd3 == 1 && xStrd3 == 1 && yStrd3 == 0)
|
|
||||||
for (uint i3 = 0; i3 < zAxis3; ++i3)
|
|
||||||
z2[i3] = OpType::op(x2[i3], *y2);
|
|
||||||
else if(zStrd3 == 1 && xStrd3 == 0 && yStrd3 == 1)
|
|
||||||
for (uint i3 = 0; i3 < zAxis3; ++i3)
|
|
||||||
z2[i3] = OpType::op(*x2, y2[i3]);
|
|
||||||
else if(zStrd3 == 1 && xStrd3 == 1 && yStrd3 == 1)
|
|
||||||
for (uint i3 = 0; i3 < zAxis3; ++i3)
|
|
||||||
z2[i3] = OpType::op(x2[i3], y2[i3]);
|
|
||||||
else
|
|
||||||
for (uint i3 = 0; i3 < zAxis3; ++i3)
|
|
||||||
z2[i3 * zStrd3] = OpType::op(x2[i3 * xStrd3], y2[i3 * yStrd3]);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
};
|
|
||||||
samediff::Threads::parallel_for(func, 0,zAxis0,1, 0,zAxis1,1, 0,zAxis2,1);
|
|
||||||
}
|
|
||||||
break;
|
|
||||||
|
|
||||||
case 5: {
|
|
||||||
|
|
||||||
auto func = PRAGMA_THREADS_FOR_3D {
|
|
||||||
|
|
||||||
for (auto i0 = start_x; i0 < stop_x; ++i0) {
|
|
||||||
for (auto i1 = start_y; i1 < stop_y; ++i1) {
|
|
||||||
for (auto i2 = start_z; i2 < stop_z; ++i2) {
|
|
||||||
for (uint i3 = 0; i3 < zAxis3; ++i3) {
|
|
||||||
|
|
||||||
auto x3 = x + i0 * xStrd0 + i1 * xStrd1 + i2 * xStrd2 + i3 * xStrd3;
|
|
||||||
auto y3 = y + i0 * yStrd0 + i1 * yStrd1 + i2 * yStrd2 + i3 * yStrd3;
|
|
||||||
auto z3 = z + i0 * zStrd0 + i1 * zStrd1 + i2 * zStrd2 + i3 * zStrd3;
|
|
||||||
|
|
||||||
if(zStrd4 == 1 && xStrd4 == 1 && yStrd4 == 0)
|
|
||||||
for (uint i4 = 0; i4 < zAxis4; ++i4)
|
|
||||||
z3[i4] = OpType::op(x3[i4], *y3);
|
|
||||||
else if(zStrd4 == 1 && xStrd4 == 0 && yStrd4 == 1)
|
|
||||||
for (uint i4 = 0; i4 < zAxis4; ++i4)
|
|
||||||
z3[i4] = OpType::op(*x3, y3[i4]);
|
|
||||||
else if(zStrd4 == 1 && xStrd4 == 1 && yStrd4 == 1)
|
|
||||||
for (uint i4 = 0; i4 < zAxis4; ++i4)
|
|
||||||
z3[i4] = OpType::op(x3[i4], y3[i4]);
|
|
||||||
else
|
|
||||||
for (uint i4 = 0; i4 < zAxis4; ++i4)
|
|
||||||
z3[i4 * zStrd4] = OpType::op(x3[i4 * xStrd4], y3[i4 * yStrd4]);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
};
|
|
||||||
samediff::Threads::parallel_for(func, 0,zAxis0,1, 0,zAxis1,1, 0,zAxis2,1);
|
|
||||||
}
|
|
||||||
break;
|
|
||||||
|
|
||||||
default: {
|
|
||||||
|
|
||||||
const bool xzSameOffsets = shape::haveSameShapeAndStrides(xShapeInfo, zShapeInfo);
|
|
||||||
const bool yzSameOffsets = shape::haveSameShapeAndStrides(yShapeInfo, zShapeInfo);
|
|
||||||
|
|
||||||
auto func = PRAGMA_THREADS_FOR{
|
|
||||||
|
|
||||||
int xCoords[MAX_RANK], yCoords[MAX_RANK], zCoords[MAX_RANK];
|
|
||||||
|
|
||||||
for (auto i = start; i < stop; ++i) {
|
|
||||||
|
|
||||||
shape::index2coordsCPU(start, i, zShapeInfo, zCoords);
|
|
||||||
|
|
||||||
for (uint j = 0; j < rank; ++j) {
|
|
||||||
xCoords[j] = shape::sizeAt(xShapeInfo, j) == 1 ? 0 : zCoords[j];
|
|
||||||
yCoords[j] = shape::sizeAt(yShapeInfo, j) == 1 ? 0 : zCoords[j];
|
|
||||||
}
|
|
||||||
|
|
||||||
const auto zOffset = shape::getOffset(zShapeInfo, zCoords);
|
|
||||||
const auto xOffset = xzSameOffsets ? zOffset : shape::getOffset(xShapeInfo, xCoords);
|
|
||||||
const auto yOffset = yzSameOffsets ? zOffset : shape::getOffset(yShapeInfo, yCoords);
|
|
||||||
|
|
||||||
z[zOffset] = OpType::op(x[xOffset], y[yOffset]);
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
samediff::Threads::parallel_for(func, 0, shape::length(zShapeInfo));
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -453,6 +453,271 @@ namespace broadcast {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
////////////////////////////////////////////////////////////////////////
|
||||||
|
template <typename X, typename Z, typename OpType>
|
||||||
|
static void execRank1(const X *x, const Nd4jLong *xShapeInfo, const X *y, const Nd4jLong *yShapeInfo, Z* z, const Nd4jLong *zShapeInfo, X* extraParams) {
|
||||||
|
|
||||||
|
uint zAxis0 = shape::sizeAt(zShapeInfo, 0);
|
||||||
|
Nd4jLong xStrd0 = shape::strideAt(xShapeInfo, 0);
|
||||||
|
Nd4jLong yStrd0 = shape::strideAt(yShapeInfo, 0);
|
||||||
|
Nd4jLong zStrd0 = shape::strideAt(zShapeInfo, 0);
|
||||||
|
|
||||||
|
auto func = PRAGMA_THREADS_FOR{
|
||||||
|
|
||||||
|
if(zStrd0 == 1 && xStrd0 == 1 && yStrd0 == 0) {
|
||||||
|
for (auto i0 = start; i0 < stop; ++i0)
|
||||||
|
z[i0] = OpType::op(x[i0], *y, extraParams);
|
||||||
|
}
|
||||||
|
else if(zStrd0 == 1 && xStrd0 == 0 && yStrd0 == 1) {
|
||||||
|
for (auto i0 = start; i0 < stop; ++i0)
|
||||||
|
z[i0] = OpType::op(*x, y[i0], extraParams);
|
||||||
|
}
|
||||||
|
else if(zStrd0 == 1 && xStrd0 == 1 && yStrd0 == 1) {
|
||||||
|
for (auto i0 = start; i0 < stop; ++i0)
|
||||||
|
z[i0] = OpType::op(x[i0], y[i0], extraParams);
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
for (auto i0 = start; i0 < stop; ++i0)
|
||||||
|
z[i0 * zStrd0] = OpType::op(x[i0 * xStrd0], y[i0 * yStrd0], extraParams);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
samediff::Threads::parallel_tad(func, 0, zAxis0);
|
||||||
|
}
|
||||||
|
|
||||||
|
////////////////////////////////////////////////////////////////////////
|
||||||
|
template <typename X, typename Z, typename OpType>
|
||||||
|
static void execRank2(const X *x, const Nd4jLong *xShapeInfo, const X *y, const Nd4jLong *yShapeInfo, Z* z, const Nd4jLong *zShapeInfo, X* extraParams) {
|
||||||
|
|
||||||
|
uint zAxis0 = shape::sizeAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 1);
|
||||||
|
Nd4jLong xStrd0 = shape::strideAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 1);
|
||||||
|
Nd4jLong yStrd0 = shape::strideAt(yShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 1);
|
||||||
|
Nd4jLong zStrd0 = shape::strideAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 1);
|
||||||
|
|
||||||
|
uint zAxis1 = shape::sizeAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 1 : 0);
|
||||||
|
Nd4jLong xStrd1 = shape::strideAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? 1 : 0);
|
||||||
|
Nd4jLong yStrd1 = shape::strideAt(yShapeInfo, shape::order(zShapeInfo) == 'c' ? 1 : 0);
|
||||||
|
Nd4jLong zStrd1 = shape::strideAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 1 : 0);
|
||||||
|
|
||||||
|
auto func = PRAGMA_THREADS_FOR{
|
||||||
|
|
||||||
|
for (auto i0 = start; i0 < stop; ++i0) {
|
||||||
|
|
||||||
|
auto x0 = x + i0 * xStrd0;
|
||||||
|
auto y0 = y + i0 * yStrd0;
|
||||||
|
auto z0 = z + i0 * zStrd0;
|
||||||
|
|
||||||
|
if(zStrd1 == 1 && xStrd1 == 1 && yStrd1 == 0)
|
||||||
|
for (uint i1 = 0; i1 < zAxis1; ++i1)
|
||||||
|
z0[i1] = OpType::op(x0[i1], *y0, extraParams);
|
||||||
|
else if(zStrd1 == 1 && xStrd1 == 0 && yStrd1 == 1)
|
||||||
|
for (uint i1 = 0; i1 < zAxis1; ++i1)
|
||||||
|
z0[i1] = OpType::op(*x0, y0[i1], extraParams);
|
||||||
|
else if(zStrd1 == 1 && xStrd1 == 1 && yStrd1 == 1)
|
||||||
|
for (uint i1 = 0; i1 < zAxis1; ++i1)
|
||||||
|
z0[i1] = OpType::op(x0[i1], y0[i1], extraParams);
|
||||||
|
else
|
||||||
|
for (uint i1 = 0; i1 < zAxis1; ++i1)
|
||||||
|
z0[i1 * zStrd1] = OpType::op(x0[i1 * xStrd1], y0[i1 * yStrd1], extraParams);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
samediff::Threads::parallel_tad(func, 0, zAxis0);
|
||||||
|
}
|
||||||
|
|
||||||
|
////////////////////////////////////////////////////////////////////////
|
||||||
|
template <typename X, typename Z, typename OpType>
|
||||||
|
static void execRank3(const X *x, const Nd4jLong *xShapeInfo, const X *y, const Nd4jLong *yShapeInfo, Z* z, const Nd4jLong *zShapeInfo, X* extraParams) {
|
||||||
|
|
||||||
|
uint zAxis0 = shape::sizeAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 2);
|
||||||
|
Nd4jLong xStrd0 = shape::strideAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 2);
|
||||||
|
Nd4jLong yStrd0 = shape::strideAt(yShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 2);
|
||||||
|
Nd4jLong zStrd0 = shape::strideAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 2);
|
||||||
|
|
||||||
|
uint zAxis1 = shape::sizeAt(zShapeInfo, 1);
|
||||||
|
Nd4jLong xStrd1 = shape::strideAt(xShapeInfo, 1);
|
||||||
|
Nd4jLong yStrd1 = shape::strideAt(yShapeInfo, 1);
|
||||||
|
Nd4jLong zStrd1 = shape::strideAt(zShapeInfo, 1);
|
||||||
|
|
||||||
|
uint zAxis2 = shape::sizeAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 2 : 0);
|
||||||
|
Nd4jLong xStrd2 = shape::strideAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? 2 : 0);
|
||||||
|
Nd4jLong yStrd2 = shape::strideAt(yShapeInfo, shape::order(zShapeInfo) == 'c' ? 2 : 0);
|
||||||
|
Nd4jLong zStrd2 = shape::strideAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 2 : 0);
|
||||||
|
|
||||||
|
auto func = PRAGMA_THREADS_FOR_2D {
|
||||||
|
|
||||||
|
for (auto i0 = start_x; i0 < stop_x; ++i0) {
|
||||||
|
for (auto i1 = start_y; i1 < stop_y; ++i1) {
|
||||||
|
|
||||||
|
auto x1 = x + i0 * xStrd0 + i1 * xStrd1;
|
||||||
|
auto y1 = y + i0 * yStrd0 + i1 * yStrd1;
|
||||||
|
auto z1 = z + i0 * zStrd0 + i1 * zStrd1;
|
||||||
|
|
||||||
|
if(zStrd2 == 1 && xStrd2 == 1 && yStrd2 == 0)
|
||||||
|
for (uint i2 = 0; i2 < zAxis2; ++i2)
|
||||||
|
z1[i2] = OpType::op(x1[i2], *y1, extraParams);
|
||||||
|
else if(zStrd2 == 1 && xStrd2 == 0 && yStrd2 == 1)
|
||||||
|
for (uint i2 = 0; i2 < zAxis2; ++i2)
|
||||||
|
z1[i2] = OpType::op(*x1, y1[i2], extraParams);
|
||||||
|
else if(zStrd2 == 1 && xStrd2 == 1 && yStrd2 == 1)
|
||||||
|
for (uint i2 = 0; i2 < zAxis2; ++i2)
|
||||||
|
z1[i2] = OpType::op(x1[i2], y1[i2], extraParams);
|
||||||
|
else
|
||||||
|
for (uint i2 = 0; i2 < zAxis2; ++i2)
|
||||||
|
z1[i2 * zStrd2] = OpType::op(x1[i2 * xStrd2], y1[i2 * yStrd2], extraParams);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
samediff::Threads::parallel_for(func, 0,zAxis0,1, 0,zAxis1,1);
|
||||||
|
}
|
||||||
|
|
||||||
|
////////////////////////////////////////////////////////////////////////
|
||||||
|
template <typename X, typename Z, typename OpType>
|
||||||
|
static void execRank4(const X *x, const Nd4jLong *xShapeInfo, const X *y, const Nd4jLong *yShapeInfo, Z* z, const Nd4jLong *zShapeInfo, X* extraParams) {
|
||||||
|
|
||||||
|
uint zAxis0 = shape::sizeAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 3);
|
||||||
|
Nd4jLong xStrd0 = shape::strideAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 3);
|
||||||
|
Nd4jLong yStrd0 = shape::strideAt(yShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 3);
|
||||||
|
Nd4jLong zStrd0 = shape::strideAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 3);
|
||||||
|
|
||||||
|
uint zAxis1 = shape::sizeAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 1 : 2);
|
||||||
|
Nd4jLong xStrd1 = shape::strideAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? 1 : 2);
|
||||||
|
Nd4jLong yStrd1 = shape::strideAt(yShapeInfo, shape::order(zShapeInfo) == 'c' ? 1 : 2);
|
||||||
|
Nd4jLong zStrd1 = shape::strideAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 1 : 2);
|
||||||
|
|
||||||
|
uint zAxis2 = shape::sizeAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 2 : 1);
|
||||||
|
Nd4jLong xStrd2 = shape::strideAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? 2 : 1);
|
||||||
|
Nd4jLong yStrd2 = shape::strideAt(yShapeInfo, shape::order(zShapeInfo) == 'c' ? 2 : 1);
|
||||||
|
Nd4jLong zStrd2 = shape::strideAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 2 : 1);
|
||||||
|
|
||||||
|
uint zAxis3 = shape::sizeAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 3 : 0);
|
||||||
|
Nd4jLong xStrd3 = shape::strideAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? 3 : 0);
|
||||||
|
Nd4jLong yStrd3 = shape::strideAt(yShapeInfo, shape::order(zShapeInfo) == 'c' ? 3 : 0);
|
||||||
|
Nd4jLong zStrd3 = shape::strideAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 3 : 0);
|
||||||
|
|
||||||
|
auto func = PRAGMA_THREADS_FOR_3D {
|
||||||
|
|
||||||
|
for (auto i0 = start_x; i0 < stop_x; ++i0) {
|
||||||
|
for (auto i1 = start_y; i1 < stop_y; ++i1) {
|
||||||
|
for (auto i2 = start_z; i2 < stop_z; ++i2) {
|
||||||
|
|
||||||
|
auto x2 = x + i0 * xStrd0 + i1 * xStrd1 + i2 * xStrd2;
|
||||||
|
auto y2 = y + i0 * yStrd0 + i1 * yStrd1 + i2 * yStrd2;
|
||||||
|
auto z2 = z + i0 * zStrd0 + i1 * zStrd1 + i2 * zStrd2;
|
||||||
|
|
||||||
|
if(zStrd3 == 1 && xStrd3 == 1 && yStrd3 == 0)
|
||||||
|
for (uint i3 = 0; i3 < zAxis3; ++i3)
|
||||||
|
z2[i3] = OpType::op(x2[i3], *y2, extraParams);
|
||||||
|
else if(zStrd3 == 1 && xStrd3 == 0 && yStrd3 == 1)
|
||||||
|
for (uint i3 = 0; i3 < zAxis3; ++i3)
|
||||||
|
z2[i3] = OpType::op(*x2, y2[i3], extraParams);
|
||||||
|
else if(zStrd3 == 1 && xStrd3 == 1 && yStrd3 == 1)
|
||||||
|
for (uint i3 = 0; i3 < zAxis3; ++i3)
|
||||||
|
z2[i3] = OpType::op(x2[i3], y2[i3], extraParams);
|
||||||
|
else
|
||||||
|
for (uint i3 = 0; i3 < zAxis3; ++i3)
|
||||||
|
z2[i3 * zStrd3] = OpType::op(x2[i3 * xStrd3], y2[i3 * yStrd3], extraParams);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
samediff::Threads::parallel_for(func, 0,zAxis0,1, 0,zAxis1,1, 0,zAxis2,1);
|
||||||
|
}
|
||||||
|
|
||||||
|
////////////////////////////////////////////////////////////////////////
|
||||||
|
template <typename X, typename Z, typename OpType>
|
||||||
|
static void execRank5(const X *x, const Nd4jLong *xShapeInfo, const X *y, const Nd4jLong *yShapeInfo, Z* z, const Nd4jLong *zShapeInfo, X* extraParams) {
|
||||||
|
|
||||||
|
uint zAxis0 = shape::sizeAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 4);
|
||||||
|
Nd4jLong xStrd0 = shape::strideAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 4);
|
||||||
|
Nd4jLong yStrd0 = shape::strideAt(yShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 4);
|
||||||
|
Nd4jLong zStrd0 = shape::strideAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 4);
|
||||||
|
|
||||||
|
uint zAxis1 = shape::sizeAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 1 : 3);
|
||||||
|
Nd4jLong xStrd1 = shape::strideAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? 1 : 3);
|
||||||
|
Nd4jLong yStrd1 = shape::strideAt(yShapeInfo, shape::order(zShapeInfo) == 'c' ? 1 : 3);
|
||||||
|
Nd4jLong zStrd1 = shape::strideAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 1 : 3);
|
||||||
|
|
||||||
|
uint zAxis2 = shape::sizeAt(zShapeInfo, 2);
|
||||||
|
Nd4jLong xStrd2 = shape::strideAt(xShapeInfo, 2);
|
||||||
|
Nd4jLong yStrd2 = shape::strideAt(yShapeInfo, 2);
|
||||||
|
Nd4jLong zStrd2 = shape::strideAt(zShapeInfo, 2);
|
||||||
|
|
||||||
|
uint zAxis3 = shape::sizeAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 3 : 1);
|
||||||
|
Nd4jLong xStrd3 = shape::strideAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? 3 : 1);
|
||||||
|
Nd4jLong yStrd3 = shape::strideAt(yShapeInfo, shape::order(zShapeInfo) == 'c' ? 3 : 1);
|
||||||
|
Nd4jLong zStrd3 = shape::strideAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 3 : 1);
|
||||||
|
|
||||||
|
uint zAxis4 = shape::sizeAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 4 : 0);
|
||||||
|
Nd4jLong xStrd4 = shape::strideAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? 4 : 0);
|
||||||
|
Nd4jLong yStrd4 = shape::strideAt(yShapeInfo, shape::order(zShapeInfo) == 'c' ? 4 : 0);
|
||||||
|
Nd4jLong zStrd4 = shape::strideAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 4 : 0);
|
||||||
|
|
||||||
|
auto func = PRAGMA_THREADS_FOR_3D {
|
||||||
|
|
||||||
|
for (auto i0 = start_x; i0 < stop_x; ++i0) {
|
||||||
|
for (auto i1 = start_y; i1 < stop_y; ++i1) {
|
||||||
|
for (auto i2 = start_z; i2 < stop_z; ++i2) {
|
||||||
|
for (uint i3 = 0; i3 < zAxis3; ++i3) {
|
||||||
|
|
||||||
|
auto x3 = x + i0 * xStrd0 + i1 * xStrd1 + i2 * xStrd2 + i3 * xStrd3;
|
||||||
|
auto y3 = y + i0 * yStrd0 + i1 * yStrd1 + i2 * yStrd2 + i3 * yStrd3;
|
||||||
|
auto z3 = z + i0 * zStrd0 + i1 * zStrd1 + i2 * zStrd2 + i3 * zStrd3;
|
||||||
|
|
||||||
|
if(zStrd4 == 1 && xStrd4 == 1 && yStrd4 == 0)
|
||||||
|
for (uint i4 = 0; i4 < zAxis4; ++i4)
|
||||||
|
z3[i4] = OpType::op(x3[i4], *y3, extraParams);
|
||||||
|
else if(zStrd4 == 1 && xStrd4 == 0 && yStrd4 == 1)
|
||||||
|
for (uint i4 = 0; i4 < zAxis4; ++i4)
|
||||||
|
z3[i4] = OpType::op(*x3, y3[i4], extraParams);
|
||||||
|
else if(zStrd4 == 1 && xStrd4 == 1 && yStrd4 == 1)
|
||||||
|
for (uint i4 = 0; i4 < zAxis4; ++i4)
|
||||||
|
z3[i4] = OpType::op(x3[i4], y3[i4], extraParams);
|
||||||
|
else
|
||||||
|
for (uint i4 = 0; i4 < zAxis4; ++i4)
|
||||||
|
z3[i4 * zStrd4] = OpType::op(x3[i4 * xStrd4], y3[i4 * yStrd4], extraParams);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
samediff::Threads::parallel_for(func, 0,zAxis0,1, 0,zAxis1,1, 0,zAxis2,1);
|
||||||
|
}
|
||||||
|
|
||||||
|
////////////////////////////////////////////////////////////////////////
|
||||||
|
template <typename X, typename Z, typename OpType>
|
||||||
|
static void execDefault(const X *x, const Nd4jLong *xShapeInfo, const X *y, const Nd4jLong *yShapeInfo, Z* z, const Nd4jLong *zShapeInfo, X* extraParams) {
|
||||||
|
|
||||||
|
const bool xzSameOffsets = shape::haveSameShapeAndStrides(xShapeInfo, zShapeInfo);
|
||||||
|
const bool yzSameOffsets = shape::haveSameShapeAndStrides(yShapeInfo, zShapeInfo);
|
||||||
|
|
||||||
|
const int rank = shape::rank(zShapeInfo); // xRank = yRank = zRank
|
||||||
|
|
||||||
|
auto func = PRAGMA_THREADS_FOR{
|
||||||
|
|
||||||
|
int xCoords[MAX_RANK], yCoords[MAX_RANK], zCoords[MAX_RANK];
|
||||||
|
|
||||||
|
for (auto i = start; i < stop; ++i) {
|
||||||
|
|
||||||
|
shape::index2coordsCPU(start, i, zShapeInfo, zCoords);
|
||||||
|
|
||||||
|
for (uint j = 0; j < rank; ++j) {
|
||||||
|
xCoords[j] = shape::sizeAt(xShapeInfo, j) == 1 ? 0 : zCoords[j];
|
||||||
|
yCoords[j] = shape::sizeAt(yShapeInfo, j) == 1 ? 0 : zCoords[j];
|
||||||
|
}
|
||||||
|
|
||||||
|
const auto zOffset = shape::getOffset(zShapeInfo, zCoords);
|
||||||
|
const auto xOffset = xzSameOffsets ? zOffset : shape::getOffset(xShapeInfo, xCoords);
|
||||||
|
const auto yOffset = yzSameOffsets ? zOffset : shape::getOffset(yShapeInfo, yCoords);
|
||||||
|
|
||||||
|
z[zOffset] = OpType::op(x[xOffset], y[yOffset], extraParams);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
samediff::Threads::parallel_for(func, 0, shape::length(zShapeInfo));
|
||||||
|
}
|
||||||
////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////
|
||||||
template <typename X, typename Z>
|
template <typename X, typename Z>
|
||||||
template<typename OpType>
|
template<typename OpType>
|
||||||
|
@ -468,220 +733,26 @@ void BroadcastBool<X, Z>::exec(const void *vx, const Nd4jLong *xShapeInfo,
|
||||||
X* extraParams = reinterpret_cast<X*>(vextraParams);
|
X* extraParams = reinterpret_cast<X*>(vextraParams);
|
||||||
|
|
||||||
const int rank = shape::rank(zShapeInfo); // xRank = yRank = zRank
|
const int rank = shape::rank(zShapeInfo); // xRank = yRank = zRank
|
||||||
const char zOrder = shape::order(zShapeInfo);
|
|
||||||
|
|
||||||
uint xAxis0 = shape::sizeAt(xShapeInfo, zOrder == 'c' ? 0 : rank-1);
|
|
||||||
uint xAxis1 = shape::sizeAt(xShapeInfo, zOrder == 'c' ? 1 : rank-2);
|
|
||||||
uint xAxis2 = rank > 2 ? shape::sizeAt(xShapeInfo, zOrder == 'c' ? 2 : rank - 3) : 0;
|
|
||||||
uint xAxis3 = rank > 3 ? shape::sizeAt(xShapeInfo, zOrder == 'c' ? 3 : rank - 4) : 0;
|
|
||||||
uint xAxis4 = rank > 4 ? shape::sizeAt(xShapeInfo, zOrder == 'c' ? 4 : rank - 5) : 0;
|
|
||||||
Nd4jLong xStrd0 = shape::strideAt(xShapeInfo, zOrder == 'c' ? 0 : rank-1);
|
|
||||||
Nd4jLong xStrd1 = shape::strideAt(xShapeInfo, zOrder == 'c' ? 1 : rank-2);
|
|
||||||
Nd4jLong xStrd2 = rank > 2 ? shape::strideAt(xShapeInfo, zOrder == 'c' ? 2 : rank - 3) : 0;
|
|
||||||
Nd4jLong xStrd3 = rank > 3 ? shape::strideAt(xShapeInfo, zOrder == 'c' ? 3 : rank - 4) : 0;
|
|
||||||
Nd4jLong xStrd4 = rank > 4 ? shape::strideAt(xShapeInfo, zOrder == 'c' ? 4 : rank - 5) : 0;
|
|
||||||
|
|
||||||
uint yAxis0 = shape::sizeAt(yShapeInfo, zOrder == 'c' ? 0 : rank-1);
|
|
||||||
uint yAxis1 = shape::sizeAt(yShapeInfo, zOrder == 'c' ? 1 : rank-2);
|
|
||||||
uint yAxis2 = rank > 2 ? shape::sizeAt(yShapeInfo, zOrder == 'c' ? 2 : rank - 3) : 0;
|
|
||||||
uint yAxis3 = rank > 3 ? shape::sizeAt(yShapeInfo, zOrder == 'c' ? 3 : rank - 4) : 0;
|
|
||||||
uint yAxis4 = rank > 4 ? shape::sizeAt(yShapeInfo, zOrder == 'c' ? 4 : rank - 5) : 0;
|
|
||||||
Nd4jLong yStrd0 = shape::strideAt(yShapeInfo, zOrder == 'c' ? 0 : rank-1);
|
|
||||||
Nd4jLong yStrd1 = shape::strideAt(yShapeInfo, zOrder == 'c' ? 1 : rank-2);
|
|
||||||
Nd4jLong yStrd2 = rank > 2 ? shape::strideAt(yShapeInfo, zOrder == 'c' ? 2 : rank - 3) : 0;
|
|
||||||
Nd4jLong yStrd3 = rank > 3 ? shape::strideAt(yShapeInfo, zOrder == 'c' ? 3 : rank - 4) : 0;
|
|
||||||
Nd4jLong yStrd4 = rank > 4 ? shape::strideAt(yShapeInfo, zOrder == 'c' ? 4 : rank - 5) : 0;
|
|
||||||
|
|
||||||
uint zAxis0 = shape::sizeAt(zShapeInfo, zOrder == 'c' ? 0 : rank-1);
|
|
||||||
uint zAxis1 = shape::sizeAt(zShapeInfo, zOrder == 'c' ? 1 : rank-2);
|
|
||||||
uint zAxis2 = rank > 2 ? shape::sizeAt(zShapeInfo, zOrder == 'c' ? 2 : rank - 3) : 0;
|
|
||||||
uint zAxis3 = rank > 3 ? shape::sizeAt(zShapeInfo, zOrder == 'c' ? 3 : rank - 4) : 0;
|
|
||||||
uint zAxis4 = rank > 4 ? shape::sizeAt(zShapeInfo, zOrder == 'c' ? 4 : rank - 5) : 0;
|
|
||||||
Nd4jLong zStrd0 = shape::strideAt(zShapeInfo, zOrder == 'c' ? 0 : rank-1);
|
|
||||||
Nd4jLong zStrd1 = shape::strideAt(zShapeInfo, zOrder == 'c' ? 1 : rank-2);
|
|
||||||
Nd4jLong zStrd2 = rank > 2 ? shape::strideAt(zShapeInfo, zOrder == 'c' ? 2 : rank - 3) : 0;
|
|
||||||
Nd4jLong zStrd3 = rank > 3 ? shape::strideAt(zShapeInfo, zOrder == 'c' ? 3 : rank - 4) : 0;
|
|
||||||
Nd4jLong zStrd4 = rank > 4 ? shape::strideAt(zShapeInfo, zOrder == 'c' ? 4 : rank - 5) : 0;
|
|
||||||
|
|
||||||
switch (rank) {
|
switch (rank) {
|
||||||
|
|
||||||
case 1: {
|
case 1:
|
||||||
|
execRank1<X,Z, OpType>(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, extraParams);
|
||||||
auto func = PRAGMA_THREADS_FOR{
|
break;
|
||||||
|
case 2:
|
||||||
if(zStrd0 == 1 && xStrd0 == 1 && yStrd0 == 0) {
|
execRank2<X,Z, OpType>(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, extraParams);
|
||||||
for (auto i0 = start; i0 < stop; ++i0)
|
break;
|
||||||
z[i0] = OpType::op(x[i0], *y, extraParams);
|
case 3:
|
||||||
}
|
execRank3<X,Z, OpType>(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, extraParams);
|
||||||
else if(zStrd0 == 1 && xStrd0 == 0 && yStrd0 == 1) {
|
break;
|
||||||
for (auto i0 = start; i0 < stop; ++i0)
|
case 4:
|
||||||
z[i0] = OpType::op(*x, y[i0], extraParams);
|
execRank4<X,Z, OpType>(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, extraParams);
|
||||||
}
|
break;
|
||||||
else if(zStrd0 == 1 && xStrd0 == 1 && yStrd0 == 1) {
|
case 5:
|
||||||
for (auto i0 = start; i0 < stop; ++i0)
|
execRank5<X,Z, OpType>(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, extraParams);
|
||||||
z[i0] = OpType::op(x[i0], y[i0], extraParams);
|
break;
|
||||||
}
|
default:
|
||||||
else {
|
execDefault<X,Z, OpType>(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, extraParams);
|
||||||
for (auto i0 = start; i0 < stop; ++i0)
|
|
||||||
z[i0 * zStrd0] = OpType::op(x[i0 * xStrd0], y[i0 * yStrd0], extraParams);
|
|
||||||
}
|
|
||||||
};
|
|
||||||
samediff::Threads::parallel_tad(func, 0, zAxis0);
|
|
||||||
}
|
|
||||||
break;
|
|
||||||
|
|
||||||
case 2: {
|
|
||||||
|
|
||||||
auto func = PRAGMA_THREADS_FOR{
|
|
||||||
|
|
||||||
for (auto i0 = start; i0 < stop; ++i0) {
|
|
||||||
|
|
||||||
auto x0 = x + i0 * xStrd0;
|
|
||||||
auto y0 = y + i0 * yStrd0;
|
|
||||||
auto z0 = z + i0 * zStrd0;
|
|
||||||
|
|
||||||
if(zStrd1 == 1 && xStrd1 == 1 && yStrd1 == 0)
|
|
||||||
for (uint i1 = 0; i1 < zAxis1; ++i1)
|
|
||||||
z0[i1] = OpType::op(x0[i1], *y0, extraParams);
|
|
||||||
else if(zStrd1 == 1 && xStrd1 == 0 && yStrd1 == 1)
|
|
||||||
for (uint i1 = 0; i1 < zAxis1; ++i1)
|
|
||||||
z0[i1] = OpType::op(*x0, y0[i1], extraParams);
|
|
||||||
else if(zStrd1 == 1 && xStrd1 == 1 && yStrd1 == 1)
|
|
||||||
for (uint i1 = 0; i1 < zAxis1; ++i1)
|
|
||||||
z0[i1] = OpType::op(x0[i1], y0[i1], extraParams);
|
|
||||||
else
|
|
||||||
for (uint i1 = 0; i1 < zAxis1; ++i1)
|
|
||||||
z0[i1 * zStrd1] = OpType::op(x0[i1 * xStrd1], y0[i1 * yStrd1], extraParams);
|
|
||||||
}
|
|
||||||
};
|
|
||||||
samediff::Threads::parallel_tad(func, 0, zAxis0);
|
|
||||||
}
|
|
||||||
break;
|
|
||||||
|
|
||||||
case 3: {
|
|
||||||
|
|
||||||
auto func = PRAGMA_THREADS_FOR_2D {
|
|
||||||
|
|
||||||
for (auto i0 = start_x; i0 < stop_x; ++i0) {
|
|
||||||
for (auto i1 = start_y; i1 < stop_y; ++i1) {
|
|
||||||
|
|
||||||
auto x1 = x + i0 * xStrd0 + i1 * xStrd1;
|
|
||||||
auto y1 = y + i0 * yStrd0 + i1 * yStrd1;
|
|
||||||
auto z1 = z + i0 * zStrd0 + i1 * zStrd1;
|
|
||||||
|
|
||||||
if(zStrd2 == 1 && xStrd2 == 1 && yStrd2 == 0)
|
|
||||||
for (uint i2 = 0; i2 < zAxis2; ++i2)
|
|
||||||
z1[i2] = OpType::op(x1[i2], *y1, extraParams);
|
|
||||||
else if(zStrd2 == 1 && xStrd2 == 0 && yStrd2 == 1)
|
|
||||||
for (uint i2 = 0; i2 < zAxis2; ++i2)
|
|
||||||
z1[i2] = OpType::op(*x1, y1[i2], extraParams);
|
|
||||||
else if(zStrd2 == 1 && xStrd2 == 1 && yStrd2 == 1)
|
|
||||||
for (uint i2 = 0; i2 < zAxis2; ++i2)
|
|
||||||
z1[i2] = OpType::op(x1[i2], y1[i2], extraParams);
|
|
||||||
else
|
|
||||||
for (uint i2 = 0; i2 < zAxis2; ++i2)
|
|
||||||
z1[i2 * zStrd2] = OpType::op(x1[i2 * xStrd2], y1[i2 * yStrd2], extraParams);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
};
|
|
||||||
samediff::Threads::parallel_for(func, 0,zAxis0,1, 0,zAxis1,1);
|
|
||||||
}
|
|
||||||
break;
|
|
||||||
|
|
||||||
case 4: {
|
|
||||||
|
|
||||||
auto func = PRAGMA_THREADS_FOR_3D {
|
|
||||||
|
|
||||||
for (auto i0 = start_x; i0 < stop_x; ++i0) {
|
|
||||||
for (auto i1 = start_y; i1 < stop_y; ++i1) {
|
|
||||||
for (auto i2 = start_z; i2 < stop_z; ++i2) {
|
|
||||||
|
|
||||||
auto x2 = x + i0 * xStrd0 + i1 * xStrd1 + i2 * xStrd2;
|
|
||||||
auto y2 = y + i0 * yStrd0 + i1 * yStrd1 + i2 * yStrd2;
|
|
||||||
auto z2 = z + i0 * zStrd0 + i1 * zStrd1 + i2 * zStrd2;
|
|
||||||
|
|
||||||
if(zStrd3 == 1 && xStrd3 == 1 && yStrd3 == 0)
|
|
||||||
for (uint i3 = 0; i3 < zAxis3; ++i3)
|
|
||||||
z2[i3] = OpType::op(x2[i3], *y2, extraParams);
|
|
||||||
else if(zStrd3 == 1 && xStrd3 == 0 && yStrd3 == 1)
|
|
||||||
for (uint i3 = 0; i3 < zAxis3; ++i3)
|
|
||||||
z2[i3] = OpType::op(*x2, y2[i3], extraParams);
|
|
||||||
else if(zStrd3 == 1 && xStrd3 == 1 && yStrd3 == 1)
|
|
||||||
for (uint i3 = 0; i3 < zAxis3; ++i3)
|
|
||||||
z2[i3] = OpType::op(x2[i3], y2[i3], extraParams);
|
|
||||||
else
|
|
||||||
for (uint i3 = 0; i3 < zAxis3; ++i3)
|
|
||||||
z2[i3 * zStrd3] = OpType::op(x2[i3 * xStrd3], y2[i3 * yStrd3], extraParams);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
};
|
|
||||||
samediff::Threads::parallel_for(func, 0,zAxis0,1, 0,zAxis1,1, 0,zAxis2,1);
|
|
||||||
}
|
|
||||||
break;
|
|
||||||
|
|
||||||
case 5: {
|
|
||||||
|
|
||||||
auto func = PRAGMA_THREADS_FOR_3D {
|
|
||||||
|
|
||||||
for (auto i0 = start_x; i0 < stop_x; ++i0) {
|
|
||||||
for (auto i1 = start_y; i1 < stop_y; ++i1) {
|
|
||||||
for (auto i2 = start_z; i2 < stop_z; ++i2) {
|
|
||||||
for (uint i3 = 0; i3 < zAxis3; ++i3) {
|
|
||||||
|
|
||||||
auto x3 = x + i0 * xStrd0 + i1 * xStrd1 + i2 * xStrd2 + i3 * xStrd3;
|
|
||||||
auto y3 = y + i0 * yStrd0 + i1 * yStrd1 + i2 * yStrd2 + i3 * yStrd3;
|
|
||||||
auto z3 = z + i0 * zStrd0 + i1 * zStrd1 + i2 * zStrd2 + i3 * zStrd3;
|
|
||||||
|
|
||||||
if(zStrd4 == 1 && xStrd4 == 1 && yStrd4 == 0)
|
|
||||||
for (uint i4 = 0; i4 < zAxis4; ++i4)
|
|
||||||
z3[i4] = OpType::op(x3[i4], *y3, extraParams);
|
|
||||||
else if(zStrd4 == 1 && xStrd4 == 0 && yStrd4 == 1)
|
|
||||||
for (uint i4 = 0; i4 < zAxis4; ++i4)
|
|
||||||
z3[i4] = OpType::op(*x3, y3[i4], extraParams);
|
|
||||||
else if(zStrd4 == 1 && xStrd4 == 1 && yStrd4 == 1)
|
|
||||||
for (uint i4 = 0; i4 < zAxis4; ++i4)
|
|
||||||
z3[i4] = OpType::op(x3[i4], y3[i4], extraParams);
|
|
||||||
else
|
|
||||||
for (uint i4 = 0; i4 < zAxis4; ++i4)
|
|
||||||
z3[i4 * zStrd4] = OpType::op(x3[i4 * xStrd4], y3[i4 * yStrd4], extraParams);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
};
|
|
||||||
samediff::Threads::parallel_for(func, 0,zAxis0,1, 0,zAxis1,1, 0,zAxis2,1);
|
|
||||||
}
|
|
||||||
break;
|
|
||||||
|
|
||||||
default: {
|
|
||||||
|
|
||||||
const bool xzSameOffsets = shape::haveSameShapeAndStrides(xShapeInfo, zShapeInfo);
|
|
||||||
const bool yzSameOffsets = shape::haveSameShapeAndStrides(yShapeInfo, zShapeInfo);
|
|
||||||
|
|
||||||
auto func = PRAGMA_THREADS_FOR{
|
|
||||||
|
|
||||||
int xCoords[MAX_RANK], yCoords[MAX_RANK], zCoords[MAX_RANK];
|
|
||||||
|
|
||||||
for (auto i = start; i < stop; ++i) {
|
|
||||||
|
|
||||||
shape::index2coordsCPU(start, i, zShapeInfo, zCoords);
|
|
||||||
|
|
||||||
for (uint j = 0; j < rank; ++j) {
|
|
||||||
xCoords[j] = shape::sizeAt(xShapeInfo, j) == 1 ? 0 : zCoords[j];
|
|
||||||
yCoords[j] = shape::sizeAt(yShapeInfo, j) == 1 ? 0 : zCoords[j];
|
|
||||||
}
|
|
||||||
|
|
||||||
const auto zOffset = shape::getOffset(zShapeInfo, zCoords);
|
|
||||||
const auto xOffset = xzSameOffsets ? zOffset : shape::getOffset(xShapeInfo, xCoords);
|
|
||||||
const auto yOffset = yzSameOffsets ? zOffset : shape::getOffset(yShapeInfo, yCoords);
|
|
||||||
|
|
||||||
z[zOffset] = OpType::op(x[xOffset], y[yOffset], extraParams);
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
samediff::Threads::parallel_for(func, 0, shape::length(zShapeInfo));
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -439,6 +439,271 @@ namespace functions {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
////////////////////////////////////////////////////////////////////////
|
||||||
|
template <typename X, typename OpType>
|
||||||
|
static void execRank1(const X *x, const Nd4jLong *xShapeInfo, const X *y, const Nd4jLong *yShapeInfo, X* z, const Nd4jLong *zShapeInfo) {
|
||||||
|
|
||||||
|
uint zAxis0 = shape::sizeAt(zShapeInfo, 0);
|
||||||
|
Nd4jLong xStrd0 = shape::strideAt(xShapeInfo, 0);
|
||||||
|
Nd4jLong yStrd0 = shape::strideAt(yShapeInfo, 0);
|
||||||
|
Nd4jLong zStrd0 = shape::strideAt(zShapeInfo, 0);
|
||||||
|
|
||||||
|
auto func = PRAGMA_THREADS_FOR{
|
||||||
|
|
||||||
|
if(zStrd0 == 1 && xStrd0 == 1 && yStrd0 == 0) {
|
||||||
|
for (auto i0 = start; i0 < stop; ++i0)
|
||||||
|
z[i0] = OpType::op(x[i0], *y);
|
||||||
|
}
|
||||||
|
else if(zStrd0 == 1 && xStrd0 == 0 && yStrd0 == 1) {
|
||||||
|
for (auto i0 = start; i0 < stop; ++i0)
|
||||||
|
z[i0] = OpType::op(*x, y[i0]);
|
||||||
|
}
|
||||||
|
else if(zStrd0 == 1 && xStrd0 == 1 && yStrd0 == 1) {
|
||||||
|
for (auto i0 = start; i0 < stop; ++i0)
|
||||||
|
z[i0] = OpType::op(x[i0], y[i0]);
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
for (auto i0 = start; i0 < stop; ++i0)
|
||||||
|
z[i0 * zStrd0] = OpType::op(x[i0 * xStrd0], y[i0 * yStrd0]);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
samediff::Threads::parallel_tad(func, 0, zAxis0);
|
||||||
|
}
|
||||||
|
|
||||||
|
////////////////////////////////////////////////////////////////////////
|
||||||
|
template <typename X, typename OpType>
|
||||||
|
static void execRank2(const X *x, const Nd4jLong *xShapeInfo, const X *y, const Nd4jLong *yShapeInfo, X* z, const Nd4jLong *zShapeInfo) {
|
||||||
|
|
||||||
|
uint zAxis0 = shape::sizeAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 1);
|
||||||
|
Nd4jLong xStrd0 = shape::strideAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 1);
|
||||||
|
Nd4jLong yStrd0 = shape::strideAt(yShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 1);
|
||||||
|
Nd4jLong zStrd0 = shape::strideAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 1);
|
||||||
|
|
||||||
|
uint zAxis1 = shape::sizeAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 1 : 0);
|
||||||
|
Nd4jLong xStrd1 = shape::strideAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? 1 : 0);
|
||||||
|
Nd4jLong yStrd1 = shape::strideAt(yShapeInfo, shape::order(zShapeInfo) == 'c' ? 1 : 0);
|
||||||
|
Nd4jLong zStrd1 = shape::strideAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 1 : 0);
|
||||||
|
|
||||||
|
auto func = PRAGMA_THREADS_FOR{
|
||||||
|
|
||||||
|
for (auto i0 = start; i0 < stop; ++i0) {
|
||||||
|
|
||||||
|
auto x0 = x + i0 * xStrd0;
|
||||||
|
auto y0 = y + i0 * yStrd0;
|
||||||
|
auto z0 = z + i0 * zStrd0;
|
||||||
|
|
||||||
|
if(zStrd1 == 1 && xStrd1 == 1 && yStrd1 == 0)
|
||||||
|
for (uint i1 = 0; i1 < zAxis1; ++i1)
|
||||||
|
z0[i1] = OpType::op(x0[i1], *y0);
|
||||||
|
else if(zStrd1 == 1 && xStrd1 == 0 && yStrd1 == 1)
|
||||||
|
for (uint i1 = 0; i1 < zAxis1; ++i1)
|
||||||
|
z0[i1] = OpType::op(*x0, y0[i1]);
|
||||||
|
else if(zStrd1 == 1 && xStrd1 == 1 && yStrd1 == 1)
|
||||||
|
for (uint i1 = 0; i1 < zAxis1; ++i1)
|
||||||
|
z0[i1] = OpType::op(x0[i1], y0[i1]);
|
||||||
|
else
|
||||||
|
for (uint i1 = 0; i1 < zAxis1; ++i1)
|
||||||
|
z0[i1 * zStrd1] = OpType::op(x0[i1 * xStrd1], y0[i1 * yStrd1]);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
samediff::Threads::parallel_tad(func, 0, zAxis0);
|
||||||
|
}
|
||||||
|
|
||||||
|
////////////////////////////////////////////////////////////////////////
|
||||||
|
template <typename X, typename OpType>
|
||||||
|
static void execRank3(const X *x, const Nd4jLong *xShapeInfo, const X *y, const Nd4jLong *yShapeInfo, X* z, const Nd4jLong *zShapeInfo) {
|
||||||
|
|
||||||
|
uint zAxis0 = shape::sizeAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 2);
|
||||||
|
Nd4jLong xStrd0 = shape::strideAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 2);
|
||||||
|
Nd4jLong yStrd0 = shape::strideAt(yShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 2);
|
||||||
|
Nd4jLong zStrd0 = shape::strideAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 2);
|
||||||
|
|
||||||
|
uint zAxis1 = shape::sizeAt(zShapeInfo, 1);
|
||||||
|
Nd4jLong xStrd1 = shape::strideAt(xShapeInfo, 1);
|
||||||
|
Nd4jLong yStrd1 = shape::strideAt(yShapeInfo, 1);
|
||||||
|
Nd4jLong zStrd1 = shape::strideAt(zShapeInfo, 1);
|
||||||
|
|
||||||
|
uint zAxis2 = shape::sizeAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 2 : 0);
|
||||||
|
Nd4jLong xStrd2 = shape::strideAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? 2 : 0);
|
||||||
|
Nd4jLong yStrd2 = shape::strideAt(yShapeInfo, shape::order(zShapeInfo) == 'c' ? 2 : 0);
|
||||||
|
Nd4jLong zStrd2 = shape::strideAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 2 : 0);
|
||||||
|
|
||||||
|
auto func = PRAGMA_THREADS_FOR_2D {
|
||||||
|
|
||||||
|
for (auto i0 = start_x; i0 < stop_x; ++i0) {
|
||||||
|
for (auto i1 = start_y; i1 < stop_y; ++i1) {
|
||||||
|
|
||||||
|
auto x1 = x + i0 * xStrd0 + i1 * xStrd1;
|
||||||
|
auto y1 = y + i0 * yStrd0 + i1 * yStrd1;
|
||||||
|
auto z1 = z + i0 * zStrd0 + i1 * zStrd1;
|
||||||
|
|
||||||
|
if(zStrd2 == 1 && xStrd2 == 1 && yStrd2 == 0)
|
||||||
|
for (uint i2 = 0; i2 < zAxis2; ++i2)
|
||||||
|
z1[i2] = OpType::op(x1[i2], *y1);
|
||||||
|
else if(zStrd2 == 1 && xStrd2 == 0 && yStrd2 == 1)
|
||||||
|
for (uint i2 = 0; i2 < zAxis2; ++i2)
|
||||||
|
z1[i2] = OpType::op(*x1, y1[i2]);
|
||||||
|
else if(zStrd2 == 1 && xStrd2 == 1 && yStrd2 == 1)
|
||||||
|
for (uint i2 = 0; i2 < zAxis2; ++i2)
|
||||||
|
z1[i2] = OpType::op(x1[i2], y1[i2]);
|
||||||
|
else
|
||||||
|
for (uint i2 = 0; i2 < zAxis2; ++i2)
|
||||||
|
z1[i2 * zStrd2] = OpType::op(x1[i2 * xStrd2], y1[i2 * yStrd2]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
samediff::Threads::parallel_for(func, 0,zAxis0,1, 0,zAxis1,1);
|
||||||
|
}
|
||||||
|
|
||||||
|
////////////////////////////////////////////////////////////////////////
|
||||||
|
template <typename X, typename OpType>
|
||||||
|
static void execRank4(const X *x, const Nd4jLong *xShapeInfo, const X *y, const Nd4jLong *yShapeInfo, X* z, const Nd4jLong *zShapeInfo) {
|
||||||
|
|
||||||
|
uint zAxis0 = shape::sizeAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 3);
|
||||||
|
Nd4jLong xStrd0 = shape::strideAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 3);
|
||||||
|
Nd4jLong yStrd0 = shape::strideAt(yShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 3);
|
||||||
|
Nd4jLong zStrd0 = shape::strideAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 3);
|
||||||
|
|
||||||
|
uint zAxis1 = shape::sizeAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 1 : 2);
|
||||||
|
Nd4jLong xStrd1 = shape::strideAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? 1 : 2);
|
||||||
|
Nd4jLong yStrd1 = shape::strideAt(yShapeInfo, shape::order(zShapeInfo) == 'c' ? 1 : 2);
|
||||||
|
Nd4jLong zStrd1 = shape::strideAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 1 : 2);
|
||||||
|
|
||||||
|
uint zAxis2 = shape::sizeAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 2 : 1);
|
||||||
|
Nd4jLong xStrd2 = shape::strideAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? 2 : 1);
|
||||||
|
Nd4jLong yStrd2 = shape::strideAt(yShapeInfo, shape::order(zShapeInfo) == 'c' ? 2 : 1);
|
||||||
|
Nd4jLong zStrd2 = shape::strideAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 2 : 1);
|
||||||
|
|
||||||
|
uint zAxis3 = shape::sizeAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 3 : 0);
|
||||||
|
Nd4jLong xStrd3 = shape::strideAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? 3 : 0);
|
||||||
|
Nd4jLong yStrd3 = shape::strideAt(yShapeInfo, shape::order(zShapeInfo) == 'c' ? 3 : 0);
|
||||||
|
Nd4jLong zStrd3 = shape::strideAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 3 : 0);
|
||||||
|
|
||||||
|
auto func = PRAGMA_THREADS_FOR_3D {
|
||||||
|
|
||||||
|
for (auto i0 = start_x; i0 < stop_x; ++i0) {
|
||||||
|
for (auto i1 = start_y; i1 < stop_y; ++i1) {
|
||||||
|
for (auto i2 = start_z; i2 < stop_z; ++i2) {
|
||||||
|
|
||||||
|
auto x2 = x + i0 * xStrd0 + i1 * xStrd1 + i2 * xStrd2;
|
||||||
|
auto y2 = y + i0 * yStrd0 + i1 * yStrd1 + i2 * yStrd2;
|
||||||
|
auto z2 = z + i0 * zStrd0 + i1 * zStrd1 + i2 * zStrd2;
|
||||||
|
|
||||||
|
if(zStrd3 == 1 && xStrd3 == 1 && yStrd3 == 0)
|
||||||
|
for (uint i3 = 0; i3 < zAxis3; ++i3)
|
||||||
|
z2[i3] = OpType::op(x2[i3], *y2);
|
||||||
|
else if(zStrd3 == 1 && xStrd3 == 0 && yStrd3 == 1)
|
||||||
|
for (uint i3 = 0; i3 < zAxis3; ++i3)
|
||||||
|
z2[i3] = OpType::op(*x2, y2[i3]);
|
||||||
|
else if(zStrd3 == 1 && xStrd3 == 1 && yStrd3 == 1)
|
||||||
|
for (uint i3 = 0; i3 < zAxis3; ++i3)
|
||||||
|
z2[i3] = OpType::op(x2[i3], y2[i3]);
|
||||||
|
else
|
||||||
|
for (uint i3 = 0; i3 < zAxis3; ++i3)
|
||||||
|
z2[i3 * zStrd3] = OpType::op(x2[i3 * xStrd3], y2[i3 * yStrd3]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
samediff::Threads::parallel_for(func, 0,zAxis0,1, 0,zAxis1,1, 0,zAxis2,1);
|
||||||
|
}
|
||||||
|
|
||||||
|
////////////////////////////////////////////////////////////////////////
|
||||||
|
template <typename X, typename OpType>
|
||||||
|
static void execRank5(const X *x, const Nd4jLong *xShapeInfo, const X *y, const Nd4jLong *yShapeInfo, X* z, const Nd4jLong *zShapeInfo) {
|
||||||
|
|
||||||
|
uint zAxis0 = shape::sizeAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 4);
|
||||||
|
Nd4jLong xStrd0 = shape::strideAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 4);
|
||||||
|
Nd4jLong yStrd0 = shape::strideAt(yShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 4);
|
||||||
|
Nd4jLong zStrd0 = shape::strideAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 4);
|
||||||
|
|
||||||
|
uint zAxis1 = shape::sizeAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 1 : 3);
|
||||||
|
Nd4jLong xStrd1 = shape::strideAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? 1 : 3);
|
||||||
|
Nd4jLong yStrd1 = shape::strideAt(yShapeInfo, shape::order(zShapeInfo) == 'c' ? 1 : 3);
|
||||||
|
Nd4jLong zStrd1 = shape::strideAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 1 : 3);
|
||||||
|
|
||||||
|
uint zAxis2 = shape::sizeAt(zShapeInfo, 2);
|
||||||
|
Nd4jLong xStrd2 = shape::strideAt(xShapeInfo, 2);
|
||||||
|
Nd4jLong yStrd2 = shape::strideAt(yShapeInfo, 2);
|
||||||
|
Nd4jLong zStrd2 = shape::strideAt(zShapeInfo, 2);
|
||||||
|
|
||||||
|
uint zAxis3 = shape::sizeAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 3 : 1);
|
||||||
|
Nd4jLong xStrd3 = shape::strideAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? 3 : 1);
|
||||||
|
Nd4jLong yStrd3 = shape::strideAt(yShapeInfo, shape::order(zShapeInfo) == 'c' ? 3 : 1);
|
||||||
|
Nd4jLong zStrd3 = shape::strideAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 3 : 1);
|
||||||
|
|
||||||
|
uint zAxis4 = shape::sizeAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 4 : 0);
|
||||||
|
Nd4jLong xStrd4 = shape::strideAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? 4 : 0);
|
||||||
|
Nd4jLong yStrd4 = shape::strideAt(yShapeInfo, shape::order(zShapeInfo) == 'c' ? 4 : 0);
|
||||||
|
Nd4jLong zStrd4 = shape::strideAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 4 : 0);
|
||||||
|
|
||||||
|
auto func = PRAGMA_THREADS_FOR_3D {
|
||||||
|
|
||||||
|
for (auto i0 = start_x; i0 < stop_x; ++i0) {
|
||||||
|
for (auto i1 = start_y; i1 < stop_y; ++i1) {
|
||||||
|
for (auto i2 = start_z; i2 < stop_z; ++i2) {
|
||||||
|
for (uint i3 = 0; i3 < zAxis3; ++i3) {
|
||||||
|
|
||||||
|
auto x3 = x + i0 * xStrd0 + i1 * xStrd1 + i2 * xStrd2 + i3 * xStrd3;
|
||||||
|
auto y3 = y + i0 * yStrd0 + i1 * yStrd1 + i2 * yStrd2 + i3 * yStrd3;
|
||||||
|
auto z3 = z + i0 * zStrd0 + i1 * zStrd1 + i2 * zStrd2 + i3 * zStrd3;
|
||||||
|
|
||||||
|
if(zStrd4 == 1 && xStrd4 == 1 && yStrd4 == 0)
|
||||||
|
for (uint i4 = 0; i4 < zAxis4; ++i4)
|
||||||
|
z3[i4] = OpType::op(x3[i4], *y3);
|
||||||
|
else if(zStrd4 == 1 && xStrd4 == 0 && yStrd4 == 1)
|
||||||
|
for (uint i4 = 0; i4 < zAxis4; ++i4)
|
||||||
|
z3[i4] = OpType::op(*x3, y3[i4]);
|
||||||
|
else if(zStrd4 == 1 && xStrd4 == 1 && yStrd4 == 1)
|
||||||
|
for (uint i4 = 0; i4 < zAxis4; ++i4)
|
||||||
|
z3[i4] = OpType::op(x3[i4], y3[i4]);
|
||||||
|
else
|
||||||
|
for (uint i4 = 0; i4 < zAxis4; ++i4)
|
||||||
|
z3[i4 * zStrd4] = OpType::op(x3[i4 * xStrd4], y3[i4 * yStrd4]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
samediff::Threads::parallel_for(func, 0,zAxis0,1, 0,zAxis1,1, 0,zAxis2,1);
|
||||||
|
}
|
||||||
|
|
||||||
|
////////////////////////////////////////////////////////////////////////
|
||||||
|
template <typename X, typename OpType>
|
||||||
|
static void execDefault(const X *x, const Nd4jLong *xShapeInfo, const X *y, const Nd4jLong *yShapeInfo, X* z, const Nd4jLong *zShapeInfo) {
|
||||||
|
|
||||||
|
const bool xzSameOffsets = shape::haveSameShapeAndStrides(xShapeInfo, zShapeInfo);
|
||||||
|
const bool yzSameOffsets = shape::haveSameShapeAndStrides(yShapeInfo, zShapeInfo);
|
||||||
|
|
||||||
|
const int rank = shape::rank(zShapeInfo); // xRank = yRank = zRank
|
||||||
|
|
||||||
|
auto func = PRAGMA_THREADS_FOR{
|
||||||
|
|
||||||
|
int xCoords[MAX_RANK], yCoords[MAX_RANK], zCoords[MAX_RANK];
|
||||||
|
|
||||||
|
for (auto i = start; i < stop; ++i) {
|
||||||
|
|
||||||
|
shape::index2coordsCPU(start, i, zShapeInfo, zCoords);
|
||||||
|
|
||||||
|
for (uint j = 0; j < rank; ++j) {
|
||||||
|
xCoords[j] = shape::sizeAt(xShapeInfo, j) == 1 ? 0 : zCoords[j];
|
||||||
|
yCoords[j] = shape::sizeAt(yShapeInfo, j) == 1 ? 0 : zCoords[j];
|
||||||
|
}
|
||||||
|
|
||||||
|
const auto zOffset = shape::getOffset(zShapeInfo, zCoords);
|
||||||
|
const auto xOffset = xzSameOffsets ? zOffset : shape::getOffset(xShapeInfo, xCoords);
|
||||||
|
const auto yOffset = yzSameOffsets ? zOffset : shape::getOffset(yShapeInfo, yCoords);
|
||||||
|
|
||||||
|
z[zOffset] = OpType::op(x[xOffset], y[yOffset]);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
samediff::Threads::parallel_for(func, 0, shape::length(zShapeInfo));
|
||||||
|
}
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////
|
||||||
template <typename X>
|
template <typename X>
|
||||||
|
@ -452,220 +717,26 @@ void BroadcastInt<X>::exec(const void *vx, const Nd4jLong *xShapeInfo,
|
||||||
X* z = reinterpret_cast<X*>(vz);
|
X* z = reinterpret_cast<X*>(vz);
|
||||||
|
|
||||||
const int rank = shape::rank(zShapeInfo); // xRank = yRank = zRank
|
const int rank = shape::rank(zShapeInfo); // xRank = yRank = zRank
|
||||||
const char zOrder = shape::order(zShapeInfo);
|
|
||||||
|
|
||||||
uint xAxis0 = shape::sizeAt(xShapeInfo, zOrder == 'c' ? 0 : rank-1);
|
|
||||||
uint xAxis1 = shape::sizeAt(xShapeInfo, zOrder == 'c' ? 1 : rank-2);
|
|
||||||
uint xAxis2 = rank > 2 ? shape::sizeAt(xShapeInfo, zOrder == 'c' ? 2 : rank - 3) : 0;
|
|
||||||
uint xAxis3 = rank > 3 ? shape::sizeAt(xShapeInfo, zOrder == 'c' ? 3 : rank - 4) : 0;
|
|
||||||
uint xAxis4 = rank > 4 ? shape::sizeAt(xShapeInfo, zOrder == 'c' ? 4 : rank - 5) : 0;
|
|
||||||
Nd4jLong xStrd0 = shape::strideAt(xShapeInfo, zOrder == 'c' ? 0 : rank-1);
|
|
||||||
Nd4jLong xStrd1 = shape::strideAt(xShapeInfo, zOrder == 'c' ? 1 : rank-2);
|
|
||||||
Nd4jLong xStrd2 = rank > 2 ? shape::strideAt(xShapeInfo, zOrder == 'c' ? 2 : rank - 3) : 0;
|
|
||||||
Nd4jLong xStrd3 = rank > 3 ? shape::strideAt(xShapeInfo, zOrder == 'c' ? 3 : rank - 4) : 0;
|
|
||||||
Nd4jLong xStrd4 = rank > 4 ? shape::strideAt(xShapeInfo, zOrder == 'c' ? 4 : rank - 5) : 0;
|
|
||||||
|
|
||||||
uint yAxis0 = shape::sizeAt(yShapeInfo, zOrder == 'c' ? 0 : rank-1);
|
|
||||||
uint yAxis1 = shape::sizeAt(yShapeInfo, zOrder == 'c' ? 1 : rank-2);
|
|
||||||
uint yAxis2 = rank > 2 ? shape::sizeAt(yShapeInfo, zOrder == 'c' ? 2 : rank - 3) : 0;
|
|
||||||
uint yAxis3 = rank > 3 ? shape::sizeAt(yShapeInfo, zOrder == 'c' ? 3 : rank - 4) : 0;
|
|
||||||
uint yAxis4 = rank > 4 ? shape::sizeAt(yShapeInfo, zOrder == 'c' ? 4 : rank - 5) : 0;
|
|
||||||
Nd4jLong yStrd0 = shape::strideAt(yShapeInfo, zOrder == 'c' ? 0 : rank-1);
|
|
||||||
Nd4jLong yStrd1 = shape::strideAt(yShapeInfo, zOrder == 'c' ? 1 : rank-2);
|
|
||||||
Nd4jLong yStrd2 = rank > 2 ? shape::strideAt(yShapeInfo, zOrder == 'c' ? 2 : rank - 3) : 0;
|
|
||||||
Nd4jLong yStrd3 = rank > 3 ? shape::strideAt(yShapeInfo, zOrder == 'c' ? 3 : rank - 4) : 0;
|
|
||||||
Nd4jLong yStrd4 = rank > 4 ? shape::strideAt(yShapeInfo, zOrder == 'c' ? 4 : rank - 5) : 0;
|
|
||||||
|
|
||||||
uint zAxis0 = shape::sizeAt(zShapeInfo, zOrder == 'c' ? 0 : rank-1);
|
|
||||||
uint zAxis1 = shape::sizeAt(zShapeInfo, zOrder == 'c' ? 1 : rank-2);
|
|
||||||
uint zAxis2 = rank > 2 ? shape::sizeAt(zShapeInfo, zOrder == 'c' ? 2 : rank - 3) : 0;
|
|
||||||
uint zAxis3 = rank > 3 ? shape::sizeAt(zShapeInfo, zOrder == 'c' ? 3 : rank - 4) : 0;
|
|
||||||
uint zAxis4 = rank > 4 ? shape::sizeAt(zShapeInfo, zOrder == 'c' ? 4 : rank - 5) : 0;
|
|
||||||
Nd4jLong zStrd0 = shape::strideAt(zShapeInfo, zOrder == 'c' ? 0 : rank-1);
|
|
||||||
Nd4jLong zStrd1 = shape::strideAt(zShapeInfo, zOrder == 'c' ? 1 : rank-2);
|
|
||||||
Nd4jLong zStrd2 = rank > 2 ? shape::strideAt(zShapeInfo, zOrder == 'c' ? 2 : rank - 3) : 0;
|
|
||||||
Nd4jLong zStrd3 = rank > 3 ? shape::strideAt(zShapeInfo, zOrder == 'c' ? 3 : rank - 4) : 0;
|
|
||||||
Nd4jLong zStrd4 = rank > 4 ? shape::strideAt(zShapeInfo, zOrder == 'c' ? 4 : rank - 5) : 0;
|
|
||||||
|
|
||||||
switch (rank) {
|
switch (rank) {
|
||||||
|
|
||||||
case 1: {
|
case 1:
|
||||||
|
execRank1<X, OpType>(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo);
|
||||||
auto func = PRAGMA_THREADS_FOR{
|
break;
|
||||||
|
case 2:
|
||||||
if(zStrd0 == 1 && xStrd0 == 1 && yStrd0 == 0) {
|
execRank2<X, OpType>(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo);
|
||||||
for (auto i0 = start; i0 < stop; ++i0)
|
break;
|
||||||
z[i0] = OpType::op(x[i0], *y);
|
case 3:
|
||||||
}
|
execRank3<X, OpType>(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo);
|
||||||
else if(zStrd0 == 1 && xStrd0 == 0 && yStrd0 == 1) {
|
break;
|
||||||
for (auto i0 = start; i0 < stop; ++i0)
|
case 4:
|
||||||
z[i0] = OpType::op(*x, y[i0]);
|
execRank4<X, OpType>(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo);
|
||||||
}
|
break;
|
||||||
else if(zStrd0 == 1 && xStrd0 == 1 && yStrd0 == 1) {
|
case 5:
|
||||||
for (auto i0 = start; i0 < stop; ++i0)
|
execRank5<X, OpType>(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo);
|
||||||
z[i0] = OpType::op(x[i0], y[i0]);
|
break;
|
||||||
}
|
default:
|
||||||
else {
|
execDefault<X, OpType>(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo);
|
||||||
for (auto i0 = start; i0 < stop; ++i0)
|
|
||||||
z[i0 * zStrd0] = OpType::op(x[i0 * xStrd0], y[i0 * yStrd0]);
|
|
||||||
}
|
|
||||||
};
|
|
||||||
samediff::Threads::parallel_tad(func, 0, zAxis0);
|
|
||||||
}
|
|
||||||
break;
|
|
||||||
|
|
||||||
case 2: {
|
|
||||||
|
|
||||||
auto func = PRAGMA_THREADS_FOR{
|
|
||||||
|
|
||||||
for (auto i0 = start; i0 < stop; ++i0) {
|
|
||||||
|
|
||||||
auto x0 = x + i0 * xStrd0;
|
|
||||||
auto y0 = y + i0 * yStrd0;
|
|
||||||
auto z0 = z + i0 * zStrd0;
|
|
||||||
|
|
||||||
if(zStrd1 == 1 && xStrd1 == 1 && yStrd1 == 0)
|
|
||||||
for (uint i1 = 0; i1 < zAxis1; ++i1)
|
|
||||||
z0[i1] = OpType::op(x0[i1], *y0);
|
|
||||||
else if(zStrd1 == 1 && xStrd1 == 0 && yStrd1 == 1)
|
|
||||||
for (uint i1 = 0; i1 < zAxis1; ++i1)
|
|
||||||
z0[i1] = OpType::op(*x0, y0[i1]);
|
|
||||||
else if(zStrd1 == 1 && xStrd1 == 1 && yStrd1 == 1)
|
|
||||||
for (uint i1 = 0; i1 < zAxis1; ++i1)
|
|
||||||
z0[i1] = OpType::op(x0[i1], y0[i1]);
|
|
||||||
else
|
|
||||||
for (uint i1 = 0; i1 < zAxis1; ++i1)
|
|
||||||
z0[i1 * zStrd1] = OpType::op(x0[i1 * xStrd1], y0[i1 * yStrd1]);
|
|
||||||
}
|
|
||||||
};
|
|
||||||
samediff::Threads::parallel_tad(func, 0, zAxis0);
|
|
||||||
}
|
|
||||||
break;
|
|
||||||
|
|
||||||
case 3: {
|
|
||||||
|
|
||||||
auto func = PRAGMA_THREADS_FOR_2D {
|
|
||||||
|
|
||||||
for (auto i0 = start_x; i0 < stop_x; ++i0) {
|
|
||||||
for (auto i1 = start_y; i1 < stop_y; ++i1) {
|
|
||||||
|
|
||||||
auto x1 = x + i0 * xStrd0 + i1 * xStrd1;
|
|
||||||
auto y1 = y + i0 * yStrd0 + i1 * yStrd1;
|
|
||||||
auto z1 = z + i0 * zStrd0 + i1 * zStrd1;
|
|
||||||
|
|
||||||
if(zStrd2 == 1 && xStrd2 == 1 && yStrd2 == 0)
|
|
||||||
for (uint i2 = 0; i2 < zAxis2; ++i2)
|
|
||||||
z1[i2] = OpType::op(x1[i2], *y1);
|
|
||||||
else if(zStrd2 == 1 && xStrd2 == 0 && yStrd2 == 1)
|
|
||||||
for (uint i2 = 0; i2 < zAxis2; ++i2)
|
|
||||||
z1[i2] = OpType::op(*x1, y1[i2]);
|
|
||||||
else if(zStrd2 == 1 && xStrd2 == 1 && yStrd2 == 1)
|
|
||||||
for (uint i2 = 0; i2 < zAxis2; ++i2)
|
|
||||||
z1[i2] = OpType::op(x1[i2], y1[i2]);
|
|
||||||
else
|
|
||||||
for (uint i2 = 0; i2 < zAxis2; ++i2)
|
|
||||||
z1[i2 * zStrd2] = OpType::op(x1[i2 * xStrd2], y1[i2 * yStrd2]);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
};
|
|
||||||
samediff::Threads::parallel_for(func, 0,zAxis0,1, 0,zAxis1,1);
|
|
||||||
}
|
|
||||||
break;
|
|
||||||
|
|
||||||
case 4: {
|
|
||||||
|
|
||||||
auto func = PRAGMA_THREADS_FOR_3D {
|
|
||||||
|
|
||||||
for (auto i0 = start_x; i0 < stop_x; ++i0) {
|
|
||||||
for (auto i1 = start_y; i1 < stop_y; ++i1) {
|
|
||||||
for (auto i2 = start_z; i2 < stop_z; ++i2) {
|
|
||||||
|
|
||||||
auto x2 = x + i0 * xStrd0 + i1 * xStrd1 + i2 * xStrd2;
|
|
||||||
auto y2 = y + i0 * yStrd0 + i1 * yStrd1 + i2 * yStrd2;
|
|
||||||
auto z2 = z + i0 * zStrd0 + i1 * zStrd1 + i2 * zStrd2;
|
|
||||||
|
|
||||||
if(zStrd3 == 1 && xStrd3 == 1 && yStrd3 == 0)
|
|
||||||
for (uint i3 = 0; i3 < zAxis3; ++i3)
|
|
||||||
z2[i3] = OpType::op(x2[i3], *y2);
|
|
||||||
else if(zStrd3 == 1 && xStrd3 == 0 && yStrd3 == 1)
|
|
||||||
for (uint i3 = 0; i3 < zAxis3; ++i3)
|
|
||||||
z2[i3] = OpType::op(*x2, y2[i3]);
|
|
||||||
else if(zStrd3 == 1 && xStrd3 == 1 && yStrd3 == 1)
|
|
||||||
for (uint i3 = 0; i3 < zAxis3; ++i3)
|
|
||||||
z2[i3] = OpType::op(x2[i3], y2[i3]);
|
|
||||||
else
|
|
||||||
for (uint i3 = 0; i3 < zAxis3; ++i3)
|
|
||||||
z2[i3 * zStrd3] = OpType::op(x2[i3 * xStrd3], y2[i3 * yStrd3]);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
};
|
|
||||||
samediff::Threads::parallel_for(func, 0,zAxis0,1, 0,zAxis1,1, 0,zAxis2,1);
|
|
||||||
}
|
|
||||||
break;
|
|
||||||
|
|
||||||
case 5: {
|
|
||||||
|
|
||||||
auto func = PRAGMA_THREADS_FOR_3D {
|
|
||||||
|
|
||||||
for (auto i0 = start_x; i0 < stop_x; ++i0) {
|
|
||||||
for (auto i1 = start_y; i1 < stop_y; ++i1) {
|
|
||||||
for (auto i2 = start_z; i2 < stop_z; ++i2) {
|
|
||||||
for (uint i3 = 0; i3 < zAxis3; ++i3) {
|
|
||||||
|
|
||||||
auto x3 = x + i0 * xStrd0 + i1 * xStrd1 + i2 * xStrd2 + i3 * xStrd3;
|
|
||||||
auto y3 = y + i0 * yStrd0 + i1 * yStrd1 + i2 * yStrd2 + i3 * yStrd3;
|
|
||||||
auto z3 = z + i0 * zStrd0 + i1 * zStrd1 + i2 * zStrd2 + i3 * zStrd3;
|
|
||||||
|
|
||||||
if(zStrd4 == 1 && xStrd4 == 1 && yStrd4 == 0)
|
|
||||||
for (uint i4 = 0; i4 < zAxis4; ++i4)
|
|
||||||
z3[i4] = OpType::op(x3[i4], *y3);
|
|
||||||
else if(zStrd4 == 1 && xStrd4 == 0 && yStrd4 == 1)
|
|
||||||
for (uint i4 = 0; i4 < zAxis4; ++i4)
|
|
||||||
z3[i4] = OpType::op(*x3, y3[i4]);
|
|
||||||
else if(zStrd4 == 1 && xStrd4 == 1 && yStrd4 == 1)
|
|
||||||
for (uint i4 = 0; i4 < zAxis4; ++i4)
|
|
||||||
z3[i4] = OpType::op(x3[i4], y3[i4]);
|
|
||||||
else
|
|
||||||
for (uint i4 = 0; i4 < zAxis4; ++i4)
|
|
||||||
z3[i4 * zStrd4] = OpType::op(x3[i4 * xStrd4], y3[i4 * yStrd4]);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
};
|
|
||||||
samediff::Threads::parallel_for(func, 0,zAxis0,1, 0,zAxis1,1, 0,zAxis2,1);
|
|
||||||
}
|
|
||||||
break;
|
|
||||||
|
|
||||||
default: {
|
|
||||||
|
|
||||||
const bool xzSameOffsets = shape::haveSameShapeAndStrides(xShapeInfo, zShapeInfo);
|
|
||||||
const bool yzSameOffsets = shape::haveSameShapeAndStrides(yShapeInfo, zShapeInfo);
|
|
||||||
|
|
||||||
auto func = PRAGMA_THREADS_FOR{
|
|
||||||
|
|
||||||
int xCoords[MAX_RANK], yCoords[MAX_RANK], zCoords[MAX_RANK];
|
|
||||||
|
|
||||||
for (auto i = start; i < stop; ++i) {
|
|
||||||
|
|
||||||
shape::index2coordsCPU(start, i, zShapeInfo, zCoords);
|
|
||||||
|
|
||||||
for (uint j = 0; j < rank; ++j) {
|
|
||||||
xCoords[j] = shape::sizeAt(xShapeInfo, j) == 1 ? 0 : zCoords[j];
|
|
||||||
yCoords[j] = shape::sizeAt(yShapeInfo, j) == 1 ? 0 : zCoords[j];
|
|
||||||
}
|
|
||||||
|
|
||||||
const auto zOffset = shape::getOffset(zShapeInfo, zCoords);
|
|
||||||
const auto xOffset = xzSameOffsets ? zOffset : shape::getOffset(xShapeInfo, xCoords);
|
|
||||||
const auto yOffset = yzSameOffsets ? zOffset : shape::getOffset(yShapeInfo, yCoords);
|
|
||||||
|
|
||||||
z[zOffset] = OpType::op(x[xOffset], y[yOffset]);
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
samediff::Threads::parallel_for(func, 0, shape::length(zShapeInfo));
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -49,7 +49,7 @@ namespace ops {
|
||||||
input = a;
|
input = a;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
//Edge case: splitting empty array (mainly for TF import compatibility) -> return N empty arrays
|
//Edge case: splitting empty array (mainly for TF import compatibility) -> return N empty arrays
|
||||||
if(input->isEmpty()){
|
if(input->isEmpty()){
|
||||||
for( int i=0; i< num_splits; i++ ){
|
for( int i=0; i< num_splits; i++ ){
|
||||||
|
@ -112,17 +112,17 @@ namespace ops {
|
||||||
inputVar = 0;
|
inputVar = 0;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
auto shapes = SHAPELIST();
|
auto shapes = SHAPELIST();
|
||||||
|
|
||||||
//Edge case: splitting empty array (mainly for TF import compatibility) -> return N empty arrays
|
//Edge case: splitting empty array (mainly for TF import compatibility) -> return N empty arrays
|
||||||
if(INPUT_VARIABLE(inputVar)->isEmpty()){
|
// if(INPUT_VARIABLE(inputVar)->isEmpty()){
|
||||||
for (int e = 0; e < num_splits; e++) {
|
// for (int e = 0; e < num_splits; e++) {
|
||||||
auto empty = ConstantShapeHelper::getInstance()->emptyShapeInfo(dataType);
|
// auto empty = ConstantShapeHelper::getInstance()->emptyShapeInfo(dataType);
|
||||||
shapes->push_back(empty);
|
// shapes->push_back(empty);
|
||||||
}
|
// }
|
||||||
return shapes;
|
// return shapes;
|
||||||
}
|
// }
|
||||||
|
|
||||||
if (block.numI() == 2)
|
if (block.numI() == 2)
|
||||||
axis = INT_ARG(1);
|
axis = INT_ARG(1);
|
||||||
|
@ -135,9 +135,9 @@ namespace ops {
|
||||||
for (int e = 0; e < shape::rank(input); e++)
|
for (int e = 0; e < shape::rank(input); e++)
|
||||||
if (e == axis)
|
if (e == axis)
|
||||||
shape[e] = shape::sizeAt(input, e) / num_splits;
|
shape[e] = shape::sizeAt(input, e) / num_splits;
|
||||||
else
|
else
|
||||||
shape[e] = shape::sizeAt(input, e);
|
shape[e] = shape::sizeAt(input, e);
|
||||||
|
|
||||||
for (int e = 0; e < num_splits; e++) {
|
for (int e = 0; e < num_splits; e++) {
|
||||||
auto newShape = ConstantShapeHelper::getInstance()->createShapeInfo(dataType, shape::order(input), shape);
|
auto newShape = ConstantShapeHelper::getInstance()->createShapeInfo(dataType, shape::order(input), shape);
|
||||||
shapes->push_back(newShape);
|
shapes->push_back(newShape);
|
||||||
|
|
|
@ -135,7 +135,7 @@ TEST_F(DeclarableOpsTests10, Test_Size_at_1) {
|
||||||
ASSERT_EQ(Status::OK(), result.status());
|
ASSERT_EQ(Status::OK(), result.status());
|
||||||
|
|
||||||
ASSERT_EQ(e, *result.at(0));
|
ASSERT_EQ(e, *result.at(0));
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
|
@ -170,7 +170,7 @@ TEST_F(DeclarableOpsTests10, Unique_SGO_Test_1) {
|
||||||
|
|
||||||
ASSERT_TRUE(exp.equalsTo(res1));
|
ASSERT_TRUE(exp.equalsTo(res1));
|
||||||
ASSERT_TRUE(expIdx.equalsTo(res2));
|
ASSERT_TRUE(expIdx.equalsTo(res2));
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
|
@ -187,7 +187,7 @@ TEST_F(DeclarableOpsTests10, Where_SGO_Test_1) {
|
||||||
ASSERT_TRUE(exp.isSameShape(resA));
|
ASSERT_TRUE(exp.isSameShape(resA));
|
||||||
ASSERT_TRUE(exp.equalsTo(resA));
|
ASSERT_TRUE(exp.equalsTo(resA));
|
||||||
// ASSERT_TRUE(expIdx.equalsTo(res.at(1)));
|
// ASSERT_TRUE(expIdx.equalsTo(res.at(1)));
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
|
@ -204,7 +204,7 @@ TEST_F(DeclarableOpsTests10, Where_SGO_Test_02) {
|
||||||
ASSERT_TRUE(exp.equalsTo(resA));
|
ASSERT_TRUE(exp.equalsTo(resA));
|
||||||
ASSERT_TRUE(exp.isSameShape(resA));
|
ASSERT_TRUE(exp.isSameShape(resA));
|
||||||
// ASSERT_TRUE(expIdx.equalsTo(res.at(1)));
|
// ASSERT_TRUE(expIdx.equalsTo(res.at(1)));
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
|
@ -228,7 +228,7 @@ TEST_F(DeclarableOpsTests10, WhereNP_SGO_Test_1) {
|
||||||
ASSERT_TRUE(exp2.equalsTo(res2));
|
ASSERT_TRUE(exp2.equalsTo(res2));
|
||||||
ASSERT_TRUE(exp3.equalsTo(res3));
|
ASSERT_TRUE(exp3.equalsTo(res3));
|
||||||
//ASSERT_TRUE(expIdx.equalsTo(res.at(1)));
|
//ASSERT_TRUE(expIdx.equalsTo(res.at(1)));
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
|
@ -245,7 +245,7 @@ TEST_F(DeclarableOpsTests10, WhereNP_SGO_Test_2) {
|
||||||
ASSERT_TRUE(exp1.equalsTo(res.at(0)));
|
ASSERT_TRUE(exp1.equalsTo(res.at(0)));
|
||||||
ASSERT_TRUE(exp2.equalsTo(res.at(1)));
|
ASSERT_TRUE(exp2.equalsTo(res.at(1)));
|
||||||
//ASSERT_TRUE(expIdx.equalsTo(res.at(1)));
|
//ASSERT_TRUE(expIdx.equalsTo(res.at(1)));
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
|
@ -263,7 +263,7 @@ TEST_F(DeclarableOpsTests10, Where_SGO_Test_2) {
|
||||||
ASSERT_TRUE(exp.equalsTo(resA));
|
ASSERT_TRUE(exp.equalsTo(resA));
|
||||||
ASSERT_TRUE(exp.isSameShape(resA));
|
ASSERT_TRUE(exp.isSameShape(resA));
|
||||||
// ASSERT_TRUE(expIdx.equalsTo(res.at(1)));
|
// ASSERT_TRUE(expIdx.equalsTo(res.at(1)));
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
|
@ -281,7 +281,7 @@ TEST_F(DeclarableOpsTests10, Where_SGO_Test_3) {
|
||||||
ASSERT_TRUE(exp.equalsTo(resA));
|
ASSERT_TRUE(exp.equalsTo(resA));
|
||||||
ASSERT_TRUE(exp.isSameShape(resA));
|
ASSERT_TRUE(exp.isSameShape(resA));
|
||||||
// ASSERT_TRUE(expIdx.equalsTo(res.at(1)));
|
// ASSERT_TRUE(expIdx.equalsTo(res.at(1)));
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
|
@ -300,7 +300,7 @@ TEST_F(DeclarableOpsTests10, Where_SGO_Test_4) {
|
||||||
//ASSERT_TRUE(exp.equalsTo(resA));
|
//ASSERT_TRUE(exp.equalsTo(resA));
|
||||||
//ASSERT_TRUE(exp.isSameShape(resA));
|
//ASSERT_TRUE(exp.isSameShape(resA));
|
||||||
// ASSERT_TRUE(expIdx.equalsTo(res.at(1)));
|
// ASSERT_TRUE(expIdx.equalsTo(res.at(1)));
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
|
@ -318,7 +318,7 @@ TEST_F(DeclarableOpsTests10, Where_SGO_Test_5) {
|
||||||
ASSERT_TRUE(exp.equalsTo(resA));
|
ASSERT_TRUE(exp.equalsTo(resA));
|
||||||
ASSERT_TRUE(exp.isSameShape(resA));
|
ASSERT_TRUE(exp.isSameShape(resA));
|
||||||
// ASSERT_TRUE(expIdx.equalsTo(res.at(1)));
|
// ASSERT_TRUE(expIdx.equalsTo(res.at(1)));
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
|
@ -337,7 +337,7 @@ TEST_F(DeclarableOpsTests10, WhereNP_SGO_Test_4) {
|
||||||
//ASSERT_TRUE(exp.equalsTo(resA));
|
//ASSERT_TRUE(exp.equalsTo(resA));
|
||||||
//ASSERT_TRUE(exp.isSameShape(resA));
|
//ASSERT_TRUE(exp.isSameShape(resA));
|
||||||
// ASSERT_TRUE(expIdx.equalsTo(res.at(1)));
|
// ASSERT_TRUE(expIdx.equalsTo(res.at(1)));
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
|
@ -354,7 +354,7 @@ TEST_F(DeclarableOpsTests10, CosineDistance_SGO_Test_1) {
|
||||||
auto resA = res.at(0);
|
auto resA = res.at(0);
|
||||||
|
|
||||||
ASSERT_TRUE(exp.equalsTo(resA));
|
ASSERT_TRUE(exp.equalsTo(resA));
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
|
@ -371,7 +371,7 @@ TEST_F(DeclarableOpsTests10, CosineDistance_SGO_Test_2) {
|
||||||
auto resA = res.at(0);
|
auto resA = res.at(0);
|
||||||
|
|
||||||
ASSERT_TRUE(exp.equalsTo(resA));
|
ASSERT_TRUE(exp.equalsTo(resA));
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
///////////////////////////////////////////////////////////////////
|
///////////////////////////////////////////////////////////////////
|
||||||
|
@ -567,7 +567,7 @@ TEST_F(DeclarableOpsTests10, LGamma_Test1) {
|
||||||
ASSERT_TRUE(exp.isSameShape(z));
|
ASSERT_TRUE(exp.isSameShape(z));
|
||||||
ASSERT_TRUE(exp.equalsTo(z));
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////////////
|
||||||
|
@ -764,50 +764,6 @@ TEST_F(DeclarableOpsTests10, sparse_softmax_cross_entropy_loss_with_logits_test4
|
||||||
ASSERT_TRUE(expected.equalsTo(output));
|
ASSERT_TRUE(expected.equalsTo(output));
|
||||||
}
|
}
|
||||||
|
|
||||||
///////////////////////////////////////////////////////////////////
|
|
||||||
TEST_F(DeclarableOpsTests10, split_test4) {
|
|
||||||
|
|
||||||
auto input = NDArrayFactory::create<double>('c', {10},{1.f,2.f,3.f,4.f,5.f,6.f,7.f,8.f,9.f,10.f});
|
|
||||||
auto axis = NDArrayFactory::create<double>(-1);
|
|
||||||
auto exp1 = NDArrayFactory::create<double>('c', {5}, {1.f,2.f,3.f,4.f,5.f});
|
|
||||||
auto exp2 = NDArrayFactory::create<double>('c', {5}, {6.f,7.f,8.f,9.f,10.f});
|
|
||||||
|
|
||||||
sd::ops::split op;
|
|
||||||
auto results = op.evaluate({&input, &axis}, {}, {2}, {});
|
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results.status());
|
|
||||||
|
|
||||||
auto out1 = results.at(0);
|
|
||||||
auto out2 = results.at(1);
|
|
||||||
|
|
||||||
ASSERT_TRUE(exp1.isSameShape(out1));
|
|
||||||
ASSERT_TRUE(exp2.isSameShape(out2));
|
|
||||||
ASSERT_TRUE(exp1.equalsTo(out1));
|
|
||||||
ASSERT_TRUE(exp2.equalsTo(out2));
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
///////////////////////////////////////////////////////////////////
|
|
||||||
TEST_F(DeclarableOpsTests10, split_test5) {
|
|
||||||
|
|
||||||
auto input = NDArrayFactory::create<double>('c', {3,8},{1.f,2.f,3.f,4.f,5.f,6.f,7.f,8.f,9.f,10.f,11.f,12.f,13.f,14.f,15.f,16.f,17.f,18.f,19.f,20.f,21.f,22.f,23.f,24.f});
|
|
||||||
auto exp1 = NDArrayFactory::create<double>('c', {3,4}, {1.f,2.f,3.f,4.f, 9.f,10.f,11.f,12.f, 17.f,18.f,19.f,20.f});
|
|
||||||
auto exp2 = NDArrayFactory::create<double>('c', {3,4}, {5.f,6.f,7.f,8.f, 13.f,14.f,15.f,16.f, 21.f,22.f,23.f,24.f});
|
|
||||||
|
|
||||||
sd::ops::split op;
|
|
||||||
auto results = op.evaluate({&input}, {}, {2,-1},{});
|
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results.status());
|
|
||||||
|
|
||||||
auto out1 = results.at(0);
|
|
||||||
auto out2 = results.at(1);
|
|
||||||
|
|
||||||
ASSERT_TRUE(exp1.isSameShape(out1));
|
|
||||||
ASSERT_TRUE(exp2.isSameShape(out2));
|
|
||||||
ASSERT_TRUE(exp1.equalsTo(out1));
|
|
||||||
ASSERT_TRUE(exp2.equalsTo(out2));
|
|
||||||
}
|
|
||||||
|
|
||||||
///////////////////////////////////////////////////////////////////
|
///////////////////////////////////////////////////////////////////
|
||||||
TEST_F(DeclarableOpsTests10, histogram_fixed_width_test1) {
|
TEST_F(DeclarableOpsTests10, histogram_fixed_width_test1) {
|
||||||
|
|
||||||
|
@ -1464,7 +1420,7 @@ TEST_F(DeclarableOpsTests10, ImageResizeBilinear_Test1_1) {
|
||||||
//expected.printIndexedBuffer("Expect for 10x10");
|
//expected.printIndexedBuffer("Expect for 10x10");
|
||||||
ASSERT_TRUE(expected.isSameShape(result));
|
ASSERT_TRUE(expected.isSameShape(result));
|
||||||
ASSERT_TRUE(expected.equalsTo(result));
|
ASSERT_TRUE(expected.equalsTo(result));
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(DeclarableOpsTests10, ImageResizeBilinear_Test1_2) {
|
TEST_F(DeclarableOpsTests10, ImageResizeBilinear_Test1_2) {
|
||||||
|
@ -1511,7 +1467,7 @@ TEST_F(DeclarableOpsTests10, ImageResizeBilinear_Test1_2) {
|
||||||
// expected.printBuffer("Expect for 4x5");
|
// expected.printBuffer("Expect for 4x5");
|
||||||
ASSERT_TRUE(expected.isSameShape(result));
|
ASSERT_TRUE(expected.isSameShape(result));
|
||||||
ASSERT_TRUE(expected.equalsTo(result));
|
ASSERT_TRUE(expected.equalsTo(result));
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(DeclarableOpsTests10, ImageResizeBilinear_Test01) {
|
TEST_F(DeclarableOpsTests10, ImageResizeBilinear_Test01) {
|
||||||
|
@ -1569,7 +1525,7 @@ TEST_F(DeclarableOpsTests10, ImageResizeBilinear_Test01) {
|
||||||
// expected.printShapeInfo("Expect shape");
|
// expected.printShapeInfo("Expect shape");
|
||||||
ASSERT_TRUE(expected.isSameShape(result));
|
ASSERT_TRUE(expected.isSameShape(result));
|
||||||
ASSERT_TRUE(expected.equalsTo(result));
|
ASSERT_TRUE(expected.equalsTo(result));
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(DeclarableOpsTests10, ImageResizeBilinear_Test02) {
|
TEST_F(DeclarableOpsTests10, ImageResizeBilinear_Test02) {
|
||||||
|
@ -1724,7 +1680,7 @@ TEST_F(DeclarableOpsTests10, ImageResizeBilinear_Test02) {
|
||||||
// expected.printShapeInfo("Expect shape");
|
// expected.printShapeInfo("Expect shape");
|
||||||
ASSERT_TRUE(expected.isSameShape(result));
|
ASSERT_TRUE(expected.isSameShape(result));
|
||||||
ASSERT_TRUE(expected.equalsTo(result));
|
ASSERT_TRUE(expected.equalsTo(result));
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////
|
||||||
|
@ -2053,7 +2009,7 @@ TEST_F(DeclarableOpsTests10, LinSpace_Test1) {
|
||||||
auto res = result.at(0);
|
auto res = result.at(0);
|
||||||
|
|
||||||
ASSERT_TRUE(expect.equalsTo(res));
|
ASSERT_TRUE(expect.equalsTo(res));
|
||||||
|
|
||||||
}
|
}
|
||||||
////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////
|
||||||
TEST_F(DeclarableOpsTests10, ImageResizeNeighbor_Test1) {
|
TEST_F(DeclarableOpsTests10, ImageResizeNeighbor_Test1) {
|
||||||
|
@ -2276,7 +2232,7 @@ TEST_F(DeclarableOpsTests10, ReduceLogSumExpTest_2) {
|
||||||
ASSERT_TRUE(expected.isSameShapeStrict(*result));
|
ASSERT_TRUE(expected.isSameShapeStrict(*result));
|
||||||
ASSERT_TRUE(expected.equalsTo(result));
|
ASSERT_TRUE(expected.equalsTo(result));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////
|
||||||
TEST_F(DeclarableOpsTests10, ReduceLogSumExpTest_3) {
|
TEST_F(DeclarableOpsTests10, ReduceLogSumExpTest_3) {
|
||||||
|
@ -2315,7 +2271,7 @@ TEST_F(DeclarableOpsTests10, Image_NonMaxSuppressing_1) {
|
||||||
ASSERT_TRUE(expected.isSameShapeStrict(*result));
|
ASSERT_TRUE(expected.isSameShapeStrict(*result));
|
||||||
ASSERT_TRUE(expected.equalsTo(result));
|
ASSERT_TRUE(expected.equalsTo(result));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////
|
||||||
|
|
|
@ -73,7 +73,7 @@ TYPED_TEST(TypedDeclarableOpsTests4, avgpool2d_1) {
|
||||||
ASSERT_TRUE(exp.isSameShape(z));
|
ASSERT_TRUE(exp.isSameShape(z));
|
||||||
ASSERT_TRUE(exp.equalsTo(z));
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////
|
||||||
|
@ -94,7 +94,7 @@ TYPED_TEST(TypedDeclarableOpsTests4, avgpool2d_2) {
|
||||||
ASSERT_TRUE(exp.isSameShape(z));
|
ASSERT_TRUE(exp.isSameShape(z));
|
||||||
ASSERT_TRUE(exp.equalsTo(z));
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////
|
||||||
|
@ -115,7 +115,7 @@ TYPED_TEST(TypedDeclarableOpsTests4, avgpool2d_3) {
|
||||||
ASSERT_TRUE(exp.isSameShape(z));
|
ASSERT_TRUE(exp.isSameShape(z));
|
||||||
ASSERT_TRUE(exp.equalsTo(z));
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////
|
||||||
|
@ -135,7 +135,7 @@ TYPED_TEST(TypedDeclarableOpsTests4, avgpool2d_4) {
|
||||||
ASSERT_TRUE(exp.isSameShape(z));
|
ASSERT_TRUE(exp.isSameShape(z));
|
||||||
ASSERT_TRUE(exp.equalsTo(z));
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////
|
||||||
|
@ -156,7 +156,7 @@ TYPED_TEST(TypedDeclarableOpsTests4, avgpool2d_5) {
|
||||||
ASSERT_TRUE(exp.isSameShape(z));
|
ASSERT_TRUE(exp.isSameShape(z));
|
||||||
ASSERT_TRUE(exp.equalsTo(z));
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////
|
||||||
|
@ -177,7 +177,7 @@ TYPED_TEST(TypedDeclarableOpsTests4, avgpool2d_6) {
|
||||||
ASSERT_TRUE(exp.isSameShape(z));
|
ASSERT_TRUE(exp.isSameShape(z));
|
||||||
ASSERT_TRUE(exp.equalsTo(z));
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////
|
||||||
|
@ -198,7 +198,7 @@ TYPED_TEST(TypedDeclarableOpsTests4, avgpool2d_7) {
|
||||||
ASSERT_TRUE(exp.isSameShape(z));
|
ASSERT_TRUE(exp.isSameShape(z));
|
||||||
ASSERT_TRUE(exp.equalsTo(z));
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////
|
||||||
|
@ -218,7 +218,7 @@ TYPED_TEST(TypedDeclarableOpsTests4, avgpool2d_8) {
|
||||||
ASSERT_TRUE(exp.isSameShape(z));
|
ASSERT_TRUE(exp.isSameShape(z));
|
||||||
ASSERT_TRUE(exp.equalsTo(z));
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////
|
||||||
|
@ -242,7 +242,7 @@ TYPED_TEST(TypedDeclarableOpsTests4, avgpool2d_9) {
|
||||||
ASSERT_TRUE(exp.isSameShape(z));
|
ASSERT_TRUE(exp.isSameShape(z));
|
||||||
ASSERT_TRUE(exp.equalsTo(z));
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////
|
||||||
|
@ -273,7 +273,7 @@ TYPED_TEST(TypedDeclarableOpsTests4, avgpool2d_10) {
|
||||||
ASSERT_TRUE(exp.isSameShape(z));
|
ASSERT_TRUE(exp.isSameShape(z));
|
||||||
ASSERT_TRUE(exp.equalsTo(z));
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////
|
||||||
|
@ -339,7 +339,7 @@ TEST_F(DeclarableOpsTests4, avgpool2d_11) {
|
||||||
|
|
||||||
ASSERT_EQ(m, *z);
|
ASSERT_EQ(m, *z);
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////
|
||||||
|
@ -373,7 +373,7 @@ TEST_F(DeclarableOpsTests4, avgpool2d_12) {
|
||||||
ASSERT_TRUE(expected.isSameShape(output));
|
ASSERT_TRUE(expected.isSameShape(output));
|
||||||
ASSERT_TRUE(expected.equalsTo(output));
|
ASSERT_TRUE(expected.equalsTo(output));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////
|
||||||
|
@ -546,7 +546,7 @@ TEST_F(DeclarableOpsTests4, biasadd_1) {
|
||||||
ASSERT_TRUE(exp.isSameShape(z));
|
ASSERT_TRUE(exp.isSameShape(z));
|
||||||
ASSERT_TRUE(exp.equalsTo(z));
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(DeclarableOpsTests4, biasadd_2) {
|
TEST_F(DeclarableOpsTests4, biasadd_2) {
|
||||||
|
@ -564,7 +564,7 @@ TEST_F(DeclarableOpsTests4, biasadd_2) {
|
||||||
ASSERT_TRUE(exp.isSameShape(z));
|
ASSERT_TRUE(exp.isSameShape(z));
|
||||||
ASSERT_TRUE(exp.equalsTo(z));
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(DeclarableOpsTests4, biasadd_3) {
|
TEST_F(DeclarableOpsTests4, biasadd_3) {
|
||||||
|
@ -581,7 +581,7 @@ TEST_F(DeclarableOpsTests4, biasadd_3) {
|
||||||
|
|
||||||
ASSERT_TRUE(exp.isSameShape(z));
|
ASSERT_TRUE(exp.isSameShape(z));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////
|
||||||
|
@ -609,7 +609,7 @@ TEST_F(DeclarableOpsTests4, biasadd_bp_1) {
|
||||||
ASSERT_TRUE(gradB->isSameShape(expGradB));
|
ASSERT_TRUE(gradB->isSameShape(expGradB));
|
||||||
ASSERT_TRUE(gradB->equalsTo(expGradB));
|
ASSERT_TRUE(gradB->equalsTo(expGradB));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////
|
||||||
|
@ -637,7 +637,7 @@ TEST_F(DeclarableOpsTests4, biasadd_bp_2) {
|
||||||
ASSERT_TRUE(gradB->isSameShape(expGradB));
|
ASSERT_TRUE(gradB->isSameShape(expGradB));
|
||||||
ASSERT_TRUE(gradB->equalsTo(expGradB));
|
ASSERT_TRUE(gradB->equalsTo(expGradB));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(DeclarableOpsTests4, biasadd_4) {
|
TEST_F(DeclarableOpsTests4, biasadd_4) {
|
||||||
|
@ -672,7 +672,7 @@ TEST_F(DeclarableOpsTests4, Test_Fill_1) {
|
||||||
ASSERT_TRUE(exp.isSameShape(z));
|
ASSERT_TRUE(exp.isSameShape(z));
|
||||||
ASSERT_TRUE(exp.equalsTo(z));
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(DeclarableOpsTests4, Test_FirasSparce_1) {
|
TEST_F(DeclarableOpsTests4, Test_FirasSparce_1) {
|
||||||
|
@ -694,7 +694,7 @@ TEST_F(DeclarableOpsTests4, Test_FirasSparce_1) {
|
||||||
// ASSERT_TRUE(exp.isSameShape(z));
|
// ASSERT_TRUE(exp.isSameShape(z));
|
||||||
ASSERT_TRUE(exp.equalsTo(z));
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(DeclarableOpsTests4, Test_FlattenTests_1) {
|
TEST_F(DeclarableOpsTests4, Test_FlattenTests_1) {
|
||||||
|
@ -713,7 +713,7 @@ TEST_F(DeclarableOpsTests4, Test_FlattenTests_1) {
|
||||||
// z->printShapeInfo("Flatten1 shape");
|
// z->printShapeInfo("Flatten1 shape");
|
||||||
ASSERT_TRUE(exp.equalsTo(z));
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(DeclarableOpsTests4, Test_FlattenTests_2) {
|
TEST_F(DeclarableOpsTests4, Test_FlattenTests_2) {
|
||||||
|
@ -734,7 +734,7 @@ TEST_F(DeclarableOpsTests4, Test_FlattenTests_2) {
|
||||||
|
|
||||||
ASSERT_TRUE(exp.equalsTo(z));
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(DeclarableOpsTests4, Test_FlattenTests_3) {
|
TEST_F(DeclarableOpsTests4, Test_FlattenTests_3) {
|
||||||
|
@ -752,7 +752,7 @@ TEST_F(DeclarableOpsTests4, Test_FlattenTests_3) {
|
||||||
|
|
||||||
ASSERT_TRUE(exp.equalsTo(z));
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(DeclarableOpsTests4, Test_FlattenTests_4) {
|
TEST_F(DeclarableOpsTests4, Test_FlattenTests_4) {
|
||||||
|
@ -770,7 +770,7 @@ TEST_F(DeclarableOpsTests4, Test_FlattenTests_4) {
|
||||||
|
|
||||||
ASSERT_TRUE(exp.equalsTo(z));
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(DeclarableOpsTests4, Test_FloorTests_1) {
|
TEST_F(DeclarableOpsTests4, Test_FloorTests_1) {
|
||||||
|
@ -788,7 +788,7 @@ TEST_F(DeclarableOpsTests4, Test_FloorTests_1) {
|
||||||
// z->printShapeInfo("Flatten1 shape");
|
// z->printShapeInfo("Flatten1 shape");
|
||||||
ASSERT_TRUE(exp.equalsTo(z));
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(DeclarableOpsTests4, Test_Reshape_Again) {
|
TEST_F(DeclarableOpsTests4, Test_Reshape_Again) {
|
||||||
|
@ -806,7 +806,7 @@ TEST_F(DeclarableOpsTests4, Test_Reshape_Again) {
|
||||||
ASSERT_TRUE(exp.isSameShape(z));
|
ASSERT_TRUE(exp.isSameShape(z));
|
||||||
ASSERT_TRUE(exp.equalsTo(z));
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(DeclarableOpsTests4, Test_Split_1) {
|
TEST_F(DeclarableOpsTests4, Test_Split_1) {
|
||||||
|
@ -845,7 +845,7 @@ TEST_F(DeclarableOpsTests4, Test_Split_1) {
|
||||||
ASSERT_TRUE(sub1.equalsTo(z1));
|
ASSERT_TRUE(sub1.equalsTo(z1));
|
||||||
ASSERT_TRUE(sub2.equalsTo(z2));
|
ASSERT_TRUE(sub2.equalsTo(z2));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// special test for TF mode, when axis goes first
|
// special test for TF mode, when axis goes first
|
||||||
|
@ -888,7 +888,7 @@ TEST_F(DeclarableOpsTests4, Test_Split_2) {
|
||||||
ASSERT_TRUE(sub2.equalsTo(z2));
|
ASSERT_TRUE(sub2.equalsTo(z2));
|
||||||
ASSERT_TRUE(sub3.equalsTo(z3));
|
ASSERT_TRUE(sub3.equalsTo(z3));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// special test for TF mode, when axis goes first
|
// special test for TF mode, when axis goes first
|
||||||
|
@ -923,10 +923,89 @@ TEST_F(DeclarableOpsTests4, Test_Split_3) {
|
||||||
ASSERT_TRUE(sub0.equalsTo(z0));
|
ASSERT_TRUE(sub0.equalsTo(z0));
|
||||||
ASSERT_TRUE(sub1.equalsTo(z1));
|
ASSERT_TRUE(sub1.equalsTo(z1));
|
||||||
ASSERT_TRUE(sub2.equalsTo(z2));
|
ASSERT_TRUE(sub2.equalsTo(z2));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
///////////////////////////////////////////////////////////////////
|
||||||
|
TEST_F(DeclarableOpsTests4, split_test4) {
|
||||||
|
|
||||||
|
auto input = NDArrayFactory::create<double>('c', {10},{1.f,2.f,3.f,4.f,5.f,6.f,7.f,8.f,9.f,10.f});
|
||||||
|
auto axis = NDArrayFactory::create<double>(-1);
|
||||||
|
auto exp1 = NDArrayFactory::create<double>('c', {5}, {1.f,2.f,3.f,4.f,5.f});
|
||||||
|
auto exp2 = NDArrayFactory::create<double>('c', {5}, {6.f,7.f,8.f,9.f,10.f});
|
||||||
|
|
||||||
|
sd::ops::split op;
|
||||||
|
auto results = op.evaluate({&input, &axis}, {}, {2}, {});
|
||||||
|
|
||||||
|
ASSERT_EQ(ND4J_STATUS_OK, results.status());
|
||||||
|
|
||||||
|
auto out1 = results.at(0);
|
||||||
|
auto out2 = results.at(1);
|
||||||
|
|
||||||
|
ASSERT_TRUE(exp1.isSameShape(out1));
|
||||||
|
ASSERT_TRUE(exp2.isSameShape(out2));
|
||||||
|
ASSERT_TRUE(exp1.equalsTo(out1));
|
||||||
|
ASSERT_TRUE(exp2.equalsTo(out2));
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
///////////////////////////////////////////////////////////////////
|
||||||
|
TEST_F(DeclarableOpsTests4, split_test5) {
|
||||||
|
|
||||||
|
auto input = NDArrayFactory::create<double>('c', {3,8},{1.f,2.f,3.f,4.f,5.f,6.f,7.f,8.f,9.f,10.f,11.f,12.f,13.f,14.f,15.f,16.f,17.f,18.f,19.f,20.f,21.f,22.f,23.f,24.f});
|
||||||
|
auto exp1 = NDArrayFactory::create<double>('c', {3,4}, {1.f,2.f,3.f,4.f, 9.f,10.f,11.f,12.f, 17.f,18.f,19.f,20.f});
|
||||||
|
auto exp2 = NDArrayFactory::create<double>('c', {3,4}, {5.f,6.f,7.f,8.f, 13.f,14.f,15.f,16.f, 21.f,22.f,23.f,24.f});
|
||||||
|
|
||||||
|
sd::ops::split op;
|
||||||
|
auto results = op.evaluate({&input}, {}, {2,-1},{});
|
||||||
|
|
||||||
|
ASSERT_EQ(ND4J_STATUS_OK, results.status());
|
||||||
|
|
||||||
|
auto out1 = results.at(0);
|
||||||
|
auto out2 = results.at(1);
|
||||||
|
|
||||||
|
ASSERT_TRUE(exp1.isSameShape(out1));
|
||||||
|
ASSERT_TRUE(exp2.isSameShape(out2));
|
||||||
|
ASSERT_TRUE(exp1.equalsTo(out1));
|
||||||
|
ASSERT_TRUE(exp2.equalsTo(out2));
|
||||||
|
}
|
||||||
|
|
||||||
|
///////////////////////////////////////////////////////////////////
|
||||||
|
TEST_F(DeclarableOpsTests4, split_test6) {
|
||||||
|
|
||||||
|
NDArray input('c', {0,4}, sd::DataType::FLOAT32);
|
||||||
|
std::vector<Nd4jLong> expShape = {0,1};
|
||||||
|
|
||||||
|
const int numSplits = 4;
|
||||||
|
const int axis = 1;
|
||||||
|
|
||||||
|
sd::ops::split op;
|
||||||
|
auto results = op.evaluate({&input}, {}, {numSplits, axis}, {});
|
||||||
|
|
||||||
|
ASSERT_EQ(ND4J_STATUS_OK, results.status());
|
||||||
|
|
||||||
|
for (int i = 0; i < numSplits; ++i)
|
||||||
|
ASSERT_TRUE(results.at(i)->isSameShape(expShape));
|
||||||
|
}
|
||||||
|
|
||||||
|
///////////////////////////////////////////////////////////////////
|
||||||
|
TEST_F(DeclarableOpsTests4, split_test7) {
|
||||||
|
|
||||||
|
NDArray input('c', {0,4}, sd::DataType::FLOAT32);
|
||||||
|
std::vector<Nd4jLong> expShape = {0,4};
|
||||||
|
|
||||||
|
const int numSplits = 4;
|
||||||
|
const int axis = 0;
|
||||||
|
|
||||||
|
sd::ops::split op;
|
||||||
|
auto results = op.evaluate({&input}, {}, {numSplits, axis}, {});
|
||||||
|
|
||||||
|
ASSERT_EQ(ND4J_STATUS_OK, results.status());
|
||||||
|
|
||||||
|
for (int i = 0; i < numSplits; ++i)
|
||||||
|
ASSERT_TRUE(results.at(i)->isSameShape(expShape));
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
TEST_F(DeclarableOpsTests4, Test_Squeeze_args_1) {
|
TEST_F(DeclarableOpsTests4, Test_Squeeze_args_1) {
|
||||||
auto x = NDArrayFactory::create<double>('c', {2, 1, 1, 1, 2}, {1, 2, 3, 4});
|
auto x = NDArrayFactory::create<double>('c', {2, 1, 1, 1, 2}, {1, 2, 3, 4});
|
||||||
auto exp = NDArrayFactory::create<double>('c', {2, 1, 2}, {1, 2, 3, 4});
|
auto exp = NDArrayFactory::create<double>('c', {2, 1, 2}, {1, 2, 3, 4});
|
||||||
|
@ -940,7 +1019,7 @@ TEST_F(DeclarableOpsTests4, Test_Squeeze_args_1) {
|
||||||
ASSERT_TRUE(exp.isSameShape(z));
|
ASSERT_TRUE(exp.isSameShape(z));
|
||||||
ASSERT_TRUE(exp.equalsTo(z));
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(DeclarableOpsTests4, Test_Squeeze_args_2) {
|
TEST_F(DeclarableOpsTests4, Test_Squeeze_args_2) {
|
||||||
|
@ -957,7 +1036,7 @@ TEST_F(DeclarableOpsTests4, Test_Squeeze_args_2) {
|
||||||
ASSERT_TRUE(exp.isSameShape(z));
|
ASSERT_TRUE(exp.isSameShape(z));
|
||||||
ASSERT_TRUE(exp.equalsTo(z));
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -974,7 +1053,7 @@ TEST_F(DeclarableOpsTests4, Test_Squeeze_args_3) {
|
||||||
ASSERT_TRUE(exp.isSameShape(z));
|
ASSERT_TRUE(exp.isSameShape(z));
|
||||||
ASSERT_TRUE(exp.equalsTo(z));
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(DeclarableOpsTests4, Test_SpaceToDepth_1) {
|
TEST_F(DeclarableOpsTests4, Test_SpaceToDepth_1) {
|
||||||
|
@ -990,7 +1069,7 @@ TEST_F(DeclarableOpsTests4, Test_SpaceToDepth_1) {
|
||||||
ASSERT_TRUE(exp.isSameShape(z));
|
ASSERT_TRUE(exp.isSameShape(z));
|
||||||
ASSERT_TRUE(exp.equalsTo(z));
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(DeclarableOpsTests4, Test_SpaceToDepth_2) {
|
TEST_F(DeclarableOpsTests4, Test_SpaceToDepth_2) {
|
||||||
|
@ -1006,7 +1085,7 @@ TEST_F(DeclarableOpsTests4, Test_SpaceToDepth_2) {
|
||||||
ASSERT_TRUE(exp.isSameShape(z));
|
ASSERT_TRUE(exp.isSameShape(z));
|
||||||
ASSERT_TRUE(exp.equalsTo(z));
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -1023,7 +1102,7 @@ TEST_F(DeclarableOpsTests4, Test_DepthToSpace_1) {
|
||||||
ASSERT_TRUE(exp.isSameShape(z));
|
ASSERT_TRUE(exp.isSameShape(z));
|
||||||
ASSERT_TRUE(exp.equalsTo(z));
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -1040,7 +1119,7 @@ TEST_F(DeclarableOpsTests4, Test_DepthToSpace_2) {
|
||||||
ASSERT_TRUE(exp.isSameShape(z));
|
ASSERT_TRUE(exp.isSameShape(z));
|
||||||
ASSERT_TRUE(exp.equalsTo(z));
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(DeclarableOpsTests4, Test_DepthToSpace_3) {
|
TEST_F(DeclarableOpsTests4, Test_DepthToSpace_3) {
|
||||||
|
@ -1055,7 +1134,7 @@ TEST_F(DeclarableOpsTests4, Test_DepthToSpace_3) {
|
||||||
|
|
||||||
ASSERT_TRUE(exp.isSameShape(z));
|
ASSERT_TRUE(exp.isSameShape(z));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -1073,7 +1152,7 @@ TEST_F(DeclarableOpsTests4, Test_Cross_1) {
|
||||||
ASSERT_TRUE(exp.isSameShape(z));
|
ASSERT_TRUE(exp.isSameShape(z));
|
||||||
ASSERT_TRUE(exp.equalsTo(z));
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -1091,7 +1170,7 @@ TEST_F(DeclarableOpsTests4, Test_Cross_2) {
|
||||||
ASSERT_TRUE(exp.isSameShape(z));
|
ASSERT_TRUE(exp.isSameShape(z));
|
||||||
ASSERT_TRUE(exp.equalsTo(z));
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -1109,7 +1188,7 @@ TEST_F(DeclarableOpsTests4, Test_Cross_3) {
|
||||||
ASSERT_TRUE(exp.isSameShape(z));
|
ASSERT_TRUE(exp.isSameShape(z));
|
||||||
ASSERT_TRUE(exp.equalsTo(z));
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(DeclarableOpsTests4, Test_Add_119) {
|
TEST_F(DeclarableOpsTests4, Test_Add_119) {
|
||||||
|
@ -1129,7 +1208,7 @@ TEST_F(DeclarableOpsTests4, Test_Add_119) {
|
||||||
ASSERT_TRUE(exp.isSameShape(z));
|
ASSERT_TRUE(exp.isSameShape(z));
|
||||||
ASSERT_TRUE(exp.equalsTo(z));
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(DeclarableOpsTests4, Test_Reshape_Negative_1) {
|
TEST_F(DeclarableOpsTests4, Test_Reshape_Negative_1) {
|
||||||
|
@ -1146,7 +1225,7 @@ TEST_F(DeclarableOpsTests4, Test_Reshape_Negative_1) {
|
||||||
ASSERT_TRUE(exp.isSameShape(z));
|
ASSERT_TRUE(exp.isSameShape(z));
|
||||||
ASSERT_TRUE(exp.equalsTo(z));
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(DeclarableOpsTests4, Test_TileToShape_1) {
|
TEST_F(DeclarableOpsTests4, Test_TileToShape_1) {
|
||||||
|
@ -1165,7 +1244,7 @@ TEST_F(DeclarableOpsTests4, Test_TileToShape_1) {
|
||||||
ASSERT_TRUE(exp.isSameShape(z));
|
ASSERT_TRUE(exp.isSameShape(z));
|
||||||
ASSERT_TRUE(exp.equalsTo(z));
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(DeclarableOpsTests4, Test_StridedSlice_Alex_1) {
|
TEST_F(DeclarableOpsTests4, Test_StridedSlice_Alex_1) {
|
||||||
|
@ -1184,7 +1263,7 @@ TEST_F(DeclarableOpsTests4, Test_StridedSlice_Alex_1) {
|
||||||
ASSERT_TRUE(exp.isSameShape(z));
|
ASSERT_TRUE(exp.isSameShape(z));
|
||||||
ASSERT_TRUE(exp.equalsTo(z));
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(DeclarableOpsTests4, Test_StridedSlice_Alex_2) {
|
TEST_F(DeclarableOpsTests4, Test_StridedSlice_Alex_2) {
|
||||||
|
@ -1208,7 +1287,7 @@ TEST_F(DeclarableOpsTests4, Test_StridedSlice_Alex_2) {
|
||||||
ASSERT_TRUE(exp.isSameShape(z));
|
ASSERT_TRUE(exp.isSameShape(z));
|
||||||
ASSERT_TRUE(exp.equalsTo(z));
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(DeclarableOpsTests4, Test_StridedSlice_Alex_3) {
|
TEST_F(DeclarableOpsTests4, Test_StridedSlice_Alex_3) {
|
||||||
|
@ -1229,7 +1308,7 @@ TEST_F(DeclarableOpsTests4, Test_StridedSlice_Alex_3) {
|
||||||
auto z = result.at(0);
|
auto z = result.at(0);
|
||||||
ASSERT_TRUE(z->isEmpty());
|
ASSERT_TRUE(z->isEmpty());
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
TEST_F(DeclarableOpsTests4, Test_StridedSlice_Alex_4) {
|
TEST_F(DeclarableOpsTests4, Test_StridedSlice_Alex_4) {
|
||||||
auto x = NDArrayFactory::create<double>('c', {1,3}, {1, 2, 3});
|
auto x = NDArrayFactory::create<double>('c', {1,3}, {1, 2, 3});
|
||||||
|
@ -1248,7 +1327,7 @@ TEST_F(DeclarableOpsTests4, Test_StridedSlice_Alex_4) {
|
||||||
auto z = result.at(0);
|
auto z = result.at(0);
|
||||||
ASSERT_TRUE(z->lengthOf() == 1);
|
ASSERT_TRUE(z->lengthOf() == 1);
|
||||||
ASSERT_TRUE(exp.equalsTo(z));
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////
|
||||||
|
@ -1272,7 +1351,7 @@ TEST_F(DeclarableOpsTests4, parallel_stack_test1) {
|
||||||
ASSERT_TRUE(expected.isSameShape(output));
|
ASSERT_TRUE(expected.isSameShape(output));
|
||||||
ASSERT_TRUE(expected.equalsTo(output));
|
ASSERT_TRUE(expected.equalsTo(output));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////
|
||||||
|
@ -1293,7 +1372,7 @@ TEST_F(DeclarableOpsTests4, parallel_stack_test2) {
|
||||||
ASSERT_TRUE(expected.equalsTo(output));
|
ASSERT_TRUE(expected.equalsTo(output));
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////
|
||||||
|
@ -1313,7 +1392,7 @@ TEST_F(DeclarableOpsTests4, parallel_stack_test3) {
|
||||||
ASSERT_TRUE(expected.isSameShape(output));
|
ASSERT_TRUE(expected.isSameShape(output));
|
||||||
ASSERT_TRUE(expected.equalsTo(output));
|
ASSERT_TRUE(expected.equalsTo(output));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
\
|
\
|
||||||
//////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////
|
||||||
|
@ -1333,7 +1412,7 @@ TEST_F(DeclarableOpsTests4, parallel_stack_test4) {
|
||||||
ASSERT_TRUE(expected.isSameShape(output));
|
ASSERT_TRUE(expected.isSameShape(output));
|
||||||
ASSERT_TRUE(expected.equalsTo(output));
|
ASSERT_TRUE(expected.equalsTo(output));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////
|
||||||
|
@ -1353,7 +1432,7 @@ TEST_F(DeclarableOpsTests4, parallel_stack_test5) {
|
||||||
ASSERT_TRUE(expected.isSameShape(output));
|
ASSERT_TRUE(expected.isSameShape(output));
|
||||||
ASSERT_TRUE(expected.equalsTo(output));
|
ASSERT_TRUE(expected.equalsTo(output));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////
|
||||||
|
@ -1373,7 +1452,7 @@ TEST_F(DeclarableOpsTests4, parallel_stack_test6) {
|
||||||
ASSERT_TRUE(expected.isSameShape(output));
|
ASSERT_TRUE(expected.isSameShape(output));
|
||||||
ASSERT_TRUE(expected.equalsTo(output));
|
ASSERT_TRUE(expected.equalsTo(output));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////
|
||||||
|
@ -1390,7 +1469,7 @@ TEST_F(DeclarableOpsTests4, parallel_stack_test7) {
|
||||||
ASSERT_TRUE(expected.isSameShape(output));
|
ASSERT_TRUE(expected.isSameShape(output));
|
||||||
ASSERT_TRUE(expected.equalsTo(output));
|
ASSERT_TRUE(expected.equalsTo(output));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////
|
||||||
|
@ -1419,7 +1498,7 @@ TEST_F(DeclarableOpsTests4, meshgrid_test1) {
|
||||||
ASSERT_TRUE(exp2.isSameShape(out2));
|
ASSERT_TRUE(exp2.isSameShape(out2));
|
||||||
ASSERT_TRUE(exp2.equalsTo(out2));
|
ASSERT_TRUE(exp2.equalsTo(out2));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////
|
||||||
|
@ -1446,7 +1525,7 @@ TEST_F(DeclarableOpsTests4, meshgrid_test2) {
|
||||||
ASSERT_TRUE(exp2.isSameShape(out2));
|
ASSERT_TRUE(exp2.isSameShape(out2));
|
||||||
ASSERT_TRUE(exp2.equalsTo(out2));
|
ASSERT_TRUE(exp2.equalsTo(out2));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////
|
||||||
|
@ -1473,7 +1552,7 @@ TEST_F(DeclarableOpsTests4, meshgrid_test3) {
|
||||||
ASSERT_TRUE(exp2.isSameShape(out2));
|
ASSERT_TRUE(exp2.isSameShape(out2));
|
||||||
ASSERT_TRUE(exp2.equalsTo(out2));
|
ASSERT_TRUE(exp2.equalsTo(out2));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////
|
||||||
|
@ -1500,7 +1579,7 @@ TEST_F(DeclarableOpsTests4, meshgrid_test4) {
|
||||||
ASSERT_TRUE(exp2.isSameShape(out2));
|
ASSERT_TRUE(exp2.isSameShape(out2));
|
||||||
ASSERT_TRUE(exp2.equalsTo(out2));
|
ASSERT_TRUE(exp2.equalsTo(out2));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////
|
||||||
|
@ -1527,7 +1606,7 @@ TEST_F(DeclarableOpsTests4, meshgrid_test5) {
|
||||||
ASSERT_TRUE(exp2.isSameShape(out2));
|
ASSERT_TRUE(exp2.isSameShape(out2));
|
||||||
ASSERT_TRUE(exp2.equalsTo(out2));
|
ASSERT_TRUE(exp2.equalsTo(out2));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////
|
||||||
|
@ -1554,7 +1633,7 @@ TEST_F(DeclarableOpsTests4, meshgrid_test6) {
|
||||||
ASSERT_TRUE(exp2.isSameShape(out2));
|
ASSERT_TRUE(exp2.isSameShape(out2));
|
||||||
ASSERT_TRUE(exp2.equalsTo(out2));
|
ASSERT_TRUE(exp2.equalsTo(out2));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////
|
||||||
|
@ -1581,7 +1660,7 @@ TEST_F(DeclarableOpsTests4, meshgrid_test7) {
|
||||||
ASSERT_TRUE(exp2.isSameShape(out2));
|
ASSERT_TRUE(exp2.isSameShape(out2));
|
||||||
ASSERT_TRUE(exp2.equalsTo(out2));
|
ASSERT_TRUE(exp2.equalsTo(out2));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////
|
||||||
|
@ -1598,7 +1677,7 @@ TEST_F(DeclarableOpsTests4, meshgrid_test8) {
|
||||||
ASSERT_TRUE(exp0.isSameShape(out0));
|
ASSERT_TRUE(exp0.isSameShape(out0));
|
||||||
ASSERT_TRUE(exp0.equalsTo(out0));
|
ASSERT_TRUE(exp0.equalsTo(out0));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////
|
||||||
|
@ -1615,7 +1694,7 @@ TEST_F(DeclarableOpsTests4, meshgrid_test9) {
|
||||||
ASSERT_TRUE(exp0.isSameShape(out0));
|
ASSERT_TRUE(exp0.isSameShape(out0));
|
||||||
ASSERT_TRUE(exp0.equalsTo(out0));
|
ASSERT_TRUE(exp0.equalsTo(out0));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
|
@ -1644,7 +1723,7 @@ TEST_F(DeclarableOpsTests4, WeightedCrossEntropyWithLogits_1) {
|
||||||
ASSERT_TRUE(expected.isSameShape(output));
|
ASSERT_TRUE(expected.isSameShape(output));
|
||||||
ASSERT_TRUE(expected.equalsTo(output));
|
ASSERT_TRUE(expected.equalsTo(output));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
|
@ -1664,7 +1743,7 @@ TEST_F(DeclarableOpsTests4, WeightedCrossEntropyWithLogits_2) {
|
||||||
ASSERT_TRUE(expected.isSameShape(output));
|
ASSERT_TRUE(expected.isSameShape(output));
|
||||||
ASSERT_TRUE(expected.equalsTo(output));
|
ASSERT_TRUE(expected.equalsTo(output));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -1718,7 +1797,7 @@ TEST_F(DeclarableOpsTests4, lstm_test1) {
|
||||||
ASSERT_TRUE(expClast.isSameShape(&cLast));
|
ASSERT_TRUE(expClast.isSameShape(&cLast));
|
||||||
ASSERT_TRUE(expClast.equalsTo(&cLast));
|
ASSERT_TRUE(expClast.equalsTo(&cLast));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
///////////////////////////////////////////////////////////////////
|
///////////////////////////////////////////////////////////////////
|
||||||
|
@ -1737,7 +1816,7 @@ TEST_F(DeclarableOpsTests4, relu6_test1) {
|
||||||
ASSERT_TRUE(expected.isSameShape(output));
|
ASSERT_TRUE(expected.isSameShape(output));
|
||||||
ASSERT_TRUE(expected.equalsTo(output));
|
ASSERT_TRUE(expected.equalsTo(output));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -1759,7 +1838,7 @@ TEST_F(DeclarableOpsTests4, relu6_bp_test1) {
|
||||||
ASSERT_TRUE(expected.isSameShape(output));
|
ASSERT_TRUE(expected.isSameShape(output));
|
||||||
ASSERT_TRUE(expected.equalsTo(output));
|
ASSERT_TRUE(expected.equalsTo(output));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
|
@ -1788,7 +1867,7 @@ TYPED_TEST(TypedDeclarableOpsTests4, LrnTest_1) {
|
||||||
// exp.printIndexedBuffer("LRN exp");
|
// exp.printIndexedBuffer("LRN exp");
|
||||||
ASSERT_TRUE(exp.equalsTo(out));
|
ASSERT_TRUE(exp.equalsTo(out));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -1816,7 +1895,7 @@ TYPED_TEST(TypedDeclarableOpsTests4, LrnTest_2) {
|
||||||
// exp.printIndexedBuffer("LRN exp");
|
// exp.printIndexedBuffer("LRN exp");
|
||||||
ASSERT_TRUE(exp.equalsTo(out));
|
ASSERT_TRUE(exp.equalsTo(out));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
|
@ -1855,7 +1934,7 @@ TYPED_TEST(TypedDeclarableOpsTests4, LrnTest_3) {
|
||||||
// exp.printIndexedBuffer("LRN exp");
|
// exp.printIndexedBuffer("LRN exp");
|
||||||
ASSERT_TRUE(exp.equalsTo(out));
|
ASSERT_TRUE(exp.equalsTo(out));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
|
@ -1894,7 +1973,7 @@ TYPED_TEST(TypedDeclarableOpsTests4, LrnTest_4) {
|
||||||
// exp.printIndexedBuffer("LRN exp");
|
// exp.printIndexedBuffer("LRN exp");
|
||||||
ASSERT_TRUE(exp.equalsTo(out));
|
ASSERT_TRUE(exp.equalsTo(out));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
|
@ -1939,7 +2018,7 @@ TYPED_TEST(TypedDeclarableOpsTests4, LrnTest_5) {
|
||||||
// exp.printIndexedBuffer("LRN exp");
|
// exp.printIndexedBuffer("LRN exp");
|
||||||
// ASSERT_TRUE(exp.equalsTo(out));
|
// ASSERT_TRUE(exp.equalsTo(out));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -1962,7 +2041,7 @@ TEST_F(DeclarableOpsTests4, tri_test1) {
|
||||||
ASSERT_TRUE(expected.isSameShape(output));
|
ASSERT_TRUE(expected.isSameShape(output));
|
||||||
ASSERT_TRUE(expected.equalsTo(output));
|
ASSERT_TRUE(expected.equalsTo(output));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////
|
||||||
|
@ -1983,7 +2062,7 @@ TEST_F(DeclarableOpsTests4, tri_test2) {
|
||||||
ASSERT_TRUE(expected.isSameShape(output));
|
ASSERT_TRUE(expected.isSameShape(output));
|
||||||
ASSERT_TRUE(expected.equalsTo(output));
|
ASSERT_TRUE(expected.equalsTo(output));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////
|
||||||
|
@ -2004,7 +2083,7 @@ TEST_F(DeclarableOpsTests4, tri_test3) {
|
||||||
ASSERT_TRUE(expected.isSameShape(output));
|
ASSERT_TRUE(expected.isSameShape(output));
|
||||||
ASSERT_TRUE(expected.equalsTo(output));
|
ASSERT_TRUE(expected.equalsTo(output));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////
|
||||||
|
@ -2025,7 +2104,7 @@ TEST_F(DeclarableOpsTests4, tri_test4) {
|
||||||
ASSERT_TRUE(expected.isSameShape(output));
|
ASSERT_TRUE(expected.isSameShape(output));
|
||||||
ASSERT_TRUE(expected.equalsTo(output));
|
ASSERT_TRUE(expected.equalsTo(output));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////
|
||||||
|
@ -2044,7 +2123,7 @@ TEST_F(DeclarableOpsTests4, tri_test5) {
|
||||||
ASSERT_TRUE(expected.isSameShape(output));
|
ASSERT_TRUE(expected.isSameShape(output));
|
||||||
ASSERT_TRUE(expected.equalsTo(output));
|
ASSERT_TRUE(expected.equalsTo(output));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////
|
||||||
|
@ -2065,7 +2144,7 @@ TEST_F(DeclarableOpsTests4, tri_test6) {
|
||||||
ASSERT_TRUE(expected.isSameShape(output));
|
ASSERT_TRUE(expected.isSameShape(output));
|
||||||
ASSERT_TRUE(expected.equalsTo(output));
|
ASSERT_TRUE(expected.equalsTo(output));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////
|
||||||
|
@ -2086,7 +2165,7 @@ TEST_F(DeclarableOpsTests4, tri_test7) {
|
||||||
ASSERT_TRUE(expected.isSameShape(output));
|
ASSERT_TRUE(expected.isSameShape(output));
|
||||||
ASSERT_TRUE(expected.equalsTo(output));
|
ASSERT_TRUE(expected.equalsTo(output));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////
|
||||||
|
@ -2105,7 +2184,7 @@ TEST_F(DeclarableOpsTests4, triu_test1) {
|
||||||
ASSERT_TRUE(expected.equalsTo(output));
|
ASSERT_TRUE(expected.equalsTo(output));
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////
|
||||||
|
@ -2123,7 +2202,7 @@ TEST_F(DeclarableOpsTests4, triu_test2) {
|
||||||
ASSERT_TRUE(expected.isSameShape(output));
|
ASSERT_TRUE(expected.isSameShape(output));
|
||||||
ASSERT_TRUE(expected.equalsTo(output));
|
ASSERT_TRUE(expected.equalsTo(output));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////
|
||||||
|
@ -2141,7 +2220,7 @@ TEST_F(DeclarableOpsTests4, triu_test3) {
|
||||||
ASSERT_TRUE(expected.isSameShape(output));
|
ASSERT_TRUE(expected.isSameShape(output));
|
||||||
ASSERT_TRUE(expected.equalsTo(output));
|
ASSERT_TRUE(expected.equalsTo(output));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////
|
||||||
|
@ -2159,7 +2238,7 @@ TEST_F(DeclarableOpsTests4, triu_test4) {
|
||||||
ASSERT_TRUE(expected.isSameShape(output));
|
ASSERT_TRUE(expected.isSameShape(output));
|
||||||
ASSERT_TRUE(expected.equalsTo(output));
|
ASSERT_TRUE(expected.equalsTo(output));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////
|
||||||
|
@ -2177,7 +2256,7 @@ TEST_F(DeclarableOpsTests4, triu_test5) {
|
||||||
ASSERT_TRUE(expected.isSameShape(output));
|
ASSERT_TRUE(expected.isSameShape(output));
|
||||||
ASSERT_TRUE(expected.equalsTo(output));
|
ASSERT_TRUE(expected.equalsTo(output));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////
|
||||||
|
@ -2195,7 +2274,7 @@ TEST_F(DeclarableOpsTests4, triu_test6) {
|
||||||
ASSERT_TRUE(expected.isSameShape(output));
|
ASSERT_TRUE(expected.isSameShape(output));
|
||||||
ASSERT_TRUE(expected.equalsTo(output));
|
ASSERT_TRUE(expected.equalsTo(output));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////
|
||||||
|
@ -2213,7 +2292,7 @@ TEST_F(DeclarableOpsTests4, triu_test7) {
|
||||||
ASSERT_TRUE(expected.isSameShape(output));
|
ASSERT_TRUE(expected.isSameShape(output));
|
||||||
ASSERT_TRUE(expected.equalsTo(output));
|
ASSERT_TRUE(expected.equalsTo(output));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////
|
||||||
|
@ -2231,7 +2310,7 @@ TEST_F(DeclarableOpsTests4, triu_test8) {
|
||||||
ASSERT_TRUE(expected.isSameShape(output));
|
ASSERT_TRUE(expected.isSameShape(output));
|
||||||
ASSERT_TRUE(expected.equalsTo(output));
|
ASSERT_TRUE(expected.equalsTo(output));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////
|
||||||
|
@ -2249,7 +2328,7 @@ TEST_F(DeclarableOpsTests4, triu_test9) {
|
||||||
ASSERT_TRUE(expected.isSameShape(output));
|
ASSERT_TRUE(expected.isSameShape(output));
|
||||||
ASSERT_TRUE(expected.equalsTo(output));
|
ASSERT_TRUE(expected.equalsTo(output));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////
|
||||||
|
@ -2267,7 +2346,7 @@ TEST_F(DeclarableOpsTests4, triu_test10) {
|
||||||
ASSERT_TRUE(expected.isSameShape(output));
|
ASSERT_TRUE(expected.isSameShape(output));
|
||||||
ASSERT_TRUE(expected.equalsTo(output));
|
ASSERT_TRUE(expected.equalsTo(output));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////
|
||||||
|
@ -2285,7 +2364,7 @@ TEST_F(DeclarableOpsTests4, triu_test11) {
|
||||||
ASSERT_TRUE(expected.isSameShape(output));
|
ASSERT_TRUE(expected.isSameShape(output));
|
||||||
ASSERT_TRUE(expected.equalsTo(output));
|
ASSERT_TRUE(expected.equalsTo(output));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -2307,7 +2386,7 @@ TEST_F(DeclarableOpsTests4, triu_bp_test1) {
|
||||||
ASSERT_TRUE(expected.isSameShape(gradI));
|
ASSERT_TRUE(expected.isSameShape(gradI));
|
||||||
ASSERT_TRUE(expected.equalsTo(gradI));
|
ASSERT_TRUE(expected.equalsTo(gradI));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////
|
||||||
|
@ -2328,7 +2407,7 @@ TEST_F(DeclarableOpsTests4, triu_bp_test2) {
|
||||||
ASSERT_TRUE(expected.isSameShape(gradI));
|
ASSERT_TRUE(expected.isSameShape(gradI));
|
||||||
ASSERT_TRUE(expected.equalsTo(gradI));
|
ASSERT_TRUE(expected.equalsTo(gradI));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////
|
||||||
|
@ -2349,7 +2428,7 @@ TEST_F(DeclarableOpsTests4, triu_bp_test3) {
|
||||||
ASSERT_TRUE(expected.isSameShape(gradI));
|
ASSERT_TRUE(expected.isSameShape(gradI));
|
||||||
ASSERT_TRUE(expected.equalsTo(gradI));
|
ASSERT_TRUE(expected.equalsTo(gradI));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////
|
||||||
|
@ -2370,6 +2449,6 @@ TEST_F(DeclarableOpsTests4, triu_bp_test4) {
|
||||||
ASSERT_TRUE(expected.isSameShape(gradI));
|
ASSERT_TRUE(expected.isSameShape(gradI));
|
||||||
ASSERT_TRUE(expected.equalsTo(gradI));
|
ASSERT_TRUE(expected.equalsTo(gradI));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -57,14 +57,14 @@ TEST_F(DeclarableOpsTests9, reduceStDevBP_test3) {
|
||||||
// output->printIndexedBuffer();
|
// output->printIndexedBuffer();
|
||||||
ASSERT_TRUE(exp.isSameShape(output));
|
ASSERT_TRUE(exp.isSameShape(output));
|
||||||
ASSERT_TRUE(exp.equalsTo(output));
|
ASSERT_TRUE(exp.equalsTo(output));
|
||||||
|
|
||||||
|
|
||||||
result = op.evaluate({&x, &gradO1}, {1,0}, {1});
|
result = op.evaluate({&x, &gradO1}, {1,0}, {1});
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result.status());
|
ASSERT_EQ(ND4J_STATUS_OK, result.status());
|
||||||
output = result.at(0);
|
output = result.at(0);
|
||||||
ASSERT_TRUE(exp.isSameShape(output));
|
ASSERT_TRUE(exp.isSameShape(output));
|
||||||
ASSERT_TRUE(exp.equalsTo(output));
|
ASSERT_TRUE(exp.equalsTo(output));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -86,14 +86,14 @@ TEST_F(DeclarableOpsTests9, reduceStDevBP_test03) {
|
||||||
// output->printIndexedBuffer();
|
// output->printIndexedBuffer();
|
||||||
ASSERT_TRUE(exp.isSameShape(output));
|
ASSERT_TRUE(exp.isSameShape(output));
|
||||||
ASSERT_TRUE(exp.equalsTo(output));
|
ASSERT_TRUE(exp.equalsTo(output));
|
||||||
|
|
||||||
|
|
||||||
result = op.evaluate({&x, &gradO1}, {1,0}, {1});
|
result = op.evaluate({&x, &gradO1}, {1,0}, {1});
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result.status());
|
ASSERT_EQ(ND4J_STATUS_OK, result.status());
|
||||||
output = result.at(0);
|
output = result.at(0);
|
||||||
ASSERT_TRUE(exp.isSameShape(output));
|
ASSERT_TRUE(exp.isSameShape(output));
|
||||||
ASSERT_TRUE(exp.equalsTo(output));
|
ASSERT_TRUE(exp.equalsTo(output));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
/*
|
/*
|
||||||
|
@ -255,7 +255,7 @@ TEST_F(DeclarableOpsTests9, concat_test1) {
|
||||||
ASSERT_TRUE(exp.isSameShape(output));
|
ASSERT_TRUE(exp.isSameShape(output));
|
||||||
ASSERT_TRUE(exp.equalsTo(output));
|
ASSERT_TRUE(exp.equalsTo(output));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
|
@ -279,7 +279,7 @@ TEST_F(DeclarableOpsTests9, concat_test2) {
|
||||||
ASSERT_TRUE(exp.isSameShape(output));
|
ASSERT_TRUE(exp.isSameShape(output));
|
||||||
ASSERT_TRUE(exp.equalsTo(output));
|
ASSERT_TRUE(exp.equalsTo(output));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
|
@ -305,7 +305,7 @@ TEST_F(DeclarableOpsTests9, concat_test3) {
|
||||||
ASSERT_TRUE(exp.isSameShape(output));
|
ASSERT_TRUE(exp.isSameShape(output));
|
||||||
ASSERT_TRUE(exp.equalsTo(output));
|
ASSERT_TRUE(exp.equalsTo(output));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
|
@ -325,7 +325,7 @@ TEST_F(DeclarableOpsTests9, concat_test4) {
|
||||||
ASSERT_TRUE(exp.isSameShape(output));
|
ASSERT_TRUE(exp.isSameShape(output));
|
||||||
ASSERT_TRUE(exp.equalsTo(output));
|
ASSERT_TRUE(exp.equalsTo(output));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
|
@ -345,7 +345,7 @@ TEST_F(DeclarableOpsTests9, concat_test5) {
|
||||||
ASSERT_TRUE(exp.isSameShape(output));
|
ASSERT_TRUE(exp.isSameShape(output));
|
||||||
ASSERT_TRUE(exp.equalsTo(output));
|
ASSERT_TRUE(exp.equalsTo(output));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
|
@ -365,7 +365,7 @@ TEST_F(DeclarableOpsTests9, concat_test6) {
|
||||||
ASSERT_TRUE(exp.isSameShape(output));
|
ASSERT_TRUE(exp.isSameShape(output));
|
||||||
ASSERT_TRUE(exp.equalsTo(output));
|
ASSERT_TRUE(exp.equalsTo(output));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
|
@ -385,7 +385,7 @@ TEST_F(DeclarableOpsTests9, concat_test7) {
|
||||||
ASSERT_TRUE(exp.isSameShape(output));
|
ASSERT_TRUE(exp.isSameShape(output));
|
||||||
ASSERT_TRUE(exp.equalsTo(output));
|
ASSERT_TRUE(exp.equalsTo(output));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
|
@ -403,7 +403,7 @@ TEST_F(DeclarableOpsTests9, concat_test8) {
|
||||||
ASSERT_TRUE(exp.isSameShape(output));
|
ASSERT_TRUE(exp.isSameShape(output));
|
||||||
ASSERT_TRUE(exp.equalsTo(output));
|
ASSERT_TRUE(exp.equalsTo(output));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
|
@ -421,7 +421,7 @@ TEST_F(DeclarableOpsTests9, concat_test9) {
|
||||||
ASSERT_TRUE(exp.isSameShape(output));
|
ASSERT_TRUE(exp.isSameShape(output));
|
||||||
ASSERT_TRUE(exp.equalsTo(output));
|
ASSERT_TRUE(exp.equalsTo(output));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
|
@ -446,7 +446,7 @@ TEST_F(DeclarableOpsTests9, concat_test10) {
|
||||||
ASSERT_TRUE(exp.isSameShape(output));
|
ASSERT_TRUE(exp.isSameShape(output));
|
||||||
ASSERT_TRUE(exp.equalsTo(output));
|
ASSERT_TRUE(exp.equalsTo(output));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
|
@ -471,7 +471,7 @@ TEST_F(DeclarableOpsTests9, concat_test11) {
|
||||||
ASSERT_TRUE(exp.isSameShape(output));
|
ASSERT_TRUE(exp.isSameShape(output));
|
||||||
ASSERT_TRUE(exp.equalsTo(output));
|
ASSERT_TRUE(exp.equalsTo(output));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
|
@ -496,7 +496,7 @@ TEST_F(DeclarableOpsTests9, concat_test12) {
|
||||||
ASSERT_TRUE(exp.isSameShape(output));
|
ASSERT_TRUE(exp.isSameShape(output));
|
||||||
ASSERT_TRUE(exp.equalsTo(output));
|
ASSERT_TRUE(exp.equalsTo(output));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
|
@ -522,7 +522,7 @@ TEST_F(DeclarableOpsTests9, concat_test13) {
|
||||||
ASSERT_TRUE(exp.isSameShape(output));
|
ASSERT_TRUE(exp.isSameShape(output));
|
||||||
ASSERT_TRUE(exp.equalsTo(output));
|
ASSERT_TRUE(exp.equalsTo(output));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(DeclarableOpsTests9, concat_test14) {
|
TEST_F(DeclarableOpsTests9, concat_test14) {
|
||||||
|
@ -548,7 +548,7 @@ TEST_F(DeclarableOpsTests9, concat_test14) {
|
||||||
ASSERT_NEAR((e+1)*1., mean, 1e-5);
|
ASSERT_NEAR((e+1)*1., mean, 1e-5);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(DeclarableOpsTests9, concat_test15) {
|
TEST_F(DeclarableOpsTests9, concat_test15) {
|
||||||
|
@ -565,7 +565,7 @@ TEST_F(DeclarableOpsTests9, concat_test15) {
|
||||||
ASSERT_TRUE(exp.isSameShape(z));
|
ASSERT_TRUE(exp.isSameShape(z));
|
||||||
ASSERT_TRUE(exp.equalsTo(z));
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////
|
||||||
|
@ -582,8 +582,6 @@ TEST_F(DeclarableOpsTests9, concat_test16) {
|
||||||
auto z = result.at(0);
|
auto z = result.at(0);
|
||||||
|
|
||||||
ASSERT_TRUE(exp.isSameShape(z));
|
ASSERT_TRUE(exp.isSameShape(z));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////
|
||||||
|
@ -611,8 +609,6 @@ TEST_F(DeclarableOpsTests9, concat_test17) {
|
||||||
auto mean = tad.meanNumber().e<double>(0);
|
auto mean = tad.meanNumber().e<double>(0);
|
||||||
ASSERT_NEAR((e+1)*1., mean, 1e-5);
|
ASSERT_NEAR((e+1)*1., mean, 1e-5);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////
|
||||||
|
@ -693,7 +689,7 @@ TEST_F(DeclarableOpsTests9, concat_test20) {
|
||||||
ASSERT_NEAR((double) e+1, mean, 1e-5);
|
ASSERT_NEAR((double) e+1, mean, 1e-5);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
|
@ -775,7 +771,7 @@ TEST_F(DeclarableOpsTests9, concat_test25) {
|
||||||
ASSERT_TRUE(exp.isSameShape(output));
|
ASSERT_TRUE(exp.isSameShape(output));
|
||||||
ASSERT_TRUE(exp.equalsTo(output));
|
ASSERT_TRUE(exp.equalsTo(output));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
|
@ -803,6 +799,25 @@ TEST_F(DeclarableOpsTests9, concat_test26) {
|
||||||
ASSERT_TRUE(exp.equalsTo(output));
|
ASSERT_TRUE(exp.equalsTo(output));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//////////////////////////////////////////////////////////////////////
|
||||||
|
TEST_F(DeclarableOpsTests9, concat_test27) {
|
||||||
|
|
||||||
|
auto x1 = NDArrayFactory::create<double>('c', {0,1});
|
||||||
|
auto x2 = NDArrayFactory::create<double>('c', {0,1});
|
||||||
|
auto x3 = NDArrayFactory::create<double>('c', {0,1});
|
||||||
|
auto x4 = NDArrayFactory::create<double>('c', {0,1});
|
||||||
|
|
||||||
|
std::vector<Nd4jLong> expShape = {0, 4};
|
||||||
|
|
||||||
|
sd::ops::concat op;
|
||||||
|
auto result = op.evaluate({&x1, &x2, &x3, &x4}, {}, {1});
|
||||||
|
ASSERT_EQ(ND4J_STATUS_OK, result.status());
|
||||||
|
|
||||||
|
auto z = result.at(0);
|
||||||
|
|
||||||
|
ASSERT_TRUE(z->isSameShape(expShape));
|
||||||
|
}
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////
|
||||||
TEST_F(DeclarableOpsTests9, tile_bp_test1) {
|
TEST_F(DeclarableOpsTests9, tile_bp_test1) {
|
||||||
|
|
||||||
|
@ -820,7 +835,7 @@ TEST_F(DeclarableOpsTests9, tile_bp_test1) {
|
||||||
ASSERT_TRUE(gradIExp.isSameShape(gradI));
|
ASSERT_TRUE(gradIExp.isSameShape(gradI));
|
||||||
ASSERT_TRUE(gradIExp.equalsTo(gradI));
|
ASSERT_TRUE(gradIExp.equalsTo(gradI));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////
|
||||||
|
@ -839,7 +854,7 @@ TEST_F(DeclarableOpsTests9, tile_bp_test2) {
|
||||||
ASSERT_TRUE(gradIExp.isSameShape(gradI));
|
ASSERT_TRUE(gradIExp.isSameShape(gradI));
|
||||||
ASSERT_TRUE(gradIExp.equalsTo(gradI));
|
ASSERT_TRUE(gradIExp.equalsTo(gradI));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////
|
||||||
|
@ -859,7 +874,7 @@ TEST_F(DeclarableOpsTests9, tile_bp_test3) {
|
||||||
ASSERT_TRUE(gradIExp.isSameShape(gradI));
|
ASSERT_TRUE(gradIExp.isSameShape(gradI));
|
||||||
ASSERT_TRUE(gradIExp.equalsTo(gradI));
|
ASSERT_TRUE(gradIExp.equalsTo(gradI));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////
|
||||||
|
@ -879,7 +894,7 @@ TEST_F(DeclarableOpsTests9, tile_bp_test4) {
|
||||||
ASSERT_TRUE(gradIExp.isSameShape(gradI));
|
ASSERT_TRUE(gradIExp.isSameShape(gradI));
|
||||||
ASSERT_TRUE(gradIExp.equalsTo(gradI));
|
ASSERT_TRUE(gradIExp.equalsTo(gradI));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////
|
||||||
|
@ -899,7 +914,7 @@ TEST_F(DeclarableOpsTests9, tile_bp_test5) {
|
||||||
ASSERT_TRUE(gradIExp.isSameShape(gradI));
|
ASSERT_TRUE(gradIExp.isSameShape(gradI));
|
||||||
ASSERT_TRUE(gradIExp.equalsTo(gradI));
|
ASSERT_TRUE(gradIExp.equalsTo(gradI));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////
|
||||||
|
@ -919,7 +934,7 @@ TEST_F(DeclarableOpsTests9, tile_bp_test6) {
|
||||||
ASSERT_TRUE(gradIExp.isSameShape(gradI));
|
ASSERT_TRUE(gradIExp.isSameShape(gradI));
|
||||||
ASSERT_TRUE(gradIExp.equalsTo(gradI));
|
ASSERT_TRUE(gradIExp.equalsTo(gradI));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////
|
||||||
|
@ -940,7 +955,7 @@ TEST_F(DeclarableOpsTests9, tile_bp_test7) {
|
||||||
ASSERT_TRUE(gradIExp.isSameShape(gradI));
|
ASSERT_TRUE(gradIExp.isSameShape(gradI));
|
||||||
ASSERT_TRUE(gradIExp.equalsTo(gradI));
|
ASSERT_TRUE(gradIExp.equalsTo(gradI));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////
|
||||||
|
@ -958,7 +973,7 @@ TEST_F(DeclarableOpsTests9, tile_test1) {
|
||||||
ASSERT_TRUE(expOut.isSameShape(out));
|
ASSERT_TRUE(expOut.isSameShape(out));
|
||||||
ASSERT_TRUE(expOut.equalsTo(out));
|
ASSERT_TRUE(expOut.equalsTo(out));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
|
@ -975,7 +990,7 @@ TEST_F(DeclarableOpsTests9, TestDropout_BP_1) {
|
||||||
//ress.at(0)->printIndexedBuffer("Result is ");
|
//ress.at(0)->printIndexedBuffer("Result is ");
|
||||||
//x.printIndexedBuffer("Input is");
|
//x.printIndexedBuffer("Input is");
|
||||||
ASSERT_FALSE(ress.at(0)->equalsTo(errs));
|
ASSERT_FALSE(ress.at(0)->equalsTo(errs));
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
|
@ -1006,8 +1021,8 @@ TEST_F(DeclarableOpsTests9, TestDropout_1) {
|
||||||
//res->printIndexedBuffer("FF dropout");
|
//res->printIndexedBuffer("FF dropout");
|
||||||
//res2->printIndexedBuffer("BP dropout");
|
//res2->printIndexedBuffer("BP dropout");
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(DeclarableOpsTests9, Test_DropoutInverted_01) {
|
TEST_F(DeclarableOpsTests9, Test_DropoutInverted_01) {
|
||||||
|
@ -1080,7 +1095,7 @@ TEST_F(DeclarableOpsTests9, Test_DropoutInverted_01) {
|
||||||
*/
|
*/
|
||||||
// ASSERT_TRUE(exp.equalsTo(ressX->at(0)));
|
// ASSERT_TRUE(exp.equalsTo(ressX->at(0)));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(DeclarableOpsTests9, Test_Dropout_BP_2) {
|
TEST_F(DeclarableOpsTests9, Test_Dropout_BP_2) {
|
||||||
|
@ -1117,7 +1132,7 @@ TEST_F(DeclarableOpsTests9, Test_Dropout_BP_2) {
|
||||||
ASSERT_NEAR(countZero.e<float>(0), 50.f, 10.f);
|
ASSERT_NEAR(countZero.e<float>(0), 50.f, 10.f);
|
||||||
// ASSERT_TRUE(exp.equalsTo(ressX->at(0)));
|
// ASSERT_TRUE(exp.equalsTo(ressX->at(0)));
|
||||||
ASSERT_TRUE(ressX.at(0)->equalsTo(ressY.at(0)));
|
ASSERT_TRUE(ressX.at(0)->equalsTo(ressY.at(0)));
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -1143,7 +1158,7 @@ TEST_F(DeclarableOpsTests9, Test_AlphaDropout_BP_1) {
|
||||||
//res->printIndexedBuffer("Result1AlphaBP1");
|
//res->printIndexedBuffer("Result1AlphaBP1");
|
||||||
//res2->printIndexedBuffer("Result1AlphaBP2");
|
//res2->printIndexedBuffer("Result1AlphaBP2");
|
||||||
ASSERT_TRUE(res2->equalsTo(res));
|
ASSERT_TRUE(res2->equalsTo(res));
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(DeclarableOpsTests9, test_range_int_1) {
|
TEST_F(DeclarableOpsTests9, test_range_int_1) {
|
||||||
|
@ -1227,7 +1242,7 @@ TEST_F(DeclarableOpsTests9, test_unstack_SGO_1) {
|
||||||
ASSERT_TRUE(result.at(i)->isSameShape(z[i]));
|
ASSERT_TRUE(result.at(i)->isSameShape(z[i]));
|
||||||
ASSERT_TRUE(result.at(i)->equalsTo(z[i]));
|
ASSERT_TRUE(result.at(i)->equalsTo(z[i]));
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
|
@ -1263,7 +1278,7 @@ TEST_F(DeclarableOpsTests9, clipbynorm_test12) {
|
||||||
ASSERT_TRUE(expect.isSameShape(outFF));
|
ASSERT_TRUE(expect.isSameShape(outFF));
|
||||||
ASSERT_TRUE(expect.equalsTo(outFF));
|
ASSERT_TRUE(expect.equalsTo(outFF));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -1355,7 +1370,7 @@ TEST_F(DeclarableOpsTests9, cumprod_1) {
|
||||||
ASSERT_EQ(Status::OK(), result.status());
|
ASSERT_EQ(Status::OK(), result.status());
|
||||||
auto z = result.at(0);
|
auto z = result.at(0);
|
||||||
ASSERT_TRUE(expFF.equalsTo(z));
|
ASSERT_TRUE(expFF.equalsTo(z));
|
||||||
|
|
||||||
|
|
||||||
//************************************//
|
//************************************//
|
||||||
exclusive = 1; reverse = 0;
|
exclusive = 1; reverse = 0;
|
||||||
|
@ -1364,7 +1379,7 @@ TEST_F(DeclarableOpsTests9, cumprod_1) {
|
||||||
ASSERT_EQ(Status::OK(), result.status());
|
ASSERT_EQ(Status::OK(), result.status());
|
||||||
z = result.at(0);
|
z = result.at(0);
|
||||||
ASSERT_TRUE(expTF.equalsTo(z));
|
ASSERT_TRUE(expTF.equalsTo(z));
|
||||||
|
|
||||||
|
|
||||||
//************************************//
|
//************************************//
|
||||||
exclusive = 0; reverse = 1;
|
exclusive = 0; reverse = 1;
|
||||||
|
@ -1373,7 +1388,7 @@ TEST_F(DeclarableOpsTests9, cumprod_1) {
|
||||||
ASSERT_EQ(Status::OK(), result.status());
|
ASSERT_EQ(Status::OK(), result.status());
|
||||||
z = result.at(0);
|
z = result.at(0);
|
||||||
ASSERT_TRUE(expFT.equalsTo(z));
|
ASSERT_TRUE(expFT.equalsTo(z));
|
||||||
|
|
||||||
|
|
||||||
//************************************//
|
//************************************//
|
||||||
exclusive = 1; reverse = 1;
|
exclusive = 1; reverse = 1;
|
||||||
|
@ -1382,7 +1397,7 @@ TEST_F(DeclarableOpsTests9, cumprod_1) {
|
||||||
ASSERT_EQ(Status::OK(), result.status());
|
ASSERT_EQ(Status::OK(), result.status());
|
||||||
z = result.at(0);
|
z = result.at(0);
|
||||||
ASSERT_TRUE(expTT.equalsTo(z));
|
ASSERT_TRUE(expTT.equalsTo(z));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1416,7 +1431,7 @@ TEST_F(DeclarableOpsTests9, cumprod_2) {
|
||||||
|
|
||||||
ASSERT_TRUE(exp.equalsTo(z));
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////
|
||||||
|
@ -1584,7 +1599,7 @@ TEST_F(DeclarableOpsTests9, prelu_test1) {
|
||||||
ASSERT_TRUE(exp.isSameShape(output));
|
ASSERT_TRUE(exp.isSameShape(output));
|
||||||
ASSERT_TRUE(exp.equalsTo(output));
|
ASSERT_TRUE(exp.equalsTo(output));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
|
@ -1602,7 +1617,7 @@ TEST_F(DeclarableOpsTests9, prelu_test2) {
|
||||||
ASSERT_TRUE(exp.isSameShape(output));
|
ASSERT_TRUE(exp.isSameShape(output));
|
||||||
ASSERT_TRUE(exp.equalsTo(output));
|
ASSERT_TRUE(exp.equalsTo(output));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
|
@ -1620,7 +1635,7 @@ TEST_F(DeclarableOpsTests9, prelu_test3) {
|
||||||
ASSERT_TRUE(exp.isSameShape(output));
|
ASSERT_TRUE(exp.isSameShape(output));
|
||||||
ASSERT_TRUE(exp.equalsTo(output));
|
ASSERT_TRUE(exp.equalsTo(output));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
|
@ -1638,7 +1653,7 @@ TEST_F(DeclarableOpsTests9, prelu_test4) {
|
||||||
ASSERT_TRUE(exp.isSameShape(output));
|
ASSERT_TRUE(exp.isSameShape(output));
|
||||||
ASSERT_TRUE(exp.equalsTo(output));
|
ASSERT_TRUE(exp.equalsTo(output));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
|
@ -1656,7 +1671,7 @@ TEST_F(DeclarableOpsTests9, prelu_test5) {
|
||||||
ASSERT_TRUE(exp.isSameShape(output));
|
ASSERT_TRUE(exp.isSameShape(output));
|
||||||
ASSERT_TRUE(exp.equalsTo(output));
|
ASSERT_TRUE(exp.equalsTo(output));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
|
@ -1674,7 +1689,7 @@ TEST_F(DeclarableOpsTests9, prelu_test6) {
|
||||||
ASSERT_TRUE(exp.isSameShape(output));
|
ASSERT_TRUE(exp.isSameShape(output));
|
||||||
ASSERT_TRUE(exp.equalsTo(output));
|
ASSERT_TRUE(exp.equalsTo(output));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -1693,7 +1708,7 @@ TEST_F(DeclarableOpsTests9, prelu_test7) {
|
||||||
ASSERT_TRUE(exp.isSameShape(output));
|
ASSERT_TRUE(exp.isSameShape(output));
|
||||||
ASSERT_TRUE(exp.equalsTo(output));
|
ASSERT_TRUE(exp.equalsTo(output));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
|
@ -1711,7 +1726,7 @@ TEST_F(DeclarableOpsTests9, prelu_test8) {
|
||||||
ASSERT_TRUE(exp.isSameShape(output));
|
ASSERT_TRUE(exp.isSameShape(output));
|
||||||
ASSERT_TRUE(exp.equalsTo(output));
|
ASSERT_TRUE(exp.equalsTo(output));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
|
@ -1729,7 +1744,7 @@ TEST_F(DeclarableOpsTests9, prelu_test9) {
|
||||||
ASSERT_TRUE(exp.isSameShape(output));
|
ASSERT_TRUE(exp.isSameShape(output));
|
||||||
ASSERT_TRUE(exp.equalsTo(output));
|
ASSERT_TRUE(exp.equalsTo(output));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
|
@ -1747,7 +1762,7 @@ TEST_F(DeclarableOpsTests9, prelu_test10) {
|
||||||
ASSERT_TRUE(exp.isSameShape(output));
|
ASSERT_TRUE(exp.isSameShape(output));
|
||||||
ASSERT_TRUE(exp.equalsTo(output));
|
ASSERT_TRUE(exp.equalsTo(output));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
|
@ -1772,7 +1787,7 @@ TEST_F(DeclarableOpsTests9, prelu_test11) {
|
||||||
ASSERT_TRUE(exp.isSameShape(output));
|
ASSERT_TRUE(exp.isSameShape(output));
|
||||||
ASSERT_TRUE(exp.equalsTo(output));
|
ASSERT_TRUE(exp.equalsTo(output));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
|
@ -1796,7 +1811,7 @@ TEST_F(DeclarableOpsTests9, prelu_test12) {
|
||||||
ASSERT_TRUE(exp.isSameShape(output));
|
ASSERT_TRUE(exp.isSameShape(output));
|
||||||
ASSERT_TRUE(exp.equalsTo(output));
|
ASSERT_TRUE(exp.equalsTo(output));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
|
@ -1820,7 +1835,7 @@ TEST_F(DeclarableOpsTests9, prelu_test13) {
|
||||||
ASSERT_TRUE(exp.isSameShape(output));
|
ASSERT_TRUE(exp.isSameShape(output));
|
||||||
ASSERT_TRUE(exp.equalsTo(output));
|
ASSERT_TRUE(exp.equalsTo(output));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
|
@ -1845,7 +1860,7 @@ TEST_F(DeclarableOpsTests9, prelu_test14) {
|
||||||
ASSERT_TRUE(exp.isSameShape(output));
|
ASSERT_TRUE(exp.isSameShape(output));
|
||||||
ASSERT_TRUE(exp.equalsTo(output));
|
ASSERT_TRUE(exp.equalsTo(output));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
|
@ -1864,7 +1879,7 @@ TEST_F(DeclarableOpsTests9, thresholdedrelu_test1) {
|
||||||
ASSERT_TRUE(exp.isSameShape(output));
|
ASSERT_TRUE(exp.isSameShape(output));
|
||||||
ASSERT_TRUE(exp.equalsTo(output));
|
ASSERT_TRUE(exp.equalsTo(output));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
|
@ -1883,7 +1898,7 @@ TEST_F(DeclarableOpsTests9, compare_and_bitpack_test1) {
|
||||||
// output->printIndexedBuffer("Packed to uint8");
|
// output->printIndexedBuffer("Packed to uint8");
|
||||||
ASSERT_TRUE(exp.isSameShape(output));
|
ASSERT_TRUE(exp.isSameShape(output));
|
||||||
ASSERT_TRUE(exp.equalsTo(output));
|
ASSERT_TRUE(exp.equalsTo(output));
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
|
@ -1902,7 +1917,7 @@ TEST_F(DeclarableOpsTests9, thresholdedrelu_test2) {
|
||||||
ASSERT_TRUE(exp.isSameShape(output));
|
ASSERT_TRUE(exp.isSameShape(output));
|
||||||
ASSERT_TRUE(exp.equalsTo(output));
|
ASSERT_TRUE(exp.equalsTo(output));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
|
@ -2017,7 +2032,7 @@ TEST_F(DeclarableOpsTests9, multiply_test1) {
|
||||||
ASSERT_TRUE(exp.isSameShape(z));
|
ASSERT_TRUE(exp.isSameShape(z));
|
||||||
ASSERT_TRUE(exp.equalsTo(z));
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
|
@ -2037,7 +2052,7 @@ TEST_F(DeclarableOpsTests9, multiply_test2) {
|
||||||
ASSERT_TRUE(exp.isSameShape(z));
|
ASSERT_TRUE(exp.isSameShape(z));
|
||||||
ASSERT_TRUE(exp.equalsTo(z));
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
|
@ -2057,7 +2072,7 @@ TEST_F(DeclarableOpsTests9, multiply_test3) {
|
||||||
ASSERT_TRUE(exp.isSameShape(z));
|
ASSERT_TRUE(exp.isSameShape(z));
|
||||||
ASSERT_TRUE(exp.equalsTo(z));
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
|
@ -2076,7 +2091,7 @@ TEST_F(DeclarableOpsTests9, multiply_test4) {
|
||||||
ASSERT_TRUE(exp.isSameShape(z));
|
ASSERT_TRUE(exp.isSameShape(z));
|
||||||
ASSERT_TRUE(exp.equalsTo(z));
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
|
@ -2094,7 +2109,7 @@ TEST_F(DeclarableOpsTests9, multiply_test5) {
|
||||||
ASSERT_TRUE(exp.isSameShape(z));
|
ASSERT_TRUE(exp.isSameShape(z));
|
||||||
ASSERT_TRUE(exp.equalsTo(z));
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
|
@ -2403,7 +2418,7 @@ TEST_F(DeclarableOpsTests9, gru_cell_bp_test1) {
|
||||||
dLdu.assign(dLdh * dhdu);
|
dLdu.assign(dLdh * dhdu);
|
||||||
dLdr.assign(mmul(dLdc * dcdZc * hi, Wch.transpose()));
|
dLdr.assign(mmul(dLdc * dcdZc * hi, Wch.transpose()));
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
const OpArgsHolder argsHolderBP({&x, &hi, &W, &Wc, &b, &bc, &dLdr, &dLdu, &dLdc, &dLdh}, {}, {});
|
const OpArgsHolder argsHolderBP({&x, &hi, &W, &Wc, &b, &bc, &dLdr, &dLdu, &dLdc, &dLdh}, {}, {});
|
||||||
|
@ -2430,7 +2445,7 @@ TEST_F(DeclarableOpsTests9, Cholesky_Test_1) {
|
||||||
auto res = result.at(0);
|
auto res = result.at(0);
|
||||||
// res->printIndexedBuffer("Output for Cholesky1");
|
// res->printIndexedBuffer("Output for Cholesky1");
|
||||||
ASSERT_TRUE(exp.equalsTo(res));
|
ASSERT_TRUE(exp.equalsTo(res));
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////
|
||||||
|
@ -2446,7 +2461,7 @@ TEST_F(DeclarableOpsTests9, Cholesky_Test_2) {
|
||||||
auto res = result.at(0);
|
auto res = result.at(0);
|
||||||
// res->printIndexedBuffer("Output for Cholesky 2");
|
// res->printIndexedBuffer("Output for Cholesky 2");
|
||||||
ASSERT_TRUE(exp.equalsTo(res));
|
ASSERT_TRUE(exp.equalsTo(res));
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////
|
||||||
|
@ -2462,7 +2477,7 @@ TEST_F(DeclarableOpsTests9, Cholesky_Test_3) {
|
||||||
auto res = result.at(0);
|
auto res = result.at(0);
|
||||||
// res->printIndexedBuffer("Output for Cholesky 3");
|
// res->printIndexedBuffer("Output for Cholesky 3");
|
||||||
ASSERT_TRUE(exp.equalsTo(res, 1e-4));
|
ASSERT_TRUE(exp.equalsTo(res, 1e-4));
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////
|
||||||
|
|
|
@ -3772,6 +3772,8 @@ public abstract class BaseNDArray implements INDArray, Iterable {
|
||||||
INDArray ret = Nd4j.createUninitialized(this.dataType(), shape, order);
|
INDArray ret = Nd4j.createUninitialized(this.dataType(), shape, order);
|
||||||
ret.setData(dup(order).data());
|
ret.setData(dup(order).data());
|
||||||
return ret;
|
return ret;
|
||||||
|
} else if (this.isEmpty()) {
|
||||||
|
return Nd4j.create(this.dataType(), shape);
|
||||||
} else {
|
} else {
|
||||||
INDArray ret = this.dup(order);
|
INDArray ret = this.dup(order);
|
||||||
return Nd4j.create(ret.data(), shape);
|
return Nd4j.create(ret.data(), shape);
|
||||||
|
|
|
@ -4997,8 +4997,8 @@ public native @Cast("bool") boolean isOptimalRequirementsMet();
|
||||||
******************************************************************************/
|
******************************************************************************/
|
||||||
|
|
||||||
//
|
//
|
||||||
// This class is suited for execution results representation.
|
// This class is suited for execution results representation.
|
||||||
//
|
//
|
||||||
// PLESE NOTE: It will delete all stored NDArrays upon destructor call
|
// PLESE NOTE: It will delete all stored NDArrays upon destructor call
|
||||||
//
|
//
|
||||||
// @author raver119@gmail.com
|
// @author raver119@gmail.com
|
||||||
|
@ -5011,7 +5011,7 @@ public native @Cast("bool") boolean isOptimalRequirementsMet();
|
||||||
// #include <graph/generated/result_generated.h>
|
// #include <graph/generated/result_generated.h>
|
||||||
// #include <system/pointercast.h>
|
// #include <system/pointercast.h>
|
||||||
// #include <system/dll.h> // forward declaration of template class NDArray
|
// #include <system/dll.h> // forward declaration of template class NDArray
|
||||||
|
|
||||||
@Namespace("sd") @NoOffset public static class ResultSet extends Pointer {
|
@Namespace("sd") @NoOffset public static class ResultSet extends Pointer {
|
||||||
static { Loader.load(); }
|
static { Loader.load(); }
|
||||||
/** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */
|
/** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */
|
||||||
|
@ -6877,6 +6877,9 @@ public static final int PREALLOC_SIZE = 33554432;
|
||||||
@Namespace("shape") public static native int sizeAt(@Cast("const Nd4jLong*") LongPointer shapeInfo, int dim);
|
@Namespace("shape") public static native int sizeAt(@Cast("const Nd4jLong*") LongPointer shapeInfo, int dim);
|
||||||
@Namespace("shape") public static native int sizeAt(@Cast("const Nd4jLong*") LongBuffer shapeInfo, int dim);
|
@Namespace("shape") public static native int sizeAt(@Cast("const Nd4jLong*") LongBuffer shapeInfo, int dim);
|
||||||
@Namespace("shape") public static native int sizeAt(@Cast("const Nd4jLong*") long[] shapeInfo, int dim);
|
@Namespace("shape") public static native int sizeAt(@Cast("const Nd4jLong*") long[] shapeInfo, int dim);
|
||||||
|
@Namespace("shape") public static native @Cast("Nd4jLong") long strideAt(@Cast("const Nd4jLong*") LongPointer shapeInfo, int dim);
|
||||||
|
@Namespace("shape") public static native @Cast("Nd4jLong") long strideAt(@Cast("const Nd4jLong*") LongBuffer shapeInfo, int dim);
|
||||||
|
@Namespace("shape") public static native @Cast("Nd4jLong") long strideAt(@Cast("const Nd4jLong*") long[] shapeInfo, int dim);
|
||||||
|
|
||||||
@Namespace("shape") public static native void traceNew(int id);
|
@Namespace("shape") public static native void traceNew(int id);
|
||||||
|
|
||||||
|
@ -7814,14 +7817,20 @@ public static final int PREALLOC_SIZE = 33554432;
|
||||||
@Namespace("shape") public static native void index2coords(@Cast("Nd4jLong") long index, int rank, @Cast("const Nd4jLong*") LongPointer shape, IntPointer coords);
|
@Namespace("shape") public static native void index2coords(@Cast("Nd4jLong") long index, int rank, @Cast("const Nd4jLong*") LongPointer shape, IntPointer coords);
|
||||||
@Namespace("shape") public static native void index2coords(@Cast("Nd4jLong") long index, int rank, @Cast("const Nd4jLong*") LongBuffer shape, IntBuffer coords);
|
@Namespace("shape") public static native void index2coords(@Cast("Nd4jLong") long index, int rank, @Cast("const Nd4jLong*") LongBuffer shape, IntBuffer coords);
|
||||||
@Namespace("shape") public static native void index2coords(@Cast("Nd4jLong") long index, int rank, @Cast("const Nd4jLong*") long[] shape, int[] coords);
|
@Namespace("shape") public static native void index2coords(@Cast("Nd4jLong") long index, int rank, @Cast("const Nd4jLong*") long[] shape, int[] coords);
|
||||||
|
|
||||||
|
@Namespace("shape") public static native void index2coordsCPU(@Cast("const Nd4jLong") long startIndex, @Cast("const Nd4jLong") long index, @Cast("const Nd4jLong*") LongPointer shapeInfo, @Cast("Nd4jLong*") LongPointer coords);
|
||||||
|
@Namespace("shape") public static native void index2coordsCPU(@Cast("const Nd4jLong") long startIndex, @Cast("const Nd4jLong") long index, @Cast("const Nd4jLong*") LongBuffer shapeInfo, @Cast("Nd4jLong*") LongBuffer coords);
|
||||||
|
@Namespace("shape") public static native void index2coordsCPU(@Cast("const Nd4jLong") long startIndex, @Cast("const Nd4jLong") long index, @Cast("const Nd4jLong*") long[] shapeInfo, @Cast("Nd4jLong*") long[] coords);
|
||||||
|
@Namespace("shape") public static native void index2coordsCPU(@Cast("const Nd4jLong") long startIndex, @Cast("const Nd4jLong") long index, @Cast("const Nd4jLong*") LongPointer shapeInfo, IntPointer coords);
|
||||||
|
@Namespace("shape") public static native void index2coordsCPU(@Cast("const Nd4jLong") long startIndex, @Cast("const Nd4jLong") long index, @Cast("const Nd4jLong*") LongBuffer shapeInfo, IntBuffer coords);
|
||||||
|
@Namespace("shape") public static native void index2coordsCPU(@Cast("const Nd4jLong") long startIndex, @Cast("const Nd4jLong") long index, @Cast("const Nd4jLong*") long[] shapeInfo, int[] coords);
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* take into account only dimensions stored in tadDims, tadDims must be sorted in increasing order!
|
* take into account only dimensions stored in tadDims, tadDims must be sorted in increasing order!
|
||||||
*/
|
*/
|
||||||
@Namespace("shape") public static native void index2coords(@Cast("Nd4jLong") long index, @Cast("const Nd4jLong*") LongPointer shapeInfo, @Cast("Nd4jLong*") LongPointer coords, int dimsSize, @Const IntPointer tadDims);
|
@Namespace("shape") public static native void index2coords(@Cast("Nd4jLong") long index, @Cast("const Nd4jLong*") LongPointer shapeInfo, IntPointer coords, int dimsSize, @Const IntPointer tadDims);
|
||||||
@Namespace("shape") public static native void index2coords(@Cast("Nd4jLong") long index, @Cast("const Nd4jLong*") LongBuffer shapeInfo, @Cast("Nd4jLong*") LongBuffer coords, int dimsSize, @Const IntBuffer tadDims);
|
@Namespace("shape") public static native void index2coords(@Cast("Nd4jLong") long index, @Cast("const Nd4jLong*") LongBuffer shapeInfo, IntBuffer coords, int dimsSize, @Const IntBuffer tadDims);
|
||||||
@Namespace("shape") public static native void index2coords(@Cast("Nd4jLong") long index, @Cast("const Nd4jLong*") long[] shapeInfo, @Cast("Nd4jLong*") long[] coords, int dimsSize, @Const int[] tadDims);
|
@Namespace("shape") public static native void index2coords(@Cast("Nd4jLong") long index, @Cast("const Nd4jLong*") long[] shapeInfo, int[] coords, int dimsSize, @Const int[] tadDims);
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Convert coordinates to the corresponding linear index (sequence number in other words)
|
* Convert coordinates to the corresponding linear index (sequence number in other words)
|
||||||
|
@ -7833,15 +7842,15 @@ public static final int PREALLOC_SIZE = 33554432;
|
||||||
@Namespace("shape") public static native @Cast("Nd4jLong") long coords2index(@Cast("const Nd4jLong*") LongPointer shapeInfo, @Const IntPointer coords);
|
@Namespace("shape") public static native @Cast("Nd4jLong") long coords2index(@Cast("const Nd4jLong*") LongPointer shapeInfo, @Const IntPointer coords);
|
||||||
@Namespace("shape") public static native @Cast("Nd4jLong") long coords2index(@Cast("const Nd4jLong*") LongBuffer shapeInfo, @Const IntBuffer coords);
|
@Namespace("shape") public static native @Cast("Nd4jLong") long coords2index(@Cast("const Nd4jLong*") LongBuffer shapeInfo, @Const IntBuffer coords);
|
||||||
@Namespace("shape") public static native @Cast("Nd4jLong") long coords2index(@Cast("const Nd4jLong*") long[] shapeInfo, @Const int[] coords);
|
@Namespace("shape") public static native @Cast("Nd4jLong") long coords2index(@Cast("const Nd4jLong*") long[] shapeInfo, @Const int[] coords);
|
||||||
@Namespace("shape") public static native @Cast("Nd4jLong") long coords2index(int rank, @Cast("const Nd4jLong*") LongPointer shape, @Cast("const Nd4jLong*") LongPointer coords);
|
@Namespace("shape") public static native @Cast("Nd4jLong") long coords2index(int rank, @Cast("const Nd4jLong*") LongPointer shape, @Const IntPointer coords);
|
||||||
@Namespace("shape") public static native @Cast("Nd4jLong") long coords2index(int rank, @Cast("const Nd4jLong*") LongBuffer shape, @Cast("const Nd4jLong*") LongBuffer coords);
|
@Namespace("shape") public static native @Cast("Nd4jLong") long coords2index(int rank, @Cast("const Nd4jLong*") LongBuffer shape, @Const IntBuffer coords);
|
||||||
@Namespace("shape") public static native @Cast("Nd4jLong") long coords2index(int rank, @Cast("const Nd4jLong*") long[] shape, @Cast("const Nd4jLong*") long[] coords);
|
@Namespace("shape") public static native @Cast("Nd4jLong") long coords2index(int rank, @Cast("const Nd4jLong*") long[] shape, @Const int[] coords);
|
||||||
/**
|
/**
|
||||||
* take into account only dimensions stored in tadDims, tadDims must be sorted in increasing order!
|
* take into account only dimensions stored in tadDims, tadDims must be sorted in increasing order!
|
||||||
*/
|
*/
|
||||||
@Namespace("shape") public static native @Cast("Nd4jLong") long coords2index(@Cast("const Nd4jLong*") LongPointer shapeInfo, @Cast("const Nd4jLong*") LongPointer coords, int dimsSize, @Const IntPointer tadDims);
|
@Namespace("shape") public static native @Cast("Nd4jLong") long coords2index(@Cast("const Nd4jLong*") LongPointer shapeInfo, @Const IntPointer coords, int dimsSize, @Const IntPointer tadDims);
|
||||||
@Namespace("shape") public static native @Cast("Nd4jLong") long coords2index(@Cast("const Nd4jLong*") LongBuffer shapeInfo, @Cast("const Nd4jLong*") LongBuffer coords, int dimsSize, @Const IntBuffer tadDims);
|
@Namespace("shape") public static native @Cast("Nd4jLong") long coords2index(@Cast("const Nd4jLong*") LongBuffer shapeInfo, @Const IntBuffer coords, int dimsSize, @Const IntBuffer tadDims);
|
||||||
@Namespace("shape") public static native @Cast("Nd4jLong") long coords2index(@Cast("const Nd4jLong*") long[] shapeInfo, @Cast("const Nd4jLong*") long[] coords, int dimsSize, @Const int[] tadDims);
|
@Namespace("shape") public static native @Cast("Nd4jLong") long coords2index(@Cast("const Nd4jLong*") long[] shapeInfo, @Const int[] coords, int dimsSize, @Const int[] tadDims);
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* increment n-dimensional array by one iteration by changing coord appropriately
|
* increment n-dimensional array by one iteration by changing coord appropriately
|
||||||
|
@ -7931,32 +7940,32 @@ public static final int PREALLOC_SIZE = 33554432;
|
||||||
// function calculates the coordinates of min array (and saves them into minIdxs) given coordinates of max array (already stored in maxIdxs)
|
// function calculates the coordinates of min array (and saves them into minIdxs) given coordinates of max array (already stored in maxIdxs)
|
||||||
// dimsToExclude - should be sorted in increasing order
|
// dimsToExclude - should be sorted in increasing order
|
||||||
// dimsLen - length of dimsToExclude, if not set (= -1), then it is calculated as maxRank - minRank
|
// dimsLen - length of dimsToExclude, if not set (= -1), then it is calculated as maxRank - minRank
|
||||||
@Namespace("shape") public static native void maxIndToMinInd(@Cast("Nd4jLong*") LongPointer maxIdxs, @Cast("Nd4jLong*") LongPointer minIdxs, @Cast("const Nd4jLong*") LongPointer maxShapeInfo, @Cast("const Nd4jLong*") LongPointer minShapeInfo, @Const IntPointer dimsToExclude/*=nullptr*/, int dimsLen/*=-1*/);
|
@Namespace("shape") public static native void maxIndToMinInd(IntPointer maxIdxs, IntPointer minIdxs, @Cast("const Nd4jLong*") LongPointer maxShapeInfo, @Cast("const Nd4jLong*") LongPointer minShapeInfo, @Const IntPointer dimsToExclude/*=nullptr*/, int dimsLen/*=-1*/);
|
||||||
@Namespace("shape") public static native void maxIndToMinInd(@Cast("Nd4jLong*") LongPointer maxIdxs, @Cast("Nd4jLong*") LongPointer minIdxs, @Cast("const Nd4jLong*") LongPointer maxShapeInfo, @Cast("const Nd4jLong*") LongPointer minShapeInfo);
|
@Namespace("shape") public static native void maxIndToMinInd(IntPointer maxIdxs, IntPointer minIdxs, @Cast("const Nd4jLong*") LongPointer maxShapeInfo, @Cast("const Nd4jLong*") LongPointer minShapeInfo);
|
||||||
@Namespace("shape") public static native void maxIndToMinInd(@Cast("Nd4jLong*") LongBuffer maxIdxs, @Cast("Nd4jLong*") LongBuffer minIdxs, @Cast("const Nd4jLong*") LongBuffer maxShapeInfo, @Cast("const Nd4jLong*") LongBuffer minShapeInfo, @Const IntBuffer dimsToExclude/*=nullptr*/, int dimsLen/*=-1*/);
|
@Namespace("shape") public static native void maxIndToMinInd(IntBuffer maxIdxs, IntBuffer minIdxs, @Cast("const Nd4jLong*") LongBuffer maxShapeInfo, @Cast("const Nd4jLong*") LongBuffer minShapeInfo, @Const IntBuffer dimsToExclude/*=nullptr*/, int dimsLen/*=-1*/);
|
||||||
@Namespace("shape") public static native void maxIndToMinInd(@Cast("Nd4jLong*") LongBuffer maxIdxs, @Cast("Nd4jLong*") LongBuffer minIdxs, @Cast("const Nd4jLong*") LongBuffer maxShapeInfo, @Cast("const Nd4jLong*") LongBuffer minShapeInfo);
|
@Namespace("shape") public static native void maxIndToMinInd(IntBuffer maxIdxs, IntBuffer minIdxs, @Cast("const Nd4jLong*") LongBuffer maxShapeInfo, @Cast("const Nd4jLong*") LongBuffer minShapeInfo);
|
||||||
@Namespace("shape") public static native void maxIndToMinInd(@Cast("Nd4jLong*") long[] maxIdxs, @Cast("Nd4jLong*") long[] minIdxs, @Cast("const Nd4jLong*") long[] maxShapeInfo, @Cast("const Nd4jLong*") long[] minShapeInfo, @Const int[] dimsToExclude/*=nullptr*/, int dimsLen/*=-1*/);
|
@Namespace("shape") public static native void maxIndToMinInd(int[] maxIdxs, int[] minIdxs, @Cast("const Nd4jLong*") long[] maxShapeInfo, @Cast("const Nd4jLong*") long[] minShapeInfo, @Const int[] dimsToExclude/*=nullptr*/, int dimsLen/*=-1*/);
|
||||||
@Namespace("shape") public static native void maxIndToMinInd(@Cast("Nd4jLong*") long[] maxIdxs, @Cast("Nd4jLong*") long[] minIdxs, @Cast("const Nd4jLong*") long[] maxShapeInfo, @Cast("const Nd4jLong*") long[] minShapeInfo);
|
@Namespace("shape") public static native void maxIndToMinInd(int[] maxIdxs, int[] minIdxs, @Cast("const Nd4jLong*") long[] maxShapeInfo, @Cast("const Nd4jLong*") long[] minShapeInfo);
|
||||||
|
|
||||||
// calculate indexes of max-array, these output indexes correspond to one minIdx index of min-array which is sub-array of max-array
|
// calculate indexes of max-array, these output indexes correspond to one minIdx index of min-array which is sub-array of max-array
|
||||||
// dimsToExclude - should be sorted in increasing order
|
// dimsToExclude - should be sorted in increasing order
|
||||||
@Namespace("shape") public static native int outerArrayIndexes(@Cast("Nd4jLong*") LongPointer maxIdxs, @Cast("const Nd4jLong") long minIdx, @Cast("const Nd4jLong*") LongPointer maxShapeInfo, @Cast("const Nd4jLong*") LongPointer minShapeInfo, @Const IntPointer dimsToExclude/*=nullptr*/);
|
@Namespace("shape") public static native int outerArrayIndexes(IntPointer maxIdxs, @Cast("const Nd4jLong") long minIdx, @Cast("const Nd4jLong*") LongPointer maxShapeInfo, @Cast("const Nd4jLong*") LongPointer minShapeInfo, @Const IntPointer dimsToExclude/*=nullptr*/);
|
||||||
@Namespace("shape") public static native int outerArrayIndexes(@Cast("Nd4jLong*") LongPointer maxIdxs, @Cast("const Nd4jLong") long minIdx, @Cast("const Nd4jLong*") LongPointer maxShapeInfo, @Cast("const Nd4jLong*") LongPointer minShapeInfo);
|
@Namespace("shape") public static native int outerArrayIndexes(IntPointer maxIdxs, @Cast("const Nd4jLong") long minIdx, @Cast("const Nd4jLong*") LongPointer maxShapeInfo, @Cast("const Nd4jLong*") LongPointer minShapeInfo);
|
||||||
@Namespace("shape") public static native int outerArrayIndexes(@Cast("Nd4jLong*") LongBuffer maxIdxs, @Cast("const Nd4jLong") long minIdx, @Cast("const Nd4jLong*") LongBuffer maxShapeInfo, @Cast("const Nd4jLong*") LongBuffer minShapeInfo, @Const IntBuffer dimsToExclude/*=nullptr*/);
|
@Namespace("shape") public static native int outerArrayIndexes(IntBuffer maxIdxs, @Cast("const Nd4jLong") long minIdx, @Cast("const Nd4jLong*") LongBuffer maxShapeInfo, @Cast("const Nd4jLong*") LongBuffer minShapeInfo, @Const IntBuffer dimsToExclude/*=nullptr*/);
|
||||||
@Namespace("shape") public static native int outerArrayIndexes(@Cast("Nd4jLong*") LongBuffer maxIdxs, @Cast("const Nd4jLong") long minIdx, @Cast("const Nd4jLong*") LongBuffer maxShapeInfo, @Cast("const Nd4jLong*") LongBuffer minShapeInfo);
|
@Namespace("shape") public static native int outerArrayIndexes(IntBuffer maxIdxs, @Cast("const Nd4jLong") long minIdx, @Cast("const Nd4jLong*") LongBuffer maxShapeInfo, @Cast("const Nd4jLong*") LongBuffer minShapeInfo);
|
||||||
@Namespace("shape") public static native int outerArrayIndexes(@Cast("Nd4jLong*") long[] maxIdxs, @Cast("const Nd4jLong") long minIdx, @Cast("const Nd4jLong*") long[] maxShapeInfo, @Cast("const Nd4jLong*") long[] minShapeInfo, @Const int[] dimsToExclude/*=nullptr*/);
|
@Namespace("shape") public static native int outerArrayIndexes(int[] maxIdxs, @Cast("const Nd4jLong") long minIdx, @Cast("const Nd4jLong*") long[] maxShapeInfo, @Cast("const Nd4jLong*") long[] minShapeInfo, @Const int[] dimsToExclude/*=nullptr*/);
|
||||||
@Namespace("shape") public static native int outerArrayIndexes(@Cast("Nd4jLong*") long[] maxIdxs, @Cast("const Nd4jLong") long minIdx, @Cast("const Nd4jLong*") long[] maxShapeInfo, @Cast("const Nd4jLong*") long[] minShapeInfo);
|
@Namespace("shape") public static native int outerArrayIndexes(int[] maxIdxs, @Cast("const Nd4jLong") long minIdx, @Cast("const Nd4jLong*") long[] maxShapeInfo, @Cast("const Nd4jLong*") long[] minShapeInfo);
|
||||||
|
|
||||||
// calculate offsets of max-array, these offsets correspond to one minIdx index of min-array which is sub-array of max-array
|
// calculate offsets of max-array, these offsets correspond to one minIdx index of min-array which is sub-array of max-array
|
||||||
// maxOffsets - will contain calculated offsets of max-array, buffer for maxOffsets should be allocated beforehand
|
// maxOffsets - will contain calculated offsets of max-array, buffer for maxOffsets should be allocated beforehand
|
||||||
// dimsToExclude - should be sorted in increasing order
|
// dimsToExclude - should be sorted in increasing order
|
||||||
// memBuff - auxiliary memory buffer (size = 2 * max_rank) for coordinates and increments storing, should be allocated beforehand
|
// memBuff - auxiliary memory buffer (size = 2 * max_rank) for coordinates and increments storing, should be allocated beforehand
|
||||||
@Namespace("shape") public static native int outerArrayOffsets(@Cast("Nd4jLong*") LongPointer maxOffsets, @Cast("const Nd4jLong") long minIdx, @Cast("const Nd4jLong*") LongPointer maxShapeInfo, @Cast("const Nd4jLong*") LongPointer minShapeInfo, @Cast("Nd4jLong*") LongPointer memBuff, @Const IntPointer dimsToExclude/*=nullptr*/);
|
@Namespace("shape") public static native int outerArrayOffsets(@Cast("Nd4jLong*") LongPointer maxOffsets, @Cast("const Nd4jLong") long minIdx, @Cast("const Nd4jLong*") LongPointer maxShapeInfo, @Cast("const Nd4jLong*") LongPointer minShapeInfo, IntPointer memBuff, @Const IntPointer dimsToExclude/*=nullptr*/);
|
||||||
@Namespace("shape") public static native int outerArrayOffsets(@Cast("Nd4jLong*") LongPointer maxOffsets, @Cast("const Nd4jLong") long minIdx, @Cast("const Nd4jLong*") LongPointer maxShapeInfo, @Cast("const Nd4jLong*") LongPointer minShapeInfo, @Cast("Nd4jLong*") LongPointer memBuff);
|
@Namespace("shape") public static native int outerArrayOffsets(@Cast("Nd4jLong*") LongPointer maxOffsets, @Cast("const Nd4jLong") long minIdx, @Cast("const Nd4jLong*") LongPointer maxShapeInfo, @Cast("const Nd4jLong*") LongPointer minShapeInfo, IntPointer memBuff);
|
||||||
@Namespace("shape") public static native int outerArrayOffsets(@Cast("Nd4jLong*") LongBuffer maxOffsets, @Cast("const Nd4jLong") long minIdx, @Cast("const Nd4jLong*") LongBuffer maxShapeInfo, @Cast("const Nd4jLong*") LongBuffer minShapeInfo, @Cast("Nd4jLong*") LongBuffer memBuff, @Const IntBuffer dimsToExclude/*=nullptr*/);
|
@Namespace("shape") public static native int outerArrayOffsets(@Cast("Nd4jLong*") LongBuffer maxOffsets, @Cast("const Nd4jLong") long minIdx, @Cast("const Nd4jLong*") LongBuffer maxShapeInfo, @Cast("const Nd4jLong*") LongBuffer minShapeInfo, IntBuffer memBuff, @Const IntBuffer dimsToExclude/*=nullptr*/);
|
||||||
@Namespace("shape") public static native int outerArrayOffsets(@Cast("Nd4jLong*") LongBuffer maxOffsets, @Cast("const Nd4jLong") long minIdx, @Cast("const Nd4jLong*") LongBuffer maxShapeInfo, @Cast("const Nd4jLong*") LongBuffer minShapeInfo, @Cast("Nd4jLong*") LongBuffer memBuff);
|
@Namespace("shape") public static native int outerArrayOffsets(@Cast("Nd4jLong*") LongBuffer maxOffsets, @Cast("const Nd4jLong") long minIdx, @Cast("const Nd4jLong*") LongBuffer maxShapeInfo, @Cast("const Nd4jLong*") LongBuffer minShapeInfo, IntBuffer memBuff);
|
||||||
@Namespace("shape") public static native int outerArrayOffsets(@Cast("Nd4jLong*") long[] maxOffsets, @Cast("const Nd4jLong") long minIdx, @Cast("const Nd4jLong*") long[] maxShapeInfo, @Cast("const Nd4jLong*") long[] minShapeInfo, @Cast("Nd4jLong*") long[] memBuff, @Const int[] dimsToExclude/*=nullptr*/);
|
@Namespace("shape") public static native int outerArrayOffsets(@Cast("Nd4jLong*") long[] maxOffsets, @Cast("const Nd4jLong") long minIdx, @Cast("const Nd4jLong*") long[] maxShapeInfo, @Cast("const Nd4jLong*") long[] minShapeInfo, int[] memBuff, @Const int[] dimsToExclude/*=nullptr*/);
|
||||||
@Namespace("shape") public static native int outerArrayOffsets(@Cast("Nd4jLong*") long[] maxOffsets, @Cast("const Nd4jLong") long minIdx, @Cast("const Nd4jLong*") long[] maxShapeInfo, @Cast("const Nd4jLong*") long[] minShapeInfo, @Cast("Nd4jLong*") long[] memBuff);
|
@Namespace("shape") public static native int outerArrayOffsets(@Cast("Nd4jLong*") long[] maxOffsets, @Cast("const Nd4jLong") long minIdx, @Cast("const Nd4jLong*") long[] maxShapeInfo, @Cast("const Nd4jLong*") long[] minShapeInfo, int[] memBuff);
|
||||||
|
|
||||||
// calculates offsets for entities (elements or sub-arrays), shape in context of sub-array means dimensions excluded from outer array
|
// calculates offsets for entities (elements or sub-arrays), shape in context of sub-array means dimensions excluded from outer array
|
||||||
// rank is equal to size of shape
|
// rank is equal to size of shape
|
||||||
|
@ -8052,9 +8061,7 @@ public static final int PREALLOC_SIZE = 33554432;
|
||||||
* get stride over contiguous axis (contiguous axis must have stride = 1)
|
* get stride over contiguous axis (contiguous axis must have stride = 1)
|
||||||
* for example when inShapeInfo is {4, 2,5,4,3, 60,1,5,20, 16384,0,99} then output is 5 (that is smallest stride in inShapeInfo except those equal to 1)
|
* for example when inShapeInfo is {4, 2,5,4,3, 60,1,5,20, 16384,0,99} then output is 5 (that is smallest stride in inShapeInfo except those equal to 1)
|
||||||
*/
|
*/
|
||||||
@Namespace("shape") public static native @Cast("Nd4jLong") long strideOverContigAxis(int axis, @Cast("const Nd4jLong*") LongPointer inShapeInfo);
|
// INLINEDEF _CUDA_HD Nd4jLong strideOverContigAxis(const int axis, const Nd4jLong* inShapeInfo);
|
||||||
@Namespace("shape") public static native @Cast("Nd4jLong") long strideOverContigAxis(int axis, @Cast("const Nd4jLong*") LongBuffer inShapeInfo);
|
|
||||||
@Namespace("shape") public static native @Cast("Nd4jLong") long strideOverContigAxis(int axis, @Cast("const Nd4jLong*") long[] inShapeInfo);
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
@ -8930,6 +8937,10 @@ public static final int PREALLOC_SIZE = 33554432;
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
//////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
//////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////
|
||||||
// INLINEDEF _CUDA_HD void calcOffsets(const Nd4jLong *xShapeInfo, Nd4jLong*& xOffsets, const Nd4jLong *yShapeInfo, Nd4jLong*& yOffsets, const Nd4jLong* zShapeInfo, Nd4jLong*& zOffsets, const char order) {
|
// INLINEDEF _CUDA_HD void calcOffsets(const Nd4jLong *xShapeInfo, Nd4jLong*& xOffsets, const Nd4jLong *yShapeInfo, Nd4jLong*& yOffsets, const Nd4jLong* zShapeInfo, Nd4jLong*& zOffsets, const char order) {
|
||||||
|
|
||||||
|
@ -9126,6 +9137,23 @@ public static final int PREALLOC_SIZE = 33554432;
|
||||||
|
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////
|
||||||
|
// INLINEDEF _CUDA_HD Nd4jLong strideOverContigAxis(const int axis, const Nd4jLong* inShapeInfo) {
|
||||||
|
|
||||||
|
// Nd4jLong result = 9223372036854775807LL;
|
||||||
|
|
||||||
|
// for(uint i = 0; i < shape::rank(inShapeInfo); ++i) {
|
||||||
|
|
||||||
|
// const auto currentStride = shape::stride(inShapeInfo)[i];
|
||||||
|
|
||||||
|
// if(i == axis || shape::shapeOf(inShapeInfo)[i] == 1)
|
||||||
|
// continue;
|
||||||
|
|
||||||
|
// if(result > currentStride)
|
||||||
|
// result = currentStride;
|
||||||
|
// }
|
||||||
|
|
||||||
|
// return result == 9223372036854775807LL ? 1 : result;
|
||||||
|
// }
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
@ -9736,18 +9764,17 @@ public static final int PREALLOC_SIZE = 33554432;
|
||||||
public native @Cast("Nd4jStatus") int execute(@Const @ByRef NDArrayVector inputs, @Const @ByRef NDArrayVector outputs, @StdVector DoubleBuffer tArgs, @Cast("Nd4jLong*") @StdVector LongBuffer iArgs, @Cast("bool*") @StdVector BooleanPointer bArgs/*=std::vector<bool>()*/, @Cast("sd::DataType*") @StdVector IntBuffer dArgs/*=std::vector<sd::DataType>()*/, @Cast("bool") boolean isInplace/*=false*/);
|
public native @Cast("Nd4jStatus") int execute(@Const @ByRef NDArrayVector inputs, @Const @ByRef NDArrayVector outputs, @StdVector DoubleBuffer tArgs, @Cast("Nd4jLong*") @StdVector LongBuffer iArgs, @Cast("bool*") @StdVector BooleanPointer bArgs/*=std::vector<bool>()*/, @Cast("sd::DataType*") @StdVector IntBuffer dArgs/*=std::vector<sd::DataType>()*/, @Cast("bool") boolean isInplace/*=false*/);
|
||||||
public native @Cast("Nd4jStatus") int execute(@Const @ByRef NDArrayVector inputs, @Const @ByRef NDArrayVector outputs, @StdVector double[] tArgs, @Cast("Nd4jLong*") @StdVector long[] iArgs, @Cast("bool*") @StdVector boolean[] bArgs/*=std::vector<bool>()*/, @Cast("sd::DataType*") @StdVector int[] dArgs/*=std::vector<sd::DataType>()*/, @Cast("bool") boolean isInplace/*=false*/);
|
public native @Cast("Nd4jStatus") int execute(@Const @ByRef NDArrayVector inputs, @Const @ByRef NDArrayVector outputs, @StdVector double[] tArgs, @Cast("Nd4jLong*") @StdVector long[] iArgs, @Cast("bool*") @StdVector boolean[] bArgs/*=std::vector<bool>()*/, @Cast("sd::DataType*") @StdVector int[] dArgs/*=std::vector<sd::DataType>()*/, @Cast("bool") boolean isInplace/*=false*/);
|
||||||
|
|
||||||
|
public native @ByVal ResultSet evaluate(@Const @ByRef NDArrayVector inputs);
|
||||||
|
|
||||||
public native ResultSet evaluate(@Const @ByRef NDArrayVector inputs);
|
public native @ByVal ResultSet evaluate(@Const @ByRef NDArrayVector inputs, @StdVector DoublePointer tArgs, @Cast("Nd4jLong*") @StdVector LongPointer iArgs, @Cast("bool*") @StdVector BooleanPointer bArgs/*=std::vector<bool>()*/, @Cast("sd::DataType*") @StdVector IntPointer dArgs/*=std::vector<sd::DataType>()*/, @Cast("bool") boolean isInplace/*=false*/);
|
||||||
|
public native @ByVal ResultSet evaluate(@Const @ByRef NDArrayVector inputs, @StdVector DoublePointer tArgs, @Cast("Nd4jLong*") @StdVector LongPointer iArgs);
|
||||||
public native ResultSet evaluate(@Const @ByRef NDArrayVector inputs, @StdVector DoublePointer tArgs, @Cast("Nd4jLong*") @StdVector LongPointer iArgs, @Cast("bool*") @StdVector BooleanPointer bArgs/*=std::vector<bool>()*/, @Cast("sd::DataType*") @StdVector IntPointer dArgs/*=std::vector<sd::DataType>()*/, @Cast("bool") boolean isInplace/*=false*/);
|
public native @ByVal ResultSet evaluate(@Const @ByRef NDArrayVector inputs, @StdVector DoubleBuffer tArgs, @Cast("Nd4jLong*") @StdVector LongBuffer iArgs, @Cast("bool*") @StdVector boolean[] bArgs/*=std::vector<bool>()*/, @Cast("sd::DataType*") @StdVector IntBuffer dArgs/*=std::vector<sd::DataType>()*/, @Cast("bool") boolean isInplace/*=false*/);
|
||||||
public native ResultSet evaluate(@Const @ByRef NDArrayVector inputs, @StdVector DoublePointer tArgs, @Cast("Nd4jLong*") @StdVector LongPointer iArgs);
|
public native @ByVal ResultSet evaluate(@Const @ByRef NDArrayVector inputs, @StdVector DoubleBuffer tArgs, @Cast("Nd4jLong*") @StdVector LongBuffer iArgs);
|
||||||
public native ResultSet evaluate(@Const @ByRef NDArrayVector inputs, @StdVector DoubleBuffer tArgs, @Cast("Nd4jLong*") @StdVector LongBuffer iArgs, @Cast("bool*") @StdVector boolean[] bArgs/*=std::vector<bool>()*/, @Cast("sd::DataType*") @StdVector IntBuffer dArgs/*=std::vector<sd::DataType>()*/, @Cast("bool") boolean isInplace/*=false*/);
|
public native @ByVal ResultSet evaluate(@Const @ByRef NDArrayVector inputs, @StdVector double[] tArgs, @Cast("Nd4jLong*") @StdVector long[] iArgs, @Cast("bool*") @StdVector BooleanPointer bArgs/*=std::vector<bool>()*/, @Cast("sd::DataType*") @StdVector int[] dArgs/*=std::vector<sd::DataType>()*/, @Cast("bool") boolean isInplace/*=false*/);
|
||||||
public native ResultSet evaluate(@Const @ByRef NDArrayVector inputs, @StdVector DoubleBuffer tArgs, @Cast("Nd4jLong*") @StdVector LongBuffer iArgs);
|
public native @ByVal ResultSet evaluate(@Const @ByRef NDArrayVector inputs, @StdVector double[] tArgs, @Cast("Nd4jLong*") @StdVector long[] iArgs);
|
||||||
public native ResultSet evaluate(@Const @ByRef NDArrayVector inputs, @StdVector double[] tArgs, @Cast("Nd4jLong*") @StdVector long[] iArgs, @Cast("bool*") @StdVector BooleanPointer bArgs/*=std::vector<bool>()*/, @Cast("sd::DataType*") @StdVector int[] dArgs/*=std::vector<sd::DataType>()*/, @Cast("bool") boolean isInplace/*=false*/);
|
public native @ByVal ResultSet evaluate(@Const @ByRef NDArrayVector inputs, @StdVector DoublePointer tArgs, @Cast("Nd4jLong*") @StdVector LongPointer iArgs, @Cast("bool*") @StdVector boolean[] bArgs/*=std::vector<bool>()*/, @Cast("sd::DataType*") @StdVector IntPointer dArgs/*=std::vector<sd::DataType>()*/, @Cast("bool") boolean isInplace/*=false*/);
|
||||||
public native ResultSet evaluate(@Const @ByRef NDArrayVector inputs, @StdVector double[] tArgs, @Cast("Nd4jLong*") @StdVector long[] iArgs);
|
public native @ByVal ResultSet evaluate(@Const @ByRef NDArrayVector inputs, @StdVector DoubleBuffer tArgs, @Cast("Nd4jLong*") @StdVector LongBuffer iArgs, @Cast("bool*") @StdVector BooleanPointer bArgs/*=std::vector<bool>()*/, @Cast("sd::DataType*") @StdVector IntBuffer dArgs/*=std::vector<sd::DataType>()*/, @Cast("bool") boolean isInplace/*=false*/);
|
||||||
public native ResultSet evaluate(@Const @ByRef NDArrayVector inputs, @StdVector DoublePointer tArgs, @Cast("Nd4jLong*") @StdVector LongPointer iArgs, @Cast("bool*") @StdVector boolean[] bArgs/*=std::vector<bool>()*/, @Cast("sd::DataType*") @StdVector IntPointer dArgs/*=std::vector<sd::DataType>()*/, @Cast("bool") boolean isInplace/*=false*/);
|
public native @ByVal ResultSet evaluate(@Const @ByRef NDArrayVector inputs, @StdVector double[] tArgs, @Cast("Nd4jLong*") @StdVector long[] iArgs, @Cast("bool*") @StdVector boolean[] bArgs/*=std::vector<bool>()*/, @Cast("sd::DataType*") @StdVector int[] dArgs/*=std::vector<sd::DataType>()*/, @Cast("bool") boolean isInplace/*=false*/);
|
||||||
public native ResultSet evaluate(@Const @ByRef NDArrayVector inputs, @StdVector DoubleBuffer tArgs, @Cast("Nd4jLong*") @StdVector LongBuffer iArgs, @Cast("bool*") @StdVector BooleanPointer bArgs/*=std::vector<bool>()*/, @Cast("sd::DataType*") @StdVector IntBuffer dArgs/*=std::vector<sd::DataType>()*/, @Cast("bool") boolean isInplace/*=false*/);
|
|
||||||
public native ResultSet evaluate(@Const @ByRef NDArrayVector inputs, @StdVector double[] tArgs, @Cast("Nd4jLong*") @StdVector long[] iArgs, @Cast("bool*") @StdVector boolean[] bArgs/*=std::vector<bool>()*/, @Cast("sd::DataType*") @StdVector int[] dArgs/*=std::vector<sd::DataType>()*/, @Cast("bool") boolean isInplace/*=false*/);
|
|
||||||
|
|
||||||
public native @Cast("Nd4jStatus") int execute(@ByRef RandomGenerator rng, @Const @ByRef NDArrayVector inputs, @Const @ByRef NDArrayVector outputs, @StdVector DoublePointer tArgs, @Cast("Nd4jLong*") @StdVector LongPointer iArgs, @Cast("bool*") @StdVector BooleanPointer bArgs, @Cast("sd::DataType*") @StdVector IntPointer dArgs/*=std::vector<sd::DataType>()*/, @Cast("bool") boolean isInplace/*=false*/, @Cast("sd::DataType") int type/*=sd::DataType::FLOAT32*/);
|
public native @Cast("Nd4jStatus") int execute(@ByRef RandomGenerator rng, @Const @ByRef NDArrayVector inputs, @Const @ByRef NDArrayVector outputs, @StdVector DoublePointer tArgs, @Cast("Nd4jLong*") @StdVector LongPointer iArgs, @Cast("bool*") @StdVector BooleanPointer bArgs, @Cast("sd::DataType*") @StdVector IntPointer dArgs/*=std::vector<sd::DataType>()*/, @Cast("bool") boolean isInplace/*=false*/, @Cast("sd::DataType") int type/*=sd::DataType::FLOAT32*/);
|
||||||
public native @Cast("Nd4jStatus") int execute(@ByRef RandomGenerator rng, @Const @ByRef NDArrayVector inputs, @Const @ByRef NDArrayVector outputs, @StdVector DoublePointer tArgs, @Cast("Nd4jLong*") @StdVector LongPointer iArgs, @Cast("bool*") @StdVector BooleanPointer bArgs);
|
public native @Cast("Nd4jStatus") int execute(@ByRef RandomGenerator rng, @Const @ByRef NDArrayVector inputs, @Const @ByRef NDArrayVector outputs, @StdVector DoublePointer tArgs, @Cast("Nd4jLong*") @StdVector LongPointer iArgs, @Cast("bool*") @StdVector BooleanPointer bArgs);
|
||||||
|
@ -9762,8 +9789,9 @@ public static final int PREALLOC_SIZE = 33554432;
|
||||||
public native @Cast("Nd4jStatus") int execute(@ByRef RandomGenerator rng, @Const @ByRef NDArrayVector inputs, @Const @ByRef NDArrayVector outputs, @StdVector double[] tArgs, @Cast("Nd4jLong*") @StdVector long[] iArgs, @Cast("bool*") @StdVector boolean[] bArgs, @Cast("sd::DataType*") @StdVector int[] dArgs/*=std::vector<sd::DataType>()*/, @Cast("bool") boolean isInplace/*=false*/, @Cast("sd::DataType") int type/*=sd::DataType::FLOAT32*/);
|
public native @Cast("Nd4jStatus") int execute(@ByRef RandomGenerator rng, @Const @ByRef NDArrayVector inputs, @Const @ByRef NDArrayVector outputs, @StdVector double[] tArgs, @Cast("Nd4jLong*") @StdVector long[] iArgs, @Cast("bool*") @StdVector boolean[] bArgs, @Cast("sd::DataType*") @StdVector int[] dArgs/*=std::vector<sd::DataType>()*/, @Cast("bool") boolean isInplace/*=false*/, @Cast("sd::DataType") int type/*=sd::DataType::FLOAT32*/);
|
||||||
public native @Cast("Nd4jStatus") int execute(@ByRef RandomGenerator rng, @Const @ByRef NDArrayVector inputs, @Const @ByRef NDArrayVector outputs, @StdVector double[] tArgs, @Cast("Nd4jLong*") @StdVector long[] iArgs, @Cast("bool*") @StdVector boolean[] bArgs);
|
public native @Cast("Nd4jStatus") int execute(@ByRef RandomGenerator rng, @Const @ByRef NDArrayVector inputs, @Const @ByRef NDArrayVector outputs, @StdVector double[] tArgs, @Cast("Nd4jLong*") @StdVector long[] iArgs, @Cast("bool*") @StdVector boolean[] bArgs);
|
||||||
|
|
||||||
public native ResultSet execute(@Const @ByRef OpArgsHolder holder, @Cast("bool") boolean isInplace/*=false*/);
|
public native @ByVal ResultSet execute(@Const @ByRef OpArgsHolder holder, @Cast("bool") boolean isInplace/*=false*/);
|
||||||
public native ResultSet execute(@Const @ByRef OpArgsHolder holder);
|
public native @ByVal ResultSet execute(@Const @ByRef OpArgsHolder holder);
|
||||||
|
|
||||||
|
|
||||||
// There methods provide various validation options
|
// There methods provide various validation options
|
||||||
public native @Cast("Nd4jStatus") int validateNonEmptyInput(@ByRef Context block);
|
public native @Cast("Nd4jStatus") int validateNonEmptyInput(@ByRef Context block);
|
||||||
|
@ -9835,9 +9863,9 @@ public static final int PREALLOC_SIZE = 33554432;
|
||||||
|
|
||||||
|
|
||||||
public native @Cast("Nd4jStatus") int execute(Context block);
|
public native @Cast("Nd4jStatus") int execute(Context block);
|
||||||
public native ResultSet execute(NDArrayList list, @ByRef NDArrayVector inputs, @StdVector DoublePointer tArgs, @StdVector IntPointer iArgs);
|
public native @ByVal ResultSet execute(NDArrayList list, @ByRef NDArrayVector inputs, @StdVector DoublePointer tArgs, @StdVector IntPointer iArgs);
|
||||||
public native ResultSet execute(NDArrayList list, @ByRef NDArrayVector inputs, @StdVector DoubleBuffer tArgs, @StdVector IntBuffer iArgs);
|
public native @ByVal ResultSet execute(NDArrayList list, @ByRef NDArrayVector inputs, @StdVector DoubleBuffer tArgs, @StdVector IntBuffer iArgs);
|
||||||
public native ResultSet execute(NDArrayList list, @ByRef NDArrayVector inputs, @StdVector double[] tArgs, @StdVector int[] iArgs);
|
public native @ByVal ResultSet execute(NDArrayList list, @ByRef NDArrayVector inputs, @StdVector double[] tArgs, @StdVector int[] iArgs);
|
||||||
|
|
||||||
public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block);
|
public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block);
|
||||||
}
|
}
|
||||||
|
|
|
@ -5000,8 +5000,8 @@ public native @Cast("bool") boolean isOptimalRequirementsMet();
|
||||||
******************************************************************************/
|
******************************************************************************/
|
||||||
|
|
||||||
//
|
//
|
||||||
// This class is suited for execution results representation.
|
// This class is suited for execution results representation.
|
||||||
//
|
//
|
||||||
// PLESE NOTE: It will delete all stored NDArrays upon destructor call
|
// PLESE NOTE: It will delete all stored NDArrays upon destructor call
|
||||||
//
|
//
|
||||||
// @author raver119@gmail.com
|
// @author raver119@gmail.com
|
||||||
|
@ -5014,7 +5014,7 @@ public native @Cast("bool") boolean isOptimalRequirementsMet();
|
||||||
// #include <graph/generated/result_generated.h>
|
// #include <graph/generated/result_generated.h>
|
||||||
// #include <system/pointercast.h>
|
// #include <system/pointercast.h>
|
||||||
// #include <system/dll.h> // forward declaration of template class NDArray
|
// #include <system/dll.h> // forward declaration of template class NDArray
|
||||||
|
|
||||||
@Namespace("sd") @NoOffset public static class ResultSet extends Pointer {
|
@Namespace("sd") @NoOffset public static class ResultSet extends Pointer {
|
||||||
static { Loader.load(); }
|
static { Loader.load(); }
|
||||||
/** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */
|
/** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */
|
||||||
|
@ -6880,6 +6880,9 @@ public static final int PREALLOC_SIZE = 33554432;
|
||||||
@Namespace("shape") public static native int sizeAt(@Cast("const Nd4jLong*") LongPointer shapeInfo, int dim);
|
@Namespace("shape") public static native int sizeAt(@Cast("const Nd4jLong*") LongPointer shapeInfo, int dim);
|
||||||
@Namespace("shape") public static native int sizeAt(@Cast("const Nd4jLong*") LongBuffer shapeInfo, int dim);
|
@Namespace("shape") public static native int sizeAt(@Cast("const Nd4jLong*") LongBuffer shapeInfo, int dim);
|
||||||
@Namespace("shape") public static native int sizeAt(@Cast("const Nd4jLong*") long[] shapeInfo, int dim);
|
@Namespace("shape") public static native int sizeAt(@Cast("const Nd4jLong*") long[] shapeInfo, int dim);
|
||||||
|
@Namespace("shape") public static native @Cast("Nd4jLong") long strideAt(@Cast("const Nd4jLong*") LongPointer shapeInfo, int dim);
|
||||||
|
@Namespace("shape") public static native @Cast("Nd4jLong") long strideAt(@Cast("const Nd4jLong*") LongBuffer shapeInfo, int dim);
|
||||||
|
@Namespace("shape") public static native @Cast("Nd4jLong") long strideAt(@Cast("const Nd4jLong*") long[] shapeInfo, int dim);
|
||||||
|
|
||||||
@Namespace("shape") public static native void traceNew(int id);
|
@Namespace("shape") public static native void traceNew(int id);
|
||||||
|
|
||||||
|
@ -7817,14 +7820,20 @@ public static final int PREALLOC_SIZE = 33554432;
|
||||||
@Namespace("shape") public static native void index2coords(@Cast("Nd4jLong") long index, int rank, @Cast("const Nd4jLong*") LongPointer shape, IntPointer coords);
|
@Namespace("shape") public static native void index2coords(@Cast("Nd4jLong") long index, int rank, @Cast("const Nd4jLong*") LongPointer shape, IntPointer coords);
|
||||||
@Namespace("shape") public static native void index2coords(@Cast("Nd4jLong") long index, int rank, @Cast("const Nd4jLong*") LongBuffer shape, IntBuffer coords);
|
@Namespace("shape") public static native void index2coords(@Cast("Nd4jLong") long index, int rank, @Cast("const Nd4jLong*") LongBuffer shape, IntBuffer coords);
|
||||||
@Namespace("shape") public static native void index2coords(@Cast("Nd4jLong") long index, int rank, @Cast("const Nd4jLong*") long[] shape, int[] coords);
|
@Namespace("shape") public static native void index2coords(@Cast("Nd4jLong") long index, int rank, @Cast("const Nd4jLong*") long[] shape, int[] coords);
|
||||||
|
|
||||||
|
@Namespace("shape") public static native void index2coordsCPU(@Cast("const Nd4jLong") long startIndex, @Cast("const Nd4jLong") long index, @Cast("const Nd4jLong*") LongPointer shapeInfo, @Cast("Nd4jLong*") LongPointer coords);
|
||||||
|
@Namespace("shape") public static native void index2coordsCPU(@Cast("const Nd4jLong") long startIndex, @Cast("const Nd4jLong") long index, @Cast("const Nd4jLong*") LongBuffer shapeInfo, @Cast("Nd4jLong*") LongBuffer coords);
|
||||||
|
@Namespace("shape") public static native void index2coordsCPU(@Cast("const Nd4jLong") long startIndex, @Cast("const Nd4jLong") long index, @Cast("const Nd4jLong*") long[] shapeInfo, @Cast("Nd4jLong*") long[] coords);
|
||||||
|
@Namespace("shape") public static native void index2coordsCPU(@Cast("const Nd4jLong") long startIndex, @Cast("const Nd4jLong") long index, @Cast("const Nd4jLong*") LongPointer shapeInfo, IntPointer coords);
|
||||||
|
@Namespace("shape") public static native void index2coordsCPU(@Cast("const Nd4jLong") long startIndex, @Cast("const Nd4jLong") long index, @Cast("const Nd4jLong*") LongBuffer shapeInfo, IntBuffer coords);
|
||||||
|
@Namespace("shape") public static native void index2coordsCPU(@Cast("const Nd4jLong") long startIndex, @Cast("const Nd4jLong") long index, @Cast("const Nd4jLong*") long[] shapeInfo, int[] coords);
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* take into account only dimensions stored in tadDims, tadDims must be sorted in increasing order!
|
* take into account only dimensions stored in tadDims, tadDims must be sorted in increasing order!
|
||||||
*/
|
*/
|
||||||
@Namespace("shape") public static native void index2coords(@Cast("Nd4jLong") long index, @Cast("const Nd4jLong*") LongPointer shapeInfo, @Cast("Nd4jLong*") LongPointer coords, int dimsSize, @Const IntPointer tadDims);
|
@Namespace("shape") public static native void index2coords(@Cast("Nd4jLong") long index, @Cast("const Nd4jLong*") LongPointer shapeInfo, IntPointer coords, int dimsSize, @Const IntPointer tadDims);
|
||||||
@Namespace("shape") public static native void index2coords(@Cast("Nd4jLong") long index, @Cast("const Nd4jLong*") LongBuffer shapeInfo, @Cast("Nd4jLong*") LongBuffer coords, int dimsSize, @Const IntBuffer tadDims);
|
@Namespace("shape") public static native void index2coords(@Cast("Nd4jLong") long index, @Cast("const Nd4jLong*") LongBuffer shapeInfo, IntBuffer coords, int dimsSize, @Const IntBuffer tadDims);
|
||||||
@Namespace("shape") public static native void index2coords(@Cast("Nd4jLong") long index, @Cast("const Nd4jLong*") long[] shapeInfo, @Cast("Nd4jLong*") long[] coords, int dimsSize, @Const int[] tadDims);
|
@Namespace("shape") public static native void index2coords(@Cast("Nd4jLong") long index, @Cast("const Nd4jLong*") long[] shapeInfo, int[] coords, int dimsSize, @Const int[] tadDims);
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Convert coordinates to the corresponding linear index (sequence number in other words)
|
* Convert coordinates to the corresponding linear index (sequence number in other words)
|
||||||
|
@ -7836,15 +7845,15 @@ public static final int PREALLOC_SIZE = 33554432;
|
||||||
@Namespace("shape") public static native @Cast("Nd4jLong") long coords2index(@Cast("const Nd4jLong*") LongPointer shapeInfo, @Const IntPointer coords);
|
@Namespace("shape") public static native @Cast("Nd4jLong") long coords2index(@Cast("const Nd4jLong*") LongPointer shapeInfo, @Const IntPointer coords);
|
||||||
@Namespace("shape") public static native @Cast("Nd4jLong") long coords2index(@Cast("const Nd4jLong*") LongBuffer shapeInfo, @Const IntBuffer coords);
|
@Namespace("shape") public static native @Cast("Nd4jLong") long coords2index(@Cast("const Nd4jLong*") LongBuffer shapeInfo, @Const IntBuffer coords);
|
||||||
@Namespace("shape") public static native @Cast("Nd4jLong") long coords2index(@Cast("const Nd4jLong*") long[] shapeInfo, @Const int[] coords);
|
@Namespace("shape") public static native @Cast("Nd4jLong") long coords2index(@Cast("const Nd4jLong*") long[] shapeInfo, @Const int[] coords);
|
||||||
@Namespace("shape") public static native @Cast("Nd4jLong") long coords2index(int rank, @Cast("const Nd4jLong*") LongPointer shape, @Cast("const Nd4jLong*") LongPointer coords);
|
@Namespace("shape") public static native @Cast("Nd4jLong") long coords2index(int rank, @Cast("const Nd4jLong*") LongPointer shape, @Const IntPointer coords);
|
||||||
@Namespace("shape") public static native @Cast("Nd4jLong") long coords2index(int rank, @Cast("const Nd4jLong*") LongBuffer shape, @Cast("const Nd4jLong*") LongBuffer coords);
|
@Namespace("shape") public static native @Cast("Nd4jLong") long coords2index(int rank, @Cast("const Nd4jLong*") LongBuffer shape, @Const IntBuffer coords);
|
||||||
@Namespace("shape") public static native @Cast("Nd4jLong") long coords2index(int rank, @Cast("const Nd4jLong*") long[] shape, @Cast("const Nd4jLong*") long[] coords);
|
@Namespace("shape") public static native @Cast("Nd4jLong") long coords2index(int rank, @Cast("const Nd4jLong*") long[] shape, @Const int[] coords);
|
||||||
/**
|
/**
|
||||||
* take into account only dimensions stored in tadDims, tadDims must be sorted in increasing order!
|
* take into account only dimensions stored in tadDims, tadDims must be sorted in increasing order!
|
||||||
*/
|
*/
|
||||||
@Namespace("shape") public static native @Cast("Nd4jLong") long coords2index(@Cast("const Nd4jLong*") LongPointer shapeInfo, @Cast("const Nd4jLong*") LongPointer coords, int dimsSize, @Const IntPointer tadDims);
|
@Namespace("shape") public static native @Cast("Nd4jLong") long coords2index(@Cast("const Nd4jLong*") LongPointer shapeInfo, @Const IntPointer coords, int dimsSize, @Const IntPointer tadDims);
|
||||||
@Namespace("shape") public static native @Cast("Nd4jLong") long coords2index(@Cast("const Nd4jLong*") LongBuffer shapeInfo, @Cast("const Nd4jLong*") LongBuffer coords, int dimsSize, @Const IntBuffer tadDims);
|
@Namespace("shape") public static native @Cast("Nd4jLong") long coords2index(@Cast("const Nd4jLong*") LongBuffer shapeInfo, @Const IntBuffer coords, int dimsSize, @Const IntBuffer tadDims);
|
||||||
@Namespace("shape") public static native @Cast("Nd4jLong") long coords2index(@Cast("const Nd4jLong*") long[] shapeInfo, @Cast("const Nd4jLong*") long[] coords, int dimsSize, @Const int[] tadDims);
|
@Namespace("shape") public static native @Cast("Nd4jLong") long coords2index(@Cast("const Nd4jLong*") long[] shapeInfo, @Const int[] coords, int dimsSize, @Const int[] tadDims);
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* increment n-dimensional array by one iteration by changing coord appropriately
|
* increment n-dimensional array by one iteration by changing coord appropriately
|
||||||
|
@ -7934,32 +7943,32 @@ public static final int PREALLOC_SIZE = 33554432;
|
||||||
// function calculates the coordinates of min array (and saves them into minIdxs) given coordinates of max array (already stored in maxIdxs)
|
// function calculates the coordinates of min array (and saves them into minIdxs) given coordinates of max array (already stored in maxIdxs)
|
||||||
// dimsToExclude - should be sorted in increasing order
|
// dimsToExclude - should be sorted in increasing order
|
||||||
// dimsLen - length of dimsToExclude, if not set (= -1), then it is calculated as maxRank - minRank
|
// dimsLen - length of dimsToExclude, if not set (= -1), then it is calculated as maxRank - minRank
|
||||||
@Namespace("shape") public static native void maxIndToMinInd(@Cast("Nd4jLong*") LongPointer maxIdxs, @Cast("Nd4jLong*") LongPointer minIdxs, @Cast("const Nd4jLong*") LongPointer maxShapeInfo, @Cast("const Nd4jLong*") LongPointer minShapeInfo, @Const IntPointer dimsToExclude/*=nullptr*/, int dimsLen/*=-1*/);
|
@Namespace("shape") public static native void maxIndToMinInd(IntPointer maxIdxs, IntPointer minIdxs, @Cast("const Nd4jLong*") LongPointer maxShapeInfo, @Cast("const Nd4jLong*") LongPointer minShapeInfo, @Const IntPointer dimsToExclude/*=nullptr*/, int dimsLen/*=-1*/);
|
||||||
@Namespace("shape") public static native void maxIndToMinInd(@Cast("Nd4jLong*") LongPointer maxIdxs, @Cast("Nd4jLong*") LongPointer minIdxs, @Cast("const Nd4jLong*") LongPointer maxShapeInfo, @Cast("const Nd4jLong*") LongPointer minShapeInfo);
|
@Namespace("shape") public static native void maxIndToMinInd(IntPointer maxIdxs, IntPointer minIdxs, @Cast("const Nd4jLong*") LongPointer maxShapeInfo, @Cast("const Nd4jLong*") LongPointer minShapeInfo);
|
||||||
@Namespace("shape") public static native void maxIndToMinInd(@Cast("Nd4jLong*") LongBuffer maxIdxs, @Cast("Nd4jLong*") LongBuffer minIdxs, @Cast("const Nd4jLong*") LongBuffer maxShapeInfo, @Cast("const Nd4jLong*") LongBuffer minShapeInfo, @Const IntBuffer dimsToExclude/*=nullptr*/, int dimsLen/*=-1*/);
|
@Namespace("shape") public static native void maxIndToMinInd(IntBuffer maxIdxs, IntBuffer minIdxs, @Cast("const Nd4jLong*") LongBuffer maxShapeInfo, @Cast("const Nd4jLong*") LongBuffer minShapeInfo, @Const IntBuffer dimsToExclude/*=nullptr*/, int dimsLen/*=-1*/);
|
||||||
@Namespace("shape") public static native void maxIndToMinInd(@Cast("Nd4jLong*") LongBuffer maxIdxs, @Cast("Nd4jLong*") LongBuffer minIdxs, @Cast("const Nd4jLong*") LongBuffer maxShapeInfo, @Cast("const Nd4jLong*") LongBuffer minShapeInfo);
|
@Namespace("shape") public static native void maxIndToMinInd(IntBuffer maxIdxs, IntBuffer minIdxs, @Cast("const Nd4jLong*") LongBuffer maxShapeInfo, @Cast("const Nd4jLong*") LongBuffer minShapeInfo);
|
||||||
@Namespace("shape") public static native void maxIndToMinInd(@Cast("Nd4jLong*") long[] maxIdxs, @Cast("Nd4jLong*") long[] minIdxs, @Cast("const Nd4jLong*") long[] maxShapeInfo, @Cast("const Nd4jLong*") long[] minShapeInfo, @Const int[] dimsToExclude/*=nullptr*/, int dimsLen/*=-1*/);
|
@Namespace("shape") public static native void maxIndToMinInd(int[] maxIdxs, int[] minIdxs, @Cast("const Nd4jLong*") long[] maxShapeInfo, @Cast("const Nd4jLong*") long[] minShapeInfo, @Const int[] dimsToExclude/*=nullptr*/, int dimsLen/*=-1*/);
|
||||||
@Namespace("shape") public static native void maxIndToMinInd(@Cast("Nd4jLong*") long[] maxIdxs, @Cast("Nd4jLong*") long[] minIdxs, @Cast("const Nd4jLong*") long[] maxShapeInfo, @Cast("const Nd4jLong*") long[] minShapeInfo);
|
@Namespace("shape") public static native void maxIndToMinInd(int[] maxIdxs, int[] minIdxs, @Cast("const Nd4jLong*") long[] maxShapeInfo, @Cast("const Nd4jLong*") long[] minShapeInfo);
|
||||||
|
|
||||||
// calculate indexes of max-array, these output indexes correspond to one minIdx index of min-array which is sub-array of max-array
|
// calculate indexes of max-array, these output indexes correspond to one minIdx index of min-array which is sub-array of max-array
|
||||||
// dimsToExclude - should be sorted in increasing order
|
// dimsToExclude - should be sorted in increasing order
|
||||||
@Namespace("shape") public static native int outerArrayIndexes(@Cast("Nd4jLong*") LongPointer maxIdxs, @Cast("const Nd4jLong") long minIdx, @Cast("const Nd4jLong*") LongPointer maxShapeInfo, @Cast("const Nd4jLong*") LongPointer minShapeInfo, @Const IntPointer dimsToExclude/*=nullptr*/);
|
@Namespace("shape") public static native int outerArrayIndexes(IntPointer maxIdxs, @Cast("const Nd4jLong") long minIdx, @Cast("const Nd4jLong*") LongPointer maxShapeInfo, @Cast("const Nd4jLong*") LongPointer minShapeInfo, @Const IntPointer dimsToExclude/*=nullptr*/);
|
||||||
@Namespace("shape") public static native int outerArrayIndexes(@Cast("Nd4jLong*") LongPointer maxIdxs, @Cast("const Nd4jLong") long minIdx, @Cast("const Nd4jLong*") LongPointer maxShapeInfo, @Cast("const Nd4jLong*") LongPointer minShapeInfo);
|
@Namespace("shape") public static native int outerArrayIndexes(IntPointer maxIdxs, @Cast("const Nd4jLong") long minIdx, @Cast("const Nd4jLong*") LongPointer maxShapeInfo, @Cast("const Nd4jLong*") LongPointer minShapeInfo);
|
||||||
@Namespace("shape") public static native int outerArrayIndexes(@Cast("Nd4jLong*") LongBuffer maxIdxs, @Cast("const Nd4jLong") long minIdx, @Cast("const Nd4jLong*") LongBuffer maxShapeInfo, @Cast("const Nd4jLong*") LongBuffer minShapeInfo, @Const IntBuffer dimsToExclude/*=nullptr*/);
|
@Namespace("shape") public static native int outerArrayIndexes(IntBuffer maxIdxs, @Cast("const Nd4jLong") long minIdx, @Cast("const Nd4jLong*") LongBuffer maxShapeInfo, @Cast("const Nd4jLong*") LongBuffer minShapeInfo, @Const IntBuffer dimsToExclude/*=nullptr*/);
|
||||||
@Namespace("shape") public static native int outerArrayIndexes(@Cast("Nd4jLong*") LongBuffer maxIdxs, @Cast("const Nd4jLong") long minIdx, @Cast("const Nd4jLong*") LongBuffer maxShapeInfo, @Cast("const Nd4jLong*") LongBuffer minShapeInfo);
|
@Namespace("shape") public static native int outerArrayIndexes(IntBuffer maxIdxs, @Cast("const Nd4jLong") long minIdx, @Cast("const Nd4jLong*") LongBuffer maxShapeInfo, @Cast("const Nd4jLong*") LongBuffer minShapeInfo);
|
||||||
@Namespace("shape") public static native int outerArrayIndexes(@Cast("Nd4jLong*") long[] maxIdxs, @Cast("const Nd4jLong") long minIdx, @Cast("const Nd4jLong*") long[] maxShapeInfo, @Cast("const Nd4jLong*") long[] minShapeInfo, @Const int[] dimsToExclude/*=nullptr*/);
|
@Namespace("shape") public static native int outerArrayIndexes(int[] maxIdxs, @Cast("const Nd4jLong") long minIdx, @Cast("const Nd4jLong*") long[] maxShapeInfo, @Cast("const Nd4jLong*") long[] minShapeInfo, @Const int[] dimsToExclude/*=nullptr*/);
|
||||||
@Namespace("shape") public static native int outerArrayIndexes(@Cast("Nd4jLong*") long[] maxIdxs, @Cast("const Nd4jLong") long minIdx, @Cast("const Nd4jLong*") long[] maxShapeInfo, @Cast("const Nd4jLong*") long[] minShapeInfo);
|
@Namespace("shape") public static native int outerArrayIndexes(int[] maxIdxs, @Cast("const Nd4jLong") long minIdx, @Cast("const Nd4jLong*") long[] maxShapeInfo, @Cast("const Nd4jLong*") long[] minShapeInfo);
|
||||||
|
|
||||||
// calculate offsets of max-array, these offsets correspond to one minIdx index of min-array which is sub-array of max-array
|
// calculate offsets of max-array, these offsets correspond to one minIdx index of min-array which is sub-array of max-array
|
||||||
// maxOffsets - will contain calculated offsets of max-array, buffer for maxOffsets should be allocated beforehand
|
// maxOffsets - will contain calculated offsets of max-array, buffer for maxOffsets should be allocated beforehand
|
||||||
// dimsToExclude - should be sorted in increasing order
|
// dimsToExclude - should be sorted in increasing order
|
||||||
// memBuff - auxiliary memory buffer (size = 2 * max_rank) for coordinates and increments storing, should be allocated beforehand
|
// memBuff - auxiliary memory buffer (size = 2 * max_rank) for coordinates and increments storing, should be allocated beforehand
|
||||||
@Namespace("shape") public static native int outerArrayOffsets(@Cast("Nd4jLong*") LongPointer maxOffsets, @Cast("const Nd4jLong") long minIdx, @Cast("const Nd4jLong*") LongPointer maxShapeInfo, @Cast("const Nd4jLong*") LongPointer minShapeInfo, @Cast("Nd4jLong*") LongPointer memBuff, @Const IntPointer dimsToExclude/*=nullptr*/);
|
@Namespace("shape") public static native int outerArrayOffsets(@Cast("Nd4jLong*") LongPointer maxOffsets, @Cast("const Nd4jLong") long minIdx, @Cast("const Nd4jLong*") LongPointer maxShapeInfo, @Cast("const Nd4jLong*") LongPointer minShapeInfo, IntPointer memBuff, @Const IntPointer dimsToExclude/*=nullptr*/);
|
||||||
@Namespace("shape") public static native int outerArrayOffsets(@Cast("Nd4jLong*") LongPointer maxOffsets, @Cast("const Nd4jLong") long minIdx, @Cast("const Nd4jLong*") LongPointer maxShapeInfo, @Cast("const Nd4jLong*") LongPointer minShapeInfo, @Cast("Nd4jLong*") LongPointer memBuff);
|
@Namespace("shape") public static native int outerArrayOffsets(@Cast("Nd4jLong*") LongPointer maxOffsets, @Cast("const Nd4jLong") long minIdx, @Cast("const Nd4jLong*") LongPointer maxShapeInfo, @Cast("const Nd4jLong*") LongPointer minShapeInfo, IntPointer memBuff);
|
||||||
@Namespace("shape") public static native int outerArrayOffsets(@Cast("Nd4jLong*") LongBuffer maxOffsets, @Cast("const Nd4jLong") long minIdx, @Cast("const Nd4jLong*") LongBuffer maxShapeInfo, @Cast("const Nd4jLong*") LongBuffer minShapeInfo, @Cast("Nd4jLong*") LongBuffer memBuff, @Const IntBuffer dimsToExclude/*=nullptr*/);
|
@Namespace("shape") public static native int outerArrayOffsets(@Cast("Nd4jLong*") LongBuffer maxOffsets, @Cast("const Nd4jLong") long minIdx, @Cast("const Nd4jLong*") LongBuffer maxShapeInfo, @Cast("const Nd4jLong*") LongBuffer minShapeInfo, IntBuffer memBuff, @Const IntBuffer dimsToExclude/*=nullptr*/);
|
||||||
@Namespace("shape") public static native int outerArrayOffsets(@Cast("Nd4jLong*") LongBuffer maxOffsets, @Cast("const Nd4jLong") long minIdx, @Cast("const Nd4jLong*") LongBuffer maxShapeInfo, @Cast("const Nd4jLong*") LongBuffer minShapeInfo, @Cast("Nd4jLong*") LongBuffer memBuff);
|
@Namespace("shape") public static native int outerArrayOffsets(@Cast("Nd4jLong*") LongBuffer maxOffsets, @Cast("const Nd4jLong") long minIdx, @Cast("const Nd4jLong*") LongBuffer maxShapeInfo, @Cast("const Nd4jLong*") LongBuffer minShapeInfo, IntBuffer memBuff);
|
||||||
@Namespace("shape") public static native int outerArrayOffsets(@Cast("Nd4jLong*") long[] maxOffsets, @Cast("const Nd4jLong") long minIdx, @Cast("const Nd4jLong*") long[] maxShapeInfo, @Cast("const Nd4jLong*") long[] minShapeInfo, @Cast("Nd4jLong*") long[] memBuff, @Const int[] dimsToExclude/*=nullptr*/);
|
@Namespace("shape") public static native int outerArrayOffsets(@Cast("Nd4jLong*") long[] maxOffsets, @Cast("const Nd4jLong") long minIdx, @Cast("const Nd4jLong*") long[] maxShapeInfo, @Cast("const Nd4jLong*") long[] minShapeInfo, int[] memBuff, @Const int[] dimsToExclude/*=nullptr*/);
|
||||||
@Namespace("shape") public static native int outerArrayOffsets(@Cast("Nd4jLong*") long[] maxOffsets, @Cast("const Nd4jLong") long minIdx, @Cast("const Nd4jLong*") long[] maxShapeInfo, @Cast("const Nd4jLong*") long[] minShapeInfo, @Cast("Nd4jLong*") long[] memBuff);
|
@Namespace("shape") public static native int outerArrayOffsets(@Cast("Nd4jLong*") long[] maxOffsets, @Cast("const Nd4jLong") long minIdx, @Cast("const Nd4jLong*") long[] maxShapeInfo, @Cast("const Nd4jLong*") long[] minShapeInfo, int[] memBuff);
|
||||||
|
|
||||||
// calculates offsets for entities (elements or sub-arrays), shape in context of sub-array means dimensions excluded from outer array
|
// calculates offsets for entities (elements or sub-arrays), shape in context of sub-array means dimensions excluded from outer array
|
||||||
// rank is equal to size of shape
|
// rank is equal to size of shape
|
||||||
|
@ -8055,9 +8064,7 @@ public static final int PREALLOC_SIZE = 33554432;
|
||||||
* get stride over contiguous axis (contiguous axis must have stride = 1)
|
* get stride over contiguous axis (contiguous axis must have stride = 1)
|
||||||
* for example when inShapeInfo is {4, 2,5,4,3, 60,1,5,20, 16384,0,99} then output is 5 (that is smallest stride in inShapeInfo except those equal to 1)
|
* for example when inShapeInfo is {4, 2,5,4,3, 60,1,5,20, 16384,0,99} then output is 5 (that is smallest stride in inShapeInfo except those equal to 1)
|
||||||
*/
|
*/
|
||||||
@Namespace("shape") public static native @Cast("Nd4jLong") long strideOverContigAxis(int axis, @Cast("const Nd4jLong*") LongPointer inShapeInfo);
|
// INLINEDEF _CUDA_HD Nd4jLong strideOverContigAxis(const int axis, const Nd4jLong* inShapeInfo);
|
||||||
@Namespace("shape") public static native @Cast("Nd4jLong") long strideOverContigAxis(int axis, @Cast("const Nd4jLong*") LongBuffer inShapeInfo);
|
|
||||||
@Namespace("shape") public static native @Cast("Nd4jLong") long strideOverContigAxis(int axis, @Cast("const Nd4jLong*") long[] inShapeInfo);
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
@ -8933,6 +8940,10 @@ public static final int PREALLOC_SIZE = 33554432;
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
//////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
//////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////
|
||||||
// INLINEDEF _CUDA_HD void calcOffsets(const Nd4jLong *xShapeInfo, Nd4jLong*& xOffsets, const Nd4jLong *yShapeInfo, Nd4jLong*& yOffsets, const Nd4jLong* zShapeInfo, Nd4jLong*& zOffsets, const char order) {
|
// INLINEDEF _CUDA_HD void calcOffsets(const Nd4jLong *xShapeInfo, Nd4jLong*& xOffsets, const Nd4jLong *yShapeInfo, Nd4jLong*& yOffsets, const Nd4jLong* zShapeInfo, Nd4jLong*& zOffsets, const char order) {
|
||||||
|
|
||||||
|
@ -9129,6 +9140,23 @@ public static final int PREALLOC_SIZE = 33554432;
|
||||||
|
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////
|
||||||
|
// INLINEDEF _CUDA_HD Nd4jLong strideOverContigAxis(const int axis, const Nd4jLong* inShapeInfo) {
|
||||||
|
|
||||||
|
// Nd4jLong result = 9223372036854775807LL;
|
||||||
|
|
||||||
|
// for(uint i = 0; i < shape::rank(inShapeInfo); ++i) {
|
||||||
|
|
||||||
|
// const auto currentStride = shape::stride(inShapeInfo)[i];
|
||||||
|
|
||||||
|
// if(i == axis || shape::shapeOf(inShapeInfo)[i] == 1)
|
||||||
|
// continue;
|
||||||
|
|
||||||
|
// if(result > currentStride)
|
||||||
|
// result = currentStride;
|
||||||
|
// }
|
||||||
|
|
||||||
|
// return result == 9223372036854775807LL ? 1 : result;
|
||||||
|
// }
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
@ -11949,18 +11977,17 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD();
|
||||||
public native @Cast("Nd4jStatus") int execute(@Const @ByRef NDArrayVector inputs, @Const @ByRef NDArrayVector outputs, @StdVector DoubleBuffer tArgs, @Cast("Nd4jLong*") @StdVector LongBuffer iArgs, @Cast("bool*") @StdVector BooleanPointer bArgs/*=std::vector<bool>()*/, @Cast("sd::DataType*") @StdVector IntBuffer dArgs/*=std::vector<sd::DataType>()*/, @Cast("bool") boolean isInplace/*=false*/);
|
public native @Cast("Nd4jStatus") int execute(@Const @ByRef NDArrayVector inputs, @Const @ByRef NDArrayVector outputs, @StdVector DoubleBuffer tArgs, @Cast("Nd4jLong*") @StdVector LongBuffer iArgs, @Cast("bool*") @StdVector BooleanPointer bArgs/*=std::vector<bool>()*/, @Cast("sd::DataType*") @StdVector IntBuffer dArgs/*=std::vector<sd::DataType>()*/, @Cast("bool") boolean isInplace/*=false*/);
|
||||||
public native @Cast("Nd4jStatus") int execute(@Const @ByRef NDArrayVector inputs, @Const @ByRef NDArrayVector outputs, @StdVector double[] tArgs, @Cast("Nd4jLong*") @StdVector long[] iArgs, @Cast("bool*") @StdVector boolean[] bArgs/*=std::vector<bool>()*/, @Cast("sd::DataType*") @StdVector int[] dArgs/*=std::vector<sd::DataType>()*/, @Cast("bool") boolean isInplace/*=false*/);
|
public native @Cast("Nd4jStatus") int execute(@Const @ByRef NDArrayVector inputs, @Const @ByRef NDArrayVector outputs, @StdVector double[] tArgs, @Cast("Nd4jLong*") @StdVector long[] iArgs, @Cast("bool*") @StdVector boolean[] bArgs/*=std::vector<bool>()*/, @Cast("sd::DataType*") @StdVector int[] dArgs/*=std::vector<sd::DataType>()*/, @Cast("bool") boolean isInplace/*=false*/);
|
||||||
|
|
||||||
|
public native @ByVal ResultSet evaluate(@Const @ByRef NDArrayVector inputs);
|
||||||
|
|
||||||
public native ResultSet evaluate(@Const @ByRef NDArrayVector inputs);
|
public native @ByVal ResultSet evaluate(@Const @ByRef NDArrayVector inputs, @StdVector DoublePointer tArgs, @Cast("Nd4jLong*") @StdVector LongPointer iArgs, @Cast("bool*") @StdVector BooleanPointer bArgs/*=std::vector<bool>()*/, @Cast("sd::DataType*") @StdVector IntPointer dArgs/*=std::vector<sd::DataType>()*/, @Cast("bool") boolean isInplace/*=false*/);
|
||||||
|
public native @ByVal ResultSet evaluate(@Const @ByRef NDArrayVector inputs, @StdVector DoublePointer tArgs, @Cast("Nd4jLong*") @StdVector LongPointer iArgs);
|
||||||
public native ResultSet evaluate(@Const @ByRef NDArrayVector inputs, @StdVector DoublePointer tArgs, @Cast("Nd4jLong*") @StdVector LongPointer iArgs, @Cast("bool*") @StdVector BooleanPointer bArgs/*=std::vector<bool>()*/, @Cast("sd::DataType*") @StdVector IntPointer dArgs/*=std::vector<sd::DataType>()*/, @Cast("bool") boolean isInplace/*=false*/);
|
public native @ByVal ResultSet evaluate(@Const @ByRef NDArrayVector inputs, @StdVector DoubleBuffer tArgs, @Cast("Nd4jLong*") @StdVector LongBuffer iArgs, @Cast("bool*") @StdVector boolean[] bArgs/*=std::vector<bool>()*/, @Cast("sd::DataType*") @StdVector IntBuffer dArgs/*=std::vector<sd::DataType>()*/, @Cast("bool") boolean isInplace/*=false*/);
|
||||||
public native ResultSet evaluate(@Const @ByRef NDArrayVector inputs, @StdVector DoublePointer tArgs, @Cast("Nd4jLong*") @StdVector LongPointer iArgs);
|
public native @ByVal ResultSet evaluate(@Const @ByRef NDArrayVector inputs, @StdVector DoubleBuffer tArgs, @Cast("Nd4jLong*") @StdVector LongBuffer iArgs);
|
||||||
public native ResultSet evaluate(@Const @ByRef NDArrayVector inputs, @StdVector DoubleBuffer tArgs, @Cast("Nd4jLong*") @StdVector LongBuffer iArgs, @Cast("bool*") @StdVector boolean[] bArgs/*=std::vector<bool>()*/, @Cast("sd::DataType*") @StdVector IntBuffer dArgs/*=std::vector<sd::DataType>()*/, @Cast("bool") boolean isInplace/*=false*/);
|
public native @ByVal ResultSet evaluate(@Const @ByRef NDArrayVector inputs, @StdVector double[] tArgs, @Cast("Nd4jLong*") @StdVector long[] iArgs, @Cast("bool*") @StdVector BooleanPointer bArgs/*=std::vector<bool>()*/, @Cast("sd::DataType*") @StdVector int[] dArgs/*=std::vector<sd::DataType>()*/, @Cast("bool") boolean isInplace/*=false*/);
|
||||||
public native ResultSet evaluate(@Const @ByRef NDArrayVector inputs, @StdVector DoubleBuffer tArgs, @Cast("Nd4jLong*") @StdVector LongBuffer iArgs);
|
public native @ByVal ResultSet evaluate(@Const @ByRef NDArrayVector inputs, @StdVector double[] tArgs, @Cast("Nd4jLong*") @StdVector long[] iArgs);
|
||||||
public native ResultSet evaluate(@Const @ByRef NDArrayVector inputs, @StdVector double[] tArgs, @Cast("Nd4jLong*") @StdVector long[] iArgs, @Cast("bool*") @StdVector BooleanPointer bArgs/*=std::vector<bool>()*/, @Cast("sd::DataType*") @StdVector int[] dArgs/*=std::vector<sd::DataType>()*/, @Cast("bool") boolean isInplace/*=false*/);
|
public native @ByVal ResultSet evaluate(@Const @ByRef NDArrayVector inputs, @StdVector DoublePointer tArgs, @Cast("Nd4jLong*") @StdVector LongPointer iArgs, @Cast("bool*") @StdVector boolean[] bArgs/*=std::vector<bool>()*/, @Cast("sd::DataType*") @StdVector IntPointer dArgs/*=std::vector<sd::DataType>()*/, @Cast("bool") boolean isInplace/*=false*/);
|
||||||
public native ResultSet evaluate(@Const @ByRef NDArrayVector inputs, @StdVector double[] tArgs, @Cast("Nd4jLong*") @StdVector long[] iArgs);
|
public native @ByVal ResultSet evaluate(@Const @ByRef NDArrayVector inputs, @StdVector DoubleBuffer tArgs, @Cast("Nd4jLong*") @StdVector LongBuffer iArgs, @Cast("bool*") @StdVector BooleanPointer bArgs/*=std::vector<bool>()*/, @Cast("sd::DataType*") @StdVector IntBuffer dArgs/*=std::vector<sd::DataType>()*/, @Cast("bool") boolean isInplace/*=false*/);
|
||||||
public native ResultSet evaluate(@Const @ByRef NDArrayVector inputs, @StdVector DoublePointer tArgs, @Cast("Nd4jLong*") @StdVector LongPointer iArgs, @Cast("bool*") @StdVector boolean[] bArgs/*=std::vector<bool>()*/, @Cast("sd::DataType*") @StdVector IntPointer dArgs/*=std::vector<sd::DataType>()*/, @Cast("bool") boolean isInplace/*=false*/);
|
public native @ByVal ResultSet evaluate(@Const @ByRef NDArrayVector inputs, @StdVector double[] tArgs, @Cast("Nd4jLong*") @StdVector long[] iArgs, @Cast("bool*") @StdVector boolean[] bArgs/*=std::vector<bool>()*/, @Cast("sd::DataType*") @StdVector int[] dArgs/*=std::vector<sd::DataType>()*/, @Cast("bool") boolean isInplace/*=false*/);
|
||||||
public native ResultSet evaluate(@Const @ByRef NDArrayVector inputs, @StdVector DoubleBuffer tArgs, @Cast("Nd4jLong*") @StdVector LongBuffer iArgs, @Cast("bool*") @StdVector BooleanPointer bArgs/*=std::vector<bool>()*/, @Cast("sd::DataType*") @StdVector IntBuffer dArgs/*=std::vector<sd::DataType>()*/, @Cast("bool") boolean isInplace/*=false*/);
|
|
||||||
public native ResultSet evaluate(@Const @ByRef NDArrayVector inputs, @StdVector double[] tArgs, @Cast("Nd4jLong*") @StdVector long[] iArgs, @Cast("bool*") @StdVector boolean[] bArgs/*=std::vector<bool>()*/, @Cast("sd::DataType*") @StdVector int[] dArgs/*=std::vector<sd::DataType>()*/, @Cast("bool") boolean isInplace/*=false*/);
|
|
||||||
|
|
||||||
public native @Cast("Nd4jStatus") int execute(@ByRef RandomGenerator rng, @Const @ByRef NDArrayVector inputs, @Const @ByRef NDArrayVector outputs, @StdVector DoublePointer tArgs, @Cast("Nd4jLong*") @StdVector LongPointer iArgs, @Cast("bool*") @StdVector BooleanPointer bArgs, @Cast("sd::DataType*") @StdVector IntPointer dArgs/*=std::vector<sd::DataType>()*/, @Cast("bool") boolean isInplace/*=false*/, @Cast("sd::DataType") int type/*=sd::DataType::FLOAT32*/);
|
public native @Cast("Nd4jStatus") int execute(@ByRef RandomGenerator rng, @Const @ByRef NDArrayVector inputs, @Const @ByRef NDArrayVector outputs, @StdVector DoublePointer tArgs, @Cast("Nd4jLong*") @StdVector LongPointer iArgs, @Cast("bool*") @StdVector BooleanPointer bArgs, @Cast("sd::DataType*") @StdVector IntPointer dArgs/*=std::vector<sd::DataType>()*/, @Cast("bool") boolean isInplace/*=false*/, @Cast("sd::DataType") int type/*=sd::DataType::FLOAT32*/);
|
||||||
public native @Cast("Nd4jStatus") int execute(@ByRef RandomGenerator rng, @Const @ByRef NDArrayVector inputs, @Const @ByRef NDArrayVector outputs, @StdVector DoublePointer tArgs, @Cast("Nd4jLong*") @StdVector LongPointer iArgs, @Cast("bool*") @StdVector BooleanPointer bArgs);
|
public native @Cast("Nd4jStatus") int execute(@ByRef RandomGenerator rng, @Const @ByRef NDArrayVector inputs, @Const @ByRef NDArrayVector outputs, @StdVector DoublePointer tArgs, @Cast("Nd4jLong*") @StdVector LongPointer iArgs, @Cast("bool*") @StdVector BooleanPointer bArgs);
|
||||||
|
@ -11975,8 +12002,9 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD();
|
||||||
public native @Cast("Nd4jStatus") int execute(@ByRef RandomGenerator rng, @Const @ByRef NDArrayVector inputs, @Const @ByRef NDArrayVector outputs, @StdVector double[] tArgs, @Cast("Nd4jLong*") @StdVector long[] iArgs, @Cast("bool*") @StdVector boolean[] bArgs, @Cast("sd::DataType*") @StdVector int[] dArgs/*=std::vector<sd::DataType>()*/, @Cast("bool") boolean isInplace/*=false*/, @Cast("sd::DataType") int type/*=sd::DataType::FLOAT32*/);
|
public native @Cast("Nd4jStatus") int execute(@ByRef RandomGenerator rng, @Const @ByRef NDArrayVector inputs, @Const @ByRef NDArrayVector outputs, @StdVector double[] tArgs, @Cast("Nd4jLong*") @StdVector long[] iArgs, @Cast("bool*") @StdVector boolean[] bArgs, @Cast("sd::DataType*") @StdVector int[] dArgs/*=std::vector<sd::DataType>()*/, @Cast("bool") boolean isInplace/*=false*/, @Cast("sd::DataType") int type/*=sd::DataType::FLOAT32*/);
|
||||||
public native @Cast("Nd4jStatus") int execute(@ByRef RandomGenerator rng, @Const @ByRef NDArrayVector inputs, @Const @ByRef NDArrayVector outputs, @StdVector double[] tArgs, @Cast("Nd4jLong*") @StdVector long[] iArgs, @Cast("bool*") @StdVector boolean[] bArgs);
|
public native @Cast("Nd4jStatus") int execute(@ByRef RandomGenerator rng, @Const @ByRef NDArrayVector inputs, @Const @ByRef NDArrayVector outputs, @StdVector double[] tArgs, @Cast("Nd4jLong*") @StdVector long[] iArgs, @Cast("bool*") @StdVector boolean[] bArgs);
|
||||||
|
|
||||||
public native ResultSet execute(@Const @ByRef OpArgsHolder holder, @Cast("bool") boolean isInplace/*=false*/);
|
public native @ByVal ResultSet execute(@Const @ByRef OpArgsHolder holder, @Cast("bool") boolean isInplace/*=false*/);
|
||||||
public native ResultSet execute(@Const @ByRef OpArgsHolder holder);
|
public native @ByVal ResultSet execute(@Const @ByRef OpArgsHolder holder);
|
||||||
|
|
||||||
|
|
||||||
// There methods provide various validation options
|
// There methods provide various validation options
|
||||||
public native @Cast("Nd4jStatus") int validateNonEmptyInput(@ByRef Context block);
|
public native @Cast("Nd4jStatus") int validateNonEmptyInput(@ByRef Context block);
|
||||||
|
@ -12048,9 +12076,9 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD();
|
||||||
|
|
||||||
|
|
||||||
public native @Cast("Nd4jStatus") int execute(Context block);
|
public native @Cast("Nd4jStatus") int execute(Context block);
|
||||||
public native ResultSet execute(NDArrayList list, @ByRef NDArrayVector inputs, @StdVector DoublePointer tArgs, @StdVector IntPointer iArgs);
|
public native @ByVal ResultSet execute(NDArrayList list, @ByRef NDArrayVector inputs, @StdVector DoublePointer tArgs, @StdVector IntPointer iArgs);
|
||||||
public native ResultSet execute(NDArrayList list, @ByRef NDArrayVector inputs, @StdVector DoubleBuffer tArgs, @StdVector IntBuffer iArgs);
|
public native @ByVal ResultSet execute(NDArrayList list, @ByRef NDArrayVector inputs, @StdVector DoubleBuffer tArgs, @StdVector IntBuffer iArgs);
|
||||||
public native ResultSet execute(NDArrayList list, @ByRef NDArrayVector inputs, @StdVector double[] tArgs, @StdVector int[] iArgs);
|
public native @ByVal ResultSet execute(NDArrayList list, @ByRef NDArrayVector inputs, @StdVector double[] tArgs, @StdVector int[] iArgs);
|
||||||
|
|
||||||
public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block);
|
public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block);
|
||||||
}
|
}
|
||||||
|
|
|
@ -2050,7 +2050,7 @@ public class ShapeOpValidation extends BaseOpValidation {
|
||||||
print(out[0]);
|
print(out[0]);
|
||||||
*/
|
*/
|
||||||
|
|
||||||
INDArray emptyIn = Nd4j.empty(DataType.FLOAT);
|
INDArray emptyIn = Nd4j.empty(DataType.FLOAT).reshape(0, 4);
|
||||||
INDArray axis = Nd4j.scalar(1);
|
INDArray axis = Nd4j.scalar(1);
|
||||||
|
|
||||||
DynamicCustomOp op = DynamicCustomOp.builder("split")
|
DynamicCustomOp op = DynamicCustomOp.builder("split")
|
||||||
|
@ -2061,9 +2061,10 @@ public class ShapeOpValidation extends BaseOpValidation {
|
||||||
List<LongShapeDescriptor> l = op.calculateOutputShape();
|
List<LongShapeDescriptor> l = op.calculateOutputShape();
|
||||||
assertEquals(4, l.size());
|
assertEquals(4, l.size());
|
||||||
for( int i=0; i<4; i++ ){
|
for( int i=0; i<4; i++ ){
|
||||||
assertArrayEquals(new long[0], l.get(i).getShape());
|
val desc = l.get(i);
|
||||||
assertTrue(l.get(i).isEmpty());
|
assertArrayEquals(new long[]{0, 1}, desc.getShape());
|
||||||
op.addOutputArgument(Nd4j.empty(DataType.FLOAT));
|
assertTrue(desc.isEmpty());
|
||||||
|
op.addOutputArgument(Nd4j.empty(DataType.FLOAT).reshape(desc.getShape()));
|
||||||
}
|
}
|
||||||
|
|
||||||
Nd4j.exec(op);
|
Nd4j.exec(op);
|
||||||
|
|
Loading…
Reference in New Issue