correct output empty shapes deducing in split op (#311)

* - correct output empty shapes deducing in split op

Signed-off-by: Yurii <iuriish@yahoo.com>

* java test fixed

Signed-off-by: raver119 <raver119@gmail.com>

* - split broadcast::exec function on individual functions corresponding to switch arg

Signed-off-by: Yurii <iuriish@yahoo.com>

* - split broadcast::exec _int and _bool function on individual functions corresponding to switch arg

Signed-off-by: Yurii <iuriish@yahoo.com>

Co-authored-by: raver119 <raver119@gmail.com>
master
Yurii Shyrma 2020-03-12 17:25:54 +02:00 committed by GitHub
parent 41bde8f885
commit e42b4e96c3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 1306 additions and 983 deletions

View File

@ -572,6 +572,272 @@ template <typename X, typename Y, typename Z>
DISPATCH_BY_OPNUM_TTT(exec, PARAMS(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo), BROADCAST_OPS); DISPATCH_BY_OPNUM_TTT(exec, PARAMS(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo), BROADCAST_OPS);
} }
////////////////////////////////////////////////////////////////////////
template <typename X, typename Y, typename Z, typename OpType>
static void execRank1(const X *x, const Nd4jLong *xShapeInfo, const Y *y, const Nd4jLong *yShapeInfo, Z* z, const Nd4jLong *zShapeInfo) {
uint zAxis0 = shape::sizeAt(zShapeInfo, 0);
Nd4jLong xStrd0 = shape::strideAt(xShapeInfo, 0);
Nd4jLong yStrd0 = shape::strideAt(yShapeInfo, 0);
Nd4jLong zStrd0 = shape::strideAt(zShapeInfo, 0);
auto func = PRAGMA_THREADS_FOR{
if(zStrd0 == 1 && xStrd0 == 1 && yStrd0 == 0) {
for (auto i0 = start; i0 < stop; ++i0)
z[i0] = OpType::op(x[i0], *y);
}
else if(zStrd0 == 1 && xStrd0 == 0 && yStrd0 == 1) {
for (auto i0 = start; i0 < stop; ++i0)
z[i0] = OpType::op(*x, y[i0]);
}
else if(zStrd0 == 1 && xStrd0 == 1 && yStrd0 == 1) {
for (auto i0 = start; i0 < stop; ++i0)
z[i0] = OpType::op(x[i0], y[i0]);
}
else {
for (auto i0 = start; i0 < stop; ++i0)
z[i0 * zStrd0] = OpType::op(x[i0 * xStrd0], y[i0 * yStrd0]);
}
};
samediff::Threads::parallel_tad(func, 0, zAxis0);
}
////////////////////////////////////////////////////////////////////////
template <typename X, typename Y, typename Z, typename OpType>
static void execRank2(const X *x, const Nd4jLong *xShapeInfo, const Y *y, const Nd4jLong *yShapeInfo, Z* z, const Nd4jLong *zShapeInfo) {
uint zAxis0 = shape::sizeAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 1);
Nd4jLong xStrd0 = shape::strideAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 1);
Nd4jLong yStrd0 = shape::strideAt(yShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 1);
Nd4jLong zStrd0 = shape::strideAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 1);
uint zAxis1 = shape::sizeAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 1 : 0);
Nd4jLong xStrd1 = shape::strideAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? 1 : 0);
Nd4jLong yStrd1 = shape::strideAt(yShapeInfo, shape::order(zShapeInfo) == 'c' ? 1 : 0);
Nd4jLong zStrd1 = shape::strideAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 1 : 0);
auto func = PRAGMA_THREADS_FOR{
for (auto i0 = start; i0 < stop; ++i0) {
auto x0 = x + i0 * xStrd0;
auto y0 = y + i0 * yStrd0;
auto z0 = z + i0 * zStrd0;
if(zStrd1 == 1 && xStrd1 == 1 && yStrd1 == 0)
for (uint i1 = 0; i1 < zAxis1; ++i1)
z0[i1] = OpType::op(x0[i1], *y0);
else if(zStrd1 == 1 && xStrd1 == 0 && yStrd1 == 1)
for (uint i1 = 0; i1 < zAxis1; ++i1)
z0[i1] = OpType::op(*x0, y0[i1]);
else if(zStrd1 == 1 && xStrd1 == 1 && yStrd1 == 1)
for (uint i1 = 0; i1 < zAxis1; ++i1)
z0[i1] = OpType::op(x0[i1], y0[i1]);
else
for (uint i1 = 0; i1 < zAxis1; ++i1)
z0[i1 * zStrd1] = OpType::op(x0[i1 * xStrd1], y0[i1 * yStrd1]);
}
};
samediff::Threads::parallel_tad(func, 0, zAxis0);
}
////////////////////////////////////////////////////////////////////////
template <typename X, typename Y, typename Z, typename OpType>
static void execRank3(const X *x, const Nd4jLong *xShapeInfo, const Y *y, const Nd4jLong *yShapeInfo, Z* z, const Nd4jLong *zShapeInfo) {
uint zAxis0 = shape::sizeAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 2);
Nd4jLong xStrd0 = shape::strideAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 2);
Nd4jLong yStrd0 = shape::strideAt(yShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 2);
Nd4jLong zStrd0 = shape::strideAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 2);
uint zAxis1 = shape::sizeAt(zShapeInfo, 1);
Nd4jLong xStrd1 = shape::strideAt(xShapeInfo, 1);
Nd4jLong yStrd1 = shape::strideAt(yShapeInfo, 1);
Nd4jLong zStrd1 = shape::strideAt(zShapeInfo, 1);
uint zAxis2 = shape::sizeAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 2 : 0);
Nd4jLong xStrd2 = shape::strideAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? 2 : 0);
Nd4jLong yStrd2 = shape::strideAt(yShapeInfo, shape::order(zShapeInfo) == 'c' ? 2 : 0);
Nd4jLong zStrd2 = shape::strideAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 2 : 0);
auto func = PRAGMA_THREADS_FOR_2D {
for (auto i0 = start_x; i0 < stop_x; ++i0) {
for (auto i1 = start_y; i1 < stop_y; ++i1) {
auto x1 = x + i0 * xStrd0 + i1 * xStrd1;
auto y1 = y + i0 * yStrd0 + i1 * yStrd1;
auto z1 = z + i0 * zStrd0 + i1 * zStrd1;
if(zStrd2 == 1 && xStrd2 == 1 && yStrd2 == 0)
for (uint i2 = 0; i2 < zAxis2; ++i2)
z1[i2] = OpType::op(x1[i2], *y1);
else if(zStrd2 == 1 && xStrd2 == 0 && yStrd2 == 1)
for (uint i2 = 0; i2 < zAxis2; ++i2)
z1[i2] = OpType::op(*x1, y1[i2]);
else if(zStrd2 == 1 && xStrd2 == 1 && yStrd2 == 1)
for (uint i2 = 0; i2 < zAxis2; ++i2)
z1[i2] = OpType::op(x1[i2], y1[i2]);
else
for (uint i2 = 0; i2 < zAxis2; ++i2)
z1[i2 * zStrd2] = OpType::op(x1[i2 * xStrd2], y1[i2 * yStrd2]);
}
}
};
samediff::Threads::parallel_for(func, 0,zAxis0,1, 0,zAxis1,1);
}
////////////////////////////////////////////////////////////////////////
template <typename X, typename Y, typename Z, typename OpType>
static void execRank4(const X *x, const Nd4jLong *xShapeInfo, const Y *y, const Nd4jLong *yShapeInfo, Z* z, const Nd4jLong *zShapeInfo) {
uint zAxis0 = shape::sizeAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 3);
Nd4jLong xStrd0 = shape::strideAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 3);
Nd4jLong yStrd0 = shape::strideAt(yShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 3);
Nd4jLong zStrd0 = shape::strideAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 3);
uint zAxis1 = shape::sizeAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 1 : 2);
Nd4jLong xStrd1 = shape::strideAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? 1 : 2);
Nd4jLong yStrd1 = shape::strideAt(yShapeInfo, shape::order(zShapeInfo) == 'c' ? 1 : 2);
Nd4jLong zStrd1 = shape::strideAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 1 : 2);
uint zAxis2 = shape::sizeAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 2 : 1);
Nd4jLong xStrd2 = shape::strideAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? 2 : 1);
Nd4jLong yStrd2 = shape::strideAt(yShapeInfo, shape::order(zShapeInfo) == 'c' ? 2 : 1);
Nd4jLong zStrd2 = shape::strideAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 2 : 1);
uint zAxis3 = shape::sizeAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 3 : 0);
Nd4jLong xStrd3 = shape::strideAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? 3 : 0);
Nd4jLong yStrd3 = shape::strideAt(yShapeInfo, shape::order(zShapeInfo) == 'c' ? 3 : 0);
Nd4jLong zStrd3 = shape::strideAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 3 : 0);
auto func = PRAGMA_THREADS_FOR_3D {
for (auto i0 = start_x; i0 < stop_x; ++i0) {
for (auto i1 = start_y; i1 < stop_y; ++i1) {
for (auto i2 = start_z; i2 < stop_z; ++i2) {
auto x2 = x + i0 * xStrd0 + i1 * xStrd1 + i2 * xStrd2;
auto y2 = y + i0 * yStrd0 + i1 * yStrd1 + i2 * yStrd2;
auto z2 = z + i0 * zStrd0 + i1 * zStrd1 + i2 * zStrd2;
if(zStrd3 == 1 && xStrd3 == 1 && yStrd3 == 0)
for (uint i3 = 0; i3 < zAxis3; ++i3)
z2[i3] = OpType::op(x2[i3], *y2);
else if(zStrd3 == 1 && xStrd3 == 0 && yStrd3 == 1)
for (uint i3 = 0; i3 < zAxis3; ++i3)
z2[i3] = OpType::op(*x2, y2[i3]);
else if(zStrd3 == 1 && xStrd3 == 1 && yStrd3 == 1)
for (uint i3 = 0; i3 < zAxis3; ++i3)
z2[i3] = OpType::op(x2[i3], y2[i3]);
else
for (uint i3 = 0; i3 < zAxis3; ++i3)
z2[i3 * zStrd3] = OpType::op(x2[i3 * xStrd3], y2[i3 * yStrd3]);
}
}
}
};
samediff::Threads::parallel_for(func, 0,zAxis0,1, 0,zAxis1,1, 0,zAxis2,1);
}
////////////////////////////////////////////////////////////////////////
template <typename X, typename Y, typename Z, typename OpType>
static void execRank5(const X *x, const Nd4jLong *xShapeInfo, const Y *y, const Nd4jLong *yShapeInfo, Z* z, const Nd4jLong *zShapeInfo) {
uint zAxis0 = shape::sizeAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 4);
Nd4jLong xStrd0 = shape::strideAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 4);
Nd4jLong yStrd0 = shape::strideAt(yShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 4);
Nd4jLong zStrd0 = shape::strideAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 4);
uint zAxis1 = shape::sizeAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 1 : 3);
Nd4jLong xStrd1 = shape::strideAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? 1 : 3);
Nd4jLong yStrd1 = shape::strideAt(yShapeInfo, shape::order(zShapeInfo) == 'c' ? 1 : 3);
Nd4jLong zStrd1 = shape::strideAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 1 : 3);
uint zAxis2 = shape::sizeAt(zShapeInfo, 2);
Nd4jLong xStrd2 = shape::strideAt(xShapeInfo, 2);
Nd4jLong yStrd2 = shape::strideAt(yShapeInfo, 2);
Nd4jLong zStrd2 = shape::strideAt(zShapeInfo, 2);
uint zAxis3 = shape::sizeAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 3 : 1);
Nd4jLong xStrd3 = shape::strideAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? 3 : 1);
Nd4jLong yStrd3 = shape::strideAt(yShapeInfo, shape::order(zShapeInfo) == 'c' ? 3 : 1);
Nd4jLong zStrd3 = shape::strideAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 3 : 1);
uint zAxis4 = shape::sizeAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 4 : 0);
Nd4jLong xStrd4 = shape::strideAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? 4 : 0);
Nd4jLong yStrd4 = shape::strideAt(yShapeInfo, shape::order(zShapeInfo) == 'c' ? 4 : 0);
Nd4jLong zStrd4 = shape::strideAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 4 : 0);
auto func = PRAGMA_THREADS_FOR_3D {
for (auto i0 = start_x; i0 < stop_x; ++i0) {
for (auto i1 = start_y; i1 < stop_y; ++i1) {
for (auto i2 = start_z; i2 < stop_z; ++i2) {
for (uint i3 = 0; i3 < zAxis3; ++i3) {
auto x3 = x + i0 * xStrd0 + i1 * xStrd1 + i2 * xStrd2 + i3 * xStrd3;
auto y3 = y + i0 * yStrd0 + i1 * yStrd1 + i2 * yStrd2 + i3 * yStrd3;
auto z3 = z + i0 * zStrd0 + i1 * zStrd1 + i2 * zStrd2 + i3 * zStrd3;
if(zStrd4 == 1 && xStrd4 == 1 && yStrd4 == 0)
for (uint i4 = 0; i4 < zAxis4; ++i4)
z3[i4] = OpType::op(x3[i4], *y3);
else if(zStrd4 == 1 && xStrd4 == 0 && yStrd4 == 1)
for (uint i4 = 0; i4 < zAxis4; ++i4)
z3[i4] = OpType::op(*x3, y3[i4]);
else if(zStrd4 == 1 && xStrd4 == 1 && yStrd4 == 1)
for (uint i4 = 0; i4 < zAxis4; ++i4)
z3[i4] = OpType::op(x3[i4], y3[i4]);
else
for (uint i4 = 0; i4 < zAxis4; ++i4)
z3[i4 * zStrd4] = OpType::op(x3[i4 * xStrd4], y3[i4 * yStrd4]);
}
}
}
}
};
samediff::Threads::parallel_for(func, 0,zAxis0,1, 0,zAxis1,1, 0,zAxis2,1);
}
////////////////////////////////////////////////////////////////////////
template <typename X, typename Y, typename Z, typename OpType>
static void execDefault(const X *x, const Nd4jLong *xShapeInfo, const Y *y, const Nd4jLong *yShapeInfo, Z* z, const Nd4jLong *zShapeInfo) {
const bool xzSameOffsets = shape::haveSameShapeAndStrides(xShapeInfo, zShapeInfo);
const bool yzSameOffsets = shape::haveSameShapeAndStrides(yShapeInfo, zShapeInfo);
const int rank = shape::rank(zShapeInfo); // xRank = yRank = zRank
auto func = PRAGMA_THREADS_FOR{
int xCoords[MAX_RANK], yCoords[MAX_RANK], zCoords[MAX_RANK];
for (auto i = start; i < stop; ++i) {
shape::index2coordsCPU(start, i, zShapeInfo, zCoords);
for (uint j = 0; j < rank; ++j) {
xCoords[j] = shape::sizeAt(xShapeInfo, j) == 1 ? 0 : zCoords[j];
yCoords[j] = shape::sizeAt(yShapeInfo, j) == 1 ? 0 : zCoords[j];
}
const auto zOffset = shape::getOffset(zShapeInfo, zCoords);
const auto xOffset = xzSameOffsets ? zOffset : shape::getOffset(xShapeInfo, xCoords);
const auto yOffset = yzSameOffsets ? zOffset : shape::getOffset(yShapeInfo, yCoords);
z[zOffset] = OpType::op(x[xOffset], y[yOffset]);
}
};
samediff::Threads::parallel_for(func, 0, shape::length(zShapeInfo));
}
//////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////
template <typename X, typename Y, typename Z> template <typename X, typename Y, typename Z>
template<typename OpType> template<typename OpType>
@ -582,220 +848,26 @@ void Broadcast<X, Y, Z>::exec(const void *vx, const Nd4jLong *xShapeInfo, const
Z* z = reinterpret_cast<Z*>(vz); Z* z = reinterpret_cast<Z*>(vz);
const int rank = shape::rank(zShapeInfo); // xRank = yRank = zRank const int rank = shape::rank(zShapeInfo); // xRank = yRank = zRank
const char zOrder = shape::order(zShapeInfo);
uint xAxis0 = shape::sizeAt(xShapeInfo, zOrder == 'c' ? 0 : rank-1);
uint xAxis1 = shape::sizeAt(xShapeInfo, zOrder == 'c' ? 1 : rank-2);
uint xAxis2 = rank > 2 ? shape::sizeAt(xShapeInfo, zOrder == 'c' ? 2 : rank - 3) : 0;
uint xAxis3 = rank > 3 ? shape::sizeAt(xShapeInfo, zOrder == 'c' ? 3 : rank - 4) : 0;
uint xAxis4 = rank > 4 ? shape::sizeAt(xShapeInfo, zOrder == 'c' ? 4 : rank - 5) : 0;
Nd4jLong xStrd0 = shape::strideAt(xShapeInfo, zOrder == 'c' ? 0 : rank-1);
Nd4jLong xStrd1 = shape::strideAt(xShapeInfo, zOrder == 'c' ? 1 : rank-2);
Nd4jLong xStrd2 = rank > 2 ? shape::strideAt(xShapeInfo, zOrder == 'c' ? 2 : rank - 3) : 0;
Nd4jLong xStrd3 = rank > 3 ? shape::strideAt(xShapeInfo, zOrder == 'c' ? 3 : rank - 4) : 0;
Nd4jLong xStrd4 = rank > 4 ? shape::strideAt(xShapeInfo, zOrder == 'c' ? 4 : rank - 5) : 0;
uint yAxis0 = shape::sizeAt(yShapeInfo, zOrder == 'c' ? 0 : rank-1);
uint yAxis1 = shape::sizeAt(yShapeInfo, zOrder == 'c' ? 1 : rank-2);
uint yAxis2 = rank > 2 ? shape::sizeAt(yShapeInfo, zOrder == 'c' ? 2 : rank - 3) : 0;
uint yAxis3 = rank > 3 ? shape::sizeAt(yShapeInfo, zOrder == 'c' ? 3 : rank - 4) : 0;
uint yAxis4 = rank > 4 ? shape::sizeAt(yShapeInfo, zOrder == 'c' ? 4 : rank - 5) : 0;
Nd4jLong yStrd0 = shape::strideAt(yShapeInfo, zOrder == 'c' ? 0 : rank-1);
Nd4jLong yStrd1 = shape::strideAt(yShapeInfo, zOrder == 'c' ? 1 : rank-2);
Nd4jLong yStrd2 = rank > 2 ? shape::strideAt(yShapeInfo, zOrder == 'c' ? 2 : rank - 3) : 0;
Nd4jLong yStrd3 = rank > 3 ? shape::strideAt(yShapeInfo, zOrder == 'c' ? 3 : rank - 4) : 0;
Nd4jLong yStrd4 = rank > 4 ? shape::strideAt(yShapeInfo, zOrder == 'c' ? 4 : rank - 5) : 0;
uint zAxis0 = shape::sizeAt(zShapeInfo, zOrder == 'c' ? 0 : rank-1);
uint zAxis1 = shape::sizeAt(zShapeInfo, zOrder == 'c' ? 1 : rank-2);
uint zAxis2 = rank > 2 ? shape::sizeAt(zShapeInfo, zOrder == 'c' ? 2 : rank - 3) : 0;
uint zAxis3 = rank > 3 ? shape::sizeAt(zShapeInfo, zOrder == 'c' ? 3 : rank - 4) : 0;
uint zAxis4 = rank > 4 ? shape::sizeAt(zShapeInfo, zOrder == 'c' ? 4 : rank - 5) : 0;
Nd4jLong zStrd0 = shape::strideAt(zShapeInfo, zOrder == 'c' ? 0 : rank-1);
Nd4jLong zStrd1 = shape::strideAt(zShapeInfo, zOrder == 'c' ? 1 : rank-2);
Nd4jLong zStrd2 = rank > 2 ? shape::strideAt(zShapeInfo, zOrder == 'c' ? 2 : rank - 3) : 0;
Nd4jLong zStrd3 = rank > 3 ? shape::strideAt(zShapeInfo, zOrder == 'c' ? 3 : rank - 4) : 0;
Nd4jLong zStrd4 = rank > 4 ? shape::strideAt(zShapeInfo, zOrder == 'c' ? 4 : rank - 5) : 0;
switch (rank) { switch (rank) {
case 1: { case 1:
execRank1<X,Y,Z, OpType>(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo);
auto func = PRAGMA_THREADS_FOR{ break;
case 2:
if(zStrd0 == 1 && xStrd0 == 1 && yStrd0 == 0) { execRank2<X,Y,Z, OpType>(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo);
for (auto i0 = start; i0 < stop; ++i0) break;
z[i0] = OpType::op(x[i0], *y); case 3:
} execRank3<X,Y,Z, OpType>(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo);
else if(zStrd0 == 1 && xStrd0 == 0 && yStrd0 == 1) { break;
for (auto i0 = start; i0 < stop; ++i0) case 4:
z[i0] = OpType::op(*x, y[i0]); execRank4<X,Y,Z, OpType>(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo);
} break;
else if(zStrd0 == 1 && xStrd0 == 1 && yStrd0 == 1) { case 5:
for (auto i0 = start; i0 < stop; ++i0) execRank5<X,Y,Z, OpType>(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo);
z[i0] = OpType::op(x[i0], y[i0]); break;
} default:
else { execDefault<X,Y,Z, OpType>(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo);
for (auto i0 = start; i0 < stop; ++i0)
z[i0 * zStrd0] = OpType::op(x[i0 * xStrd0], y[i0 * yStrd0]);
}
};
samediff::Threads::parallel_tad(func, 0, zAxis0);
}
break;
case 2: {
auto func = PRAGMA_THREADS_FOR{
for (auto i0 = start; i0 < stop; ++i0) {
auto x0 = x + i0 * xStrd0;
auto y0 = y + i0 * yStrd0;
auto z0 = z + i0 * zStrd0;
if(zStrd1 == 1 && xStrd1 == 1 && yStrd1 == 0)
for (uint i1 = 0; i1 < zAxis1; ++i1)
z0[i1] = OpType::op(x0[i1], *y0);
else if(zStrd1 == 1 && xStrd1 == 0 && yStrd1 == 1)
for (uint i1 = 0; i1 < zAxis1; ++i1)
z0[i1] = OpType::op(*x0, y0[i1]);
else if(zStrd1 == 1 && xStrd1 == 1 && yStrd1 == 1)
for (uint i1 = 0; i1 < zAxis1; ++i1)
z0[i1] = OpType::op(x0[i1], y0[i1]);
else
for (uint i1 = 0; i1 < zAxis1; ++i1)
z0[i1 * zStrd1] = OpType::op(x0[i1 * xStrd1], y0[i1 * yStrd1]);
}
};
samediff::Threads::parallel_tad(func, 0, zAxis0);
}
break;
case 3: {
auto func = PRAGMA_THREADS_FOR_2D {
for (auto i0 = start_x; i0 < stop_x; ++i0) {
for (auto i1 = start_y; i1 < stop_y; ++i1) {
auto x1 = x + i0 * xStrd0 + i1 * xStrd1;
auto y1 = y + i0 * yStrd0 + i1 * yStrd1;
auto z1 = z + i0 * zStrd0 + i1 * zStrd1;
if(zStrd2 == 1 && xStrd2 == 1 && yStrd2 == 0)
for (uint i2 = 0; i2 < zAxis2; ++i2)
z1[i2] = OpType::op(x1[i2], *y1);
else if(zStrd2 == 1 && xStrd2 == 0 && yStrd2 == 1)
for (uint i2 = 0; i2 < zAxis2; ++i2)
z1[i2] = OpType::op(*x1, y1[i2]);
else if(zStrd2 == 1 && xStrd2 == 1 && yStrd2 == 1)
for (uint i2 = 0; i2 < zAxis2; ++i2)
z1[i2] = OpType::op(x1[i2], y1[i2]);
else
for (uint i2 = 0; i2 < zAxis2; ++i2)
z1[i2 * zStrd2] = OpType::op(x1[i2 * xStrd2], y1[i2 * yStrd2]);
}
}
};
samediff::Threads::parallel_for(func, 0,zAxis0,1, 0,zAxis1,1);
}
break;
case 4: {
auto func = PRAGMA_THREADS_FOR_3D {
for (auto i0 = start_x; i0 < stop_x; ++i0) {
for (auto i1 = start_y; i1 < stop_y; ++i1) {
for (auto i2 = start_z; i2 < stop_z; ++i2) {
auto x2 = x + i0 * xStrd0 + i1 * xStrd1 + i2 * xStrd2;
auto y2 = y + i0 * yStrd0 + i1 * yStrd1 + i2 * yStrd2;
auto z2 = z + i0 * zStrd0 + i1 * zStrd1 + i2 * zStrd2;
if(zStrd3 == 1 && xStrd3 == 1 && yStrd3 == 0)
for (uint i3 = 0; i3 < zAxis3; ++i3)
z2[i3] = OpType::op(x2[i3], *y2);
else if(zStrd3 == 1 && xStrd3 == 0 && yStrd3 == 1)
for (uint i3 = 0; i3 < zAxis3; ++i3)
z2[i3] = OpType::op(*x2, y2[i3]);
else if(zStrd3 == 1 && xStrd3 == 1 && yStrd3 == 1)
for (uint i3 = 0; i3 < zAxis3; ++i3)
z2[i3] = OpType::op(x2[i3], y2[i3]);
else
for (uint i3 = 0; i3 < zAxis3; ++i3)
z2[i3 * zStrd3] = OpType::op(x2[i3 * xStrd3], y2[i3 * yStrd3]);
}
}
}
};
samediff::Threads::parallel_for(func, 0,zAxis0,1, 0,zAxis1,1, 0,zAxis2,1);
}
break;
case 5: {
auto func = PRAGMA_THREADS_FOR_3D {
for (auto i0 = start_x; i0 < stop_x; ++i0) {
for (auto i1 = start_y; i1 < stop_y; ++i1) {
for (auto i2 = start_z; i2 < stop_z; ++i2) {
for (uint i3 = 0; i3 < zAxis3; ++i3) {
auto x3 = x + i0 * xStrd0 + i1 * xStrd1 + i2 * xStrd2 + i3 * xStrd3;
auto y3 = y + i0 * yStrd0 + i1 * yStrd1 + i2 * yStrd2 + i3 * yStrd3;
auto z3 = z + i0 * zStrd0 + i1 * zStrd1 + i2 * zStrd2 + i3 * zStrd3;
if(zStrd4 == 1 && xStrd4 == 1 && yStrd4 == 0)
for (uint i4 = 0; i4 < zAxis4; ++i4)
z3[i4] = OpType::op(x3[i4], *y3);
else if(zStrd4 == 1 && xStrd4 == 0 && yStrd4 == 1)
for (uint i4 = 0; i4 < zAxis4; ++i4)
z3[i4] = OpType::op(*x3, y3[i4]);
else if(zStrd4 == 1 && xStrd4 == 1 && yStrd4 == 1)
for (uint i4 = 0; i4 < zAxis4; ++i4)
z3[i4] = OpType::op(x3[i4], y3[i4]);
else
for (uint i4 = 0; i4 < zAxis4; ++i4)
z3[i4 * zStrd4] = OpType::op(x3[i4 * xStrd4], y3[i4 * yStrd4]);
}
}
}
}
};
samediff::Threads::parallel_for(func, 0,zAxis0,1, 0,zAxis1,1, 0,zAxis2,1);
}
break;
default: {
const bool xzSameOffsets = shape::haveSameShapeAndStrides(xShapeInfo, zShapeInfo);
const bool yzSameOffsets = shape::haveSameShapeAndStrides(yShapeInfo, zShapeInfo);
auto func = PRAGMA_THREADS_FOR{
int xCoords[MAX_RANK], yCoords[MAX_RANK], zCoords[MAX_RANK];
for (auto i = start; i < stop; ++i) {
shape::index2coordsCPU(start, i, zShapeInfo, zCoords);
for (uint j = 0; j < rank; ++j) {
xCoords[j] = shape::sizeAt(xShapeInfo, j) == 1 ? 0 : zCoords[j];
yCoords[j] = shape::sizeAt(yShapeInfo, j) == 1 ? 0 : zCoords[j];
}
const auto zOffset = shape::getOffset(zShapeInfo, zCoords);
const auto xOffset = xzSameOffsets ? zOffset : shape::getOffset(xShapeInfo, xCoords);
const auto yOffset = yzSameOffsets ? zOffset : shape::getOffset(yShapeInfo, yCoords);
z[zOffset] = OpType::op(x[xOffset], y[yOffset]);
}
};
samediff::Threads::parallel_for(func, 0, shape::length(zShapeInfo));
}
} }
} }

View File

@ -453,6 +453,271 @@ namespace broadcast {
} }
} }
////////////////////////////////////////////////////////////////////////
template <typename X, typename Z, typename OpType>
static void execRank1(const X *x, const Nd4jLong *xShapeInfo, const X *y, const Nd4jLong *yShapeInfo, Z* z, const Nd4jLong *zShapeInfo, X* extraParams) {
uint zAxis0 = shape::sizeAt(zShapeInfo, 0);
Nd4jLong xStrd0 = shape::strideAt(xShapeInfo, 0);
Nd4jLong yStrd0 = shape::strideAt(yShapeInfo, 0);
Nd4jLong zStrd0 = shape::strideAt(zShapeInfo, 0);
auto func = PRAGMA_THREADS_FOR{
if(zStrd0 == 1 && xStrd0 == 1 && yStrd0 == 0) {
for (auto i0 = start; i0 < stop; ++i0)
z[i0] = OpType::op(x[i0], *y, extraParams);
}
else if(zStrd0 == 1 && xStrd0 == 0 && yStrd0 == 1) {
for (auto i0 = start; i0 < stop; ++i0)
z[i0] = OpType::op(*x, y[i0], extraParams);
}
else if(zStrd0 == 1 && xStrd0 == 1 && yStrd0 == 1) {
for (auto i0 = start; i0 < stop; ++i0)
z[i0] = OpType::op(x[i0], y[i0], extraParams);
}
else {
for (auto i0 = start; i0 < stop; ++i0)
z[i0 * zStrd0] = OpType::op(x[i0 * xStrd0], y[i0 * yStrd0], extraParams);
}
};
samediff::Threads::parallel_tad(func, 0, zAxis0);
}
////////////////////////////////////////////////////////////////////////
template <typename X, typename Z, typename OpType>
static void execRank2(const X *x, const Nd4jLong *xShapeInfo, const X *y, const Nd4jLong *yShapeInfo, Z* z, const Nd4jLong *zShapeInfo, X* extraParams) {
uint zAxis0 = shape::sizeAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 1);
Nd4jLong xStrd0 = shape::strideAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 1);
Nd4jLong yStrd0 = shape::strideAt(yShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 1);
Nd4jLong zStrd0 = shape::strideAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 1);
uint zAxis1 = shape::sizeAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 1 : 0);
Nd4jLong xStrd1 = shape::strideAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? 1 : 0);
Nd4jLong yStrd1 = shape::strideAt(yShapeInfo, shape::order(zShapeInfo) == 'c' ? 1 : 0);
Nd4jLong zStrd1 = shape::strideAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 1 : 0);
auto func = PRAGMA_THREADS_FOR{
for (auto i0 = start; i0 < stop; ++i0) {
auto x0 = x + i0 * xStrd0;
auto y0 = y + i0 * yStrd0;
auto z0 = z + i0 * zStrd0;
if(zStrd1 == 1 && xStrd1 == 1 && yStrd1 == 0)
for (uint i1 = 0; i1 < zAxis1; ++i1)
z0[i1] = OpType::op(x0[i1], *y0, extraParams);
else if(zStrd1 == 1 && xStrd1 == 0 && yStrd1 == 1)
for (uint i1 = 0; i1 < zAxis1; ++i1)
z0[i1] = OpType::op(*x0, y0[i1], extraParams);
else if(zStrd1 == 1 && xStrd1 == 1 && yStrd1 == 1)
for (uint i1 = 0; i1 < zAxis1; ++i1)
z0[i1] = OpType::op(x0[i1], y0[i1], extraParams);
else
for (uint i1 = 0; i1 < zAxis1; ++i1)
z0[i1 * zStrd1] = OpType::op(x0[i1 * xStrd1], y0[i1 * yStrd1], extraParams);
}
};
samediff::Threads::parallel_tad(func, 0, zAxis0);
}
////////////////////////////////////////////////////////////////////////
template <typename X, typename Z, typename OpType>
static void execRank3(const X *x, const Nd4jLong *xShapeInfo, const X *y, const Nd4jLong *yShapeInfo, Z* z, const Nd4jLong *zShapeInfo, X* extraParams) {
uint zAxis0 = shape::sizeAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 2);
Nd4jLong xStrd0 = shape::strideAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 2);
Nd4jLong yStrd0 = shape::strideAt(yShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 2);
Nd4jLong zStrd0 = shape::strideAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 2);
uint zAxis1 = shape::sizeAt(zShapeInfo, 1);
Nd4jLong xStrd1 = shape::strideAt(xShapeInfo, 1);
Nd4jLong yStrd1 = shape::strideAt(yShapeInfo, 1);
Nd4jLong zStrd1 = shape::strideAt(zShapeInfo, 1);
uint zAxis2 = shape::sizeAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 2 : 0);
Nd4jLong xStrd2 = shape::strideAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? 2 : 0);
Nd4jLong yStrd2 = shape::strideAt(yShapeInfo, shape::order(zShapeInfo) == 'c' ? 2 : 0);
Nd4jLong zStrd2 = shape::strideAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 2 : 0);
auto func = PRAGMA_THREADS_FOR_2D {
for (auto i0 = start_x; i0 < stop_x; ++i0) {
for (auto i1 = start_y; i1 < stop_y; ++i1) {
auto x1 = x + i0 * xStrd0 + i1 * xStrd1;
auto y1 = y + i0 * yStrd0 + i1 * yStrd1;
auto z1 = z + i0 * zStrd0 + i1 * zStrd1;
if(zStrd2 == 1 && xStrd2 == 1 && yStrd2 == 0)
for (uint i2 = 0; i2 < zAxis2; ++i2)
z1[i2] = OpType::op(x1[i2], *y1, extraParams);
else if(zStrd2 == 1 && xStrd2 == 0 && yStrd2 == 1)
for (uint i2 = 0; i2 < zAxis2; ++i2)
z1[i2] = OpType::op(*x1, y1[i2], extraParams);
else if(zStrd2 == 1 && xStrd2 == 1 && yStrd2 == 1)
for (uint i2 = 0; i2 < zAxis2; ++i2)
z1[i2] = OpType::op(x1[i2], y1[i2], extraParams);
else
for (uint i2 = 0; i2 < zAxis2; ++i2)
z1[i2 * zStrd2] = OpType::op(x1[i2 * xStrd2], y1[i2 * yStrd2], extraParams);
}
}
};
samediff::Threads::parallel_for(func, 0,zAxis0,1, 0,zAxis1,1);
}
////////////////////////////////////////////////////////////////////////
template <typename X, typename Z, typename OpType>
static void execRank4(const X *x, const Nd4jLong *xShapeInfo, const X *y, const Nd4jLong *yShapeInfo, Z* z, const Nd4jLong *zShapeInfo, X* extraParams) {
uint zAxis0 = shape::sizeAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 3);
Nd4jLong xStrd0 = shape::strideAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 3);
Nd4jLong yStrd0 = shape::strideAt(yShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 3);
Nd4jLong zStrd0 = shape::strideAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 3);
uint zAxis1 = shape::sizeAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 1 : 2);
Nd4jLong xStrd1 = shape::strideAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? 1 : 2);
Nd4jLong yStrd1 = shape::strideAt(yShapeInfo, shape::order(zShapeInfo) == 'c' ? 1 : 2);
Nd4jLong zStrd1 = shape::strideAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 1 : 2);
uint zAxis2 = shape::sizeAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 2 : 1);
Nd4jLong xStrd2 = shape::strideAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? 2 : 1);
Nd4jLong yStrd2 = shape::strideAt(yShapeInfo, shape::order(zShapeInfo) == 'c' ? 2 : 1);
Nd4jLong zStrd2 = shape::strideAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 2 : 1);
uint zAxis3 = shape::sizeAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 3 : 0);
Nd4jLong xStrd3 = shape::strideAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? 3 : 0);
Nd4jLong yStrd3 = shape::strideAt(yShapeInfo, shape::order(zShapeInfo) == 'c' ? 3 : 0);
Nd4jLong zStrd3 = shape::strideAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 3 : 0);
auto func = PRAGMA_THREADS_FOR_3D {
for (auto i0 = start_x; i0 < stop_x; ++i0) {
for (auto i1 = start_y; i1 < stop_y; ++i1) {
for (auto i2 = start_z; i2 < stop_z; ++i2) {
auto x2 = x + i0 * xStrd0 + i1 * xStrd1 + i2 * xStrd2;
auto y2 = y + i0 * yStrd0 + i1 * yStrd1 + i2 * yStrd2;
auto z2 = z + i0 * zStrd0 + i1 * zStrd1 + i2 * zStrd2;
if(zStrd3 == 1 && xStrd3 == 1 && yStrd3 == 0)
for (uint i3 = 0; i3 < zAxis3; ++i3)
z2[i3] = OpType::op(x2[i3], *y2, extraParams);
else if(zStrd3 == 1 && xStrd3 == 0 && yStrd3 == 1)
for (uint i3 = 0; i3 < zAxis3; ++i3)
z2[i3] = OpType::op(*x2, y2[i3], extraParams);
else if(zStrd3 == 1 && xStrd3 == 1 && yStrd3 == 1)
for (uint i3 = 0; i3 < zAxis3; ++i3)
z2[i3] = OpType::op(x2[i3], y2[i3], extraParams);
else
for (uint i3 = 0; i3 < zAxis3; ++i3)
z2[i3 * zStrd3] = OpType::op(x2[i3 * xStrd3], y2[i3 * yStrd3], extraParams);
}
}
}
};
samediff::Threads::parallel_for(func, 0,zAxis0,1, 0,zAxis1,1, 0,zAxis2,1);
}
////////////////////////////////////////////////////////////////////////
template <typename X, typename Z, typename OpType>
static void execRank5(const X *x, const Nd4jLong *xShapeInfo, const X *y, const Nd4jLong *yShapeInfo, Z* z, const Nd4jLong *zShapeInfo, X* extraParams) {
uint zAxis0 = shape::sizeAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 4);
Nd4jLong xStrd0 = shape::strideAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 4);
Nd4jLong yStrd0 = shape::strideAt(yShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 4);
Nd4jLong zStrd0 = shape::strideAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 4);
uint zAxis1 = shape::sizeAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 1 : 3);
Nd4jLong xStrd1 = shape::strideAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? 1 : 3);
Nd4jLong yStrd1 = shape::strideAt(yShapeInfo, shape::order(zShapeInfo) == 'c' ? 1 : 3);
Nd4jLong zStrd1 = shape::strideAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 1 : 3);
uint zAxis2 = shape::sizeAt(zShapeInfo, 2);
Nd4jLong xStrd2 = shape::strideAt(xShapeInfo, 2);
Nd4jLong yStrd2 = shape::strideAt(yShapeInfo, 2);
Nd4jLong zStrd2 = shape::strideAt(zShapeInfo, 2);
uint zAxis3 = shape::sizeAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 3 : 1);
Nd4jLong xStrd3 = shape::strideAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? 3 : 1);
Nd4jLong yStrd3 = shape::strideAt(yShapeInfo, shape::order(zShapeInfo) == 'c' ? 3 : 1);
Nd4jLong zStrd3 = shape::strideAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 3 : 1);
uint zAxis4 = shape::sizeAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 4 : 0);
Nd4jLong xStrd4 = shape::strideAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? 4 : 0);
Nd4jLong yStrd4 = shape::strideAt(yShapeInfo, shape::order(zShapeInfo) == 'c' ? 4 : 0);
Nd4jLong zStrd4 = shape::strideAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 4 : 0);
auto func = PRAGMA_THREADS_FOR_3D {
for (auto i0 = start_x; i0 < stop_x; ++i0) {
for (auto i1 = start_y; i1 < stop_y; ++i1) {
for (auto i2 = start_z; i2 < stop_z; ++i2) {
for (uint i3 = 0; i3 < zAxis3; ++i3) {
auto x3 = x + i0 * xStrd0 + i1 * xStrd1 + i2 * xStrd2 + i3 * xStrd3;
auto y3 = y + i0 * yStrd0 + i1 * yStrd1 + i2 * yStrd2 + i3 * yStrd3;
auto z3 = z + i0 * zStrd0 + i1 * zStrd1 + i2 * zStrd2 + i3 * zStrd3;
if(zStrd4 == 1 && xStrd4 == 1 && yStrd4 == 0)
for (uint i4 = 0; i4 < zAxis4; ++i4)
z3[i4] = OpType::op(x3[i4], *y3, extraParams);
else if(zStrd4 == 1 && xStrd4 == 0 && yStrd4 == 1)
for (uint i4 = 0; i4 < zAxis4; ++i4)
z3[i4] = OpType::op(*x3, y3[i4], extraParams);
else if(zStrd4 == 1 && xStrd4 == 1 && yStrd4 == 1)
for (uint i4 = 0; i4 < zAxis4; ++i4)
z3[i4] = OpType::op(x3[i4], y3[i4], extraParams);
else
for (uint i4 = 0; i4 < zAxis4; ++i4)
z3[i4 * zStrd4] = OpType::op(x3[i4 * xStrd4], y3[i4 * yStrd4], extraParams);
}
}
}
}
};
samediff::Threads::parallel_for(func, 0,zAxis0,1, 0,zAxis1,1, 0,zAxis2,1);
}
////////////////////////////////////////////////////////////////////////
template <typename X, typename Z, typename OpType>
static void execDefault(const X *x, const Nd4jLong *xShapeInfo, const X *y, const Nd4jLong *yShapeInfo, Z* z, const Nd4jLong *zShapeInfo, X* extraParams) {
const bool xzSameOffsets = shape::haveSameShapeAndStrides(xShapeInfo, zShapeInfo);
const bool yzSameOffsets = shape::haveSameShapeAndStrides(yShapeInfo, zShapeInfo);
const int rank = shape::rank(zShapeInfo); // xRank = yRank = zRank
auto func = PRAGMA_THREADS_FOR{
int xCoords[MAX_RANK], yCoords[MAX_RANK], zCoords[MAX_RANK];
for (auto i = start; i < stop; ++i) {
shape::index2coordsCPU(start, i, zShapeInfo, zCoords);
for (uint j = 0; j < rank; ++j) {
xCoords[j] = shape::sizeAt(xShapeInfo, j) == 1 ? 0 : zCoords[j];
yCoords[j] = shape::sizeAt(yShapeInfo, j) == 1 ? 0 : zCoords[j];
}
const auto zOffset = shape::getOffset(zShapeInfo, zCoords);
const auto xOffset = xzSameOffsets ? zOffset : shape::getOffset(xShapeInfo, xCoords);
const auto yOffset = yzSameOffsets ? zOffset : shape::getOffset(yShapeInfo, yCoords);
z[zOffset] = OpType::op(x[xOffset], y[yOffset], extraParams);
}
};
samediff::Threads::parallel_for(func, 0, shape::length(zShapeInfo));
}
//////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////
template <typename X, typename Z> template <typename X, typename Z>
template<typename OpType> template<typename OpType>
@ -468,220 +733,26 @@ void BroadcastBool<X, Z>::exec(const void *vx, const Nd4jLong *xShapeInfo,
X* extraParams = reinterpret_cast<X*>(vextraParams); X* extraParams = reinterpret_cast<X*>(vextraParams);
const int rank = shape::rank(zShapeInfo); // xRank = yRank = zRank const int rank = shape::rank(zShapeInfo); // xRank = yRank = zRank
const char zOrder = shape::order(zShapeInfo);
uint xAxis0 = shape::sizeAt(xShapeInfo, zOrder == 'c' ? 0 : rank-1);
uint xAxis1 = shape::sizeAt(xShapeInfo, zOrder == 'c' ? 1 : rank-2);
uint xAxis2 = rank > 2 ? shape::sizeAt(xShapeInfo, zOrder == 'c' ? 2 : rank - 3) : 0;
uint xAxis3 = rank > 3 ? shape::sizeAt(xShapeInfo, zOrder == 'c' ? 3 : rank - 4) : 0;
uint xAxis4 = rank > 4 ? shape::sizeAt(xShapeInfo, zOrder == 'c' ? 4 : rank - 5) : 0;
Nd4jLong xStrd0 = shape::strideAt(xShapeInfo, zOrder == 'c' ? 0 : rank-1);
Nd4jLong xStrd1 = shape::strideAt(xShapeInfo, zOrder == 'c' ? 1 : rank-2);
Nd4jLong xStrd2 = rank > 2 ? shape::strideAt(xShapeInfo, zOrder == 'c' ? 2 : rank - 3) : 0;
Nd4jLong xStrd3 = rank > 3 ? shape::strideAt(xShapeInfo, zOrder == 'c' ? 3 : rank - 4) : 0;
Nd4jLong xStrd4 = rank > 4 ? shape::strideAt(xShapeInfo, zOrder == 'c' ? 4 : rank - 5) : 0;
uint yAxis0 = shape::sizeAt(yShapeInfo, zOrder == 'c' ? 0 : rank-1);
uint yAxis1 = shape::sizeAt(yShapeInfo, zOrder == 'c' ? 1 : rank-2);
uint yAxis2 = rank > 2 ? shape::sizeAt(yShapeInfo, zOrder == 'c' ? 2 : rank - 3) : 0;
uint yAxis3 = rank > 3 ? shape::sizeAt(yShapeInfo, zOrder == 'c' ? 3 : rank - 4) : 0;
uint yAxis4 = rank > 4 ? shape::sizeAt(yShapeInfo, zOrder == 'c' ? 4 : rank - 5) : 0;
Nd4jLong yStrd0 = shape::strideAt(yShapeInfo, zOrder == 'c' ? 0 : rank-1);
Nd4jLong yStrd1 = shape::strideAt(yShapeInfo, zOrder == 'c' ? 1 : rank-2);
Nd4jLong yStrd2 = rank > 2 ? shape::strideAt(yShapeInfo, zOrder == 'c' ? 2 : rank - 3) : 0;
Nd4jLong yStrd3 = rank > 3 ? shape::strideAt(yShapeInfo, zOrder == 'c' ? 3 : rank - 4) : 0;
Nd4jLong yStrd4 = rank > 4 ? shape::strideAt(yShapeInfo, zOrder == 'c' ? 4 : rank - 5) : 0;
uint zAxis0 = shape::sizeAt(zShapeInfo, zOrder == 'c' ? 0 : rank-1);
uint zAxis1 = shape::sizeAt(zShapeInfo, zOrder == 'c' ? 1 : rank-2);
uint zAxis2 = rank > 2 ? shape::sizeAt(zShapeInfo, zOrder == 'c' ? 2 : rank - 3) : 0;
uint zAxis3 = rank > 3 ? shape::sizeAt(zShapeInfo, zOrder == 'c' ? 3 : rank - 4) : 0;
uint zAxis4 = rank > 4 ? shape::sizeAt(zShapeInfo, zOrder == 'c' ? 4 : rank - 5) : 0;
Nd4jLong zStrd0 = shape::strideAt(zShapeInfo, zOrder == 'c' ? 0 : rank-1);
Nd4jLong zStrd1 = shape::strideAt(zShapeInfo, zOrder == 'c' ? 1 : rank-2);
Nd4jLong zStrd2 = rank > 2 ? shape::strideAt(zShapeInfo, zOrder == 'c' ? 2 : rank - 3) : 0;
Nd4jLong zStrd3 = rank > 3 ? shape::strideAt(zShapeInfo, zOrder == 'c' ? 3 : rank - 4) : 0;
Nd4jLong zStrd4 = rank > 4 ? shape::strideAt(zShapeInfo, zOrder == 'c' ? 4 : rank - 5) : 0;
switch (rank) { switch (rank) {
case 1: { case 1:
execRank1<X,Z, OpType>(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, extraParams);
auto func = PRAGMA_THREADS_FOR{ break;
case 2:
if(zStrd0 == 1 && xStrd0 == 1 && yStrd0 == 0) { execRank2<X,Z, OpType>(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, extraParams);
for (auto i0 = start; i0 < stop; ++i0) break;
z[i0] = OpType::op(x[i0], *y, extraParams); case 3:
} execRank3<X,Z, OpType>(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, extraParams);
else if(zStrd0 == 1 && xStrd0 == 0 && yStrd0 == 1) { break;
for (auto i0 = start; i0 < stop; ++i0) case 4:
z[i0] = OpType::op(*x, y[i0], extraParams); execRank4<X,Z, OpType>(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, extraParams);
} break;
else if(zStrd0 == 1 && xStrd0 == 1 && yStrd0 == 1) { case 5:
for (auto i0 = start; i0 < stop; ++i0) execRank5<X,Z, OpType>(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, extraParams);
z[i0] = OpType::op(x[i0], y[i0], extraParams); break;
} default:
else { execDefault<X,Z, OpType>(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, extraParams);
for (auto i0 = start; i0 < stop; ++i0)
z[i0 * zStrd0] = OpType::op(x[i0 * xStrd0], y[i0 * yStrd0], extraParams);
}
};
samediff::Threads::parallel_tad(func, 0, zAxis0);
}
break;
case 2: {
auto func = PRAGMA_THREADS_FOR{
for (auto i0 = start; i0 < stop; ++i0) {
auto x0 = x + i0 * xStrd0;
auto y0 = y + i0 * yStrd0;
auto z0 = z + i0 * zStrd0;
if(zStrd1 == 1 && xStrd1 == 1 && yStrd1 == 0)
for (uint i1 = 0; i1 < zAxis1; ++i1)
z0[i1] = OpType::op(x0[i1], *y0, extraParams);
else if(zStrd1 == 1 && xStrd1 == 0 && yStrd1 == 1)
for (uint i1 = 0; i1 < zAxis1; ++i1)
z0[i1] = OpType::op(*x0, y0[i1], extraParams);
else if(zStrd1 == 1 && xStrd1 == 1 && yStrd1 == 1)
for (uint i1 = 0; i1 < zAxis1; ++i1)
z0[i1] = OpType::op(x0[i1], y0[i1], extraParams);
else
for (uint i1 = 0; i1 < zAxis1; ++i1)
z0[i1 * zStrd1] = OpType::op(x0[i1 * xStrd1], y0[i1 * yStrd1], extraParams);
}
};
samediff::Threads::parallel_tad(func, 0, zAxis0);
}
break;
case 3: {
auto func = PRAGMA_THREADS_FOR_2D {
for (auto i0 = start_x; i0 < stop_x; ++i0) {
for (auto i1 = start_y; i1 < stop_y; ++i1) {
auto x1 = x + i0 * xStrd0 + i1 * xStrd1;
auto y1 = y + i0 * yStrd0 + i1 * yStrd1;
auto z1 = z + i0 * zStrd0 + i1 * zStrd1;
if(zStrd2 == 1 && xStrd2 == 1 && yStrd2 == 0)
for (uint i2 = 0; i2 < zAxis2; ++i2)
z1[i2] = OpType::op(x1[i2], *y1, extraParams);
else if(zStrd2 == 1 && xStrd2 == 0 && yStrd2 == 1)
for (uint i2 = 0; i2 < zAxis2; ++i2)
z1[i2] = OpType::op(*x1, y1[i2], extraParams);
else if(zStrd2 == 1 && xStrd2 == 1 && yStrd2 == 1)
for (uint i2 = 0; i2 < zAxis2; ++i2)
z1[i2] = OpType::op(x1[i2], y1[i2], extraParams);
else
for (uint i2 = 0; i2 < zAxis2; ++i2)
z1[i2 * zStrd2] = OpType::op(x1[i2 * xStrd2], y1[i2 * yStrd2], extraParams);
}
}
};
samediff::Threads::parallel_for(func, 0,zAxis0,1, 0,zAxis1,1);
}
break;
case 4: {
auto func = PRAGMA_THREADS_FOR_3D {
for (auto i0 = start_x; i0 < stop_x; ++i0) {
for (auto i1 = start_y; i1 < stop_y; ++i1) {
for (auto i2 = start_z; i2 < stop_z; ++i2) {
auto x2 = x + i0 * xStrd0 + i1 * xStrd1 + i2 * xStrd2;
auto y2 = y + i0 * yStrd0 + i1 * yStrd1 + i2 * yStrd2;
auto z2 = z + i0 * zStrd0 + i1 * zStrd1 + i2 * zStrd2;
if(zStrd3 == 1 && xStrd3 == 1 && yStrd3 == 0)
for (uint i3 = 0; i3 < zAxis3; ++i3)
z2[i3] = OpType::op(x2[i3], *y2, extraParams);
else if(zStrd3 == 1 && xStrd3 == 0 && yStrd3 == 1)
for (uint i3 = 0; i3 < zAxis3; ++i3)
z2[i3] = OpType::op(*x2, y2[i3], extraParams);
else if(zStrd3 == 1 && xStrd3 == 1 && yStrd3 == 1)
for (uint i3 = 0; i3 < zAxis3; ++i3)
z2[i3] = OpType::op(x2[i3], y2[i3], extraParams);
else
for (uint i3 = 0; i3 < zAxis3; ++i3)
z2[i3 * zStrd3] = OpType::op(x2[i3 * xStrd3], y2[i3 * yStrd3], extraParams);
}
}
}
};
samediff::Threads::parallel_for(func, 0,zAxis0,1, 0,zAxis1,1, 0,zAxis2,1);
}
break;
case 5: {
auto func = PRAGMA_THREADS_FOR_3D {
for (auto i0 = start_x; i0 < stop_x; ++i0) {
for (auto i1 = start_y; i1 < stop_y; ++i1) {
for (auto i2 = start_z; i2 < stop_z; ++i2) {
for (uint i3 = 0; i3 < zAxis3; ++i3) {
auto x3 = x + i0 * xStrd0 + i1 * xStrd1 + i2 * xStrd2 + i3 * xStrd3;
auto y3 = y + i0 * yStrd0 + i1 * yStrd1 + i2 * yStrd2 + i3 * yStrd3;
auto z3 = z + i0 * zStrd0 + i1 * zStrd1 + i2 * zStrd2 + i3 * zStrd3;
if(zStrd4 == 1 && xStrd4 == 1 && yStrd4 == 0)
for (uint i4 = 0; i4 < zAxis4; ++i4)
z3[i4] = OpType::op(x3[i4], *y3, extraParams);
else if(zStrd4 == 1 && xStrd4 == 0 && yStrd4 == 1)
for (uint i4 = 0; i4 < zAxis4; ++i4)
z3[i4] = OpType::op(*x3, y3[i4], extraParams);
else if(zStrd4 == 1 && xStrd4 == 1 && yStrd4 == 1)
for (uint i4 = 0; i4 < zAxis4; ++i4)
z3[i4] = OpType::op(x3[i4], y3[i4], extraParams);
else
for (uint i4 = 0; i4 < zAxis4; ++i4)
z3[i4 * zStrd4] = OpType::op(x3[i4 * xStrd4], y3[i4 * yStrd4], extraParams);
}
}
}
}
};
samediff::Threads::parallel_for(func, 0,zAxis0,1, 0,zAxis1,1, 0,zAxis2,1);
}
break;
default: {
const bool xzSameOffsets = shape::haveSameShapeAndStrides(xShapeInfo, zShapeInfo);
const bool yzSameOffsets = shape::haveSameShapeAndStrides(yShapeInfo, zShapeInfo);
auto func = PRAGMA_THREADS_FOR{
int xCoords[MAX_RANK], yCoords[MAX_RANK], zCoords[MAX_RANK];
for (auto i = start; i < stop; ++i) {
shape::index2coordsCPU(start, i, zShapeInfo, zCoords);
for (uint j = 0; j < rank; ++j) {
xCoords[j] = shape::sizeAt(xShapeInfo, j) == 1 ? 0 : zCoords[j];
yCoords[j] = shape::sizeAt(yShapeInfo, j) == 1 ? 0 : zCoords[j];
}
const auto zOffset = shape::getOffset(zShapeInfo, zCoords);
const auto xOffset = xzSameOffsets ? zOffset : shape::getOffset(xShapeInfo, xCoords);
const auto yOffset = yzSameOffsets ? zOffset : shape::getOffset(yShapeInfo, yCoords);
z[zOffset] = OpType::op(x[xOffset], y[yOffset], extraParams);
}
};
samediff::Threads::parallel_for(func, 0, shape::length(zShapeInfo));
}
} }
} }

View File

@ -439,6 +439,271 @@ namespace functions {
} }
} }
////////////////////////////////////////////////////////////////////////
template <typename X, typename OpType>
static void execRank1(const X *x, const Nd4jLong *xShapeInfo, const X *y, const Nd4jLong *yShapeInfo, X* z, const Nd4jLong *zShapeInfo) {
uint zAxis0 = shape::sizeAt(zShapeInfo, 0);
Nd4jLong xStrd0 = shape::strideAt(xShapeInfo, 0);
Nd4jLong yStrd0 = shape::strideAt(yShapeInfo, 0);
Nd4jLong zStrd0 = shape::strideAt(zShapeInfo, 0);
auto func = PRAGMA_THREADS_FOR{
if(zStrd0 == 1 && xStrd0 == 1 && yStrd0 == 0) {
for (auto i0 = start; i0 < stop; ++i0)
z[i0] = OpType::op(x[i0], *y);
}
else if(zStrd0 == 1 && xStrd0 == 0 && yStrd0 == 1) {
for (auto i0 = start; i0 < stop; ++i0)
z[i0] = OpType::op(*x, y[i0]);
}
else if(zStrd0 == 1 && xStrd0 == 1 && yStrd0 == 1) {
for (auto i0 = start; i0 < stop; ++i0)
z[i0] = OpType::op(x[i0], y[i0]);
}
else {
for (auto i0 = start; i0 < stop; ++i0)
z[i0 * zStrd0] = OpType::op(x[i0 * xStrd0], y[i0 * yStrd0]);
}
};
samediff::Threads::parallel_tad(func, 0, zAxis0);
}
////////////////////////////////////////////////////////////////////////
template <typename X, typename OpType>
static void execRank2(const X *x, const Nd4jLong *xShapeInfo, const X *y, const Nd4jLong *yShapeInfo, X* z, const Nd4jLong *zShapeInfo) {
uint zAxis0 = shape::sizeAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 1);
Nd4jLong xStrd0 = shape::strideAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 1);
Nd4jLong yStrd0 = shape::strideAt(yShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 1);
Nd4jLong zStrd0 = shape::strideAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 1);
uint zAxis1 = shape::sizeAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 1 : 0);
Nd4jLong xStrd1 = shape::strideAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? 1 : 0);
Nd4jLong yStrd1 = shape::strideAt(yShapeInfo, shape::order(zShapeInfo) == 'c' ? 1 : 0);
Nd4jLong zStrd1 = shape::strideAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 1 : 0);
auto func = PRAGMA_THREADS_FOR{
for (auto i0 = start; i0 < stop; ++i0) {
auto x0 = x + i0 * xStrd0;
auto y0 = y + i0 * yStrd0;
auto z0 = z + i0 * zStrd0;
if(zStrd1 == 1 && xStrd1 == 1 && yStrd1 == 0)
for (uint i1 = 0; i1 < zAxis1; ++i1)
z0[i1] = OpType::op(x0[i1], *y0);
else if(zStrd1 == 1 && xStrd1 == 0 && yStrd1 == 1)
for (uint i1 = 0; i1 < zAxis1; ++i1)
z0[i1] = OpType::op(*x0, y0[i1]);
else if(zStrd1 == 1 && xStrd1 == 1 && yStrd1 == 1)
for (uint i1 = 0; i1 < zAxis1; ++i1)
z0[i1] = OpType::op(x0[i1], y0[i1]);
else
for (uint i1 = 0; i1 < zAxis1; ++i1)
z0[i1 * zStrd1] = OpType::op(x0[i1 * xStrd1], y0[i1 * yStrd1]);
}
};
samediff::Threads::parallel_tad(func, 0, zAxis0);
}
////////////////////////////////////////////////////////////////////////
template <typename X, typename OpType>
static void execRank3(const X *x, const Nd4jLong *xShapeInfo, const X *y, const Nd4jLong *yShapeInfo, X* z, const Nd4jLong *zShapeInfo) {
uint zAxis0 = shape::sizeAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 2);
Nd4jLong xStrd0 = shape::strideAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 2);
Nd4jLong yStrd0 = shape::strideAt(yShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 2);
Nd4jLong zStrd0 = shape::strideAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 2);
uint zAxis1 = shape::sizeAt(zShapeInfo, 1);
Nd4jLong xStrd1 = shape::strideAt(xShapeInfo, 1);
Nd4jLong yStrd1 = shape::strideAt(yShapeInfo, 1);
Nd4jLong zStrd1 = shape::strideAt(zShapeInfo, 1);
uint zAxis2 = shape::sizeAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 2 : 0);
Nd4jLong xStrd2 = shape::strideAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? 2 : 0);
Nd4jLong yStrd2 = shape::strideAt(yShapeInfo, shape::order(zShapeInfo) == 'c' ? 2 : 0);
Nd4jLong zStrd2 = shape::strideAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 2 : 0);
auto func = PRAGMA_THREADS_FOR_2D {
for (auto i0 = start_x; i0 < stop_x; ++i0) {
for (auto i1 = start_y; i1 < stop_y; ++i1) {
auto x1 = x + i0 * xStrd0 + i1 * xStrd1;
auto y1 = y + i0 * yStrd0 + i1 * yStrd1;
auto z1 = z + i0 * zStrd0 + i1 * zStrd1;
if(zStrd2 == 1 && xStrd2 == 1 && yStrd2 == 0)
for (uint i2 = 0; i2 < zAxis2; ++i2)
z1[i2] = OpType::op(x1[i2], *y1);
else if(zStrd2 == 1 && xStrd2 == 0 && yStrd2 == 1)
for (uint i2 = 0; i2 < zAxis2; ++i2)
z1[i2] = OpType::op(*x1, y1[i2]);
else if(zStrd2 == 1 && xStrd2 == 1 && yStrd2 == 1)
for (uint i2 = 0; i2 < zAxis2; ++i2)
z1[i2] = OpType::op(x1[i2], y1[i2]);
else
for (uint i2 = 0; i2 < zAxis2; ++i2)
z1[i2 * zStrd2] = OpType::op(x1[i2 * xStrd2], y1[i2 * yStrd2]);
}
}
};
samediff::Threads::parallel_for(func, 0,zAxis0,1, 0,zAxis1,1);
}
////////////////////////////////////////////////////////////////////////
template <typename X, typename OpType>
static void execRank4(const X *x, const Nd4jLong *xShapeInfo, const X *y, const Nd4jLong *yShapeInfo, X* z, const Nd4jLong *zShapeInfo) {
uint zAxis0 = shape::sizeAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 3);
Nd4jLong xStrd0 = shape::strideAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 3);
Nd4jLong yStrd0 = shape::strideAt(yShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 3);
Nd4jLong zStrd0 = shape::strideAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 3);
uint zAxis1 = shape::sizeAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 1 : 2);
Nd4jLong xStrd1 = shape::strideAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? 1 : 2);
Nd4jLong yStrd1 = shape::strideAt(yShapeInfo, shape::order(zShapeInfo) == 'c' ? 1 : 2);
Nd4jLong zStrd1 = shape::strideAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 1 : 2);
uint zAxis2 = shape::sizeAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 2 : 1);
Nd4jLong xStrd2 = shape::strideAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? 2 : 1);
Nd4jLong yStrd2 = shape::strideAt(yShapeInfo, shape::order(zShapeInfo) == 'c' ? 2 : 1);
Nd4jLong zStrd2 = shape::strideAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 2 : 1);
uint zAxis3 = shape::sizeAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 3 : 0);
Nd4jLong xStrd3 = shape::strideAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? 3 : 0);
Nd4jLong yStrd3 = shape::strideAt(yShapeInfo, shape::order(zShapeInfo) == 'c' ? 3 : 0);
Nd4jLong zStrd3 = shape::strideAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 3 : 0);
auto func = PRAGMA_THREADS_FOR_3D {
for (auto i0 = start_x; i0 < stop_x; ++i0) {
for (auto i1 = start_y; i1 < stop_y; ++i1) {
for (auto i2 = start_z; i2 < stop_z; ++i2) {
auto x2 = x + i0 * xStrd0 + i1 * xStrd1 + i2 * xStrd2;
auto y2 = y + i0 * yStrd0 + i1 * yStrd1 + i2 * yStrd2;
auto z2 = z + i0 * zStrd0 + i1 * zStrd1 + i2 * zStrd2;
if(zStrd3 == 1 && xStrd3 == 1 && yStrd3 == 0)
for (uint i3 = 0; i3 < zAxis3; ++i3)
z2[i3] = OpType::op(x2[i3], *y2);
else if(zStrd3 == 1 && xStrd3 == 0 && yStrd3 == 1)
for (uint i3 = 0; i3 < zAxis3; ++i3)
z2[i3] = OpType::op(*x2, y2[i3]);
else if(zStrd3 == 1 && xStrd3 == 1 && yStrd3 == 1)
for (uint i3 = 0; i3 < zAxis3; ++i3)
z2[i3] = OpType::op(x2[i3], y2[i3]);
else
for (uint i3 = 0; i3 < zAxis3; ++i3)
z2[i3 * zStrd3] = OpType::op(x2[i3 * xStrd3], y2[i3 * yStrd3]);
}
}
}
};
samediff::Threads::parallel_for(func, 0,zAxis0,1, 0,zAxis1,1, 0,zAxis2,1);
}
////////////////////////////////////////////////////////////////////////
template <typename X, typename OpType>
static void execRank5(const X *x, const Nd4jLong *xShapeInfo, const X *y, const Nd4jLong *yShapeInfo, X* z, const Nd4jLong *zShapeInfo) {
uint zAxis0 = shape::sizeAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 4);
Nd4jLong xStrd0 = shape::strideAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 4);
Nd4jLong yStrd0 = shape::strideAt(yShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 4);
Nd4jLong zStrd0 = shape::strideAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 4);
uint zAxis1 = shape::sizeAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 1 : 3);
Nd4jLong xStrd1 = shape::strideAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? 1 : 3);
Nd4jLong yStrd1 = shape::strideAt(yShapeInfo, shape::order(zShapeInfo) == 'c' ? 1 : 3);
Nd4jLong zStrd1 = shape::strideAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 1 : 3);
uint zAxis2 = shape::sizeAt(zShapeInfo, 2);
Nd4jLong xStrd2 = shape::strideAt(xShapeInfo, 2);
Nd4jLong yStrd2 = shape::strideAt(yShapeInfo, 2);
Nd4jLong zStrd2 = shape::strideAt(zShapeInfo, 2);
uint zAxis3 = shape::sizeAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 3 : 1);
Nd4jLong xStrd3 = shape::strideAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? 3 : 1);
Nd4jLong yStrd3 = shape::strideAt(yShapeInfo, shape::order(zShapeInfo) == 'c' ? 3 : 1);
Nd4jLong zStrd3 = shape::strideAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 3 : 1);
uint zAxis4 = shape::sizeAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 4 : 0);
Nd4jLong xStrd4 = shape::strideAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? 4 : 0);
Nd4jLong yStrd4 = shape::strideAt(yShapeInfo, shape::order(zShapeInfo) == 'c' ? 4 : 0);
Nd4jLong zStrd4 = shape::strideAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 4 : 0);
auto func = PRAGMA_THREADS_FOR_3D {
for (auto i0 = start_x; i0 < stop_x; ++i0) {
for (auto i1 = start_y; i1 < stop_y; ++i1) {
for (auto i2 = start_z; i2 < stop_z; ++i2) {
for (uint i3 = 0; i3 < zAxis3; ++i3) {
auto x3 = x + i0 * xStrd0 + i1 * xStrd1 + i2 * xStrd2 + i3 * xStrd3;
auto y3 = y + i0 * yStrd0 + i1 * yStrd1 + i2 * yStrd2 + i3 * yStrd3;
auto z3 = z + i0 * zStrd0 + i1 * zStrd1 + i2 * zStrd2 + i3 * zStrd3;
if(zStrd4 == 1 && xStrd4 == 1 && yStrd4 == 0)
for (uint i4 = 0; i4 < zAxis4; ++i4)
z3[i4] = OpType::op(x3[i4], *y3);
else if(zStrd4 == 1 && xStrd4 == 0 && yStrd4 == 1)
for (uint i4 = 0; i4 < zAxis4; ++i4)
z3[i4] = OpType::op(*x3, y3[i4]);
else if(zStrd4 == 1 && xStrd4 == 1 && yStrd4 == 1)
for (uint i4 = 0; i4 < zAxis4; ++i4)
z3[i4] = OpType::op(x3[i4], y3[i4]);
else
for (uint i4 = 0; i4 < zAxis4; ++i4)
z3[i4 * zStrd4] = OpType::op(x3[i4 * xStrd4], y3[i4 * yStrd4]);
}
}
}
}
};
samediff::Threads::parallel_for(func, 0,zAxis0,1, 0,zAxis1,1, 0,zAxis2,1);
}
////////////////////////////////////////////////////////////////////////
template <typename X, typename OpType>
static void execDefault(const X *x, const Nd4jLong *xShapeInfo, const X *y, const Nd4jLong *yShapeInfo, X* z, const Nd4jLong *zShapeInfo) {
const bool xzSameOffsets = shape::haveSameShapeAndStrides(xShapeInfo, zShapeInfo);
const bool yzSameOffsets = shape::haveSameShapeAndStrides(yShapeInfo, zShapeInfo);
const int rank = shape::rank(zShapeInfo); // xRank = yRank = zRank
auto func = PRAGMA_THREADS_FOR{
int xCoords[MAX_RANK], yCoords[MAX_RANK], zCoords[MAX_RANK];
for (auto i = start; i < stop; ++i) {
shape::index2coordsCPU(start, i, zShapeInfo, zCoords);
for (uint j = 0; j < rank; ++j) {
xCoords[j] = shape::sizeAt(xShapeInfo, j) == 1 ? 0 : zCoords[j];
yCoords[j] = shape::sizeAt(yShapeInfo, j) == 1 ? 0 : zCoords[j];
}
const auto zOffset = shape::getOffset(zShapeInfo, zCoords);
const auto xOffset = xzSameOffsets ? zOffset : shape::getOffset(xShapeInfo, xCoords);
const auto yOffset = yzSameOffsets ? zOffset : shape::getOffset(yShapeInfo, yCoords);
z[zOffset] = OpType::op(x[xOffset], y[yOffset]);
}
};
samediff::Threads::parallel_for(func, 0, shape::length(zShapeInfo));
}
//////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////
template <typename X> template <typename X>
@ -452,220 +717,26 @@ void BroadcastInt<X>::exec(const void *vx, const Nd4jLong *xShapeInfo,
X* z = reinterpret_cast<X*>(vz); X* z = reinterpret_cast<X*>(vz);
const int rank = shape::rank(zShapeInfo); // xRank = yRank = zRank const int rank = shape::rank(zShapeInfo); // xRank = yRank = zRank
const char zOrder = shape::order(zShapeInfo);
uint xAxis0 = shape::sizeAt(xShapeInfo, zOrder == 'c' ? 0 : rank-1);
uint xAxis1 = shape::sizeAt(xShapeInfo, zOrder == 'c' ? 1 : rank-2);
uint xAxis2 = rank > 2 ? shape::sizeAt(xShapeInfo, zOrder == 'c' ? 2 : rank - 3) : 0;
uint xAxis3 = rank > 3 ? shape::sizeAt(xShapeInfo, zOrder == 'c' ? 3 : rank - 4) : 0;
uint xAxis4 = rank > 4 ? shape::sizeAt(xShapeInfo, zOrder == 'c' ? 4 : rank - 5) : 0;
Nd4jLong xStrd0 = shape::strideAt(xShapeInfo, zOrder == 'c' ? 0 : rank-1);
Nd4jLong xStrd1 = shape::strideAt(xShapeInfo, zOrder == 'c' ? 1 : rank-2);
Nd4jLong xStrd2 = rank > 2 ? shape::strideAt(xShapeInfo, zOrder == 'c' ? 2 : rank - 3) : 0;
Nd4jLong xStrd3 = rank > 3 ? shape::strideAt(xShapeInfo, zOrder == 'c' ? 3 : rank - 4) : 0;
Nd4jLong xStrd4 = rank > 4 ? shape::strideAt(xShapeInfo, zOrder == 'c' ? 4 : rank - 5) : 0;
uint yAxis0 = shape::sizeAt(yShapeInfo, zOrder == 'c' ? 0 : rank-1);
uint yAxis1 = shape::sizeAt(yShapeInfo, zOrder == 'c' ? 1 : rank-2);
uint yAxis2 = rank > 2 ? shape::sizeAt(yShapeInfo, zOrder == 'c' ? 2 : rank - 3) : 0;
uint yAxis3 = rank > 3 ? shape::sizeAt(yShapeInfo, zOrder == 'c' ? 3 : rank - 4) : 0;
uint yAxis4 = rank > 4 ? shape::sizeAt(yShapeInfo, zOrder == 'c' ? 4 : rank - 5) : 0;
Nd4jLong yStrd0 = shape::strideAt(yShapeInfo, zOrder == 'c' ? 0 : rank-1);
Nd4jLong yStrd1 = shape::strideAt(yShapeInfo, zOrder == 'c' ? 1 : rank-2);
Nd4jLong yStrd2 = rank > 2 ? shape::strideAt(yShapeInfo, zOrder == 'c' ? 2 : rank - 3) : 0;
Nd4jLong yStrd3 = rank > 3 ? shape::strideAt(yShapeInfo, zOrder == 'c' ? 3 : rank - 4) : 0;
Nd4jLong yStrd4 = rank > 4 ? shape::strideAt(yShapeInfo, zOrder == 'c' ? 4 : rank - 5) : 0;
uint zAxis0 = shape::sizeAt(zShapeInfo, zOrder == 'c' ? 0 : rank-1);
uint zAxis1 = shape::sizeAt(zShapeInfo, zOrder == 'c' ? 1 : rank-2);
uint zAxis2 = rank > 2 ? shape::sizeAt(zShapeInfo, zOrder == 'c' ? 2 : rank - 3) : 0;
uint zAxis3 = rank > 3 ? shape::sizeAt(zShapeInfo, zOrder == 'c' ? 3 : rank - 4) : 0;
uint zAxis4 = rank > 4 ? shape::sizeAt(zShapeInfo, zOrder == 'c' ? 4 : rank - 5) : 0;
Nd4jLong zStrd0 = shape::strideAt(zShapeInfo, zOrder == 'c' ? 0 : rank-1);
Nd4jLong zStrd1 = shape::strideAt(zShapeInfo, zOrder == 'c' ? 1 : rank-2);
Nd4jLong zStrd2 = rank > 2 ? shape::strideAt(zShapeInfo, zOrder == 'c' ? 2 : rank - 3) : 0;
Nd4jLong zStrd3 = rank > 3 ? shape::strideAt(zShapeInfo, zOrder == 'c' ? 3 : rank - 4) : 0;
Nd4jLong zStrd4 = rank > 4 ? shape::strideAt(zShapeInfo, zOrder == 'c' ? 4 : rank - 5) : 0;
switch (rank) { switch (rank) {
case 1: { case 1:
execRank1<X, OpType>(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo);
auto func = PRAGMA_THREADS_FOR{ break;
case 2:
if(zStrd0 == 1 && xStrd0 == 1 && yStrd0 == 0) { execRank2<X, OpType>(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo);
for (auto i0 = start; i0 < stop; ++i0) break;
z[i0] = OpType::op(x[i0], *y); case 3:
} execRank3<X, OpType>(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo);
else if(zStrd0 == 1 && xStrd0 == 0 && yStrd0 == 1) { break;
for (auto i0 = start; i0 < stop; ++i0) case 4:
z[i0] = OpType::op(*x, y[i0]); execRank4<X, OpType>(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo);
} break;
else if(zStrd0 == 1 && xStrd0 == 1 && yStrd0 == 1) { case 5:
for (auto i0 = start; i0 < stop; ++i0) execRank5<X, OpType>(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo);
z[i0] = OpType::op(x[i0], y[i0]); break;
} default:
else { execDefault<X, OpType>(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo);
for (auto i0 = start; i0 < stop; ++i0)
z[i0 * zStrd0] = OpType::op(x[i0 * xStrd0], y[i0 * yStrd0]);
}
};
samediff::Threads::parallel_tad(func, 0, zAxis0);
}
break;
case 2: {
auto func = PRAGMA_THREADS_FOR{
for (auto i0 = start; i0 < stop; ++i0) {
auto x0 = x + i0 * xStrd0;
auto y0 = y + i0 * yStrd0;
auto z0 = z + i0 * zStrd0;
if(zStrd1 == 1 && xStrd1 == 1 && yStrd1 == 0)
for (uint i1 = 0; i1 < zAxis1; ++i1)
z0[i1] = OpType::op(x0[i1], *y0);
else if(zStrd1 == 1 && xStrd1 == 0 && yStrd1 == 1)
for (uint i1 = 0; i1 < zAxis1; ++i1)
z0[i1] = OpType::op(*x0, y0[i1]);
else if(zStrd1 == 1 && xStrd1 == 1 && yStrd1 == 1)
for (uint i1 = 0; i1 < zAxis1; ++i1)
z0[i1] = OpType::op(x0[i1], y0[i1]);
else
for (uint i1 = 0; i1 < zAxis1; ++i1)
z0[i1 * zStrd1] = OpType::op(x0[i1 * xStrd1], y0[i1 * yStrd1]);
}
};
samediff::Threads::parallel_tad(func, 0, zAxis0);
}
break;
case 3: {
auto func = PRAGMA_THREADS_FOR_2D {
for (auto i0 = start_x; i0 < stop_x; ++i0) {
for (auto i1 = start_y; i1 < stop_y; ++i1) {
auto x1 = x + i0 * xStrd0 + i1 * xStrd1;
auto y1 = y + i0 * yStrd0 + i1 * yStrd1;
auto z1 = z + i0 * zStrd0 + i1 * zStrd1;
if(zStrd2 == 1 && xStrd2 == 1 && yStrd2 == 0)
for (uint i2 = 0; i2 < zAxis2; ++i2)
z1[i2] = OpType::op(x1[i2], *y1);
else if(zStrd2 == 1 && xStrd2 == 0 && yStrd2 == 1)
for (uint i2 = 0; i2 < zAxis2; ++i2)
z1[i2] = OpType::op(*x1, y1[i2]);
else if(zStrd2 == 1 && xStrd2 == 1 && yStrd2 == 1)
for (uint i2 = 0; i2 < zAxis2; ++i2)
z1[i2] = OpType::op(x1[i2], y1[i2]);
else
for (uint i2 = 0; i2 < zAxis2; ++i2)
z1[i2 * zStrd2] = OpType::op(x1[i2 * xStrd2], y1[i2 * yStrd2]);
}
}
};
samediff::Threads::parallel_for(func, 0,zAxis0,1, 0,zAxis1,1);
}
break;
case 4: {
auto func = PRAGMA_THREADS_FOR_3D {
for (auto i0 = start_x; i0 < stop_x; ++i0) {
for (auto i1 = start_y; i1 < stop_y; ++i1) {
for (auto i2 = start_z; i2 < stop_z; ++i2) {
auto x2 = x + i0 * xStrd0 + i1 * xStrd1 + i2 * xStrd2;
auto y2 = y + i0 * yStrd0 + i1 * yStrd1 + i2 * yStrd2;
auto z2 = z + i0 * zStrd0 + i1 * zStrd1 + i2 * zStrd2;
if(zStrd3 == 1 && xStrd3 == 1 && yStrd3 == 0)
for (uint i3 = 0; i3 < zAxis3; ++i3)
z2[i3] = OpType::op(x2[i3], *y2);
else if(zStrd3 == 1 && xStrd3 == 0 && yStrd3 == 1)
for (uint i3 = 0; i3 < zAxis3; ++i3)
z2[i3] = OpType::op(*x2, y2[i3]);
else if(zStrd3 == 1 && xStrd3 == 1 && yStrd3 == 1)
for (uint i3 = 0; i3 < zAxis3; ++i3)
z2[i3] = OpType::op(x2[i3], y2[i3]);
else
for (uint i3 = 0; i3 < zAxis3; ++i3)
z2[i3 * zStrd3] = OpType::op(x2[i3 * xStrd3], y2[i3 * yStrd3]);
}
}
}
};
samediff::Threads::parallel_for(func, 0,zAxis0,1, 0,zAxis1,1, 0,zAxis2,1);
}
break;
case 5: {
auto func = PRAGMA_THREADS_FOR_3D {
for (auto i0 = start_x; i0 < stop_x; ++i0) {
for (auto i1 = start_y; i1 < stop_y; ++i1) {
for (auto i2 = start_z; i2 < stop_z; ++i2) {
for (uint i3 = 0; i3 < zAxis3; ++i3) {
auto x3 = x + i0 * xStrd0 + i1 * xStrd1 + i2 * xStrd2 + i3 * xStrd3;
auto y3 = y + i0 * yStrd0 + i1 * yStrd1 + i2 * yStrd2 + i3 * yStrd3;
auto z3 = z + i0 * zStrd0 + i1 * zStrd1 + i2 * zStrd2 + i3 * zStrd3;
if(zStrd4 == 1 && xStrd4 == 1 && yStrd4 == 0)
for (uint i4 = 0; i4 < zAxis4; ++i4)
z3[i4] = OpType::op(x3[i4], *y3);
else if(zStrd4 == 1 && xStrd4 == 0 && yStrd4 == 1)
for (uint i4 = 0; i4 < zAxis4; ++i4)
z3[i4] = OpType::op(*x3, y3[i4]);
else if(zStrd4 == 1 && xStrd4 == 1 && yStrd4 == 1)
for (uint i4 = 0; i4 < zAxis4; ++i4)
z3[i4] = OpType::op(x3[i4], y3[i4]);
else
for (uint i4 = 0; i4 < zAxis4; ++i4)
z3[i4 * zStrd4] = OpType::op(x3[i4 * xStrd4], y3[i4 * yStrd4]);
}
}
}
}
};
samediff::Threads::parallel_for(func, 0,zAxis0,1, 0,zAxis1,1, 0,zAxis2,1);
}
break;
default: {
const bool xzSameOffsets = shape::haveSameShapeAndStrides(xShapeInfo, zShapeInfo);
const bool yzSameOffsets = shape::haveSameShapeAndStrides(yShapeInfo, zShapeInfo);
auto func = PRAGMA_THREADS_FOR{
int xCoords[MAX_RANK], yCoords[MAX_RANK], zCoords[MAX_RANK];
for (auto i = start; i < stop; ++i) {
shape::index2coordsCPU(start, i, zShapeInfo, zCoords);
for (uint j = 0; j < rank; ++j) {
xCoords[j] = shape::sizeAt(xShapeInfo, j) == 1 ? 0 : zCoords[j];
yCoords[j] = shape::sizeAt(yShapeInfo, j) == 1 ? 0 : zCoords[j];
}
const auto zOffset = shape::getOffset(zShapeInfo, zCoords);
const auto xOffset = xzSameOffsets ? zOffset : shape::getOffset(xShapeInfo, xCoords);
const auto yOffset = yzSameOffsets ? zOffset : shape::getOffset(yShapeInfo, yCoords);
z[zOffset] = OpType::op(x[xOffset], y[yOffset]);
}
};
samediff::Threads::parallel_for(func, 0, shape::length(zShapeInfo));
}
} }
} }

View File

@ -49,7 +49,7 @@ namespace ops {
input = a; input = a;
} }
} }
//Edge case: splitting empty array (mainly for TF import compatibility) -> return N empty arrays //Edge case: splitting empty array (mainly for TF import compatibility) -> return N empty arrays
if(input->isEmpty()){ if(input->isEmpty()){
for( int i=0; i< num_splits; i++ ){ for( int i=0; i< num_splits; i++ ){
@ -112,17 +112,17 @@ namespace ops {
inputVar = 0; inputVar = 0;
} }
} }
auto shapes = SHAPELIST(); auto shapes = SHAPELIST();
//Edge case: splitting empty array (mainly for TF import compatibility) -> return N empty arrays //Edge case: splitting empty array (mainly for TF import compatibility) -> return N empty arrays
if(INPUT_VARIABLE(inputVar)->isEmpty()){ // if(INPUT_VARIABLE(inputVar)->isEmpty()){
for (int e = 0; e < num_splits; e++) { // for (int e = 0; e < num_splits; e++) {
auto empty = ConstantShapeHelper::getInstance()->emptyShapeInfo(dataType); // auto empty = ConstantShapeHelper::getInstance()->emptyShapeInfo(dataType);
shapes->push_back(empty); // shapes->push_back(empty);
} // }
return shapes; // return shapes;
} // }
if (block.numI() == 2) if (block.numI() == 2)
axis = INT_ARG(1); axis = INT_ARG(1);
@ -135,9 +135,9 @@ namespace ops {
for (int e = 0; e < shape::rank(input); e++) for (int e = 0; e < shape::rank(input); e++)
if (e == axis) if (e == axis)
shape[e] = shape::sizeAt(input, e) / num_splits; shape[e] = shape::sizeAt(input, e) / num_splits;
else else
shape[e] = shape::sizeAt(input, e); shape[e] = shape::sizeAt(input, e);
for (int e = 0; e < num_splits; e++) { for (int e = 0; e < num_splits; e++) {
auto newShape = ConstantShapeHelper::getInstance()->createShapeInfo(dataType, shape::order(input), shape); auto newShape = ConstantShapeHelper::getInstance()->createShapeInfo(dataType, shape::order(input), shape);
shapes->push_back(newShape); shapes->push_back(newShape);

View File

@ -135,7 +135,7 @@ TEST_F(DeclarableOpsTests10, Test_Size_at_1) {
ASSERT_EQ(Status::OK(), result.status()); ASSERT_EQ(Status::OK(), result.status());
ASSERT_EQ(e, *result.at(0)); ASSERT_EQ(e, *result.at(0));
} }
//////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////
@ -170,7 +170,7 @@ TEST_F(DeclarableOpsTests10, Unique_SGO_Test_1) {
ASSERT_TRUE(exp.equalsTo(res1)); ASSERT_TRUE(exp.equalsTo(res1));
ASSERT_TRUE(expIdx.equalsTo(res2)); ASSERT_TRUE(expIdx.equalsTo(res2));
} }
//////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////
@ -187,7 +187,7 @@ TEST_F(DeclarableOpsTests10, Where_SGO_Test_1) {
ASSERT_TRUE(exp.isSameShape(resA)); ASSERT_TRUE(exp.isSameShape(resA));
ASSERT_TRUE(exp.equalsTo(resA)); ASSERT_TRUE(exp.equalsTo(resA));
// ASSERT_TRUE(expIdx.equalsTo(res.at(1))); // ASSERT_TRUE(expIdx.equalsTo(res.at(1)));
} }
//////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////
@ -204,7 +204,7 @@ TEST_F(DeclarableOpsTests10, Where_SGO_Test_02) {
ASSERT_TRUE(exp.equalsTo(resA)); ASSERT_TRUE(exp.equalsTo(resA));
ASSERT_TRUE(exp.isSameShape(resA)); ASSERT_TRUE(exp.isSameShape(resA));
// ASSERT_TRUE(expIdx.equalsTo(res.at(1))); // ASSERT_TRUE(expIdx.equalsTo(res.at(1)));
} }
//////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////
@ -228,7 +228,7 @@ TEST_F(DeclarableOpsTests10, WhereNP_SGO_Test_1) {
ASSERT_TRUE(exp2.equalsTo(res2)); ASSERT_TRUE(exp2.equalsTo(res2));
ASSERT_TRUE(exp3.equalsTo(res3)); ASSERT_TRUE(exp3.equalsTo(res3));
//ASSERT_TRUE(expIdx.equalsTo(res.at(1))); //ASSERT_TRUE(expIdx.equalsTo(res.at(1)));
} }
//////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////
@ -245,7 +245,7 @@ TEST_F(DeclarableOpsTests10, WhereNP_SGO_Test_2) {
ASSERT_TRUE(exp1.equalsTo(res.at(0))); ASSERT_TRUE(exp1.equalsTo(res.at(0)));
ASSERT_TRUE(exp2.equalsTo(res.at(1))); ASSERT_TRUE(exp2.equalsTo(res.at(1)));
//ASSERT_TRUE(expIdx.equalsTo(res.at(1))); //ASSERT_TRUE(expIdx.equalsTo(res.at(1)));
} }
//////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////
@ -263,7 +263,7 @@ TEST_F(DeclarableOpsTests10, Where_SGO_Test_2) {
ASSERT_TRUE(exp.equalsTo(resA)); ASSERT_TRUE(exp.equalsTo(resA));
ASSERT_TRUE(exp.isSameShape(resA)); ASSERT_TRUE(exp.isSameShape(resA));
// ASSERT_TRUE(expIdx.equalsTo(res.at(1))); // ASSERT_TRUE(expIdx.equalsTo(res.at(1)));
} }
//////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////
@ -281,7 +281,7 @@ TEST_F(DeclarableOpsTests10, Where_SGO_Test_3) {
ASSERT_TRUE(exp.equalsTo(resA)); ASSERT_TRUE(exp.equalsTo(resA));
ASSERT_TRUE(exp.isSameShape(resA)); ASSERT_TRUE(exp.isSameShape(resA));
// ASSERT_TRUE(expIdx.equalsTo(res.at(1))); // ASSERT_TRUE(expIdx.equalsTo(res.at(1)));
} }
//////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////
@ -300,7 +300,7 @@ TEST_F(DeclarableOpsTests10, Where_SGO_Test_4) {
//ASSERT_TRUE(exp.equalsTo(resA)); //ASSERT_TRUE(exp.equalsTo(resA));
//ASSERT_TRUE(exp.isSameShape(resA)); //ASSERT_TRUE(exp.isSameShape(resA));
// ASSERT_TRUE(expIdx.equalsTo(res.at(1))); // ASSERT_TRUE(expIdx.equalsTo(res.at(1)));
} }
//////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////
@ -318,7 +318,7 @@ TEST_F(DeclarableOpsTests10, Where_SGO_Test_5) {
ASSERT_TRUE(exp.equalsTo(resA)); ASSERT_TRUE(exp.equalsTo(resA));
ASSERT_TRUE(exp.isSameShape(resA)); ASSERT_TRUE(exp.isSameShape(resA));
// ASSERT_TRUE(expIdx.equalsTo(res.at(1))); // ASSERT_TRUE(expIdx.equalsTo(res.at(1)));
} }
//////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////
@ -337,7 +337,7 @@ TEST_F(DeclarableOpsTests10, WhereNP_SGO_Test_4) {
//ASSERT_TRUE(exp.equalsTo(resA)); //ASSERT_TRUE(exp.equalsTo(resA));
//ASSERT_TRUE(exp.isSameShape(resA)); //ASSERT_TRUE(exp.isSameShape(resA));
// ASSERT_TRUE(expIdx.equalsTo(res.at(1))); // ASSERT_TRUE(expIdx.equalsTo(res.at(1)));
} }
//////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////
@ -354,7 +354,7 @@ TEST_F(DeclarableOpsTests10, CosineDistance_SGO_Test_1) {
auto resA = res.at(0); auto resA = res.at(0);
ASSERT_TRUE(exp.equalsTo(resA)); ASSERT_TRUE(exp.equalsTo(resA));
} }
//////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////
@ -371,7 +371,7 @@ TEST_F(DeclarableOpsTests10, CosineDistance_SGO_Test_2) {
auto resA = res.at(0); auto resA = res.at(0);
ASSERT_TRUE(exp.equalsTo(resA)); ASSERT_TRUE(exp.equalsTo(resA));
} }
/////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////
@ -567,7 +567,7 @@ TEST_F(DeclarableOpsTests10, LGamma_Test1) {
ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z)); ASSERT_TRUE(exp.equalsTo(z));
} }
////////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////////
@ -764,50 +764,6 @@ TEST_F(DeclarableOpsTests10, sparse_softmax_cross_entropy_loss_with_logits_test4
ASSERT_TRUE(expected.equalsTo(output)); ASSERT_TRUE(expected.equalsTo(output));
} }
///////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests10, split_test4) {
auto input = NDArrayFactory::create<double>('c', {10},{1.f,2.f,3.f,4.f,5.f,6.f,7.f,8.f,9.f,10.f});
auto axis = NDArrayFactory::create<double>(-1);
auto exp1 = NDArrayFactory::create<double>('c', {5}, {1.f,2.f,3.f,4.f,5.f});
auto exp2 = NDArrayFactory::create<double>('c', {5}, {6.f,7.f,8.f,9.f,10.f});
sd::ops::split op;
auto results = op.evaluate({&input, &axis}, {}, {2}, {});
ASSERT_EQ(ND4J_STATUS_OK, results.status());
auto out1 = results.at(0);
auto out2 = results.at(1);
ASSERT_TRUE(exp1.isSameShape(out1));
ASSERT_TRUE(exp2.isSameShape(out2));
ASSERT_TRUE(exp1.equalsTo(out1));
ASSERT_TRUE(exp2.equalsTo(out2));
}
///////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests10, split_test5) {
auto input = NDArrayFactory::create<double>('c', {3,8},{1.f,2.f,3.f,4.f,5.f,6.f,7.f,8.f,9.f,10.f,11.f,12.f,13.f,14.f,15.f,16.f,17.f,18.f,19.f,20.f,21.f,22.f,23.f,24.f});
auto exp1 = NDArrayFactory::create<double>('c', {3,4}, {1.f,2.f,3.f,4.f, 9.f,10.f,11.f,12.f, 17.f,18.f,19.f,20.f});
auto exp2 = NDArrayFactory::create<double>('c', {3,4}, {5.f,6.f,7.f,8.f, 13.f,14.f,15.f,16.f, 21.f,22.f,23.f,24.f});
sd::ops::split op;
auto results = op.evaluate({&input}, {}, {2,-1},{});
ASSERT_EQ(ND4J_STATUS_OK, results.status());
auto out1 = results.at(0);
auto out2 = results.at(1);
ASSERT_TRUE(exp1.isSameShape(out1));
ASSERT_TRUE(exp2.isSameShape(out2));
ASSERT_TRUE(exp1.equalsTo(out1));
ASSERT_TRUE(exp2.equalsTo(out2));
}
/////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests10, histogram_fixed_width_test1) { TEST_F(DeclarableOpsTests10, histogram_fixed_width_test1) {
@ -1464,7 +1420,7 @@ TEST_F(DeclarableOpsTests10, ImageResizeBilinear_Test1_1) {
//expected.printIndexedBuffer("Expect for 10x10"); //expected.printIndexedBuffer("Expect for 10x10");
ASSERT_TRUE(expected.isSameShape(result)); ASSERT_TRUE(expected.isSameShape(result));
ASSERT_TRUE(expected.equalsTo(result)); ASSERT_TRUE(expected.equalsTo(result));
} }
TEST_F(DeclarableOpsTests10, ImageResizeBilinear_Test1_2) { TEST_F(DeclarableOpsTests10, ImageResizeBilinear_Test1_2) {
@ -1511,7 +1467,7 @@ TEST_F(DeclarableOpsTests10, ImageResizeBilinear_Test1_2) {
// expected.printBuffer("Expect for 4x5"); // expected.printBuffer("Expect for 4x5");
ASSERT_TRUE(expected.isSameShape(result)); ASSERT_TRUE(expected.isSameShape(result));
ASSERT_TRUE(expected.equalsTo(result)); ASSERT_TRUE(expected.equalsTo(result));
} }
TEST_F(DeclarableOpsTests10, ImageResizeBilinear_Test01) { TEST_F(DeclarableOpsTests10, ImageResizeBilinear_Test01) {
@ -1569,7 +1525,7 @@ TEST_F(DeclarableOpsTests10, ImageResizeBilinear_Test01) {
// expected.printShapeInfo("Expect shape"); // expected.printShapeInfo("Expect shape");
ASSERT_TRUE(expected.isSameShape(result)); ASSERT_TRUE(expected.isSameShape(result));
ASSERT_TRUE(expected.equalsTo(result)); ASSERT_TRUE(expected.equalsTo(result));
} }
TEST_F(DeclarableOpsTests10, ImageResizeBilinear_Test02) { TEST_F(DeclarableOpsTests10, ImageResizeBilinear_Test02) {
@ -1724,7 +1680,7 @@ TEST_F(DeclarableOpsTests10, ImageResizeBilinear_Test02) {
// expected.printShapeInfo("Expect shape"); // expected.printShapeInfo("Expect shape");
ASSERT_TRUE(expected.isSameShape(result)); ASSERT_TRUE(expected.isSameShape(result));
ASSERT_TRUE(expected.equalsTo(result)); ASSERT_TRUE(expected.equalsTo(result));
} }
//////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////
@ -2053,7 +2009,7 @@ TEST_F(DeclarableOpsTests10, LinSpace_Test1) {
auto res = result.at(0); auto res = result.at(0);
ASSERT_TRUE(expect.equalsTo(res)); ASSERT_TRUE(expect.equalsTo(res));
} }
//////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests10, ImageResizeNeighbor_Test1) { TEST_F(DeclarableOpsTests10, ImageResizeNeighbor_Test1) {
@ -2276,7 +2232,7 @@ TEST_F(DeclarableOpsTests10, ReduceLogSumExpTest_2) {
ASSERT_TRUE(expected.isSameShapeStrict(*result)); ASSERT_TRUE(expected.isSameShapeStrict(*result));
ASSERT_TRUE(expected.equalsTo(result)); ASSERT_TRUE(expected.equalsTo(result));
} }
//////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests10, ReduceLogSumExpTest_3) { TEST_F(DeclarableOpsTests10, ReduceLogSumExpTest_3) {
@ -2315,7 +2271,7 @@ TEST_F(DeclarableOpsTests10, Image_NonMaxSuppressing_1) {
ASSERT_TRUE(expected.isSameShapeStrict(*result)); ASSERT_TRUE(expected.isSameShapeStrict(*result));
ASSERT_TRUE(expected.equalsTo(result)); ASSERT_TRUE(expected.equalsTo(result));
} }
//////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////

View File

@ -73,7 +73,7 @@ TYPED_TEST(TypedDeclarableOpsTests4, avgpool2d_1) {
ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z)); ASSERT_TRUE(exp.equalsTo(z));
} }
////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////
@ -94,7 +94,7 @@ TYPED_TEST(TypedDeclarableOpsTests4, avgpool2d_2) {
ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z)); ASSERT_TRUE(exp.equalsTo(z));
} }
////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////
@ -115,7 +115,7 @@ TYPED_TEST(TypedDeclarableOpsTests4, avgpool2d_3) {
ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z)); ASSERT_TRUE(exp.equalsTo(z));
} }
////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////
@ -135,7 +135,7 @@ TYPED_TEST(TypedDeclarableOpsTests4, avgpool2d_4) {
ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z)); ASSERT_TRUE(exp.equalsTo(z));
} }
////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////
@ -156,7 +156,7 @@ TYPED_TEST(TypedDeclarableOpsTests4, avgpool2d_5) {
ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z)); ASSERT_TRUE(exp.equalsTo(z));
} }
////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////
@ -177,7 +177,7 @@ TYPED_TEST(TypedDeclarableOpsTests4, avgpool2d_6) {
ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z)); ASSERT_TRUE(exp.equalsTo(z));
} }
////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////
@ -198,7 +198,7 @@ TYPED_TEST(TypedDeclarableOpsTests4, avgpool2d_7) {
ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z)); ASSERT_TRUE(exp.equalsTo(z));
} }
////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////
@ -218,7 +218,7 @@ TYPED_TEST(TypedDeclarableOpsTests4, avgpool2d_8) {
ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z)); ASSERT_TRUE(exp.equalsTo(z));
} }
////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////
@ -242,7 +242,7 @@ TYPED_TEST(TypedDeclarableOpsTests4, avgpool2d_9) {
ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z)); ASSERT_TRUE(exp.equalsTo(z));
} }
////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////
@ -273,7 +273,7 @@ TYPED_TEST(TypedDeclarableOpsTests4, avgpool2d_10) {
ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z)); ASSERT_TRUE(exp.equalsTo(z));
} }
////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////
@ -339,7 +339,7 @@ TEST_F(DeclarableOpsTests4, avgpool2d_11) {
ASSERT_EQ(m, *z); ASSERT_EQ(m, *z);
} }
////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////
@ -373,7 +373,7 @@ TEST_F(DeclarableOpsTests4, avgpool2d_12) {
ASSERT_TRUE(expected.isSameShape(output)); ASSERT_TRUE(expected.isSameShape(output));
ASSERT_TRUE(expected.equalsTo(output)); ASSERT_TRUE(expected.equalsTo(output));
} }
////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////
@ -546,7 +546,7 @@ TEST_F(DeclarableOpsTests4, biasadd_1) {
ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z)); ASSERT_TRUE(exp.equalsTo(z));
} }
TEST_F(DeclarableOpsTests4, biasadd_2) { TEST_F(DeclarableOpsTests4, biasadd_2) {
@ -564,7 +564,7 @@ TEST_F(DeclarableOpsTests4, biasadd_2) {
ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z)); ASSERT_TRUE(exp.equalsTo(z));
} }
TEST_F(DeclarableOpsTests4, biasadd_3) { TEST_F(DeclarableOpsTests4, biasadd_3) {
@ -581,7 +581,7 @@ TEST_F(DeclarableOpsTests4, biasadd_3) {
ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.isSameShape(z));
} }
////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////
@ -609,7 +609,7 @@ TEST_F(DeclarableOpsTests4, biasadd_bp_1) {
ASSERT_TRUE(gradB->isSameShape(expGradB)); ASSERT_TRUE(gradB->isSameShape(expGradB));
ASSERT_TRUE(gradB->equalsTo(expGradB)); ASSERT_TRUE(gradB->equalsTo(expGradB));
} }
////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////
@ -637,7 +637,7 @@ TEST_F(DeclarableOpsTests4, biasadd_bp_2) {
ASSERT_TRUE(gradB->isSameShape(expGradB)); ASSERT_TRUE(gradB->isSameShape(expGradB));
ASSERT_TRUE(gradB->equalsTo(expGradB)); ASSERT_TRUE(gradB->equalsTo(expGradB));
} }
TEST_F(DeclarableOpsTests4, biasadd_4) { TEST_F(DeclarableOpsTests4, biasadd_4) {
@ -672,7 +672,7 @@ TEST_F(DeclarableOpsTests4, Test_Fill_1) {
ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z)); ASSERT_TRUE(exp.equalsTo(z));
} }
TEST_F(DeclarableOpsTests4, Test_FirasSparce_1) { TEST_F(DeclarableOpsTests4, Test_FirasSparce_1) {
@ -694,7 +694,7 @@ TEST_F(DeclarableOpsTests4, Test_FirasSparce_1) {
// ASSERT_TRUE(exp.isSameShape(z)); // ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z)); ASSERT_TRUE(exp.equalsTo(z));
} }
TEST_F(DeclarableOpsTests4, Test_FlattenTests_1) { TEST_F(DeclarableOpsTests4, Test_FlattenTests_1) {
@ -713,7 +713,7 @@ TEST_F(DeclarableOpsTests4, Test_FlattenTests_1) {
// z->printShapeInfo("Flatten1 shape"); // z->printShapeInfo("Flatten1 shape");
ASSERT_TRUE(exp.equalsTo(z)); ASSERT_TRUE(exp.equalsTo(z));
} }
TEST_F(DeclarableOpsTests4, Test_FlattenTests_2) { TEST_F(DeclarableOpsTests4, Test_FlattenTests_2) {
@ -734,7 +734,7 @@ TEST_F(DeclarableOpsTests4, Test_FlattenTests_2) {
ASSERT_TRUE(exp.equalsTo(z)); ASSERT_TRUE(exp.equalsTo(z));
} }
TEST_F(DeclarableOpsTests4, Test_FlattenTests_3) { TEST_F(DeclarableOpsTests4, Test_FlattenTests_3) {
@ -752,7 +752,7 @@ TEST_F(DeclarableOpsTests4, Test_FlattenTests_3) {
ASSERT_TRUE(exp.equalsTo(z)); ASSERT_TRUE(exp.equalsTo(z));
} }
TEST_F(DeclarableOpsTests4, Test_FlattenTests_4) { TEST_F(DeclarableOpsTests4, Test_FlattenTests_4) {
@ -770,7 +770,7 @@ TEST_F(DeclarableOpsTests4, Test_FlattenTests_4) {
ASSERT_TRUE(exp.equalsTo(z)); ASSERT_TRUE(exp.equalsTo(z));
} }
TEST_F(DeclarableOpsTests4, Test_FloorTests_1) { TEST_F(DeclarableOpsTests4, Test_FloorTests_1) {
@ -788,7 +788,7 @@ TEST_F(DeclarableOpsTests4, Test_FloorTests_1) {
// z->printShapeInfo("Flatten1 shape"); // z->printShapeInfo("Flatten1 shape");
ASSERT_TRUE(exp.equalsTo(z)); ASSERT_TRUE(exp.equalsTo(z));
} }
TEST_F(DeclarableOpsTests4, Test_Reshape_Again) { TEST_F(DeclarableOpsTests4, Test_Reshape_Again) {
@ -806,7 +806,7 @@ TEST_F(DeclarableOpsTests4, Test_Reshape_Again) {
ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z)); ASSERT_TRUE(exp.equalsTo(z));
} }
TEST_F(DeclarableOpsTests4, Test_Split_1) { TEST_F(DeclarableOpsTests4, Test_Split_1) {
@ -845,7 +845,7 @@ TEST_F(DeclarableOpsTests4, Test_Split_1) {
ASSERT_TRUE(sub1.equalsTo(z1)); ASSERT_TRUE(sub1.equalsTo(z1));
ASSERT_TRUE(sub2.equalsTo(z2)); ASSERT_TRUE(sub2.equalsTo(z2));
} }
// special test for TF mode, when axis goes first // special test for TF mode, when axis goes first
@ -888,7 +888,7 @@ TEST_F(DeclarableOpsTests4, Test_Split_2) {
ASSERT_TRUE(sub2.equalsTo(z2)); ASSERT_TRUE(sub2.equalsTo(z2));
ASSERT_TRUE(sub3.equalsTo(z3)); ASSERT_TRUE(sub3.equalsTo(z3));
} }
// special test for TF mode, when axis goes first // special test for TF mode, when axis goes first
@ -923,10 +923,89 @@ TEST_F(DeclarableOpsTests4, Test_Split_3) {
ASSERT_TRUE(sub0.equalsTo(z0)); ASSERT_TRUE(sub0.equalsTo(z0));
ASSERT_TRUE(sub1.equalsTo(z1)); ASSERT_TRUE(sub1.equalsTo(z1));
ASSERT_TRUE(sub2.equalsTo(z2)); ASSERT_TRUE(sub2.equalsTo(z2));
} }
///////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests4, split_test4) {
auto input = NDArrayFactory::create<double>('c', {10},{1.f,2.f,3.f,4.f,5.f,6.f,7.f,8.f,9.f,10.f});
auto axis = NDArrayFactory::create<double>(-1);
auto exp1 = NDArrayFactory::create<double>('c', {5}, {1.f,2.f,3.f,4.f,5.f});
auto exp2 = NDArrayFactory::create<double>('c', {5}, {6.f,7.f,8.f,9.f,10.f});
sd::ops::split op;
auto results = op.evaluate({&input, &axis}, {}, {2}, {});
ASSERT_EQ(ND4J_STATUS_OK, results.status());
auto out1 = results.at(0);
auto out2 = results.at(1);
ASSERT_TRUE(exp1.isSameShape(out1));
ASSERT_TRUE(exp2.isSameShape(out2));
ASSERT_TRUE(exp1.equalsTo(out1));
ASSERT_TRUE(exp2.equalsTo(out2));
}
///////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests4, split_test5) {
auto input = NDArrayFactory::create<double>('c', {3,8},{1.f,2.f,3.f,4.f,5.f,6.f,7.f,8.f,9.f,10.f,11.f,12.f,13.f,14.f,15.f,16.f,17.f,18.f,19.f,20.f,21.f,22.f,23.f,24.f});
auto exp1 = NDArrayFactory::create<double>('c', {3,4}, {1.f,2.f,3.f,4.f, 9.f,10.f,11.f,12.f, 17.f,18.f,19.f,20.f});
auto exp2 = NDArrayFactory::create<double>('c', {3,4}, {5.f,6.f,7.f,8.f, 13.f,14.f,15.f,16.f, 21.f,22.f,23.f,24.f});
sd::ops::split op;
auto results = op.evaluate({&input}, {}, {2,-1},{});
ASSERT_EQ(ND4J_STATUS_OK, results.status());
auto out1 = results.at(0);
auto out2 = results.at(1);
ASSERT_TRUE(exp1.isSameShape(out1));
ASSERT_TRUE(exp2.isSameShape(out2));
ASSERT_TRUE(exp1.equalsTo(out1));
ASSERT_TRUE(exp2.equalsTo(out2));
}
///////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests4, split_test6) {
NDArray input('c', {0,4}, sd::DataType::FLOAT32);
std::vector<Nd4jLong> expShape = {0,1};
const int numSplits = 4;
const int axis = 1;
sd::ops::split op;
auto results = op.evaluate({&input}, {}, {numSplits, axis}, {});
ASSERT_EQ(ND4J_STATUS_OK, results.status());
for (int i = 0; i < numSplits; ++i)
ASSERT_TRUE(results.at(i)->isSameShape(expShape));
}
///////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests4, split_test7) {
NDArray input('c', {0,4}, sd::DataType::FLOAT32);
std::vector<Nd4jLong> expShape = {0,4};
const int numSplits = 4;
const int axis = 0;
sd::ops::split op;
auto results = op.evaluate({&input}, {}, {numSplits, axis}, {});
ASSERT_EQ(ND4J_STATUS_OK, results.status());
for (int i = 0; i < numSplits; ++i)
ASSERT_TRUE(results.at(i)->isSameShape(expShape));
}
TEST_F(DeclarableOpsTests4, Test_Squeeze_args_1) { TEST_F(DeclarableOpsTests4, Test_Squeeze_args_1) {
auto x = NDArrayFactory::create<double>('c', {2, 1, 1, 1, 2}, {1, 2, 3, 4}); auto x = NDArrayFactory::create<double>('c', {2, 1, 1, 1, 2}, {1, 2, 3, 4});
auto exp = NDArrayFactory::create<double>('c', {2, 1, 2}, {1, 2, 3, 4}); auto exp = NDArrayFactory::create<double>('c', {2, 1, 2}, {1, 2, 3, 4});
@ -940,7 +1019,7 @@ TEST_F(DeclarableOpsTests4, Test_Squeeze_args_1) {
ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z)); ASSERT_TRUE(exp.equalsTo(z));
} }
TEST_F(DeclarableOpsTests4, Test_Squeeze_args_2) { TEST_F(DeclarableOpsTests4, Test_Squeeze_args_2) {
@ -957,7 +1036,7 @@ TEST_F(DeclarableOpsTests4, Test_Squeeze_args_2) {
ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z)); ASSERT_TRUE(exp.equalsTo(z));
} }
@ -974,7 +1053,7 @@ TEST_F(DeclarableOpsTests4, Test_Squeeze_args_3) {
ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z)); ASSERT_TRUE(exp.equalsTo(z));
} }
TEST_F(DeclarableOpsTests4, Test_SpaceToDepth_1) { TEST_F(DeclarableOpsTests4, Test_SpaceToDepth_1) {
@ -990,7 +1069,7 @@ TEST_F(DeclarableOpsTests4, Test_SpaceToDepth_1) {
ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z)); ASSERT_TRUE(exp.equalsTo(z));
} }
TEST_F(DeclarableOpsTests4, Test_SpaceToDepth_2) { TEST_F(DeclarableOpsTests4, Test_SpaceToDepth_2) {
@ -1006,7 +1085,7 @@ TEST_F(DeclarableOpsTests4, Test_SpaceToDepth_2) {
ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z)); ASSERT_TRUE(exp.equalsTo(z));
} }
@ -1023,7 +1102,7 @@ TEST_F(DeclarableOpsTests4, Test_DepthToSpace_1) {
ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z)); ASSERT_TRUE(exp.equalsTo(z));
} }
@ -1040,7 +1119,7 @@ TEST_F(DeclarableOpsTests4, Test_DepthToSpace_2) {
ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z)); ASSERT_TRUE(exp.equalsTo(z));
} }
TEST_F(DeclarableOpsTests4, Test_DepthToSpace_3) { TEST_F(DeclarableOpsTests4, Test_DepthToSpace_3) {
@ -1055,7 +1134,7 @@ TEST_F(DeclarableOpsTests4, Test_DepthToSpace_3) {
ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.isSameShape(z));
} }
@ -1073,7 +1152,7 @@ TEST_F(DeclarableOpsTests4, Test_Cross_1) {
ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z)); ASSERT_TRUE(exp.equalsTo(z));
} }
@ -1091,7 +1170,7 @@ TEST_F(DeclarableOpsTests4, Test_Cross_2) {
ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z)); ASSERT_TRUE(exp.equalsTo(z));
} }
@ -1109,7 +1188,7 @@ TEST_F(DeclarableOpsTests4, Test_Cross_3) {
ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z)); ASSERT_TRUE(exp.equalsTo(z));
} }
TEST_F(DeclarableOpsTests4, Test_Add_119) { TEST_F(DeclarableOpsTests4, Test_Add_119) {
@ -1129,7 +1208,7 @@ TEST_F(DeclarableOpsTests4, Test_Add_119) {
ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z)); ASSERT_TRUE(exp.equalsTo(z));
} }
TEST_F(DeclarableOpsTests4, Test_Reshape_Negative_1) { TEST_F(DeclarableOpsTests4, Test_Reshape_Negative_1) {
@ -1146,7 +1225,7 @@ TEST_F(DeclarableOpsTests4, Test_Reshape_Negative_1) {
ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z)); ASSERT_TRUE(exp.equalsTo(z));
} }
TEST_F(DeclarableOpsTests4, Test_TileToShape_1) { TEST_F(DeclarableOpsTests4, Test_TileToShape_1) {
@ -1165,7 +1244,7 @@ TEST_F(DeclarableOpsTests4, Test_TileToShape_1) {
ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z)); ASSERT_TRUE(exp.equalsTo(z));
} }
TEST_F(DeclarableOpsTests4, Test_StridedSlice_Alex_1) { TEST_F(DeclarableOpsTests4, Test_StridedSlice_Alex_1) {
@ -1184,7 +1263,7 @@ TEST_F(DeclarableOpsTests4, Test_StridedSlice_Alex_1) {
ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z)); ASSERT_TRUE(exp.equalsTo(z));
} }
TEST_F(DeclarableOpsTests4, Test_StridedSlice_Alex_2) { TEST_F(DeclarableOpsTests4, Test_StridedSlice_Alex_2) {
@ -1208,7 +1287,7 @@ TEST_F(DeclarableOpsTests4, Test_StridedSlice_Alex_2) {
ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z)); ASSERT_TRUE(exp.equalsTo(z));
} }
TEST_F(DeclarableOpsTests4, Test_StridedSlice_Alex_3) { TEST_F(DeclarableOpsTests4, Test_StridedSlice_Alex_3) {
@ -1229,7 +1308,7 @@ TEST_F(DeclarableOpsTests4, Test_StridedSlice_Alex_3) {
auto z = result.at(0); auto z = result.at(0);
ASSERT_TRUE(z->isEmpty()); ASSERT_TRUE(z->isEmpty());
} }
TEST_F(DeclarableOpsTests4, Test_StridedSlice_Alex_4) { TEST_F(DeclarableOpsTests4, Test_StridedSlice_Alex_4) {
auto x = NDArrayFactory::create<double>('c', {1,3}, {1, 2, 3}); auto x = NDArrayFactory::create<double>('c', {1,3}, {1, 2, 3});
@ -1248,7 +1327,7 @@ TEST_F(DeclarableOpsTests4, Test_StridedSlice_Alex_4) {
auto z = result.at(0); auto z = result.at(0);
ASSERT_TRUE(z->lengthOf() == 1); ASSERT_TRUE(z->lengthOf() == 1);
ASSERT_TRUE(exp.equalsTo(z)); ASSERT_TRUE(exp.equalsTo(z));
} }
////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////
@ -1272,7 +1351,7 @@ TEST_F(DeclarableOpsTests4, parallel_stack_test1) {
ASSERT_TRUE(expected.isSameShape(output)); ASSERT_TRUE(expected.isSameShape(output));
ASSERT_TRUE(expected.equalsTo(output)); ASSERT_TRUE(expected.equalsTo(output));
} }
////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////
@ -1293,7 +1372,7 @@ TEST_F(DeclarableOpsTests4, parallel_stack_test2) {
ASSERT_TRUE(expected.equalsTo(output)); ASSERT_TRUE(expected.equalsTo(output));
} }
////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////
@ -1313,7 +1392,7 @@ TEST_F(DeclarableOpsTests4, parallel_stack_test3) {
ASSERT_TRUE(expected.isSameShape(output)); ASSERT_TRUE(expected.isSameShape(output));
ASSERT_TRUE(expected.equalsTo(output)); ASSERT_TRUE(expected.equalsTo(output));
} }
\ \
////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////
@ -1333,7 +1412,7 @@ TEST_F(DeclarableOpsTests4, parallel_stack_test4) {
ASSERT_TRUE(expected.isSameShape(output)); ASSERT_TRUE(expected.isSameShape(output));
ASSERT_TRUE(expected.equalsTo(output)); ASSERT_TRUE(expected.equalsTo(output));
} }
////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////
@ -1353,7 +1432,7 @@ TEST_F(DeclarableOpsTests4, parallel_stack_test5) {
ASSERT_TRUE(expected.isSameShape(output)); ASSERT_TRUE(expected.isSameShape(output));
ASSERT_TRUE(expected.equalsTo(output)); ASSERT_TRUE(expected.equalsTo(output));
} }
////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////
@ -1373,7 +1452,7 @@ TEST_F(DeclarableOpsTests4, parallel_stack_test6) {
ASSERT_TRUE(expected.isSameShape(output)); ASSERT_TRUE(expected.isSameShape(output));
ASSERT_TRUE(expected.equalsTo(output)); ASSERT_TRUE(expected.equalsTo(output));
} }
////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////
@ -1390,7 +1469,7 @@ TEST_F(DeclarableOpsTests4, parallel_stack_test7) {
ASSERT_TRUE(expected.isSameShape(output)); ASSERT_TRUE(expected.isSameShape(output));
ASSERT_TRUE(expected.equalsTo(output)); ASSERT_TRUE(expected.equalsTo(output));
} }
////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////
@ -1419,7 +1498,7 @@ TEST_F(DeclarableOpsTests4, meshgrid_test1) {
ASSERT_TRUE(exp2.isSameShape(out2)); ASSERT_TRUE(exp2.isSameShape(out2));
ASSERT_TRUE(exp2.equalsTo(out2)); ASSERT_TRUE(exp2.equalsTo(out2));
} }
////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////
@ -1446,7 +1525,7 @@ TEST_F(DeclarableOpsTests4, meshgrid_test2) {
ASSERT_TRUE(exp2.isSameShape(out2)); ASSERT_TRUE(exp2.isSameShape(out2));
ASSERT_TRUE(exp2.equalsTo(out2)); ASSERT_TRUE(exp2.equalsTo(out2));
} }
////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////
@ -1473,7 +1552,7 @@ TEST_F(DeclarableOpsTests4, meshgrid_test3) {
ASSERT_TRUE(exp2.isSameShape(out2)); ASSERT_TRUE(exp2.isSameShape(out2));
ASSERT_TRUE(exp2.equalsTo(out2)); ASSERT_TRUE(exp2.equalsTo(out2));
} }
////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////
@ -1500,7 +1579,7 @@ TEST_F(DeclarableOpsTests4, meshgrid_test4) {
ASSERT_TRUE(exp2.isSameShape(out2)); ASSERT_TRUE(exp2.isSameShape(out2));
ASSERT_TRUE(exp2.equalsTo(out2)); ASSERT_TRUE(exp2.equalsTo(out2));
} }
////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////
@ -1527,7 +1606,7 @@ TEST_F(DeclarableOpsTests4, meshgrid_test5) {
ASSERT_TRUE(exp2.isSameShape(out2)); ASSERT_TRUE(exp2.isSameShape(out2));
ASSERT_TRUE(exp2.equalsTo(out2)); ASSERT_TRUE(exp2.equalsTo(out2));
} }
////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////
@ -1554,7 +1633,7 @@ TEST_F(DeclarableOpsTests4, meshgrid_test6) {
ASSERT_TRUE(exp2.isSameShape(out2)); ASSERT_TRUE(exp2.isSameShape(out2));
ASSERT_TRUE(exp2.equalsTo(out2)); ASSERT_TRUE(exp2.equalsTo(out2));
} }
////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////
@ -1581,7 +1660,7 @@ TEST_F(DeclarableOpsTests4, meshgrid_test7) {
ASSERT_TRUE(exp2.isSameShape(out2)); ASSERT_TRUE(exp2.isSameShape(out2));
ASSERT_TRUE(exp2.equalsTo(out2)); ASSERT_TRUE(exp2.equalsTo(out2));
} }
////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////
@ -1598,7 +1677,7 @@ TEST_F(DeclarableOpsTests4, meshgrid_test8) {
ASSERT_TRUE(exp0.isSameShape(out0)); ASSERT_TRUE(exp0.isSameShape(out0));
ASSERT_TRUE(exp0.equalsTo(out0)); ASSERT_TRUE(exp0.equalsTo(out0));
} }
////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////
@ -1615,7 +1694,7 @@ TEST_F(DeclarableOpsTests4, meshgrid_test9) {
ASSERT_TRUE(exp0.isSameShape(out0)); ASSERT_TRUE(exp0.isSameShape(out0));
ASSERT_TRUE(exp0.equalsTo(out0)); ASSERT_TRUE(exp0.equalsTo(out0));
} }
//////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////
@ -1644,7 +1723,7 @@ TEST_F(DeclarableOpsTests4, WeightedCrossEntropyWithLogits_1) {
ASSERT_TRUE(expected.isSameShape(output)); ASSERT_TRUE(expected.isSameShape(output));
ASSERT_TRUE(expected.equalsTo(output)); ASSERT_TRUE(expected.equalsTo(output));
} }
//////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////
@ -1664,7 +1743,7 @@ TEST_F(DeclarableOpsTests4, WeightedCrossEntropyWithLogits_2) {
ASSERT_TRUE(expected.isSameShape(output)); ASSERT_TRUE(expected.isSameShape(output));
ASSERT_TRUE(expected.equalsTo(output)); ASSERT_TRUE(expected.equalsTo(output));
} }
@ -1718,7 +1797,7 @@ TEST_F(DeclarableOpsTests4, lstm_test1) {
ASSERT_TRUE(expClast.isSameShape(&cLast)); ASSERT_TRUE(expClast.isSameShape(&cLast));
ASSERT_TRUE(expClast.equalsTo(&cLast)); ASSERT_TRUE(expClast.equalsTo(&cLast));
} }
/////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////
@ -1737,7 +1816,7 @@ TEST_F(DeclarableOpsTests4, relu6_test1) {
ASSERT_TRUE(expected.isSameShape(output)); ASSERT_TRUE(expected.isSameShape(output));
ASSERT_TRUE(expected.equalsTo(output)); ASSERT_TRUE(expected.equalsTo(output));
} }
@ -1759,7 +1838,7 @@ TEST_F(DeclarableOpsTests4, relu6_bp_test1) {
ASSERT_TRUE(expected.isSameShape(output)); ASSERT_TRUE(expected.isSameShape(output));
ASSERT_TRUE(expected.equalsTo(output)); ASSERT_TRUE(expected.equalsTo(output));
} }
//////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////
@ -1788,7 +1867,7 @@ TYPED_TEST(TypedDeclarableOpsTests4, LrnTest_1) {
// exp.printIndexedBuffer("LRN exp"); // exp.printIndexedBuffer("LRN exp");
ASSERT_TRUE(exp.equalsTo(out)); ASSERT_TRUE(exp.equalsTo(out));
} }
@ -1816,7 +1895,7 @@ TYPED_TEST(TypedDeclarableOpsTests4, LrnTest_2) {
// exp.printIndexedBuffer("LRN exp"); // exp.printIndexedBuffer("LRN exp");
ASSERT_TRUE(exp.equalsTo(out)); ASSERT_TRUE(exp.equalsTo(out));
} }
//////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////
@ -1855,7 +1934,7 @@ TYPED_TEST(TypedDeclarableOpsTests4, LrnTest_3) {
// exp.printIndexedBuffer("LRN exp"); // exp.printIndexedBuffer("LRN exp");
ASSERT_TRUE(exp.equalsTo(out)); ASSERT_TRUE(exp.equalsTo(out));
} }
//////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////
@ -1894,7 +1973,7 @@ TYPED_TEST(TypedDeclarableOpsTests4, LrnTest_4) {
// exp.printIndexedBuffer("LRN exp"); // exp.printIndexedBuffer("LRN exp");
ASSERT_TRUE(exp.equalsTo(out)); ASSERT_TRUE(exp.equalsTo(out));
} }
//////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////
@ -1939,7 +2018,7 @@ TYPED_TEST(TypedDeclarableOpsTests4, LrnTest_5) {
// exp.printIndexedBuffer("LRN exp"); // exp.printIndexedBuffer("LRN exp");
// ASSERT_TRUE(exp.equalsTo(out)); // ASSERT_TRUE(exp.equalsTo(out));
} }
@ -1962,7 +2041,7 @@ TEST_F(DeclarableOpsTests4, tri_test1) {
ASSERT_TRUE(expected.isSameShape(output)); ASSERT_TRUE(expected.isSameShape(output));
ASSERT_TRUE(expected.equalsTo(output)); ASSERT_TRUE(expected.equalsTo(output));
} }
////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////
@ -1983,7 +2062,7 @@ TEST_F(DeclarableOpsTests4, tri_test2) {
ASSERT_TRUE(expected.isSameShape(output)); ASSERT_TRUE(expected.isSameShape(output));
ASSERT_TRUE(expected.equalsTo(output)); ASSERT_TRUE(expected.equalsTo(output));
} }
////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////
@ -2004,7 +2083,7 @@ TEST_F(DeclarableOpsTests4, tri_test3) {
ASSERT_TRUE(expected.isSameShape(output)); ASSERT_TRUE(expected.isSameShape(output));
ASSERT_TRUE(expected.equalsTo(output)); ASSERT_TRUE(expected.equalsTo(output));
} }
////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////
@ -2025,7 +2104,7 @@ TEST_F(DeclarableOpsTests4, tri_test4) {
ASSERT_TRUE(expected.isSameShape(output)); ASSERT_TRUE(expected.isSameShape(output));
ASSERT_TRUE(expected.equalsTo(output)); ASSERT_TRUE(expected.equalsTo(output));
} }
////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////
@ -2044,7 +2123,7 @@ TEST_F(DeclarableOpsTests4, tri_test5) {
ASSERT_TRUE(expected.isSameShape(output)); ASSERT_TRUE(expected.isSameShape(output));
ASSERT_TRUE(expected.equalsTo(output)); ASSERT_TRUE(expected.equalsTo(output));
} }
////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////
@ -2065,7 +2144,7 @@ TEST_F(DeclarableOpsTests4, tri_test6) {
ASSERT_TRUE(expected.isSameShape(output)); ASSERT_TRUE(expected.isSameShape(output));
ASSERT_TRUE(expected.equalsTo(output)); ASSERT_TRUE(expected.equalsTo(output));
} }
////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////
@ -2086,7 +2165,7 @@ TEST_F(DeclarableOpsTests4, tri_test7) {
ASSERT_TRUE(expected.isSameShape(output)); ASSERT_TRUE(expected.isSameShape(output));
ASSERT_TRUE(expected.equalsTo(output)); ASSERT_TRUE(expected.equalsTo(output));
} }
////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////
@ -2105,7 +2184,7 @@ TEST_F(DeclarableOpsTests4, triu_test1) {
ASSERT_TRUE(expected.equalsTo(output)); ASSERT_TRUE(expected.equalsTo(output));
} }
////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////
@ -2123,7 +2202,7 @@ TEST_F(DeclarableOpsTests4, triu_test2) {
ASSERT_TRUE(expected.isSameShape(output)); ASSERT_TRUE(expected.isSameShape(output));
ASSERT_TRUE(expected.equalsTo(output)); ASSERT_TRUE(expected.equalsTo(output));
} }
////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////
@ -2141,7 +2220,7 @@ TEST_F(DeclarableOpsTests4, triu_test3) {
ASSERT_TRUE(expected.isSameShape(output)); ASSERT_TRUE(expected.isSameShape(output));
ASSERT_TRUE(expected.equalsTo(output)); ASSERT_TRUE(expected.equalsTo(output));
} }
////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////
@ -2159,7 +2238,7 @@ TEST_F(DeclarableOpsTests4, triu_test4) {
ASSERT_TRUE(expected.isSameShape(output)); ASSERT_TRUE(expected.isSameShape(output));
ASSERT_TRUE(expected.equalsTo(output)); ASSERT_TRUE(expected.equalsTo(output));
} }
////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////
@ -2177,7 +2256,7 @@ TEST_F(DeclarableOpsTests4, triu_test5) {
ASSERT_TRUE(expected.isSameShape(output)); ASSERT_TRUE(expected.isSameShape(output));
ASSERT_TRUE(expected.equalsTo(output)); ASSERT_TRUE(expected.equalsTo(output));
} }
////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////
@ -2195,7 +2274,7 @@ TEST_F(DeclarableOpsTests4, triu_test6) {
ASSERT_TRUE(expected.isSameShape(output)); ASSERT_TRUE(expected.isSameShape(output));
ASSERT_TRUE(expected.equalsTo(output)); ASSERT_TRUE(expected.equalsTo(output));
} }
////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////
@ -2213,7 +2292,7 @@ TEST_F(DeclarableOpsTests4, triu_test7) {
ASSERT_TRUE(expected.isSameShape(output)); ASSERT_TRUE(expected.isSameShape(output));
ASSERT_TRUE(expected.equalsTo(output)); ASSERT_TRUE(expected.equalsTo(output));
} }
////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////
@ -2231,7 +2310,7 @@ TEST_F(DeclarableOpsTests4, triu_test8) {
ASSERT_TRUE(expected.isSameShape(output)); ASSERT_TRUE(expected.isSameShape(output));
ASSERT_TRUE(expected.equalsTo(output)); ASSERT_TRUE(expected.equalsTo(output));
} }
////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////
@ -2249,7 +2328,7 @@ TEST_F(DeclarableOpsTests4, triu_test9) {
ASSERT_TRUE(expected.isSameShape(output)); ASSERT_TRUE(expected.isSameShape(output));
ASSERT_TRUE(expected.equalsTo(output)); ASSERT_TRUE(expected.equalsTo(output));
} }
////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////
@ -2267,7 +2346,7 @@ TEST_F(DeclarableOpsTests4, triu_test10) {
ASSERT_TRUE(expected.isSameShape(output)); ASSERT_TRUE(expected.isSameShape(output));
ASSERT_TRUE(expected.equalsTo(output)); ASSERT_TRUE(expected.equalsTo(output));
} }
////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////
@ -2285,7 +2364,7 @@ TEST_F(DeclarableOpsTests4, triu_test11) {
ASSERT_TRUE(expected.isSameShape(output)); ASSERT_TRUE(expected.isSameShape(output));
ASSERT_TRUE(expected.equalsTo(output)); ASSERT_TRUE(expected.equalsTo(output));
} }
@ -2307,7 +2386,7 @@ TEST_F(DeclarableOpsTests4, triu_bp_test1) {
ASSERT_TRUE(expected.isSameShape(gradI)); ASSERT_TRUE(expected.isSameShape(gradI));
ASSERT_TRUE(expected.equalsTo(gradI)); ASSERT_TRUE(expected.equalsTo(gradI));
} }
////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////
@ -2328,7 +2407,7 @@ TEST_F(DeclarableOpsTests4, triu_bp_test2) {
ASSERT_TRUE(expected.isSameShape(gradI)); ASSERT_TRUE(expected.isSameShape(gradI));
ASSERT_TRUE(expected.equalsTo(gradI)); ASSERT_TRUE(expected.equalsTo(gradI));
} }
////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////
@ -2349,7 +2428,7 @@ TEST_F(DeclarableOpsTests4, triu_bp_test3) {
ASSERT_TRUE(expected.isSameShape(gradI)); ASSERT_TRUE(expected.isSameShape(gradI));
ASSERT_TRUE(expected.equalsTo(gradI)); ASSERT_TRUE(expected.equalsTo(gradI));
} }
////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////
@ -2370,6 +2449,6 @@ TEST_F(DeclarableOpsTests4, triu_bp_test4) {
ASSERT_TRUE(expected.isSameShape(gradI)); ASSERT_TRUE(expected.isSameShape(gradI));
ASSERT_TRUE(expected.equalsTo(gradI)); ASSERT_TRUE(expected.equalsTo(gradI));
} }

View File

@ -57,14 +57,14 @@ TEST_F(DeclarableOpsTests9, reduceStDevBP_test3) {
// output->printIndexedBuffer(); // output->printIndexedBuffer();
ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.isSameShape(output));
ASSERT_TRUE(exp.equalsTo(output)); ASSERT_TRUE(exp.equalsTo(output));
result = op.evaluate({&x, &gradO1}, {1,0}, {1}); result = op.evaluate({&x, &gradO1}, {1,0}, {1});
ASSERT_EQ(ND4J_STATUS_OK, result.status()); ASSERT_EQ(ND4J_STATUS_OK, result.status());
output = result.at(0); output = result.at(0);
ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.isSameShape(output));
ASSERT_TRUE(exp.equalsTo(output)); ASSERT_TRUE(exp.equalsTo(output));
} }
@ -86,14 +86,14 @@ TEST_F(DeclarableOpsTests9, reduceStDevBP_test03) {
// output->printIndexedBuffer(); // output->printIndexedBuffer();
ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.isSameShape(output));
ASSERT_TRUE(exp.equalsTo(output)); ASSERT_TRUE(exp.equalsTo(output));
result = op.evaluate({&x, &gradO1}, {1,0}, {1}); result = op.evaluate({&x, &gradO1}, {1,0}, {1});
ASSERT_EQ(ND4J_STATUS_OK, result.status()); ASSERT_EQ(ND4J_STATUS_OK, result.status());
output = result.at(0); output = result.at(0);
ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.isSameShape(output));
ASSERT_TRUE(exp.equalsTo(output)); ASSERT_TRUE(exp.equalsTo(output));
} }
/* /*
@ -255,7 +255,7 @@ TEST_F(DeclarableOpsTests9, concat_test1) {
ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.isSameShape(output));
ASSERT_TRUE(exp.equalsTo(output)); ASSERT_TRUE(exp.equalsTo(output));
} }
//////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////
@ -279,7 +279,7 @@ TEST_F(DeclarableOpsTests9, concat_test2) {
ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.isSameShape(output));
ASSERT_TRUE(exp.equalsTo(output)); ASSERT_TRUE(exp.equalsTo(output));
} }
//////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////
@ -305,7 +305,7 @@ TEST_F(DeclarableOpsTests9, concat_test3) {
ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.isSameShape(output));
ASSERT_TRUE(exp.equalsTo(output)); ASSERT_TRUE(exp.equalsTo(output));
} }
//////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////
@ -325,7 +325,7 @@ TEST_F(DeclarableOpsTests9, concat_test4) {
ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.isSameShape(output));
ASSERT_TRUE(exp.equalsTo(output)); ASSERT_TRUE(exp.equalsTo(output));
} }
//////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////
@ -345,7 +345,7 @@ TEST_F(DeclarableOpsTests9, concat_test5) {
ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.isSameShape(output));
ASSERT_TRUE(exp.equalsTo(output)); ASSERT_TRUE(exp.equalsTo(output));
} }
//////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////
@ -365,7 +365,7 @@ TEST_F(DeclarableOpsTests9, concat_test6) {
ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.isSameShape(output));
ASSERT_TRUE(exp.equalsTo(output)); ASSERT_TRUE(exp.equalsTo(output));
} }
//////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////
@ -385,7 +385,7 @@ TEST_F(DeclarableOpsTests9, concat_test7) {
ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.isSameShape(output));
ASSERT_TRUE(exp.equalsTo(output)); ASSERT_TRUE(exp.equalsTo(output));
} }
//////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////
@ -403,7 +403,7 @@ TEST_F(DeclarableOpsTests9, concat_test8) {
ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.isSameShape(output));
ASSERT_TRUE(exp.equalsTo(output)); ASSERT_TRUE(exp.equalsTo(output));
} }
//////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////
@ -421,7 +421,7 @@ TEST_F(DeclarableOpsTests9, concat_test9) {
ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.isSameShape(output));
ASSERT_TRUE(exp.equalsTo(output)); ASSERT_TRUE(exp.equalsTo(output));
} }
//////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////
@ -446,7 +446,7 @@ TEST_F(DeclarableOpsTests9, concat_test10) {
ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.isSameShape(output));
ASSERT_TRUE(exp.equalsTo(output)); ASSERT_TRUE(exp.equalsTo(output));
} }
//////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////
@ -471,7 +471,7 @@ TEST_F(DeclarableOpsTests9, concat_test11) {
ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.isSameShape(output));
ASSERT_TRUE(exp.equalsTo(output)); ASSERT_TRUE(exp.equalsTo(output));
} }
//////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////
@ -496,7 +496,7 @@ TEST_F(DeclarableOpsTests9, concat_test12) {
ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.isSameShape(output));
ASSERT_TRUE(exp.equalsTo(output)); ASSERT_TRUE(exp.equalsTo(output));
} }
//////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////
@ -522,7 +522,7 @@ TEST_F(DeclarableOpsTests9, concat_test13) {
ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.isSameShape(output));
ASSERT_TRUE(exp.equalsTo(output)); ASSERT_TRUE(exp.equalsTo(output));
} }
TEST_F(DeclarableOpsTests9, concat_test14) { TEST_F(DeclarableOpsTests9, concat_test14) {
@ -548,7 +548,7 @@ TEST_F(DeclarableOpsTests9, concat_test14) {
ASSERT_NEAR((e+1)*1., mean, 1e-5); ASSERT_NEAR((e+1)*1., mean, 1e-5);
} }
} }
TEST_F(DeclarableOpsTests9, concat_test15) { TEST_F(DeclarableOpsTests9, concat_test15) {
@ -565,7 +565,7 @@ TEST_F(DeclarableOpsTests9, concat_test15) {
ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z)); ASSERT_TRUE(exp.equalsTo(z));
} }
////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////
@ -582,8 +582,6 @@ TEST_F(DeclarableOpsTests9, concat_test16) {
auto z = result.at(0); auto z = result.at(0);
ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.isSameShape(z));
} }
////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////
@ -611,8 +609,6 @@ TEST_F(DeclarableOpsTests9, concat_test17) {
auto mean = tad.meanNumber().e<double>(0); auto mean = tad.meanNumber().e<double>(0);
ASSERT_NEAR((e+1)*1., mean, 1e-5); ASSERT_NEAR((e+1)*1., mean, 1e-5);
} }
} }
////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////
@ -693,7 +689,7 @@ TEST_F(DeclarableOpsTests9, concat_test20) {
ASSERT_NEAR((double) e+1, mean, 1e-5); ASSERT_NEAR((double) e+1, mean, 1e-5);
} }
} }
//////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////
@ -775,7 +771,7 @@ TEST_F(DeclarableOpsTests9, concat_test25) {
ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.isSameShape(output));
ASSERT_TRUE(exp.equalsTo(output)); ASSERT_TRUE(exp.equalsTo(output));
} }
//////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////
@ -803,6 +799,25 @@ TEST_F(DeclarableOpsTests9, concat_test26) {
ASSERT_TRUE(exp.equalsTo(output)); ASSERT_TRUE(exp.equalsTo(output));
} }
//////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests9, concat_test27) {
auto x1 = NDArrayFactory::create<double>('c', {0,1});
auto x2 = NDArrayFactory::create<double>('c', {0,1});
auto x3 = NDArrayFactory::create<double>('c', {0,1});
auto x4 = NDArrayFactory::create<double>('c', {0,1});
std::vector<Nd4jLong> expShape = {0, 4};
sd::ops::concat op;
auto result = op.evaluate({&x1, &x2, &x3, &x4}, {}, {1});
ASSERT_EQ(ND4J_STATUS_OK, result.status());
auto z = result.at(0);
ASSERT_TRUE(z->isSameShape(expShape));
}
////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests9, tile_bp_test1) { TEST_F(DeclarableOpsTests9, tile_bp_test1) {
@ -820,7 +835,7 @@ TEST_F(DeclarableOpsTests9, tile_bp_test1) {
ASSERT_TRUE(gradIExp.isSameShape(gradI)); ASSERT_TRUE(gradIExp.isSameShape(gradI));
ASSERT_TRUE(gradIExp.equalsTo(gradI)); ASSERT_TRUE(gradIExp.equalsTo(gradI));
} }
////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////
@ -839,7 +854,7 @@ TEST_F(DeclarableOpsTests9, tile_bp_test2) {
ASSERT_TRUE(gradIExp.isSameShape(gradI)); ASSERT_TRUE(gradIExp.isSameShape(gradI));
ASSERT_TRUE(gradIExp.equalsTo(gradI)); ASSERT_TRUE(gradIExp.equalsTo(gradI));
} }
////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////
@ -859,7 +874,7 @@ TEST_F(DeclarableOpsTests9, tile_bp_test3) {
ASSERT_TRUE(gradIExp.isSameShape(gradI)); ASSERT_TRUE(gradIExp.isSameShape(gradI));
ASSERT_TRUE(gradIExp.equalsTo(gradI)); ASSERT_TRUE(gradIExp.equalsTo(gradI));
} }
////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////
@ -879,7 +894,7 @@ TEST_F(DeclarableOpsTests9, tile_bp_test4) {
ASSERT_TRUE(gradIExp.isSameShape(gradI)); ASSERT_TRUE(gradIExp.isSameShape(gradI));
ASSERT_TRUE(gradIExp.equalsTo(gradI)); ASSERT_TRUE(gradIExp.equalsTo(gradI));
} }
////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////
@ -899,7 +914,7 @@ TEST_F(DeclarableOpsTests9, tile_bp_test5) {
ASSERT_TRUE(gradIExp.isSameShape(gradI)); ASSERT_TRUE(gradIExp.isSameShape(gradI));
ASSERT_TRUE(gradIExp.equalsTo(gradI)); ASSERT_TRUE(gradIExp.equalsTo(gradI));
} }
////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////
@ -919,7 +934,7 @@ TEST_F(DeclarableOpsTests9, tile_bp_test6) {
ASSERT_TRUE(gradIExp.isSameShape(gradI)); ASSERT_TRUE(gradIExp.isSameShape(gradI));
ASSERT_TRUE(gradIExp.equalsTo(gradI)); ASSERT_TRUE(gradIExp.equalsTo(gradI));
} }
////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////
@ -940,7 +955,7 @@ TEST_F(DeclarableOpsTests9, tile_bp_test7) {
ASSERT_TRUE(gradIExp.isSameShape(gradI)); ASSERT_TRUE(gradIExp.isSameShape(gradI));
ASSERT_TRUE(gradIExp.equalsTo(gradI)); ASSERT_TRUE(gradIExp.equalsTo(gradI));
} }
////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////
@ -958,7 +973,7 @@ TEST_F(DeclarableOpsTests9, tile_test1) {
ASSERT_TRUE(expOut.isSameShape(out)); ASSERT_TRUE(expOut.isSameShape(out));
ASSERT_TRUE(expOut.equalsTo(out)); ASSERT_TRUE(expOut.equalsTo(out));
} }
//////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////
@ -975,7 +990,7 @@ TEST_F(DeclarableOpsTests9, TestDropout_BP_1) {
//ress.at(0)->printIndexedBuffer("Result is "); //ress.at(0)->printIndexedBuffer("Result is ");
//x.printIndexedBuffer("Input is"); //x.printIndexedBuffer("Input is");
ASSERT_FALSE(ress.at(0)->equalsTo(errs)); ASSERT_FALSE(ress.at(0)->equalsTo(errs));
} }
//////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////
@ -1006,8 +1021,8 @@ TEST_F(DeclarableOpsTests9, TestDropout_1) {
//res->printIndexedBuffer("FF dropout"); //res->printIndexedBuffer("FF dropout");
//res2->printIndexedBuffer("BP dropout"); //res2->printIndexedBuffer("BP dropout");
} }
TEST_F(DeclarableOpsTests9, Test_DropoutInverted_01) { TEST_F(DeclarableOpsTests9, Test_DropoutInverted_01) {
@ -1080,7 +1095,7 @@ TEST_F(DeclarableOpsTests9, Test_DropoutInverted_01) {
*/ */
// ASSERT_TRUE(exp.equalsTo(ressX->at(0))); // ASSERT_TRUE(exp.equalsTo(ressX->at(0)));
} }
TEST_F(DeclarableOpsTests9, Test_Dropout_BP_2) { TEST_F(DeclarableOpsTests9, Test_Dropout_BP_2) {
@ -1117,7 +1132,7 @@ TEST_F(DeclarableOpsTests9, Test_Dropout_BP_2) {
ASSERT_NEAR(countZero.e<float>(0), 50.f, 10.f); ASSERT_NEAR(countZero.e<float>(0), 50.f, 10.f);
// ASSERT_TRUE(exp.equalsTo(ressX->at(0))); // ASSERT_TRUE(exp.equalsTo(ressX->at(0)));
ASSERT_TRUE(ressX.at(0)->equalsTo(ressY.at(0))); ASSERT_TRUE(ressX.at(0)->equalsTo(ressY.at(0)));
} }
@ -1143,7 +1158,7 @@ TEST_F(DeclarableOpsTests9, Test_AlphaDropout_BP_1) {
//res->printIndexedBuffer("Result1AlphaBP1"); //res->printIndexedBuffer("Result1AlphaBP1");
//res2->printIndexedBuffer("Result1AlphaBP2"); //res2->printIndexedBuffer("Result1AlphaBP2");
ASSERT_TRUE(res2->equalsTo(res)); ASSERT_TRUE(res2->equalsTo(res));
} }
TEST_F(DeclarableOpsTests9, test_range_int_1) { TEST_F(DeclarableOpsTests9, test_range_int_1) {
@ -1227,7 +1242,7 @@ TEST_F(DeclarableOpsTests9, test_unstack_SGO_1) {
ASSERT_TRUE(result.at(i)->isSameShape(z[i])); ASSERT_TRUE(result.at(i)->isSameShape(z[i]));
ASSERT_TRUE(result.at(i)->equalsTo(z[i])); ASSERT_TRUE(result.at(i)->equalsTo(z[i]));
} }
} }
//////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////
@ -1263,7 +1278,7 @@ TEST_F(DeclarableOpsTests9, clipbynorm_test12) {
ASSERT_TRUE(expect.isSameShape(outFF)); ASSERT_TRUE(expect.isSameShape(outFF));
ASSERT_TRUE(expect.equalsTo(outFF)); ASSERT_TRUE(expect.equalsTo(outFF));
} }
@ -1355,7 +1370,7 @@ TEST_F(DeclarableOpsTests9, cumprod_1) {
ASSERT_EQ(Status::OK(), result.status()); ASSERT_EQ(Status::OK(), result.status());
auto z = result.at(0); auto z = result.at(0);
ASSERT_TRUE(expFF.equalsTo(z)); ASSERT_TRUE(expFF.equalsTo(z));
//************************************// //************************************//
exclusive = 1; reverse = 0; exclusive = 1; reverse = 0;
@ -1364,7 +1379,7 @@ TEST_F(DeclarableOpsTests9, cumprod_1) {
ASSERT_EQ(Status::OK(), result.status()); ASSERT_EQ(Status::OK(), result.status());
z = result.at(0); z = result.at(0);
ASSERT_TRUE(expTF.equalsTo(z)); ASSERT_TRUE(expTF.equalsTo(z));
//************************************// //************************************//
exclusive = 0; reverse = 1; exclusive = 0; reverse = 1;
@ -1373,7 +1388,7 @@ TEST_F(DeclarableOpsTests9, cumprod_1) {
ASSERT_EQ(Status::OK(), result.status()); ASSERT_EQ(Status::OK(), result.status());
z = result.at(0); z = result.at(0);
ASSERT_TRUE(expFT.equalsTo(z)); ASSERT_TRUE(expFT.equalsTo(z));
//************************************// //************************************//
exclusive = 1; reverse = 1; exclusive = 1; reverse = 1;
@ -1382,7 +1397,7 @@ TEST_F(DeclarableOpsTests9, cumprod_1) {
ASSERT_EQ(Status::OK(), result.status()); ASSERT_EQ(Status::OK(), result.status());
z = result.at(0); z = result.at(0);
ASSERT_TRUE(expTT.equalsTo(z)); ASSERT_TRUE(expTT.equalsTo(z));
} }
@ -1416,7 +1431,7 @@ TEST_F(DeclarableOpsTests9, cumprod_2) {
ASSERT_TRUE(exp.equalsTo(z)); ASSERT_TRUE(exp.equalsTo(z));
} }
////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////
@ -1584,7 +1599,7 @@ TEST_F(DeclarableOpsTests9, prelu_test1) {
ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.isSameShape(output));
ASSERT_TRUE(exp.equalsTo(output)); ASSERT_TRUE(exp.equalsTo(output));
} }
//////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////
@ -1602,7 +1617,7 @@ TEST_F(DeclarableOpsTests9, prelu_test2) {
ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.isSameShape(output));
ASSERT_TRUE(exp.equalsTo(output)); ASSERT_TRUE(exp.equalsTo(output));
} }
//////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////
@ -1620,7 +1635,7 @@ TEST_F(DeclarableOpsTests9, prelu_test3) {
ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.isSameShape(output));
ASSERT_TRUE(exp.equalsTo(output)); ASSERT_TRUE(exp.equalsTo(output));
} }
//////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////
@ -1638,7 +1653,7 @@ TEST_F(DeclarableOpsTests9, prelu_test4) {
ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.isSameShape(output));
ASSERT_TRUE(exp.equalsTo(output)); ASSERT_TRUE(exp.equalsTo(output));
} }
//////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////
@ -1656,7 +1671,7 @@ TEST_F(DeclarableOpsTests9, prelu_test5) {
ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.isSameShape(output));
ASSERT_TRUE(exp.equalsTo(output)); ASSERT_TRUE(exp.equalsTo(output));
} }
//////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////
@ -1674,7 +1689,7 @@ TEST_F(DeclarableOpsTests9, prelu_test6) {
ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.isSameShape(output));
ASSERT_TRUE(exp.equalsTo(output)); ASSERT_TRUE(exp.equalsTo(output));
} }
@ -1693,7 +1708,7 @@ TEST_F(DeclarableOpsTests9, prelu_test7) {
ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.isSameShape(output));
ASSERT_TRUE(exp.equalsTo(output)); ASSERT_TRUE(exp.equalsTo(output));
} }
//////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////
@ -1711,7 +1726,7 @@ TEST_F(DeclarableOpsTests9, prelu_test8) {
ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.isSameShape(output));
ASSERT_TRUE(exp.equalsTo(output)); ASSERT_TRUE(exp.equalsTo(output));
} }
//////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////
@ -1729,7 +1744,7 @@ TEST_F(DeclarableOpsTests9, prelu_test9) {
ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.isSameShape(output));
ASSERT_TRUE(exp.equalsTo(output)); ASSERT_TRUE(exp.equalsTo(output));
} }
//////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////
@ -1747,7 +1762,7 @@ TEST_F(DeclarableOpsTests9, prelu_test10) {
ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.isSameShape(output));
ASSERT_TRUE(exp.equalsTo(output)); ASSERT_TRUE(exp.equalsTo(output));
} }
//////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////
@ -1772,7 +1787,7 @@ TEST_F(DeclarableOpsTests9, prelu_test11) {
ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.isSameShape(output));
ASSERT_TRUE(exp.equalsTo(output)); ASSERT_TRUE(exp.equalsTo(output));
} }
//////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////
@ -1796,7 +1811,7 @@ TEST_F(DeclarableOpsTests9, prelu_test12) {
ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.isSameShape(output));
ASSERT_TRUE(exp.equalsTo(output)); ASSERT_TRUE(exp.equalsTo(output));
} }
//////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////
@ -1820,7 +1835,7 @@ TEST_F(DeclarableOpsTests9, prelu_test13) {
ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.isSameShape(output));
ASSERT_TRUE(exp.equalsTo(output)); ASSERT_TRUE(exp.equalsTo(output));
} }
//////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////
@ -1845,7 +1860,7 @@ TEST_F(DeclarableOpsTests9, prelu_test14) {
ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.isSameShape(output));
ASSERT_TRUE(exp.equalsTo(output)); ASSERT_TRUE(exp.equalsTo(output));
} }
//////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////
@ -1864,7 +1879,7 @@ TEST_F(DeclarableOpsTests9, thresholdedrelu_test1) {
ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.isSameShape(output));
ASSERT_TRUE(exp.equalsTo(output)); ASSERT_TRUE(exp.equalsTo(output));
} }
//////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////
@ -1883,7 +1898,7 @@ TEST_F(DeclarableOpsTests9, compare_and_bitpack_test1) {
// output->printIndexedBuffer("Packed to uint8"); // output->printIndexedBuffer("Packed to uint8");
ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.isSameShape(output));
ASSERT_TRUE(exp.equalsTo(output)); ASSERT_TRUE(exp.equalsTo(output));
} }
//////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////
@ -1902,7 +1917,7 @@ TEST_F(DeclarableOpsTests9, thresholdedrelu_test2) {
ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.isSameShape(output));
ASSERT_TRUE(exp.equalsTo(output)); ASSERT_TRUE(exp.equalsTo(output));
} }
//////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////
@ -2017,7 +2032,7 @@ TEST_F(DeclarableOpsTests9, multiply_test1) {
ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z)); ASSERT_TRUE(exp.equalsTo(z));
} }
//////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////
@ -2037,7 +2052,7 @@ TEST_F(DeclarableOpsTests9, multiply_test2) {
ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z)); ASSERT_TRUE(exp.equalsTo(z));
} }
//////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////
@ -2057,7 +2072,7 @@ TEST_F(DeclarableOpsTests9, multiply_test3) {
ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z)); ASSERT_TRUE(exp.equalsTo(z));
} }
//////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////
@ -2076,7 +2091,7 @@ TEST_F(DeclarableOpsTests9, multiply_test4) {
ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z)); ASSERT_TRUE(exp.equalsTo(z));
} }
//////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////
@ -2094,7 +2109,7 @@ TEST_F(DeclarableOpsTests9, multiply_test5) {
ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z)); ASSERT_TRUE(exp.equalsTo(z));
} }
//////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////
@ -2403,7 +2418,7 @@ TEST_F(DeclarableOpsTests9, gru_cell_bp_test1) {
dLdu.assign(dLdh * dhdu); dLdu.assign(dLdh * dhdu);
dLdr.assign(mmul(dLdc * dcdZc * hi, Wch.transpose())); dLdr.assign(mmul(dLdc * dcdZc * hi, Wch.transpose()));
const OpArgsHolder argsHolderBP({&x, &hi, &W, &Wc, &b, &bc, &dLdr, &dLdu, &dLdc, &dLdh}, {}, {}); const OpArgsHolder argsHolderBP({&x, &hi, &W, &Wc, &b, &bc, &dLdr, &dLdu, &dLdc, &dLdh}, {}, {});
@ -2430,7 +2445,7 @@ TEST_F(DeclarableOpsTests9, Cholesky_Test_1) {
auto res = result.at(0); auto res = result.at(0);
// res->printIndexedBuffer("Output for Cholesky1"); // res->printIndexedBuffer("Output for Cholesky1");
ASSERT_TRUE(exp.equalsTo(res)); ASSERT_TRUE(exp.equalsTo(res));
} }
//////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////
@ -2446,7 +2461,7 @@ TEST_F(DeclarableOpsTests9, Cholesky_Test_2) {
auto res = result.at(0); auto res = result.at(0);
// res->printIndexedBuffer("Output for Cholesky 2"); // res->printIndexedBuffer("Output for Cholesky 2");
ASSERT_TRUE(exp.equalsTo(res)); ASSERT_TRUE(exp.equalsTo(res));
} }
//////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////
@ -2462,7 +2477,7 @@ TEST_F(DeclarableOpsTests9, Cholesky_Test_3) {
auto res = result.at(0); auto res = result.at(0);
// res->printIndexedBuffer("Output for Cholesky 3"); // res->printIndexedBuffer("Output for Cholesky 3");
ASSERT_TRUE(exp.equalsTo(res, 1e-4)); ASSERT_TRUE(exp.equalsTo(res, 1e-4));
} }
//////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////

View File

@ -3772,6 +3772,8 @@ public abstract class BaseNDArray implements INDArray, Iterable {
INDArray ret = Nd4j.createUninitialized(this.dataType(), shape, order); INDArray ret = Nd4j.createUninitialized(this.dataType(), shape, order);
ret.setData(dup(order).data()); ret.setData(dup(order).data());
return ret; return ret;
} else if (this.isEmpty()) {
return Nd4j.create(this.dataType(), shape);
} else { } else {
INDArray ret = this.dup(order); INDArray ret = this.dup(order);
return Nd4j.create(ret.data(), shape); return Nd4j.create(ret.data(), shape);

View File

@ -4997,8 +4997,8 @@ public native @Cast("bool") boolean isOptimalRequirementsMet();
******************************************************************************/ ******************************************************************************/
// //
// This class is suited for execution results representation. // This class is suited for execution results representation.
// //
// PLESE NOTE: It will delete all stored NDArrays upon destructor call // PLESE NOTE: It will delete all stored NDArrays upon destructor call
// //
// @author raver119@gmail.com // @author raver119@gmail.com
@ -5011,7 +5011,7 @@ public native @Cast("bool") boolean isOptimalRequirementsMet();
// #include <graph/generated/result_generated.h> // #include <graph/generated/result_generated.h>
// #include <system/pointercast.h> // #include <system/pointercast.h>
// #include <system/dll.h> // forward declaration of template class NDArray // #include <system/dll.h> // forward declaration of template class NDArray
@Namespace("sd") @NoOffset public static class ResultSet extends Pointer { @Namespace("sd") @NoOffset public static class ResultSet extends Pointer {
static { Loader.load(); } static { Loader.load(); }
/** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */
@ -6877,6 +6877,9 @@ public static final int PREALLOC_SIZE = 33554432;
@Namespace("shape") public static native int sizeAt(@Cast("const Nd4jLong*") LongPointer shapeInfo, int dim); @Namespace("shape") public static native int sizeAt(@Cast("const Nd4jLong*") LongPointer shapeInfo, int dim);
@Namespace("shape") public static native int sizeAt(@Cast("const Nd4jLong*") LongBuffer shapeInfo, int dim); @Namespace("shape") public static native int sizeAt(@Cast("const Nd4jLong*") LongBuffer shapeInfo, int dim);
@Namespace("shape") public static native int sizeAt(@Cast("const Nd4jLong*") long[] shapeInfo, int dim); @Namespace("shape") public static native int sizeAt(@Cast("const Nd4jLong*") long[] shapeInfo, int dim);
@Namespace("shape") public static native @Cast("Nd4jLong") long strideAt(@Cast("const Nd4jLong*") LongPointer shapeInfo, int dim);
@Namespace("shape") public static native @Cast("Nd4jLong") long strideAt(@Cast("const Nd4jLong*") LongBuffer shapeInfo, int dim);
@Namespace("shape") public static native @Cast("Nd4jLong") long strideAt(@Cast("const Nd4jLong*") long[] shapeInfo, int dim);
@Namespace("shape") public static native void traceNew(int id); @Namespace("shape") public static native void traceNew(int id);
@ -7814,14 +7817,20 @@ public static final int PREALLOC_SIZE = 33554432;
@Namespace("shape") public static native void index2coords(@Cast("Nd4jLong") long index, int rank, @Cast("const Nd4jLong*") LongPointer shape, IntPointer coords); @Namespace("shape") public static native void index2coords(@Cast("Nd4jLong") long index, int rank, @Cast("const Nd4jLong*") LongPointer shape, IntPointer coords);
@Namespace("shape") public static native void index2coords(@Cast("Nd4jLong") long index, int rank, @Cast("const Nd4jLong*") LongBuffer shape, IntBuffer coords); @Namespace("shape") public static native void index2coords(@Cast("Nd4jLong") long index, int rank, @Cast("const Nd4jLong*") LongBuffer shape, IntBuffer coords);
@Namespace("shape") public static native void index2coords(@Cast("Nd4jLong") long index, int rank, @Cast("const Nd4jLong*") long[] shape, int[] coords); @Namespace("shape") public static native void index2coords(@Cast("Nd4jLong") long index, int rank, @Cast("const Nd4jLong*") long[] shape, int[] coords);
@Namespace("shape") public static native void index2coordsCPU(@Cast("const Nd4jLong") long startIndex, @Cast("const Nd4jLong") long index, @Cast("const Nd4jLong*") LongPointer shapeInfo, @Cast("Nd4jLong*") LongPointer coords);
@Namespace("shape") public static native void index2coordsCPU(@Cast("const Nd4jLong") long startIndex, @Cast("const Nd4jLong") long index, @Cast("const Nd4jLong*") LongBuffer shapeInfo, @Cast("Nd4jLong*") LongBuffer coords);
@Namespace("shape") public static native void index2coordsCPU(@Cast("const Nd4jLong") long startIndex, @Cast("const Nd4jLong") long index, @Cast("const Nd4jLong*") long[] shapeInfo, @Cast("Nd4jLong*") long[] coords);
@Namespace("shape") public static native void index2coordsCPU(@Cast("const Nd4jLong") long startIndex, @Cast("const Nd4jLong") long index, @Cast("const Nd4jLong*") LongPointer shapeInfo, IntPointer coords);
@Namespace("shape") public static native void index2coordsCPU(@Cast("const Nd4jLong") long startIndex, @Cast("const Nd4jLong") long index, @Cast("const Nd4jLong*") LongBuffer shapeInfo, IntBuffer coords);
@Namespace("shape") public static native void index2coordsCPU(@Cast("const Nd4jLong") long startIndex, @Cast("const Nd4jLong") long index, @Cast("const Nd4jLong*") long[] shapeInfo, int[] coords);
/** /**
* take into account only dimensions stored in tadDims, tadDims must be sorted in increasing order! * take into account only dimensions stored in tadDims, tadDims must be sorted in increasing order!
*/ */
@Namespace("shape") public static native void index2coords(@Cast("Nd4jLong") long index, @Cast("const Nd4jLong*") LongPointer shapeInfo, @Cast("Nd4jLong*") LongPointer coords, int dimsSize, @Const IntPointer tadDims); @Namespace("shape") public static native void index2coords(@Cast("Nd4jLong") long index, @Cast("const Nd4jLong*") LongPointer shapeInfo, IntPointer coords, int dimsSize, @Const IntPointer tadDims);
@Namespace("shape") public static native void index2coords(@Cast("Nd4jLong") long index, @Cast("const Nd4jLong*") LongBuffer shapeInfo, @Cast("Nd4jLong*") LongBuffer coords, int dimsSize, @Const IntBuffer tadDims); @Namespace("shape") public static native void index2coords(@Cast("Nd4jLong") long index, @Cast("const Nd4jLong*") LongBuffer shapeInfo, IntBuffer coords, int dimsSize, @Const IntBuffer tadDims);
@Namespace("shape") public static native void index2coords(@Cast("Nd4jLong") long index, @Cast("const Nd4jLong*") long[] shapeInfo, @Cast("Nd4jLong*") long[] coords, int dimsSize, @Const int[] tadDims); @Namespace("shape") public static native void index2coords(@Cast("Nd4jLong") long index, @Cast("const Nd4jLong*") long[] shapeInfo, int[] coords, int dimsSize, @Const int[] tadDims);
/** /**
* Convert coordinates to the corresponding linear index (sequence number in other words) * Convert coordinates to the corresponding linear index (sequence number in other words)
@ -7833,15 +7842,15 @@ public static final int PREALLOC_SIZE = 33554432;
@Namespace("shape") public static native @Cast("Nd4jLong") long coords2index(@Cast("const Nd4jLong*") LongPointer shapeInfo, @Const IntPointer coords); @Namespace("shape") public static native @Cast("Nd4jLong") long coords2index(@Cast("const Nd4jLong*") LongPointer shapeInfo, @Const IntPointer coords);
@Namespace("shape") public static native @Cast("Nd4jLong") long coords2index(@Cast("const Nd4jLong*") LongBuffer shapeInfo, @Const IntBuffer coords); @Namespace("shape") public static native @Cast("Nd4jLong") long coords2index(@Cast("const Nd4jLong*") LongBuffer shapeInfo, @Const IntBuffer coords);
@Namespace("shape") public static native @Cast("Nd4jLong") long coords2index(@Cast("const Nd4jLong*") long[] shapeInfo, @Const int[] coords); @Namespace("shape") public static native @Cast("Nd4jLong") long coords2index(@Cast("const Nd4jLong*") long[] shapeInfo, @Const int[] coords);
@Namespace("shape") public static native @Cast("Nd4jLong") long coords2index(int rank, @Cast("const Nd4jLong*") LongPointer shape, @Cast("const Nd4jLong*") LongPointer coords); @Namespace("shape") public static native @Cast("Nd4jLong") long coords2index(int rank, @Cast("const Nd4jLong*") LongPointer shape, @Const IntPointer coords);
@Namespace("shape") public static native @Cast("Nd4jLong") long coords2index(int rank, @Cast("const Nd4jLong*") LongBuffer shape, @Cast("const Nd4jLong*") LongBuffer coords); @Namespace("shape") public static native @Cast("Nd4jLong") long coords2index(int rank, @Cast("const Nd4jLong*") LongBuffer shape, @Const IntBuffer coords);
@Namespace("shape") public static native @Cast("Nd4jLong") long coords2index(int rank, @Cast("const Nd4jLong*") long[] shape, @Cast("const Nd4jLong*") long[] coords); @Namespace("shape") public static native @Cast("Nd4jLong") long coords2index(int rank, @Cast("const Nd4jLong*") long[] shape, @Const int[] coords);
/** /**
* take into account only dimensions stored in tadDims, tadDims must be sorted in increasing order! * take into account only dimensions stored in tadDims, tadDims must be sorted in increasing order!
*/ */
@Namespace("shape") public static native @Cast("Nd4jLong") long coords2index(@Cast("const Nd4jLong*") LongPointer shapeInfo, @Cast("const Nd4jLong*") LongPointer coords, int dimsSize, @Const IntPointer tadDims); @Namespace("shape") public static native @Cast("Nd4jLong") long coords2index(@Cast("const Nd4jLong*") LongPointer shapeInfo, @Const IntPointer coords, int dimsSize, @Const IntPointer tadDims);
@Namespace("shape") public static native @Cast("Nd4jLong") long coords2index(@Cast("const Nd4jLong*") LongBuffer shapeInfo, @Cast("const Nd4jLong*") LongBuffer coords, int dimsSize, @Const IntBuffer tadDims); @Namespace("shape") public static native @Cast("Nd4jLong") long coords2index(@Cast("const Nd4jLong*") LongBuffer shapeInfo, @Const IntBuffer coords, int dimsSize, @Const IntBuffer tadDims);
@Namespace("shape") public static native @Cast("Nd4jLong") long coords2index(@Cast("const Nd4jLong*") long[] shapeInfo, @Cast("const Nd4jLong*") long[] coords, int dimsSize, @Const int[] tadDims); @Namespace("shape") public static native @Cast("Nd4jLong") long coords2index(@Cast("const Nd4jLong*") long[] shapeInfo, @Const int[] coords, int dimsSize, @Const int[] tadDims);
/** /**
* increment n-dimensional array by one iteration by changing coord appropriately * increment n-dimensional array by one iteration by changing coord appropriately
@ -7931,32 +7940,32 @@ public static final int PREALLOC_SIZE = 33554432;
// function calculates the coordinates of min array (and saves them into minIdxs) given coordinates of max array (already stored in maxIdxs) // function calculates the coordinates of min array (and saves them into minIdxs) given coordinates of max array (already stored in maxIdxs)
// dimsToExclude - should be sorted in increasing order // dimsToExclude - should be sorted in increasing order
// dimsLen - length of dimsToExclude, if not set (= -1), then it is calculated as maxRank - minRank // dimsLen - length of dimsToExclude, if not set (= -1), then it is calculated as maxRank - minRank
@Namespace("shape") public static native void maxIndToMinInd(@Cast("Nd4jLong*") LongPointer maxIdxs, @Cast("Nd4jLong*") LongPointer minIdxs, @Cast("const Nd4jLong*") LongPointer maxShapeInfo, @Cast("const Nd4jLong*") LongPointer minShapeInfo, @Const IntPointer dimsToExclude/*=nullptr*/, int dimsLen/*=-1*/); @Namespace("shape") public static native void maxIndToMinInd(IntPointer maxIdxs, IntPointer minIdxs, @Cast("const Nd4jLong*") LongPointer maxShapeInfo, @Cast("const Nd4jLong*") LongPointer minShapeInfo, @Const IntPointer dimsToExclude/*=nullptr*/, int dimsLen/*=-1*/);
@Namespace("shape") public static native void maxIndToMinInd(@Cast("Nd4jLong*") LongPointer maxIdxs, @Cast("Nd4jLong*") LongPointer minIdxs, @Cast("const Nd4jLong*") LongPointer maxShapeInfo, @Cast("const Nd4jLong*") LongPointer minShapeInfo); @Namespace("shape") public static native void maxIndToMinInd(IntPointer maxIdxs, IntPointer minIdxs, @Cast("const Nd4jLong*") LongPointer maxShapeInfo, @Cast("const Nd4jLong*") LongPointer minShapeInfo);
@Namespace("shape") public static native void maxIndToMinInd(@Cast("Nd4jLong*") LongBuffer maxIdxs, @Cast("Nd4jLong*") LongBuffer minIdxs, @Cast("const Nd4jLong*") LongBuffer maxShapeInfo, @Cast("const Nd4jLong*") LongBuffer minShapeInfo, @Const IntBuffer dimsToExclude/*=nullptr*/, int dimsLen/*=-1*/); @Namespace("shape") public static native void maxIndToMinInd(IntBuffer maxIdxs, IntBuffer minIdxs, @Cast("const Nd4jLong*") LongBuffer maxShapeInfo, @Cast("const Nd4jLong*") LongBuffer minShapeInfo, @Const IntBuffer dimsToExclude/*=nullptr*/, int dimsLen/*=-1*/);
@Namespace("shape") public static native void maxIndToMinInd(@Cast("Nd4jLong*") LongBuffer maxIdxs, @Cast("Nd4jLong*") LongBuffer minIdxs, @Cast("const Nd4jLong*") LongBuffer maxShapeInfo, @Cast("const Nd4jLong*") LongBuffer minShapeInfo); @Namespace("shape") public static native void maxIndToMinInd(IntBuffer maxIdxs, IntBuffer minIdxs, @Cast("const Nd4jLong*") LongBuffer maxShapeInfo, @Cast("const Nd4jLong*") LongBuffer minShapeInfo);
@Namespace("shape") public static native void maxIndToMinInd(@Cast("Nd4jLong*") long[] maxIdxs, @Cast("Nd4jLong*") long[] minIdxs, @Cast("const Nd4jLong*") long[] maxShapeInfo, @Cast("const Nd4jLong*") long[] minShapeInfo, @Const int[] dimsToExclude/*=nullptr*/, int dimsLen/*=-1*/); @Namespace("shape") public static native void maxIndToMinInd(int[] maxIdxs, int[] minIdxs, @Cast("const Nd4jLong*") long[] maxShapeInfo, @Cast("const Nd4jLong*") long[] minShapeInfo, @Const int[] dimsToExclude/*=nullptr*/, int dimsLen/*=-1*/);
@Namespace("shape") public static native void maxIndToMinInd(@Cast("Nd4jLong*") long[] maxIdxs, @Cast("Nd4jLong*") long[] minIdxs, @Cast("const Nd4jLong*") long[] maxShapeInfo, @Cast("const Nd4jLong*") long[] minShapeInfo); @Namespace("shape") public static native void maxIndToMinInd(int[] maxIdxs, int[] minIdxs, @Cast("const Nd4jLong*") long[] maxShapeInfo, @Cast("const Nd4jLong*") long[] minShapeInfo);
// calculate indexes of max-array, these output indexes correspond to one minIdx index of min-array which is sub-array of max-array // calculate indexes of max-array, these output indexes correspond to one minIdx index of min-array which is sub-array of max-array
// dimsToExclude - should be sorted in increasing order // dimsToExclude - should be sorted in increasing order
@Namespace("shape") public static native int outerArrayIndexes(@Cast("Nd4jLong*") LongPointer maxIdxs, @Cast("const Nd4jLong") long minIdx, @Cast("const Nd4jLong*") LongPointer maxShapeInfo, @Cast("const Nd4jLong*") LongPointer minShapeInfo, @Const IntPointer dimsToExclude/*=nullptr*/); @Namespace("shape") public static native int outerArrayIndexes(IntPointer maxIdxs, @Cast("const Nd4jLong") long minIdx, @Cast("const Nd4jLong*") LongPointer maxShapeInfo, @Cast("const Nd4jLong*") LongPointer minShapeInfo, @Const IntPointer dimsToExclude/*=nullptr*/);
@Namespace("shape") public static native int outerArrayIndexes(@Cast("Nd4jLong*") LongPointer maxIdxs, @Cast("const Nd4jLong") long minIdx, @Cast("const Nd4jLong*") LongPointer maxShapeInfo, @Cast("const Nd4jLong*") LongPointer minShapeInfo); @Namespace("shape") public static native int outerArrayIndexes(IntPointer maxIdxs, @Cast("const Nd4jLong") long minIdx, @Cast("const Nd4jLong*") LongPointer maxShapeInfo, @Cast("const Nd4jLong*") LongPointer minShapeInfo);
@Namespace("shape") public static native int outerArrayIndexes(@Cast("Nd4jLong*") LongBuffer maxIdxs, @Cast("const Nd4jLong") long minIdx, @Cast("const Nd4jLong*") LongBuffer maxShapeInfo, @Cast("const Nd4jLong*") LongBuffer minShapeInfo, @Const IntBuffer dimsToExclude/*=nullptr*/); @Namespace("shape") public static native int outerArrayIndexes(IntBuffer maxIdxs, @Cast("const Nd4jLong") long minIdx, @Cast("const Nd4jLong*") LongBuffer maxShapeInfo, @Cast("const Nd4jLong*") LongBuffer minShapeInfo, @Const IntBuffer dimsToExclude/*=nullptr*/);
@Namespace("shape") public static native int outerArrayIndexes(@Cast("Nd4jLong*") LongBuffer maxIdxs, @Cast("const Nd4jLong") long minIdx, @Cast("const Nd4jLong*") LongBuffer maxShapeInfo, @Cast("const Nd4jLong*") LongBuffer minShapeInfo); @Namespace("shape") public static native int outerArrayIndexes(IntBuffer maxIdxs, @Cast("const Nd4jLong") long minIdx, @Cast("const Nd4jLong*") LongBuffer maxShapeInfo, @Cast("const Nd4jLong*") LongBuffer minShapeInfo);
@Namespace("shape") public static native int outerArrayIndexes(@Cast("Nd4jLong*") long[] maxIdxs, @Cast("const Nd4jLong") long minIdx, @Cast("const Nd4jLong*") long[] maxShapeInfo, @Cast("const Nd4jLong*") long[] minShapeInfo, @Const int[] dimsToExclude/*=nullptr*/); @Namespace("shape") public static native int outerArrayIndexes(int[] maxIdxs, @Cast("const Nd4jLong") long minIdx, @Cast("const Nd4jLong*") long[] maxShapeInfo, @Cast("const Nd4jLong*") long[] minShapeInfo, @Const int[] dimsToExclude/*=nullptr*/);
@Namespace("shape") public static native int outerArrayIndexes(@Cast("Nd4jLong*") long[] maxIdxs, @Cast("const Nd4jLong") long minIdx, @Cast("const Nd4jLong*") long[] maxShapeInfo, @Cast("const Nd4jLong*") long[] minShapeInfo); @Namespace("shape") public static native int outerArrayIndexes(int[] maxIdxs, @Cast("const Nd4jLong") long minIdx, @Cast("const Nd4jLong*") long[] maxShapeInfo, @Cast("const Nd4jLong*") long[] minShapeInfo);
// calculate offsets of max-array, these offsets correspond to one minIdx index of min-array which is sub-array of max-array // calculate offsets of max-array, these offsets correspond to one minIdx index of min-array which is sub-array of max-array
// maxOffsets - will contain calculated offsets of max-array, buffer for maxOffsets should be allocated beforehand // maxOffsets - will contain calculated offsets of max-array, buffer for maxOffsets should be allocated beforehand
// dimsToExclude - should be sorted in increasing order // dimsToExclude - should be sorted in increasing order
// memBuff - auxiliary memory buffer (size = 2 * max_rank) for coordinates and increments storing, should be allocated beforehand // memBuff - auxiliary memory buffer (size = 2 * max_rank) for coordinates and increments storing, should be allocated beforehand
@Namespace("shape") public static native int outerArrayOffsets(@Cast("Nd4jLong*") LongPointer maxOffsets, @Cast("const Nd4jLong") long minIdx, @Cast("const Nd4jLong*") LongPointer maxShapeInfo, @Cast("const Nd4jLong*") LongPointer minShapeInfo, @Cast("Nd4jLong*") LongPointer memBuff, @Const IntPointer dimsToExclude/*=nullptr*/); @Namespace("shape") public static native int outerArrayOffsets(@Cast("Nd4jLong*") LongPointer maxOffsets, @Cast("const Nd4jLong") long minIdx, @Cast("const Nd4jLong*") LongPointer maxShapeInfo, @Cast("const Nd4jLong*") LongPointer minShapeInfo, IntPointer memBuff, @Const IntPointer dimsToExclude/*=nullptr*/);
@Namespace("shape") public static native int outerArrayOffsets(@Cast("Nd4jLong*") LongPointer maxOffsets, @Cast("const Nd4jLong") long minIdx, @Cast("const Nd4jLong*") LongPointer maxShapeInfo, @Cast("const Nd4jLong*") LongPointer minShapeInfo, @Cast("Nd4jLong*") LongPointer memBuff); @Namespace("shape") public static native int outerArrayOffsets(@Cast("Nd4jLong*") LongPointer maxOffsets, @Cast("const Nd4jLong") long minIdx, @Cast("const Nd4jLong*") LongPointer maxShapeInfo, @Cast("const Nd4jLong*") LongPointer minShapeInfo, IntPointer memBuff);
@Namespace("shape") public static native int outerArrayOffsets(@Cast("Nd4jLong*") LongBuffer maxOffsets, @Cast("const Nd4jLong") long minIdx, @Cast("const Nd4jLong*") LongBuffer maxShapeInfo, @Cast("const Nd4jLong*") LongBuffer minShapeInfo, @Cast("Nd4jLong*") LongBuffer memBuff, @Const IntBuffer dimsToExclude/*=nullptr*/); @Namespace("shape") public static native int outerArrayOffsets(@Cast("Nd4jLong*") LongBuffer maxOffsets, @Cast("const Nd4jLong") long minIdx, @Cast("const Nd4jLong*") LongBuffer maxShapeInfo, @Cast("const Nd4jLong*") LongBuffer minShapeInfo, IntBuffer memBuff, @Const IntBuffer dimsToExclude/*=nullptr*/);
@Namespace("shape") public static native int outerArrayOffsets(@Cast("Nd4jLong*") LongBuffer maxOffsets, @Cast("const Nd4jLong") long minIdx, @Cast("const Nd4jLong*") LongBuffer maxShapeInfo, @Cast("const Nd4jLong*") LongBuffer minShapeInfo, @Cast("Nd4jLong*") LongBuffer memBuff); @Namespace("shape") public static native int outerArrayOffsets(@Cast("Nd4jLong*") LongBuffer maxOffsets, @Cast("const Nd4jLong") long minIdx, @Cast("const Nd4jLong*") LongBuffer maxShapeInfo, @Cast("const Nd4jLong*") LongBuffer minShapeInfo, IntBuffer memBuff);
@Namespace("shape") public static native int outerArrayOffsets(@Cast("Nd4jLong*") long[] maxOffsets, @Cast("const Nd4jLong") long minIdx, @Cast("const Nd4jLong*") long[] maxShapeInfo, @Cast("const Nd4jLong*") long[] minShapeInfo, @Cast("Nd4jLong*") long[] memBuff, @Const int[] dimsToExclude/*=nullptr*/); @Namespace("shape") public static native int outerArrayOffsets(@Cast("Nd4jLong*") long[] maxOffsets, @Cast("const Nd4jLong") long minIdx, @Cast("const Nd4jLong*") long[] maxShapeInfo, @Cast("const Nd4jLong*") long[] minShapeInfo, int[] memBuff, @Const int[] dimsToExclude/*=nullptr*/);
@Namespace("shape") public static native int outerArrayOffsets(@Cast("Nd4jLong*") long[] maxOffsets, @Cast("const Nd4jLong") long minIdx, @Cast("const Nd4jLong*") long[] maxShapeInfo, @Cast("const Nd4jLong*") long[] minShapeInfo, @Cast("Nd4jLong*") long[] memBuff); @Namespace("shape") public static native int outerArrayOffsets(@Cast("Nd4jLong*") long[] maxOffsets, @Cast("const Nd4jLong") long minIdx, @Cast("const Nd4jLong*") long[] maxShapeInfo, @Cast("const Nd4jLong*") long[] minShapeInfo, int[] memBuff);
// calculates offsets for entities (elements or sub-arrays), shape in context of sub-array means dimensions excluded from outer array // calculates offsets for entities (elements or sub-arrays), shape in context of sub-array means dimensions excluded from outer array
// rank is equal to size of shape // rank is equal to size of shape
@ -8052,9 +8061,7 @@ public static final int PREALLOC_SIZE = 33554432;
* get stride over contiguous axis (contiguous axis must have stride = 1) * get stride over contiguous axis (contiguous axis must have stride = 1)
* for example when inShapeInfo is {4, 2,5,4,3, 60,1,5,20, 16384,0,99} then output is 5 (that is smallest stride in inShapeInfo except those equal to 1) * for example when inShapeInfo is {4, 2,5,4,3, 60,1,5,20, 16384,0,99} then output is 5 (that is smallest stride in inShapeInfo except those equal to 1)
*/ */
@Namespace("shape") public static native @Cast("Nd4jLong") long strideOverContigAxis(int axis, @Cast("const Nd4jLong*") LongPointer inShapeInfo); // INLINEDEF _CUDA_HD Nd4jLong strideOverContigAxis(const int axis, const Nd4jLong* inShapeInfo);
@Namespace("shape") public static native @Cast("Nd4jLong") long strideOverContigAxis(int axis, @Cast("const Nd4jLong*") LongBuffer inShapeInfo);
@Namespace("shape") public static native @Cast("Nd4jLong") long strideOverContigAxis(int axis, @Cast("const Nd4jLong*") long[] inShapeInfo);
@ -8930,6 +8937,10 @@ public static final int PREALLOC_SIZE = 33554432;
////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////
//////////////////////////////////////////////////////////////////////
//////////////////////////////////////////////////////////////////////
////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////
// INLINEDEF _CUDA_HD void calcOffsets(const Nd4jLong *xShapeInfo, Nd4jLong*& xOffsets, const Nd4jLong *yShapeInfo, Nd4jLong*& yOffsets, const Nd4jLong* zShapeInfo, Nd4jLong*& zOffsets, const char order) { // INLINEDEF _CUDA_HD void calcOffsets(const Nd4jLong *xShapeInfo, Nd4jLong*& xOffsets, const Nd4jLong *yShapeInfo, Nd4jLong*& yOffsets, const Nd4jLong* zShapeInfo, Nd4jLong*& zOffsets, const char order) {
@ -9126,6 +9137,23 @@ public static final int PREALLOC_SIZE = 33554432;
////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////
// INLINEDEF _CUDA_HD Nd4jLong strideOverContigAxis(const int axis, const Nd4jLong* inShapeInfo) {
// Nd4jLong result = 9223372036854775807LL;
// for(uint i = 0; i < shape::rank(inShapeInfo); ++i) {
// const auto currentStride = shape::stride(inShapeInfo)[i];
// if(i == axis || shape::shapeOf(inShapeInfo)[i] == 1)
// continue;
// if(result > currentStride)
// result = currentStride;
// }
// return result == 9223372036854775807LL ? 1 : result;
// }
@ -9736,18 +9764,17 @@ public static final int PREALLOC_SIZE = 33554432;
public native @Cast("Nd4jStatus") int execute(@Const @ByRef NDArrayVector inputs, @Const @ByRef NDArrayVector outputs, @StdVector DoubleBuffer tArgs, @Cast("Nd4jLong*") @StdVector LongBuffer iArgs, @Cast("bool*") @StdVector BooleanPointer bArgs/*=std::vector<bool>()*/, @Cast("sd::DataType*") @StdVector IntBuffer dArgs/*=std::vector<sd::DataType>()*/, @Cast("bool") boolean isInplace/*=false*/); public native @Cast("Nd4jStatus") int execute(@Const @ByRef NDArrayVector inputs, @Const @ByRef NDArrayVector outputs, @StdVector DoubleBuffer tArgs, @Cast("Nd4jLong*") @StdVector LongBuffer iArgs, @Cast("bool*") @StdVector BooleanPointer bArgs/*=std::vector<bool>()*/, @Cast("sd::DataType*") @StdVector IntBuffer dArgs/*=std::vector<sd::DataType>()*/, @Cast("bool") boolean isInplace/*=false*/);
public native @Cast("Nd4jStatus") int execute(@Const @ByRef NDArrayVector inputs, @Const @ByRef NDArrayVector outputs, @StdVector double[] tArgs, @Cast("Nd4jLong*") @StdVector long[] iArgs, @Cast("bool*") @StdVector boolean[] bArgs/*=std::vector<bool>()*/, @Cast("sd::DataType*") @StdVector int[] dArgs/*=std::vector<sd::DataType>()*/, @Cast("bool") boolean isInplace/*=false*/); public native @Cast("Nd4jStatus") int execute(@Const @ByRef NDArrayVector inputs, @Const @ByRef NDArrayVector outputs, @StdVector double[] tArgs, @Cast("Nd4jLong*") @StdVector long[] iArgs, @Cast("bool*") @StdVector boolean[] bArgs/*=std::vector<bool>()*/, @Cast("sd::DataType*") @StdVector int[] dArgs/*=std::vector<sd::DataType>()*/, @Cast("bool") boolean isInplace/*=false*/);
public native @ByVal ResultSet evaluate(@Const @ByRef NDArrayVector inputs);
public native ResultSet evaluate(@Const @ByRef NDArrayVector inputs); public native @ByVal ResultSet evaluate(@Const @ByRef NDArrayVector inputs, @StdVector DoublePointer tArgs, @Cast("Nd4jLong*") @StdVector LongPointer iArgs, @Cast("bool*") @StdVector BooleanPointer bArgs/*=std::vector<bool>()*/, @Cast("sd::DataType*") @StdVector IntPointer dArgs/*=std::vector<sd::DataType>()*/, @Cast("bool") boolean isInplace/*=false*/);
public native @ByVal ResultSet evaluate(@Const @ByRef NDArrayVector inputs, @StdVector DoublePointer tArgs, @Cast("Nd4jLong*") @StdVector LongPointer iArgs);
public native ResultSet evaluate(@Const @ByRef NDArrayVector inputs, @StdVector DoublePointer tArgs, @Cast("Nd4jLong*") @StdVector LongPointer iArgs, @Cast("bool*") @StdVector BooleanPointer bArgs/*=std::vector<bool>()*/, @Cast("sd::DataType*") @StdVector IntPointer dArgs/*=std::vector<sd::DataType>()*/, @Cast("bool") boolean isInplace/*=false*/); public native @ByVal ResultSet evaluate(@Const @ByRef NDArrayVector inputs, @StdVector DoubleBuffer tArgs, @Cast("Nd4jLong*") @StdVector LongBuffer iArgs, @Cast("bool*") @StdVector boolean[] bArgs/*=std::vector<bool>()*/, @Cast("sd::DataType*") @StdVector IntBuffer dArgs/*=std::vector<sd::DataType>()*/, @Cast("bool") boolean isInplace/*=false*/);
public native ResultSet evaluate(@Const @ByRef NDArrayVector inputs, @StdVector DoublePointer tArgs, @Cast("Nd4jLong*") @StdVector LongPointer iArgs); public native @ByVal ResultSet evaluate(@Const @ByRef NDArrayVector inputs, @StdVector DoubleBuffer tArgs, @Cast("Nd4jLong*") @StdVector LongBuffer iArgs);
public native ResultSet evaluate(@Const @ByRef NDArrayVector inputs, @StdVector DoubleBuffer tArgs, @Cast("Nd4jLong*") @StdVector LongBuffer iArgs, @Cast("bool*") @StdVector boolean[] bArgs/*=std::vector<bool>()*/, @Cast("sd::DataType*") @StdVector IntBuffer dArgs/*=std::vector<sd::DataType>()*/, @Cast("bool") boolean isInplace/*=false*/); public native @ByVal ResultSet evaluate(@Const @ByRef NDArrayVector inputs, @StdVector double[] tArgs, @Cast("Nd4jLong*") @StdVector long[] iArgs, @Cast("bool*") @StdVector BooleanPointer bArgs/*=std::vector<bool>()*/, @Cast("sd::DataType*") @StdVector int[] dArgs/*=std::vector<sd::DataType>()*/, @Cast("bool") boolean isInplace/*=false*/);
public native ResultSet evaluate(@Const @ByRef NDArrayVector inputs, @StdVector DoubleBuffer tArgs, @Cast("Nd4jLong*") @StdVector LongBuffer iArgs); public native @ByVal ResultSet evaluate(@Const @ByRef NDArrayVector inputs, @StdVector double[] tArgs, @Cast("Nd4jLong*") @StdVector long[] iArgs);
public native ResultSet evaluate(@Const @ByRef NDArrayVector inputs, @StdVector double[] tArgs, @Cast("Nd4jLong*") @StdVector long[] iArgs, @Cast("bool*") @StdVector BooleanPointer bArgs/*=std::vector<bool>()*/, @Cast("sd::DataType*") @StdVector int[] dArgs/*=std::vector<sd::DataType>()*/, @Cast("bool") boolean isInplace/*=false*/); public native @ByVal ResultSet evaluate(@Const @ByRef NDArrayVector inputs, @StdVector DoublePointer tArgs, @Cast("Nd4jLong*") @StdVector LongPointer iArgs, @Cast("bool*") @StdVector boolean[] bArgs/*=std::vector<bool>()*/, @Cast("sd::DataType*") @StdVector IntPointer dArgs/*=std::vector<sd::DataType>()*/, @Cast("bool") boolean isInplace/*=false*/);
public native ResultSet evaluate(@Const @ByRef NDArrayVector inputs, @StdVector double[] tArgs, @Cast("Nd4jLong*") @StdVector long[] iArgs); public native @ByVal ResultSet evaluate(@Const @ByRef NDArrayVector inputs, @StdVector DoubleBuffer tArgs, @Cast("Nd4jLong*") @StdVector LongBuffer iArgs, @Cast("bool*") @StdVector BooleanPointer bArgs/*=std::vector<bool>()*/, @Cast("sd::DataType*") @StdVector IntBuffer dArgs/*=std::vector<sd::DataType>()*/, @Cast("bool") boolean isInplace/*=false*/);
public native ResultSet evaluate(@Const @ByRef NDArrayVector inputs, @StdVector DoublePointer tArgs, @Cast("Nd4jLong*") @StdVector LongPointer iArgs, @Cast("bool*") @StdVector boolean[] bArgs/*=std::vector<bool>()*/, @Cast("sd::DataType*") @StdVector IntPointer dArgs/*=std::vector<sd::DataType>()*/, @Cast("bool") boolean isInplace/*=false*/); public native @ByVal ResultSet evaluate(@Const @ByRef NDArrayVector inputs, @StdVector double[] tArgs, @Cast("Nd4jLong*") @StdVector long[] iArgs, @Cast("bool*") @StdVector boolean[] bArgs/*=std::vector<bool>()*/, @Cast("sd::DataType*") @StdVector int[] dArgs/*=std::vector<sd::DataType>()*/, @Cast("bool") boolean isInplace/*=false*/);
public native ResultSet evaluate(@Const @ByRef NDArrayVector inputs, @StdVector DoubleBuffer tArgs, @Cast("Nd4jLong*") @StdVector LongBuffer iArgs, @Cast("bool*") @StdVector BooleanPointer bArgs/*=std::vector<bool>()*/, @Cast("sd::DataType*") @StdVector IntBuffer dArgs/*=std::vector<sd::DataType>()*/, @Cast("bool") boolean isInplace/*=false*/);
public native ResultSet evaluate(@Const @ByRef NDArrayVector inputs, @StdVector double[] tArgs, @Cast("Nd4jLong*") @StdVector long[] iArgs, @Cast("bool*") @StdVector boolean[] bArgs/*=std::vector<bool>()*/, @Cast("sd::DataType*") @StdVector int[] dArgs/*=std::vector<sd::DataType>()*/, @Cast("bool") boolean isInplace/*=false*/);
public native @Cast("Nd4jStatus") int execute(@ByRef RandomGenerator rng, @Const @ByRef NDArrayVector inputs, @Const @ByRef NDArrayVector outputs, @StdVector DoublePointer tArgs, @Cast("Nd4jLong*") @StdVector LongPointer iArgs, @Cast("bool*") @StdVector BooleanPointer bArgs, @Cast("sd::DataType*") @StdVector IntPointer dArgs/*=std::vector<sd::DataType>()*/, @Cast("bool") boolean isInplace/*=false*/, @Cast("sd::DataType") int type/*=sd::DataType::FLOAT32*/); public native @Cast("Nd4jStatus") int execute(@ByRef RandomGenerator rng, @Const @ByRef NDArrayVector inputs, @Const @ByRef NDArrayVector outputs, @StdVector DoublePointer tArgs, @Cast("Nd4jLong*") @StdVector LongPointer iArgs, @Cast("bool*") @StdVector BooleanPointer bArgs, @Cast("sd::DataType*") @StdVector IntPointer dArgs/*=std::vector<sd::DataType>()*/, @Cast("bool") boolean isInplace/*=false*/, @Cast("sd::DataType") int type/*=sd::DataType::FLOAT32*/);
public native @Cast("Nd4jStatus") int execute(@ByRef RandomGenerator rng, @Const @ByRef NDArrayVector inputs, @Const @ByRef NDArrayVector outputs, @StdVector DoublePointer tArgs, @Cast("Nd4jLong*") @StdVector LongPointer iArgs, @Cast("bool*") @StdVector BooleanPointer bArgs); public native @Cast("Nd4jStatus") int execute(@ByRef RandomGenerator rng, @Const @ByRef NDArrayVector inputs, @Const @ByRef NDArrayVector outputs, @StdVector DoublePointer tArgs, @Cast("Nd4jLong*") @StdVector LongPointer iArgs, @Cast("bool*") @StdVector BooleanPointer bArgs);
@ -9762,8 +9789,9 @@ public static final int PREALLOC_SIZE = 33554432;
public native @Cast("Nd4jStatus") int execute(@ByRef RandomGenerator rng, @Const @ByRef NDArrayVector inputs, @Const @ByRef NDArrayVector outputs, @StdVector double[] tArgs, @Cast("Nd4jLong*") @StdVector long[] iArgs, @Cast("bool*") @StdVector boolean[] bArgs, @Cast("sd::DataType*") @StdVector int[] dArgs/*=std::vector<sd::DataType>()*/, @Cast("bool") boolean isInplace/*=false*/, @Cast("sd::DataType") int type/*=sd::DataType::FLOAT32*/); public native @Cast("Nd4jStatus") int execute(@ByRef RandomGenerator rng, @Const @ByRef NDArrayVector inputs, @Const @ByRef NDArrayVector outputs, @StdVector double[] tArgs, @Cast("Nd4jLong*") @StdVector long[] iArgs, @Cast("bool*") @StdVector boolean[] bArgs, @Cast("sd::DataType*") @StdVector int[] dArgs/*=std::vector<sd::DataType>()*/, @Cast("bool") boolean isInplace/*=false*/, @Cast("sd::DataType") int type/*=sd::DataType::FLOAT32*/);
public native @Cast("Nd4jStatus") int execute(@ByRef RandomGenerator rng, @Const @ByRef NDArrayVector inputs, @Const @ByRef NDArrayVector outputs, @StdVector double[] tArgs, @Cast("Nd4jLong*") @StdVector long[] iArgs, @Cast("bool*") @StdVector boolean[] bArgs); public native @Cast("Nd4jStatus") int execute(@ByRef RandomGenerator rng, @Const @ByRef NDArrayVector inputs, @Const @ByRef NDArrayVector outputs, @StdVector double[] tArgs, @Cast("Nd4jLong*") @StdVector long[] iArgs, @Cast("bool*") @StdVector boolean[] bArgs);
public native ResultSet execute(@Const @ByRef OpArgsHolder holder, @Cast("bool") boolean isInplace/*=false*/); public native @ByVal ResultSet execute(@Const @ByRef OpArgsHolder holder, @Cast("bool") boolean isInplace/*=false*/);
public native ResultSet execute(@Const @ByRef OpArgsHolder holder); public native @ByVal ResultSet execute(@Const @ByRef OpArgsHolder holder);
// There methods provide various validation options // There methods provide various validation options
public native @Cast("Nd4jStatus") int validateNonEmptyInput(@ByRef Context block); public native @Cast("Nd4jStatus") int validateNonEmptyInput(@ByRef Context block);
@ -9835,9 +9863,9 @@ public static final int PREALLOC_SIZE = 33554432;
public native @Cast("Nd4jStatus") int execute(Context block); public native @Cast("Nd4jStatus") int execute(Context block);
public native ResultSet execute(NDArrayList list, @ByRef NDArrayVector inputs, @StdVector DoublePointer tArgs, @StdVector IntPointer iArgs); public native @ByVal ResultSet execute(NDArrayList list, @ByRef NDArrayVector inputs, @StdVector DoublePointer tArgs, @StdVector IntPointer iArgs);
public native ResultSet execute(NDArrayList list, @ByRef NDArrayVector inputs, @StdVector DoubleBuffer tArgs, @StdVector IntBuffer iArgs); public native @ByVal ResultSet execute(NDArrayList list, @ByRef NDArrayVector inputs, @StdVector DoubleBuffer tArgs, @StdVector IntBuffer iArgs);
public native ResultSet execute(NDArrayList list, @ByRef NDArrayVector inputs, @StdVector double[] tArgs, @StdVector int[] iArgs); public native @ByVal ResultSet execute(NDArrayList list, @ByRef NDArrayVector inputs, @StdVector double[] tArgs, @StdVector int[] iArgs);
public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block);
} }

View File

@ -5000,8 +5000,8 @@ public native @Cast("bool") boolean isOptimalRequirementsMet();
******************************************************************************/ ******************************************************************************/
// //
// This class is suited for execution results representation. // This class is suited for execution results representation.
// //
// PLESE NOTE: It will delete all stored NDArrays upon destructor call // PLESE NOTE: It will delete all stored NDArrays upon destructor call
// //
// @author raver119@gmail.com // @author raver119@gmail.com
@ -5014,7 +5014,7 @@ public native @Cast("bool") boolean isOptimalRequirementsMet();
// #include <graph/generated/result_generated.h> // #include <graph/generated/result_generated.h>
// #include <system/pointercast.h> // #include <system/pointercast.h>
// #include <system/dll.h> // forward declaration of template class NDArray // #include <system/dll.h> // forward declaration of template class NDArray
@Namespace("sd") @NoOffset public static class ResultSet extends Pointer { @Namespace("sd") @NoOffset public static class ResultSet extends Pointer {
static { Loader.load(); } static { Loader.load(); }
/** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */
@ -6880,6 +6880,9 @@ public static final int PREALLOC_SIZE = 33554432;
@Namespace("shape") public static native int sizeAt(@Cast("const Nd4jLong*") LongPointer shapeInfo, int dim); @Namespace("shape") public static native int sizeAt(@Cast("const Nd4jLong*") LongPointer shapeInfo, int dim);
@Namespace("shape") public static native int sizeAt(@Cast("const Nd4jLong*") LongBuffer shapeInfo, int dim); @Namespace("shape") public static native int sizeAt(@Cast("const Nd4jLong*") LongBuffer shapeInfo, int dim);
@Namespace("shape") public static native int sizeAt(@Cast("const Nd4jLong*") long[] shapeInfo, int dim); @Namespace("shape") public static native int sizeAt(@Cast("const Nd4jLong*") long[] shapeInfo, int dim);
@Namespace("shape") public static native @Cast("Nd4jLong") long strideAt(@Cast("const Nd4jLong*") LongPointer shapeInfo, int dim);
@Namespace("shape") public static native @Cast("Nd4jLong") long strideAt(@Cast("const Nd4jLong*") LongBuffer shapeInfo, int dim);
@Namespace("shape") public static native @Cast("Nd4jLong") long strideAt(@Cast("const Nd4jLong*") long[] shapeInfo, int dim);
@Namespace("shape") public static native void traceNew(int id); @Namespace("shape") public static native void traceNew(int id);
@ -7817,14 +7820,20 @@ public static final int PREALLOC_SIZE = 33554432;
@Namespace("shape") public static native void index2coords(@Cast("Nd4jLong") long index, int rank, @Cast("const Nd4jLong*") LongPointer shape, IntPointer coords); @Namespace("shape") public static native void index2coords(@Cast("Nd4jLong") long index, int rank, @Cast("const Nd4jLong*") LongPointer shape, IntPointer coords);
@Namespace("shape") public static native void index2coords(@Cast("Nd4jLong") long index, int rank, @Cast("const Nd4jLong*") LongBuffer shape, IntBuffer coords); @Namespace("shape") public static native void index2coords(@Cast("Nd4jLong") long index, int rank, @Cast("const Nd4jLong*") LongBuffer shape, IntBuffer coords);
@Namespace("shape") public static native void index2coords(@Cast("Nd4jLong") long index, int rank, @Cast("const Nd4jLong*") long[] shape, int[] coords); @Namespace("shape") public static native void index2coords(@Cast("Nd4jLong") long index, int rank, @Cast("const Nd4jLong*") long[] shape, int[] coords);
@Namespace("shape") public static native void index2coordsCPU(@Cast("const Nd4jLong") long startIndex, @Cast("const Nd4jLong") long index, @Cast("const Nd4jLong*") LongPointer shapeInfo, @Cast("Nd4jLong*") LongPointer coords);
@Namespace("shape") public static native void index2coordsCPU(@Cast("const Nd4jLong") long startIndex, @Cast("const Nd4jLong") long index, @Cast("const Nd4jLong*") LongBuffer shapeInfo, @Cast("Nd4jLong*") LongBuffer coords);
@Namespace("shape") public static native void index2coordsCPU(@Cast("const Nd4jLong") long startIndex, @Cast("const Nd4jLong") long index, @Cast("const Nd4jLong*") long[] shapeInfo, @Cast("Nd4jLong*") long[] coords);
@Namespace("shape") public static native void index2coordsCPU(@Cast("const Nd4jLong") long startIndex, @Cast("const Nd4jLong") long index, @Cast("const Nd4jLong*") LongPointer shapeInfo, IntPointer coords);
@Namespace("shape") public static native void index2coordsCPU(@Cast("const Nd4jLong") long startIndex, @Cast("const Nd4jLong") long index, @Cast("const Nd4jLong*") LongBuffer shapeInfo, IntBuffer coords);
@Namespace("shape") public static native void index2coordsCPU(@Cast("const Nd4jLong") long startIndex, @Cast("const Nd4jLong") long index, @Cast("const Nd4jLong*") long[] shapeInfo, int[] coords);
/** /**
* take into account only dimensions stored in tadDims, tadDims must be sorted in increasing order! * take into account only dimensions stored in tadDims, tadDims must be sorted in increasing order!
*/ */
@Namespace("shape") public static native void index2coords(@Cast("Nd4jLong") long index, @Cast("const Nd4jLong*") LongPointer shapeInfo, @Cast("Nd4jLong*") LongPointer coords, int dimsSize, @Const IntPointer tadDims); @Namespace("shape") public static native void index2coords(@Cast("Nd4jLong") long index, @Cast("const Nd4jLong*") LongPointer shapeInfo, IntPointer coords, int dimsSize, @Const IntPointer tadDims);
@Namespace("shape") public static native void index2coords(@Cast("Nd4jLong") long index, @Cast("const Nd4jLong*") LongBuffer shapeInfo, @Cast("Nd4jLong*") LongBuffer coords, int dimsSize, @Const IntBuffer tadDims); @Namespace("shape") public static native void index2coords(@Cast("Nd4jLong") long index, @Cast("const Nd4jLong*") LongBuffer shapeInfo, IntBuffer coords, int dimsSize, @Const IntBuffer tadDims);
@Namespace("shape") public static native void index2coords(@Cast("Nd4jLong") long index, @Cast("const Nd4jLong*") long[] shapeInfo, @Cast("Nd4jLong*") long[] coords, int dimsSize, @Const int[] tadDims); @Namespace("shape") public static native void index2coords(@Cast("Nd4jLong") long index, @Cast("const Nd4jLong*") long[] shapeInfo, int[] coords, int dimsSize, @Const int[] tadDims);
/** /**
* Convert coordinates to the corresponding linear index (sequence number in other words) * Convert coordinates to the corresponding linear index (sequence number in other words)
@ -7836,15 +7845,15 @@ public static final int PREALLOC_SIZE = 33554432;
@Namespace("shape") public static native @Cast("Nd4jLong") long coords2index(@Cast("const Nd4jLong*") LongPointer shapeInfo, @Const IntPointer coords); @Namespace("shape") public static native @Cast("Nd4jLong") long coords2index(@Cast("const Nd4jLong*") LongPointer shapeInfo, @Const IntPointer coords);
@Namespace("shape") public static native @Cast("Nd4jLong") long coords2index(@Cast("const Nd4jLong*") LongBuffer shapeInfo, @Const IntBuffer coords); @Namespace("shape") public static native @Cast("Nd4jLong") long coords2index(@Cast("const Nd4jLong*") LongBuffer shapeInfo, @Const IntBuffer coords);
@Namespace("shape") public static native @Cast("Nd4jLong") long coords2index(@Cast("const Nd4jLong*") long[] shapeInfo, @Const int[] coords); @Namespace("shape") public static native @Cast("Nd4jLong") long coords2index(@Cast("const Nd4jLong*") long[] shapeInfo, @Const int[] coords);
@Namespace("shape") public static native @Cast("Nd4jLong") long coords2index(int rank, @Cast("const Nd4jLong*") LongPointer shape, @Cast("const Nd4jLong*") LongPointer coords); @Namespace("shape") public static native @Cast("Nd4jLong") long coords2index(int rank, @Cast("const Nd4jLong*") LongPointer shape, @Const IntPointer coords);
@Namespace("shape") public static native @Cast("Nd4jLong") long coords2index(int rank, @Cast("const Nd4jLong*") LongBuffer shape, @Cast("const Nd4jLong*") LongBuffer coords); @Namespace("shape") public static native @Cast("Nd4jLong") long coords2index(int rank, @Cast("const Nd4jLong*") LongBuffer shape, @Const IntBuffer coords);
@Namespace("shape") public static native @Cast("Nd4jLong") long coords2index(int rank, @Cast("const Nd4jLong*") long[] shape, @Cast("const Nd4jLong*") long[] coords); @Namespace("shape") public static native @Cast("Nd4jLong") long coords2index(int rank, @Cast("const Nd4jLong*") long[] shape, @Const int[] coords);
/** /**
* take into account only dimensions stored in tadDims, tadDims must be sorted in increasing order! * take into account only dimensions stored in tadDims, tadDims must be sorted in increasing order!
*/ */
@Namespace("shape") public static native @Cast("Nd4jLong") long coords2index(@Cast("const Nd4jLong*") LongPointer shapeInfo, @Cast("const Nd4jLong*") LongPointer coords, int dimsSize, @Const IntPointer tadDims); @Namespace("shape") public static native @Cast("Nd4jLong") long coords2index(@Cast("const Nd4jLong*") LongPointer shapeInfo, @Const IntPointer coords, int dimsSize, @Const IntPointer tadDims);
@Namespace("shape") public static native @Cast("Nd4jLong") long coords2index(@Cast("const Nd4jLong*") LongBuffer shapeInfo, @Cast("const Nd4jLong*") LongBuffer coords, int dimsSize, @Const IntBuffer tadDims); @Namespace("shape") public static native @Cast("Nd4jLong") long coords2index(@Cast("const Nd4jLong*") LongBuffer shapeInfo, @Const IntBuffer coords, int dimsSize, @Const IntBuffer tadDims);
@Namespace("shape") public static native @Cast("Nd4jLong") long coords2index(@Cast("const Nd4jLong*") long[] shapeInfo, @Cast("const Nd4jLong*") long[] coords, int dimsSize, @Const int[] tadDims); @Namespace("shape") public static native @Cast("Nd4jLong") long coords2index(@Cast("const Nd4jLong*") long[] shapeInfo, @Const int[] coords, int dimsSize, @Const int[] tadDims);
/** /**
* increment n-dimensional array by one iteration by changing coord appropriately * increment n-dimensional array by one iteration by changing coord appropriately
@ -7934,32 +7943,32 @@ public static final int PREALLOC_SIZE = 33554432;
// function calculates the coordinates of min array (and saves them into minIdxs) given coordinates of max array (already stored in maxIdxs) // function calculates the coordinates of min array (and saves them into minIdxs) given coordinates of max array (already stored in maxIdxs)
// dimsToExclude - should be sorted in increasing order // dimsToExclude - should be sorted in increasing order
// dimsLen - length of dimsToExclude, if not set (= -1), then it is calculated as maxRank - minRank // dimsLen - length of dimsToExclude, if not set (= -1), then it is calculated as maxRank - minRank
@Namespace("shape") public static native void maxIndToMinInd(@Cast("Nd4jLong*") LongPointer maxIdxs, @Cast("Nd4jLong*") LongPointer minIdxs, @Cast("const Nd4jLong*") LongPointer maxShapeInfo, @Cast("const Nd4jLong*") LongPointer minShapeInfo, @Const IntPointer dimsToExclude/*=nullptr*/, int dimsLen/*=-1*/); @Namespace("shape") public static native void maxIndToMinInd(IntPointer maxIdxs, IntPointer minIdxs, @Cast("const Nd4jLong*") LongPointer maxShapeInfo, @Cast("const Nd4jLong*") LongPointer minShapeInfo, @Const IntPointer dimsToExclude/*=nullptr*/, int dimsLen/*=-1*/);
@Namespace("shape") public static native void maxIndToMinInd(@Cast("Nd4jLong*") LongPointer maxIdxs, @Cast("Nd4jLong*") LongPointer minIdxs, @Cast("const Nd4jLong*") LongPointer maxShapeInfo, @Cast("const Nd4jLong*") LongPointer minShapeInfo); @Namespace("shape") public static native void maxIndToMinInd(IntPointer maxIdxs, IntPointer minIdxs, @Cast("const Nd4jLong*") LongPointer maxShapeInfo, @Cast("const Nd4jLong*") LongPointer minShapeInfo);
@Namespace("shape") public static native void maxIndToMinInd(@Cast("Nd4jLong*") LongBuffer maxIdxs, @Cast("Nd4jLong*") LongBuffer minIdxs, @Cast("const Nd4jLong*") LongBuffer maxShapeInfo, @Cast("const Nd4jLong*") LongBuffer minShapeInfo, @Const IntBuffer dimsToExclude/*=nullptr*/, int dimsLen/*=-1*/); @Namespace("shape") public static native void maxIndToMinInd(IntBuffer maxIdxs, IntBuffer minIdxs, @Cast("const Nd4jLong*") LongBuffer maxShapeInfo, @Cast("const Nd4jLong*") LongBuffer minShapeInfo, @Const IntBuffer dimsToExclude/*=nullptr*/, int dimsLen/*=-1*/);
@Namespace("shape") public static native void maxIndToMinInd(@Cast("Nd4jLong*") LongBuffer maxIdxs, @Cast("Nd4jLong*") LongBuffer minIdxs, @Cast("const Nd4jLong*") LongBuffer maxShapeInfo, @Cast("const Nd4jLong*") LongBuffer minShapeInfo); @Namespace("shape") public static native void maxIndToMinInd(IntBuffer maxIdxs, IntBuffer minIdxs, @Cast("const Nd4jLong*") LongBuffer maxShapeInfo, @Cast("const Nd4jLong*") LongBuffer minShapeInfo);
@Namespace("shape") public static native void maxIndToMinInd(@Cast("Nd4jLong*") long[] maxIdxs, @Cast("Nd4jLong*") long[] minIdxs, @Cast("const Nd4jLong*") long[] maxShapeInfo, @Cast("const Nd4jLong*") long[] minShapeInfo, @Const int[] dimsToExclude/*=nullptr*/, int dimsLen/*=-1*/); @Namespace("shape") public static native void maxIndToMinInd(int[] maxIdxs, int[] minIdxs, @Cast("const Nd4jLong*") long[] maxShapeInfo, @Cast("const Nd4jLong*") long[] minShapeInfo, @Const int[] dimsToExclude/*=nullptr*/, int dimsLen/*=-1*/);
@Namespace("shape") public static native void maxIndToMinInd(@Cast("Nd4jLong*") long[] maxIdxs, @Cast("Nd4jLong*") long[] minIdxs, @Cast("const Nd4jLong*") long[] maxShapeInfo, @Cast("const Nd4jLong*") long[] minShapeInfo); @Namespace("shape") public static native void maxIndToMinInd(int[] maxIdxs, int[] minIdxs, @Cast("const Nd4jLong*") long[] maxShapeInfo, @Cast("const Nd4jLong*") long[] minShapeInfo);
// calculate indexes of max-array, these output indexes correspond to one minIdx index of min-array which is sub-array of max-array // calculate indexes of max-array, these output indexes correspond to one minIdx index of min-array which is sub-array of max-array
// dimsToExclude - should be sorted in increasing order // dimsToExclude - should be sorted in increasing order
@Namespace("shape") public static native int outerArrayIndexes(@Cast("Nd4jLong*") LongPointer maxIdxs, @Cast("const Nd4jLong") long minIdx, @Cast("const Nd4jLong*") LongPointer maxShapeInfo, @Cast("const Nd4jLong*") LongPointer minShapeInfo, @Const IntPointer dimsToExclude/*=nullptr*/); @Namespace("shape") public static native int outerArrayIndexes(IntPointer maxIdxs, @Cast("const Nd4jLong") long minIdx, @Cast("const Nd4jLong*") LongPointer maxShapeInfo, @Cast("const Nd4jLong*") LongPointer minShapeInfo, @Const IntPointer dimsToExclude/*=nullptr*/);
@Namespace("shape") public static native int outerArrayIndexes(@Cast("Nd4jLong*") LongPointer maxIdxs, @Cast("const Nd4jLong") long minIdx, @Cast("const Nd4jLong*") LongPointer maxShapeInfo, @Cast("const Nd4jLong*") LongPointer minShapeInfo); @Namespace("shape") public static native int outerArrayIndexes(IntPointer maxIdxs, @Cast("const Nd4jLong") long minIdx, @Cast("const Nd4jLong*") LongPointer maxShapeInfo, @Cast("const Nd4jLong*") LongPointer minShapeInfo);
@Namespace("shape") public static native int outerArrayIndexes(@Cast("Nd4jLong*") LongBuffer maxIdxs, @Cast("const Nd4jLong") long minIdx, @Cast("const Nd4jLong*") LongBuffer maxShapeInfo, @Cast("const Nd4jLong*") LongBuffer minShapeInfo, @Const IntBuffer dimsToExclude/*=nullptr*/); @Namespace("shape") public static native int outerArrayIndexes(IntBuffer maxIdxs, @Cast("const Nd4jLong") long minIdx, @Cast("const Nd4jLong*") LongBuffer maxShapeInfo, @Cast("const Nd4jLong*") LongBuffer minShapeInfo, @Const IntBuffer dimsToExclude/*=nullptr*/);
@Namespace("shape") public static native int outerArrayIndexes(@Cast("Nd4jLong*") LongBuffer maxIdxs, @Cast("const Nd4jLong") long minIdx, @Cast("const Nd4jLong*") LongBuffer maxShapeInfo, @Cast("const Nd4jLong*") LongBuffer minShapeInfo); @Namespace("shape") public static native int outerArrayIndexes(IntBuffer maxIdxs, @Cast("const Nd4jLong") long minIdx, @Cast("const Nd4jLong*") LongBuffer maxShapeInfo, @Cast("const Nd4jLong*") LongBuffer minShapeInfo);
@Namespace("shape") public static native int outerArrayIndexes(@Cast("Nd4jLong*") long[] maxIdxs, @Cast("const Nd4jLong") long minIdx, @Cast("const Nd4jLong*") long[] maxShapeInfo, @Cast("const Nd4jLong*") long[] minShapeInfo, @Const int[] dimsToExclude/*=nullptr*/); @Namespace("shape") public static native int outerArrayIndexes(int[] maxIdxs, @Cast("const Nd4jLong") long minIdx, @Cast("const Nd4jLong*") long[] maxShapeInfo, @Cast("const Nd4jLong*") long[] minShapeInfo, @Const int[] dimsToExclude/*=nullptr*/);
@Namespace("shape") public static native int outerArrayIndexes(@Cast("Nd4jLong*") long[] maxIdxs, @Cast("const Nd4jLong") long minIdx, @Cast("const Nd4jLong*") long[] maxShapeInfo, @Cast("const Nd4jLong*") long[] minShapeInfo); @Namespace("shape") public static native int outerArrayIndexes(int[] maxIdxs, @Cast("const Nd4jLong") long minIdx, @Cast("const Nd4jLong*") long[] maxShapeInfo, @Cast("const Nd4jLong*") long[] minShapeInfo);
// calculate offsets of max-array, these offsets correspond to one minIdx index of min-array which is sub-array of max-array // calculate offsets of max-array, these offsets correspond to one minIdx index of min-array which is sub-array of max-array
// maxOffsets - will contain calculated offsets of max-array, buffer for maxOffsets should be allocated beforehand // maxOffsets - will contain calculated offsets of max-array, buffer for maxOffsets should be allocated beforehand
// dimsToExclude - should be sorted in increasing order // dimsToExclude - should be sorted in increasing order
// memBuff - auxiliary memory buffer (size = 2 * max_rank) for coordinates and increments storing, should be allocated beforehand // memBuff - auxiliary memory buffer (size = 2 * max_rank) for coordinates and increments storing, should be allocated beforehand
@Namespace("shape") public static native int outerArrayOffsets(@Cast("Nd4jLong*") LongPointer maxOffsets, @Cast("const Nd4jLong") long minIdx, @Cast("const Nd4jLong*") LongPointer maxShapeInfo, @Cast("const Nd4jLong*") LongPointer minShapeInfo, @Cast("Nd4jLong*") LongPointer memBuff, @Const IntPointer dimsToExclude/*=nullptr*/); @Namespace("shape") public static native int outerArrayOffsets(@Cast("Nd4jLong*") LongPointer maxOffsets, @Cast("const Nd4jLong") long minIdx, @Cast("const Nd4jLong*") LongPointer maxShapeInfo, @Cast("const Nd4jLong*") LongPointer minShapeInfo, IntPointer memBuff, @Const IntPointer dimsToExclude/*=nullptr*/);
@Namespace("shape") public static native int outerArrayOffsets(@Cast("Nd4jLong*") LongPointer maxOffsets, @Cast("const Nd4jLong") long minIdx, @Cast("const Nd4jLong*") LongPointer maxShapeInfo, @Cast("const Nd4jLong*") LongPointer minShapeInfo, @Cast("Nd4jLong*") LongPointer memBuff); @Namespace("shape") public static native int outerArrayOffsets(@Cast("Nd4jLong*") LongPointer maxOffsets, @Cast("const Nd4jLong") long minIdx, @Cast("const Nd4jLong*") LongPointer maxShapeInfo, @Cast("const Nd4jLong*") LongPointer minShapeInfo, IntPointer memBuff);
@Namespace("shape") public static native int outerArrayOffsets(@Cast("Nd4jLong*") LongBuffer maxOffsets, @Cast("const Nd4jLong") long minIdx, @Cast("const Nd4jLong*") LongBuffer maxShapeInfo, @Cast("const Nd4jLong*") LongBuffer minShapeInfo, @Cast("Nd4jLong*") LongBuffer memBuff, @Const IntBuffer dimsToExclude/*=nullptr*/); @Namespace("shape") public static native int outerArrayOffsets(@Cast("Nd4jLong*") LongBuffer maxOffsets, @Cast("const Nd4jLong") long minIdx, @Cast("const Nd4jLong*") LongBuffer maxShapeInfo, @Cast("const Nd4jLong*") LongBuffer minShapeInfo, IntBuffer memBuff, @Const IntBuffer dimsToExclude/*=nullptr*/);
@Namespace("shape") public static native int outerArrayOffsets(@Cast("Nd4jLong*") LongBuffer maxOffsets, @Cast("const Nd4jLong") long minIdx, @Cast("const Nd4jLong*") LongBuffer maxShapeInfo, @Cast("const Nd4jLong*") LongBuffer minShapeInfo, @Cast("Nd4jLong*") LongBuffer memBuff); @Namespace("shape") public static native int outerArrayOffsets(@Cast("Nd4jLong*") LongBuffer maxOffsets, @Cast("const Nd4jLong") long minIdx, @Cast("const Nd4jLong*") LongBuffer maxShapeInfo, @Cast("const Nd4jLong*") LongBuffer minShapeInfo, IntBuffer memBuff);
@Namespace("shape") public static native int outerArrayOffsets(@Cast("Nd4jLong*") long[] maxOffsets, @Cast("const Nd4jLong") long minIdx, @Cast("const Nd4jLong*") long[] maxShapeInfo, @Cast("const Nd4jLong*") long[] minShapeInfo, @Cast("Nd4jLong*") long[] memBuff, @Const int[] dimsToExclude/*=nullptr*/); @Namespace("shape") public static native int outerArrayOffsets(@Cast("Nd4jLong*") long[] maxOffsets, @Cast("const Nd4jLong") long minIdx, @Cast("const Nd4jLong*") long[] maxShapeInfo, @Cast("const Nd4jLong*") long[] minShapeInfo, int[] memBuff, @Const int[] dimsToExclude/*=nullptr*/);
@Namespace("shape") public static native int outerArrayOffsets(@Cast("Nd4jLong*") long[] maxOffsets, @Cast("const Nd4jLong") long minIdx, @Cast("const Nd4jLong*") long[] maxShapeInfo, @Cast("const Nd4jLong*") long[] minShapeInfo, @Cast("Nd4jLong*") long[] memBuff); @Namespace("shape") public static native int outerArrayOffsets(@Cast("Nd4jLong*") long[] maxOffsets, @Cast("const Nd4jLong") long minIdx, @Cast("const Nd4jLong*") long[] maxShapeInfo, @Cast("const Nd4jLong*") long[] minShapeInfo, int[] memBuff);
// calculates offsets for entities (elements or sub-arrays), shape in context of sub-array means dimensions excluded from outer array // calculates offsets for entities (elements or sub-arrays), shape in context of sub-array means dimensions excluded from outer array
// rank is equal to size of shape // rank is equal to size of shape
@ -8055,9 +8064,7 @@ public static final int PREALLOC_SIZE = 33554432;
* get stride over contiguous axis (contiguous axis must have stride = 1) * get stride over contiguous axis (contiguous axis must have stride = 1)
* for example when inShapeInfo is {4, 2,5,4,3, 60,1,5,20, 16384,0,99} then output is 5 (that is smallest stride in inShapeInfo except those equal to 1) * for example when inShapeInfo is {4, 2,5,4,3, 60,1,5,20, 16384,0,99} then output is 5 (that is smallest stride in inShapeInfo except those equal to 1)
*/ */
@Namespace("shape") public static native @Cast("Nd4jLong") long strideOverContigAxis(int axis, @Cast("const Nd4jLong*") LongPointer inShapeInfo); // INLINEDEF _CUDA_HD Nd4jLong strideOverContigAxis(const int axis, const Nd4jLong* inShapeInfo);
@Namespace("shape") public static native @Cast("Nd4jLong") long strideOverContigAxis(int axis, @Cast("const Nd4jLong*") LongBuffer inShapeInfo);
@Namespace("shape") public static native @Cast("Nd4jLong") long strideOverContigAxis(int axis, @Cast("const Nd4jLong*") long[] inShapeInfo);
@ -8933,6 +8940,10 @@ public static final int PREALLOC_SIZE = 33554432;
////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////
//////////////////////////////////////////////////////////////////////
//////////////////////////////////////////////////////////////////////
////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////
// INLINEDEF _CUDA_HD void calcOffsets(const Nd4jLong *xShapeInfo, Nd4jLong*& xOffsets, const Nd4jLong *yShapeInfo, Nd4jLong*& yOffsets, const Nd4jLong* zShapeInfo, Nd4jLong*& zOffsets, const char order) { // INLINEDEF _CUDA_HD void calcOffsets(const Nd4jLong *xShapeInfo, Nd4jLong*& xOffsets, const Nd4jLong *yShapeInfo, Nd4jLong*& yOffsets, const Nd4jLong* zShapeInfo, Nd4jLong*& zOffsets, const char order) {
@ -9129,6 +9140,23 @@ public static final int PREALLOC_SIZE = 33554432;
////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////
// INLINEDEF _CUDA_HD Nd4jLong strideOverContigAxis(const int axis, const Nd4jLong* inShapeInfo) {
// Nd4jLong result = 9223372036854775807LL;
// for(uint i = 0; i < shape::rank(inShapeInfo); ++i) {
// const auto currentStride = shape::stride(inShapeInfo)[i];
// if(i == axis || shape::shapeOf(inShapeInfo)[i] == 1)
// continue;
// if(result > currentStride)
// result = currentStride;
// }
// return result == 9223372036854775807LL ? 1 : result;
// }
@ -11949,18 +11977,17 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD();
public native @Cast("Nd4jStatus") int execute(@Const @ByRef NDArrayVector inputs, @Const @ByRef NDArrayVector outputs, @StdVector DoubleBuffer tArgs, @Cast("Nd4jLong*") @StdVector LongBuffer iArgs, @Cast("bool*") @StdVector BooleanPointer bArgs/*=std::vector<bool>()*/, @Cast("sd::DataType*") @StdVector IntBuffer dArgs/*=std::vector<sd::DataType>()*/, @Cast("bool") boolean isInplace/*=false*/); public native @Cast("Nd4jStatus") int execute(@Const @ByRef NDArrayVector inputs, @Const @ByRef NDArrayVector outputs, @StdVector DoubleBuffer tArgs, @Cast("Nd4jLong*") @StdVector LongBuffer iArgs, @Cast("bool*") @StdVector BooleanPointer bArgs/*=std::vector<bool>()*/, @Cast("sd::DataType*") @StdVector IntBuffer dArgs/*=std::vector<sd::DataType>()*/, @Cast("bool") boolean isInplace/*=false*/);
public native @Cast("Nd4jStatus") int execute(@Const @ByRef NDArrayVector inputs, @Const @ByRef NDArrayVector outputs, @StdVector double[] tArgs, @Cast("Nd4jLong*") @StdVector long[] iArgs, @Cast("bool*") @StdVector boolean[] bArgs/*=std::vector<bool>()*/, @Cast("sd::DataType*") @StdVector int[] dArgs/*=std::vector<sd::DataType>()*/, @Cast("bool") boolean isInplace/*=false*/); public native @Cast("Nd4jStatus") int execute(@Const @ByRef NDArrayVector inputs, @Const @ByRef NDArrayVector outputs, @StdVector double[] tArgs, @Cast("Nd4jLong*") @StdVector long[] iArgs, @Cast("bool*") @StdVector boolean[] bArgs/*=std::vector<bool>()*/, @Cast("sd::DataType*") @StdVector int[] dArgs/*=std::vector<sd::DataType>()*/, @Cast("bool") boolean isInplace/*=false*/);
public native @ByVal ResultSet evaluate(@Const @ByRef NDArrayVector inputs);
public native ResultSet evaluate(@Const @ByRef NDArrayVector inputs); public native @ByVal ResultSet evaluate(@Const @ByRef NDArrayVector inputs, @StdVector DoublePointer tArgs, @Cast("Nd4jLong*") @StdVector LongPointer iArgs, @Cast("bool*") @StdVector BooleanPointer bArgs/*=std::vector<bool>()*/, @Cast("sd::DataType*") @StdVector IntPointer dArgs/*=std::vector<sd::DataType>()*/, @Cast("bool") boolean isInplace/*=false*/);
public native @ByVal ResultSet evaluate(@Const @ByRef NDArrayVector inputs, @StdVector DoublePointer tArgs, @Cast("Nd4jLong*") @StdVector LongPointer iArgs);
public native ResultSet evaluate(@Const @ByRef NDArrayVector inputs, @StdVector DoublePointer tArgs, @Cast("Nd4jLong*") @StdVector LongPointer iArgs, @Cast("bool*") @StdVector BooleanPointer bArgs/*=std::vector<bool>()*/, @Cast("sd::DataType*") @StdVector IntPointer dArgs/*=std::vector<sd::DataType>()*/, @Cast("bool") boolean isInplace/*=false*/); public native @ByVal ResultSet evaluate(@Const @ByRef NDArrayVector inputs, @StdVector DoubleBuffer tArgs, @Cast("Nd4jLong*") @StdVector LongBuffer iArgs, @Cast("bool*") @StdVector boolean[] bArgs/*=std::vector<bool>()*/, @Cast("sd::DataType*") @StdVector IntBuffer dArgs/*=std::vector<sd::DataType>()*/, @Cast("bool") boolean isInplace/*=false*/);
public native ResultSet evaluate(@Const @ByRef NDArrayVector inputs, @StdVector DoublePointer tArgs, @Cast("Nd4jLong*") @StdVector LongPointer iArgs); public native @ByVal ResultSet evaluate(@Const @ByRef NDArrayVector inputs, @StdVector DoubleBuffer tArgs, @Cast("Nd4jLong*") @StdVector LongBuffer iArgs);
public native ResultSet evaluate(@Const @ByRef NDArrayVector inputs, @StdVector DoubleBuffer tArgs, @Cast("Nd4jLong*") @StdVector LongBuffer iArgs, @Cast("bool*") @StdVector boolean[] bArgs/*=std::vector<bool>()*/, @Cast("sd::DataType*") @StdVector IntBuffer dArgs/*=std::vector<sd::DataType>()*/, @Cast("bool") boolean isInplace/*=false*/); public native @ByVal ResultSet evaluate(@Const @ByRef NDArrayVector inputs, @StdVector double[] tArgs, @Cast("Nd4jLong*") @StdVector long[] iArgs, @Cast("bool*") @StdVector BooleanPointer bArgs/*=std::vector<bool>()*/, @Cast("sd::DataType*") @StdVector int[] dArgs/*=std::vector<sd::DataType>()*/, @Cast("bool") boolean isInplace/*=false*/);
public native ResultSet evaluate(@Const @ByRef NDArrayVector inputs, @StdVector DoubleBuffer tArgs, @Cast("Nd4jLong*") @StdVector LongBuffer iArgs); public native @ByVal ResultSet evaluate(@Const @ByRef NDArrayVector inputs, @StdVector double[] tArgs, @Cast("Nd4jLong*") @StdVector long[] iArgs);
public native ResultSet evaluate(@Const @ByRef NDArrayVector inputs, @StdVector double[] tArgs, @Cast("Nd4jLong*") @StdVector long[] iArgs, @Cast("bool*") @StdVector BooleanPointer bArgs/*=std::vector<bool>()*/, @Cast("sd::DataType*") @StdVector int[] dArgs/*=std::vector<sd::DataType>()*/, @Cast("bool") boolean isInplace/*=false*/); public native @ByVal ResultSet evaluate(@Const @ByRef NDArrayVector inputs, @StdVector DoublePointer tArgs, @Cast("Nd4jLong*") @StdVector LongPointer iArgs, @Cast("bool*") @StdVector boolean[] bArgs/*=std::vector<bool>()*/, @Cast("sd::DataType*") @StdVector IntPointer dArgs/*=std::vector<sd::DataType>()*/, @Cast("bool") boolean isInplace/*=false*/);
public native ResultSet evaluate(@Const @ByRef NDArrayVector inputs, @StdVector double[] tArgs, @Cast("Nd4jLong*") @StdVector long[] iArgs); public native @ByVal ResultSet evaluate(@Const @ByRef NDArrayVector inputs, @StdVector DoubleBuffer tArgs, @Cast("Nd4jLong*") @StdVector LongBuffer iArgs, @Cast("bool*") @StdVector BooleanPointer bArgs/*=std::vector<bool>()*/, @Cast("sd::DataType*") @StdVector IntBuffer dArgs/*=std::vector<sd::DataType>()*/, @Cast("bool") boolean isInplace/*=false*/);
public native ResultSet evaluate(@Const @ByRef NDArrayVector inputs, @StdVector DoublePointer tArgs, @Cast("Nd4jLong*") @StdVector LongPointer iArgs, @Cast("bool*") @StdVector boolean[] bArgs/*=std::vector<bool>()*/, @Cast("sd::DataType*") @StdVector IntPointer dArgs/*=std::vector<sd::DataType>()*/, @Cast("bool") boolean isInplace/*=false*/); public native @ByVal ResultSet evaluate(@Const @ByRef NDArrayVector inputs, @StdVector double[] tArgs, @Cast("Nd4jLong*") @StdVector long[] iArgs, @Cast("bool*") @StdVector boolean[] bArgs/*=std::vector<bool>()*/, @Cast("sd::DataType*") @StdVector int[] dArgs/*=std::vector<sd::DataType>()*/, @Cast("bool") boolean isInplace/*=false*/);
public native ResultSet evaluate(@Const @ByRef NDArrayVector inputs, @StdVector DoubleBuffer tArgs, @Cast("Nd4jLong*") @StdVector LongBuffer iArgs, @Cast("bool*") @StdVector BooleanPointer bArgs/*=std::vector<bool>()*/, @Cast("sd::DataType*") @StdVector IntBuffer dArgs/*=std::vector<sd::DataType>()*/, @Cast("bool") boolean isInplace/*=false*/);
public native ResultSet evaluate(@Const @ByRef NDArrayVector inputs, @StdVector double[] tArgs, @Cast("Nd4jLong*") @StdVector long[] iArgs, @Cast("bool*") @StdVector boolean[] bArgs/*=std::vector<bool>()*/, @Cast("sd::DataType*") @StdVector int[] dArgs/*=std::vector<sd::DataType>()*/, @Cast("bool") boolean isInplace/*=false*/);
public native @Cast("Nd4jStatus") int execute(@ByRef RandomGenerator rng, @Const @ByRef NDArrayVector inputs, @Const @ByRef NDArrayVector outputs, @StdVector DoublePointer tArgs, @Cast("Nd4jLong*") @StdVector LongPointer iArgs, @Cast("bool*") @StdVector BooleanPointer bArgs, @Cast("sd::DataType*") @StdVector IntPointer dArgs/*=std::vector<sd::DataType>()*/, @Cast("bool") boolean isInplace/*=false*/, @Cast("sd::DataType") int type/*=sd::DataType::FLOAT32*/); public native @Cast("Nd4jStatus") int execute(@ByRef RandomGenerator rng, @Const @ByRef NDArrayVector inputs, @Const @ByRef NDArrayVector outputs, @StdVector DoublePointer tArgs, @Cast("Nd4jLong*") @StdVector LongPointer iArgs, @Cast("bool*") @StdVector BooleanPointer bArgs, @Cast("sd::DataType*") @StdVector IntPointer dArgs/*=std::vector<sd::DataType>()*/, @Cast("bool") boolean isInplace/*=false*/, @Cast("sd::DataType") int type/*=sd::DataType::FLOAT32*/);
public native @Cast("Nd4jStatus") int execute(@ByRef RandomGenerator rng, @Const @ByRef NDArrayVector inputs, @Const @ByRef NDArrayVector outputs, @StdVector DoublePointer tArgs, @Cast("Nd4jLong*") @StdVector LongPointer iArgs, @Cast("bool*") @StdVector BooleanPointer bArgs); public native @Cast("Nd4jStatus") int execute(@ByRef RandomGenerator rng, @Const @ByRef NDArrayVector inputs, @Const @ByRef NDArrayVector outputs, @StdVector DoublePointer tArgs, @Cast("Nd4jLong*") @StdVector LongPointer iArgs, @Cast("bool*") @StdVector BooleanPointer bArgs);
@ -11975,8 +12002,9 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD();
public native @Cast("Nd4jStatus") int execute(@ByRef RandomGenerator rng, @Const @ByRef NDArrayVector inputs, @Const @ByRef NDArrayVector outputs, @StdVector double[] tArgs, @Cast("Nd4jLong*") @StdVector long[] iArgs, @Cast("bool*") @StdVector boolean[] bArgs, @Cast("sd::DataType*") @StdVector int[] dArgs/*=std::vector<sd::DataType>()*/, @Cast("bool") boolean isInplace/*=false*/, @Cast("sd::DataType") int type/*=sd::DataType::FLOAT32*/); public native @Cast("Nd4jStatus") int execute(@ByRef RandomGenerator rng, @Const @ByRef NDArrayVector inputs, @Const @ByRef NDArrayVector outputs, @StdVector double[] tArgs, @Cast("Nd4jLong*") @StdVector long[] iArgs, @Cast("bool*") @StdVector boolean[] bArgs, @Cast("sd::DataType*") @StdVector int[] dArgs/*=std::vector<sd::DataType>()*/, @Cast("bool") boolean isInplace/*=false*/, @Cast("sd::DataType") int type/*=sd::DataType::FLOAT32*/);
public native @Cast("Nd4jStatus") int execute(@ByRef RandomGenerator rng, @Const @ByRef NDArrayVector inputs, @Const @ByRef NDArrayVector outputs, @StdVector double[] tArgs, @Cast("Nd4jLong*") @StdVector long[] iArgs, @Cast("bool*") @StdVector boolean[] bArgs); public native @Cast("Nd4jStatus") int execute(@ByRef RandomGenerator rng, @Const @ByRef NDArrayVector inputs, @Const @ByRef NDArrayVector outputs, @StdVector double[] tArgs, @Cast("Nd4jLong*") @StdVector long[] iArgs, @Cast("bool*") @StdVector boolean[] bArgs);
public native ResultSet execute(@Const @ByRef OpArgsHolder holder, @Cast("bool") boolean isInplace/*=false*/); public native @ByVal ResultSet execute(@Const @ByRef OpArgsHolder holder, @Cast("bool") boolean isInplace/*=false*/);
public native ResultSet execute(@Const @ByRef OpArgsHolder holder); public native @ByVal ResultSet execute(@Const @ByRef OpArgsHolder holder);
// There methods provide various validation options // There methods provide various validation options
public native @Cast("Nd4jStatus") int validateNonEmptyInput(@ByRef Context block); public native @Cast("Nd4jStatus") int validateNonEmptyInput(@ByRef Context block);
@ -12048,9 +12076,9 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD();
public native @Cast("Nd4jStatus") int execute(Context block); public native @Cast("Nd4jStatus") int execute(Context block);
public native ResultSet execute(NDArrayList list, @ByRef NDArrayVector inputs, @StdVector DoublePointer tArgs, @StdVector IntPointer iArgs); public native @ByVal ResultSet execute(NDArrayList list, @ByRef NDArrayVector inputs, @StdVector DoublePointer tArgs, @StdVector IntPointer iArgs);
public native ResultSet execute(NDArrayList list, @ByRef NDArrayVector inputs, @StdVector DoubleBuffer tArgs, @StdVector IntBuffer iArgs); public native @ByVal ResultSet execute(NDArrayList list, @ByRef NDArrayVector inputs, @StdVector DoubleBuffer tArgs, @StdVector IntBuffer iArgs);
public native ResultSet execute(NDArrayList list, @ByRef NDArrayVector inputs, @StdVector double[] tArgs, @StdVector int[] iArgs); public native @ByVal ResultSet execute(NDArrayList list, @ByRef NDArrayVector inputs, @StdVector double[] tArgs, @StdVector int[] iArgs);
public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block);
} }

View File

@ -2050,7 +2050,7 @@ public class ShapeOpValidation extends BaseOpValidation {
print(out[0]); print(out[0]);
*/ */
INDArray emptyIn = Nd4j.empty(DataType.FLOAT); INDArray emptyIn = Nd4j.empty(DataType.FLOAT).reshape(0, 4);
INDArray axis = Nd4j.scalar(1); INDArray axis = Nd4j.scalar(1);
DynamicCustomOp op = DynamicCustomOp.builder("split") DynamicCustomOp op = DynamicCustomOp.builder("split")
@ -2061,9 +2061,10 @@ public class ShapeOpValidation extends BaseOpValidation {
List<LongShapeDescriptor> l = op.calculateOutputShape(); List<LongShapeDescriptor> l = op.calculateOutputShape();
assertEquals(4, l.size()); assertEquals(4, l.size());
for( int i=0; i<4; i++ ){ for( int i=0; i<4; i++ ){
assertArrayEquals(new long[0], l.get(i).getShape()); val desc = l.get(i);
assertTrue(l.get(i).isEmpty()); assertArrayEquals(new long[]{0, 1}, desc.getShape());
op.addOutputArgument(Nd4j.empty(DataType.FLOAT)); assertTrue(desc.isEmpty());
op.addOutputArgument(Nd4j.empty(DataType.FLOAT).reshape(desc.getShape()));
} }
Nd4j.exec(op); Nd4j.exec(op);