correct output empty shapes deducing in split op (#311)

* - correct output empty shapes deducing in split op

Signed-off-by: Yurii <iuriish@yahoo.com>

* java test fixed

Signed-off-by: raver119 <raver119@gmail.com>

* - split broadcast::exec function on individual functions corresponding to switch arg

Signed-off-by: Yurii <iuriish@yahoo.com>

* - split broadcast::exec _int and _bool function on individual functions corresponding to switch arg

Signed-off-by: Yurii <iuriish@yahoo.com>

Co-authored-by: raver119 <raver119@gmail.com>
master
Yurii Shyrma 2020-03-12 17:25:54 +02:00 committed by GitHub
parent 41bde8f885
commit e42b4e96c3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 1306 additions and 983 deletions

View File

@ -572,6 +572,272 @@ template <typename X, typename Y, typename Z>
DISPATCH_BY_OPNUM_TTT(exec, PARAMS(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo), BROADCAST_OPS);
}
////////////////////////////////////////////////////////////////////////
template <typename X, typename Y, typename Z, typename OpType>
static void execRank1(const X *x, const Nd4jLong *xShapeInfo, const Y *y, const Nd4jLong *yShapeInfo, Z* z, const Nd4jLong *zShapeInfo) {
uint zAxis0 = shape::sizeAt(zShapeInfo, 0);
Nd4jLong xStrd0 = shape::strideAt(xShapeInfo, 0);
Nd4jLong yStrd0 = shape::strideAt(yShapeInfo, 0);
Nd4jLong zStrd0 = shape::strideAt(zShapeInfo, 0);
auto func = PRAGMA_THREADS_FOR{
if(zStrd0 == 1 && xStrd0 == 1 && yStrd0 == 0) {
for (auto i0 = start; i0 < stop; ++i0)
z[i0] = OpType::op(x[i0], *y);
}
else if(zStrd0 == 1 && xStrd0 == 0 && yStrd0 == 1) {
for (auto i0 = start; i0 < stop; ++i0)
z[i0] = OpType::op(*x, y[i0]);
}
else if(zStrd0 == 1 && xStrd0 == 1 && yStrd0 == 1) {
for (auto i0 = start; i0 < stop; ++i0)
z[i0] = OpType::op(x[i0], y[i0]);
}
else {
for (auto i0 = start; i0 < stop; ++i0)
z[i0 * zStrd0] = OpType::op(x[i0 * xStrd0], y[i0 * yStrd0]);
}
};
samediff::Threads::parallel_tad(func, 0, zAxis0);
}
////////////////////////////////////////////////////////////////////////
template <typename X, typename Y, typename Z, typename OpType>
static void execRank2(const X *x, const Nd4jLong *xShapeInfo, const Y *y, const Nd4jLong *yShapeInfo, Z* z, const Nd4jLong *zShapeInfo) {
uint zAxis0 = shape::sizeAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 1);
Nd4jLong xStrd0 = shape::strideAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 1);
Nd4jLong yStrd0 = shape::strideAt(yShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 1);
Nd4jLong zStrd0 = shape::strideAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 1);
uint zAxis1 = shape::sizeAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 1 : 0);
Nd4jLong xStrd1 = shape::strideAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? 1 : 0);
Nd4jLong yStrd1 = shape::strideAt(yShapeInfo, shape::order(zShapeInfo) == 'c' ? 1 : 0);
Nd4jLong zStrd1 = shape::strideAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 1 : 0);
auto func = PRAGMA_THREADS_FOR{
for (auto i0 = start; i0 < stop; ++i0) {
auto x0 = x + i0 * xStrd0;
auto y0 = y + i0 * yStrd0;
auto z0 = z + i0 * zStrd0;
if(zStrd1 == 1 && xStrd1 == 1 && yStrd1 == 0)
for (uint i1 = 0; i1 < zAxis1; ++i1)
z0[i1] = OpType::op(x0[i1], *y0);
else if(zStrd1 == 1 && xStrd1 == 0 && yStrd1 == 1)
for (uint i1 = 0; i1 < zAxis1; ++i1)
z0[i1] = OpType::op(*x0, y0[i1]);
else if(zStrd1 == 1 && xStrd1 == 1 && yStrd1 == 1)
for (uint i1 = 0; i1 < zAxis1; ++i1)
z0[i1] = OpType::op(x0[i1], y0[i1]);
else
for (uint i1 = 0; i1 < zAxis1; ++i1)
z0[i1 * zStrd1] = OpType::op(x0[i1 * xStrd1], y0[i1 * yStrd1]);
}
};
samediff::Threads::parallel_tad(func, 0, zAxis0);
}
////////////////////////////////////////////////////////////////////////
template <typename X, typename Y, typename Z, typename OpType>
static void execRank3(const X *x, const Nd4jLong *xShapeInfo, const Y *y, const Nd4jLong *yShapeInfo, Z* z, const Nd4jLong *zShapeInfo) {
uint zAxis0 = shape::sizeAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 2);
Nd4jLong xStrd0 = shape::strideAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 2);
Nd4jLong yStrd0 = shape::strideAt(yShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 2);
Nd4jLong zStrd0 = shape::strideAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 2);
uint zAxis1 = shape::sizeAt(zShapeInfo, 1);
Nd4jLong xStrd1 = shape::strideAt(xShapeInfo, 1);
Nd4jLong yStrd1 = shape::strideAt(yShapeInfo, 1);
Nd4jLong zStrd1 = shape::strideAt(zShapeInfo, 1);
uint zAxis2 = shape::sizeAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 2 : 0);
Nd4jLong xStrd2 = shape::strideAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? 2 : 0);
Nd4jLong yStrd2 = shape::strideAt(yShapeInfo, shape::order(zShapeInfo) == 'c' ? 2 : 0);
Nd4jLong zStrd2 = shape::strideAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 2 : 0);
auto func = PRAGMA_THREADS_FOR_2D {
for (auto i0 = start_x; i0 < stop_x; ++i0) {
for (auto i1 = start_y; i1 < stop_y; ++i1) {
auto x1 = x + i0 * xStrd0 + i1 * xStrd1;
auto y1 = y + i0 * yStrd0 + i1 * yStrd1;
auto z1 = z + i0 * zStrd0 + i1 * zStrd1;
if(zStrd2 == 1 && xStrd2 == 1 && yStrd2 == 0)
for (uint i2 = 0; i2 < zAxis2; ++i2)
z1[i2] = OpType::op(x1[i2], *y1);
else if(zStrd2 == 1 && xStrd2 == 0 && yStrd2 == 1)
for (uint i2 = 0; i2 < zAxis2; ++i2)
z1[i2] = OpType::op(*x1, y1[i2]);
else if(zStrd2 == 1 && xStrd2 == 1 && yStrd2 == 1)
for (uint i2 = 0; i2 < zAxis2; ++i2)
z1[i2] = OpType::op(x1[i2], y1[i2]);
else
for (uint i2 = 0; i2 < zAxis2; ++i2)
z1[i2 * zStrd2] = OpType::op(x1[i2 * xStrd2], y1[i2 * yStrd2]);
}
}
};
samediff::Threads::parallel_for(func, 0,zAxis0,1, 0,zAxis1,1);
}
////////////////////////////////////////////////////////////////////////
template <typename X, typename Y, typename Z, typename OpType>
static void execRank4(const X *x, const Nd4jLong *xShapeInfo, const Y *y, const Nd4jLong *yShapeInfo, Z* z, const Nd4jLong *zShapeInfo) {
uint zAxis0 = shape::sizeAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 3);
Nd4jLong xStrd0 = shape::strideAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 3);
Nd4jLong yStrd0 = shape::strideAt(yShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 3);
Nd4jLong zStrd0 = shape::strideAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 3);
uint zAxis1 = shape::sizeAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 1 : 2);
Nd4jLong xStrd1 = shape::strideAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? 1 : 2);
Nd4jLong yStrd1 = shape::strideAt(yShapeInfo, shape::order(zShapeInfo) == 'c' ? 1 : 2);
Nd4jLong zStrd1 = shape::strideAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 1 : 2);
uint zAxis2 = shape::sizeAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 2 : 1);
Nd4jLong xStrd2 = shape::strideAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? 2 : 1);
Nd4jLong yStrd2 = shape::strideAt(yShapeInfo, shape::order(zShapeInfo) == 'c' ? 2 : 1);
Nd4jLong zStrd2 = shape::strideAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 2 : 1);
uint zAxis3 = shape::sizeAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 3 : 0);
Nd4jLong xStrd3 = shape::strideAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? 3 : 0);
Nd4jLong yStrd3 = shape::strideAt(yShapeInfo, shape::order(zShapeInfo) == 'c' ? 3 : 0);
Nd4jLong zStrd3 = shape::strideAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 3 : 0);
auto func = PRAGMA_THREADS_FOR_3D {
for (auto i0 = start_x; i0 < stop_x; ++i0) {
for (auto i1 = start_y; i1 < stop_y; ++i1) {
for (auto i2 = start_z; i2 < stop_z; ++i2) {
auto x2 = x + i0 * xStrd0 + i1 * xStrd1 + i2 * xStrd2;
auto y2 = y + i0 * yStrd0 + i1 * yStrd1 + i2 * yStrd2;
auto z2 = z + i0 * zStrd0 + i1 * zStrd1 + i2 * zStrd2;
if(zStrd3 == 1 && xStrd3 == 1 && yStrd3 == 0)
for (uint i3 = 0; i3 < zAxis3; ++i3)
z2[i3] = OpType::op(x2[i3], *y2);
else if(zStrd3 == 1 && xStrd3 == 0 && yStrd3 == 1)
for (uint i3 = 0; i3 < zAxis3; ++i3)
z2[i3] = OpType::op(*x2, y2[i3]);
else if(zStrd3 == 1 && xStrd3 == 1 && yStrd3 == 1)
for (uint i3 = 0; i3 < zAxis3; ++i3)
z2[i3] = OpType::op(x2[i3], y2[i3]);
else
for (uint i3 = 0; i3 < zAxis3; ++i3)
z2[i3 * zStrd3] = OpType::op(x2[i3 * xStrd3], y2[i3 * yStrd3]);
}
}
}
};
samediff::Threads::parallel_for(func, 0,zAxis0,1, 0,zAxis1,1, 0,zAxis2,1);
}
////////////////////////////////////////////////////////////////////////
template <typename X, typename Y, typename Z, typename OpType>
static void execRank5(const X *x, const Nd4jLong *xShapeInfo, const Y *y, const Nd4jLong *yShapeInfo, Z* z, const Nd4jLong *zShapeInfo) {
uint zAxis0 = shape::sizeAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 4);
Nd4jLong xStrd0 = shape::strideAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 4);
Nd4jLong yStrd0 = shape::strideAt(yShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 4);
Nd4jLong zStrd0 = shape::strideAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 4);
uint zAxis1 = shape::sizeAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 1 : 3);
Nd4jLong xStrd1 = shape::strideAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? 1 : 3);
Nd4jLong yStrd1 = shape::strideAt(yShapeInfo, shape::order(zShapeInfo) == 'c' ? 1 : 3);
Nd4jLong zStrd1 = shape::strideAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 1 : 3);
uint zAxis2 = shape::sizeAt(zShapeInfo, 2);
Nd4jLong xStrd2 = shape::strideAt(xShapeInfo, 2);
Nd4jLong yStrd2 = shape::strideAt(yShapeInfo, 2);
Nd4jLong zStrd2 = shape::strideAt(zShapeInfo, 2);
uint zAxis3 = shape::sizeAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 3 : 1);
Nd4jLong xStrd3 = shape::strideAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? 3 : 1);
Nd4jLong yStrd3 = shape::strideAt(yShapeInfo, shape::order(zShapeInfo) == 'c' ? 3 : 1);
Nd4jLong zStrd3 = shape::strideAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 3 : 1);
uint zAxis4 = shape::sizeAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 4 : 0);
Nd4jLong xStrd4 = shape::strideAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? 4 : 0);
Nd4jLong yStrd4 = shape::strideAt(yShapeInfo, shape::order(zShapeInfo) == 'c' ? 4 : 0);
Nd4jLong zStrd4 = shape::strideAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 4 : 0);
auto func = PRAGMA_THREADS_FOR_3D {
for (auto i0 = start_x; i0 < stop_x; ++i0) {
for (auto i1 = start_y; i1 < stop_y; ++i1) {
for (auto i2 = start_z; i2 < stop_z; ++i2) {
for (uint i3 = 0; i3 < zAxis3; ++i3) {
auto x3 = x + i0 * xStrd0 + i1 * xStrd1 + i2 * xStrd2 + i3 * xStrd3;
auto y3 = y + i0 * yStrd0 + i1 * yStrd1 + i2 * yStrd2 + i3 * yStrd3;
auto z3 = z + i0 * zStrd0 + i1 * zStrd1 + i2 * zStrd2 + i3 * zStrd3;
if(zStrd4 == 1 && xStrd4 == 1 && yStrd4 == 0)
for (uint i4 = 0; i4 < zAxis4; ++i4)
z3[i4] = OpType::op(x3[i4], *y3);
else if(zStrd4 == 1 && xStrd4 == 0 && yStrd4 == 1)
for (uint i4 = 0; i4 < zAxis4; ++i4)
z3[i4] = OpType::op(*x3, y3[i4]);
else if(zStrd4 == 1 && xStrd4 == 1 && yStrd4 == 1)
for (uint i4 = 0; i4 < zAxis4; ++i4)
z3[i4] = OpType::op(x3[i4], y3[i4]);
else
for (uint i4 = 0; i4 < zAxis4; ++i4)
z3[i4 * zStrd4] = OpType::op(x3[i4 * xStrd4], y3[i4 * yStrd4]);
}
}
}
}
};
samediff::Threads::parallel_for(func, 0,zAxis0,1, 0,zAxis1,1, 0,zAxis2,1);
}
////////////////////////////////////////////////////////////////////////
template <typename X, typename Y, typename Z, typename OpType>
static void execDefault(const X *x, const Nd4jLong *xShapeInfo, const Y *y, const Nd4jLong *yShapeInfo, Z* z, const Nd4jLong *zShapeInfo) {
const bool xzSameOffsets = shape::haveSameShapeAndStrides(xShapeInfo, zShapeInfo);
const bool yzSameOffsets = shape::haveSameShapeAndStrides(yShapeInfo, zShapeInfo);
const int rank = shape::rank(zShapeInfo); // xRank = yRank = zRank
auto func = PRAGMA_THREADS_FOR{
int xCoords[MAX_RANK], yCoords[MAX_RANK], zCoords[MAX_RANK];
for (auto i = start; i < stop; ++i) {
shape::index2coordsCPU(start, i, zShapeInfo, zCoords);
for (uint j = 0; j < rank; ++j) {
xCoords[j] = shape::sizeAt(xShapeInfo, j) == 1 ? 0 : zCoords[j];
yCoords[j] = shape::sizeAt(yShapeInfo, j) == 1 ? 0 : zCoords[j];
}
const auto zOffset = shape::getOffset(zShapeInfo, zCoords);
const auto xOffset = xzSameOffsets ? zOffset : shape::getOffset(xShapeInfo, xCoords);
const auto yOffset = yzSameOffsets ? zOffset : shape::getOffset(yShapeInfo, yCoords);
z[zOffset] = OpType::op(x[xOffset], y[yOffset]);
}
};
samediff::Threads::parallel_for(func, 0, shape::length(zShapeInfo));
}
////////////////////////////////////////////////////////////////////////
template <typename X, typename Y, typename Z>
template<typename OpType>
@ -582,220 +848,26 @@ void Broadcast<X, Y, Z>::exec(const void *vx, const Nd4jLong *xShapeInfo, const
Z* z = reinterpret_cast<Z*>(vz);
const int rank = shape::rank(zShapeInfo); // xRank = yRank = zRank
const char zOrder = shape::order(zShapeInfo);
uint xAxis0 = shape::sizeAt(xShapeInfo, zOrder == 'c' ? 0 : rank-1);
uint xAxis1 = shape::sizeAt(xShapeInfo, zOrder == 'c' ? 1 : rank-2);
uint xAxis2 = rank > 2 ? shape::sizeAt(xShapeInfo, zOrder == 'c' ? 2 : rank - 3) : 0;
uint xAxis3 = rank > 3 ? shape::sizeAt(xShapeInfo, zOrder == 'c' ? 3 : rank - 4) : 0;
uint xAxis4 = rank > 4 ? shape::sizeAt(xShapeInfo, zOrder == 'c' ? 4 : rank - 5) : 0;
Nd4jLong xStrd0 = shape::strideAt(xShapeInfo, zOrder == 'c' ? 0 : rank-1);
Nd4jLong xStrd1 = shape::strideAt(xShapeInfo, zOrder == 'c' ? 1 : rank-2);
Nd4jLong xStrd2 = rank > 2 ? shape::strideAt(xShapeInfo, zOrder == 'c' ? 2 : rank - 3) : 0;
Nd4jLong xStrd3 = rank > 3 ? shape::strideAt(xShapeInfo, zOrder == 'c' ? 3 : rank - 4) : 0;
Nd4jLong xStrd4 = rank > 4 ? shape::strideAt(xShapeInfo, zOrder == 'c' ? 4 : rank - 5) : 0;
uint yAxis0 = shape::sizeAt(yShapeInfo, zOrder == 'c' ? 0 : rank-1);
uint yAxis1 = shape::sizeAt(yShapeInfo, zOrder == 'c' ? 1 : rank-2);
uint yAxis2 = rank > 2 ? shape::sizeAt(yShapeInfo, zOrder == 'c' ? 2 : rank - 3) : 0;
uint yAxis3 = rank > 3 ? shape::sizeAt(yShapeInfo, zOrder == 'c' ? 3 : rank - 4) : 0;
uint yAxis4 = rank > 4 ? shape::sizeAt(yShapeInfo, zOrder == 'c' ? 4 : rank - 5) : 0;
Nd4jLong yStrd0 = shape::strideAt(yShapeInfo, zOrder == 'c' ? 0 : rank-1);
Nd4jLong yStrd1 = shape::strideAt(yShapeInfo, zOrder == 'c' ? 1 : rank-2);
Nd4jLong yStrd2 = rank > 2 ? shape::strideAt(yShapeInfo, zOrder == 'c' ? 2 : rank - 3) : 0;
Nd4jLong yStrd3 = rank > 3 ? shape::strideAt(yShapeInfo, zOrder == 'c' ? 3 : rank - 4) : 0;
Nd4jLong yStrd4 = rank > 4 ? shape::strideAt(yShapeInfo, zOrder == 'c' ? 4 : rank - 5) : 0;
uint zAxis0 = shape::sizeAt(zShapeInfo, zOrder == 'c' ? 0 : rank-1);
uint zAxis1 = shape::sizeAt(zShapeInfo, zOrder == 'c' ? 1 : rank-2);
uint zAxis2 = rank > 2 ? shape::sizeAt(zShapeInfo, zOrder == 'c' ? 2 : rank - 3) : 0;
uint zAxis3 = rank > 3 ? shape::sizeAt(zShapeInfo, zOrder == 'c' ? 3 : rank - 4) : 0;
uint zAxis4 = rank > 4 ? shape::sizeAt(zShapeInfo, zOrder == 'c' ? 4 : rank - 5) : 0;
Nd4jLong zStrd0 = shape::strideAt(zShapeInfo, zOrder == 'c' ? 0 : rank-1);
Nd4jLong zStrd1 = shape::strideAt(zShapeInfo, zOrder == 'c' ? 1 : rank-2);
Nd4jLong zStrd2 = rank > 2 ? shape::strideAt(zShapeInfo, zOrder == 'c' ? 2 : rank - 3) : 0;
Nd4jLong zStrd3 = rank > 3 ? shape::strideAt(zShapeInfo, zOrder == 'c' ? 3 : rank - 4) : 0;
Nd4jLong zStrd4 = rank > 4 ? shape::strideAt(zShapeInfo, zOrder == 'c' ? 4 : rank - 5) : 0;
switch (rank) {
case 1: {
auto func = PRAGMA_THREADS_FOR{
if(zStrd0 == 1 && xStrd0 == 1 && yStrd0 == 0) {
for (auto i0 = start; i0 < stop; ++i0)
z[i0] = OpType::op(x[i0], *y);
}
else if(zStrd0 == 1 && xStrd0 == 0 && yStrd0 == 1) {
for (auto i0 = start; i0 < stop; ++i0)
z[i0] = OpType::op(*x, y[i0]);
}
else if(zStrd0 == 1 && xStrd0 == 1 && yStrd0 == 1) {
for (auto i0 = start; i0 < stop; ++i0)
z[i0] = OpType::op(x[i0], y[i0]);
}
else {
for (auto i0 = start; i0 < stop; ++i0)
z[i0 * zStrd0] = OpType::op(x[i0 * xStrd0], y[i0 * yStrd0]);
}
};
samediff::Threads::parallel_tad(func, 0, zAxis0);
}
break;
case 2: {
auto func = PRAGMA_THREADS_FOR{
for (auto i0 = start; i0 < stop; ++i0) {
auto x0 = x + i0 * xStrd0;
auto y0 = y + i0 * yStrd0;
auto z0 = z + i0 * zStrd0;
if(zStrd1 == 1 && xStrd1 == 1 && yStrd1 == 0)
for (uint i1 = 0; i1 < zAxis1; ++i1)
z0[i1] = OpType::op(x0[i1], *y0);
else if(zStrd1 == 1 && xStrd1 == 0 && yStrd1 == 1)
for (uint i1 = 0; i1 < zAxis1; ++i1)
z0[i1] = OpType::op(*x0, y0[i1]);
else if(zStrd1 == 1 && xStrd1 == 1 && yStrd1 == 1)
for (uint i1 = 0; i1 < zAxis1; ++i1)
z0[i1] = OpType::op(x0[i1], y0[i1]);
else
for (uint i1 = 0; i1 < zAxis1; ++i1)
z0[i1 * zStrd1] = OpType::op(x0[i1 * xStrd1], y0[i1 * yStrd1]);
}
};
samediff::Threads::parallel_tad(func, 0, zAxis0);
}
break;
case 3: {
auto func = PRAGMA_THREADS_FOR_2D {
for (auto i0 = start_x; i0 < stop_x; ++i0) {
for (auto i1 = start_y; i1 < stop_y; ++i1) {
auto x1 = x + i0 * xStrd0 + i1 * xStrd1;
auto y1 = y + i0 * yStrd0 + i1 * yStrd1;
auto z1 = z + i0 * zStrd0 + i1 * zStrd1;
if(zStrd2 == 1 && xStrd2 == 1 && yStrd2 == 0)
for (uint i2 = 0; i2 < zAxis2; ++i2)
z1[i2] = OpType::op(x1[i2], *y1);
else if(zStrd2 == 1 && xStrd2 == 0 && yStrd2 == 1)
for (uint i2 = 0; i2 < zAxis2; ++i2)
z1[i2] = OpType::op(*x1, y1[i2]);
else if(zStrd2 == 1 && xStrd2 == 1 && yStrd2 == 1)
for (uint i2 = 0; i2 < zAxis2; ++i2)
z1[i2] = OpType::op(x1[i2], y1[i2]);
else
for (uint i2 = 0; i2 < zAxis2; ++i2)
z1[i2 * zStrd2] = OpType::op(x1[i2 * xStrd2], y1[i2 * yStrd2]);
}
}
};
samediff::Threads::parallel_for(func, 0,zAxis0,1, 0,zAxis1,1);
}
break;
case 4: {
auto func = PRAGMA_THREADS_FOR_3D {
for (auto i0 = start_x; i0 < stop_x; ++i0) {
for (auto i1 = start_y; i1 < stop_y; ++i1) {
for (auto i2 = start_z; i2 < stop_z; ++i2) {
auto x2 = x + i0 * xStrd0 + i1 * xStrd1 + i2 * xStrd2;
auto y2 = y + i0 * yStrd0 + i1 * yStrd1 + i2 * yStrd2;
auto z2 = z + i0 * zStrd0 + i1 * zStrd1 + i2 * zStrd2;
if(zStrd3 == 1 && xStrd3 == 1 && yStrd3 == 0)
for (uint i3 = 0; i3 < zAxis3; ++i3)
z2[i3] = OpType::op(x2[i3], *y2);
else if(zStrd3 == 1 && xStrd3 == 0 && yStrd3 == 1)
for (uint i3 = 0; i3 < zAxis3; ++i3)
z2[i3] = OpType::op(*x2, y2[i3]);
else if(zStrd3 == 1 && xStrd3 == 1 && yStrd3 == 1)
for (uint i3 = 0; i3 < zAxis3; ++i3)
z2[i3] = OpType::op(x2[i3], y2[i3]);
else
for (uint i3 = 0; i3 < zAxis3; ++i3)
z2[i3 * zStrd3] = OpType::op(x2[i3 * xStrd3], y2[i3 * yStrd3]);
}
}
}
};
samediff::Threads::parallel_for(func, 0,zAxis0,1, 0,zAxis1,1, 0,zAxis2,1);
}
break;
case 5: {
auto func = PRAGMA_THREADS_FOR_3D {
for (auto i0 = start_x; i0 < stop_x; ++i0) {
for (auto i1 = start_y; i1 < stop_y; ++i1) {
for (auto i2 = start_z; i2 < stop_z; ++i2) {
for (uint i3 = 0; i3 < zAxis3; ++i3) {
auto x3 = x + i0 * xStrd0 + i1 * xStrd1 + i2 * xStrd2 + i3 * xStrd3;
auto y3 = y + i0 * yStrd0 + i1 * yStrd1 + i2 * yStrd2 + i3 * yStrd3;
auto z3 = z + i0 * zStrd0 + i1 * zStrd1 + i2 * zStrd2 + i3 * zStrd3;
if(zStrd4 == 1 && xStrd4 == 1 && yStrd4 == 0)
for (uint i4 = 0; i4 < zAxis4; ++i4)
z3[i4] = OpType::op(x3[i4], *y3);
else if(zStrd4 == 1 && xStrd4 == 0 && yStrd4 == 1)
for (uint i4 = 0; i4 < zAxis4; ++i4)
z3[i4] = OpType::op(*x3, y3[i4]);
else if(zStrd4 == 1 && xStrd4 == 1 && yStrd4 == 1)
for (uint i4 = 0; i4 < zAxis4; ++i4)
z3[i4] = OpType::op(x3[i4], y3[i4]);
else
for (uint i4 = 0; i4 < zAxis4; ++i4)
z3[i4 * zStrd4] = OpType::op(x3[i4 * xStrd4], y3[i4 * yStrd4]);
}
}
}
}
};
samediff::Threads::parallel_for(func, 0,zAxis0,1, 0,zAxis1,1, 0,zAxis2,1);
}
break;
default: {
const bool xzSameOffsets = shape::haveSameShapeAndStrides(xShapeInfo, zShapeInfo);
const bool yzSameOffsets = shape::haveSameShapeAndStrides(yShapeInfo, zShapeInfo);
auto func = PRAGMA_THREADS_FOR{
int xCoords[MAX_RANK], yCoords[MAX_RANK], zCoords[MAX_RANK];
for (auto i = start; i < stop; ++i) {
shape::index2coordsCPU(start, i, zShapeInfo, zCoords);
for (uint j = 0; j < rank; ++j) {
xCoords[j] = shape::sizeAt(xShapeInfo, j) == 1 ? 0 : zCoords[j];
yCoords[j] = shape::sizeAt(yShapeInfo, j) == 1 ? 0 : zCoords[j];
}
const auto zOffset = shape::getOffset(zShapeInfo, zCoords);
const auto xOffset = xzSameOffsets ? zOffset : shape::getOffset(xShapeInfo, xCoords);
const auto yOffset = yzSameOffsets ? zOffset : shape::getOffset(yShapeInfo, yCoords);
z[zOffset] = OpType::op(x[xOffset], y[yOffset]);
}
};
samediff::Threads::parallel_for(func, 0, shape::length(zShapeInfo));
}
case 1:
execRank1<X,Y,Z, OpType>(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo);
break;
case 2:
execRank2<X,Y,Z, OpType>(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo);
break;
case 3:
execRank3<X,Y,Z, OpType>(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo);
break;
case 4:
execRank4<X,Y,Z, OpType>(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo);
break;
case 5:
execRank5<X,Y,Z, OpType>(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo);
break;
default:
execDefault<X,Y,Z, OpType>(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo);
}
}

View File

@ -453,6 +453,271 @@ namespace broadcast {
}
}
////////////////////////////////////////////////////////////////////////
template <typename X, typename Z, typename OpType>
static void execRank1(const X *x, const Nd4jLong *xShapeInfo, const X *y, const Nd4jLong *yShapeInfo, Z* z, const Nd4jLong *zShapeInfo, X* extraParams) {
uint zAxis0 = shape::sizeAt(zShapeInfo, 0);
Nd4jLong xStrd0 = shape::strideAt(xShapeInfo, 0);
Nd4jLong yStrd0 = shape::strideAt(yShapeInfo, 0);
Nd4jLong zStrd0 = shape::strideAt(zShapeInfo, 0);
auto func = PRAGMA_THREADS_FOR{
if(zStrd0 == 1 && xStrd0 == 1 && yStrd0 == 0) {
for (auto i0 = start; i0 < stop; ++i0)
z[i0] = OpType::op(x[i0], *y, extraParams);
}
else if(zStrd0 == 1 && xStrd0 == 0 && yStrd0 == 1) {
for (auto i0 = start; i0 < stop; ++i0)
z[i0] = OpType::op(*x, y[i0], extraParams);
}
else if(zStrd0 == 1 && xStrd0 == 1 && yStrd0 == 1) {
for (auto i0 = start; i0 < stop; ++i0)
z[i0] = OpType::op(x[i0], y[i0], extraParams);
}
else {
for (auto i0 = start; i0 < stop; ++i0)
z[i0 * zStrd0] = OpType::op(x[i0 * xStrd0], y[i0 * yStrd0], extraParams);
}
};
samediff::Threads::parallel_tad(func, 0, zAxis0);
}
////////////////////////////////////////////////////////////////////////
template <typename X, typename Z, typename OpType>
static void execRank2(const X *x, const Nd4jLong *xShapeInfo, const X *y, const Nd4jLong *yShapeInfo, Z* z, const Nd4jLong *zShapeInfo, X* extraParams) {
uint zAxis0 = shape::sizeAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 1);
Nd4jLong xStrd0 = shape::strideAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 1);
Nd4jLong yStrd0 = shape::strideAt(yShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 1);
Nd4jLong zStrd0 = shape::strideAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 1);
uint zAxis1 = shape::sizeAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 1 : 0);
Nd4jLong xStrd1 = shape::strideAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? 1 : 0);
Nd4jLong yStrd1 = shape::strideAt(yShapeInfo, shape::order(zShapeInfo) == 'c' ? 1 : 0);
Nd4jLong zStrd1 = shape::strideAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 1 : 0);
auto func = PRAGMA_THREADS_FOR{
for (auto i0 = start; i0 < stop; ++i0) {
auto x0 = x + i0 * xStrd0;
auto y0 = y + i0 * yStrd0;
auto z0 = z + i0 * zStrd0;
if(zStrd1 == 1 && xStrd1 == 1 && yStrd1 == 0)
for (uint i1 = 0; i1 < zAxis1; ++i1)
z0[i1] = OpType::op(x0[i1], *y0, extraParams);
else if(zStrd1 == 1 && xStrd1 == 0 && yStrd1 == 1)
for (uint i1 = 0; i1 < zAxis1; ++i1)
z0[i1] = OpType::op(*x0, y0[i1], extraParams);
else if(zStrd1 == 1 && xStrd1 == 1 && yStrd1 == 1)
for (uint i1 = 0; i1 < zAxis1; ++i1)
z0[i1] = OpType::op(x0[i1], y0[i1], extraParams);
else
for (uint i1 = 0; i1 < zAxis1; ++i1)
z0[i1 * zStrd1] = OpType::op(x0[i1 * xStrd1], y0[i1 * yStrd1], extraParams);
}
};
samediff::Threads::parallel_tad(func, 0, zAxis0);
}
////////////////////////////////////////////////////////////////////////
template <typename X, typename Z, typename OpType>
static void execRank3(const X *x, const Nd4jLong *xShapeInfo, const X *y, const Nd4jLong *yShapeInfo, Z* z, const Nd4jLong *zShapeInfo, X* extraParams) {
uint zAxis0 = shape::sizeAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 2);
Nd4jLong xStrd0 = shape::strideAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 2);
Nd4jLong yStrd0 = shape::strideAt(yShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 2);
Nd4jLong zStrd0 = shape::strideAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 2);
uint zAxis1 = shape::sizeAt(zShapeInfo, 1);
Nd4jLong xStrd1 = shape::strideAt(xShapeInfo, 1);
Nd4jLong yStrd1 = shape::strideAt(yShapeInfo, 1);
Nd4jLong zStrd1 = shape::strideAt(zShapeInfo, 1);
uint zAxis2 = shape::sizeAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 2 : 0);
Nd4jLong xStrd2 = shape::strideAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? 2 : 0);
Nd4jLong yStrd2 = shape::strideAt(yShapeInfo, shape::order(zShapeInfo) == 'c' ? 2 : 0);
Nd4jLong zStrd2 = shape::strideAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 2 : 0);
auto func = PRAGMA_THREADS_FOR_2D {
for (auto i0 = start_x; i0 < stop_x; ++i0) {
for (auto i1 = start_y; i1 < stop_y; ++i1) {
auto x1 = x + i0 * xStrd0 + i1 * xStrd1;
auto y1 = y + i0 * yStrd0 + i1 * yStrd1;
auto z1 = z + i0 * zStrd0 + i1 * zStrd1;
if(zStrd2 == 1 && xStrd2 == 1 && yStrd2 == 0)
for (uint i2 = 0; i2 < zAxis2; ++i2)
z1[i2] = OpType::op(x1[i2], *y1, extraParams);
else if(zStrd2 == 1 && xStrd2 == 0 && yStrd2 == 1)
for (uint i2 = 0; i2 < zAxis2; ++i2)
z1[i2] = OpType::op(*x1, y1[i2], extraParams);
else if(zStrd2 == 1 && xStrd2 == 1 && yStrd2 == 1)
for (uint i2 = 0; i2 < zAxis2; ++i2)
z1[i2] = OpType::op(x1[i2], y1[i2], extraParams);
else
for (uint i2 = 0; i2 < zAxis2; ++i2)
z1[i2 * zStrd2] = OpType::op(x1[i2 * xStrd2], y1[i2 * yStrd2], extraParams);
}
}
};
samediff::Threads::parallel_for(func, 0,zAxis0,1, 0,zAxis1,1);
}
////////////////////////////////////////////////////////////////////////
template <typename X, typename Z, typename OpType>
static void execRank4(const X *x, const Nd4jLong *xShapeInfo, const X *y, const Nd4jLong *yShapeInfo, Z* z, const Nd4jLong *zShapeInfo, X* extraParams) {
uint zAxis0 = shape::sizeAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 3);
Nd4jLong xStrd0 = shape::strideAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 3);
Nd4jLong yStrd0 = shape::strideAt(yShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 3);
Nd4jLong zStrd0 = shape::strideAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 3);
uint zAxis1 = shape::sizeAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 1 : 2);
Nd4jLong xStrd1 = shape::strideAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? 1 : 2);
Nd4jLong yStrd1 = shape::strideAt(yShapeInfo, shape::order(zShapeInfo) == 'c' ? 1 : 2);
Nd4jLong zStrd1 = shape::strideAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 1 : 2);
uint zAxis2 = shape::sizeAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 2 : 1);
Nd4jLong xStrd2 = shape::strideAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? 2 : 1);
Nd4jLong yStrd2 = shape::strideAt(yShapeInfo, shape::order(zShapeInfo) == 'c' ? 2 : 1);
Nd4jLong zStrd2 = shape::strideAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 2 : 1);
uint zAxis3 = shape::sizeAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 3 : 0);
Nd4jLong xStrd3 = shape::strideAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? 3 : 0);
Nd4jLong yStrd3 = shape::strideAt(yShapeInfo, shape::order(zShapeInfo) == 'c' ? 3 : 0);
Nd4jLong zStrd3 = shape::strideAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 3 : 0);
auto func = PRAGMA_THREADS_FOR_3D {
for (auto i0 = start_x; i0 < stop_x; ++i0) {
for (auto i1 = start_y; i1 < stop_y; ++i1) {
for (auto i2 = start_z; i2 < stop_z; ++i2) {
auto x2 = x + i0 * xStrd0 + i1 * xStrd1 + i2 * xStrd2;
auto y2 = y + i0 * yStrd0 + i1 * yStrd1 + i2 * yStrd2;
auto z2 = z + i0 * zStrd0 + i1 * zStrd1 + i2 * zStrd2;
if(zStrd3 == 1 && xStrd3 == 1 && yStrd3 == 0)
for (uint i3 = 0; i3 < zAxis3; ++i3)
z2[i3] = OpType::op(x2[i3], *y2, extraParams);
else if(zStrd3 == 1 && xStrd3 == 0 && yStrd3 == 1)
for (uint i3 = 0; i3 < zAxis3; ++i3)
z2[i3] = OpType::op(*x2, y2[i3], extraParams);
else if(zStrd3 == 1 && xStrd3 == 1 && yStrd3 == 1)
for (uint i3 = 0; i3 < zAxis3; ++i3)
z2[i3] = OpType::op(x2[i3], y2[i3], extraParams);
else
for (uint i3 = 0; i3 < zAxis3; ++i3)
z2[i3 * zStrd3] = OpType::op(x2[i3 * xStrd3], y2[i3 * yStrd3], extraParams);
}
}
}
};
samediff::Threads::parallel_for(func, 0,zAxis0,1, 0,zAxis1,1, 0,zAxis2,1);
}
////////////////////////////////////////////////////////////////////////
template <typename X, typename Z, typename OpType>
static void execRank5(const X *x, const Nd4jLong *xShapeInfo, const X *y, const Nd4jLong *yShapeInfo, Z* z, const Nd4jLong *zShapeInfo, X* extraParams) {
uint zAxis0 = shape::sizeAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 4);
Nd4jLong xStrd0 = shape::strideAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 4);
Nd4jLong yStrd0 = shape::strideAt(yShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 4);
Nd4jLong zStrd0 = shape::strideAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 4);
uint zAxis1 = shape::sizeAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 1 : 3);
Nd4jLong xStrd1 = shape::strideAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? 1 : 3);
Nd4jLong yStrd1 = shape::strideAt(yShapeInfo, shape::order(zShapeInfo) == 'c' ? 1 : 3);
Nd4jLong zStrd1 = shape::strideAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 1 : 3);
uint zAxis2 = shape::sizeAt(zShapeInfo, 2);
Nd4jLong xStrd2 = shape::strideAt(xShapeInfo, 2);
Nd4jLong yStrd2 = shape::strideAt(yShapeInfo, 2);
Nd4jLong zStrd2 = shape::strideAt(zShapeInfo, 2);
uint zAxis3 = shape::sizeAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 3 : 1);
Nd4jLong xStrd3 = shape::strideAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? 3 : 1);
Nd4jLong yStrd3 = shape::strideAt(yShapeInfo, shape::order(zShapeInfo) == 'c' ? 3 : 1);
Nd4jLong zStrd3 = shape::strideAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 3 : 1);
uint zAxis4 = shape::sizeAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 4 : 0);
Nd4jLong xStrd4 = shape::strideAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? 4 : 0);
Nd4jLong yStrd4 = shape::strideAt(yShapeInfo, shape::order(zShapeInfo) == 'c' ? 4 : 0);
Nd4jLong zStrd4 = shape::strideAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 4 : 0);
auto func = PRAGMA_THREADS_FOR_3D {
for (auto i0 = start_x; i0 < stop_x; ++i0) {
for (auto i1 = start_y; i1 < stop_y; ++i1) {
for (auto i2 = start_z; i2 < stop_z; ++i2) {
for (uint i3 = 0; i3 < zAxis3; ++i3) {
auto x3 = x + i0 * xStrd0 + i1 * xStrd1 + i2 * xStrd2 + i3 * xStrd3;
auto y3 = y + i0 * yStrd0 + i1 * yStrd1 + i2 * yStrd2 + i3 * yStrd3;
auto z3 = z + i0 * zStrd0 + i1 * zStrd1 + i2 * zStrd2 + i3 * zStrd3;
if(zStrd4 == 1 && xStrd4 == 1 && yStrd4 == 0)
for (uint i4 = 0; i4 < zAxis4; ++i4)
z3[i4] = OpType::op(x3[i4], *y3, extraParams);
else if(zStrd4 == 1 && xStrd4 == 0 && yStrd4 == 1)
for (uint i4 = 0; i4 < zAxis4; ++i4)
z3[i4] = OpType::op(*x3, y3[i4], extraParams);
else if(zStrd4 == 1 && xStrd4 == 1 && yStrd4 == 1)
for (uint i4 = 0; i4 < zAxis4; ++i4)
z3[i4] = OpType::op(x3[i4], y3[i4], extraParams);
else
for (uint i4 = 0; i4 < zAxis4; ++i4)
z3[i4 * zStrd4] = OpType::op(x3[i4 * xStrd4], y3[i4 * yStrd4], extraParams);
}
}
}
}
};
samediff::Threads::parallel_for(func, 0,zAxis0,1, 0,zAxis1,1, 0,zAxis2,1);
}
////////////////////////////////////////////////////////////////////////
template <typename X, typename Z, typename OpType>
static void execDefault(const X *x, const Nd4jLong *xShapeInfo, const X *y, const Nd4jLong *yShapeInfo, Z* z, const Nd4jLong *zShapeInfo, X* extraParams) {
const bool xzSameOffsets = shape::haveSameShapeAndStrides(xShapeInfo, zShapeInfo);
const bool yzSameOffsets = shape::haveSameShapeAndStrides(yShapeInfo, zShapeInfo);
const int rank = shape::rank(zShapeInfo); // xRank = yRank = zRank
auto func = PRAGMA_THREADS_FOR{
int xCoords[MAX_RANK], yCoords[MAX_RANK], zCoords[MAX_RANK];
for (auto i = start; i < stop; ++i) {
shape::index2coordsCPU(start, i, zShapeInfo, zCoords);
for (uint j = 0; j < rank; ++j) {
xCoords[j] = shape::sizeAt(xShapeInfo, j) == 1 ? 0 : zCoords[j];
yCoords[j] = shape::sizeAt(yShapeInfo, j) == 1 ? 0 : zCoords[j];
}
const auto zOffset = shape::getOffset(zShapeInfo, zCoords);
const auto xOffset = xzSameOffsets ? zOffset : shape::getOffset(xShapeInfo, xCoords);
const auto yOffset = yzSameOffsets ? zOffset : shape::getOffset(yShapeInfo, yCoords);
z[zOffset] = OpType::op(x[xOffset], y[yOffset], extraParams);
}
};
samediff::Threads::parallel_for(func, 0, shape::length(zShapeInfo));
}
////////////////////////////////////////////////////////////////////////
template <typename X, typename Z>
template<typename OpType>
@ -468,220 +733,26 @@ void BroadcastBool<X, Z>::exec(const void *vx, const Nd4jLong *xShapeInfo,
X* extraParams = reinterpret_cast<X*>(vextraParams);
const int rank = shape::rank(zShapeInfo); // xRank = yRank = zRank
const char zOrder = shape::order(zShapeInfo);
uint xAxis0 = shape::sizeAt(xShapeInfo, zOrder == 'c' ? 0 : rank-1);
uint xAxis1 = shape::sizeAt(xShapeInfo, zOrder == 'c' ? 1 : rank-2);
uint xAxis2 = rank > 2 ? shape::sizeAt(xShapeInfo, zOrder == 'c' ? 2 : rank - 3) : 0;
uint xAxis3 = rank > 3 ? shape::sizeAt(xShapeInfo, zOrder == 'c' ? 3 : rank - 4) : 0;
uint xAxis4 = rank > 4 ? shape::sizeAt(xShapeInfo, zOrder == 'c' ? 4 : rank - 5) : 0;
Nd4jLong xStrd0 = shape::strideAt(xShapeInfo, zOrder == 'c' ? 0 : rank-1);
Nd4jLong xStrd1 = shape::strideAt(xShapeInfo, zOrder == 'c' ? 1 : rank-2);
Nd4jLong xStrd2 = rank > 2 ? shape::strideAt(xShapeInfo, zOrder == 'c' ? 2 : rank - 3) : 0;
Nd4jLong xStrd3 = rank > 3 ? shape::strideAt(xShapeInfo, zOrder == 'c' ? 3 : rank - 4) : 0;
Nd4jLong xStrd4 = rank > 4 ? shape::strideAt(xShapeInfo, zOrder == 'c' ? 4 : rank - 5) : 0;
uint yAxis0 = shape::sizeAt(yShapeInfo, zOrder == 'c' ? 0 : rank-1);
uint yAxis1 = shape::sizeAt(yShapeInfo, zOrder == 'c' ? 1 : rank-2);
uint yAxis2 = rank > 2 ? shape::sizeAt(yShapeInfo, zOrder == 'c' ? 2 : rank - 3) : 0;
uint yAxis3 = rank > 3 ? shape::sizeAt(yShapeInfo, zOrder == 'c' ? 3 : rank - 4) : 0;
uint yAxis4 = rank > 4 ? shape::sizeAt(yShapeInfo, zOrder == 'c' ? 4 : rank - 5) : 0;
Nd4jLong yStrd0 = shape::strideAt(yShapeInfo, zOrder == 'c' ? 0 : rank-1);
Nd4jLong yStrd1 = shape::strideAt(yShapeInfo, zOrder == 'c' ? 1 : rank-2);
Nd4jLong yStrd2 = rank > 2 ? shape::strideAt(yShapeInfo, zOrder == 'c' ? 2 : rank - 3) : 0;
Nd4jLong yStrd3 = rank > 3 ? shape::strideAt(yShapeInfo, zOrder == 'c' ? 3 : rank - 4) : 0;
Nd4jLong yStrd4 = rank > 4 ? shape::strideAt(yShapeInfo, zOrder == 'c' ? 4 : rank - 5) : 0;
uint zAxis0 = shape::sizeAt(zShapeInfo, zOrder == 'c' ? 0 : rank-1);
uint zAxis1 = shape::sizeAt(zShapeInfo, zOrder == 'c' ? 1 : rank-2);
uint zAxis2 = rank > 2 ? shape::sizeAt(zShapeInfo, zOrder == 'c' ? 2 : rank - 3) : 0;
uint zAxis3 = rank > 3 ? shape::sizeAt(zShapeInfo, zOrder == 'c' ? 3 : rank - 4) : 0;
uint zAxis4 = rank > 4 ? shape::sizeAt(zShapeInfo, zOrder == 'c' ? 4 : rank - 5) : 0;
Nd4jLong zStrd0 = shape::strideAt(zShapeInfo, zOrder == 'c' ? 0 : rank-1);
Nd4jLong zStrd1 = shape::strideAt(zShapeInfo, zOrder == 'c' ? 1 : rank-2);
Nd4jLong zStrd2 = rank > 2 ? shape::strideAt(zShapeInfo, zOrder == 'c' ? 2 : rank - 3) : 0;
Nd4jLong zStrd3 = rank > 3 ? shape::strideAt(zShapeInfo, zOrder == 'c' ? 3 : rank - 4) : 0;
Nd4jLong zStrd4 = rank > 4 ? shape::strideAt(zShapeInfo, zOrder == 'c' ? 4 : rank - 5) : 0;
switch (rank) {
case 1: {
auto func = PRAGMA_THREADS_FOR{
if(zStrd0 == 1 && xStrd0 == 1 && yStrd0 == 0) {
for (auto i0 = start; i0 < stop; ++i0)
z[i0] = OpType::op(x[i0], *y, extraParams);
}
else if(zStrd0 == 1 && xStrd0 == 0 && yStrd0 == 1) {
for (auto i0 = start; i0 < stop; ++i0)
z[i0] = OpType::op(*x, y[i0], extraParams);
}
else if(zStrd0 == 1 && xStrd0 == 1 && yStrd0 == 1) {
for (auto i0 = start; i0 < stop; ++i0)
z[i0] = OpType::op(x[i0], y[i0], extraParams);
}
else {
for (auto i0 = start; i0 < stop; ++i0)
z[i0 * zStrd0] = OpType::op(x[i0 * xStrd0], y[i0 * yStrd0], extraParams);
}
};
samediff::Threads::parallel_tad(func, 0, zAxis0);
}
break;
case 2: {
auto func = PRAGMA_THREADS_FOR{
for (auto i0 = start; i0 < stop; ++i0) {
auto x0 = x + i0 * xStrd0;
auto y0 = y + i0 * yStrd0;
auto z0 = z + i0 * zStrd0;
if(zStrd1 == 1 && xStrd1 == 1 && yStrd1 == 0)
for (uint i1 = 0; i1 < zAxis1; ++i1)
z0[i1] = OpType::op(x0[i1], *y0, extraParams);
else if(zStrd1 == 1 && xStrd1 == 0 && yStrd1 == 1)
for (uint i1 = 0; i1 < zAxis1; ++i1)
z0[i1] = OpType::op(*x0, y0[i1], extraParams);
else if(zStrd1 == 1 && xStrd1 == 1 && yStrd1 == 1)
for (uint i1 = 0; i1 < zAxis1; ++i1)
z0[i1] = OpType::op(x0[i1], y0[i1], extraParams);
else
for (uint i1 = 0; i1 < zAxis1; ++i1)
z0[i1 * zStrd1] = OpType::op(x0[i1 * xStrd1], y0[i1 * yStrd1], extraParams);
}
};
samediff::Threads::parallel_tad(func, 0, zAxis0);
}
break;
case 3: {
auto func = PRAGMA_THREADS_FOR_2D {
for (auto i0 = start_x; i0 < stop_x; ++i0) {
for (auto i1 = start_y; i1 < stop_y; ++i1) {
auto x1 = x + i0 * xStrd0 + i1 * xStrd1;
auto y1 = y + i0 * yStrd0 + i1 * yStrd1;
auto z1 = z + i0 * zStrd0 + i1 * zStrd1;
if(zStrd2 == 1 && xStrd2 == 1 && yStrd2 == 0)
for (uint i2 = 0; i2 < zAxis2; ++i2)
z1[i2] = OpType::op(x1[i2], *y1, extraParams);
else if(zStrd2 == 1 && xStrd2 == 0 && yStrd2 == 1)
for (uint i2 = 0; i2 < zAxis2; ++i2)
z1[i2] = OpType::op(*x1, y1[i2], extraParams);
else if(zStrd2 == 1 && xStrd2 == 1 && yStrd2 == 1)
for (uint i2 = 0; i2 < zAxis2; ++i2)
z1[i2] = OpType::op(x1[i2], y1[i2], extraParams);
else
for (uint i2 = 0; i2 < zAxis2; ++i2)
z1[i2 * zStrd2] = OpType::op(x1[i2 * xStrd2], y1[i2 * yStrd2], extraParams);
}
}
};
samediff::Threads::parallel_for(func, 0,zAxis0,1, 0,zAxis1,1);
}
break;
case 4: {
auto func = PRAGMA_THREADS_FOR_3D {
for (auto i0 = start_x; i0 < stop_x; ++i0) {
for (auto i1 = start_y; i1 < stop_y; ++i1) {
for (auto i2 = start_z; i2 < stop_z; ++i2) {
auto x2 = x + i0 * xStrd0 + i1 * xStrd1 + i2 * xStrd2;
auto y2 = y + i0 * yStrd0 + i1 * yStrd1 + i2 * yStrd2;
auto z2 = z + i0 * zStrd0 + i1 * zStrd1 + i2 * zStrd2;
if(zStrd3 == 1 && xStrd3 == 1 && yStrd3 == 0)
for (uint i3 = 0; i3 < zAxis3; ++i3)
z2[i3] = OpType::op(x2[i3], *y2, extraParams);
else if(zStrd3 == 1 && xStrd3 == 0 && yStrd3 == 1)
for (uint i3 = 0; i3 < zAxis3; ++i3)
z2[i3] = OpType::op(*x2, y2[i3], extraParams);
else if(zStrd3 == 1 && xStrd3 == 1 && yStrd3 == 1)
for (uint i3 = 0; i3 < zAxis3; ++i3)
z2[i3] = OpType::op(x2[i3], y2[i3], extraParams);
else
for (uint i3 = 0; i3 < zAxis3; ++i3)
z2[i3 * zStrd3] = OpType::op(x2[i3 * xStrd3], y2[i3 * yStrd3], extraParams);
}
}
}
};
samediff::Threads::parallel_for(func, 0,zAxis0,1, 0,zAxis1,1, 0,zAxis2,1);
}
break;
case 5: {
auto func = PRAGMA_THREADS_FOR_3D {
for (auto i0 = start_x; i0 < stop_x; ++i0) {
for (auto i1 = start_y; i1 < stop_y; ++i1) {
for (auto i2 = start_z; i2 < stop_z; ++i2) {
for (uint i3 = 0; i3 < zAxis3; ++i3) {
auto x3 = x + i0 * xStrd0 + i1 * xStrd1 + i2 * xStrd2 + i3 * xStrd3;
auto y3 = y + i0 * yStrd0 + i1 * yStrd1 + i2 * yStrd2 + i3 * yStrd3;
auto z3 = z + i0 * zStrd0 + i1 * zStrd1 + i2 * zStrd2 + i3 * zStrd3;
if(zStrd4 == 1 && xStrd4 == 1 && yStrd4 == 0)
for (uint i4 = 0; i4 < zAxis4; ++i4)
z3[i4] = OpType::op(x3[i4], *y3, extraParams);
else if(zStrd4 == 1 && xStrd4 == 0 && yStrd4 == 1)
for (uint i4 = 0; i4 < zAxis4; ++i4)
z3[i4] = OpType::op(*x3, y3[i4], extraParams);
else if(zStrd4 == 1 && xStrd4 == 1 && yStrd4 == 1)
for (uint i4 = 0; i4 < zAxis4; ++i4)
z3[i4] = OpType::op(x3[i4], y3[i4], extraParams);
else
for (uint i4 = 0; i4 < zAxis4; ++i4)
z3[i4 * zStrd4] = OpType::op(x3[i4 * xStrd4], y3[i4 * yStrd4], extraParams);
}
}
}
}
};
samediff::Threads::parallel_for(func, 0,zAxis0,1, 0,zAxis1,1, 0,zAxis2,1);
}
break;
default: {
const bool xzSameOffsets = shape::haveSameShapeAndStrides(xShapeInfo, zShapeInfo);
const bool yzSameOffsets = shape::haveSameShapeAndStrides(yShapeInfo, zShapeInfo);
auto func = PRAGMA_THREADS_FOR{
int xCoords[MAX_RANK], yCoords[MAX_RANK], zCoords[MAX_RANK];
for (auto i = start; i < stop; ++i) {
shape::index2coordsCPU(start, i, zShapeInfo, zCoords);
for (uint j = 0; j < rank; ++j) {
xCoords[j] = shape::sizeAt(xShapeInfo, j) == 1 ? 0 : zCoords[j];
yCoords[j] = shape::sizeAt(yShapeInfo, j) == 1 ? 0 : zCoords[j];
}
const auto zOffset = shape::getOffset(zShapeInfo, zCoords);
const auto xOffset = xzSameOffsets ? zOffset : shape::getOffset(xShapeInfo, xCoords);
const auto yOffset = yzSameOffsets ? zOffset : shape::getOffset(yShapeInfo, yCoords);
z[zOffset] = OpType::op(x[xOffset], y[yOffset], extraParams);
}
};
samediff::Threads::parallel_for(func, 0, shape::length(zShapeInfo));
}
case 1:
execRank1<X,Z, OpType>(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, extraParams);
break;
case 2:
execRank2<X,Z, OpType>(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, extraParams);
break;
case 3:
execRank3<X,Z, OpType>(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, extraParams);
break;
case 4:
execRank4<X,Z, OpType>(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, extraParams);
break;
case 5:
execRank5<X,Z, OpType>(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, extraParams);
break;
default:
execDefault<X,Z, OpType>(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, extraParams);
}
}

View File

@ -439,6 +439,271 @@ namespace functions {
}
}
////////////////////////////////////////////////////////////////////////
template <typename X, typename OpType>
static void execRank1(const X *x, const Nd4jLong *xShapeInfo, const X *y, const Nd4jLong *yShapeInfo, X* z, const Nd4jLong *zShapeInfo) {
uint zAxis0 = shape::sizeAt(zShapeInfo, 0);
Nd4jLong xStrd0 = shape::strideAt(xShapeInfo, 0);
Nd4jLong yStrd0 = shape::strideAt(yShapeInfo, 0);
Nd4jLong zStrd0 = shape::strideAt(zShapeInfo, 0);
auto func = PRAGMA_THREADS_FOR{
if(zStrd0 == 1 && xStrd0 == 1 && yStrd0 == 0) {
for (auto i0 = start; i0 < stop; ++i0)
z[i0] = OpType::op(x[i0], *y);
}
else if(zStrd0 == 1 && xStrd0 == 0 && yStrd0 == 1) {
for (auto i0 = start; i0 < stop; ++i0)
z[i0] = OpType::op(*x, y[i0]);
}
else if(zStrd0 == 1 && xStrd0 == 1 && yStrd0 == 1) {
for (auto i0 = start; i0 < stop; ++i0)
z[i0] = OpType::op(x[i0], y[i0]);
}
else {
for (auto i0 = start; i0 < stop; ++i0)
z[i0 * zStrd0] = OpType::op(x[i0 * xStrd0], y[i0 * yStrd0]);
}
};
samediff::Threads::parallel_tad(func, 0, zAxis0);
}
////////////////////////////////////////////////////////////////////////
template <typename X, typename OpType>
static void execRank2(const X *x, const Nd4jLong *xShapeInfo, const X *y, const Nd4jLong *yShapeInfo, X* z, const Nd4jLong *zShapeInfo) {
uint zAxis0 = shape::sizeAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 1);
Nd4jLong xStrd0 = shape::strideAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 1);
Nd4jLong yStrd0 = shape::strideAt(yShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 1);
Nd4jLong zStrd0 = shape::strideAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 1);
uint zAxis1 = shape::sizeAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 1 : 0);
Nd4jLong xStrd1 = shape::strideAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? 1 : 0);
Nd4jLong yStrd1 = shape::strideAt(yShapeInfo, shape::order(zShapeInfo) == 'c' ? 1 : 0);
Nd4jLong zStrd1 = shape::strideAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 1 : 0);
auto func = PRAGMA_THREADS_FOR{
for (auto i0 = start; i0 < stop; ++i0) {
auto x0 = x + i0 * xStrd0;
auto y0 = y + i0 * yStrd0;
auto z0 = z + i0 * zStrd0;
if(zStrd1 == 1 && xStrd1 == 1 && yStrd1 == 0)
for (uint i1 = 0; i1 < zAxis1; ++i1)
z0[i1] = OpType::op(x0[i1], *y0);
else if(zStrd1 == 1 && xStrd1 == 0 && yStrd1 == 1)
for (uint i1 = 0; i1 < zAxis1; ++i1)
z0[i1] = OpType::op(*x0, y0[i1]);
else if(zStrd1 == 1 && xStrd1 == 1 && yStrd1 == 1)
for (uint i1 = 0; i1 < zAxis1; ++i1)
z0[i1] = OpType::op(x0[i1], y0[i1]);
else
for (uint i1 = 0; i1 < zAxis1; ++i1)
z0[i1 * zStrd1] = OpType::op(x0[i1 * xStrd1], y0[i1 * yStrd1]);
}
};
samediff::Threads::parallel_tad(func, 0, zAxis0);
}
////////////////////////////////////////////////////////////////////////
template <typename X, typename OpType>
static void execRank3(const X *x, const Nd4jLong *xShapeInfo, const X *y, const Nd4jLong *yShapeInfo, X* z, const Nd4jLong *zShapeInfo) {
uint zAxis0 = shape::sizeAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 2);
Nd4jLong xStrd0 = shape::strideAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 2);
Nd4jLong yStrd0 = shape::strideAt(yShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 2);
Nd4jLong zStrd0 = shape::strideAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 2);
uint zAxis1 = shape::sizeAt(zShapeInfo, 1);
Nd4jLong xStrd1 = shape::strideAt(xShapeInfo, 1);
Nd4jLong yStrd1 = shape::strideAt(yShapeInfo, 1);
Nd4jLong zStrd1 = shape::strideAt(zShapeInfo, 1);
uint zAxis2 = shape::sizeAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 2 : 0);
Nd4jLong xStrd2 = shape::strideAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? 2 : 0);
Nd4jLong yStrd2 = shape::strideAt(yShapeInfo, shape::order(zShapeInfo) == 'c' ? 2 : 0);
Nd4jLong zStrd2 = shape::strideAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 2 : 0);
auto func = PRAGMA_THREADS_FOR_2D {
for (auto i0 = start_x; i0 < stop_x; ++i0) {
for (auto i1 = start_y; i1 < stop_y; ++i1) {
auto x1 = x + i0 * xStrd0 + i1 * xStrd1;
auto y1 = y + i0 * yStrd0 + i1 * yStrd1;
auto z1 = z + i0 * zStrd0 + i1 * zStrd1;
if(zStrd2 == 1 && xStrd2 == 1 && yStrd2 == 0)
for (uint i2 = 0; i2 < zAxis2; ++i2)
z1[i2] = OpType::op(x1[i2], *y1);
else if(zStrd2 == 1 && xStrd2 == 0 && yStrd2 == 1)
for (uint i2 = 0; i2 < zAxis2; ++i2)
z1[i2] = OpType::op(*x1, y1[i2]);
else if(zStrd2 == 1 && xStrd2 == 1 && yStrd2 == 1)
for (uint i2 = 0; i2 < zAxis2; ++i2)
z1[i2] = OpType::op(x1[i2], y1[i2]);
else
for (uint i2 = 0; i2 < zAxis2; ++i2)
z1[i2 * zStrd2] = OpType::op(x1[i2 * xStrd2], y1[i2 * yStrd2]);
}
}
};
samediff::Threads::parallel_for(func, 0,zAxis0,1, 0,zAxis1,1);
}
////////////////////////////////////////////////////////////////////////
template <typename X, typename OpType>
static void execRank4(const X *x, const Nd4jLong *xShapeInfo, const X *y, const Nd4jLong *yShapeInfo, X* z, const Nd4jLong *zShapeInfo) {
uint zAxis0 = shape::sizeAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 3);
Nd4jLong xStrd0 = shape::strideAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 3);
Nd4jLong yStrd0 = shape::strideAt(yShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 3);
Nd4jLong zStrd0 = shape::strideAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 3);
uint zAxis1 = shape::sizeAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 1 : 2);
Nd4jLong xStrd1 = shape::strideAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? 1 : 2);
Nd4jLong yStrd1 = shape::strideAt(yShapeInfo, shape::order(zShapeInfo) == 'c' ? 1 : 2);
Nd4jLong zStrd1 = shape::strideAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 1 : 2);
uint zAxis2 = shape::sizeAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 2 : 1);
Nd4jLong xStrd2 = shape::strideAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? 2 : 1);
Nd4jLong yStrd2 = shape::strideAt(yShapeInfo, shape::order(zShapeInfo) == 'c' ? 2 : 1);
Nd4jLong zStrd2 = shape::strideAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 2 : 1);
uint zAxis3 = shape::sizeAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 3 : 0);
Nd4jLong xStrd3 = shape::strideAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? 3 : 0);
Nd4jLong yStrd3 = shape::strideAt(yShapeInfo, shape::order(zShapeInfo) == 'c' ? 3 : 0);
Nd4jLong zStrd3 = shape::strideAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 3 : 0);
auto func = PRAGMA_THREADS_FOR_3D {
for (auto i0 = start_x; i0 < stop_x; ++i0) {
for (auto i1 = start_y; i1 < stop_y; ++i1) {
for (auto i2 = start_z; i2 < stop_z; ++i2) {
auto x2 = x + i0 * xStrd0 + i1 * xStrd1 + i2 * xStrd2;
auto y2 = y + i0 * yStrd0 + i1 * yStrd1 + i2 * yStrd2;
auto z2 = z + i0 * zStrd0 + i1 * zStrd1 + i2 * zStrd2;
if(zStrd3 == 1 && xStrd3 == 1 && yStrd3 == 0)
for (uint i3 = 0; i3 < zAxis3; ++i3)
z2[i3] = OpType::op(x2[i3], *y2);
else if(zStrd3 == 1 && xStrd3 == 0 && yStrd3 == 1)
for (uint i3 = 0; i3 < zAxis3; ++i3)
z2[i3] = OpType::op(*x2, y2[i3]);
else if(zStrd3 == 1 && xStrd3 == 1 && yStrd3 == 1)
for (uint i3 = 0; i3 < zAxis3; ++i3)
z2[i3] = OpType::op(x2[i3], y2[i3]);
else
for (uint i3 = 0; i3 < zAxis3; ++i3)
z2[i3 * zStrd3] = OpType::op(x2[i3 * xStrd3], y2[i3 * yStrd3]);
}
}
}
};
samediff::Threads::parallel_for(func, 0,zAxis0,1, 0,zAxis1,1, 0,zAxis2,1);
}
////////////////////////////////////////////////////////////////////////
template <typename X, typename OpType>
static void execRank5(const X *x, const Nd4jLong *xShapeInfo, const X *y, const Nd4jLong *yShapeInfo, X* z, const Nd4jLong *zShapeInfo) {
uint zAxis0 = shape::sizeAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 4);
Nd4jLong xStrd0 = shape::strideAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 4);
Nd4jLong yStrd0 = shape::strideAt(yShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 4);
Nd4jLong zStrd0 = shape::strideAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 4);
uint zAxis1 = shape::sizeAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 1 : 3);
Nd4jLong xStrd1 = shape::strideAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? 1 : 3);
Nd4jLong yStrd1 = shape::strideAt(yShapeInfo, shape::order(zShapeInfo) == 'c' ? 1 : 3);
Nd4jLong zStrd1 = shape::strideAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 1 : 3);
uint zAxis2 = shape::sizeAt(zShapeInfo, 2);
Nd4jLong xStrd2 = shape::strideAt(xShapeInfo, 2);
Nd4jLong yStrd2 = shape::strideAt(yShapeInfo, 2);
Nd4jLong zStrd2 = shape::strideAt(zShapeInfo, 2);
uint zAxis3 = shape::sizeAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 3 : 1);
Nd4jLong xStrd3 = shape::strideAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? 3 : 1);
Nd4jLong yStrd3 = shape::strideAt(yShapeInfo, shape::order(zShapeInfo) == 'c' ? 3 : 1);
Nd4jLong zStrd3 = shape::strideAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 3 : 1);
uint zAxis4 = shape::sizeAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 4 : 0);
Nd4jLong xStrd4 = shape::strideAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? 4 : 0);
Nd4jLong yStrd4 = shape::strideAt(yShapeInfo, shape::order(zShapeInfo) == 'c' ? 4 : 0);
Nd4jLong zStrd4 = shape::strideAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 4 : 0);
auto func = PRAGMA_THREADS_FOR_3D {
for (auto i0 = start_x; i0 < stop_x; ++i0) {
for (auto i1 = start_y; i1 < stop_y; ++i1) {
for (auto i2 = start_z; i2 < stop_z; ++i2) {
for (uint i3 = 0; i3 < zAxis3; ++i3) {
auto x3 = x + i0 * xStrd0 + i1 * xStrd1 + i2 * xStrd2 + i3 * xStrd3;
auto y3 = y + i0 * yStrd0 + i1 * yStrd1 + i2 * yStrd2 + i3 * yStrd3;
auto z3 = z + i0 * zStrd0 + i1 * zStrd1 + i2 * zStrd2 + i3 * zStrd3;
if(zStrd4 == 1 && xStrd4 == 1 && yStrd4 == 0)
for (uint i4 = 0; i4 < zAxis4; ++i4)
z3[i4] = OpType::op(x3[i4], *y3);
else if(zStrd4 == 1 && xStrd4 == 0 && yStrd4 == 1)
for (uint i4 = 0; i4 < zAxis4; ++i4)
z3[i4] = OpType::op(*x3, y3[i4]);
else if(zStrd4 == 1 && xStrd4 == 1 && yStrd4 == 1)
for (uint i4 = 0; i4 < zAxis4; ++i4)
z3[i4] = OpType::op(x3[i4], y3[i4]);
else
for (uint i4 = 0; i4 < zAxis4; ++i4)
z3[i4 * zStrd4] = OpType::op(x3[i4 * xStrd4], y3[i4 * yStrd4]);
}
}
}
}
};
samediff::Threads::parallel_for(func, 0,zAxis0,1, 0,zAxis1,1, 0,zAxis2,1);
}
////////////////////////////////////////////////////////////////////////
template <typename X, typename OpType>
static void execDefault(const X *x, const Nd4jLong *xShapeInfo, const X *y, const Nd4jLong *yShapeInfo, X* z, const Nd4jLong *zShapeInfo) {
const bool xzSameOffsets = shape::haveSameShapeAndStrides(xShapeInfo, zShapeInfo);
const bool yzSameOffsets = shape::haveSameShapeAndStrides(yShapeInfo, zShapeInfo);
const int rank = shape::rank(zShapeInfo); // xRank = yRank = zRank
auto func = PRAGMA_THREADS_FOR{
int xCoords[MAX_RANK], yCoords[MAX_RANK], zCoords[MAX_RANK];
for (auto i = start; i < stop; ++i) {
shape::index2coordsCPU(start, i, zShapeInfo, zCoords);
for (uint j = 0; j < rank; ++j) {
xCoords[j] = shape::sizeAt(xShapeInfo, j) == 1 ? 0 : zCoords[j];
yCoords[j] = shape::sizeAt(yShapeInfo, j) == 1 ? 0 : zCoords[j];
}
const auto zOffset = shape::getOffset(zShapeInfo, zCoords);
const auto xOffset = xzSameOffsets ? zOffset : shape::getOffset(xShapeInfo, xCoords);
const auto yOffset = yzSameOffsets ? zOffset : shape::getOffset(yShapeInfo, yCoords);
z[zOffset] = OpType::op(x[xOffset], y[yOffset]);
}
};
samediff::Threads::parallel_for(func, 0, shape::length(zShapeInfo));
}
////////////////////////////////////////////////////////////////////////
template <typename X>
@ -452,220 +717,26 @@ void BroadcastInt<X>::exec(const void *vx, const Nd4jLong *xShapeInfo,
X* z = reinterpret_cast<X*>(vz);
const int rank = shape::rank(zShapeInfo); // xRank = yRank = zRank
const char zOrder = shape::order(zShapeInfo);
uint xAxis0 = shape::sizeAt(xShapeInfo, zOrder == 'c' ? 0 : rank-1);
uint xAxis1 = shape::sizeAt(xShapeInfo, zOrder == 'c' ? 1 : rank-2);
uint xAxis2 = rank > 2 ? shape::sizeAt(xShapeInfo, zOrder == 'c' ? 2 : rank - 3) : 0;
uint xAxis3 = rank > 3 ? shape::sizeAt(xShapeInfo, zOrder == 'c' ? 3 : rank - 4) : 0;
uint xAxis4 = rank > 4 ? shape::sizeAt(xShapeInfo, zOrder == 'c' ? 4 : rank - 5) : 0;
Nd4jLong xStrd0 = shape::strideAt(xShapeInfo, zOrder == 'c' ? 0 : rank-1);
Nd4jLong xStrd1 = shape::strideAt(xShapeInfo, zOrder == 'c' ? 1 : rank-2);
Nd4jLong xStrd2 = rank > 2 ? shape::strideAt(xShapeInfo, zOrder == 'c' ? 2 : rank - 3) : 0;
Nd4jLong xStrd3 = rank > 3 ? shape::strideAt(xShapeInfo, zOrder == 'c' ? 3 : rank - 4) : 0;
Nd4jLong xStrd4 = rank > 4 ? shape::strideAt(xShapeInfo, zOrder == 'c' ? 4 : rank - 5) : 0;
uint yAxis0 = shape::sizeAt(yShapeInfo, zOrder == 'c' ? 0 : rank-1);
uint yAxis1 = shape::sizeAt(yShapeInfo, zOrder == 'c' ? 1 : rank-2);
uint yAxis2 = rank > 2 ? shape::sizeAt(yShapeInfo, zOrder == 'c' ? 2 : rank - 3) : 0;
uint yAxis3 = rank > 3 ? shape::sizeAt(yShapeInfo, zOrder == 'c' ? 3 : rank - 4) : 0;
uint yAxis4 = rank > 4 ? shape::sizeAt(yShapeInfo, zOrder == 'c' ? 4 : rank - 5) : 0;
Nd4jLong yStrd0 = shape::strideAt(yShapeInfo, zOrder == 'c' ? 0 : rank-1);
Nd4jLong yStrd1 = shape::strideAt(yShapeInfo, zOrder == 'c' ? 1 : rank-2);
Nd4jLong yStrd2 = rank > 2 ? shape::strideAt(yShapeInfo, zOrder == 'c' ? 2 : rank - 3) : 0;
Nd4jLong yStrd3 = rank > 3 ? shape::strideAt(yShapeInfo, zOrder == 'c' ? 3 : rank - 4) : 0;
Nd4jLong yStrd4 = rank > 4 ? shape::strideAt(yShapeInfo, zOrder == 'c' ? 4 : rank - 5) : 0;
uint zAxis0 = shape::sizeAt(zShapeInfo, zOrder == 'c' ? 0 : rank-1);
uint zAxis1 = shape::sizeAt(zShapeInfo, zOrder == 'c' ? 1 : rank-2);
uint zAxis2 = rank > 2 ? shape::sizeAt(zShapeInfo, zOrder == 'c' ? 2 : rank - 3) : 0;
uint zAxis3 = rank > 3 ? shape::sizeAt(zShapeInfo, zOrder == 'c' ? 3 : rank - 4) : 0;
uint zAxis4 = rank > 4 ? shape::sizeAt(zShapeInfo, zOrder == 'c' ? 4 : rank - 5) : 0;
Nd4jLong zStrd0 = shape::strideAt(zShapeInfo, zOrder == 'c' ? 0 : rank-1);
Nd4jLong zStrd1 = shape::strideAt(zShapeInfo, zOrder == 'c' ? 1 : rank-2);
Nd4jLong zStrd2 = rank > 2 ? shape::strideAt(zShapeInfo, zOrder == 'c' ? 2 : rank - 3) : 0;
Nd4jLong zStrd3 = rank > 3 ? shape::strideAt(zShapeInfo, zOrder == 'c' ? 3 : rank - 4) : 0;
Nd4jLong zStrd4 = rank > 4 ? shape::strideAt(zShapeInfo, zOrder == 'c' ? 4 : rank - 5) : 0;
switch (rank) {
case 1: {
auto func = PRAGMA_THREADS_FOR{
if(zStrd0 == 1 && xStrd0 == 1 && yStrd0 == 0) {
for (auto i0 = start; i0 < stop; ++i0)
z[i0] = OpType::op(x[i0], *y);
}
else if(zStrd0 == 1 && xStrd0 == 0 && yStrd0 == 1) {
for (auto i0 = start; i0 < stop; ++i0)
z[i0] = OpType::op(*x, y[i0]);
}
else if(zStrd0 == 1 && xStrd0 == 1 && yStrd0 == 1) {
for (auto i0 = start; i0 < stop; ++i0)
z[i0] = OpType::op(x[i0], y[i0]);
}
else {
for (auto i0 = start; i0 < stop; ++i0)
z[i0 * zStrd0] = OpType::op(x[i0 * xStrd0], y[i0 * yStrd0]);
}
};
samediff::Threads::parallel_tad(func, 0, zAxis0);
}
break;
case 2: {
auto func = PRAGMA_THREADS_FOR{
for (auto i0 = start; i0 < stop; ++i0) {
auto x0 = x + i0 * xStrd0;
auto y0 = y + i0 * yStrd0;
auto z0 = z + i0 * zStrd0;
if(zStrd1 == 1 && xStrd1 == 1 && yStrd1 == 0)
for (uint i1 = 0; i1 < zAxis1; ++i1)
z0[i1] = OpType::op(x0[i1], *y0);
else if(zStrd1 == 1 && xStrd1 == 0 && yStrd1 == 1)
for (uint i1 = 0; i1 < zAxis1; ++i1)
z0[i1] = OpType::op(*x0, y0[i1]);
else if(zStrd1 == 1 && xStrd1 == 1 && yStrd1 == 1)
for (uint i1 = 0; i1 < zAxis1; ++i1)
z0[i1] = OpType::op(x0[i1], y0[i1]);
else
for (uint i1 = 0; i1 < zAxis1; ++i1)
z0[i1 * zStrd1] = OpType::op(x0[i1 * xStrd1], y0[i1 * yStrd1]);
}
};
samediff::Threads::parallel_tad(func, 0, zAxis0);
}
break;
case 3: {
auto func = PRAGMA_THREADS_FOR_2D {
for (auto i0 = start_x; i0 < stop_x; ++i0) {
for (auto i1 = start_y; i1 < stop_y; ++i1) {
auto x1 = x + i0 * xStrd0 + i1 * xStrd1;
auto y1 = y + i0 * yStrd0 + i1 * yStrd1;
auto z1 = z + i0 * zStrd0 + i1 * zStrd1;
if(zStrd2 == 1 && xStrd2 == 1 && yStrd2 == 0)
for (uint i2 = 0; i2 < zAxis2; ++i2)
z1[i2] = OpType::op(x1[i2], *y1);
else if(zStrd2 == 1 && xStrd2 == 0 && yStrd2 == 1)
for (uint i2 = 0; i2 < zAxis2; ++i2)
z1[i2] = OpType::op(*x1, y1[i2]);
else if(zStrd2 == 1 && xStrd2 == 1 && yStrd2 == 1)
for (uint i2 = 0; i2 < zAxis2; ++i2)
z1[i2] = OpType::op(x1[i2], y1[i2]);
else
for (uint i2 = 0; i2 < zAxis2; ++i2)
z1[i2 * zStrd2] = OpType::op(x1[i2 * xStrd2], y1[i2 * yStrd2]);
}
}
};
samediff::Threads::parallel_for(func, 0,zAxis0,1, 0,zAxis1,1);
}
break;
case 4: {
auto func = PRAGMA_THREADS_FOR_3D {
for (auto i0 = start_x; i0 < stop_x; ++i0) {
for (auto i1 = start_y; i1 < stop_y; ++i1) {
for (auto i2 = start_z; i2 < stop_z; ++i2) {
auto x2 = x + i0 * xStrd0 + i1 * xStrd1 + i2 * xStrd2;
auto y2 = y + i0 * yStrd0 + i1 * yStrd1 + i2 * yStrd2;
auto z2 = z + i0 * zStrd0 + i1 * zStrd1 + i2 * zStrd2;
if(zStrd3 == 1 && xStrd3 == 1 && yStrd3 == 0)
for (uint i3 = 0; i3 < zAxis3; ++i3)
z2[i3] = OpType::op(x2[i3], *y2);
else if(zStrd3 == 1 && xStrd3 == 0 && yStrd3 == 1)
for (uint i3 = 0; i3 < zAxis3; ++i3)
z2[i3] = OpType::op(*x2, y2[i3]);
else if(zStrd3 == 1 && xStrd3 == 1 && yStrd3 == 1)
for (uint i3 = 0; i3 < zAxis3; ++i3)
z2[i3] = OpType::op(x2[i3], y2[i3]);
else
for (uint i3 = 0; i3 < zAxis3; ++i3)
z2[i3 * zStrd3] = OpType::op(x2[i3 * xStrd3], y2[i3 * yStrd3]);
}
}
}
};
samediff::Threads::parallel_for(func, 0,zAxis0,1, 0,zAxis1,1, 0,zAxis2,1);
}
break;
case 5: {
auto func = PRAGMA_THREADS_FOR_3D {
for (auto i0 = start_x; i0 < stop_x; ++i0) {
for (auto i1 = start_y; i1 < stop_y; ++i1) {
for (auto i2 = start_z; i2 < stop_z; ++i2) {
for (uint i3 = 0; i3 < zAxis3; ++i3) {
auto x3 = x + i0 * xStrd0 + i1 * xStrd1 + i2 * xStrd2 + i3 * xStrd3;
auto y3 = y + i0 * yStrd0 + i1 * yStrd1 + i2 * yStrd2 + i3 * yStrd3;
auto z3 = z + i0 * zStrd0 + i1 * zStrd1 + i2 * zStrd2 + i3 * zStrd3;
if(zStrd4 == 1 && xStrd4 == 1 && yStrd4 == 0)
for (uint i4 = 0; i4 < zAxis4; ++i4)
z3[i4] = OpType::op(x3[i4], *y3);
else if(zStrd4 == 1 && xStrd4 == 0 && yStrd4 == 1)
for (uint i4 = 0; i4 < zAxis4; ++i4)
z3[i4] = OpType::op(*x3, y3[i4]);
else if(zStrd4 == 1 && xStrd4 == 1 && yStrd4 == 1)
for (uint i4 = 0; i4 < zAxis4; ++i4)
z3[i4] = OpType::op(x3[i4], y3[i4]);
else
for (uint i4 = 0; i4 < zAxis4; ++i4)
z3[i4 * zStrd4] = OpType::op(x3[i4 * xStrd4], y3[i4 * yStrd4]);
}
}
}
}
};
samediff::Threads::parallel_for(func, 0,zAxis0,1, 0,zAxis1,1, 0,zAxis2,1);
}
break;
default: {
const bool xzSameOffsets = shape::haveSameShapeAndStrides(xShapeInfo, zShapeInfo);
const bool yzSameOffsets = shape::haveSameShapeAndStrides(yShapeInfo, zShapeInfo);
auto func = PRAGMA_THREADS_FOR{
int xCoords[MAX_RANK], yCoords[MAX_RANK], zCoords[MAX_RANK];
for (auto i = start; i < stop; ++i) {
shape::index2coordsCPU(start, i, zShapeInfo, zCoords);
for (uint j = 0; j < rank; ++j) {
xCoords[j] = shape::sizeAt(xShapeInfo, j) == 1 ? 0 : zCoords[j];
yCoords[j] = shape::sizeAt(yShapeInfo, j) == 1 ? 0 : zCoords[j];
}
const auto zOffset = shape::getOffset(zShapeInfo, zCoords);
const auto xOffset = xzSameOffsets ? zOffset : shape::getOffset(xShapeInfo, xCoords);
const auto yOffset = yzSameOffsets ? zOffset : shape::getOffset(yShapeInfo, yCoords);
z[zOffset] = OpType::op(x[xOffset], y[yOffset]);
}
};
samediff::Threads::parallel_for(func, 0, shape::length(zShapeInfo));
}
case 1:
execRank1<X, OpType>(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo);
break;
case 2:
execRank2<X, OpType>(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo);
break;
case 3:
execRank3<X, OpType>(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo);
break;
case 4:
execRank4<X, OpType>(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo);
break;
case 5:
execRank5<X, OpType>(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo);
break;
default:
execDefault<X, OpType>(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo);
}
}

View File

@ -116,13 +116,13 @@ namespace ops {
auto shapes = SHAPELIST();
//Edge case: splitting empty array (mainly for TF import compatibility) -> return N empty arrays
if(INPUT_VARIABLE(inputVar)->isEmpty()){
for (int e = 0; e < num_splits; e++) {
auto empty = ConstantShapeHelper::getInstance()->emptyShapeInfo(dataType);
shapes->push_back(empty);
}
return shapes;
}
// if(INPUT_VARIABLE(inputVar)->isEmpty()){
// for (int e = 0; e < num_splits; e++) {
// auto empty = ConstantShapeHelper::getInstance()->emptyShapeInfo(dataType);
// shapes->push_back(empty);
// }
// return shapes;
// }
if (block.numI() == 2)
axis = INT_ARG(1);

View File

@ -764,50 +764,6 @@ TEST_F(DeclarableOpsTests10, sparse_softmax_cross_entropy_loss_with_logits_test4
ASSERT_TRUE(expected.equalsTo(output));
}
///////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests10, split_test4) {
auto input = NDArrayFactory::create<double>('c', {10},{1.f,2.f,3.f,4.f,5.f,6.f,7.f,8.f,9.f,10.f});
auto axis = NDArrayFactory::create<double>(-1);
auto exp1 = NDArrayFactory::create<double>('c', {5}, {1.f,2.f,3.f,4.f,5.f});
auto exp2 = NDArrayFactory::create<double>('c', {5}, {6.f,7.f,8.f,9.f,10.f});
sd::ops::split op;
auto results = op.evaluate({&input, &axis}, {}, {2}, {});
ASSERT_EQ(ND4J_STATUS_OK, results.status());
auto out1 = results.at(0);
auto out2 = results.at(1);
ASSERT_TRUE(exp1.isSameShape(out1));
ASSERT_TRUE(exp2.isSameShape(out2));
ASSERT_TRUE(exp1.equalsTo(out1));
ASSERT_TRUE(exp2.equalsTo(out2));
}
///////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests10, split_test5) {
auto input = NDArrayFactory::create<double>('c', {3,8},{1.f,2.f,3.f,4.f,5.f,6.f,7.f,8.f,9.f,10.f,11.f,12.f,13.f,14.f,15.f,16.f,17.f,18.f,19.f,20.f,21.f,22.f,23.f,24.f});
auto exp1 = NDArrayFactory::create<double>('c', {3,4}, {1.f,2.f,3.f,4.f, 9.f,10.f,11.f,12.f, 17.f,18.f,19.f,20.f});
auto exp2 = NDArrayFactory::create<double>('c', {3,4}, {5.f,6.f,7.f,8.f, 13.f,14.f,15.f,16.f, 21.f,22.f,23.f,24.f});
sd::ops::split op;
auto results = op.evaluate({&input}, {}, {2,-1},{});
ASSERT_EQ(ND4J_STATUS_OK, results.status());
auto out1 = results.at(0);
auto out2 = results.at(1);
ASSERT_TRUE(exp1.isSameShape(out1));
ASSERT_TRUE(exp2.isSameShape(out2));
ASSERT_TRUE(exp1.equalsTo(out1));
ASSERT_TRUE(exp2.equalsTo(out2));
}
///////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests10, histogram_fixed_width_test1) {

View File

@ -923,10 +923,89 @@ TEST_F(DeclarableOpsTests4, Test_Split_3) {
ASSERT_TRUE(sub0.equalsTo(z0));
ASSERT_TRUE(sub1.equalsTo(z1));
ASSERT_TRUE(sub2.equalsTo(z2));
}
///////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests4, split_test4) {
auto input = NDArrayFactory::create<double>('c', {10},{1.f,2.f,3.f,4.f,5.f,6.f,7.f,8.f,9.f,10.f});
auto axis = NDArrayFactory::create<double>(-1);
auto exp1 = NDArrayFactory::create<double>('c', {5}, {1.f,2.f,3.f,4.f,5.f});
auto exp2 = NDArrayFactory::create<double>('c', {5}, {6.f,7.f,8.f,9.f,10.f});
sd::ops::split op;
auto results = op.evaluate({&input, &axis}, {}, {2}, {});
ASSERT_EQ(ND4J_STATUS_OK, results.status());
auto out1 = results.at(0);
auto out2 = results.at(1);
ASSERT_TRUE(exp1.isSameShape(out1));
ASSERT_TRUE(exp2.isSameShape(out2));
ASSERT_TRUE(exp1.equalsTo(out1));
ASSERT_TRUE(exp2.equalsTo(out2));
}
///////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests4, split_test5) {
auto input = NDArrayFactory::create<double>('c', {3,8},{1.f,2.f,3.f,4.f,5.f,6.f,7.f,8.f,9.f,10.f,11.f,12.f,13.f,14.f,15.f,16.f,17.f,18.f,19.f,20.f,21.f,22.f,23.f,24.f});
auto exp1 = NDArrayFactory::create<double>('c', {3,4}, {1.f,2.f,3.f,4.f, 9.f,10.f,11.f,12.f, 17.f,18.f,19.f,20.f});
auto exp2 = NDArrayFactory::create<double>('c', {3,4}, {5.f,6.f,7.f,8.f, 13.f,14.f,15.f,16.f, 21.f,22.f,23.f,24.f});
sd::ops::split op;
auto results = op.evaluate({&input}, {}, {2,-1},{});
ASSERT_EQ(ND4J_STATUS_OK, results.status());
auto out1 = results.at(0);
auto out2 = results.at(1);
ASSERT_TRUE(exp1.isSameShape(out1));
ASSERT_TRUE(exp2.isSameShape(out2));
ASSERT_TRUE(exp1.equalsTo(out1));
ASSERT_TRUE(exp2.equalsTo(out2));
}
///////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests4, split_test6) {
NDArray input('c', {0,4}, sd::DataType::FLOAT32);
std::vector<Nd4jLong> expShape = {0,1};
const int numSplits = 4;
const int axis = 1;
sd::ops::split op;
auto results = op.evaluate({&input}, {}, {numSplits, axis}, {});
ASSERT_EQ(ND4J_STATUS_OK, results.status());
for (int i = 0; i < numSplits; ++i)
ASSERT_TRUE(results.at(i)->isSameShape(expShape));
}
///////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests4, split_test7) {
NDArray input('c', {0,4}, sd::DataType::FLOAT32);
std::vector<Nd4jLong> expShape = {0,4};
const int numSplits = 4;
const int axis = 0;
sd::ops::split op;
auto results = op.evaluate({&input}, {}, {numSplits, axis}, {});
ASSERT_EQ(ND4J_STATUS_OK, results.status());
for (int i = 0; i < numSplits; ++i)
ASSERT_TRUE(results.at(i)->isSameShape(expShape));
}
TEST_F(DeclarableOpsTests4, Test_Squeeze_args_1) {
auto x = NDArrayFactory::create<double>('c', {2, 1, 1, 1, 2}, {1, 2, 3, 4});
auto exp = NDArrayFactory::create<double>('c', {2, 1, 2}, {1, 2, 3, 4});

View File

@ -582,8 +582,6 @@ TEST_F(DeclarableOpsTests9, concat_test16) {
auto z = result.at(0);
ASSERT_TRUE(exp.isSameShape(z));
}
//////////////////////////////////////////////////////////////////////
@ -611,8 +609,6 @@ TEST_F(DeclarableOpsTests9, concat_test17) {
auto mean = tad.meanNumber().e<double>(0);
ASSERT_NEAR((e+1)*1., mean, 1e-5);
}
}
//////////////////////////////////////////////////////////////////////
@ -803,6 +799,25 @@ TEST_F(DeclarableOpsTests9, concat_test26) {
ASSERT_TRUE(exp.equalsTo(output));
}
//////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests9, concat_test27) {
auto x1 = NDArrayFactory::create<double>('c', {0,1});
auto x2 = NDArrayFactory::create<double>('c', {0,1});
auto x3 = NDArrayFactory::create<double>('c', {0,1});
auto x4 = NDArrayFactory::create<double>('c', {0,1});
std::vector<Nd4jLong> expShape = {0, 4};
sd::ops::concat op;
auto result = op.evaluate({&x1, &x2, &x3, &x4}, {}, {1});
ASSERT_EQ(ND4J_STATUS_OK, result.status());
auto z = result.at(0);
ASSERT_TRUE(z->isSameShape(expShape));
}
//////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests9, tile_bp_test1) {

View File

@ -3772,6 +3772,8 @@ public abstract class BaseNDArray implements INDArray, Iterable {
INDArray ret = Nd4j.createUninitialized(this.dataType(), shape, order);
ret.setData(dup(order).data());
return ret;
} else if (this.isEmpty()) {
return Nd4j.create(this.dataType(), shape);
} else {
INDArray ret = this.dup(order);
return Nd4j.create(ret.data(), shape);

View File

@ -6877,6 +6877,9 @@ public static final int PREALLOC_SIZE = 33554432;
@Namespace("shape") public static native int sizeAt(@Cast("const Nd4jLong*") LongPointer shapeInfo, int dim);
@Namespace("shape") public static native int sizeAt(@Cast("const Nd4jLong*") LongBuffer shapeInfo, int dim);
@Namespace("shape") public static native int sizeAt(@Cast("const Nd4jLong*") long[] shapeInfo, int dim);
@Namespace("shape") public static native @Cast("Nd4jLong") long strideAt(@Cast("const Nd4jLong*") LongPointer shapeInfo, int dim);
@Namespace("shape") public static native @Cast("Nd4jLong") long strideAt(@Cast("const Nd4jLong*") LongBuffer shapeInfo, int dim);
@Namespace("shape") public static native @Cast("Nd4jLong") long strideAt(@Cast("const Nd4jLong*") long[] shapeInfo, int dim);
@Namespace("shape") public static native void traceNew(int id);
@ -7814,14 +7817,20 @@ public static final int PREALLOC_SIZE = 33554432;
@Namespace("shape") public static native void index2coords(@Cast("Nd4jLong") long index, int rank, @Cast("const Nd4jLong*") LongPointer shape, IntPointer coords);
@Namespace("shape") public static native void index2coords(@Cast("Nd4jLong") long index, int rank, @Cast("const Nd4jLong*") LongBuffer shape, IntBuffer coords);
@Namespace("shape") public static native void index2coords(@Cast("Nd4jLong") long index, int rank, @Cast("const Nd4jLong*") long[] shape, int[] coords);
@Namespace("shape") public static native void index2coordsCPU(@Cast("const Nd4jLong") long startIndex, @Cast("const Nd4jLong") long index, @Cast("const Nd4jLong*") LongPointer shapeInfo, @Cast("Nd4jLong*") LongPointer coords);
@Namespace("shape") public static native void index2coordsCPU(@Cast("const Nd4jLong") long startIndex, @Cast("const Nd4jLong") long index, @Cast("const Nd4jLong*") LongBuffer shapeInfo, @Cast("Nd4jLong*") LongBuffer coords);
@Namespace("shape") public static native void index2coordsCPU(@Cast("const Nd4jLong") long startIndex, @Cast("const Nd4jLong") long index, @Cast("const Nd4jLong*") long[] shapeInfo, @Cast("Nd4jLong*") long[] coords);
@Namespace("shape") public static native void index2coordsCPU(@Cast("const Nd4jLong") long startIndex, @Cast("const Nd4jLong") long index, @Cast("const Nd4jLong*") LongPointer shapeInfo, IntPointer coords);
@Namespace("shape") public static native void index2coordsCPU(@Cast("const Nd4jLong") long startIndex, @Cast("const Nd4jLong") long index, @Cast("const Nd4jLong*") LongBuffer shapeInfo, IntBuffer coords);
@Namespace("shape") public static native void index2coordsCPU(@Cast("const Nd4jLong") long startIndex, @Cast("const Nd4jLong") long index, @Cast("const Nd4jLong*") long[] shapeInfo, int[] coords);
/**
* take into account only dimensions stored in tadDims, tadDims must be sorted in increasing order!
*/
@Namespace("shape") public static native void index2coords(@Cast("Nd4jLong") long index, @Cast("const Nd4jLong*") LongPointer shapeInfo, @Cast("Nd4jLong*") LongPointer coords, int dimsSize, @Const IntPointer tadDims);
@Namespace("shape") public static native void index2coords(@Cast("Nd4jLong") long index, @Cast("const Nd4jLong*") LongBuffer shapeInfo, @Cast("Nd4jLong*") LongBuffer coords, int dimsSize, @Const IntBuffer tadDims);
@Namespace("shape") public static native void index2coords(@Cast("Nd4jLong") long index, @Cast("const Nd4jLong*") long[] shapeInfo, @Cast("Nd4jLong*") long[] coords, int dimsSize, @Const int[] tadDims);
@Namespace("shape") public static native void index2coords(@Cast("Nd4jLong") long index, @Cast("const Nd4jLong*") LongPointer shapeInfo, IntPointer coords, int dimsSize, @Const IntPointer tadDims);
@Namespace("shape") public static native void index2coords(@Cast("Nd4jLong") long index, @Cast("const Nd4jLong*") LongBuffer shapeInfo, IntBuffer coords, int dimsSize, @Const IntBuffer tadDims);
@Namespace("shape") public static native void index2coords(@Cast("Nd4jLong") long index, @Cast("const Nd4jLong*") long[] shapeInfo, int[] coords, int dimsSize, @Const int[] tadDims);
/**
* Convert coordinates to the corresponding linear index (sequence number in other words)
@ -7833,15 +7842,15 @@ public static final int PREALLOC_SIZE = 33554432;
@Namespace("shape") public static native @Cast("Nd4jLong") long coords2index(@Cast("const Nd4jLong*") LongPointer shapeInfo, @Const IntPointer coords);
@Namespace("shape") public static native @Cast("Nd4jLong") long coords2index(@Cast("const Nd4jLong*") LongBuffer shapeInfo, @Const IntBuffer coords);
@Namespace("shape") public static native @Cast("Nd4jLong") long coords2index(@Cast("const Nd4jLong*") long[] shapeInfo, @Const int[] coords);
@Namespace("shape") public static native @Cast("Nd4jLong") long coords2index(int rank, @Cast("const Nd4jLong*") LongPointer shape, @Cast("const Nd4jLong*") LongPointer coords);
@Namespace("shape") public static native @Cast("Nd4jLong") long coords2index(int rank, @Cast("const Nd4jLong*") LongBuffer shape, @Cast("const Nd4jLong*") LongBuffer coords);
@Namespace("shape") public static native @Cast("Nd4jLong") long coords2index(int rank, @Cast("const Nd4jLong*") long[] shape, @Cast("const Nd4jLong*") long[] coords);
@Namespace("shape") public static native @Cast("Nd4jLong") long coords2index(int rank, @Cast("const Nd4jLong*") LongPointer shape, @Const IntPointer coords);
@Namespace("shape") public static native @Cast("Nd4jLong") long coords2index(int rank, @Cast("const Nd4jLong*") LongBuffer shape, @Const IntBuffer coords);
@Namespace("shape") public static native @Cast("Nd4jLong") long coords2index(int rank, @Cast("const Nd4jLong*") long[] shape, @Const int[] coords);
/**
* take into account only dimensions stored in tadDims, tadDims must be sorted in increasing order!
*/
@Namespace("shape") public static native @Cast("Nd4jLong") long coords2index(@Cast("const Nd4jLong*") LongPointer shapeInfo, @Cast("const Nd4jLong*") LongPointer coords, int dimsSize, @Const IntPointer tadDims);
@Namespace("shape") public static native @Cast("Nd4jLong") long coords2index(@Cast("const Nd4jLong*") LongBuffer shapeInfo, @Cast("const Nd4jLong*") LongBuffer coords, int dimsSize, @Const IntBuffer tadDims);
@Namespace("shape") public static native @Cast("Nd4jLong") long coords2index(@Cast("const Nd4jLong*") long[] shapeInfo, @Cast("const Nd4jLong*") long[] coords, int dimsSize, @Const int[] tadDims);
@Namespace("shape") public static native @Cast("Nd4jLong") long coords2index(@Cast("const Nd4jLong*") LongPointer shapeInfo, @Const IntPointer coords, int dimsSize, @Const IntPointer tadDims);
@Namespace("shape") public static native @Cast("Nd4jLong") long coords2index(@Cast("const Nd4jLong*") LongBuffer shapeInfo, @Const IntBuffer coords, int dimsSize, @Const IntBuffer tadDims);
@Namespace("shape") public static native @Cast("Nd4jLong") long coords2index(@Cast("const Nd4jLong*") long[] shapeInfo, @Const int[] coords, int dimsSize, @Const int[] tadDims);
/**
* increment n-dimensional array by one iteration by changing coord appropriately
@ -7931,32 +7940,32 @@ public static final int PREALLOC_SIZE = 33554432;
// function calculates the coordinates of min array (and saves them into minIdxs) given coordinates of max array (already stored in maxIdxs)
// dimsToExclude - should be sorted in increasing order
// dimsLen - length of dimsToExclude, if not set (= -1), then it is calculated as maxRank - minRank
@Namespace("shape") public static native void maxIndToMinInd(@Cast("Nd4jLong*") LongPointer maxIdxs, @Cast("Nd4jLong*") LongPointer minIdxs, @Cast("const Nd4jLong*") LongPointer maxShapeInfo, @Cast("const Nd4jLong*") LongPointer minShapeInfo, @Const IntPointer dimsToExclude/*=nullptr*/, int dimsLen/*=-1*/);
@Namespace("shape") public static native void maxIndToMinInd(@Cast("Nd4jLong*") LongPointer maxIdxs, @Cast("Nd4jLong*") LongPointer minIdxs, @Cast("const Nd4jLong*") LongPointer maxShapeInfo, @Cast("const Nd4jLong*") LongPointer minShapeInfo);
@Namespace("shape") public static native void maxIndToMinInd(@Cast("Nd4jLong*") LongBuffer maxIdxs, @Cast("Nd4jLong*") LongBuffer minIdxs, @Cast("const Nd4jLong*") LongBuffer maxShapeInfo, @Cast("const Nd4jLong*") LongBuffer minShapeInfo, @Const IntBuffer dimsToExclude/*=nullptr*/, int dimsLen/*=-1*/);
@Namespace("shape") public static native void maxIndToMinInd(@Cast("Nd4jLong*") LongBuffer maxIdxs, @Cast("Nd4jLong*") LongBuffer minIdxs, @Cast("const Nd4jLong*") LongBuffer maxShapeInfo, @Cast("const Nd4jLong*") LongBuffer minShapeInfo);
@Namespace("shape") public static native void maxIndToMinInd(@Cast("Nd4jLong*") long[] maxIdxs, @Cast("Nd4jLong*") long[] minIdxs, @Cast("const Nd4jLong*") long[] maxShapeInfo, @Cast("const Nd4jLong*") long[] minShapeInfo, @Const int[] dimsToExclude/*=nullptr*/, int dimsLen/*=-1*/);
@Namespace("shape") public static native void maxIndToMinInd(@Cast("Nd4jLong*") long[] maxIdxs, @Cast("Nd4jLong*") long[] minIdxs, @Cast("const Nd4jLong*") long[] maxShapeInfo, @Cast("const Nd4jLong*") long[] minShapeInfo);
@Namespace("shape") public static native void maxIndToMinInd(IntPointer maxIdxs, IntPointer minIdxs, @Cast("const Nd4jLong*") LongPointer maxShapeInfo, @Cast("const Nd4jLong*") LongPointer minShapeInfo, @Const IntPointer dimsToExclude/*=nullptr*/, int dimsLen/*=-1*/);
@Namespace("shape") public static native void maxIndToMinInd(IntPointer maxIdxs, IntPointer minIdxs, @Cast("const Nd4jLong*") LongPointer maxShapeInfo, @Cast("const Nd4jLong*") LongPointer minShapeInfo);
@Namespace("shape") public static native void maxIndToMinInd(IntBuffer maxIdxs, IntBuffer minIdxs, @Cast("const Nd4jLong*") LongBuffer maxShapeInfo, @Cast("const Nd4jLong*") LongBuffer minShapeInfo, @Const IntBuffer dimsToExclude/*=nullptr*/, int dimsLen/*=-1*/);
@Namespace("shape") public static native void maxIndToMinInd(IntBuffer maxIdxs, IntBuffer minIdxs, @Cast("const Nd4jLong*") LongBuffer maxShapeInfo, @Cast("const Nd4jLong*") LongBuffer minShapeInfo);
@Namespace("shape") public static native void maxIndToMinInd(int[] maxIdxs, int[] minIdxs, @Cast("const Nd4jLong*") long[] maxShapeInfo, @Cast("const Nd4jLong*") long[] minShapeInfo, @Const int[] dimsToExclude/*=nullptr*/, int dimsLen/*=-1*/);
@Namespace("shape") public static native void maxIndToMinInd(int[] maxIdxs, int[] minIdxs, @Cast("const Nd4jLong*") long[] maxShapeInfo, @Cast("const Nd4jLong*") long[] minShapeInfo);
// calculate indexes of max-array, these output indexes correspond to one minIdx index of min-array which is sub-array of max-array
// dimsToExclude - should be sorted in increasing order
@Namespace("shape") public static native int outerArrayIndexes(@Cast("Nd4jLong*") LongPointer maxIdxs, @Cast("const Nd4jLong") long minIdx, @Cast("const Nd4jLong*") LongPointer maxShapeInfo, @Cast("const Nd4jLong*") LongPointer minShapeInfo, @Const IntPointer dimsToExclude/*=nullptr*/);
@Namespace("shape") public static native int outerArrayIndexes(@Cast("Nd4jLong*") LongPointer maxIdxs, @Cast("const Nd4jLong") long minIdx, @Cast("const Nd4jLong*") LongPointer maxShapeInfo, @Cast("const Nd4jLong*") LongPointer minShapeInfo);
@Namespace("shape") public static native int outerArrayIndexes(@Cast("Nd4jLong*") LongBuffer maxIdxs, @Cast("const Nd4jLong") long minIdx, @Cast("const Nd4jLong*") LongBuffer maxShapeInfo, @Cast("const Nd4jLong*") LongBuffer minShapeInfo, @Const IntBuffer dimsToExclude/*=nullptr*/);
@Namespace("shape") public static native int outerArrayIndexes(@Cast("Nd4jLong*") LongBuffer maxIdxs, @Cast("const Nd4jLong") long minIdx, @Cast("const Nd4jLong*") LongBuffer maxShapeInfo, @Cast("const Nd4jLong*") LongBuffer minShapeInfo);
@Namespace("shape") public static native int outerArrayIndexes(@Cast("Nd4jLong*") long[] maxIdxs, @Cast("const Nd4jLong") long minIdx, @Cast("const Nd4jLong*") long[] maxShapeInfo, @Cast("const Nd4jLong*") long[] minShapeInfo, @Const int[] dimsToExclude/*=nullptr*/);
@Namespace("shape") public static native int outerArrayIndexes(@Cast("Nd4jLong*") long[] maxIdxs, @Cast("const Nd4jLong") long minIdx, @Cast("const Nd4jLong*") long[] maxShapeInfo, @Cast("const Nd4jLong*") long[] minShapeInfo);
@Namespace("shape") public static native int outerArrayIndexes(IntPointer maxIdxs, @Cast("const Nd4jLong") long minIdx, @Cast("const Nd4jLong*") LongPointer maxShapeInfo, @Cast("const Nd4jLong*") LongPointer minShapeInfo, @Const IntPointer dimsToExclude/*=nullptr*/);
@Namespace("shape") public static native int outerArrayIndexes(IntPointer maxIdxs, @Cast("const Nd4jLong") long minIdx, @Cast("const Nd4jLong*") LongPointer maxShapeInfo, @Cast("const Nd4jLong*") LongPointer minShapeInfo);
@Namespace("shape") public static native int outerArrayIndexes(IntBuffer maxIdxs, @Cast("const Nd4jLong") long minIdx, @Cast("const Nd4jLong*") LongBuffer maxShapeInfo, @Cast("const Nd4jLong*") LongBuffer minShapeInfo, @Const IntBuffer dimsToExclude/*=nullptr*/);
@Namespace("shape") public static native int outerArrayIndexes(IntBuffer maxIdxs, @Cast("const Nd4jLong") long minIdx, @Cast("const Nd4jLong*") LongBuffer maxShapeInfo, @Cast("const Nd4jLong*") LongBuffer minShapeInfo);
@Namespace("shape") public static native int outerArrayIndexes(int[] maxIdxs, @Cast("const Nd4jLong") long minIdx, @Cast("const Nd4jLong*") long[] maxShapeInfo, @Cast("const Nd4jLong*") long[] minShapeInfo, @Const int[] dimsToExclude/*=nullptr*/);
@Namespace("shape") public static native int outerArrayIndexes(int[] maxIdxs, @Cast("const Nd4jLong") long minIdx, @Cast("const Nd4jLong*") long[] maxShapeInfo, @Cast("const Nd4jLong*") long[] minShapeInfo);
// calculate offsets of max-array, these offsets correspond to one minIdx index of min-array which is sub-array of max-array
// maxOffsets - will contain calculated offsets of max-array, buffer for maxOffsets should be allocated beforehand
// dimsToExclude - should be sorted in increasing order
// memBuff - auxiliary memory buffer (size = 2 * max_rank) for coordinates and increments storing, should be allocated beforehand
@Namespace("shape") public static native int outerArrayOffsets(@Cast("Nd4jLong*") LongPointer maxOffsets, @Cast("const Nd4jLong") long minIdx, @Cast("const Nd4jLong*") LongPointer maxShapeInfo, @Cast("const Nd4jLong*") LongPointer minShapeInfo, @Cast("Nd4jLong*") LongPointer memBuff, @Const IntPointer dimsToExclude/*=nullptr*/);
@Namespace("shape") public static native int outerArrayOffsets(@Cast("Nd4jLong*") LongPointer maxOffsets, @Cast("const Nd4jLong") long minIdx, @Cast("const Nd4jLong*") LongPointer maxShapeInfo, @Cast("const Nd4jLong*") LongPointer minShapeInfo, @Cast("Nd4jLong*") LongPointer memBuff);
@Namespace("shape") public static native int outerArrayOffsets(@Cast("Nd4jLong*") LongBuffer maxOffsets, @Cast("const Nd4jLong") long minIdx, @Cast("const Nd4jLong*") LongBuffer maxShapeInfo, @Cast("const Nd4jLong*") LongBuffer minShapeInfo, @Cast("Nd4jLong*") LongBuffer memBuff, @Const IntBuffer dimsToExclude/*=nullptr*/);
@Namespace("shape") public static native int outerArrayOffsets(@Cast("Nd4jLong*") LongBuffer maxOffsets, @Cast("const Nd4jLong") long minIdx, @Cast("const Nd4jLong*") LongBuffer maxShapeInfo, @Cast("const Nd4jLong*") LongBuffer minShapeInfo, @Cast("Nd4jLong*") LongBuffer memBuff);
@Namespace("shape") public static native int outerArrayOffsets(@Cast("Nd4jLong*") long[] maxOffsets, @Cast("const Nd4jLong") long minIdx, @Cast("const Nd4jLong*") long[] maxShapeInfo, @Cast("const Nd4jLong*") long[] minShapeInfo, @Cast("Nd4jLong*") long[] memBuff, @Const int[] dimsToExclude/*=nullptr*/);
@Namespace("shape") public static native int outerArrayOffsets(@Cast("Nd4jLong*") long[] maxOffsets, @Cast("const Nd4jLong") long minIdx, @Cast("const Nd4jLong*") long[] maxShapeInfo, @Cast("const Nd4jLong*") long[] minShapeInfo, @Cast("Nd4jLong*") long[] memBuff);
@Namespace("shape") public static native int outerArrayOffsets(@Cast("Nd4jLong*") LongPointer maxOffsets, @Cast("const Nd4jLong") long minIdx, @Cast("const Nd4jLong*") LongPointer maxShapeInfo, @Cast("const Nd4jLong*") LongPointer minShapeInfo, IntPointer memBuff, @Const IntPointer dimsToExclude/*=nullptr*/);
@Namespace("shape") public static native int outerArrayOffsets(@Cast("Nd4jLong*") LongPointer maxOffsets, @Cast("const Nd4jLong") long minIdx, @Cast("const Nd4jLong*") LongPointer maxShapeInfo, @Cast("const Nd4jLong*") LongPointer minShapeInfo, IntPointer memBuff);
@Namespace("shape") public static native int outerArrayOffsets(@Cast("Nd4jLong*") LongBuffer maxOffsets, @Cast("const Nd4jLong") long minIdx, @Cast("const Nd4jLong*") LongBuffer maxShapeInfo, @Cast("const Nd4jLong*") LongBuffer minShapeInfo, IntBuffer memBuff, @Const IntBuffer dimsToExclude/*=nullptr*/);
@Namespace("shape") public static native int outerArrayOffsets(@Cast("Nd4jLong*") LongBuffer maxOffsets, @Cast("const Nd4jLong") long minIdx, @Cast("const Nd4jLong*") LongBuffer maxShapeInfo, @Cast("const Nd4jLong*") LongBuffer minShapeInfo, IntBuffer memBuff);
@Namespace("shape") public static native int outerArrayOffsets(@Cast("Nd4jLong*") long[] maxOffsets, @Cast("const Nd4jLong") long minIdx, @Cast("const Nd4jLong*") long[] maxShapeInfo, @Cast("const Nd4jLong*") long[] minShapeInfo, int[] memBuff, @Const int[] dimsToExclude/*=nullptr*/);
@Namespace("shape") public static native int outerArrayOffsets(@Cast("Nd4jLong*") long[] maxOffsets, @Cast("const Nd4jLong") long minIdx, @Cast("const Nd4jLong*") long[] maxShapeInfo, @Cast("const Nd4jLong*") long[] minShapeInfo, int[] memBuff);
// calculates offsets for entities (elements or sub-arrays), shape in context of sub-array means dimensions excluded from outer array
// rank is equal to size of shape
@ -8052,9 +8061,7 @@ public static final int PREALLOC_SIZE = 33554432;
* get stride over contiguous axis (contiguous axis must have stride = 1)
* for example when inShapeInfo is {4, 2,5,4,3, 60,1,5,20, 16384,0,99} then output is 5 (that is smallest stride in inShapeInfo except those equal to 1)
*/
@Namespace("shape") public static native @Cast("Nd4jLong") long strideOverContigAxis(int axis, @Cast("const Nd4jLong*") LongPointer inShapeInfo);
@Namespace("shape") public static native @Cast("Nd4jLong") long strideOverContigAxis(int axis, @Cast("const Nd4jLong*") LongBuffer inShapeInfo);
@Namespace("shape") public static native @Cast("Nd4jLong") long strideOverContigAxis(int axis, @Cast("const Nd4jLong*") long[] inShapeInfo);
// INLINEDEF _CUDA_HD Nd4jLong strideOverContigAxis(const int axis, const Nd4jLong* inShapeInfo);
@ -8930,6 +8937,10 @@ public static final int PREALLOC_SIZE = 33554432;
//////////////////////////////////////////////////////////////////////
//////////////////////////////////////////////////////////////////////
//////////////////////////////////////////////////////////////////////
//////////////////////////////////////////////////////////////////////
// INLINEDEF _CUDA_HD void calcOffsets(const Nd4jLong *xShapeInfo, Nd4jLong*& xOffsets, const Nd4jLong *yShapeInfo, Nd4jLong*& yOffsets, const Nd4jLong* zShapeInfo, Nd4jLong*& zOffsets, const char order) {
@ -9126,6 +9137,23 @@ public static final int PREALLOC_SIZE = 33554432;
//////////////////////////////////////////////////////////////////////
// INLINEDEF _CUDA_HD Nd4jLong strideOverContigAxis(const int axis, const Nd4jLong* inShapeInfo) {
// Nd4jLong result = 9223372036854775807LL;
// for(uint i = 0; i < shape::rank(inShapeInfo); ++i) {
// const auto currentStride = shape::stride(inShapeInfo)[i];
// if(i == axis || shape::shapeOf(inShapeInfo)[i] == 1)
// continue;
// if(result > currentStride)
// result = currentStride;
// }
// return result == 9223372036854775807LL ? 1 : result;
// }
@ -9736,18 +9764,17 @@ public static final int PREALLOC_SIZE = 33554432;
public native @Cast("Nd4jStatus") int execute(@Const @ByRef NDArrayVector inputs, @Const @ByRef NDArrayVector outputs, @StdVector DoubleBuffer tArgs, @Cast("Nd4jLong*") @StdVector LongBuffer iArgs, @Cast("bool*") @StdVector BooleanPointer bArgs/*=std::vector<bool>()*/, @Cast("sd::DataType*") @StdVector IntBuffer dArgs/*=std::vector<sd::DataType>()*/, @Cast("bool") boolean isInplace/*=false*/);
public native @Cast("Nd4jStatus") int execute(@Const @ByRef NDArrayVector inputs, @Const @ByRef NDArrayVector outputs, @StdVector double[] tArgs, @Cast("Nd4jLong*") @StdVector long[] iArgs, @Cast("bool*") @StdVector boolean[] bArgs/*=std::vector<bool>()*/, @Cast("sd::DataType*") @StdVector int[] dArgs/*=std::vector<sd::DataType>()*/, @Cast("bool") boolean isInplace/*=false*/);
public native @ByVal ResultSet evaluate(@Const @ByRef NDArrayVector inputs);
public native ResultSet evaluate(@Const @ByRef NDArrayVector inputs);
public native ResultSet evaluate(@Const @ByRef NDArrayVector inputs, @StdVector DoublePointer tArgs, @Cast("Nd4jLong*") @StdVector LongPointer iArgs, @Cast("bool*") @StdVector BooleanPointer bArgs/*=std::vector<bool>()*/, @Cast("sd::DataType*") @StdVector IntPointer dArgs/*=std::vector<sd::DataType>()*/, @Cast("bool") boolean isInplace/*=false*/);
public native ResultSet evaluate(@Const @ByRef NDArrayVector inputs, @StdVector DoublePointer tArgs, @Cast("Nd4jLong*") @StdVector LongPointer iArgs);
public native ResultSet evaluate(@Const @ByRef NDArrayVector inputs, @StdVector DoubleBuffer tArgs, @Cast("Nd4jLong*") @StdVector LongBuffer iArgs, @Cast("bool*") @StdVector boolean[] bArgs/*=std::vector<bool>()*/, @Cast("sd::DataType*") @StdVector IntBuffer dArgs/*=std::vector<sd::DataType>()*/, @Cast("bool") boolean isInplace/*=false*/);
public native ResultSet evaluate(@Const @ByRef NDArrayVector inputs, @StdVector DoubleBuffer tArgs, @Cast("Nd4jLong*") @StdVector LongBuffer iArgs);
public native ResultSet evaluate(@Const @ByRef NDArrayVector inputs, @StdVector double[] tArgs, @Cast("Nd4jLong*") @StdVector long[] iArgs, @Cast("bool*") @StdVector BooleanPointer bArgs/*=std::vector<bool>()*/, @Cast("sd::DataType*") @StdVector int[] dArgs/*=std::vector<sd::DataType>()*/, @Cast("bool") boolean isInplace/*=false*/);
public native ResultSet evaluate(@Const @ByRef NDArrayVector inputs, @StdVector double[] tArgs, @Cast("Nd4jLong*") @StdVector long[] iArgs);
public native ResultSet evaluate(@Const @ByRef NDArrayVector inputs, @StdVector DoublePointer tArgs, @Cast("Nd4jLong*") @StdVector LongPointer iArgs, @Cast("bool*") @StdVector boolean[] bArgs/*=std::vector<bool>()*/, @Cast("sd::DataType*") @StdVector IntPointer dArgs/*=std::vector<sd::DataType>()*/, @Cast("bool") boolean isInplace/*=false*/);
public native ResultSet evaluate(@Const @ByRef NDArrayVector inputs, @StdVector DoubleBuffer tArgs, @Cast("Nd4jLong*") @StdVector LongBuffer iArgs, @Cast("bool*") @StdVector BooleanPointer bArgs/*=std::vector<bool>()*/, @Cast("sd::DataType*") @StdVector IntBuffer dArgs/*=std::vector<sd::DataType>()*/, @Cast("bool") boolean isInplace/*=false*/);
public native ResultSet evaluate(@Const @ByRef NDArrayVector inputs, @StdVector double[] tArgs, @Cast("Nd4jLong*") @StdVector long[] iArgs, @Cast("bool*") @StdVector boolean[] bArgs/*=std::vector<bool>()*/, @Cast("sd::DataType*") @StdVector int[] dArgs/*=std::vector<sd::DataType>()*/, @Cast("bool") boolean isInplace/*=false*/);
public native @ByVal ResultSet evaluate(@Const @ByRef NDArrayVector inputs, @StdVector DoublePointer tArgs, @Cast("Nd4jLong*") @StdVector LongPointer iArgs, @Cast("bool*") @StdVector BooleanPointer bArgs/*=std::vector<bool>()*/, @Cast("sd::DataType*") @StdVector IntPointer dArgs/*=std::vector<sd::DataType>()*/, @Cast("bool") boolean isInplace/*=false*/);
public native @ByVal ResultSet evaluate(@Const @ByRef NDArrayVector inputs, @StdVector DoublePointer tArgs, @Cast("Nd4jLong*") @StdVector LongPointer iArgs);
public native @ByVal ResultSet evaluate(@Const @ByRef NDArrayVector inputs, @StdVector DoubleBuffer tArgs, @Cast("Nd4jLong*") @StdVector LongBuffer iArgs, @Cast("bool*") @StdVector boolean[] bArgs/*=std::vector<bool>()*/, @Cast("sd::DataType*") @StdVector IntBuffer dArgs/*=std::vector<sd::DataType>()*/, @Cast("bool") boolean isInplace/*=false*/);
public native @ByVal ResultSet evaluate(@Const @ByRef NDArrayVector inputs, @StdVector DoubleBuffer tArgs, @Cast("Nd4jLong*") @StdVector LongBuffer iArgs);
public native @ByVal ResultSet evaluate(@Const @ByRef NDArrayVector inputs, @StdVector double[] tArgs, @Cast("Nd4jLong*") @StdVector long[] iArgs, @Cast("bool*") @StdVector BooleanPointer bArgs/*=std::vector<bool>()*/, @Cast("sd::DataType*") @StdVector int[] dArgs/*=std::vector<sd::DataType>()*/, @Cast("bool") boolean isInplace/*=false*/);
public native @ByVal ResultSet evaluate(@Const @ByRef NDArrayVector inputs, @StdVector double[] tArgs, @Cast("Nd4jLong*") @StdVector long[] iArgs);
public native @ByVal ResultSet evaluate(@Const @ByRef NDArrayVector inputs, @StdVector DoublePointer tArgs, @Cast("Nd4jLong*") @StdVector LongPointer iArgs, @Cast("bool*") @StdVector boolean[] bArgs/*=std::vector<bool>()*/, @Cast("sd::DataType*") @StdVector IntPointer dArgs/*=std::vector<sd::DataType>()*/, @Cast("bool") boolean isInplace/*=false*/);
public native @ByVal ResultSet evaluate(@Const @ByRef NDArrayVector inputs, @StdVector DoubleBuffer tArgs, @Cast("Nd4jLong*") @StdVector LongBuffer iArgs, @Cast("bool*") @StdVector BooleanPointer bArgs/*=std::vector<bool>()*/, @Cast("sd::DataType*") @StdVector IntBuffer dArgs/*=std::vector<sd::DataType>()*/, @Cast("bool") boolean isInplace/*=false*/);
public native @ByVal ResultSet evaluate(@Const @ByRef NDArrayVector inputs, @StdVector double[] tArgs, @Cast("Nd4jLong*") @StdVector long[] iArgs, @Cast("bool*") @StdVector boolean[] bArgs/*=std::vector<bool>()*/, @Cast("sd::DataType*") @StdVector int[] dArgs/*=std::vector<sd::DataType>()*/, @Cast("bool") boolean isInplace/*=false*/);
public native @Cast("Nd4jStatus") int execute(@ByRef RandomGenerator rng, @Const @ByRef NDArrayVector inputs, @Const @ByRef NDArrayVector outputs, @StdVector DoublePointer tArgs, @Cast("Nd4jLong*") @StdVector LongPointer iArgs, @Cast("bool*") @StdVector BooleanPointer bArgs, @Cast("sd::DataType*") @StdVector IntPointer dArgs/*=std::vector<sd::DataType>()*/, @Cast("bool") boolean isInplace/*=false*/, @Cast("sd::DataType") int type/*=sd::DataType::FLOAT32*/);
public native @Cast("Nd4jStatus") int execute(@ByRef RandomGenerator rng, @Const @ByRef NDArrayVector inputs, @Const @ByRef NDArrayVector outputs, @StdVector DoublePointer tArgs, @Cast("Nd4jLong*") @StdVector LongPointer iArgs, @Cast("bool*") @StdVector BooleanPointer bArgs);
@ -9762,8 +9789,9 @@ public static final int PREALLOC_SIZE = 33554432;
public native @Cast("Nd4jStatus") int execute(@ByRef RandomGenerator rng, @Const @ByRef NDArrayVector inputs, @Const @ByRef NDArrayVector outputs, @StdVector double[] tArgs, @Cast("Nd4jLong*") @StdVector long[] iArgs, @Cast("bool*") @StdVector boolean[] bArgs, @Cast("sd::DataType*") @StdVector int[] dArgs/*=std::vector<sd::DataType>()*/, @Cast("bool") boolean isInplace/*=false*/, @Cast("sd::DataType") int type/*=sd::DataType::FLOAT32*/);
public native @Cast("Nd4jStatus") int execute(@ByRef RandomGenerator rng, @Const @ByRef NDArrayVector inputs, @Const @ByRef NDArrayVector outputs, @StdVector double[] tArgs, @Cast("Nd4jLong*") @StdVector long[] iArgs, @Cast("bool*") @StdVector boolean[] bArgs);
public native ResultSet execute(@Const @ByRef OpArgsHolder holder, @Cast("bool") boolean isInplace/*=false*/);
public native ResultSet execute(@Const @ByRef OpArgsHolder holder);
public native @ByVal ResultSet execute(@Const @ByRef OpArgsHolder holder, @Cast("bool") boolean isInplace/*=false*/);
public native @ByVal ResultSet execute(@Const @ByRef OpArgsHolder holder);
// There methods provide various validation options
public native @Cast("Nd4jStatus") int validateNonEmptyInput(@ByRef Context block);
@ -9835,9 +9863,9 @@ public static final int PREALLOC_SIZE = 33554432;
public native @Cast("Nd4jStatus") int execute(Context block);
public native ResultSet execute(NDArrayList list, @ByRef NDArrayVector inputs, @StdVector DoublePointer tArgs, @StdVector IntPointer iArgs);
public native ResultSet execute(NDArrayList list, @ByRef NDArrayVector inputs, @StdVector DoubleBuffer tArgs, @StdVector IntBuffer iArgs);
public native ResultSet execute(NDArrayList list, @ByRef NDArrayVector inputs, @StdVector double[] tArgs, @StdVector int[] iArgs);
public native @ByVal ResultSet execute(NDArrayList list, @ByRef NDArrayVector inputs, @StdVector DoublePointer tArgs, @StdVector IntPointer iArgs);
public native @ByVal ResultSet execute(NDArrayList list, @ByRef NDArrayVector inputs, @StdVector DoubleBuffer tArgs, @StdVector IntBuffer iArgs);
public native @ByVal ResultSet execute(NDArrayList list, @ByRef NDArrayVector inputs, @StdVector double[] tArgs, @StdVector int[] iArgs);
public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block);
}

View File

@ -6880,6 +6880,9 @@ public static final int PREALLOC_SIZE = 33554432;
@Namespace("shape") public static native int sizeAt(@Cast("const Nd4jLong*") LongPointer shapeInfo, int dim);
@Namespace("shape") public static native int sizeAt(@Cast("const Nd4jLong*") LongBuffer shapeInfo, int dim);
@Namespace("shape") public static native int sizeAt(@Cast("const Nd4jLong*") long[] shapeInfo, int dim);
@Namespace("shape") public static native @Cast("Nd4jLong") long strideAt(@Cast("const Nd4jLong*") LongPointer shapeInfo, int dim);
@Namespace("shape") public static native @Cast("Nd4jLong") long strideAt(@Cast("const Nd4jLong*") LongBuffer shapeInfo, int dim);
@Namespace("shape") public static native @Cast("Nd4jLong") long strideAt(@Cast("const Nd4jLong*") long[] shapeInfo, int dim);
@Namespace("shape") public static native void traceNew(int id);
@ -7817,14 +7820,20 @@ public static final int PREALLOC_SIZE = 33554432;
@Namespace("shape") public static native void index2coords(@Cast("Nd4jLong") long index, int rank, @Cast("const Nd4jLong*") LongPointer shape, IntPointer coords);
@Namespace("shape") public static native void index2coords(@Cast("Nd4jLong") long index, int rank, @Cast("const Nd4jLong*") LongBuffer shape, IntBuffer coords);
@Namespace("shape") public static native void index2coords(@Cast("Nd4jLong") long index, int rank, @Cast("const Nd4jLong*") long[] shape, int[] coords);
@Namespace("shape") public static native void index2coordsCPU(@Cast("const Nd4jLong") long startIndex, @Cast("const Nd4jLong") long index, @Cast("const Nd4jLong*") LongPointer shapeInfo, @Cast("Nd4jLong*") LongPointer coords);
@Namespace("shape") public static native void index2coordsCPU(@Cast("const Nd4jLong") long startIndex, @Cast("const Nd4jLong") long index, @Cast("const Nd4jLong*") LongBuffer shapeInfo, @Cast("Nd4jLong*") LongBuffer coords);
@Namespace("shape") public static native void index2coordsCPU(@Cast("const Nd4jLong") long startIndex, @Cast("const Nd4jLong") long index, @Cast("const Nd4jLong*") long[] shapeInfo, @Cast("Nd4jLong*") long[] coords);
@Namespace("shape") public static native void index2coordsCPU(@Cast("const Nd4jLong") long startIndex, @Cast("const Nd4jLong") long index, @Cast("const Nd4jLong*") LongPointer shapeInfo, IntPointer coords);
@Namespace("shape") public static native void index2coordsCPU(@Cast("const Nd4jLong") long startIndex, @Cast("const Nd4jLong") long index, @Cast("const Nd4jLong*") LongBuffer shapeInfo, IntBuffer coords);
@Namespace("shape") public static native void index2coordsCPU(@Cast("const Nd4jLong") long startIndex, @Cast("const Nd4jLong") long index, @Cast("const Nd4jLong*") long[] shapeInfo, int[] coords);
/**
* take into account only dimensions stored in tadDims, tadDims must be sorted in increasing order!
*/
@Namespace("shape") public static native void index2coords(@Cast("Nd4jLong") long index, @Cast("const Nd4jLong*") LongPointer shapeInfo, @Cast("Nd4jLong*") LongPointer coords, int dimsSize, @Const IntPointer tadDims);
@Namespace("shape") public static native void index2coords(@Cast("Nd4jLong") long index, @Cast("const Nd4jLong*") LongBuffer shapeInfo, @Cast("Nd4jLong*") LongBuffer coords, int dimsSize, @Const IntBuffer tadDims);
@Namespace("shape") public static native void index2coords(@Cast("Nd4jLong") long index, @Cast("const Nd4jLong*") long[] shapeInfo, @Cast("Nd4jLong*") long[] coords, int dimsSize, @Const int[] tadDims);
@Namespace("shape") public static native void index2coords(@Cast("Nd4jLong") long index, @Cast("const Nd4jLong*") LongPointer shapeInfo, IntPointer coords, int dimsSize, @Const IntPointer tadDims);
@Namespace("shape") public static native void index2coords(@Cast("Nd4jLong") long index, @Cast("const Nd4jLong*") LongBuffer shapeInfo, IntBuffer coords, int dimsSize, @Const IntBuffer tadDims);
@Namespace("shape") public static native void index2coords(@Cast("Nd4jLong") long index, @Cast("const Nd4jLong*") long[] shapeInfo, int[] coords, int dimsSize, @Const int[] tadDims);
/**
* Convert coordinates to the corresponding linear index (sequence number in other words)
@ -7836,15 +7845,15 @@ public static final int PREALLOC_SIZE = 33554432;
@Namespace("shape") public static native @Cast("Nd4jLong") long coords2index(@Cast("const Nd4jLong*") LongPointer shapeInfo, @Const IntPointer coords);
@Namespace("shape") public static native @Cast("Nd4jLong") long coords2index(@Cast("const Nd4jLong*") LongBuffer shapeInfo, @Const IntBuffer coords);
@Namespace("shape") public static native @Cast("Nd4jLong") long coords2index(@Cast("const Nd4jLong*") long[] shapeInfo, @Const int[] coords);
@Namespace("shape") public static native @Cast("Nd4jLong") long coords2index(int rank, @Cast("const Nd4jLong*") LongPointer shape, @Cast("const Nd4jLong*") LongPointer coords);
@Namespace("shape") public static native @Cast("Nd4jLong") long coords2index(int rank, @Cast("const Nd4jLong*") LongBuffer shape, @Cast("const Nd4jLong*") LongBuffer coords);
@Namespace("shape") public static native @Cast("Nd4jLong") long coords2index(int rank, @Cast("const Nd4jLong*") long[] shape, @Cast("const Nd4jLong*") long[] coords);
@Namespace("shape") public static native @Cast("Nd4jLong") long coords2index(int rank, @Cast("const Nd4jLong*") LongPointer shape, @Const IntPointer coords);
@Namespace("shape") public static native @Cast("Nd4jLong") long coords2index(int rank, @Cast("const Nd4jLong*") LongBuffer shape, @Const IntBuffer coords);
@Namespace("shape") public static native @Cast("Nd4jLong") long coords2index(int rank, @Cast("const Nd4jLong*") long[] shape, @Const int[] coords);
/**
* take into account only dimensions stored in tadDims, tadDims must be sorted in increasing order!
*/
@Namespace("shape") public static native @Cast("Nd4jLong") long coords2index(@Cast("const Nd4jLong*") LongPointer shapeInfo, @Cast("const Nd4jLong*") LongPointer coords, int dimsSize, @Const IntPointer tadDims);
@Namespace("shape") public static native @Cast("Nd4jLong") long coords2index(@Cast("const Nd4jLong*") LongBuffer shapeInfo, @Cast("const Nd4jLong*") LongBuffer coords, int dimsSize, @Const IntBuffer tadDims);
@Namespace("shape") public static native @Cast("Nd4jLong") long coords2index(@Cast("const Nd4jLong*") long[] shapeInfo, @Cast("const Nd4jLong*") long[] coords, int dimsSize, @Const int[] tadDims);
@Namespace("shape") public static native @Cast("Nd4jLong") long coords2index(@Cast("const Nd4jLong*") LongPointer shapeInfo, @Const IntPointer coords, int dimsSize, @Const IntPointer tadDims);
@Namespace("shape") public static native @Cast("Nd4jLong") long coords2index(@Cast("const Nd4jLong*") LongBuffer shapeInfo, @Const IntBuffer coords, int dimsSize, @Const IntBuffer tadDims);
@Namespace("shape") public static native @Cast("Nd4jLong") long coords2index(@Cast("const Nd4jLong*") long[] shapeInfo, @Const int[] coords, int dimsSize, @Const int[] tadDims);
/**
* increment n-dimensional array by one iteration by changing coord appropriately
@ -7934,32 +7943,32 @@ public static final int PREALLOC_SIZE = 33554432;
// function calculates the coordinates of min array (and saves them into minIdxs) given coordinates of max array (already stored in maxIdxs)
// dimsToExclude - should be sorted in increasing order
// dimsLen - length of dimsToExclude, if not set (= -1), then it is calculated as maxRank - minRank
@Namespace("shape") public static native void maxIndToMinInd(@Cast("Nd4jLong*") LongPointer maxIdxs, @Cast("Nd4jLong*") LongPointer minIdxs, @Cast("const Nd4jLong*") LongPointer maxShapeInfo, @Cast("const Nd4jLong*") LongPointer minShapeInfo, @Const IntPointer dimsToExclude/*=nullptr*/, int dimsLen/*=-1*/);
@Namespace("shape") public static native void maxIndToMinInd(@Cast("Nd4jLong*") LongPointer maxIdxs, @Cast("Nd4jLong*") LongPointer minIdxs, @Cast("const Nd4jLong*") LongPointer maxShapeInfo, @Cast("const Nd4jLong*") LongPointer minShapeInfo);
@Namespace("shape") public static native void maxIndToMinInd(@Cast("Nd4jLong*") LongBuffer maxIdxs, @Cast("Nd4jLong*") LongBuffer minIdxs, @Cast("const Nd4jLong*") LongBuffer maxShapeInfo, @Cast("const Nd4jLong*") LongBuffer minShapeInfo, @Const IntBuffer dimsToExclude/*=nullptr*/, int dimsLen/*=-1*/);
@Namespace("shape") public static native void maxIndToMinInd(@Cast("Nd4jLong*") LongBuffer maxIdxs, @Cast("Nd4jLong*") LongBuffer minIdxs, @Cast("const Nd4jLong*") LongBuffer maxShapeInfo, @Cast("const Nd4jLong*") LongBuffer minShapeInfo);
@Namespace("shape") public static native void maxIndToMinInd(@Cast("Nd4jLong*") long[] maxIdxs, @Cast("Nd4jLong*") long[] minIdxs, @Cast("const Nd4jLong*") long[] maxShapeInfo, @Cast("const Nd4jLong*") long[] minShapeInfo, @Const int[] dimsToExclude/*=nullptr*/, int dimsLen/*=-1*/);
@Namespace("shape") public static native void maxIndToMinInd(@Cast("Nd4jLong*") long[] maxIdxs, @Cast("Nd4jLong*") long[] minIdxs, @Cast("const Nd4jLong*") long[] maxShapeInfo, @Cast("const Nd4jLong*") long[] minShapeInfo);
@Namespace("shape") public static native void maxIndToMinInd(IntPointer maxIdxs, IntPointer minIdxs, @Cast("const Nd4jLong*") LongPointer maxShapeInfo, @Cast("const Nd4jLong*") LongPointer minShapeInfo, @Const IntPointer dimsToExclude/*=nullptr*/, int dimsLen/*=-1*/);
@Namespace("shape") public static native void maxIndToMinInd(IntPointer maxIdxs, IntPointer minIdxs, @Cast("const Nd4jLong*") LongPointer maxShapeInfo, @Cast("const Nd4jLong*") LongPointer minShapeInfo);
@Namespace("shape") public static native void maxIndToMinInd(IntBuffer maxIdxs, IntBuffer minIdxs, @Cast("const Nd4jLong*") LongBuffer maxShapeInfo, @Cast("const Nd4jLong*") LongBuffer minShapeInfo, @Const IntBuffer dimsToExclude/*=nullptr*/, int dimsLen/*=-1*/);
@Namespace("shape") public static native void maxIndToMinInd(IntBuffer maxIdxs, IntBuffer minIdxs, @Cast("const Nd4jLong*") LongBuffer maxShapeInfo, @Cast("const Nd4jLong*") LongBuffer minShapeInfo);
@Namespace("shape") public static native void maxIndToMinInd(int[] maxIdxs, int[] minIdxs, @Cast("const Nd4jLong*") long[] maxShapeInfo, @Cast("const Nd4jLong*") long[] minShapeInfo, @Const int[] dimsToExclude/*=nullptr*/, int dimsLen/*=-1*/);
@Namespace("shape") public static native void maxIndToMinInd(int[] maxIdxs, int[] minIdxs, @Cast("const Nd4jLong*") long[] maxShapeInfo, @Cast("const Nd4jLong*") long[] minShapeInfo);
// calculate indexes of max-array, these output indexes correspond to one minIdx index of min-array which is sub-array of max-array
// dimsToExclude - should be sorted in increasing order
@Namespace("shape") public static native int outerArrayIndexes(@Cast("Nd4jLong*") LongPointer maxIdxs, @Cast("const Nd4jLong") long minIdx, @Cast("const Nd4jLong*") LongPointer maxShapeInfo, @Cast("const Nd4jLong*") LongPointer minShapeInfo, @Const IntPointer dimsToExclude/*=nullptr*/);
@Namespace("shape") public static native int outerArrayIndexes(@Cast("Nd4jLong*") LongPointer maxIdxs, @Cast("const Nd4jLong") long minIdx, @Cast("const Nd4jLong*") LongPointer maxShapeInfo, @Cast("const Nd4jLong*") LongPointer minShapeInfo);
@Namespace("shape") public static native int outerArrayIndexes(@Cast("Nd4jLong*") LongBuffer maxIdxs, @Cast("const Nd4jLong") long minIdx, @Cast("const Nd4jLong*") LongBuffer maxShapeInfo, @Cast("const Nd4jLong*") LongBuffer minShapeInfo, @Const IntBuffer dimsToExclude/*=nullptr*/);
@Namespace("shape") public static native int outerArrayIndexes(@Cast("Nd4jLong*") LongBuffer maxIdxs, @Cast("const Nd4jLong") long minIdx, @Cast("const Nd4jLong*") LongBuffer maxShapeInfo, @Cast("const Nd4jLong*") LongBuffer minShapeInfo);
@Namespace("shape") public static native int outerArrayIndexes(@Cast("Nd4jLong*") long[] maxIdxs, @Cast("const Nd4jLong") long minIdx, @Cast("const Nd4jLong*") long[] maxShapeInfo, @Cast("const Nd4jLong*") long[] minShapeInfo, @Const int[] dimsToExclude/*=nullptr*/);
@Namespace("shape") public static native int outerArrayIndexes(@Cast("Nd4jLong*") long[] maxIdxs, @Cast("const Nd4jLong") long minIdx, @Cast("const Nd4jLong*") long[] maxShapeInfo, @Cast("const Nd4jLong*") long[] minShapeInfo);
@Namespace("shape") public static native int outerArrayIndexes(IntPointer maxIdxs, @Cast("const Nd4jLong") long minIdx, @Cast("const Nd4jLong*") LongPointer maxShapeInfo, @Cast("const Nd4jLong*") LongPointer minShapeInfo, @Const IntPointer dimsToExclude/*=nullptr*/);
@Namespace("shape") public static native int outerArrayIndexes(IntPointer maxIdxs, @Cast("const Nd4jLong") long minIdx, @Cast("const Nd4jLong*") LongPointer maxShapeInfo, @Cast("const Nd4jLong*") LongPointer minShapeInfo);
@Namespace("shape") public static native int outerArrayIndexes(IntBuffer maxIdxs, @Cast("const Nd4jLong") long minIdx, @Cast("const Nd4jLong*") LongBuffer maxShapeInfo, @Cast("const Nd4jLong*") LongBuffer minShapeInfo, @Const IntBuffer dimsToExclude/*=nullptr*/);
@Namespace("shape") public static native int outerArrayIndexes(IntBuffer maxIdxs, @Cast("const Nd4jLong") long minIdx, @Cast("const Nd4jLong*") LongBuffer maxShapeInfo, @Cast("const Nd4jLong*") LongBuffer minShapeInfo);
@Namespace("shape") public static native int outerArrayIndexes(int[] maxIdxs, @Cast("const Nd4jLong") long minIdx, @Cast("const Nd4jLong*") long[] maxShapeInfo, @Cast("const Nd4jLong*") long[] minShapeInfo, @Const int[] dimsToExclude/*=nullptr*/);
@Namespace("shape") public static native int outerArrayIndexes(int[] maxIdxs, @Cast("const Nd4jLong") long minIdx, @Cast("const Nd4jLong*") long[] maxShapeInfo, @Cast("const Nd4jLong*") long[] minShapeInfo);
// calculate offsets of max-array, these offsets correspond to one minIdx index of min-array which is sub-array of max-array
// maxOffsets - will contain calculated offsets of max-array, buffer for maxOffsets should be allocated beforehand
// dimsToExclude - should be sorted in increasing order
// memBuff - auxiliary memory buffer (size = 2 * max_rank) for coordinates and increments storing, should be allocated beforehand
@Namespace("shape") public static native int outerArrayOffsets(@Cast("Nd4jLong*") LongPointer maxOffsets, @Cast("const Nd4jLong") long minIdx, @Cast("const Nd4jLong*") LongPointer maxShapeInfo, @Cast("const Nd4jLong*") LongPointer minShapeInfo, @Cast("Nd4jLong*") LongPointer memBuff, @Const IntPointer dimsToExclude/*=nullptr*/);
@Namespace("shape") public static native int outerArrayOffsets(@Cast("Nd4jLong*") LongPointer maxOffsets, @Cast("const Nd4jLong") long minIdx, @Cast("const Nd4jLong*") LongPointer maxShapeInfo, @Cast("const Nd4jLong*") LongPointer minShapeInfo, @Cast("Nd4jLong*") LongPointer memBuff);
@Namespace("shape") public static native int outerArrayOffsets(@Cast("Nd4jLong*") LongBuffer maxOffsets, @Cast("const Nd4jLong") long minIdx, @Cast("const Nd4jLong*") LongBuffer maxShapeInfo, @Cast("const Nd4jLong*") LongBuffer minShapeInfo, @Cast("Nd4jLong*") LongBuffer memBuff, @Const IntBuffer dimsToExclude/*=nullptr*/);
@Namespace("shape") public static native int outerArrayOffsets(@Cast("Nd4jLong*") LongBuffer maxOffsets, @Cast("const Nd4jLong") long minIdx, @Cast("const Nd4jLong*") LongBuffer maxShapeInfo, @Cast("const Nd4jLong*") LongBuffer minShapeInfo, @Cast("Nd4jLong*") LongBuffer memBuff);
@Namespace("shape") public static native int outerArrayOffsets(@Cast("Nd4jLong*") long[] maxOffsets, @Cast("const Nd4jLong") long minIdx, @Cast("const Nd4jLong*") long[] maxShapeInfo, @Cast("const Nd4jLong*") long[] minShapeInfo, @Cast("Nd4jLong*") long[] memBuff, @Const int[] dimsToExclude/*=nullptr*/);
@Namespace("shape") public static native int outerArrayOffsets(@Cast("Nd4jLong*") long[] maxOffsets, @Cast("const Nd4jLong") long minIdx, @Cast("const Nd4jLong*") long[] maxShapeInfo, @Cast("const Nd4jLong*") long[] minShapeInfo, @Cast("Nd4jLong*") long[] memBuff);
@Namespace("shape") public static native int outerArrayOffsets(@Cast("Nd4jLong*") LongPointer maxOffsets, @Cast("const Nd4jLong") long minIdx, @Cast("const Nd4jLong*") LongPointer maxShapeInfo, @Cast("const Nd4jLong*") LongPointer minShapeInfo, IntPointer memBuff, @Const IntPointer dimsToExclude/*=nullptr*/);
@Namespace("shape") public static native int outerArrayOffsets(@Cast("Nd4jLong*") LongPointer maxOffsets, @Cast("const Nd4jLong") long minIdx, @Cast("const Nd4jLong*") LongPointer maxShapeInfo, @Cast("const Nd4jLong*") LongPointer minShapeInfo, IntPointer memBuff);
@Namespace("shape") public static native int outerArrayOffsets(@Cast("Nd4jLong*") LongBuffer maxOffsets, @Cast("const Nd4jLong") long minIdx, @Cast("const Nd4jLong*") LongBuffer maxShapeInfo, @Cast("const Nd4jLong*") LongBuffer minShapeInfo, IntBuffer memBuff, @Const IntBuffer dimsToExclude/*=nullptr*/);
@Namespace("shape") public static native int outerArrayOffsets(@Cast("Nd4jLong*") LongBuffer maxOffsets, @Cast("const Nd4jLong") long minIdx, @Cast("const Nd4jLong*") LongBuffer maxShapeInfo, @Cast("const Nd4jLong*") LongBuffer minShapeInfo, IntBuffer memBuff);
@Namespace("shape") public static native int outerArrayOffsets(@Cast("Nd4jLong*") long[] maxOffsets, @Cast("const Nd4jLong") long minIdx, @Cast("const Nd4jLong*") long[] maxShapeInfo, @Cast("const Nd4jLong*") long[] minShapeInfo, int[] memBuff, @Const int[] dimsToExclude/*=nullptr*/);
@Namespace("shape") public static native int outerArrayOffsets(@Cast("Nd4jLong*") long[] maxOffsets, @Cast("const Nd4jLong") long minIdx, @Cast("const Nd4jLong*") long[] maxShapeInfo, @Cast("const Nd4jLong*") long[] minShapeInfo, int[] memBuff);
// calculates offsets for entities (elements or sub-arrays), shape in context of sub-array means dimensions excluded from outer array
// rank is equal to size of shape
@ -8055,9 +8064,7 @@ public static final int PREALLOC_SIZE = 33554432;
* get stride over contiguous axis (contiguous axis must have stride = 1)
* for example when inShapeInfo is {4, 2,5,4,3, 60,1,5,20, 16384,0,99} then output is 5 (that is smallest stride in inShapeInfo except those equal to 1)
*/
@Namespace("shape") public static native @Cast("Nd4jLong") long strideOverContigAxis(int axis, @Cast("const Nd4jLong*") LongPointer inShapeInfo);
@Namespace("shape") public static native @Cast("Nd4jLong") long strideOverContigAxis(int axis, @Cast("const Nd4jLong*") LongBuffer inShapeInfo);
@Namespace("shape") public static native @Cast("Nd4jLong") long strideOverContigAxis(int axis, @Cast("const Nd4jLong*") long[] inShapeInfo);
// INLINEDEF _CUDA_HD Nd4jLong strideOverContigAxis(const int axis, const Nd4jLong* inShapeInfo);
@ -8933,6 +8940,10 @@ public static final int PREALLOC_SIZE = 33554432;
//////////////////////////////////////////////////////////////////////
//////////////////////////////////////////////////////////////////////
//////////////////////////////////////////////////////////////////////
//////////////////////////////////////////////////////////////////////
// INLINEDEF _CUDA_HD void calcOffsets(const Nd4jLong *xShapeInfo, Nd4jLong*& xOffsets, const Nd4jLong *yShapeInfo, Nd4jLong*& yOffsets, const Nd4jLong* zShapeInfo, Nd4jLong*& zOffsets, const char order) {
@ -9129,6 +9140,23 @@ public static final int PREALLOC_SIZE = 33554432;
//////////////////////////////////////////////////////////////////////
// INLINEDEF _CUDA_HD Nd4jLong strideOverContigAxis(const int axis, const Nd4jLong* inShapeInfo) {
// Nd4jLong result = 9223372036854775807LL;
// for(uint i = 0; i < shape::rank(inShapeInfo); ++i) {
// const auto currentStride = shape::stride(inShapeInfo)[i];
// if(i == axis || shape::shapeOf(inShapeInfo)[i] == 1)
// continue;
// if(result > currentStride)
// result = currentStride;
// }
// return result == 9223372036854775807LL ? 1 : result;
// }
@ -11949,18 +11977,17 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD();
public native @Cast("Nd4jStatus") int execute(@Const @ByRef NDArrayVector inputs, @Const @ByRef NDArrayVector outputs, @StdVector DoubleBuffer tArgs, @Cast("Nd4jLong*") @StdVector LongBuffer iArgs, @Cast("bool*") @StdVector BooleanPointer bArgs/*=std::vector<bool>()*/, @Cast("sd::DataType*") @StdVector IntBuffer dArgs/*=std::vector<sd::DataType>()*/, @Cast("bool") boolean isInplace/*=false*/);
public native @Cast("Nd4jStatus") int execute(@Const @ByRef NDArrayVector inputs, @Const @ByRef NDArrayVector outputs, @StdVector double[] tArgs, @Cast("Nd4jLong*") @StdVector long[] iArgs, @Cast("bool*") @StdVector boolean[] bArgs/*=std::vector<bool>()*/, @Cast("sd::DataType*") @StdVector int[] dArgs/*=std::vector<sd::DataType>()*/, @Cast("bool") boolean isInplace/*=false*/);
public native @ByVal ResultSet evaluate(@Const @ByRef NDArrayVector inputs);
public native ResultSet evaluate(@Const @ByRef NDArrayVector inputs);
public native ResultSet evaluate(@Const @ByRef NDArrayVector inputs, @StdVector DoublePointer tArgs, @Cast("Nd4jLong*") @StdVector LongPointer iArgs, @Cast("bool*") @StdVector BooleanPointer bArgs/*=std::vector<bool>()*/, @Cast("sd::DataType*") @StdVector IntPointer dArgs/*=std::vector<sd::DataType>()*/, @Cast("bool") boolean isInplace/*=false*/);
public native ResultSet evaluate(@Const @ByRef NDArrayVector inputs, @StdVector DoublePointer tArgs, @Cast("Nd4jLong*") @StdVector LongPointer iArgs);
public native ResultSet evaluate(@Const @ByRef NDArrayVector inputs, @StdVector DoubleBuffer tArgs, @Cast("Nd4jLong*") @StdVector LongBuffer iArgs, @Cast("bool*") @StdVector boolean[] bArgs/*=std::vector<bool>()*/, @Cast("sd::DataType*") @StdVector IntBuffer dArgs/*=std::vector<sd::DataType>()*/, @Cast("bool") boolean isInplace/*=false*/);
public native ResultSet evaluate(@Const @ByRef NDArrayVector inputs, @StdVector DoubleBuffer tArgs, @Cast("Nd4jLong*") @StdVector LongBuffer iArgs);
public native ResultSet evaluate(@Const @ByRef NDArrayVector inputs, @StdVector double[] tArgs, @Cast("Nd4jLong*") @StdVector long[] iArgs, @Cast("bool*") @StdVector BooleanPointer bArgs/*=std::vector<bool>()*/, @Cast("sd::DataType*") @StdVector int[] dArgs/*=std::vector<sd::DataType>()*/, @Cast("bool") boolean isInplace/*=false*/);
public native ResultSet evaluate(@Const @ByRef NDArrayVector inputs, @StdVector double[] tArgs, @Cast("Nd4jLong*") @StdVector long[] iArgs);
public native ResultSet evaluate(@Const @ByRef NDArrayVector inputs, @StdVector DoublePointer tArgs, @Cast("Nd4jLong*") @StdVector LongPointer iArgs, @Cast("bool*") @StdVector boolean[] bArgs/*=std::vector<bool>()*/, @Cast("sd::DataType*") @StdVector IntPointer dArgs/*=std::vector<sd::DataType>()*/, @Cast("bool") boolean isInplace/*=false*/);
public native ResultSet evaluate(@Const @ByRef NDArrayVector inputs, @StdVector DoubleBuffer tArgs, @Cast("Nd4jLong*") @StdVector LongBuffer iArgs, @Cast("bool*") @StdVector BooleanPointer bArgs/*=std::vector<bool>()*/, @Cast("sd::DataType*") @StdVector IntBuffer dArgs/*=std::vector<sd::DataType>()*/, @Cast("bool") boolean isInplace/*=false*/);
public native ResultSet evaluate(@Const @ByRef NDArrayVector inputs, @StdVector double[] tArgs, @Cast("Nd4jLong*") @StdVector long[] iArgs, @Cast("bool*") @StdVector boolean[] bArgs/*=std::vector<bool>()*/, @Cast("sd::DataType*") @StdVector int[] dArgs/*=std::vector<sd::DataType>()*/, @Cast("bool") boolean isInplace/*=false*/);
public native @ByVal ResultSet evaluate(@Const @ByRef NDArrayVector inputs, @StdVector DoublePointer tArgs, @Cast("Nd4jLong*") @StdVector LongPointer iArgs, @Cast("bool*") @StdVector BooleanPointer bArgs/*=std::vector<bool>()*/, @Cast("sd::DataType*") @StdVector IntPointer dArgs/*=std::vector<sd::DataType>()*/, @Cast("bool") boolean isInplace/*=false*/);
public native @ByVal ResultSet evaluate(@Const @ByRef NDArrayVector inputs, @StdVector DoublePointer tArgs, @Cast("Nd4jLong*") @StdVector LongPointer iArgs);
public native @ByVal ResultSet evaluate(@Const @ByRef NDArrayVector inputs, @StdVector DoubleBuffer tArgs, @Cast("Nd4jLong*") @StdVector LongBuffer iArgs, @Cast("bool*") @StdVector boolean[] bArgs/*=std::vector<bool>()*/, @Cast("sd::DataType*") @StdVector IntBuffer dArgs/*=std::vector<sd::DataType>()*/, @Cast("bool") boolean isInplace/*=false*/);
public native @ByVal ResultSet evaluate(@Const @ByRef NDArrayVector inputs, @StdVector DoubleBuffer tArgs, @Cast("Nd4jLong*") @StdVector LongBuffer iArgs);
public native @ByVal ResultSet evaluate(@Const @ByRef NDArrayVector inputs, @StdVector double[] tArgs, @Cast("Nd4jLong*") @StdVector long[] iArgs, @Cast("bool*") @StdVector BooleanPointer bArgs/*=std::vector<bool>()*/, @Cast("sd::DataType*") @StdVector int[] dArgs/*=std::vector<sd::DataType>()*/, @Cast("bool") boolean isInplace/*=false*/);
public native @ByVal ResultSet evaluate(@Const @ByRef NDArrayVector inputs, @StdVector double[] tArgs, @Cast("Nd4jLong*") @StdVector long[] iArgs);
public native @ByVal ResultSet evaluate(@Const @ByRef NDArrayVector inputs, @StdVector DoublePointer tArgs, @Cast("Nd4jLong*") @StdVector LongPointer iArgs, @Cast("bool*") @StdVector boolean[] bArgs/*=std::vector<bool>()*/, @Cast("sd::DataType*") @StdVector IntPointer dArgs/*=std::vector<sd::DataType>()*/, @Cast("bool") boolean isInplace/*=false*/);
public native @ByVal ResultSet evaluate(@Const @ByRef NDArrayVector inputs, @StdVector DoubleBuffer tArgs, @Cast("Nd4jLong*") @StdVector LongBuffer iArgs, @Cast("bool*") @StdVector BooleanPointer bArgs/*=std::vector<bool>()*/, @Cast("sd::DataType*") @StdVector IntBuffer dArgs/*=std::vector<sd::DataType>()*/, @Cast("bool") boolean isInplace/*=false*/);
public native @ByVal ResultSet evaluate(@Const @ByRef NDArrayVector inputs, @StdVector double[] tArgs, @Cast("Nd4jLong*") @StdVector long[] iArgs, @Cast("bool*") @StdVector boolean[] bArgs/*=std::vector<bool>()*/, @Cast("sd::DataType*") @StdVector int[] dArgs/*=std::vector<sd::DataType>()*/, @Cast("bool") boolean isInplace/*=false*/);
public native @Cast("Nd4jStatus") int execute(@ByRef RandomGenerator rng, @Const @ByRef NDArrayVector inputs, @Const @ByRef NDArrayVector outputs, @StdVector DoublePointer tArgs, @Cast("Nd4jLong*") @StdVector LongPointer iArgs, @Cast("bool*") @StdVector BooleanPointer bArgs, @Cast("sd::DataType*") @StdVector IntPointer dArgs/*=std::vector<sd::DataType>()*/, @Cast("bool") boolean isInplace/*=false*/, @Cast("sd::DataType") int type/*=sd::DataType::FLOAT32*/);
public native @Cast("Nd4jStatus") int execute(@ByRef RandomGenerator rng, @Const @ByRef NDArrayVector inputs, @Const @ByRef NDArrayVector outputs, @StdVector DoublePointer tArgs, @Cast("Nd4jLong*") @StdVector LongPointer iArgs, @Cast("bool*") @StdVector BooleanPointer bArgs);
@ -11975,8 +12002,9 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD();
public native @Cast("Nd4jStatus") int execute(@ByRef RandomGenerator rng, @Const @ByRef NDArrayVector inputs, @Const @ByRef NDArrayVector outputs, @StdVector double[] tArgs, @Cast("Nd4jLong*") @StdVector long[] iArgs, @Cast("bool*") @StdVector boolean[] bArgs, @Cast("sd::DataType*") @StdVector int[] dArgs/*=std::vector<sd::DataType>()*/, @Cast("bool") boolean isInplace/*=false*/, @Cast("sd::DataType") int type/*=sd::DataType::FLOAT32*/);
public native @Cast("Nd4jStatus") int execute(@ByRef RandomGenerator rng, @Const @ByRef NDArrayVector inputs, @Const @ByRef NDArrayVector outputs, @StdVector double[] tArgs, @Cast("Nd4jLong*") @StdVector long[] iArgs, @Cast("bool*") @StdVector boolean[] bArgs);
public native ResultSet execute(@Const @ByRef OpArgsHolder holder, @Cast("bool") boolean isInplace/*=false*/);
public native ResultSet execute(@Const @ByRef OpArgsHolder holder);
public native @ByVal ResultSet execute(@Const @ByRef OpArgsHolder holder, @Cast("bool") boolean isInplace/*=false*/);
public native @ByVal ResultSet execute(@Const @ByRef OpArgsHolder holder);
// There methods provide various validation options
public native @Cast("Nd4jStatus") int validateNonEmptyInput(@ByRef Context block);
@ -12048,9 +12076,9 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD();
public native @Cast("Nd4jStatus") int execute(Context block);
public native ResultSet execute(NDArrayList list, @ByRef NDArrayVector inputs, @StdVector DoublePointer tArgs, @StdVector IntPointer iArgs);
public native ResultSet execute(NDArrayList list, @ByRef NDArrayVector inputs, @StdVector DoubleBuffer tArgs, @StdVector IntBuffer iArgs);
public native ResultSet execute(NDArrayList list, @ByRef NDArrayVector inputs, @StdVector double[] tArgs, @StdVector int[] iArgs);
public native @ByVal ResultSet execute(NDArrayList list, @ByRef NDArrayVector inputs, @StdVector DoublePointer tArgs, @StdVector IntPointer iArgs);
public native @ByVal ResultSet execute(NDArrayList list, @ByRef NDArrayVector inputs, @StdVector DoubleBuffer tArgs, @StdVector IntBuffer iArgs);
public native @ByVal ResultSet execute(NDArrayList list, @ByRef NDArrayVector inputs, @StdVector double[] tArgs, @StdVector int[] iArgs);
public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block);
}

View File

@ -2050,7 +2050,7 @@ public class ShapeOpValidation extends BaseOpValidation {
print(out[0]);
*/
INDArray emptyIn = Nd4j.empty(DataType.FLOAT);
INDArray emptyIn = Nd4j.empty(DataType.FLOAT).reshape(0, 4);
INDArray axis = Nd4j.scalar(1);
DynamicCustomOp op = DynamicCustomOp.builder("split")
@ -2061,9 +2061,10 @@ public class ShapeOpValidation extends BaseOpValidation {
List<LongShapeDescriptor> l = op.calculateOutputShape();
assertEquals(4, l.size());
for( int i=0; i<4; i++ ){
assertArrayEquals(new long[0], l.get(i).getShape());
assertTrue(l.get(i).isEmpty());
op.addOutputArgument(Nd4j.empty(DataType.FLOAT));
val desc = l.get(i);
assertArrayEquals(new long[]{0, 1}, desc.getShape());
assertTrue(desc.isEmpty());
op.addOutputArgument(Nd4j.empty(DataType.FLOAT).reshape(desc.getShape()));
}
Nd4j.exec(op);