diff --git a/libnd4j/include/loops/cpu/broadcasting.hpp b/libnd4j/include/loops/cpu/broadcasting.hpp index 55f9338fb..8de52cca7 100644 --- a/libnd4j/include/loops/cpu/broadcasting.hpp +++ b/libnd4j/include/loops/cpu/broadcasting.hpp @@ -572,6 +572,272 @@ template DISPATCH_BY_OPNUM_TTT(exec, PARAMS(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo), BROADCAST_OPS); } +//////////////////////////////////////////////////////////////////////// +template +static void execRank1(const X *x, const Nd4jLong *xShapeInfo, const Y *y, const Nd4jLong *yShapeInfo, Z* z, const Nd4jLong *zShapeInfo) { + + uint zAxis0 = shape::sizeAt(zShapeInfo, 0); + Nd4jLong xStrd0 = shape::strideAt(xShapeInfo, 0); + Nd4jLong yStrd0 = shape::strideAt(yShapeInfo, 0); + Nd4jLong zStrd0 = shape::strideAt(zShapeInfo, 0); + + auto func = PRAGMA_THREADS_FOR{ + + if(zStrd0 == 1 && xStrd0 == 1 && yStrd0 == 0) { + for (auto i0 = start; i0 < stop; ++i0) + z[i0] = OpType::op(x[i0], *y); + } + else if(zStrd0 == 1 && xStrd0 == 0 && yStrd0 == 1) { + for (auto i0 = start; i0 < stop; ++i0) + z[i0] = OpType::op(*x, y[i0]); + } + else if(zStrd0 == 1 && xStrd0 == 1 && yStrd0 == 1) { + for (auto i0 = start; i0 < stop; ++i0) + z[i0] = OpType::op(x[i0], y[i0]); + } + else { + for (auto i0 = start; i0 < stop; ++i0) + z[i0 * zStrd0] = OpType::op(x[i0 * xStrd0], y[i0 * yStrd0]); + } + }; + samediff::Threads::parallel_tad(func, 0, zAxis0); +} + +//////////////////////////////////////////////////////////////////////// +template +static void execRank2(const X *x, const Nd4jLong *xShapeInfo, const Y *y, const Nd4jLong *yShapeInfo, Z* z, const Nd4jLong *zShapeInfo) { + + uint zAxis0 = shape::sizeAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 1); + Nd4jLong xStrd0 = shape::strideAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 1); + Nd4jLong yStrd0 = shape::strideAt(yShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 1); + Nd4jLong zStrd0 = shape::strideAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 1); + + uint zAxis1 = shape::sizeAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 1 : 0); + Nd4jLong xStrd1 = shape::strideAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? 1 : 0); + Nd4jLong yStrd1 = shape::strideAt(yShapeInfo, shape::order(zShapeInfo) == 'c' ? 1 : 0); + Nd4jLong zStrd1 = shape::strideAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 1 : 0); + + auto func = PRAGMA_THREADS_FOR{ + + for (auto i0 = start; i0 < stop; ++i0) { + + auto x0 = x + i0 * xStrd0; + auto y0 = y + i0 * yStrd0; + auto z0 = z + i0 * zStrd0; + + if(zStrd1 == 1 && xStrd1 == 1 && yStrd1 == 0) + for (uint i1 = 0; i1 < zAxis1; ++i1) + z0[i1] = OpType::op(x0[i1], *y0); + else if(zStrd1 == 1 && xStrd1 == 0 && yStrd1 == 1) + for (uint i1 = 0; i1 < zAxis1; ++i1) + z0[i1] = OpType::op(*x0, y0[i1]); + else if(zStrd1 == 1 && xStrd1 == 1 && yStrd1 == 1) + for (uint i1 = 0; i1 < zAxis1; ++i1) + z0[i1] = OpType::op(x0[i1], y0[i1]); + else + for (uint i1 = 0; i1 < zAxis1; ++i1) + z0[i1 * zStrd1] = OpType::op(x0[i1 * xStrd1], y0[i1 * yStrd1]); + } + }; + + samediff::Threads::parallel_tad(func, 0, zAxis0); +} + +//////////////////////////////////////////////////////////////////////// +template +static void execRank3(const X *x, const Nd4jLong *xShapeInfo, const Y *y, const Nd4jLong *yShapeInfo, Z* z, const Nd4jLong *zShapeInfo) { + + uint zAxis0 = shape::sizeAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 2); + Nd4jLong xStrd0 = shape::strideAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 2); + Nd4jLong yStrd0 = shape::strideAt(yShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 2); + Nd4jLong zStrd0 = shape::strideAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 2); + + uint zAxis1 = shape::sizeAt(zShapeInfo, 1); + Nd4jLong xStrd1 = shape::strideAt(xShapeInfo, 1); + Nd4jLong yStrd1 = shape::strideAt(yShapeInfo, 1); + Nd4jLong zStrd1 = shape::strideAt(zShapeInfo, 1); + + uint zAxis2 = shape::sizeAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 2 : 0); + Nd4jLong xStrd2 = shape::strideAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? 2 : 0); + Nd4jLong yStrd2 = shape::strideAt(yShapeInfo, shape::order(zShapeInfo) == 'c' ? 2 : 0); + Nd4jLong zStrd2 = shape::strideAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 2 : 0); + + auto func = PRAGMA_THREADS_FOR_2D { + + for (auto i0 = start_x; i0 < stop_x; ++i0) { + for (auto i1 = start_y; i1 < stop_y; ++i1) { + + auto x1 = x + i0 * xStrd0 + i1 * xStrd1; + auto y1 = y + i0 * yStrd0 + i1 * yStrd1; + auto z1 = z + i0 * zStrd0 + i1 * zStrd1; + + if(zStrd2 == 1 && xStrd2 == 1 && yStrd2 == 0) + for (uint i2 = 0; i2 < zAxis2; ++i2) + z1[i2] = OpType::op(x1[i2], *y1); + else if(zStrd2 == 1 && xStrd2 == 0 && yStrd2 == 1) + for (uint i2 = 0; i2 < zAxis2; ++i2) + z1[i2] = OpType::op(*x1, y1[i2]); + else if(zStrd2 == 1 && xStrd2 == 1 && yStrd2 == 1) + for (uint i2 = 0; i2 < zAxis2; ++i2) + z1[i2] = OpType::op(x1[i2], y1[i2]); + else + for (uint i2 = 0; i2 < zAxis2; ++i2) + z1[i2 * zStrd2] = OpType::op(x1[i2 * xStrd2], y1[i2 * yStrd2]); + } + } + }; + + samediff::Threads::parallel_for(func, 0,zAxis0,1, 0,zAxis1,1); +} + +//////////////////////////////////////////////////////////////////////// +template +static void execRank4(const X *x, const Nd4jLong *xShapeInfo, const Y *y, const Nd4jLong *yShapeInfo, Z* z, const Nd4jLong *zShapeInfo) { + + uint zAxis0 = shape::sizeAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 3); + Nd4jLong xStrd0 = shape::strideAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 3); + Nd4jLong yStrd0 = shape::strideAt(yShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 3); + Nd4jLong zStrd0 = shape::strideAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 3); + + uint zAxis1 = shape::sizeAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 1 : 2); + Nd4jLong xStrd1 = shape::strideAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? 1 : 2); + Nd4jLong yStrd1 = shape::strideAt(yShapeInfo, shape::order(zShapeInfo) == 'c' ? 1 : 2); + Nd4jLong zStrd1 = shape::strideAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 1 : 2); + + uint zAxis2 = shape::sizeAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 2 : 1); + Nd4jLong xStrd2 = shape::strideAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? 2 : 1); + Nd4jLong yStrd2 = shape::strideAt(yShapeInfo, shape::order(zShapeInfo) == 'c' ? 2 : 1); + Nd4jLong zStrd2 = shape::strideAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 2 : 1); + + uint zAxis3 = shape::sizeAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 3 : 0); + Nd4jLong xStrd3 = shape::strideAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? 3 : 0); + Nd4jLong yStrd3 = shape::strideAt(yShapeInfo, shape::order(zShapeInfo) == 'c' ? 3 : 0); + Nd4jLong zStrd3 = shape::strideAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 3 : 0); + + auto func = PRAGMA_THREADS_FOR_3D { + + for (auto i0 = start_x; i0 < stop_x; ++i0) { + for (auto i1 = start_y; i1 < stop_y; ++i1) { + for (auto i2 = start_z; i2 < stop_z; ++i2) { + + auto x2 = x + i0 * xStrd0 + i1 * xStrd1 + i2 * xStrd2; + auto y2 = y + i0 * yStrd0 + i1 * yStrd1 + i2 * yStrd2; + auto z2 = z + i0 * zStrd0 + i1 * zStrd1 + i2 * zStrd2; + + if(zStrd3 == 1 && xStrd3 == 1 && yStrd3 == 0) + for (uint i3 = 0; i3 < zAxis3; ++i3) + z2[i3] = OpType::op(x2[i3], *y2); + else if(zStrd3 == 1 && xStrd3 == 0 && yStrd3 == 1) + for (uint i3 = 0; i3 < zAxis3; ++i3) + z2[i3] = OpType::op(*x2, y2[i3]); + else if(zStrd3 == 1 && xStrd3 == 1 && yStrd3 == 1) + for (uint i3 = 0; i3 < zAxis3; ++i3) + z2[i3] = OpType::op(x2[i3], y2[i3]); + else + for (uint i3 = 0; i3 < zAxis3; ++i3) + z2[i3 * zStrd3] = OpType::op(x2[i3 * xStrd3], y2[i3 * yStrd3]); + } + } + } + }; + + samediff::Threads::parallel_for(func, 0,zAxis0,1, 0,zAxis1,1, 0,zAxis2,1); +} + +//////////////////////////////////////////////////////////////////////// +template +static void execRank5(const X *x, const Nd4jLong *xShapeInfo, const Y *y, const Nd4jLong *yShapeInfo, Z* z, const Nd4jLong *zShapeInfo) { + + uint zAxis0 = shape::sizeAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 4); + Nd4jLong xStrd0 = shape::strideAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 4); + Nd4jLong yStrd0 = shape::strideAt(yShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 4); + Nd4jLong zStrd0 = shape::strideAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 4); + + uint zAxis1 = shape::sizeAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 1 : 3); + Nd4jLong xStrd1 = shape::strideAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? 1 : 3); + Nd4jLong yStrd1 = shape::strideAt(yShapeInfo, shape::order(zShapeInfo) == 'c' ? 1 : 3); + Nd4jLong zStrd1 = shape::strideAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 1 : 3); + + uint zAxis2 = shape::sizeAt(zShapeInfo, 2); + Nd4jLong xStrd2 = shape::strideAt(xShapeInfo, 2); + Nd4jLong yStrd2 = shape::strideAt(yShapeInfo, 2); + Nd4jLong zStrd2 = shape::strideAt(zShapeInfo, 2); + + uint zAxis3 = shape::sizeAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 3 : 1); + Nd4jLong xStrd3 = shape::strideAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? 3 : 1); + Nd4jLong yStrd3 = shape::strideAt(yShapeInfo, shape::order(zShapeInfo) == 'c' ? 3 : 1); + Nd4jLong zStrd3 = shape::strideAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 3 : 1); + + uint zAxis4 = shape::sizeAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 4 : 0); + Nd4jLong xStrd4 = shape::strideAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? 4 : 0); + Nd4jLong yStrd4 = shape::strideAt(yShapeInfo, shape::order(zShapeInfo) == 'c' ? 4 : 0); + Nd4jLong zStrd4 = shape::strideAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 4 : 0); + + auto func = PRAGMA_THREADS_FOR_3D { + + for (auto i0 = start_x; i0 < stop_x; ++i0) { + for (auto i1 = start_y; i1 < stop_y; ++i1) { + for (auto i2 = start_z; i2 < stop_z; ++i2) { + for (uint i3 = 0; i3 < zAxis3; ++i3) { + + auto x3 = x + i0 * xStrd0 + i1 * xStrd1 + i2 * xStrd2 + i3 * xStrd3; + auto y3 = y + i0 * yStrd0 + i1 * yStrd1 + i2 * yStrd2 + i3 * yStrd3; + auto z3 = z + i0 * zStrd0 + i1 * zStrd1 + i2 * zStrd2 + i3 * zStrd3; + + if(zStrd4 == 1 && xStrd4 == 1 && yStrd4 == 0) + for (uint i4 = 0; i4 < zAxis4; ++i4) + z3[i4] = OpType::op(x3[i4], *y3); + else if(zStrd4 == 1 && xStrd4 == 0 && yStrd4 == 1) + for (uint i4 = 0; i4 < zAxis4; ++i4) + z3[i4] = OpType::op(*x3, y3[i4]); + else if(zStrd4 == 1 && xStrd4 == 1 && yStrd4 == 1) + for (uint i4 = 0; i4 < zAxis4; ++i4) + z3[i4] = OpType::op(x3[i4], y3[i4]); + else + for (uint i4 = 0; i4 < zAxis4; ++i4) + z3[i4 * zStrd4] = OpType::op(x3[i4 * xStrd4], y3[i4 * yStrd4]); + } + } + } + } + }; + + samediff::Threads::parallel_for(func, 0,zAxis0,1, 0,zAxis1,1, 0,zAxis2,1); +} + +//////////////////////////////////////////////////////////////////////// +template +static void execDefault(const X *x, const Nd4jLong *xShapeInfo, const Y *y, const Nd4jLong *yShapeInfo, Z* z, const Nd4jLong *zShapeInfo) { + + const bool xzSameOffsets = shape::haveSameShapeAndStrides(xShapeInfo, zShapeInfo); + const bool yzSameOffsets = shape::haveSameShapeAndStrides(yShapeInfo, zShapeInfo); + + const int rank = shape::rank(zShapeInfo); // xRank = yRank = zRank + + auto func = PRAGMA_THREADS_FOR{ + + int xCoords[MAX_RANK], yCoords[MAX_RANK], zCoords[MAX_RANK]; + + for (auto i = start; i < stop; ++i) { + + shape::index2coordsCPU(start, i, zShapeInfo, zCoords); + + for (uint j = 0; j < rank; ++j) { + xCoords[j] = shape::sizeAt(xShapeInfo, j) == 1 ? 0 : zCoords[j]; + yCoords[j] = shape::sizeAt(yShapeInfo, j) == 1 ? 0 : zCoords[j]; + } + + const auto zOffset = shape::getOffset(zShapeInfo, zCoords); + const auto xOffset = xzSameOffsets ? zOffset : shape::getOffset(xShapeInfo, xCoords); + const auto yOffset = yzSameOffsets ? zOffset : shape::getOffset(yShapeInfo, yCoords); + + z[zOffset] = OpType::op(x[xOffset], y[yOffset]); + } + }; + + samediff::Threads::parallel_for(func, 0, shape::length(zShapeInfo)); +} + //////////////////////////////////////////////////////////////////////// template template @@ -582,220 +848,26 @@ void Broadcast::exec(const void *vx, const Nd4jLong *xShapeInfo, const Z* z = reinterpret_cast(vz); const int rank = shape::rank(zShapeInfo); // xRank = yRank = zRank - const char zOrder = shape::order(zShapeInfo); - - uint xAxis0 = shape::sizeAt(xShapeInfo, zOrder == 'c' ? 0 : rank-1); - uint xAxis1 = shape::sizeAt(xShapeInfo, zOrder == 'c' ? 1 : rank-2); - uint xAxis2 = rank > 2 ? shape::sizeAt(xShapeInfo, zOrder == 'c' ? 2 : rank - 3) : 0; - uint xAxis3 = rank > 3 ? shape::sizeAt(xShapeInfo, zOrder == 'c' ? 3 : rank - 4) : 0; - uint xAxis4 = rank > 4 ? shape::sizeAt(xShapeInfo, zOrder == 'c' ? 4 : rank - 5) : 0; - Nd4jLong xStrd0 = shape::strideAt(xShapeInfo, zOrder == 'c' ? 0 : rank-1); - Nd4jLong xStrd1 = shape::strideAt(xShapeInfo, zOrder == 'c' ? 1 : rank-2); - Nd4jLong xStrd2 = rank > 2 ? shape::strideAt(xShapeInfo, zOrder == 'c' ? 2 : rank - 3) : 0; - Nd4jLong xStrd3 = rank > 3 ? shape::strideAt(xShapeInfo, zOrder == 'c' ? 3 : rank - 4) : 0; - Nd4jLong xStrd4 = rank > 4 ? shape::strideAt(xShapeInfo, zOrder == 'c' ? 4 : rank - 5) : 0; - - uint yAxis0 = shape::sizeAt(yShapeInfo, zOrder == 'c' ? 0 : rank-1); - uint yAxis1 = shape::sizeAt(yShapeInfo, zOrder == 'c' ? 1 : rank-2); - uint yAxis2 = rank > 2 ? shape::sizeAt(yShapeInfo, zOrder == 'c' ? 2 : rank - 3) : 0; - uint yAxis3 = rank > 3 ? shape::sizeAt(yShapeInfo, zOrder == 'c' ? 3 : rank - 4) : 0; - uint yAxis4 = rank > 4 ? shape::sizeAt(yShapeInfo, zOrder == 'c' ? 4 : rank - 5) : 0; - Nd4jLong yStrd0 = shape::strideAt(yShapeInfo, zOrder == 'c' ? 0 : rank-1); - Nd4jLong yStrd1 = shape::strideAt(yShapeInfo, zOrder == 'c' ? 1 : rank-2); - Nd4jLong yStrd2 = rank > 2 ? shape::strideAt(yShapeInfo, zOrder == 'c' ? 2 : rank - 3) : 0; - Nd4jLong yStrd3 = rank > 3 ? shape::strideAt(yShapeInfo, zOrder == 'c' ? 3 : rank - 4) : 0; - Nd4jLong yStrd4 = rank > 4 ? shape::strideAt(yShapeInfo, zOrder == 'c' ? 4 : rank - 5) : 0; - - uint zAxis0 = shape::sizeAt(zShapeInfo, zOrder == 'c' ? 0 : rank-1); - uint zAxis1 = shape::sizeAt(zShapeInfo, zOrder == 'c' ? 1 : rank-2); - uint zAxis2 = rank > 2 ? shape::sizeAt(zShapeInfo, zOrder == 'c' ? 2 : rank - 3) : 0; - uint zAxis3 = rank > 3 ? shape::sizeAt(zShapeInfo, zOrder == 'c' ? 3 : rank - 4) : 0; - uint zAxis4 = rank > 4 ? shape::sizeAt(zShapeInfo, zOrder == 'c' ? 4 : rank - 5) : 0; - Nd4jLong zStrd0 = shape::strideAt(zShapeInfo, zOrder == 'c' ? 0 : rank-1); - Nd4jLong zStrd1 = shape::strideAt(zShapeInfo, zOrder == 'c' ? 1 : rank-2); - Nd4jLong zStrd2 = rank > 2 ? shape::strideAt(zShapeInfo, zOrder == 'c' ? 2 : rank - 3) : 0; - Nd4jLong zStrd3 = rank > 3 ? shape::strideAt(zShapeInfo, zOrder == 'c' ? 3 : rank - 4) : 0; - Nd4jLong zStrd4 = rank > 4 ? shape::strideAt(zShapeInfo, zOrder == 'c' ? 4 : rank - 5) : 0; switch (rank) { - case 1: { - - auto func = PRAGMA_THREADS_FOR{ - - if(zStrd0 == 1 && xStrd0 == 1 && yStrd0 == 0) { - for (auto i0 = start; i0 < stop; ++i0) - z[i0] = OpType::op(x[i0], *y); - } - else if(zStrd0 == 1 && xStrd0 == 0 && yStrd0 == 1) { - for (auto i0 = start; i0 < stop; ++i0) - z[i0] = OpType::op(*x, y[i0]); - } - else if(zStrd0 == 1 && xStrd0 == 1 && yStrd0 == 1) { - for (auto i0 = start; i0 < stop; ++i0) - z[i0] = OpType::op(x[i0], y[i0]); - } - else { - for (auto i0 = start; i0 < stop; ++i0) - z[i0 * zStrd0] = OpType::op(x[i0 * xStrd0], y[i0 * yStrd0]); - } - }; - samediff::Threads::parallel_tad(func, 0, zAxis0); - } - break; - - case 2: { - - auto func = PRAGMA_THREADS_FOR{ - - for (auto i0 = start; i0 < stop; ++i0) { - - auto x0 = x + i0 * xStrd0; - auto y0 = y + i0 * yStrd0; - auto z0 = z + i0 * zStrd0; - - if(zStrd1 == 1 && xStrd1 == 1 && yStrd1 == 0) - for (uint i1 = 0; i1 < zAxis1; ++i1) - z0[i1] = OpType::op(x0[i1], *y0); - else if(zStrd1 == 1 && xStrd1 == 0 && yStrd1 == 1) - for (uint i1 = 0; i1 < zAxis1; ++i1) - z0[i1] = OpType::op(*x0, y0[i1]); - else if(zStrd1 == 1 && xStrd1 == 1 && yStrd1 == 1) - for (uint i1 = 0; i1 < zAxis1; ++i1) - z0[i1] = OpType::op(x0[i1], y0[i1]); - else - for (uint i1 = 0; i1 < zAxis1; ++i1) - z0[i1 * zStrd1] = OpType::op(x0[i1 * xStrd1], y0[i1 * yStrd1]); - } - }; - samediff::Threads::parallel_tad(func, 0, zAxis0); - } - break; - - case 3: { - - auto func = PRAGMA_THREADS_FOR_2D { - - for (auto i0 = start_x; i0 < stop_x; ++i0) { - for (auto i1 = start_y; i1 < stop_y; ++i1) { - - auto x1 = x + i0 * xStrd0 + i1 * xStrd1; - auto y1 = y + i0 * yStrd0 + i1 * yStrd1; - auto z1 = z + i0 * zStrd0 + i1 * zStrd1; - - if(zStrd2 == 1 && xStrd2 == 1 && yStrd2 == 0) - for (uint i2 = 0; i2 < zAxis2; ++i2) - z1[i2] = OpType::op(x1[i2], *y1); - else if(zStrd2 == 1 && xStrd2 == 0 && yStrd2 == 1) - for (uint i2 = 0; i2 < zAxis2; ++i2) - z1[i2] = OpType::op(*x1, y1[i2]); - else if(zStrd2 == 1 && xStrd2 == 1 && yStrd2 == 1) - for (uint i2 = 0; i2 < zAxis2; ++i2) - z1[i2] = OpType::op(x1[i2], y1[i2]); - else - for (uint i2 = 0; i2 < zAxis2; ++i2) - z1[i2 * zStrd2] = OpType::op(x1[i2 * xStrd2], y1[i2 * yStrd2]); - } - } - }; - samediff::Threads::parallel_for(func, 0,zAxis0,1, 0,zAxis1,1); - } - break; - - case 4: { - - auto func = PRAGMA_THREADS_FOR_3D { - - for (auto i0 = start_x; i0 < stop_x; ++i0) { - for (auto i1 = start_y; i1 < stop_y; ++i1) { - for (auto i2 = start_z; i2 < stop_z; ++i2) { - - auto x2 = x + i0 * xStrd0 + i1 * xStrd1 + i2 * xStrd2; - auto y2 = y + i0 * yStrd0 + i1 * yStrd1 + i2 * yStrd2; - auto z2 = z + i0 * zStrd0 + i1 * zStrd1 + i2 * zStrd2; - - if(zStrd3 == 1 && xStrd3 == 1 && yStrd3 == 0) - for (uint i3 = 0; i3 < zAxis3; ++i3) - z2[i3] = OpType::op(x2[i3], *y2); - else if(zStrd3 == 1 && xStrd3 == 0 && yStrd3 == 1) - for (uint i3 = 0; i3 < zAxis3; ++i3) - z2[i3] = OpType::op(*x2, y2[i3]); - else if(zStrd3 == 1 && xStrd3 == 1 && yStrd3 == 1) - for (uint i3 = 0; i3 < zAxis3; ++i3) - z2[i3] = OpType::op(x2[i3], y2[i3]); - else - for (uint i3 = 0; i3 < zAxis3; ++i3) - z2[i3 * zStrd3] = OpType::op(x2[i3 * xStrd3], y2[i3 * yStrd3]); - } - } - } - }; - samediff::Threads::parallel_for(func, 0,zAxis0,1, 0,zAxis1,1, 0,zAxis2,1); - } - break; - - case 5: { - - auto func = PRAGMA_THREADS_FOR_3D { - - for (auto i0 = start_x; i0 < stop_x; ++i0) { - for (auto i1 = start_y; i1 < stop_y; ++i1) { - for (auto i2 = start_z; i2 < stop_z; ++i2) { - for (uint i3 = 0; i3 < zAxis3; ++i3) { - - auto x3 = x + i0 * xStrd0 + i1 * xStrd1 + i2 * xStrd2 + i3 * xStrd3; - auto y3 = y + i0 * yStrd0 + i1 * yStrd1 + i2 * yStrd2 + i3 * yStrd3; - auto z3 = z + i0 * zStrd0 + i1 * zStrd1 + i2 * zStrd2 + i3 * zStrd3; - - if(zStrd4 == 1 && xStrd4 == 1 && yStrd4 == 0) - for (uint i4 = 0; i4 < zAxis4; ++i4) - z3[i4] = OpType::op(x3[i4], *y3); - else if(zStrd4 == 1 && xStrd4 == 0 && yStrd4 == 1) - for (uint i4 = 0; i4 < zAxis4; ++i4) - z3[i4] = OpType::op(*x3, y3[i4]); - else if(zStrd4 == 1 && xStrd4 == 1 && yStrd4 == 1) - for (uint i4 = 0; i4 < zAxis4; ++i4) - z3[i4] = OpType::op(x3[i4], y3[i4]); - else - for (uint i4 = 0; i4 < zAxis4; ++i4) - z3[i4 * zStrd4] = OpType::op(x3[i4 * xStrd4], y3[i4 * yStrd4]); - } - } - } - } - }; - samediff::Threads::parallel_for(func, 0,zAxis0,1, 0,zAxis1,1, 0,zAxis2,1); - } - break; - - default: { - - const bool xzSameOffsets = shape::haveSameShapeAndStrides(xShapeInfo, zShapeInfo); - const bool yzSameOffsets = shape::haveSameShapeAndStrides(yShapeInfo, zShapeInfo); - - auto func = PRAGMA_THREADS_FOR{ - - int xCoords[MAX_RANK], yCoords[MAX_RANK], zCoords[MAX_RANK]; - - for (auto i = start; i < stop; ++i) { - - shape::index2coordsCPU(start, i, zShapeInfo, zCoords); - - for (uint j = 0; j < rank; ++j) { - xCoords[j] = shape::sizeAt(xShapeInfo, j) == 1 ? 0 : zCoords[j]; - yCoords[j] = shape::sizeAt(yShapeInfo, j) == 1 ? 0 : zCoords[j]; - } - - const auto zOffset = shape::getOffset(zShapeInfo, zCoords); - const auto xOffset = xzSameOffsets ? zOffset : shape::getOffset(xShapeInfo, xCoords); - const auto yOffset = yzSameOffsets ? zOffset : shape::getOffset(yShapeInfo, yCoords); - - z[zOffset] = OpType::op(x[xOffset], y[yOffset]); - } - }; - - samediff::Threads::parallel_for(func, 0, shape::length(zShapeInfo)); - } + case 1: + execRank1(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo); + break; + case 2: + execRank2(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo); + break; + case 3: + execRank3(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo); + break; + case 4: + execRank4(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo); + break; + case 5: + execRank5(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo); + break; + default: + execDefault(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo); } } diff --git a/libnd4j/include/loops/cpu/broadcasting_bool.hpp b/libnd4j/include/loops/cpu/broadcasting_bool.hpp index b1b7eb27b..21b40cb55 100644 --- a/libnd4j/include/loops/cpu/broadcasting_bool.hpp +++ b/libnd4j/include/loops/cpu/broadcasting_bool.hpp @@ -453,6 +453,271 @@ namespace broadcast { } } +//////////////////////////////////////////////////////////////////////// +template +static void execRank1(const X *x, const Nd4jLong *xShapeInfo, const X *y, const Nd4jLong *yShapeInfo, Z* z, const Nd4jLong *zShapeInfo, X* extraParams) { + + uint zAxis0 = shape::sizeAt(zShapeInfo, 0); + Nd4jLong xStrd0 = shape::strideAt(xShapeInfo, 0); + Nd4jLong yStrd0 = shape::strideAt(yShapeInfo, 0); + Nd4jLong zStrd0 = shape::strideAt(zShapeInfo, 0); + + auto func = PRAGMA_THREADS_FOR{ + + if(zStrd0 == 1 && xStrd0 == 1 && yStrd0 == 0) { + for (auto i0 = start; i0 < stop; ++i0) + z[i0] = OpType::op(x[i0], *y, extraParams); + } + else if(zStrd0 == 1 && xStrd0 == 0 && yStrd0 == 1) { + for (auto i0 = start; i0 < stop; ++i0) + z[i0] = OpType::op(*x, y[i0], extraParams); + } + else if(zStrd0 == 1 && xStrd0 == 1 && yStrd0 == 1) { + for (auto i0 = start; i0 < stop; ++i0) + z[i0] = OpType::op(x[i0], y[i0], extraParams); + } + else { + for (auto i0 = start; i0 < stop; ++i0) + z[i0 * zStrd0] = OpType::op(x[i0 * xStrd0], y[i0 * yStrd0], extraParams); + } + }; + samediff::Threads::parallel_tad(func, 0, zAxis0); +} + +//////////////////////////////////////////////////////////////////////// +template +static void execRank2(const X *x, const Nd4jLong *xShapeInfo, const X *y, const Nd4jLong *yShapeInfo, Z* z, const Nd4jLong *zShapeInfo, X* extraParams) { + + uint zAxis0 = shape::sizeAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 1); + Nd4jLong xStrd0 = shape::strideAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 1); + Nd4jLong yStrd0 = shape::strideAt(yShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 1); + Nd4jLong zStrd0 = shape::strideAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 1); + + uint zAxis1 = shape::sizeAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 1 : 0); + Nd4jLong xStrd1 = shape::strideAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? 1 : 0); + Nd4jLong yStrd1 = shape::strideAt(yShapeInfo, shape::order(zShapeInfo) == 'c' ? 1 : 0); + Nd4jLong zStrd1 = shape::strideAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 1 : 0); + + auto func = PRAGMA_THREADS_FOR{ + + for (auto i0 = start; i0 < stop; ++i0) { + + auto x0 = x + i0 * xStrd0; + auto y0 = y + i0 * yStrd0; + auto z0 = z + i0 * zStrd0; + + if(zStrd1 == 1 && xStrd1 == 1 && yStrd1 == 0) + for (uint i1 = 0; i1 < zAxis1; ++i1) + z0[i1] = OpType::op(x0[i1], *y0, extraParams); + else if(zStrd1 == 1 && xStrd1 == 0 && yStrd1 == 1) + for (uint i1 = 0; i1 < zAxis1; ++i1) + z0[i1] = OpType::op(*x0, y0[i1], extraParams); + else if(zStrd1 == 1 && xStrd1 == 1 && yStrd1 == 1) + for (uint i1 = 0; i1 < zAxis1; ++i1) + z0[i1] = OpType::op(x0[i1], y0[i1], extraParams); + else + for (uint i1 = 0; i1 < zAxis1; ++i1) + z0[i1 * zStrd1] = OpType::op(x0[i1 * xStrd1], y0[i1 * yStrd1], extraParams); + } + }; + + samediff::Threads::parallel_tad(func, 0, zAxis0); +} + +//////////////////////////////////////////////////////////////////////// +template +static void execRank3(const X *x, const Nd4jLong *xShapeInfo, const X *y, const Nd4jLong *yShapeInfo, Z* z, const Nd4jLong *zShapeInfo, X* extraParams) { + + uint zAxis0 = shape::sizeAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 2); + Nd4jLong xStrd0 = shape::strideAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 2); + Nd4jLong yStrd0 = shape::strideAt(yShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 2); + Nd4jLong zStrd0 = shape::strideAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 2); + + uint zAxis1 = shape::sizeAt(zShapeInfo, 1); + Nd4jLong xStrd1 = shape::strideAt(xShapeInfo, 1); + Nd4jLong yStrd1 = shape::strideAt(yShapeInfo, 1); + Nd4jLong zStrd1 = shape::strideAt(zShapeInfo, 1); + + uint zAxis2 = shape::sizeAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 2 : 0); + Nd4jLong xStrd2 = shape::strideAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? 2 : 0); + Nd4jLong yStrd2 = shape::strideAt(yShapeInfo, shape::order(zShapeInfo) == 'c' ? 2 : 0); + Nd4jLong zStrd2 = shape::strideAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 2 : 0); + + auto func = PRAGMA_THREADS_FOR_2D { + + for (auto i0 = start_x; i0 < stop_x; ++i0) { + for (auto i1 = start_y; i1 < stop_y; ++i1) { + + auto x1 = x + i0 * xStrd0 + i1 * xStrd1; + auto y1 = y + i0 * yStrd0 + i1 * yStrd1; + auto z1 = z + i0 * zStrd0 + i1 * zStrd1; + + if(zStrd2 == 1 && xStrd2 == 1 && yStrd2 == 0) + for (uint i2 = 0; i2 < zAxis2; ++i2) + z1[i2] = OpType::op(x1[i2], *y1, extraParams); + else if(zStrd2 == 1 && xStrd2 == 0 && yStrd2 == 1) + for (uint i2 = 0; i2 < zAxis2; ++i2) + z1[i2] = OpType::op(*x1, y1[i2], extraParams); + else if(zStrd2 == 1 && xStrd2 == 1 && yStrd2 == 1) + for (uint i2 = 0; i2 < zAxis2; ++i2) + z1[i2] = OpType::op(x1[i2], y1[i2], extraParams); + else + for (uint i2 = 0; i2 < zAxis2; ++i2) + z1[i2 * zStrd2] = OpType::op(x1[i2 * xStrd2], y1[i2 * yStrd2], extraParams); + } + } + }; + + samediff::Threads::parallel_for(func, 0,zAxis0,1, 0,zAxis1,1); +} + +//////////////////////////////////////////////////////////////////////// +template +static void execRank4(const X *x, const Nd4jLong *xShapeInfo, const X *y, const Nd4jLong *yShapeInfo, Z* z, const Nd4jLong *zShapeInfo, X* extraParams) { + + uint zAxis0 = shape::sizeAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 3); + Nd4jLong xStrd0 = shape::strideAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 3); + Nd4jLong yStrd0 = shape::strideAt(yShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 3); + Nd4jLong zStrd0 = shape::strideAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 3); + + uint zAxis1 = shape::sizeAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 1 : 2); + Nd4jLong xStrd1 = shape::strideAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? 1 : 2); + Nd4jLong yStrd1 = shape::strideAt(yShapeInfo, shape::order(zShapeInfo) == 'c' ? 1 : 2); + Nd4jLong zStrd1 = shape::strideAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 1 : 2); + + uint zAxis2 = shape::sizeAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 2 : 1); + Nd4jLong xStrd2 = shape::strideAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? 2 : 1); + Nd4jLong yStrd2 = shape::strideAt(yShapeInfo, shape::order(zShapeInfo) == 'c' ? 2 : 1); + Nd4jLong zStrd2 = shape::strideAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 2 : 1); + + uint zAxis3 = shape::sizeAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 3 : 0); + Nd4jLong xStrd3 = shape::strideAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? 3 : 0); + Nd4jLong yStrd3 = shape::strideAt(yShapeInfo, shape::order(zShapeInfo) == 'c' ? 3 : 0); + Nd4jLong zStrd3 = shape::strideAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 3 : 0); + + auto func = PRAGMA_THREADS_FOR_3D { + + for (auto i0 = start_x; i0 < stop_x; ++i0) { + for (auto i1 = start_y; i1 < stop_y; ++i1) { + for (auto i2 = start_z; i2 < stop_z; ++i2) { + + auto x2 = x + i0 * xStrd0 + i1 * xStrd1 + i2 * xStrd2; + auto y2 = y + i0 * yStrd0 + i1 * yStrd1 + i2 * yStrd2; + auto z2 = z + i0 * zStrd0 + i1 * zStrd1 + i2 * zStrd2; + + if(zStrd3 == 1 && xStrd3 == 1 && yStrd3 == 0) + for (uint i3 = 0; i3 < zAxis3; ++i3) + z2[i3] = OpType::op(x2[i3], *y2, extraParams); + else if(zStrd3 == 1 && xStrd3 == 0 && yStrd3 == 1) + for (uint i3 = 0; i3 < zAxis3; ++i3) + z2[i3] = OpType::op(*x2, y2[i3], extraParams); + else if(zStrd3 == 1 && xStrd3 == 1 && yStrd3 == 1) + for (uint i3 = 0; i3 < zAxis3; ++i3) + z2[i3] = OpType::op(x2[i3], y2[i3], extraParams); + else + for (uint i3 = 0; i3 < zAxis3; ++i3) + z2[i3 * zStrd3] = OpType::op(x2[i3 * xStrd3], y2[i3 * yStrd3], extraParams); + } + } + } + }; + + samediff::Threads::parallel_for(func, 0,zAxis0,1, 0,zAxis1,1, 0,zAxis2,1); +} + +//////////////////////////////////////////////////////////////////////// +template +static void execRank5(const X *x, const Nd4jLong *xShapeInfo, const X *y, const Nd4jLong *yShapeInfo, Z* z, const Nd4jLong *zShapeInfo, X* extraParams) { + + uint zAxis0 = shape::sizeAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 4); + Nd4jLong xStrd0 = shape::strideAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 4); + Nd4jLong yStrd0 = shape::strideAt(yShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 4); + Nd4jLong zStrd0 = shape::strideAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 4); + + uint zAxis1 = shape::sizeAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 1 : 3); + Nd4jLong xStrd1 = shape::strideAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? 1 : 3); + Nd4jLong yStrd1 = shape::strideAt(yShapeInfo, shape::order(zShapeInfo) == 'c' ? 1 : 3); + Nd4jLong zStrd1 = shape::strideAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 1 : 3); + + uint zAxis2 = shape::sizeAt(zShapeInfo, 2); + Nd4jLong xStrd2 = shape::strideAt(xShapeInfo, 2); + Nd4jLong yStrd2 = shape::strideAt(yShapeInfo, 2); + Nd4jLong zStrd2 = shape::strideAt(zShapeInfo, 2); + + uint zAxis3 = shape::sizeAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 3 : 1); + Nd4jLong xStrd3 = shape::strideAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? 3 : 1); + Nd4jLong yStrd3 = shape::strideAt(yShapeInfo, shape::order(zShapeInfo) == 'c' ? 3 : 1); + Nd4jLong zStrd3 = shape::strideAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 3 : 1); + + uint zAxis4 = shape::sizeAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 4 : 0); + Nd4jLong xStrd4 = shape::strideAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? 4 : 0); + Nd4jLong yStrd4 = shape::strideAt(yShapeInfo, shape::order(zShapeInfo) == 'c' ? 4 : 0); + Nd4jLong zStrd4 = shape::strideAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 4 : 0); + + auto func = PRAGMA_THREADS_FOR_3D { + + for (auto i0 = start_x; i0 < stop_x; ++i0) { + for (auto i1 = start_y; i1 < stop_y; ++i1) { + for (auto i2 = start_z; i2 < stop_z; ++i2) { + for (uint i3 = 0; i3 < zAxis3; ++i3) { + + auto x3 = x + i0 * xStrd0 + i1 * xStrd1 + i2 * xStrd2 + i3 * xStrd3; + auto y3 = y + i0 * yStrd0 + i1 * yStrd1 + i2 * yStrd2 + i3 * yStrd3; + auto z3 = z + i0 * zStrd0 + i1 * zStrd1 + i2 * zStrd2 + i3 * zStrd3; + + if(zStrd4 == 1 && xStrd4 == 1 && yStrd4 == 0) + for (uint i4 = 0; i4 < zAxis4; ++i4) + z3[i4] = OpType::op(x3[i4], *y3, extraParams); + else if(zStrd4 == 1 && xStrd4 == 0 && yStrd4 == 1) + for (uint i4 = 0; i4 < zAxis4; ++i4) + z3[i4] = OpType::op(*x3, y3[i4], extraParams); + else if(zStrd4 == 1 && xStrd4 == 1 && yStrd4 == 1) + for (uint i4 = 0; i4 < zAxis4; ++i4) + z3[i4] = OpType::op(x3[i4], y3[i4], extraParams); + else + for (uint i4 = 0; i4 < zAxis4; ++i4) + z3[i4 * zStrd4] = OpType::op(x3[i4 * xStrd4], y3[i4 * yStrd4], extraParams); + } + } + } + } + }; + + samediff::Threads::parallel_for(func, 0,zAxis0,1, 0,zAxis1,1, 0,zAxis2,1); +} + +//////////////////////////////////////////////////////////////////////// +template +static void execDefault(const X *x, const Nd4jLong *xShapeInfo, const X *y, const Nd4jLong *yShapeInfo, Z* z, const Nd4jLong *zShapeInfo, X* extraParams) { + + const bool xzSameOffsets = shape::haveSameShapeAndStrides(xShapeInfo, zShapeInfo); + const bool yzSameOffsets = shape::haveSameShapeAndStrides(yShapeInfo, zShapeInfo); + + const int rank = shape::rank(zShapeInfo); // xRank = yRank = zRank + + auto func = PRAGMA_THREADS_FOR{ + + int xCoords[MAX_RANK], yCoords[MAX_RANK], zCoords[MAX_RANK]; + + for (auto i = start; i < stop; ++i) { + + shape::index2coordsCPU(start, i, zShapeInfo, zCoords); + + for (uint j = 0; j < rank; ++j) { + xCoords[j] = shape::sizeAt(xShapeInfo, j) == 1 ? 0 : zCoords[j]; + yCoords[j] = shape::sizeAt(yShapeInfo, j) == 1 ? 0 : zCoords[j]; + } + + const auto zOffset = shape::getOffset(zShapeInfo, zCoords); + const auto xOffset = xzSameOffsets ? zOffset : shape::getOffset(xShapeInfo, xCoords); + const auto yOffset = yzSameOffsets ? zOffset : shape::getOffset(yShapeInfo, yCoords); + + z[zOffset] = OpType::op(x[xOffset], y[yOffset], extraParams); + } + }; + + samediff::Threads::parallel_for(func, 0, shape::length(zShapeInfo)); +} //////////////////////////////////////////////////////////////////////// template template @@ -468,220 +733,26 @@ void BroadcastBool::exec(const void *vx, const Nd4jLong *xShapeInfo, X* extraParams = reinterpret_cast(vextraParams); const int rank = shape::rank(zShapeInfo); // xRank = yRank = zRank - const char zOrder = shape::order(zShapeInfo); - - uint xAxis0 = shape::sizeAt(xShapeInfo, zOrder == 'c' ? 0 : rank-1); - uint xAxis1 = shape::sizeAt(xShapeInfo, zOrder == 'c' ? 1 : rank-2); - uint xAxis2 = rank > 2 ? shape::sizeAt(xShapeInfo, zOrder == 'c' ? 2 : rank - 3) : 0; - uint xAxis3 = rank > 3 ? shape::sizeAt(xShapeInfo, zOrder == 'c' ? 3 : rank - 4) : 0; - uint xAxis4 = rank > 4 ? shape::sizeAt(xShapeInfo, zOrder == 'c' ? 4 : rank - 5) : 0; - Nd4jLong xStrd0 = shape::strideAt(xShapeInfo, zOrder == 'c' ? 0 : rank-1); - Nd4jLong xStrd1 = shape::strideAt(xShapeInfo, zOrder == 'c' ? 1 : rank-2); - Nd4jLong xStrd2 = rank > 2 ? shape::strideAt(xShapeInfo, zOrder == 'c' ? 2 : rank - 3) : 0; - Nd4jLong xStrd3 = rank > 3 ? shape::strideAt(xShapeInfo, zOrder == 'c' ? 3 : rank - 4) : 0; - Nd4jLong xStrd4 = rank > 4 ? shape::strideAt(xShapeInfo, zOrder == 'c' ? 4 : rank - 5) : 0; - - uint yAxis0 = shape::sizeAt(yShapeInfo, zOrder == 'c' ? 0 : rank-1); - uint yAxis1 = shape::sizeAt(yShapeInfo, zOrder == 'c' ? 1 : rank-2); - uint yAxis2 = rank > 2 ? shape::sizeAt(yShapeInfo, zOrder == 'c' ? 2 : rank - 3) : 0; - uint yAxis3 = rank > 3 ? shape::sizeAt(yShapeInfo, zOrder == 'c' ? 3 : rank - 4) : 0; - uint yAxis4 = rank > 4 ? shape::sizeAt(yShapeInfo, zOrder == 'c' ? 4 : rank - 5) : 0; - Nd4jLong yStrd0 = shape::strideAt(yShapeInfo, zOrder == 'c' ? 0 : rank-1); - Nd4jLong yStrd1 = shape::strideAt(yShapeInfo, zOrder == 'c' ? 1 : rank-2); - Nd4jLong yStrd2 = rank > 2 ? shape::strideAt(yShapeInfo, zOrder == 'c' ? 2 : rank - 3) : 0; - Nd4jLong yStrd3 = rank > 3 ? shape::strideAt(yShapeInfo, zOrder == 'c' ? 3 : rank - 4) : 0; - Nd4jLong yStrd4 = rank > 4 ? shape::strideAt(yShapeInfo, zOrder == 'c' ? 4 : rank - 5) : 0; - - uint zAxis0 = shape::sizeAt(zShapeInfo, zOrder == 'c' ? 0 : rank-1); - uint zAxis1 = shape::sizeAt(zShapeInfo, zOrder == 'c' ? 1 : rank-2); - uint zAxis2 = rank > 2 ? shape::sizeAt(zShapeInfo, zOrder == 'c' ? 2 : rank - 3) : 0; - uint zAxis3 = rank > 3 ? shape::sizeAt(zShapeInfo, zOrder == 'c' ? 3 : rank - 4) : 0; - uint zAxis4 = rank > 4 ? shape::sizeAt(zShapeInfo, zOrder == 'c' ? 4 : rank - 5) : 0; - Nd4jLong zStrd0 = shape::strideAt(zShapeInfo, zOrder == 'c' ? 0 : rank-1); - Nd4jLong zStrd1 = shape::strideAt(zShapeInfo, zOrder == 'c' ? 1 : rank-2); - Nd4jLong zStrd2 = rank > 2 ? shape::strideAt(zShapeInfo, zOrder == 'c' ? 2 : rank - 3) : 0; - Nd4jLong zStrd3 = rank > 3 ? shape::strideAt(zShapeInfo, zOrder == 'c' ? 3 : rank - 4) : 0; - Nd4jLong zStrd4 = rank > 4 ? shape::strideAt(zShapeInfo, zOrder == 'c' ? 4 : rank - 5) : 0; switch (rank) { - case 1: { - - auto func = PRAGMA_THREADS_FOR{ - - if(zStrd0 == 1 && xStrd0 == 1 && yStrd0 == 0) { - for (auto i0 = start; i0 < stop; ++i0) - z[i0] = OpType::op(x[i0], *y, extraParams); - } - else if(zStrd0 == 1 && xStrd0 == 0 && yStrd0 == 1) { - for (auto i0 = start; i0 < stop; ++i0) - z[i0] = OpType::op(*x, y[i0], extraParams); - } - else if(zStrd0 == 1 && xStrd0 == 1 && yStrd0 == 1) { - for (auto i0 = start; i0 < stop; ++i0) - z[i0] = OpType::op(x[i0], y[i0], extraParams); - } - else { - for (auto i0 = start; i0 < stop; ++i0) - z[i0 * zStrd0] = OpType::op(x[i0 * xStrd0], y[i0 * yStrd0], extraParams); - } - }; - samediff::Threads::parallel_tad(func, 0, zAxis0); - } - break; - - case 2: { - - auto func = PRAGMA_THREADS_FOR{ - - for (auto i0 = start; i0 < stop; ++i0) { - - auto x0 = x + i0 * xStrd0; - auto y0 = y + i0 * yStrd0; - auto z0 = z + i0 * zStrd0; - - if(zStrd1 == 1 && xStrd1 == 1 && yStrd1 == 0) - for (uint i1 = 0; i1 < zAxis1; ++i1) - z0[i1] = OpType::op(x0[i1], *y0, extraParams); - else if(zStrd1 == 1 && xStrd1 == 0 && yStrd1 == 1) - for (uint i1 = 0; i1 < zAxis1; ++i1) - z0[i1] = OpType::op(*x0, y0[i1], extraParams); - else if(zStrd1 == 1 && xStrd1 == 1 && yStrd1 == 1) - for (uint i1 = 0; i1 < zAxis1; ++i1) - z0[i1] = OpType::op(x0[i1], y0[i1], extraParams); - else - for (uint i1 = 0; i1 < zAxis1; ++i1) - z0[i1 * zStrd1] = OpType::op(x0[i1 * xStrd1], y0[i1 * yStrd1], extraParams); - } - }; - samediff::Threads::parallel_tad(func, 0, zAxis0); - } - break; - - case 3: { - - auto func = PRAGMA_THREADS_FOR_2D { - - for (auto i0 = start_x; i0 < stop_x; ++i0) { - for (auto i1 = start_y; i1 < stop_y; ++i1) { - - auto x1 = x + i0 * xStrd0 + i1 * xStrd1; - auto y1 = y + i0 * yStrd0 + i1 * yStrd1; - auto z1 = z + i0 * zStrd0 + i1 * zStrd1; - - if(zStrd2 == 1 && xStrd2 == 1 && yStrd2 == 0) - for (uint i2 = 0; i2 < zAxis2; ++i2) - z1[i2] = OpType::op(x1[i2], *y1, extraParams); - else if(zStrd2 == 1 && xStrd2 == 0 && yStrd2 == 1) - for (uint i2 = 0; i2 < zAxis2; ++i2) - z1[i2] = OpType::op(*x1, y1[i2], extraParams); - else if(zStrd2 == 1 && xStrd2 == 1 && yStrd2 == 1) - for (uint i2 = 0; i2 < zAxis2; ++i2) - z1[i2] = OpType::op(x1[i2], y1[i2], extraParams); - else - for (uint i2 = 0; i2 < zAxis2; ++i2) - z1[i2 * zStrd2] = OpType::op(x1[i2 * xStrd2], y1[i2 * yStrd2], extraParams); - } - } - }; - samediff::Threads::parallel_for(func, 0,zAxis0,1, 0,zAxis1,1); - } - break; - - case 4: { - - auto func = PRAGMA_THREADS_FOR_3D { - - for (auto i0 = start_x; i0 < stop_x; ++i0) { - for (auto i1 = start_y; i1 < stop_y; ++i1) { - for (auto i2 = start_z; i2 < stop_z; ++i2) { - - auto x2 = x + i0 * xStrd0 + i1 * xStrd1 + i2 * xStrd2; - auto y2 = y + i0 * yStrd0 + i1 * yStrd1 + i2 * yStrd2; - auto z2 = z + i0 * zStrd0 + i1 * zStrd1 + i2 * zStrd2; - - if(zStrd3 == 1 && xStrd3 == 1 && yStrd3 == 0) - for (uint i3 = 0; i3 < zAxis3; ++i3) - z2[i3] = OpType::op(x2[i3], *y2, extraParams); - else if(zStrd3 == 1 && xStrd3 == 0 && yStrd3 == 1) - for (uint i3 = 0; i3 < zAxis3; ++i3) - z2[i3] = OpType::op(*x2, y2[i3], extraParams); - else if(zStrd3 == 1 && xStrd3 == 1 && yStrd3 == 1) - for (uint i3 = 0; i3 < zAxis3; ++i3) - z2[i3] = OpType::op(x2[i3], y2[i3], extraParams); - else - for (uint i3 = 0; i3 < zAxis3; ++i3) - z2[i3 * zStrd3] = OpType::op(x2[i3 * xStrd3], y2[i3 * yStrd3], extraParams); - } - } - } - }; - samediff::Threads::parallel_for(func, 0,zAxis0,1, 0,zAxis1,1, 0,zAxis2,1); - } - break; - - case 5: { - - auto func = PRAGMA_THREADS_FOR_3D { - - for (auto i0 = start_x; i0 < stop_x; ++i0) { - for (auto i1 = start_y; i1 < stop_y; ++i1) { - for (auto i2 = start_z; i2 < stop_z; ++i2) { - for (uint i3 = 0; i3 < zAxis3; ++i3) { - - auto x3 = x + i0 * xStrd0 + i1 * xStrd1 + i2 * xStrd2 + i3 * xStrd3; - auto y3 = y + i0 * yStrd0 + i1 * yStrd1 + i2 * yStrd2 + i3 * yStrd3; - auto z3 = z + i0 * zStrd0 + i1 * zStrd1 + i2 * zStrd2 + i3 * zStrd3; - - if(zStrd4 == 1 && xStrd4 == 1 && yStrd4 == 0) - for (uint i4 = 0; i4 < zAxis4; ++i4) - z3[i4] = OpType::op(x3[i4], *y3, extraParams); - else if(zStrd4 == 1 && xStrd4 == 0 && yStrd4 == 1) - for (uint i4 = 0; i4 < zAxis4; ++i4) - z3[i4] = OpType::op(*x3, y3[i4], extraParams); - else if(zStrd4 == 1 && xStrd4 == 1 && yStrd4 == 1) - for (uint i4 = 0; i4 < zAxis4; ++i4) - z3[i4] = OpType::op(x3[i4], y3[i4], extraParams); - else - for (uint i4 = 0; i4 < zAxis4; ++i4) - z3[i4 * zStrd4] = OpType::op(x3[i4 * xStrd4], y3[i4 * yStrd4], extraParams); - } - } - } - } - }; - samediff::Threads::parallel_for(func, 0,zAxis0,1, 0,zAxis1,1, 0,zAxis2,1); - } - break; - - default: { - - const bool xzSameOffsets = shape::haveSameShapeAndStrides(xShapeInfo, zShapeInfo); - const bool yzSameOffsets = shape::haveSameShapeAndStrides(yShapeInfo, zShapeInfo); - - auto func = PRAGMA_THREADS_FOR{ - - int xCoords[MAX_RANK], yCoords[MAX_RANK], zCoords[MAX_RANK]; - - for (auto i = start; i < stop; ++i) { - - shape::index2coordsCPU(start, i, zShapeInfo, zCoords); - - for (uint j = 0; j < rank; ++j) { - xCoords[j] = shape::sizeAt(xShapeInfo, j) == 1 ? 0 : zCoords[j]; - yCoords[j] = shape::sizeAt(yShapeInfo, j) == 1 ? 0 : zCoords[j]; - } - - const auto zOffset = shape::getOffset(zShapeInfo, zCoords); - const auto xOffset = xzSameOffsets ? zOffset : shape::getOffset(xShapeInfo, xCoords); - const auto yOffset = yzSameOffsets ? zOffset : shape::getOffset(yShapeInfo, yCoords); - - z[zOffset] = OpType::op(x[xOffset], y[yOffset], extraParams); - } - }; - - samediff::Threads::parallel_for(func, 0, shape::length(zShapeInfo)); - } + case 1: + execRank1(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, extraParams); + break; + case 2: + execRank2(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, extraParams); + break; + case 3: + execRank3(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, extraParams); + break; + case 4: + execRank4(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, extraParams); + break; + case 5: + execRank5(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, extraParams); + break; + default: + execDefault(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, extraParams); } } diff --git a/libnd4j/include/loops/cpu/broadcasting_int.hpp b/libnd4j/include/loops/cpu/broadcasting_int.hpp index deb8c2ea3..456994b16 100644 --- a/libnd4j/include/loops/cpu/broadcasting_int.hpp +++ b/libnd4j/include/loops/cpu/broadcasting_int.hpp @@ -439,6 +439,271 @@ namespace functions { } } +//////////////////////////////////////////////////////////////////////// +template +static void execRank1(const X *x, const Nd4jLong *xShapeInfo, const X *y, const Nd4jLong *yShapeInfo, X* z, const Nd4jLong *zShapeInfo) { + + uint zAxis0 = shape::sizeAt(zShapeInfo, 0); + Nd4jLong xStrd0 = shape::strideAt(xShapeInfo, 0); + Nd4jLong yStrd0 = shape::strideAt(yShapeInfo, 0); + Nd4jLong zStrd0 = shape::strideAt(zShapeInfo, 0); + + auto func = PRAGMA_THREADS_FOR{ + + if(zStrd0 == 1 && xStrd0 == 1 && yStrd0 == 0) { + for (auto i0 = start; i0 < stop; ++i0) + z[i0] = OpType::op(x[i0], *y); + } + else if(zStrd0 == 1 && xStrd0 == 0 && yStrd0 == 1) { + for (auto i0 = start; i0 < stop; ++i0) + z[i0] = OpType::op(*x, y[i0]); + } + else if(zStrd0 == 1 && xStrd0 == 1 && yStrd0 == 1) { + for (auto i0 = start; i0 < stop; ++i0) + z[i0] = OpType::op(x[i0], y[i0]); + } + else { + for (auto i0 = start; i0 < stop; ++i0) + z[i0 * zStrd0] = OpType::op(x[i0 * xStrd0], y[i0 * yStrd0]); + } + }; + samediff::Threads::parallel_tad(func, 0, zAxis0); +} + +//////////////////////////////////////////////////////////////////////// +template +static void execRank2(const X *x, const Nd4jLong *xShapeInfo, const X *y, const Nd4jLong *yShapeInfo, X* z, const Nd4jLong *zShapeInfo) { + + uint zAxis0 = shape::sizeAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 1); + Nd4jLong xStrd0 = shape::strideAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 1); + Nd4jLong yStrd0 = shape::strideAt(yShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 1); + Nd4jLong zStrd0 = shape::strideAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 1); + + uint zAxis1 = shape::sizeAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 1 : 0); + Nd4jLong xStrd1 = shape::strideAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? 1 : 0); + Nd4jLong yStrd1 = shape::strideAt(yShapeInfo, shape::order(zShapeInfo) == 'c' ? 1 : 0); + Nd4jLong zStrd1 = shape::strideAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 1 : 0); + + auto func = PRAGMA_THREADS_FOR{ + + for (auto i0 = start; i0 < stop; ++i0) { + + auto x0 = x + i0 * xStrd0; + auto y0 = y + i0 * yStrd0; + auto z0 = z + i0 * zStrd0; + + if(zStrd1 == 1 && xStrd1 == 1 && yStrd1 == 0) + for (uint i1 = 0; i1 < zAxis1; ++i1) + z0[i1] = OpType::op(x0[i1], *y0); + else if(zStrd1 == 1 && xStrd1 == 0 && yStrd1 == 1) + for (uint i1 = 0; i1 < zAxis1; ++i1) + z0[i1] = OpType::op(*x0, y0[i1]); + else if(zStrd1 == 1 && xStrd1 == 1 && yStrd1 == 1) + for (uint i1 = 0; i1 < zAxis1; ++i1) + z0[i1] = OpType::op(x0[i1], y0[i1]); + else + for (uint i1 = 0; i1 < zAxis1; ++i1) + z0[i1 * zStrd1] = OpType::op(x0[i1 * xStrd1], y0[i1 * yStrd1]); + } + }; + + samediff::Threads::parallel_tad(func, 0, zAxis0); +} + +//////////////////////////////////////////////////////////////////////// +template +static void execRank3(const X *x, const Nd4jLong *xShapeInfo, const X *y, const Nd4jLong *yShapeInfo, X* z, const Nd4jLong *zShapeInfo) { + + uint zAxis0 = shape::sizeAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 2); + Nd4jLong xStrd0 = shape::strideAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 2); + Nd4jLong yStrd0 = shape::strideAt(yShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 2); + Nd4jLong zStrd0 = shape::strideAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 2); + + uint zAxis1 = shape::sizeAt(zShapeInfo, 1); + Nd4jLong xStrd1 = shape::strideAt(xShapeInfo, 1); + Nd4jLong yStrd1 = shape::strideAt(yShapeInfo, 1); + Nd4jLong zStrd1 = shape::strideAt(zShapeInfo, 1); + + uint zAxis2 = shape::sizeAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 2 : 0); + Nd4jLong xStrd2 = shape::strideAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? 2 : 0); + Nd4jLong yStrd2 = shape::strideAt(yShapeInfo, shape::order(zShapeInfo) == 'c' ? 2 : 0); + Nd4jLong zStrd2 = shape::strideAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 2 : 0); + + auto func = PRAGMA_THREADS_FOR_2D { + + for (auto i0 = start_x; i0 < stop_x; ++i0) { + for (auto i1 = start_y; i1 < stop_y; ++i1) { + + auto x1 = x + i0 * xStrd0 + i1 * xStrd1; + auto y1 = y + i0 * yStrd0 + i1 * yStrd1; + auto z1 = z + i0 * zStrd0 + i1 * zStrd1; + + if(zStrd2 == 1 && xStrd2 == 1 && yStrd2 == 0) + for (uint i2 = 0; i2 < zAxis2; ++i2) + z1[i2] = OpType::op(x1[i2], *y1); + else if(zStrd2 == 1 && xStrd2 == 0 && yStrd2 == 1) + for (uint i2 = 0; i2 < zAxis2; ++i2) + z1[i2] = OpType::op(*x1, y1[i2]); + else if(zStrd2 == 1 && xStrd2 == 1 && yStrd2 == 1) + for (uint i2 = 0; i2 < zAxis2; ++i2) + z1[i2] = OpType::op(x1[i2], y1[i2]); + else + for (uint i2 = 0; i2 < zAxis2; ++i2) + z1[i2 * zStrd2] = OpType::op(x1[i2 * xStrd2], y1[i2 * yStrd2]); + } + } + }; + + samediff::Threads::parallel_for(func, 0,zAxis0,1, 0,zAxis1,1); +} + +//////////////////////////////////////////////////////////////////////// +template +static void execRank4(const X *x, const Nd4jLong *xShapeInfo, const X *y, const Nd4jLong *yShapeInfo, X* z, const Nd4jLong *zShapeInfo) { + + uint zAxis0 = shape::sizeAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 3); + Nd4jLong xStrd0 = shape::strideAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 3); + Nd4jLong yStrd0 = shape::strideAt(yShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 3); + Nd4jLong zStrd0 = shape::strideAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 3); + + uint zAxis1 = shape::sizeAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 1 : 2); + Nd4jLong xStrd1 = shape::strideAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? 1 : 2); + Nd4jLong yStrd1 = shape::strideAt(yShapeInfo, shape::order(zShapeInfo) == 'c' ? 1 : 2); + Nd4jLong zStrd1 = shape::strideAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 1 : 2); + + uint zAxis2 = shape::sizeAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 2 : 1); + Nd4jLong xStrd2 = shape::strideAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? 2 : 1); + Nd4jLong yStrd2 = shape::strideAt(yShapeInfo, shape::order(zShapeInfo) == 'c' ? 2 : 1); + Nd4jLong zStrd2 = shape::strideAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 2 : 1); + + uint zAxis3 = shape::sizeAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 3 : 0); + Nd4jLong xStrd3 = shape::strideAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? 3 : 0); + Nd4jLong yStrd3 = shape::strideAt(yShapeInfo, shape::order(zShapeInfo) == 'c' ? 3 : 0); + Nd4jLong zStrd3 = shape::strideAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 3 : 0); + + auto func = PRAGMA_THREADS_FOR_3D { + + for (auto i0 = start_x; i0 < stop_x; ++i0) { + for (auto i1 = start_y; i1 < stop_y; ++i1) { + for (auto i2 = start_z; i2 < stop_z; ++i2) { + + auto x2 = x + i0 * xStrd0 + i1 * xStrd1 + i2 * xStrd2; + auto y2 = y + i0 * yStrd0 + i1 * yStrd1 + i2 * yStrd2; + auto z2 = z + i0 * zStrd0 + i1 * zStrd1 + i2 * zStrd2; + + if(zStrd3 == 1 && xStrd3 == 1 && yStrd3 == 0) + for (uint i3 = 0; i3 < zAxis3; ++i3) + z2[i3] = OpType::op(x2[i3], *y2); + else if(zStrd3 == 1 && xStrd3 == 0 && yStrd3 == 1) + for (uint i3 = 0; i3 < zAxis3; ++i3) + z2[i3] = OpType::op(*x2, y2[i3]); + else if(zStrd3 == 1 && xStrd3 == 1 && yStrd3 == 1) + for (uint i3 = 0; i3 < zAxis3; ++i3) + z2[i3] = OpType::op(x2[i3], y2[i3]); + else + for (uint i3 = 0; i3 < zAxis3; ++i3) + z2[i3 * zStrd3] = OpType::op(x2[i3 * xStrd3], y2[i3 * yStrd3]); + } + } + } + }; + + samediff::Threads::parallel_for(func, 0,zAxis0,1, 0,zAxis1,1, 0,zAxis2,1); +} + +//////////////////////////////////////////////////////////////////////// +template +static void execRank5(const X *x, const Nd4jLong *xShapeInfo, const X *y, const Nd4jLong *yShapeInfo, X* z, const Nd4jLong *zShapeInfo) { + + uint zAxis0 = shape::sizeAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 4); + Nd4jLong xStrd0 = shape::strideAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 4); + Nd4jLong yStrd0 = shape::strideAt(yShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 4); + Nd4jLong zStrd0 = shape::strideAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 4); + + uint zAxis1 = shape::sizeAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 1 : 3); + Nd4jLong xStrd1 = shape::strideAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? 1 : 3); + Nd4jLong yStrd1 = shape::strideAt(yShapeInfo, shape::order(zShapeInfo) == 'c' ? 1 : 3); + Nd4jLong zStrd1 = shape::strideAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 1 : 3); + + uint zAxis2 = shape::sizeAt(zShapeInfo, 2); + Nd4jLong xStrd2 = shape::strideAt(xShapeInfo, 2); + Nd4jLong yStrd2 = shape::strideAt(yShapeInfo, 2); + Nd4jLong zStrd2 = shape::strideAt(zShapeInfo, 2); + + uint zAxis3 = shape::sizeAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 3 : 1); + Nd4jLong xStrd3 = shape::strideAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? 3 : 1); + Nd4jLong yStrd3 = shape::strideAt(yShapeInfo, shape::order(zShapeInfo) == 'c' ? 3 : 1); + Nd4jLong zStrd3 = shape::strideAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 3 : 1); + + uint zAxis4 = shape::sizeAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 4 : 0); + Nd4jLong xStrd4 = shape::strideAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? 4 : 0); + Nd4jLong yStrd4 = shape::strideAt(yShapeInfo, shape::order(zShapeInfo) == 'c' ? 4 : 0); + Nd4jLong zStrd4 = shape::strideAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 4 : 0); + + auto func = PRAGMA_THREADS_FOR_3D { + + for (auto i0 = start_x; i0 < stop_x; ++i0) { + for (auto i1 = start_y; i1 < stop_y; ++i1) { + for (auto i2 = start_z; i2 < stop_z; ++i2) { + for (uint i3 = 0; i3 < zAxis3; ++i3) { + + auto x3 = x + i0 * xStrd0 + i1 * xStrd1 + i2 * xStrd2 + i3 * xStrd3; + auto y3 = y + i0 * yStrd0 + i1 * yStrd1 + i2 * yStrd2 + i3 * yStrd3; + auto z3 = z + i0 * zStrd0 + i1 * zStrd1 + i2 * zStrd2 + i3 * zStrd3; + + if(zStrd4 == 1 && xStrd4 == 1 && yStrd4 == 0) + for (uint i4 = 0; i4 < zAxis4; ++i4) + z3[i4] = OpType::op(x3[i4], *y3); + else if(zStrd4 == 1 && xStrd4 == 0 && yStrd4 == 1) + for (uint i4 = 0; i4 < zAxis4; ++i4) + z3[i4] = OpType::op(*x3, y3[i4]); + else if(zStrd4 == 1 && xStrd4 == 1 && yStrd4 == 1) + for (uint i4 = 0; i4 < zAxis4; ++i4) + z3[i4] = OpType::op(x3[i4], y3[i4]); + else + for (uint i4 = 0; i4 < zAxis4; ++i4) + z3[i4 * zStrd4] = OpType::op(x3[i4 * xStrd4], y3[i4 * yStrd4]); + } + } + } + } + }; + + samediff::Threads::parallel_for(func, 0,zAxis0,1, 0,zAxis1,1, 0,zAxis2,1); +} + +//////////////////////////////////////////////////////////////////////// +template +static void execDefault(const X *x, const Nd4jLong *xShapeInfo, const X *y, const Nd4jLong *yShapeInfo, X* z, const Nd4jLong *zShapeInfo) { + + const bool xzSameOffsets = shape::haveSameShapeAndStrides(xShapeInfo, zShapeInfo); + const bool yzSameOffsets = shape::haveSameShapeAndStrides(yShapeInfo, zShapeInfo); + + const int rank = shape::rank(zShapeInfo); // xRank = yRank = zRank + + auto func = PRAGMA_THREADS_FOR{ + + int xCoords[MAX_RANK], yCoords[MAX_RANK], zCoords[MAX_RANK]; + + for (auto i = start; i < stop; ++i) { + + shape::index2coordsCPU(start, i, zShapeInfo, zCoords); + + for (uint j = 0; j < rank; ++j) { + xCoords[j] = shape::sizeAt(xShapeInfo, j) == 1 ? 0 : zCoords[j]; + yCoords[j] = shape::sizeAt(yShapeInfo, j) == 1 ? 0 : zCoords[j]; + } + + const auto zOffset = shape::getOffset(zShapeInfo, zCoords); + const auto xOffset = xzSameOffsets ? zOffset : shape::getOffset(xShapeInfo, xCoords); + const auto yOffset = yzSameOffsets ? zOffset : shape::getOffset(yShapeInfo, yCoords); + + z[zOffset] = OpType::op(x[xOffset], y[yOffset]); + } + }; + + samediff::Threads::parallel_for(func, 0, shape::length(zShapeInfo)); +} //////////////////////////////////////////////////////////////////////// template @@ -452,220 +717,26 @@ void BroadcastInt::exec(const void *vx, const Nd4jLong *xShapeInfo, X* z = reinterpret_cast(vz); const int rank = shape::rank(zShapeInfo); // xRank = yRank = zRank - const char zOrder = shape::order(zShapeInfo); - - uint xAxis0 = shape::sizeAt(xShapeInfo, zOrder == 'c' ? 0 : rank-1); - uint xAxis1 = shape::sizeAt(xShapeInfo, zOrder == 'c' ? 1 : rank-2); - uint xAxis2 = rank > 2 ? shape::sizeAt(xShapeInfo, zOrder == 'c' ? 2 : rank - 3) : 0; - uint xAxis3 = rank > 3 ? shape::sizeAt(xShapeInfo, zOrder == 'c' ? 3 : rank - 4) : 0; - uint xAxis4 = rank > 4 ? shape::sizeAt(xShapeInfo, zOrder == 'c' ? 4 : rank - 5) : 0; - Nd4jLong xStrd0 = shape::strideAt(xShapeInfo, zOrder == 'c' ? 0 : rank-1); - Nd4jLong xStrd1 = shape::strideAt(xShapeInfo, zOrder == 'c' ? 1 : rank-2); - Nd4jLong xStrd2 = rank > 2 ? shape::strideAt(xShapeInfo, zOrder == 'c' ? 2 : rank - 3) : 0; - Nd4jLong xStrd3 = rank > 3 ? shape::strideAt(xShapeInfo, zOrder == 'c' ? 3 : rank - 4) : 0; - Nd4jLong xStrd4 = rank > 4 ? shape::strideAt(xShapeInfo, zOrder == 'c' ? 4 : rank - 5) : 0; - - uint yAxis0 = shape::sizeAt(yShapeInfo, zOrder == 'c' ? 0 : rank-1); - uint yAxis1 = shape::sizeAt(yShapeInfo, zOrder == 'c' ? 1 : rank-2); - uint yAxis2 = rank > 2 ? shape::sizeAt(yShapeInfo, zOrder == 'c' ? 2 : rank - 3) : 0; - uint yAxis3 = rank > 3 ? shape::sizeAt(yShapeInfo, zOrder == 'c' ? 3 : rank - 4) : 0; - uint yAxis4 = rank > 4 ? shape::sizeAt(yShapeInfo, zOrder == 'c' ? 4 : rank - 5) : 0; - Nd4jLong yStrd0 = shape::strideAt(yShapeInfo, zOrder == 'c' ? 0 : rank-1); - Nd4jLong yStrd1 = shape::strideAt(yShapeInfo, zOrder == 'c' ? 1 : rank-2); - Nd4jLong yStrd2 = rank > 2 ? shape::strideAt(yShapeInfo, zOrder == 'c' ? 2 : rank - 3) : 0; - Nd4jLong yStrd3 = rank > 3 ? shape::strideAt(yShapeInfo, zOrder == 'c' ? 3 : rank - 4) : 0; - Nd4jLong yStrd4 = rank > 4 ? shape::strideAt(yShapeInfo, zOrder == 'c' ? 4 : rank - 5) : 0; - - uint zAxis0 = shape::sizeAt(zShapeInfo, zOrder == 'c' ? 0 : rank-1); - uint zAxis1 = shape::sizeAt(zShapeInfo, zOrder == 'c' ? 1 : rank-2); - uint zAxis2 = rank > 2 ? shape::sizeAt(zShapeInfo, zOrder == 'c' ? 2 : rank - 3) : 0; - uint zAxis3 = rank > 3 ? shape::sizeAt(zShapeInfo, zOrder == 'c' ? 3 : rank - 4) : 0; - uint zAxis4 = rank > 4 ? shape::sizeAt(zShapeInfo, zOrder == 'c' ? 4 : rank - 5) : 0; - Nd4jLong zStrd0 = shape::strideAt(zShapeInfo, zOrder == 'c' ? 0 : rank-1); - Nd4jLong zStrd1 = shape::strideAt(zShapeInfo, zOrder == 'c' ? 1 : rank-2); - Nd4jLong zStrd2 = rank > 2 ? shape::strideAt(zShapeInfo, zOrder == 'c' ? 2 : rank - 3) : 0; - Nd4jLong zStrd3 = rank > 3 ? shape::strideAt(zShapeInfo, zOrder == 'c' ? 3 : rank - 4) : 0; - Nd4jLong zStrd4 = rank > 4 ? shape::strideAt(zShapeInfo, zOrder == 'c' ? 4 : rank - 5) : 0; switch (rank) { - case 1: { - - auto func = PRAGMA_THREADS_FOR{ - - if(zStrd0 == 1 && xStrd0 == 1 && yStrd0 == 0) { - for (auto i0 = start; i0 < stop; ++i0) - z[i0] = OpType::op(x[i0], *y); - } - else if(zStrd0 == 1 && xStrd0 == 0 && yStrd0 == 1) { - for (auto i0 = start; i0 < stop; ++i0) - z[i0] = OpType::op(*x, y[i0]); - } - else if(zStrd0 == 1 && xStrd0 == 1 && yStrd0 == 1) { - for (auto i0 = start; i0 < stop; ++i0) - z[i0] = OpType::op(x[i0], y[i0]); - } - else { - for (auto i0 = start; i0 < stop; ++i0) - z[i0 * zStrd0] = OpType::op(x[i0 * xStrd0], y[i0 * yStrd0]); - } - }; - samediff::Threads::parallel_tad(func, 0, zAxis0); - } - break; - - case 2: { - - auto func = PRAGMA_THREADS_FOR{ - - for (auto i0 = start; i0 < stop; ++i0) { - - auto x0 = x + i0 * xStrd0; - auto y0 = y + i0 * yStrd0; - auto z0 = z + i0 * zStrd0; - - if(zStrd1 == 1 && xStrd1 == 1 && yStrd1 == 0) - for (uint i1 = 0; i1 < zAxis1; ++i1) - z0[i1] = OpType::op(x0[i1], *y0); - else if(zStrd1 == 1 && xStrd1 == 0 && yStrd1 == 1) - for (uint i1 = 0; i1 < zAxis1; ++i1) - z0[i1] = OpType::op(*x0, y0[i1]); - else if(zStrd1 == 1 && xStrd1 == 1 && yStrd1 == 1) - for (uint i1 = 0; i1 < zAxis1; ++i1) - z0[i1] = OpType::op(x0[i1], y0[i1]); - else - for (uint i1 = 0; i1 < zAxis1; ++i1) - z0[i1 * zStrd1] = OpType::op(x0[i1 * xStrd1], y0[i1 * yStrd1]); - } - }; - samediff::Threads::parallel_tad(func, 0, zAxis0); - } - break; - - case 3: { - - auto func = PRAGMA_THREADS_FOR_2D { - - for (auto i0 = start_x; i0 < stop_x; ++i0) { - for (auto i1 = start_y; i1 < stop_y; ++i1) { - - auto x1 = x + i0 * xStrd0 + i1 * xStrd1; - auto y1 = y + i0 * yStrd0 + i1 * yStrd1; - auto z1 = z + i0 * zStrd0 + i1 * zStrd1; - - if(zStrd2 == 1 && xStrd2 == 1 && yStrd2 == 0) - for (uint i2 = 0; i2 < zAxis2; ++i2) - z1[i2] = OpType::op(x1[i2], *y1); - else if(zStrd2 == 1 && xStrd2 == 0 && yStrd2 == 1) - for (uint i2 = 0; i2 < zAxis2; ++i2) - z1[i2] = OpType::op(*x1, y1[i2]); - else if(zStrd2 == 1 && xStrd2 == 1 && yStrd2 == 1) - for (uint i2 = 0; i2 < zAxis2; ++i2) - z1[i2] = OpType::op(x1[i2], y1[i2]); - else - for (uint i2 = 0; i2 < zAxis2; ++i2) - z1[i2 * zStrd2] = OpType::op(x1[i2 * xStrd2], y1[i2 * yStrd2]); - } - } - }; - samediff::Threads::parallel_for(func, 0,zAxis0,1, 0,zAxis1,1); - } - break; - - case 4: { - - auto func = PRAGMA_THREADS_FOR_3D { - - for (auto i0 = start_x; i0 < stop_x; ++i0) { - for (auto i1 = start_y; i1 < stop_y; ++i1) { - for (auto i2 = start_z; i2 < stop_z; ++i2) { - - auto x2 = x + i0 * xStrd0 + i1 * xStrd1 + i2 * xStrd2; - auto y2 = y + i0 * yStrd0 + i1 * yStrd1 + i2 * yStrd2; - auto z2 = z + i0 * zStrd0 + i1 * zStrd1 + i2 * zStrd2; - - if(zStrd3 == 1 && xStrd3 == 1 && yStrd3 == 0) - for (uint i3 = 0; i3 < zAxis3; ++i3) - z2[i3] = OpType::op(x2[i3], *y2); - else if(zStrd3 == 1 && xStrd3 == 0 && yStrd3 == 1) - for (uint i3 = 0; i3 < zAxis3; ++i3) - z2[i3] = OpType::op(*x2, y2[i3]); - else if(zStrd3 == 1 && xStrd3 == 1 && yStrd3 == 1) - for (uint i3 = 0; i3 < zAxis3; ++i3) - z2[i3] = OpType::op(x2[i3], y2[i3]); - else - for (uint i3 = 0; i3 < zAxis3; ++i3) - z2[i3 * zStrd3] = OpType::op(x2[i3 * xStrd3], y2[i3 * yStrd3]); - } - } - } - }; - samediff::Threads::parallel_for(func, 0,zAxis0,1, 0,zAxis1,1, 0,zAxis2,1); - } - break; - - case 5: { - - auto func = PRAGMA_THREADS_FOR_3D { - - for (auto i0 = start_x; i0 < stop_x; ++i0) { - for (auto i1 = start_y; i1 < stop_y; ++i1) { - for (auto i2 = start_z; i2 < stop_z; ++i2) { - for (uint i3 = 0; i3 < zAxis3; ++i3) { - - auto x3 = x + i0 * xStrd0 + i1 * xStrd1 + i2 * xStrd2 + i3 * xStrd3; - auto y3 = y + i0 * yStrd0 + i1 * yStrd1 + i2 * yStrd2 + i3 * yStrd3; - auto z3 = z + i0 * zStrd0 + i1 * zStrd1 + i2 * zStrd2 + i3 * zStrd3; - - if(zStrd4 == 1 && xStrd4 == 1 && yStrd4 == 0) - for (uint i4 = 0; i4 < zAxis4; ++i4) - z3[i4] = OpType::op(x3[i4], *y3); - else if(zStrd4 == 1 && xStrd4 == 0 && yStrd4 == 1) - for (uint i4 = 0; i4 < zAxis4; ++i4) - z3[i4] = OpType::op(*x3, y3[i4]); - else if(zStrd4 == 1 && xStrd4 == 1 && yStrd4 == 1) - for (uint i4 = 0; i4 < zAxis4; ++i4) - z3[i4] = OpType::op(x3[i4], y3[i4]); - else - for (uint i4 = 0; i4 < zAxis4; ++i4) - z3[i4 * zStrd4] = OpType::op(x3[i4 * xStrd4], y3[i4 * yStrd4]); - } - } - } - } - }; - samediff::Threads::parallel_for(func, 0,zAxis0,1, 0,zAxis1,1, 0,zAxis2,1); - } - break; - - default: { - - const bool xzSameOffsets = shape::haveSameShapeAndStrides(xShapeInfo, zShapeInfo); - const bool yzSameOffsets = shape::haveSameShapeAndStrides(yShapeInfo, zShapeInfo); - - auto func = PRAGMA_THREADS_FOR{ - - int xCoords[MAX_RANK], yCoords[MAX_RANK], zCoords[MAX_RANK]; - - for (auto i = start; i < stop; ++i) { - - shape::index2coordsCPU(start, i, zShapeInfo, zCoords); - - for (uint j = 0; j < rank; ++j) { - xCoords[j] = shape::sizeAt(xShapeInfo, j) == 1 ? 0 : zCoords[j]; - yCoords[j] = shape::sizeAt(yShapeInfo, j) == 1 ? 0 : zCoords[j]; - } - - const auto zOffset = shape::getOffset(zShapeInfo, zCoords); - const auto xOffset = xzSameOffsets ? zOffset : shape::getOffset(xShapeInfo, xCoords); - const auto yOffset = yzSameOffsets ? zOffset : shape::getOffset(yShapeInfo, yCoords); - - z[zOffset] = OpType::op(x[xOffset], y[yOffset]); - } - }; - - samediff::Threads::parallel_for(func, 0, shape::length(zShapeInfo)); - } + case 1: + execRank1(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo); + break; + case 2: + execRank2(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo); + break; + case 3: + execRank3(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo); + break; + case 4: + execRank4(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo); + break; + case 5: + execRank5(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo); + break; + default: + execDefault(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo); } } diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/split.cpp b/libnd4j/include/ops/declarable/generic/parity_ops/split.cpp index c3a3779c6..60a80378e 100644 --- a/libnd4j/include/ops/declarable/generic/parity_ops/split.cpp +++ b/libnd4j/include/ops/declarable/generic/parity_ops/split.cpp @@ -49,7 +49,7 @@ namespace ops { input = a; } } - + //Edge case: splitting empty array (mainly for TF import compatibility) -> return N empty arrays if(input->isEmpty()){ for( int i=0; i< num_splits; i++ ){ @@ -112,17 +112,17 @@ namespace ops { inputVar = 0; } } - + auto shapes = SHAPELIST(); - + //Edge case: splitting empty array (mainly for TF import compatibility) -> return N empty arrays - if(INPUT_VARIABLE(inputVar)->isEmpty()){ - for (int e = 0; e < num_splits; e++) { - auto empty = ConstantShapeHelper::getInstance()->emptyShapeInfo(dataType); - shapes->push_back(empty); - } - return shapes; - } + // if(INPUT_VARIABLE(inputVar)->isEmpty()){ + // for (int e = 0; e < num_splits; e++) { + // auto empty = ConstantShapeHelper::getInstance()->emptyShapeInfo(dataType); + // shapes->push_back(empty); + // } + // return shapes; + // } if (block.numI() == 2) axis = INT_ARG(1); @@ -135,9 +135,9 @@ namespace ops { for (int e = 0; e < shape::rank(input); e++) if (e == axis) shape[e] = shape::sizeAt(input, e) / num_splits; - else + else shape[e] = shape::sizeAt(input, e); - + for (int e = 0; e < num_splits; e++) { auto newShape = ConstantShapeHelper::getInstance()->createShapeInfo(dataType, shape::order(input), shape); shapes->push_back(newShape); diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests10.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests10.cpp index efd46d1b5..03e5ae53f 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests10.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests10.cpp @@ -135,7 +135,7 @@ TEST_F(DeclarableOpsTests10, Test_Size_at_1) { ASSERT_EQ(Status::OK(), result.status()); ASSERT_EQ(e, *result.at(0)); - + } //////////////////////////////////////////////////////////////////////////////// @@ -170,7 +170,7 @@ TEST_F(DeclarableOpsTests10, Unique_SGO_Test_1) { ASSERT_TRUE(exp.equalsTo(res1)); ASSERT_TRUE(expIdx.equalsTo(res2)); - + } //////////////////////////////////////////////////////////////////////////////// @@ -187,7 +187,7 @@ TEST_F(DeclarableOpsTests10, Where_SGO_Test_1) { ASSERT_TRUE(exp.isSameShape(resA)); ASSERT_TRUE(exp.equalsTo(resA)); // ASSERT_TRUE(expIdx.equalsTo(res.at(1))); - + } //////////////////////////////////////////////////////////////////////////////// @@ -204,7 +204,7 @@ TEST_F(DeclarableOpsTests10, Where_SGO_Test_02) { ASSERT_TRUE(exp.equalsTo(resA)); ASSERT_TRUE(exp.isSameShape(resA)); // ASSERT_TRUE(expIdx.equalsTo(res.at(1))); - + } //////////////////////////////////////////////////////////////////////////////// @@ -228,7 +228,7 @@ TEST_F(DeclarableOpsTests10, WhereNP_SGO_Test_1) { ASSERT_TRUE(exp2.equalsTo(res2)); ASSERT_TRUE(exp3.equalsTo(res3)); //ASSERT_TRUE(expIdx.equalsTo(res.at(1))); - + } //////////////////////////////////////////////////////////////////////////////// @@ -245,7 +245,7 @@ TEST_F(DeclarableOpsTests10, WhereNP_SGO_Test_2) { ASSERT_TRUE(exp1.equalsTo(res.at(0))); ASSERT_TRUE(exp2.equalsTo(res.at(1))); //ASSERT_TRUE(expIdx.equalsTo(res.at(1))); - + } //////////////////////////////////////////////////////////////////////////////// @@ -263,7 +263,7 @@ TEST_F(DeclarableOpsTests10, Where_SGO_Test_2) { ASSERT_TRUE(exp.equalsTo(resA)); ASSERT_TRUE(exp.isSameShape(resA)); // ASSERT_TRUE(expIdx.equalsTo(res.at(1))); - + } //////////////////////////////////////////////////////////////////////////////// @@ -281,7 +281,7 @@ TEST_F(DeclarableOpsTests10, Where_SGO_Test_3) { ASSERT_TRUE(exp.equalsTo(resA)); ASSERT_TRUE(exp.isSameShape(resA)); // ASSERT_TRUE(expIdx.equalsTo(res.at(1))); - + } //////////////////////////////////////////////////////////////////////////////// @@ -300,7 +300,7 @@ TEST_F(DeclarableOpsTests10, Where_SGO_Test_4) { //ASSERT_TRUE(exp.equalsTo(resA)); //ASSERT_TRUE(exp.isSameShape(resA)); // ASSERT_TRUE(expIdx.equalsTo(res.at(1))); - + } //////////////////////////////////////////////////////////////////////////////// @@ -318,7 +318,7 @@ TEST_F(DeclarableOpsTests10, Where_SGO_Test_5) { ASSERT_TRUE(exp.equalsTo(resA)); ASSERT_TRUE(exp.isSameShape(resA)); // ASSERT_TRUE(expIdx.equalsTo(res.at(1))); - + } //////////////////////////////////////////////////////////////////////////////// @@ -337,7 +337,7 @@ TEST_F(DeclarableOpsTests10, WhereNP_SGO_Test_4) { //ASSERT_TRUE(exp.equalsTo(resA)); //ASSERT_TRUE(exp.isSameShape(resA)); // ASSERT_TRUE(expIdx.equalsTo(res.at(1))); - + } //////////////////////////////////////////////////////////////////////////////// @@ -354,7 +354,7 @@ TEST_F(DeclarableOpsTests10, CosineDistance_SGO_Test_1) { auto resA = res.at(0); ASSERT_TRUE(exp.equalsTo(resA)); - + } //////////////////////////////////////////////////////////////////////////////// @@ -371,7 +371,7 @@ TEST_F(DeclarableOpsTests10, CosineDistance_SGO_Test_2) { auto resA = res.at(0); ASSERT_TRUE(exp.equalsTo(resA)); - + } /////////////////////////////////////////////////////////////////// @@ -567,7 +567,7 @@ TEST_F(DeclarableOpsTests10, LGamma_Test1) { ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.equalsTo(z)); - + } ////////////////////////////////////////////////////////////////////////////// @@ -764,50 +764,6 @@ TEST_F(DeclarableOpsTests10, sparse_softmax_cross_entropy_loss_with_logits_test4 ASSERT_TRUE(expected.equalsTo(output)); } -/////////////////////////////////////////////////////////////////// -TEST_F(DeclarableOpsTests10, split_test4) { - - auto input = NDArrayFactory::create('c', {10},{1.f,2.f,3.f,4.f,5.f,6.f,7.f,8.f,9.f,10.f}); - auto axis = NDArrayFactory::create(-1); - auto exp1 = NDArrayFactory::create('c', {5}, {1.f,2.f,3.f,4.f,5.f}); - auto exp2 = NDArrayFactory::create('c', {5}, {6.f,7.f,8.f,9.f,10.f}); - - sd::ops::split op; - auto results = op.evaluate({&input, &axis}, {}, {2}, {}); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto out1 = results.at(0); - auto out2 = results.at(1); - - ASSERT_TRUE(exp1.isSameShape(out1)); - ASSERT_TRUE(exp2.isSameShape(out2)); - ASSERT_TRUE(exp1.equalsTo(out1)); - ASSERT_TRUE(exp2.equalsTo(out2)); -} - - -/////////////////////////////////////////////////////////////////// -TEST_F(DeclarableOpsTests10, split_test5) { - - auto input = NDArrayFactory::create('c', {3,8},{1.f,2.f,3.f,4.f,5.f,6.f,7.f,8.f,9.f,10.f,11.f,12.f,13.f,14.f,15.f,16.f,17.f,18.f,19.f,20.f,21.f,22.f,23.f,24.f}); - auto exp1 = NDArrayFactory::create('c', {3,4}, {1.f,2.f,3.f,4.f, 9.f,10.f,11.f,12.f, 17.f,18.f,19.f,20.f}); - auto exp2 = NDArrayFactory::create('c', {3,4}, {5.f,6.f,7.f,8.f, 13.f,14.f,15.f,16.f, 21.f,22.f,23.f,24.f}); - - sd::ops::split op; - auto results = op.evaluate({&input}, {}, {2,-1},{}); - - ASSERT_EQ(ND4J_STATUS_OK, results.status()); - - auto out1 = results.at(0); - auto out2 = results.at(1); - - ASSERT_TRUE(exp1.isSameShape(out1)); - ASSERT_TRUE(exp2.isSameShape(out2)); - ASSERT_TRUE(exp1.equalsTo(out1)); - ASSERT_TRUE(exp2.equalsTo(out2)); -} - /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests10, histogram_fixed_width_test1) { @@ -1464,7 +1420,7 @@ TEST_F(DeclarableOpsTests10, ImageResizeBilinear_Test1_1) { //expected.printIndexedBuffer("Expect for 10x10"); ASSERT_TRUE(expected.isSameShape(result)); ASSERT_TRUE(expected.equalsTo(result)); - + } TEST_F(DeclarableOpsTests10, ImageResizeBilinear_Test1_2) { @@ -1511,7 +1467,7 @@ TEST_F(DeclarableOpsTests10, ImageResizeBilinear_Test1_2) { // expected.printBuffer("Expect for 4x5"); ASSERT_TRUE(expected.isSameShape(result)); ASSERT_TRUE(expected.equalsTo(result)); - + } TEST_F(DeclarableOpsTests10, ImageResizeBilinear_Test01) { @@ -1569,7 +1525,7 @@ TEST_F(DeclarableOpsTests10, ImageResizeBilinear_Test01) { // expected.printShapeInfo("Expect shape"); ASSERT_TRUE(expected.isSameShape(result)); ASSERT_TRUE(expected.equalsTo(result)); - + } TEST_F(DeclarableOpsTests10, ImageResizeBilinear_Test02) { @@ -1724,7 +1680,7 @@ TEST_F(DeclarableOpsTests10, ImageResizeBilinear_Test02) { // expected.printShapeInfo("Expect shape"); ASSERT_TRUE(expected.isSameShape(result)); ASSERT_TRUE(expected.equalsTo(result)); - + } //////////////////////////////////////////////////////////////////// @@ -2053,7 +2009,7 @@ TEST_F(DeclarableOpsTests10, LinSpace_Test1) { auto res = result.at(0); ASSERT_TRUE(expect.equalsTo(res)); - + } //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests10, ImageResizeNeighbor_Test1) { @@ -2276,7 +2232,7 @@ TEST_F(DeclarableOpsTests10, ReduceLogSumExpTest_2) { ASSERT_TRUE(expected.isSameShapeStrict(*result)); ASSERT_TRUE(expected.equalsTo(result)); - + } //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests10, ReduceLogSumExpTest_3) { @@ -2315,7 +2271,7 @@ TEST_F(DeclarableOpsTests10, Image_NonMaxSuppressing_1) { ASSERT_TRUE(expected.isSameShapeStrict(*result)); ASSERT_TRUE(expected.equalsTo(result)); - + } //////////////////////////////////////////////////////////////////// diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests4.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests4.cpp index d930fcc36..69dec8359 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests4.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests4.cpp @@ -73,7 +73,7 @@ TYPED_TEST(TypedDeclarableOpsTests4, avgpool2d_1) { ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.equalsTo(z)); - + } ////////////////////////////////////////////////////////////////////// @@ -94,7 +94,7 @@ TYPED_TEST(TypedDeclarableOpsTests4, avgpool2d_2) { ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.equalsTo(z)); - + } ////////////////////////////////////////////////////////////////////// @@ -115,7 +115,7 @@ TYPED_TEST(TypedDeclarableOpsTests4, avgpool2d_3) { ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.equalsTo(z)); - + } ////////////////////////////////////////////////////////////////////// @@ -135,7 +135,7 @@ TYPED_TEST(TypedDeclarableOpsTests4, avgpool2d_4) { ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.equalsTo(z)); - + } ////////////////////////////////////////////////////////////////////// @@ -156,7 +156,7 @@ TYPED_TEST(TypedDeclarableOpsTests4, avgpool2d_5) { ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.equalsTo(z)); - + } ////////////////////////////////////////////////////////////////////// @@ -177,7 +177,7 @@ TYPED_TEST(TypedDeclarableOpsTests4, avgpool2d_6) { ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.equalsTo(z)); - + } ////////////////////////////////////////////////////////////////////// @@ -198,7 +198,7 @@ TYPED_TEST(TypedDeclarableOpsTests4, avgpool2d_7) { ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.equalsTo(z)); - + } ////////////////////////////////////////////////////////////////////// @@ -218,7 +218,7 @@ TYPED_TEST(TypedDeclarableOpsTests4, avgpool2d_8) { ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.equalsTo(z)); - + } ////////////////////////////////////////////////////////////////////// @@ -242,7 +242,7 @@ TYPED_TEST(TypedDeclarableOpsTests4, avgpool2d_9) { ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.equalsTo(z)); - + } ////////////////////////////////////////////////////////////////////// @@ -273,7 +273,7 @@ TYPED_TEST(TypedDeclarableOpsTests4, avgpool2d_10) { ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.equalsTo(z)); - + } ////////////////////////////////////////////////////////////////////// @@ -339,7 +339,7 @@ TEST_F(DeclarableOpsTests4, avgpool2d_11) { ASSERT_EQ(m, *z); - + } ////////////////////////////////////////////////////////////////////// @@ -373,7 +373,7 @@ TEST_F(DeclarableOpsTests4, avgpool2d_12) { ASSERT_TRUE(expected.isSameShape(output)); ASSERT_TRUE(expected.equalsTo(output)); - + } ////////////////////////////////////////////////////////////////////// @@ -546,7 +546,7 @@ TEST_F(DeclarableOpsTests4, biasadd_1) { ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.equalsTo(z)); - + } TEST_F(DeclarableOpsTests4, biasadd_2) { @@ -564,7 +564,7 @@ TEST_F(DeclarableOpsTests4, biasadd_2) { ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.equalsTo(z)); - + } TEST_F(DeclarableOpsTests4, biasadd_3) { @@ -581,7 +581,7 @@ TEST_F(DeclarableOpsTests4, biasadd_3) { ASSERT_TRUE(exp.isSameShape(z)); - + } ////////////////////////////////////////////////////////////////////// @@ -609,7 +609,7 @@ TEST_F(DeclarableOpsTests4, biasadd_bp_1) { ASSERT_TRUE(gradB->isSameShape(expGradB)); ASSERT_TRUE(gradB->equalsTo(expGradB)); - + } ////////////////////////////////////////////////////////////////////// @@ -637,7 +637,7 @@ TEST_F(DeclarableOpsTests4, biasadd_bp_2) { ASSERT_TRUE(gradB->isSameShape(expGradB)); ASSERT_TRUE(gradB->equalsTo(expGradB)); - + } TEST_F(DeclarableOpsTests4, biasadd_4) { @@ -672,7 +672,7 @@ TEST_F(DeclarableOpsTests4, Test_Fill_1) { ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.equalsTo(z)); - + } TEST_F(DeclarableOpsTests4, Test_FirasSparce_1) { @@ -694,7 +694,7 @@ TEST_F(DeclarableOpsTests4, Test_FirasSparce_1) { // ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.equalsTo(z)); - + } TEST_F(DeclarableOpsTests4, Test_FlattenTests_1) { @@ -713,7 +713,7 @@ TEST_F(DeclarableOpsTests4, Test_FlattenTests_1) { // z->printShapeInfo("Flatten1 shape"); ASSERT_TRUE(exp.equalsTo(z)); - + } TEST_F(DeclarableOpsTests4, Test_FlattenTests_2) { @@ -734,7 +734,7 @@ TEST_F(DeclarableOpsTests4, Test_FlattenTests_2) { ASSERT_TRUE(exp.equalsTo(z)); - + } TEST_F(DeclarableOpsTests4, Test_FlattenTests_3) { @@ -752,7 +752,7 @@ TEST_F(DeclarableOpsTests4, Test_FlattenTests_3) { ASSERT_TRUE(exp.equalsTo(z)); - + } TEST_F(DeclarableOpsTests4, Test_FlattenTests_4) { @@ -770,7 +770,7 @@ TEST_F(DeclarableOpsTests4, Test_FlattenTests_4) { ASSERT_TRUE(exp.equalsTo(z)); - + } TEST_F(DeclarableOpsTests4, Test_FloorTests_1) { @@ -788,7 +788,7 @@ TEST_F(DeclarableOpsTests4, Test_FloorTests_1) { // z->printShapeInfo("Flatten1 shape"); ASSERT_TRUE(exp.equalsTo(z)); - + } TEST_F(DeclarableOpsTests4, Test_Reshape_Again) { @@ -806,7 +806,7 @@ TEST_F(DeclarableOpsTests4, Test_Reshape_Again) { ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.equalsTo(z)); - + } TEST_F(DeclarableOpsTests4, Test_Split_1) { @@ -845,7 +845,7 @@ TEST_F(DeclarableOpsTests4, Test_Split_1) { ASSERT_TRUE(sub1.equalsTo(z1)); ASSERT_TRUE(sub2.equalsTo(z2)); - + } // special test for TF mode, when axis goes first @@ -888,7 +888,7 @@ TEST_F(DeclarableOpsTests4, Test_Split_2) { ASSERT_TRUE(sub2.equalsTo(z2)); ASSERT_TRUE(sub3.equalsTo(z3)); - + } // special test for TF mode, when axis goes first @@ -923,10 +923,89 @@ TEST_F(DeclarableOpsTests4, Test_Split_3) { ASSERT_TRUE(sub0.equalsTo(z0)); ASSERT_TRUE(sub1.equalsTo(z1)); ASSERT_TRUE(sub2.equalsTo(z2)); - - } +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests4, split_test4) { + + auto input = NDArrayFactory::create('c', {10},{1.f,2.f,3.f,4.f,5.f,6.f,7.f,8.f,9.f,10.f}); + auto axis = NDArrayFactory::create(-1); + auto exp1 = NDArrayFactory::create('c', {5}, {1.f,2.f,3.f,4.f,5.f}); + auto exp2 = NDArrayFactory::create('c', {5}, {6.f,7.f,8.f,9.f,10.f}); + + sd::ops::split op; + auto results = op.evaluate({&input, &axis}, {}, {2}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto out1 = results.at(0); + auto out2 = results.at(1); + + ASSERT_TRUE(exp1.isSameShape(out1)); + ASSERT_TRUE(exp2.isSameShape(out2)); + ASSERT_TRUE(exp1.equalsTo(out1)); + ASSERT_TRUE(exp2.equalsTo(out2)); +} + + +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests4, split_test5) { + + auto input = NDArrayFactory::create('c', {3,8},{1.f,2.f,3.f,4.f,5.f,6.f,7.f,8.f,9.f,10.f,11.f,12.f,13.f,14.f,15.f,16.f,17.f,18.f,19.f,20.f,21.f,22.f,23.f,24.f}); + auto exp1 = NDArrayFactory::create('c', {3,4}, {1.f,2.f,3.f,4.f, 9.f,10.f,11.f,12.f, 17.f,18.f,19.f,20.f}); + auto exp2 = NDArrayFactory::create('c', {3,4}, {5.f,6.f,7.f,8.f, 13.f,14.f,15.f,16.f, 21.f,22.f,23.f,24.f}); + + sd::ops::split op; + auto results = op.evaluate({&input}, {}, {2,-1},{}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto out1 = results.at(0); + auto out2 = results.at(1); + + ASSERT_TRUE(exp1.isSameShape(out1)); + ASSERT_TRUE(exp2.isSameShape(out2)); + ASSERT_TRUE(exp1.equalsTo(out1)); + ASSERT_TRUE(exp2.equalsTo(out2)); +} + +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests4, split_test6) { + + NDArray input('c', {0,4}, sd::DataType::FLOAT32); + std::vector expShape = {0,1}; + + const int numSplits = 4; + const int axis = 1; + + sd::ops::split op; + auto results = op.evaluate({&input}, {}, {numSplits, axis}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + for (int i = 0; i < numSplits; ++i) + ASSERT_TRUE(results.at(i)->isSameShape(expShape)); +} + +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests4, split_test7) { + + NDArray input('c', {0,4}, sd::DataType::FLOAT32); + std::vector expShape = {0,4}; + + const int numSplits = 4; + const int axis = 0; + + sd::ops::split op; + auto results = op.evaluate({&input}, {}, {numSplits, axis}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + for (int i = 0; i < numSplits; ++i) + ASSERT_TRUE(results.at(i)->isSameShape(expShape)); +} + + TEST_F(DeclarableOpsTests4, Test_Squeeze_args_1) { auto x = NDArrayFactory::create('c', {2, 1, 1, 1, 2}, {1, 2, 3, 4}); auto exp = NDArrayFactory::create('c', {2, 1, 2}, {1, 2, 3, 4}); @@ -940,7 +1019,7 @@ TEST_F(DeclarableOpsTests4, Test_Squeeze_args_1) { ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.equalsTo(z)); - + } TEST_F(DeclarableOpsTests4, Test_Squeeze_args_2) { @@ -957,7 +1036,7 @@ TEST_F(DeclarableOpsTests4, Test_Squeeze_args_2) { ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.equalsTo(z)); - + } @@ -974,7 +1053,7 @@ TEST_F(DeclarableOpsTests4, Test_Squeeze_args_3) { ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.equalsTo(z)); - + } TEST_F(DeclarableOpsTests4, Test_SpaceToDepth_1) { @@ -990,7 +1069,7 @@ TEST_F(DeclarableOpsTests4, Test_SpaceToDepth_1) { ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.equalsTo(z)); - + } TEST_F(DeclarableOpsTests4, Test_SpaceToDepth_2) { @@ -1006,7 +1085,7 @@ TEST_F(DeclarableOpsTests4, Test_SpaceToDepth_2) { ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.equalsTo(z)); - + } @@ -1023,7 +1102,7 @@ TEST_F(DeclarableOpsTests4, Test_DepthToSpace_1) { ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.equalsTo(z)); - + } @@ -1040,7 +1119,7 @@ TEST_F(DeclarableOpsTests4, Test_DepthToSpace_2) { ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.equalsTo(z)); - + } TEST_F(DeclarableOpsTests4, Test_DepthToSpace_3) { @@ -1055,7 +1134,7 @@ TEST_F(DeclarableOpsTests4, Test_DepthToSpace_3) { ASSERT_TRUE(exp.isSameShape(z)); - + } @@ -1073,7 +1152,7 @@ TEST_F(DeclarableOpsTests4, Test_Cross_1) { ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.equalsTo(z)); - + } @@ -1091,7 +1170,7 @@ TEST_F(DeclarableOpsTests4, Test_Cross_2) { ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.equalsTo(z)); - + } @@ -1109,7 +1188,7 @@ TEST_F(DeclarableOpsTests4, Test_Cross_3) { ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.equalsTo(z)); - + } TEST_F(DeclarableOpsTests4, Test_Add_119) { @@ -1129,7 +1208,7 @@ TEST_F(DeclarableOpsTests4, Test_Add_119) { ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.equalsTo(z)); - + } TEST_F(DeclarableOpsTests4, Test_Reshape_Negative_1) { @@ -1146,7 +1225,7 @@ TEST_F(DeclarableOpsTests4, Test_Reshape_Negative_1) { ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.equalsTo(z)); - + } TEST_F(DeclarableOpsTests4, Test_TileToShape_1) { @@ -1165,7 +1244,7 @@ TEST_F(DeclarableOpsTests4, Test_TileToShape_1) { ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.equalsTo(z)); - + } TEST_F(DeclarableOpsTests4, Test_StridedSlice_Alex_1) { @@ -1184,7 +1263,7 @@ TEST_F(DeclarableOpsTests4, Test_StridedSlice_Alex_1) { ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.equalsTo(z)); - + } TEST_F(DeclarableOpsTests4, Test_StridedSlice_Alex_2) { @@ -1208,7 +1287,7 @@ TEST_F(DeclarableOpsTests4, Test_StridedSlice_Alex_2) { ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.equalsTo(z)); - + } TEST_F(DeclarableOpsTests4, Test_StridedSlice_Alex_3) { @@ -1229,7 +1308,7 @@ TEST_F(DeclarableOpsTests4, Test_StridedSlice_Alex_3) { auto z = result.at(0); ASSERT_TRUE(z->isEmpty()); - + } TEST_F(DeclarableOpsTests4, Test_StridedSlice_Alex_4) { auto x = NDArrayFactory::create('c', {1,3}, {1, 2, 3}); @@ -1248,7 +1327,7 @@ TEST_F(DeclarableOpsTests4, Test_StridedSlice_Alex_4) { auto z = result.at(0); ASSERT_TRUE(z->lengthOf() == 1); ASSERT_TRUE(exp.equalsTo(z)); - + } ////////////////////////////////////////////////////////////////////// @@ -1272,7 +1351,7 @@ TEST_F(DeclarableOpsTests4, parallel_stack_test1) { ASSERT_TRUE(expected.isSameShape(output)); ASSERT_TRUE(expected.equalsTo(output)); - + } ////////////////////////////////////////////////////////////////////// @@ -1293,7 +1372,7 @@ TEST_F(DeclarableOpsTests4, parallel_stack_test2) { ASSERT_TRUE(expected.equalsTo(output)); - + } ////////////////////////////////////////////////////////////////////// @@ -1313,7 +1392,7 @@ TEST_F(DeclarableOpsTests4, parallel_stack_test3) { ASSERT_TRUE(expected.isSameShape(output)); ASSERT_TRUE(expected.equalsTo(output)); - + } \ ////////////////////////////////////////////////////////////////////// @@ -1333,7 +1412,7 @@ TEST_F(DeclarableOpsTests4, parallel_stack_test4) { ASSERT_TRUE(expected.isSameShape(output)); ASSERT_TRUE(expected.equalsTo(output)); - + } ////////////////////////////////////////////////////////////////////// @@ -1353,7 +1432,7 @@ TEST_F(DeclarableOpsTests4, parallel_stack_test5) { ASSERT_TRUE(expected.isSameShape(output)); ASSERT_TRUE(expected.equalsTo(output)); - + } ////////////////////////////////////////////////////////////////////// @@ -1373,7 +1452,7 @@ TEST_F(DeclarableOpsTests4, parallel_stack_test6) { ASSERT_TRUE(expected.isSameShape(output)); ASSERT_TRUE(expected.equalsTo(output)); - + } ////////////////////////////////////////////////////////////////////// @@ -1390,7 +1469,7 @@ TEST_F(DeclarableOpsTests4, parallel_stack_test7) { ASSERT_TRUE(expected.isSameShape(output)); ASSERT_TRUE(expected.equalsTo(output)); - + } ////////////////////////////////////////////////////////////////////// @@ -1419,7 +1498,7 @@ TEST_F(DeclarableOpsTests4, meshgrid_test1) { ASSERT_TRUE(exp2.isSameShape(out2)); ASSERT_TRUE(exp2.equalsTo(out2)); - + } ////////////////////////////////////////////////////////////////////// @@ -1446,7 +1525,7 @@ TEST_F(DeclarableOpsTests4, meshgrid_test2) { ASSERT_TRUE(exp2.isSameShape(out2)); ASSERT_TRUE(exp2.equalsTo(out2)); - + } ////////////////////////////////////////////////////////////////////// @@ -1473,7 +1552,7 @@ TEST_F(DeclarableOpsTests4, meshgrid_test3) { ASSERT_TRUE(exp2.isSameShape(out2)); ASSERT_TRUE(exp2.equalsTo(out2)); - + } ////////////////////////////////////////////////////////////////////// @@ -1500,7 +1579,7 @@ TEST_F(DeclarableOpsTests4, meshgrid_test4) { ASSERT_TRUE(exp2.isSameShape(out2)); ASSERT_TRUE(exp2.equalsTo(out2)); - + } ////////////////////////////////////////////////////////////////////// @@ -1527,7 +1606,7 @@ TEST_F(DeclarableOpsTests4, meshgrid_test5) { ASSERT_TRUE(exp2.isSameShape(out2)); ASSERT_TRUE(exp2.equalsTo(out2)); - + } ////////////////////////////////////////////////////////////////////// @@ -1554,7 +1633,7 @@ TEST_F(DeclarableOpsTests4, meshgrid_test6) { ASSERT_TRUE(exp2.isSameShape(out2)); ASSERT_TRUE(exp2.equalsTo(out2)); - + } ////////////////////////////////////////////////////////////////////// @@ -1581,7 +1660,7 @@ TEST_F(DeclarableOpsTests4, meshgrid_test7) { ASSERT_TRUE(exp2.isSameShape(out2)); ASSERT_TRUE(exp2.equalsTo(out2)); - + } ////////////////////////////////////////////////////////////////////// @@ -1598,7 +1677,7 @@ TEST_F(DeclarableOpsTests4, meshgrid_test8) { ASSERT_TRUE(exp0.isSameShape(out0)); ASSERT_TRUE(exp0.equalsTo(out0)); - + } ////////////////////////////////////////////////////////////////////// @@ -1615,7 +1694,7 @@ TEST_F(DeclarableOpsTests4, meshgrid_test9) { ASSERT_TRUE(exp0.isSameShape(out0)); ASSERT_TRUE(exp0.equalsTo(out0)); - + } //////////////////////////////////////////////////////////////////////////////// @@ -1644,7 +1723,7 @@ TEST_F(DeclarableOpsTests4, WeightedCrossEntropyWithLogits_1) { ASSERT_TRUE(expected.isSameShape(output)); ASSERT_TRUE(expected.equalsTo(output)); - + } //////////////////////////////////////////////////////////////////////////////// @@ -1664,7 +1743,7 @@ TEST_F(DeclarableOpsTests4, WeightedCrossEntropyWithLogits_2) { ASSERT_TRUE(expected.isSameShape(output)); ASSERT_TRUE(expected.equalsTo(output)); - + } @@ -1718,7 +1797,7 @@ TEST_F(DeclarableOpsTests4, lstm_test1) { ASSERT_TRUE(expClast.isSameShape(&cLast)); ASSERT_TRUE(expClast.equalsTo(&cLast)); - + } /////////////////////////////////////////////////////////////////// @@ -1737,7 +1816,7 @@ TEST_F(DeclarableOpsTests4, relu6_test1) { ASSERT_TRUE(expected.isSameShape(output)); ASSERT_TRUE(expected.equalsTo(output)); - + } @@ -1759,7 +1838,7 @@ TEST_F(DeclarableOpsTests4, relu6_bp_test1) { ASSERT_TRUE(expected.isSameShape(output)); ASSERT_TRUE(expected.equalsTo(output)); - + } //////////////////////////////////////////////////////////////////////////////// @@ -1788,7 +1867,7 @@ TYPED_TEST(TypedDeclarableOpsTests4, LrnTest_1) { // exp.printIndexedBuffer("LRN exp"); ASSERT_TRUE(exp.equalsTo(out)); - + } @@ -1816,7 +1895,7 @@ TYPED_TEST(TypedDeclarableOpsTests4, LrnTest_2) { // exp.printIndexedBuffer("LRN exp"); ASSERT_TRUE(exp.equalsTo(out)); - + } //////////////////////////////////////////////////////////////////////////////// @@ -1855,7 +1934,7 @@ TYPED_TEST(TypedDeclarableOpsTests4, LrnTest_3) { // exp.printIndexedBuffer("LRN exp"); ASSERT_TRUE(exp.equalsTo(out)); - + } //////////////////////////////////////////////////////////////////////////////// @@ -1894,7 +1973,7 @@ TYPED_TEST(TypedDeclarableOpsTests4, LrnTest_4) { // exp.printIndexedBuffer("LRN exp"); ASSERT_TRUE(exp.equalsTo(out)); - + } //////////////////////////////////////////////////////////////////////////////// @@ -1939,7 +2018,7 @@ TYPED_TEST(TypedDeclarableOpsTests4, LrnTest_5) { // exp.printIndexedBuffer("LRN exp"); // ASSERT_TRUE(exp.equalsTo(out)); - + } @@ -1962,7 +2041,7 @@ TEST_F(DeclarableOpsTests4, tri_test1) { ASSERT_TRUE(expected.isSameShape(output)); ASSERT_TRUE(expected.equalsTo(output)); - + } ////////////////////////////////////////////////////////////////////// @@ -1983,7 +2062,7 @@ TEST_F(DeclarableOpsTests4, tri_test2) { ASSERT_TRUE(expected.isSameShape(output)); ASSERT_TRUE(expected.equalsTo(output)); - + } ////////////////////////////////////////////////////////////////////// @@ -2004,7 +2083,7 @@ TEST_F(DeclarableOpsTests4, tri_test3) { ASSERT_TRUE(expected.isSameShape(output)); ASSERT_TRUE(expected.equalsTo(output)); - + } ////////////////////////////////////////////////////////////////////// @@ -2025,7 +2104,7 @@ TEST_F(DeclarableOpsTests4, tri_test4) { ASSERT_TRUE(expected.isSameShape(output)); ASSERT_TRUE(expected.equalsTo(output)); - + } ////////////////////////////////////////////////////////////////////// @@ -2044,7 +2123,7 @@ TEST_F(DeclarableOpsTests4, tri_test5) { ASSERT_TRUE(expected.isSameShape(output)); ASSERT_TRUE(expected.equalsTo(output)); - + } ////////////////////////////////////////////////////////////////////// @@ -2065,7 +2144,7 @@ TEST_F(DeclarableOpsTests4, tri_test6) { ASSERT_TRUE(expected.isSameShape(output)); ASSERT_TRUE(expected.equalsTo(output)); - + } ////////////////////////////////////////////////////////////////////// @@ -2086,7 +2165,7 @@ TEST_F(DeclarableOpsTests4, tri_test7) { ASSERT_TRUE(expected.isSameShape(output)); ASSERT_TRUE(expected.equalsTo(output)); - + } ////////////////////////////////////////////////////////////////////// @@ -2105,7 +2184,7 @@ TEST_F(DeclarableOpsTests4, triu_test1) { ASSERT_TRUE(expected.equalsTo(output)); - + } ////////////////////////////////////////////////////////////////////// @@ -2123,7 +2202,7 @@ TEST_F(DeclarableOpsTests4, triu_test2) { ASSERT_TRUE(expected.isSameShape(output)); ASSERT_TRUE(expected.equalsTo(output)); - + } ////////////////////////////////////////////////////////////////////// @@ -2141,7 +2220,7 @@ TEST_F(DeclarableOpsTests4, triu_test3) { ASSERT_TRUE(expected.isSameShape(output)); ASSERT_TRUE(expected.equalsTo(output)); - + } ////////////////////////////////////////////////////////////////////// @@ -2159,7 +2238,7 @@ TEST_F(DeclarableOpsTests4, triu_test4) { ASSERT_TRUE(expected.isSameShape(output)); ASSERT_TRUE(expected.equalsTo(output)); - + } ////////////////////////////////////////////////////////////////////// @@ -2177,7 +2256,7 @@ TEST_F(DeclarableOpsTests4, triu_test5) { ASSERT_TRUE(expected.isSameShape(output)); ASSERT_TRUE(expected.equalsTo(output)); - + } ////////////////////////////////////////////////////////////////////// @@ -2195,7 +2274,7 @@ TEST_F(DeclarableOpsTests4, triu_test6) { ASSERT_TRUE(expected.isSameShape(output)); ASSERT_TRUE(expected.equalsTo(output)); - + } ////////////////////////////////////////////////////////////////////// @@ -2213,7 +2292,7 @@ TEST_F(DeclarableOpsTests4, triu_test7) { ASSERT_TRUE(expected.isSameShape(output)); ASSERT_TRUE(expected.equalsTo(output)); - + } ////////////////////////////////////////////////////////////////////// @@ -2231,7 +2310,7 @@ TEST_F(DeclarableOpsTests4, triu_test8) { ASSERT_TRUE(expected.isSameShape(output)); ASSERT_TRUE(expected.equalsTo(output)); - + } ////////////////////////////////////////////////////////////////////// @@ -2249,7 +2328,7 @@ TEST_F(DeclarableOpsTests4, triu_test9) { ASSERT_TRUE(expected.isSameShape(output)); ASSERT_TRUE(expected.equalsTo(output)); - + } ////////////////////////////////////////////////////////////////////// @@ -2267,7 +2346,7 @@ TEST_F(DeclarableOpsTests4, triu_test10) { ASSERT_TRUE(expected.isSameShape(output)); ASSERT_TRUE(expected.equalsTo(output)); - + } ////////////////////////////////////////////////////////////////////// @@ -2285,7 +2364,7 @@ TEST_F(DeclarableOpsTests4, triu_test11) { ASSERT_TRUE(expected.isSameShape(output)); ASSERT_TRUE(expected.equalsTo(output)); - + } @@ -2307,7 +2386,7 @@ TEST_F(DeclarableOpsTests4, triu_bp_test1) { ASSERT_TRUE(expected.isSameShape(gradI)); ASSERT_TRUE(expected.equalsTo(gradI)); - + } ////////////////////////////////////////////////////////////////////// @@ -2328,7 +2407,7 @@ TEST_F(DeclarableOpsTests4, triu_bp_test2) { ASSERT_TRUE(expected.isSameShape(gradI)); ASSERT_TRUE(expected.equalsTo(gradI)); - + } ////////////////////////////////////////////////////////////////////// @@ -2349,7 +2428,7 @@ TEST_F(DeclarableOpsTests4, triu_bp_test3) { ASSERT_TRUE(expected.isSameShape(gradI)); ASSERT_TRUE(expected.equalsTo(gradI)); - + } ////////////////////////////////////////////////////////////////////// @@ -2370,6 +2449,6 @@ TEST_F(DeclarableOpsTests4, triu_bp_test4) { ASSERT_TRUE(expected.isSameShape(gradI)); ASSERT_TRUE(expected.equalsTo(gradI)); - + } diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests9.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests9.cpp index a1f9c40a1..c7e704a21 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests9.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests9.cpp @@ -57,14 +57,14 @@ TEST_F(DeclarableOpsTests9, reduceStDevBP_test3) { // output->printIndexedBuffer(); ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.equalsTo(output)); - + result = op.evaluate({&x, &gradO1}, {1,0}, {1}); ASSERT_EQ(ND4J_STATUS_OK, result.status()); output = result.at(0); ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.equalsTo(output)); - + } @@ -86,14 +86,14 @@ TEST_F(DeclarableOpsTests9, reduceStDevBP_test03) { // output->printIndexedBuffer(); ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.equalsTo(output)); - + result = op.evaluate({&x, &gradO1}, {1,0}, {1}); ASSERT_EQ(ND4J_STATUS_OK, result.status()); output = result.at(0); ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.equalsTo(output)); - + } /* @@ -255,7 +255,7 @@ TEST_F(DeclarableOpsTests9, concat_test1) { ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.equalsTo(output)); - + } //////////////////////////////////////////////////////////////////////////////// @@ -279,7 +279,7 @@ TEST_F(DeclarableOpsTests9, concat_test2) { ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.equalsTo(output)); - + } //////////////////////////////////////////////////////////////////////////////// @@ -305,7 +305,7 @@ TEST_F(DeclarableOpsTests9, concat_test3) { ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.equalsTo(output)); - + } //////////////////////////////////////////////////////////////////////////////// @@ -325,7 +325,7 @@ TEST_F(DeclarableOpsTests9, concat_test4) { ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.equalsTo(output)); - + } //////////////////////////////////////////////////////////////////////////////// @@ -345,7 +345,7 @@ TEST_F(DeclarableOpsTests9, concat_test5) { ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.equalsTo(output)); - + } //////////////////////////////////////////////////////////////////////////////// @@ -365,7 +365,7 @@ TEST_F(DeclarableOpsTests9, concat_test6) { ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.equalsTo(output)); - + } //////////////////////////////////////////////////////////////////////////////// @@ -385,7 +385,7 @@ TEST_F(DeclarableOpsTests9, concat_test7) { ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.equalsTo(output)); - + } //////////////////////////////////////////////////////////////////////////////// @@ -403,7 +403,7 @@ TEST_F(DeclarableOpsTests9, concat_test8) { ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.equalsTo(output)); - + } //////////////////////////////////////////////////////////////////////////////// @@ -421,7 +421,7 @@ TEST_F(DeclarableOpsTests9, concat_test9) { ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.equalsTo(output)); - + } //////////////////////////////////////////////////////////////////////////////// @@ -446,7 +446,7 @@ TEST_F(DeclarableOpsTests9, concat_test10) { ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.equalsTo(output)); - + } //////////////////////////////////////////////////////////////////////////////// @@ -471,7 +471,7 @@ TEST_F(DeclarableOpsTests9, concat_test11) { ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.equalsTo(output)); - + } //////////////////////////////////////////////////////////////////////////////// @@ -496,7 +496,7 @@ TEST_F(DeclarableOpsTests9, concat_test12) { ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.equalsTo(output)); - + } //////////////////////////////////////////////////////////////////////////////// @@ -522,7 +522,7 @@ TEST_F(DeclarableOpsTests9, concat_test13) { ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.equalsTo(output)); - + } TEST_F(DeclarableOpsTests9, concat_test14) { @@ -548,7 +548,7 @@ TEST_F(DeclarableOpsTests9, concat_test14) { ASSERT_NEAR((e+1)*1., mean, 1e-5); } - + } TEST_F(DeclarableOpsTests9, concat_test15) { @@ -565,7 +565,7 @@ TEST_F(DeclarableOpsTests9, concat_test15) { ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.equalsTo(z)); - + } ////////////////////////////////////////////////////////////////////// @@ -582,8 +582,6 @@ TEST_F(DeclarableOpsTests9, concat_test16) { auto z = result.at(0); ASSERT_TRUE(exp.isSameShape(z)); - - } ////////////////////////////////////////////////////////////////////// @@ -611,8 +609,6 @@ TEST_F(DeclarableOpsTests9, concat_test17) { auto mean = tad.meanNumber().e(0); ASSERT_NEAR((e+1)*1., mean, 1e-5); } - - } ////////////////////////////////////////////////////////////////////// @@ -693,7 +689,7 @@ TEST_F(DeclarableOpsTests9, concat_test20) { ASSERT_NEAR((double) e+1, mean, 1e-5); } - + } //////////////////////////////////////////////////////////////////////////////// @@ -775,7 +771,7 @@ TEST_F(DeclarableOpsTests9, concat_test25) { ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.equalsTo(output)); - + } //////////////////////////////////////////////////////////////////////////////// @@ -803,6 +799,25 @@ TEST_F(DeclarableOpsTests9, concat_test26) { ASSERT_TRUE(exp.equalsTo(output)); } +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests9, concat_test27) { + + auto x1 = NDArrayFactory::create('c', {0,1}); + auto x2 = NDArrayFactory::create('c', {0,1}); + auto x3 = NDArrayFactory::create('c', {0,1}); + auto x4 = NDArrayFactory::create('c', {0,1}); + + std::vector expShape = {0, 4}; + + sd::ops::concat op; + auto result = op.evaluate({&x1, &x2, &x3, &x4}, {}, {1}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto z = result.at(0); + + ASSERT_TRUE(z->isSameShape(expShape)); +} + ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests9, tile_bp_test1) { @@ -820,7 +835,7 @@ TEST_F(DeclarableOpsTests9, tile_bp_test1) { ASSERT_TRUE(gradIExp.isSameShape(gradI)); ASSERT_TRUE(gradIExp.equalsTo(gradI)); - + } ////////////////////////////////////////////////////////////////////// @@ -839,7 +854,7 @@ TEST_F(DeclarableOpsTests9, tile_bp_test2) { ASSERT_TRUE(gradIExp.isSameShape(gradI)); ASSERT_TRUE(gradIExp.equalsTo(gradI)); - + } ////////////////////////////////////////////////////////////////////// @@ -859,7 +874,7 @@ TEST_F(DeclarableOpsTests9, tile_bp_test3) { ASSERT_TRUE(gradIExp.isSameShape(gradI)); ASSERT_TRUE(gradIExp.equalsTo(gradI)); - + } ////////////////////////////////////////////////////////////////////// @@ -879,7 +894,7 @@ TEST_F(DeclarableOpsTests9, tile_bp_test4) { ASSERT_TRUE(gradIExp.isSameShape(gradI)); ASSERT_TRUE(gradIExp.equalsTo(gradI)); - + } ////////////////////////////////////////////////////////////////////// @@ -899,7 +914,7 @@ TEST_F(DeclarableOpsTests9, tile_bp_test5) { ASSERT_TRUE(gradIExp.isSameShape(gradI)); ASSERT_TRUE(gradIExp.equalsTo(gradI)); - + } ////////////////////////////////////////////////////////////////////// @@ -919,7 +934,7 @@ TEST_F(DeclarableOpsTests9, tile_bp_test6) { ASSERT_TRUE(gradIExp.isSameShape(gradI)); ASSERT_TRUE(gradIExp.equalsTo(gradI)); - + } ////////////////////////////////////////////////////////////////////// @@ -940,7 +955,7 @@ TEST_F(DeclarableOpsTests9, tile_bp_test7) { ASSERT_TRUE(gradIExp.isSameShape(gradI)); ASSERT_TRUE(gradIExp.equalsTo(gradI)); - + } ////////////////////////////////////////////////////////////////////// @@ -958,7 +973,7 @@ TEST_F(DeclarableOpsTests9, tile_test1) { ASSERT_TRUE(expOut.isSameShape(out)); ASSERT_TRUE(expOut.equalsTo(out)); - + } //////////////////////////////////////////////////////////////////////////////// @@ -975,7 +990,7 @@ TEST_F(DeclarableOpsTests9, TestDropout_BP_1) { //ress.at(0)->printIndexedBuffer("Result is "); //x.printIndexedBuffer("Input is"); ASSERT_FALSE(ress.at(0)->equalsTo(errs)); - + } //////////////////////////////////////////////////////////////////////////////// @@ -1006,8 +1021,8 @@ TEST_F(DeclarableOpsTests9, TestDropout_1) { //res->printIndexedBuffer("FF dropout"); //res2->printIndexedBuffer("BP dropout"); - - + + } TEST_F(DeclarableOpsTests9, Test_DropoutInverted_01) { @@ -1080,7 +1095,7 @@ TEST_F(DeclarableOpsTests9, Test_DropoutInverted_01) { */ // ASSERT_TRUE(exp.equalsTo(ressX->at(0))); - + } TEST_F(DeclarableOpsTests9, Test_Dropout_BP_2) { @@ -1117,7 +1132,7 @@ TEST_F(DeclarableOpsTests9, Test_Dropout_BP_2) { ASSERT_NEAR(countZero.e(0), 50.f, 10.f); // ASSERT_TRUE(exp.equalsTo(ressX->at(0))); ASSERT_TRUE(ressX.at(0)->equalsTo(ressY.at(0))); - + } @@ -1143,7 +1158,7 @@ TEST_F(DeclarableOpsTests9, Test_AlphaDropout_BP_1) { //res->printIndexedBuffer("Result1AlphaBP1"); //res2->printIndexedBuffer("Result1AlphaBP2"); ASSERT_TRUE(res2->equalsTo(res)); - + } TEST_F(DeclarableOpsTests9, test_range_int_1) { @@ -1227,7 +1242,7 @@ TEST_F(DeclarableOpsTests9, test_unstack_SGO_1) { ASSERT_TRUE(result.at(i)->isSameShape(z[i])); ASSERT_TRUE(result.at(i)->equalsTo(z[i])); } - + } //////////////////////////////////////////////////////////////////////////////// @@ -1263,7 +1278,7 @@ TEST_F(DeclarableOpsTests9, clipbynorm_test12) { ASSERT_TRUE(expect.isSameShape(outFF)); ASSERT_TRUE(expect.equalsTo(outFF)); - + } @@ -1355,7 +1370,7 @@ TEST_F(DeclarableOpsTests9, cumprod_1) { ASSERT_EQ(Status::OK(), result.status()); auto z = result.at(0); ASSERT_TRUE(expFF.equalsTo(z)); - + //************************************// exclusive = 1; reverse = 0; @@ -1364,7 +1379,7 @@ TEST_F(DeclarableOpsTests9, cumprod_1) { ASSERT_EQ(Status::OK(), result.status()); z = result.at(0); ASSERT_TRUE(expTF.equalsTo(z)); - + //************************************// exclusive = 0; reverse = 1; @@ -1373,7 +1388,7 @@ TEST_F(DeclarableOpsTests9, cumprod_1) { ASSERT_EQ(Status::OK(), result.status()); z = result.at(0); ASSERT_TRUE(expFT.equalsTo(z)); - + //************************************// exclusive = 1; reverse = 1; @@ -1382,7 +1397,7 @@ TEST_F(DeclarableOpsTests9, cumprod_1) { ASSERT_EQ(Status::OK(), result.status()); z = result.at(0); ASSERT_TRUE(expTT.equalsTo(z)); - + } @@ -1416,7 +1431,7 @@ TEST_F(DeclarableOpsTests9, cumprod_2) { ASSERT_TRUE(exp.equalsTo(z)); - + } ////////////////////////////////////////////////////////////////////// @@ -1584,7 +1599,7 @@ TEST_F(DeclarableOpsTests9, prelu_test1) { ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.equalsTo(output)); - + } //////////////////////////////////////////////////////////////////////////////// @@ -1602,7 +1617,7 @@ TEST_F(DeclarableOpsTests9, prelu_test2) { ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.equalsTo(output)); - + } //////////////////////////////////////////////////////////////////////////////// @@ -1620,7 +1635,7 @@ TEST_F(DeclarableOpsTests9, prelu_test3) { ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.equalsTo(output)); - + } //////////////////////////////////////////////////////////////////////////////// @@ -1638,7 +1653,7 @@ TEST_F(DeclarableOpsTests9, prelu_test4) { ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.equalsTo(output)); - + } //////////////////////////////////////////////////////////////////////////////// @@ -1656,7 +1671,7 @@ TEST_F(DeclarableOpsTests9, prelu_test5) { ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.equalsTo(output)); - + } //////////////////////////////////////////////////////////////////////////////// @@ -1674,7 +1689,7 @@ TEST_F(DeclarableOpsTests9, prelu_test6) { ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.equalsTo(output)); - + } @@ -1693,7 +1708,7 @@ TEST_F(DeclarableOpsTests9, prelu_test7) { ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.equalsTo(output)); - + } //////////////////////////////////////////////////////////////////////////////// @@ -1711,7 +1726,7 @@ TEST_F(DeclarableOpsTests9, prelu_test8) { ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.equalsTo(output)); - + } //////////////////////////////////////////////////////////////////////////////// @@ -1729,7 +1744,7 @@ TEST_F(DeclarableOpsTests9, prelu_test9) { ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.equalsTo(output)); - + } //////////////////////////////////////////////////////////////////////////////// @@ -1747,7 +1762,7 @@ TEST_F(DeclarableOpsTests9, prelu_test10) { ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.equalsTo(output)); - + } //////////////////////////////////////////////////////////////////////////////// @@ -1772,7 +1787,7 @@ TEST_F(DeclarableOpsTests9, prelu_test11) { ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.equalsTo(output)); - + } //////////////////////////////////////////////////////////////////////////////// @@ -1796,7 +1811,7 @@ TEST_F(DeclarableOpsTests9, prelu_test12) { ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.equalsTo(output)); - + } //////////////////////////////////////////////////////////////////////////////// @@ -1820,7 +1835,7 @@ TEST_F(DeclarableOpsTests9, prelu_test13) { ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.equalsTo(output)); - + } //////////////////////////////////////////////////////////////////////////////// @@ -1845,7 +1860,7 @@ TEST_F(DeclarableOpsTests9, prelu_test14) { ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.equalsTo(output)); - + } //////////////////////////////////////////////////////////////////////////////// @@ -1864,7 +1879,7 @@ TEST_F(DeclarableOpsTests9, thresholdedrelu_test1) { ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.equalsTo(output)); - + } //////////////////////////////////////////////////////////////////////////////// @@ -1883,7 +1898,7 @@ TEST_F(DeclarableOpsTests9, compare_and_bitpack_test1) { // output->printIndexedBuffer("Packed to uint8"); ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.equalsTo(output)); - + } //////////////////////////////////////////////////////////////////////////////// @@ -1902,7 +1917,7 @@ TEST_F(DeclarableOpsTests9, thresholdedrelu_test2) { ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.equalsTo(output)); - + } //////////////////////////////////////////////////////////////////////////////// @@ -2017,7 +2032,7 @@ TEST_F(DeclarableOpsTests9, multiply_test1) { ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.equalsTo(z)); - + } //////////////////////////////////////////////////////////////////////////////// @@ -2037,7 +2052,7 @@ TEST_F(DeclarableOpsTests9, multiply_test2) { ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.equalsTo(z)); - + } //////////////////////////////////////////////////////////////////////////////// @@ -2057,7 +2072,7 @@ TEST_F(DeclarableOpsTests9, multiply_test3) { ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.equalsTo(z)); - + } //////////////////////////////////////////////////////////////////////////////// @@ -2076,7 +2091,7 @@ TEST_F(DeclarableOpsTests9, multiply_test4) { ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.equalsTo(z)); - + } //////////////////////////////////////////////////////////////////////////////// @@ -2094,7 +2109,7 @@ TEST_F(DeclarableOpsTests9, multiply_test5) { ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.equalsTo(z)); - + } //////////////////////////////////////////////////////////////////////////////// @@ -2403,7 +2418,7 @@ TEST_F(DeclarableOpsTests9, gru_cell_bp_test1) { dLdu.assign(dLdh * dhdu); dLdr.assign(mmul(dLdc * dcdZc * hi, Wch.transpose())); - + const OpArgsHolder argsHolderBP({&x, &hi, &W, &Wc, &b, &bc, &dLdr, &dLdu, &dLdc, &dLdh}, {}, {}); @@ -2430,7 +2445,7 @@ TEST_F(DeclarableOpsTests9, Cholesky_Test_1) { auto res = result.at(0); // res->printIndexedBuffer("Output for Cholesky1"); ASSERT_TRUE(exp.equalsTo(res)); - + } //////////////////////////////////////////////////////////////////// @@ -2446,7 +2461,7 @@ TEST_F(DeclarableOpsTests9, Cholesky_Test_2) { auto res = result.at(0); // res->printIndexedBuffer("Output for Cholesky 2"); ASSERT_TRUE(exp.equalsTo(res)); - + } //////////////////////////////////////////////////////////////////// @@ -2462,7 +2477,7 @@ TEST_F(DeclarableOpsTests9, Cholesky_Test_3) { auto res = result.at(0); // res->printIndexedBuffer("Output for Cholesky 3"); ASSERT_TRUE(exp.equalsTo(res, 1e-4)); - + } //////////////////////////////////////////////////////////////////// diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ndarray/BaseNDArray.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ndarray/BaseNDArray.java index 0060d84f9..3edbd2682 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ndarray/BaseNDArray.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ndarray/BaseNDArray.java @@ -3772,6 +3772,8 @@ public abstract class BaseNDArray implements INDArray, Iterable { INDArray ret = Nd4j.createUninitialized(this.dataType(), shape, order); ret.setData(dup(order).data()); return ret; + } else if (this.isEmpty()) { + return Nd4j.create(this.dataType(), shape); } else { INDArray ret = this.dup(order); return Nd4j.create(ret.data(), shape); diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/nativeblas/Nd4jCuda.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/nativeblas/Nd4jCuda.java index 9b0a01d39..98521d58c 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/nativeblas/Nd4jCuda.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/nativeblas/Nd4jCuda.java @@ -4997,8 +4997,8 @@ public native @Cast("bool") boolean isOptimalRequirementsMet(); ******************************************************************************/ // -// This class is suited for execution results representation. -// +// This class is suited for execution results representation. +// // PLESE NOTE: It will delete all stored NDArrays upon destructor call // // @author raver119@gmail.com @@ -5011,7 +5011,7 @@ public native @Cast("bool") boolean isOptimalRequirementsMet(); // #include // #include // #include // forward declaration of template class NDArray - + @Namespace("sd") @NoOffset public static class ResultSet extends Pointer { static { Loader.load(); } /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ @@ -6877,6 +6877,9 @@ public static final int PREALLOC_SIZE = 33554432; @Namespace("shape") public static native int sizeAt(@Cast("const Nd4jLong*") LongPointer shapeInfo, int dim); @Namespace("shape") public static native int sizeAt(@Cast("const Nd4jLong*") LongBuffer shapeInfo, int dim); @Namespace("shape") public static native int sizeAt(@Cast("const Nd4jLong*") long[] shapeInfo, int dim); + @Namespace("shape") public static native @Cast("Nd4jLong") long strideAt(@Cast("const Nd4jLong*") LongPointer shapeInfo, int dim); + @Namespace("shape") public static native @Cast("Nd4jLong") long strideAt(@Cast("const Nd4jLong*") LongBuffer shapeInfo, int dim); + @Namespace("shape") public static native @Cast("Nd4jLong") long strideAt(@Cast("const Nd4jLong*") long[] shapeInfo, int dim); @Namespace("shape") public static native void traceNew(int id); @@ -7814,14 +7817,20 @@ public static final int PREALLOC_SIZE = 33554432; @Namespace("shape") public static native void index2coords(@Cast("Nd4jLong") long index, int rank, @Cast("const Nd4jLong*") LongPointer shape, IntPointer coords); @Namespace("shape") public static native void index2coords(@Cast("Nd4jLong") long index, int rank, @Cast("const Nd4jLong*") LongBuffer shape, IntBuffer coords); @Namespace("shape") public static native void index2coords(@Cast("Nd4jLong") long index, int rank, @Cast("const Nd4jLong*") long[] shape, int[] coords); + + @Namespace("shape") public static native void index2coordsCPU(@Cast("const Nd4jLong") long startIndex, @Cast("const Nd4jLong") long index, @Cast("const Nd4jLong*") LongPointer shapeInfo, @Cast("Nd4jLong*") LongPointer coords); + @Namespace("shape") public static native void index2coordsCPU(@Cast("const Nd4jLong") long startIndex, @Cast("const Nd4jLong") long index, @Cast("const Nd4jLong*") LongBuffer shapeInfo, @Cast("Nd4jLong*") LongBuffer coords); + @Namespace("shape") public static native void index2coordsCPU(@Cast("const Nd4jLong") long startIndex, @Cast("const Nd4jLong") long index, @Cast("const Nd4jLong*") long[] shapeInfo, @Cast("Nd4jLong*") long[] coords); + @Namespace("shape") public static native void index2coordsCPU(@Cast("const Nd4jLong") long startIndex, @Cast("const Nd4jLong") long index, @Cast("const Nd4jLong*") LongPointer shapeInfo, IntPointer coords); + @Namespace("shape") public static native void index2coordsCPU(@Cast("const Nd4jLong") long startIndex, @Cast("const Nd4jLong") long index, @Cast("const Nd4jLong*") LongBuffer shapeInfo, IntBuffer coords); + @Namespace("shape") public static native void index2coordsCPU(@Cast("const Nd4jLong") long startIndex, @Cast("const Nd4jLong") long index, @Cast("const Nd4jLong*") long[] shapeInfo, int[] coords); + /** * take into account only dimensions stored in tadDims, tadDims must be sorted in increasing order! */ - @Namespace("shape") public static native void index2coords(@Cast("Nd4jLong") long index, @Cast("const Nd4jLong*") LongPointer shapeInfo, @Cast("Nd4jLong*") LongPointer coords, int dimsSize, @Const IntPointer tadDims); - @Namespace("shape") public static native void index2coords(@Cast("Nd4jLong") long index, @Cast("const Nd4jLong*") LongBuffer shapeInfo, @Cast("Nd4jLong*") LongBuffer coords, int dimsSize, @Const IntBuffer tadDims); - @Namespace("shape") public static native void index2coords(@Cast("Nd4jLong") long index, @Cast("const Nd4jLong*") long[] shapeInfo, @Cast("Nd4jLong*") long[] coords, int dimsSize, @Const int[] tadDims); - - + @Namespace("shape") public static native void index2coords(@Cast("Nd4jLong") long index, @Cast("const Nd4jLong*") LongPointer shapeInfo, IntPointer coords, int dimsSize, @Const IntPointer tadDims); + @Namespace("shape") public static native void index2coords(@Cast("Nd4jLong") long index, @Cast("const Nd4jLong*") LongBuffer shapeInfo, IntBuffer coords, int dimsSize, @Const IntBuffer tadDims); + @Namespace("shape") public static native void index2coords(@Cast("Nd4jLong") long index, @Cast("const Nd4jLong*") long[] shapeInfo, int[] coords, int dimsSize, @Const int[] tadDims); /** * Convert coordinates to the corresponding linear index (sequence number in other words) @@ -7833,15 +7842,15 @@ public static final int PREALLOC_SIZE = 33554432; @Namespace("shape") public static native @Cast("Nd4jLong") long coords2index(@Cast("const Nd4jLong*") LongPointer shapeInfo, @Const IntPointer coords); @Namespace("shape") public static native @Cast("Nd4jLong") long coords2index(@Cast("const Nd4jLong*") LongBuffer shapeInfo, @Const IntBuffer coords); @Namespace("shape") public static native @Cast("Nd4jLong") long coords2index(@Cast("const Nd4jLong*") long[] shapeInfo, @Const int[] coords); - @Namespace("shape") public static native @Cast("Nd4jLong") long coords2index(int rank, @Cast("const Nd4jLong*") LongPointer shape, @Cast("const Nd4jLong*") LongPointer coords); - @Namespace("shape") public static native @Cast("Nd4jLong") long coords2index(int rank, @Cast("const Nd4jLong*") LongBuffer shape, @Cast("const Nd4jLong*") LongBuffer coords); - @Namespace("shape") public static native @Cast("Nd4jLong") long coords2index(int rank, @Cast("const Nd4jLong*") long[] shape, @Cast("const Nd4jLong*") long[] coords); + @Namespace("shape") public static native @Cast("Nd4jLong") long coords2index(int rank, @Cast("const Nd4jLong*") LongPointer shape, @Const IntPointer coords); + @Namespace("shape") public static native @Cast("Nd4jLong") long coords2index(int rank, @Cast("const Nd4jLong*") LongBuffer shape, @Const IntBuffer coords); + @Namespace("shape") public static native @Cast("Nd4jLong") long coords2index(int rank, @Cast("const Nd4jLong*") long[] shape, @Const int[] coords); /** * take into account only dimensions stored in tadDims, tadDims must be sorted in increasing order! */ - @Namespace("shape") public static native @Cast("Nd4jLong") long coords2index(@Cast("const Nd4jLong*") LongPointer shapeInfo, @Cast("const Nd4jLong*") LongPointer coords, int dimsSize, @Const IntPointer tadDims); - @Namespace("shape") public static native @Cast("Nd4jLong") long coords2index(@Cast("const Nd4jLong*") LongBuffer shapeInfo, @Cast("const Nd4jLong*") LongBuffer coords, int dimsSize, @Const IntBuffer tadDims); - @Namespace("shape") public static native @Cast("Nd4jLong") long coords2index(@Cast("const Nd4jLong*") long[] shapeInfo, @Cast("const Nd4jLong*") long[] coords, int dimsSize, @Const int[] tadDims); + @Namespace("shape") public static native @Cast("Nd4jLong") long coords2index(@Cast("const Nd4jLong*") LongPointer shapeInfo, @Const IntPointer coords, int dimsSize, @Const IntPointer tadDims); + @Namespace("shape") public static native @Cast("Nd4jLong") long coords2index(@Cast("const Nd4jLong*") LongBuffer shapeInfo, @Const IntBuffer coords, int dimsSize, @Const IntBuffer tadDims); + @Namespace("shape") public static native @Cast("Nd4jLong") long coords2index(@Cast("const Nd4jLong*") long[] shapeInfo, @Const int[] coords, int dimsSize, @Const int[] tadDims); /** * increment n-dimensional array by one iteration by changing coord appropriately @@ -7931,32 +7940,32 @@ public static final int PREALLOC_SIZE = 33554432; // function calculates the coordinates of min array (and saves them into minIdxs) given coordinates of max array (already stored in maxIdxs) // dimsToExclude - should be sorted in increasing order // dimsLen - length of dimsToExclude, if not set (= -1), then it is calculated as maxRank - minRank - @Namespace("shape") public static native void maxIndToMinInd(@Cast("Nd4jLong*") LongPointer maxIdxs, @Cast("Nd4jLong*") LongPointer minIdxs, @Cast("const Nd4jLong*") LongPointer maxShapeInfo, @Cast("const Nd4jLong*") LongPointer minShapeInfo, @Const IntPointer dimsToExclude/*=nullptr*/, int dimsLen/*=-1*/); - @Namespace("shape") public static native void maxIndToMinInd(@Cast("Nd4jLong*") LongPointer maxIdxs, @Cast("Nd4jLong*") LongPointer minIdxs, @Cast("const Nd4jLong*") LongPointer maxShapeInfo, @Cast("const Nd4jLong*") LongPointer minShapeInfo); - @Namespace("shape") public static native void maxIndToMinInd(@Cast("Nd4jLong*") LongBuffer maxIdxs, @Cast("Nd4jLong*") LongBuffer minIdxs, @Cast("const Nd4jLong*") LongBuffer maxShapeInfo, @Cast("const Nd4jLong*") LongBuffer minShapeInfo, @Const IntBuffer dimsToExclude/*=nullptr*/, int dimsLen/*=-1*/); - @Namespace("shape") public static native void maxIndToMinInd(@Cast("Nd4jLong*") LongBuffer maxIdxs, @Cast("Nd4jLong*") LongBuffer minIdxs, @Cast("const Nd4jLong*") LongBuffer maxShapeInfo, @Cast("const Nd4jLong*") LongBuffer minShapeInfo); - @Namespace("shape") public static native void maxIndToMinInd(@Cast("Nd4jLong*") long[] maxIdxs, @Cast("Nd4jLong*") long[] minIdxs, @Cast("const Nd4jLong*") long[] maxShapeInfo, @Cast("const Nd4jLong*") long[] minShapeInfo, @Const int[] dimsToExclude/*=nullptr*/, int dimsLen/*=-1*/); - @Namespace("shape") public static native void maxIndToMinInd(@Cast("Nd4jLong*") long[] maxIdxs, @Cast("Nd4jLong*") long[] minIdxs, @Cast("const Nd4jLong*") long[] maxShapeInfo, @Cast("const Nd4jLong*") long[] minShapeInfo); + @Namespace("shape") public static native void maxIndToMinInd(IntPointer maxIdxs, IntPointer minIdxs, @Cast("const Nd4jLong*") LongPointer maxShapeInfo, @Cast("const Nd4jLong*") LongPointer minShapeInfo, @Const IntPointer dimsToExclude/*=nullptr*/, int dimsLen/*=-1*/); + @Namespace("shape") public static native void maxIndToMinInd(IntPointer maxIdxs, IntPointer minIdxs, @Cast("const Nd4jLong*") LongPointer maxShapeInfo, @Cast("const Nd4jLong*") LongPointer minShapeInfo); + @Namespace("shape") public static native void maxIndToMinInd(IntBuffer maxIdxs, IntBuffer minIdxs, @Cast("const Nd4jLong*") LongBuffer maxShapeInfo, @Cast("const Nd4jLong*") LongBuffer minShapeInfo, @Const IntBuffer dimsToExclude/*=nullptr*/, int dimsLen/*=-1*/); + @Namespace("shape") public static native void maxIndToMinInd(IntBuffer maxIdxs, IntBuffer minIdxs, @Cast("const Nd4jLong*") LongBuffer maxShapeInfo, @Cast("const Nd4jLong*") LongBuffer minShapeInfo); + @Namespace("shape") public static native void maxIndToMinInd(int[] maxIdxs, int[] minIdxs, @Cast("const Nd4jLong*") long[] maxShapeInfo, @Cast("const Nd4jLong*") long[] minShapeInfo, @Const int[] dimsToExclude/*=nullptr*/, int dimsLen/*=-1*/); + @Namespace("shape") public static native void maxIndToMinInd(int[] maxIdxs, int[] minIdxs, @Cast("const Nd4jLong*") long[] maxShapeInfo, @Cast("const Nd4jLong*") long[] minShapeInfo); // calculate indexes of max-array, these output indexes correspond to one minIdx index of min-array which is sub-array of max-array // dimsToExclude - should be sorted in increasing order - @Namespace("shape") public static native int outerArrayIndexes(@Cast("Nd4jLong*") LongPointer maxIdxs, @Cast("const Nd4jLong") long minIdx, @Cast("const Nd4jLong*") LongPointer maxShapeInfo, @Cast("const Nd4jLong*") LongPointer minShapeInfo, @Const IntPointer dimsToExclude/*=nullptr*/); - @Namespace("shape") public static native int outerArrayIndexes(@Cast("Nd4jLong*") LongPointer maxIdxs, @Cast("const Nd4jLong") long minIdx, @Cast("const Nd4jLong*") LongPointer maxShapeInfo, @Cast("const Nd4jLong*") LongPointer minShapeInfo); - @Namespace("shape") public static native int outerArrayIndexes(@Cast("Nd4jLong*") LongBuffer maxIdxs, @Cast("const Nd4jLong") long minIdx, @Cast("const Nd4jLong*") LongBuffer maxShapeInfo, @Cast("const Nd4jLong*") LongBuffer minShapeInfo, @Const IntBuffer dimsToExclude/*=nullptr*/); - @Namespace("shape") public static native int outerArrayIndexes(@Cast("Nd4jLong*") LongBuffer maxIdxs, @Cast("const Nd4jLong") long minIdx, @Cast("const Nd4jLong*") LongBuffer maxShapeInfo, @Cast("const Nd4jLong*") LongBuffer minShapeInfo); - @Namespace("shape") public static native int outerArrayIndexes(@Cast("Nd4jLong*") long[] maxIdxs, @Cast("const Nd4jLong") long minIdx, @Cast("const Nd4jLong*") long[] maxShapeInfo, @Cast("const Nd4jLong*") long[] minShapeInfo, @Const int[] dimsToExclude/*=nullptr*/); - @Namespace("shape") public static native int outerArrayIndexes(@Cast("Nd4jLong*") long[] maxIdxs, @Cast("const Nd4jLong") long minIdx, @Cast("const Nd4jLong*") long[] maxShapeInfo, @Cast("const Nd4jLong*") long[] minShapeInfo); + @Namespace("shape") public static native int outerArrayIndexes(IntPointer maxIdxs, @Cast("const Nd4jLong") long minIdx, @Cast("const Nd4jLong*") LongPointer maxShapeInfo, @Cast("const Nd4jLong*") LongPointer minShapeInfo, @Const IntPointer dimsToExclude/*=nullptr*/); + @Namespace("shape") public static native int outerArrayIndexes(IntPointer maxIdxs, @Cast("const Nd4jLong") long minIdx, @Cast("const Nd4jLong*") LongPointer maxShapeInfo, @Cast("const Nd4jLong*") LongPointer minShapeInfo); + @Namespace("shape") public static native int outerArrayIndexes(IntBuffer maxIdxs, @Cast("const Nd4jLong") long minIdx, @Cast("const Nd4jLong*") LongBuffer maxShapeInfo, @Cast("const Nd4jLong*") LongBuffer minShapeInfo, @Const IntBuffer dimsToExclude/*=nullptr*/); + @Namespace("shape") public static native int outerArrayIndexes(IntBuffer maxIdxs, @Cast("const Nd4jLong") long minIdx, @Cast("const Nd4jLong*") LongBuffer maxShapeInfo, @Cast("const Nd4jLong*") LongBuffer minShapeInfo); + @Namespace("shape") public static native int outerArrayIndexes(int[] maxIdxs, @Cast("const Nd4jLong") long minIdx, @Cast("const Nd4jLong*") long[] maxShapeInfo, @Cast("const Nd4jLong*") long[] minShapeInfo, @Const int[] dimsToExclude/*=nullptr*/); + @Namespace("shape") public static native int outerArrayIndexes(int[] maxIdxs, @Cast("const Nd4jLong") long minIdx, @Cast("const Nd4jLong*") long[] maxShapeInfo, @Cast("const Nd4jLong*") long[] minShapeInfo); // calculate offsets of max-array, these offsets correspond to one minIdx index of min-array which is sub-array of max-array // maxOffsets - will contain calculated offsets of max-array, buffer for maxOffsets should be allocated beforehand // dimsToExclude - should be sorted in increasing order // memBuff - auxiliary memory buffer (size = 2 * max_rank) for coordinates and increments storing, should be allocated beforehand - @Namespace("shape") public static native int outerArrayOffsets(@Cast("Nd4jLong*") LongPointer maxOffsets, @Cast("const Nd4jLong") long minIdx, @Cast("const Nd4jLong*") LongPointer maxShapeInfo, @Cast("const Nd4jLong*") LongPointer minShapeInfo, @Cast("Nd4jLong*") LongPointer memBuff, @Const IntPointer dimsToExclude/*=nullptr*/); - @Namespace("shape") public static native int outerArrayOffsets(@Cast("Nd4jLong*") LongPointer maxOffsets, @Cast("const Nd4jLong") long minIdx, @Cast("const Nd4jLong*") LongPointer maxShapeInfo, @Cast("const Nd4jLong*") LongPointer minShapeInfo, @Cast("Nd4jLong*") LongPointer memBuff); - @Namespace("shape") public static native int outerArrayOffsets(@Cast("Nd4jLong*") LongBuffer maxOffsets, @Cast("const Nd4jLong") long minIdx, @Cast("const Nd4jLong*") LongBuffer maxShapeInfo, @Cast("const Nd4jLong*") LongBuffer minShapeInfo, @Cast("Nd4jLong*") LongBuffer memBuff, @Const IntBuffer dimsToExclude/*=nullptr*/); - @Namespace("shape") public static native int outerArrayOffsets(@Cast("Nd4jLong*") LongBuffer maxOffsets, @Cast("const Nd4jLong") long minIdx, @Cast("const Nd4jLong*") LongBuffer maxShapeInfo, @Cast("const Nd4jLong*") LongBuffer minShapeInfo, @Cast("Nd4jLong*") LongBuffer memBuff); - @Namespace("shape") public static native int outerArrayOffsets(@Cast("Nd4jLong*") long[] maxOffsets, @Cast("const Nd4jLong") long minIdx, @Cast("const Nd4jLong*") long[] maxShapeInfo, @Cast("const Nd4jLong*") long[] minShapeInfo, @Cast("Nd4jLong*") long[] memBuff, @Const int[] dimsToExclude/*=nullptr*/); - @Namespace("shape") public static native int outerArrayOffsets(@Cast("Nd4jLong*") long[] maxOffsets, @Cast("const Nd4jLong") long minIdx, @Cast("const Nd4jLong*") long[] maxShapeInfo, @Cast("const Nd4jLong*") long[] minShapeInfo, @Cast("Nd4jLong*") long[] memBuff); + @Namespace("shape") public static native int outerArrayOffsets(@Cast("Nd4jLong*") LongPointer maxOffsets, @Cast("const Nd4jLong") long minIdx, @Cast("const Nd4jLong*") LongPointer maxShapeInfo, @Cast("const Nd4jLong*") LongPointer minShapeInfo, IntPointer memBuff, @Const IntPointer dimsToExclude/*=nullptr*/); + @Namespace("shape") public static native int outerArrayOffsets(@Cast("Nd4jLong*") LongPointer maxOffsets, @Cast("const Nd4jLong") long minIdx, @Cast("const Nd4jLong*") LongPointer maxShapeInfo, @Cast("const Nd4jLong*") LongPointer minShapeInfo, IntPointer memBuff); + @Namespace("shape") public static native int outerArrayOffsets(@Cast("Nd4jLong*") LongBuffer maxOffsets, @Cast("const Nd4jLong") long minIdx, @Cast("const Nd4jLong*") LongBuffer maxShapeInfo, @Cast("const Nd4jLong*") LongBuffer minShapeInfo, IntBuffer memBuff, @Const IntBuffer dimsToExclude/*=nullptr*/); + @Namespace("shape") public static native int outerArrayOffsets(@Cast("Nd4jLong*") LongBuffer maxOffsets, @Cast("const Nd4jLong") long minIdx, @Cast("const Nd4jLong*") LongBuffer maxShapeInfo, @Cast("const Nd4jLong*") LongBuffer minShapeInfo, IntBuffer memBuff); + @Namespace("shape") public static native int outerArrayOffsets(@Cast("Nd4jLong*") long[] maxOffsets, @Cast("const Nd4jLong") long minIdx, @Cast("const Nd4jLong*") long[] maxShapeInfo, @Cast("const Nd4jLong*") long[] minShapeInfo, int[] memBuff, @Const int[] dimsToExclude/*=nullptr*/); + @Namespace("shape") public static native int outerArrayOffsets(@Cast("Nd4jLong*") long[] maxOffsets, @Cast("const Nd4jLong") long minIdx, @Cast("const Nd4jLong*") long[] maxShapeInfo, @Cast("const Nd4jLong*") long[] minShapeInfo, int[] memBuff); // calculates offsets for entities (elements or sub-arrays), shape in context of sub-array means dimensions excluded from outer array // rank is equal to size of shape @@ -8052,9 +8061,7 @@ public static final int PREALLOC_SIZE = 33554432; * get stride over contiguous axis (contiguous axis must have stride = 1) * for example when inShapeInfo is {4, 2,5,4,3, 60,1,5,20, 16384,0,99} then output is 5 (that is smallest stride in inShapeInfo except those equal to 1) */ - @Namespace("shape") public static native @Cast("Nd4jLong") long strideOverContigAxis(int axis, @Cast("const Nd4jLong*") LongPointer inShapeInfo); - @Namespace("shape") public static native @Cast("Nd4jLong") long strideOverContigAxis(int axis, @Cast("const Nd4jLong*") LongBuffer inShapeInfo); - @Namespace("shape") public static native @Cast("Nd4jLong") long strideOverContigAxis(int axis, @Cast("const Nd4jLong*") long[] inShapeInfo); + // INLINEDEF _CUDA_HD Nd4jLong strideOverContigAxis(const int axis, const Nd4jLong* inShapeInfo); @@ -8930,6 +8937,10 @@ public static final int PREALLOC_SIZE = 33554432; ////////////////////////////////////////////////////////////////////// +////////////////////////////////////////////////////////////////////// + +////////////////////////////////////////////////////////////////////// + ////////////////////////////////////////////////////////////////////// // INLINEDEF _CUDA_HD void calcOffsets(const Nd4jLong *xShapeInfo, Nd4jLong*& xOffsets, const Nd4jLong *yShapeInfo, Nd4jLong*& yOffsets, const Nd4jLong* zShapeInfo, Nd4jLong*& zOffsets, const char order) { @@ -9126,6 +9137,23 @@ public static final int PREALLOC_SIZE = 33554432; ////////////////////////////////////////////////////////////////////// +// INLINEDEF _CUDA_HD Nd4jLong strideOverContigAxis(const int axis, const Nd4jLong* inShapeInfo) { + +// Nd4jLong result = 9223372036854775807LL; + +// for(uint i = 0; i < shape::rank(inShapeInfo); ++i) { + +// const auto currentStride = shape::stride(inShapeInfo)[i]; + +// if(i == axis || shape::shapeOf(inShapeInfo)[i] == 1) +// continue; + +// if(result > currentStride) +// result = currentStride; +// } + +// return result == 9223372036854775807LL ? 1 : result; +// } @@ -9736,18 +9764,17 @@ public static final int PREALLOC_SIZE = 33554432; public native @Cast("Nd4jStatus") int execute(@Const @ByRef NDArrayVector inputs, @Const @ByRef NDArrayVector outputs, @StdVector DoubleBuffer tArgs, @Cast("Nd4jLong*") @StdVector LongBuffer iArgs, @Cast("bool*") @StdVector BooleanPointer bArgs/*=std::vector()*/, @Cast("sd::DataType*") @StdVector IntBuffer dArgs/*=std::vector()*/, @Cast("bool") boolean isInplace/*=false*/); public native @Cast("Nd4jStatus") int execute(@Const @ByRef NDArrayVector inputs, @Const @ByRef NDArrayVector outputs, @StdVector double[] tArgs, @Cast("Nd4jLong*") @StdVector long[] iArgs, @Cast("bool*") @StdVector boolean[] bArgs/*=std::vector()*/, @Cast("sd::DataType*") @StdVector int[] dArgs/*=std::vector()*/, @Cast("bool") boolean isInplace/*=false*/); + public native @ByVal ResultSet evaluate(@Const @ByRef NDArrayVector inputs); - public native ResultSet evaluate(@Const @ByRef NDArrayVector inputs); - - public native ResultSet evaluate(@Const @ByRef NDArrayVector inputs, @StdVector DoublePointer tArgs, @Cast("Nd4jLong*") @StdVector LongPointer iArgs, @Cast("bool*") @StdVector BooleanPointer bArgs/*=std::vector()*/, @Cast("sd::DataType*") @StdVector IntPointer dArgs/*=std::vector()*/, @Cast("bool") boolean isInplace/*=false*/); - public native ResultSet evaluate(@Const @ByRef NDArrayVector inputs, @StdVector DoublePointer tArgs, @Cast("Nd4jLong*") @StdVector LongPointer iArgs); - public native ResultSet evaluate(@Const @ByRef NDArrayVector inputs, @StdVector DoubleBuffer tArgs, @Cast("Nd4jLong*") @StdVector LongBuffer iArgs, @Cast("bool*") @StdVector boolean[] bArgs/*=std::vector()*/, @Cast("sd::DataType*") @StdVector IntBuffer dArgs/*=std::vector()*/, @Cast("bool") boolean isInplace/*=false*/); - public native ResultSet evaluate(@Const @ByRef NDArrayVector inputs, @StdVector DoubleBuffer tArgs, @Cast("Nd4jLong*") @StdVector LongBuffer iArgs); - public native ResultSet evaluate(@Const @ByRef NDArrayVector inputs, @StdVector double[] tArgs, @Cast("Nd4jLong*") @StdVector long[] iArgs, @Cast("bool*") @StdVector BooleanPointer bArgs/*=std::vector()*/, @Cast("sd::DataType*") @StdVector int[] dArgs/*=std::vector()*/, @Cast("bool") boolean isInplace/*=false*/); - public native ResultSet evaluate(@Const @ByRef NDArrayVector inputs, @StdVector double[] tArgs, @Cast("Nd4jLong*") @StdVector long[] iArgs); - public native ResultSet evaluate(@Const @ByRef NDArrayVector inputs, @StdVector DoublePointer tArgs, @Cast("Nd4jLong*") @StdVector LongPointer iArgs, @Cast("bool*") @StdVector boolean[] bArgs/*=std::vector()*/, @Cast("sd::DataType*") @StdVector IntPointer dArgs/*=std::vector()*/, @Cast("bool") boolean isInplace/*=false*/); - public native ResultSet evaluate(@Const @ByRef NDArrayVector inputs, @StdVector DoubleBuffer tArgs, @Cast("Nd4jLong*") @StdVector LongBuffer iArgs, @Cast("bool*") @StdVector BooleanPointer bArgs/*=std::vector()*/, @Cast("sd::DataType*") @StdVector IntBuffer dArgs/*=std::vector()*/, @Cast("bool") boolean isInplace/*=false*/); - public native ResultSet evaluate(@Const @ByRef NDArrayVector inputs, @StdVector double[] tArgs, @Cast("Nd4jLong*") @StdVector long[] iArgs, @Cast("bool*") @StdVector boolean[] bArgs/*=std::vector()*/, @Cast("sd::DataType*") @StdVector int[] dArgs/*=std::vector()*/, @Cast("bool") boolean isInplace/*=false*/); + public native @ByVal ResultSet evaluate(@Const @ByRef NDArrayVector inputs, @StdVector DoublePointer tArgs, @Cast("Nd4jLong*") @StdVector LongPointer iArgs, @Cast("bool*") @StdVector BooleanPointer bArgs/*=std::vector()*/, @Cast("sd::DataType*") @StdVector IntPointer dArgs/*=std::vector()*/, @Cast("bool") boolean isInplace/*=false*/); + public native @ByVal ResultSet evaluate(@Const @ByRef NDArrayVector inputs, @StdVector DoublePointer tArgs, @Cast("Nd4jLong*") @StdVector LongPointer iArgs); + public native @ByVal ResultSet evaluate(@Const @ByRef NDArrayVector inputs, @StdVector DoubleBuffer tArgs, @Cast("Nd4jLong*") @StdVector LongBuffer iArgs, @Cast("bool*") @StdVector boolean[] bArgs/*=std::vector()*/, @Cast("sd::DataType*") @StdVector IntBuffer dArgs/*=std::vector()*/, @Cast("bool") boolean isInplace/*=false*/); + public native @ByVal ResultSet evaluate(@Const @ByRef NDArrayVector inputs, @StdVector DoubleBuffer tArgs, @Cast("Nd4jLong*") @StdVector LongBuffer iArgs); + public native @ByVal ResultSet evaluate(@Const @ByRef NDArrayVector inputs, @StdVector double[] tArgs, @Cast("Nd4jLong*") @StdVector long[] iArgs, @Cast("bool*") @StdVector BooleanPointer bArgs/*=std::vector()*/, @Cast("sd::DataType*") @StdVector int[] dArgs/*=std::vector()*/, @Cast("bool") boolean isInplace/*=false*/); + public native @ByVal ResultSet evaluate(@Const @ByRef NDArrayVector inputs, @StdVector double[] tArgs, @Cast("Nd4jLong*") @StdVector long[] iArgs); + public native @ByVal ResultSet evaluate(@Const @ByRef NDArrayVector inputs, @StdVector DoublePointer tArgs, @Cast("Nd4jLong*") @StdVector LongPointer iArgs, @Cast("bool*") @StdVector boolean[] bArgs/*=std::vector()*/, @Cast("sd::DataType*") @StdVector IntPointer dArgs/*=std::vector()*/, @Cast("bool") boolean isInplace/*=false*/); + public native @ByVal ResultSet evaluate(@Const @ByRef NDArrayVector inputs, @StdVector DoubleBuffer tArgs, @Cast("Nd4jLong*") @StdVector LongBuffer iArgs, @Cast("bool*") @StdVector BooleanPointer bArgs/*=std::vector()*/, @Cast("sd::DataType*") @StdVector IntBuffer dArgs/*=std::vector()*/, @Cast("bool") boolean isInplace/*=false*/); + public native @ByVal ResultSet evaluate(@Const @ByRef NDArrayVector inputs, @StdVector double[] tArgs, @Cast("Nd4jLong*") @StdVector long[] iArgs, @Cast("bool*") @StdVector boolean[] bArgs/*=std::vector()*/, @Cast("sd::DataType*") @StdVector int[] dArgs/*=std::vector()*/, @Cast("bool") boolean isInplace/*=false*/); public native @Cast("Nd4jStatus") int execute(@ByRef RandomGenerator rng, @Const @ByRef NDArrayVector inputs, @Const @ByRef NDArrayVector outputs, @StdVector DoublePointer tArgs, @Cast("Nd4jLong*") @StdVector LongPointer iArgs, @Cast("bool*") @StdVector BooleanPointer bArgs, @Cast("sd::DataType*") @StdVector IntPointer dArgs/*=std::vector()*/, @Cast("bool") boolean isInplace/*=false*/, @Cast("sd::DataType") int type/*=sd::DataType::FLOAT32*/); public native @Cast("Nd4jStatus") int execute(@ByRef RandomGenerator rng, @Const @ByRef NDArrayVector inputs, @Const @ByRef NDArrayVector outputs, @StdVector DoublePointer tArgs, @Cast("Nd4jLong*") @StdVector LongPointer iArgs, @Cast("bool*") @StdVector BooleanPointer bArgs); @@ -9762,8 +9789,9 @@ public static final int PREALLOC_SIZE = 33554432; public native @Cast("Nd4jStatus") int execute(@ByRef RandomGenerator rng, @Const @ByRef NDArrayVector inputs, @Const @ByRef NDArrayVector outputs, @StdVector double[] tArgs, @Cast("Nd4jLong*") @StdVector long[] iArgs, @Cast("bool*") @StdVector boolean[] bArgs, @Cast("sd::DataType*") @StdVector int[] dArgs/*=std::vector()*/, @Cast("bool") boolean isInplace/*=false*/, @Cast("sd::DataType") int type/*=sd::DataType::FLOAT32*/); public native @Cast("Nd4jStatus") int execute(@ByRef RandomGenerator rng, @Const @ByRef NDArrayVector inputs, @Const @ByRef NDArrayVector outputs, @StdVector double[] tArgs, @Cast("Nd4jLong*") @StdVector long[] iArgs, @Cast("bool*") @StdVector boolean[] bArgs); - public native ResultSet execute(@Const @ByRef OpArgsHolder holder, @Cast("bool") boolean isInplace/*=false*/); - public native ResultSet execute(@Const @ByRef OpArgsHolder holder); + public native @ByVal ResultSet execute(@Const @ByRef OpArgsHolder holder, @Cast("bool") boolean isInplace/*=false*/); + public native @ByVal ResultSet execute(@Const @ByRef OpArgsHolder holder); + // There methods provide various validation options public native @Cast("Nd4jStatus") int validateNonEmptyInput(@ByRef Context block); @@ -9835,9 +9863,9 @@ public static final int PREALLOC_SIZE = 33554432; public native @Cast("Nd4jStatus") int execute(Context block); - public native ResultSet execute(NDArrayList list, @ByRef NDArrayVector inputs, @StdVector DoublePointer tArgs, @StdVector IntPointer iArgs); - public native ResultSet execute(NDArrayList list, @ByRef NDArrayVector inputs, @StdVector DoubleBuffer tArgs, @StdVector IntBuffer iArgs); - public native ResultSet execute(NDArrayList list, @ByRef NDArrayVector inputs, @StdVector double[] tArgs, @StdVector int[] iArgs); + public native @ByVal ResultSet execute(NDArrayList list, @ByRef NDArrayVector inputs, @StdVector DoublePointer tArgs, @StdVector IntPointer iArgs); + public native @ByVal ResultSet execute(NDArrayList list, @ByRef NDArrayVector inputs, @StdVector DoubleBuffer tArgs, @StdVector IntBuffer iArgs); + public native @ByVal ResultSet execute(NDArrayList list, @ByRef NDArrayVector inputs, @StdVector double[] tArgs, @StdVector int[] iArgs); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/nativeblas/Nd4jCpu.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/nativeblas/Nd4jCpu.java index 5105eb5d2..240dbc843 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/nativeblas/Nd4jCpu.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/nativeblas/Nd4jCpu.java @@ -5000,8 +5000,8 @@ public native @Cast("bool") boolean isOptimalRequirementsMet(); ******************************************************************************/ // -// This class is suited for execution results representation. -// +// This class is suited for execution results representation. +// // PLESE NOTE: It will delete all stored NDArrays upon destructor call // // @author raver119@gmail.com @@ -5014,7 +5014,7 @@ public native @Cast("bool") boolean isOptimalRequirementsMet(); // #include // #include // #include // forward declaration of template class NDArray - + @Namespace("sd") @NoOffset public static class ResultSet extends Pointer { static { Loader.load(); } /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ @@ -6880,6 +6880,9 @@ public static final int PREALLOC_SIZE = 33554432; @Namespace("shape") public static native int sizeAt(@Cast("const Nd4jLong*") LongPointer shapeInfo, int dim); @Namespace("shape") public static native int sizeAt(@Cast("const Nd4jLong*") LongBuffer shapeInfo, int dim); @Namespace("shape") public static native int sizeAt(@Cast("const Nd4jLong*") long[] shapeInfo, int dim); + @Namespace("shape") public static native @Cast("Nd4jLong") long strideAt(@Cast("const Nd4jLong*") LongPointer shapeInfo, int dim); + @Namespace("shape") public static native @Cast("Nd4jLong") long strideAt(@Cast("const Nd4jLong*") LongBuffer shapeInfo, int dim); + @Namespace("shape") public static native @Cast("Nd4jLong") long strideAt(@Cast("const Nd4jLong*") long[] shapeInfo, int dim); @Namespace("shape") public static native void traceNew(int id); @@ -7817,14 +7820,20 @@ public static final int PREALLOC_SIZE = 33554432; @Namespace("shape") public static native void index2coords(@Cast("Nd4jLong") long index, int rank, @Cast("const Nd4jLong*") LongPointer shape, IntPointer coords); @Namespace("shape") public static native void index2coords(@Cast("Nd4jLong") long index, int rank, @Cast("const Nd4jLong*") LongBuffer shape, IntBuffer coords); @Namespace("shape") public static native void index2coords(@Cast("Nd4jLong") long index, int rank, @Cast("const Nd4jLong*") long[] shape, int[] coords); + + @Namespace("shape") public static native void index2coordsCPU(@Cast("const Nd4jLong") long startIndex, @Cast("const Nd4jLong") long index, @Cast("const Nd4jLong*") LongPointer shapeInfo, @Cast("Nd4jLong*") LongPointer coords); + @Namespace("shape") public static native void index2coordsCPU(@Cast("const Nd4jLong") long startIndex, @Cast("const Nd4jLong") long index, @Cast("const Nd4jLong*") LongBuffer shapeInfo, @Cast("Nd4jLong*") LongBuffer coords); + @Namespace("shape") public static native void index2coordsCPU(@Cast("const Nd4jLong") long startIndex, @Cast("const Nd4jLong") long index, @Cast("const Nd4jLong*") long[] shapeInfo, @Cast("Nd4jLong*") long[] coords); + @Namespace("shape") public static native void index2coordsCPU(@Cast("const Nd4jLong") long startIndex, @Cast("const Nd4jLong") long index, @Cast("const Nd4jLong*") LongPointer shapeInfo, IntPointer coords); + @Namespace("shape") public static native void index2coordsCPU(@Cast("const Nd4jLong") long startIndex, @Cast("const Nd4jLong") long index, @Cast("const Nd4jLong*") LongBuffer shapeInfo, IntBuffer coords); + @Namespace("shape") public static native void index2coordsCPU(@Cast("const Nd4jLong") long startIndex, @Cast("const Nd4jLong") long index, @Cast("const Nd4jLong*") long[] shapeInfo, int[] coords); + /** * take into account only dimensions stored in tadDims, tadDims must be sorted in increasing order! */ - @Namespace("shape") public static native void index2coords(@Cast("Nd4jLong") long index, @Cast("const Nd4jLong*") LongPointer shapeInfo, @Cast("Nd4jLong*") LongPointer coords, int dimsSize, @Const IntPointer tadDims); - @Namespace("shape") public static native void index2coords(@Cast("Nd4jLong") long index, @Cast("const Nd4jLong*") LongBuffer shapeInfo, @Cast("Nd4jLong*") LongBuffer coords, int dimsSize, @Const IntBuffer tadDims); - @Namespace("shape") public static native void index2coords(@Cast("Nd4jLong") long index, @Cast("const Nd4jLong*") long[] shapeInfo, @Cast("Nd4jLong*") long[] coords, int dimsSize, @Const int[] tadDims); - - + @Namespace("shape") public static native void index2coords(@Cast("Nd4jLong") long index, @Cast("const Nd4jLong*") LongPointer shapeInfo, IntPointer coords, int dimsSize, @Const IntPointer tadDims); + @Namespace("shape") public static native void index2coords(@Cast("Nd4jLong") long index, @Cast("const Nd4jLong*") LongBuffer shapeInfo, IntBuffer coords, int dimsSize, @Const IntBuffer tadDims); + @Namespace("shape") public static native void index2coords(@Cast("Nd4jLong") long index, @Cast("const Nd4jLong*") long[] shapeInfo, int[] coords, int dimsSize, @Const int[] tadDims); /** * Convert coordinates to the corresponding linear index (sequence number in other words) @@ -7836,15 +7845,15 @@ public static final int PREALLOC_SIZE = 33554432; @Namespace("shape") public static native @Cast("Nd4jLong") long coords2index(@Cast("const Nd4jLong*") LongPointer shapeInfo, @Const IntPointer coords); @Namespace("shape") public static native @Cast("Nd4jLong") long coords2index(@Cast("const Nd4jLong*") LongBuffer shapeInfo, @Const IntBuffer coords); @Namespace("shape") public static native @Cast("Nd4jLong") long coords2index(@Cast("const Nd4jLong*") long[] shapeInfo, @Const int[] coords); - @Namespace("shape") public static native @Cast("Nd4jLong") long coords2index(int rank, @Cast("const Nd4jLong*") LongPointer shape, @Cast("const Nd4jLong*") LongPointer coords); - @Namespace("shape") public static native @Cast("Nd4jLong") long coords2index(int rank, @Cast("const Nd4jLong*") LongBuffer shape, @Cast("const Nd4jLong*") LongBuffer coords); - @Namespace("shape") public static native @Cast("Nd4jLong") long coords2index(int rank, @Cast("const Nd4jLong*") long[] shape, @Cast("const Nd4jLong*") long[] coords); + @Namespace("shape") public static native @Cast("Nd4jLong") long coords2index(int rank, @Cast("const Nd4jLong*") LongPointer shape, @Const IntPointer coords); + @Namespace("shape") public static native @Cast("Nd4jLong") long coords2index(int rank, @Cast("const Nd4jLong*") LongBuffer shape, @Const IntBuffer coords); + @Namespace("shape") public static native @Cast("Nd4jLong") long coords2index(int rank, @Cast("const Nd4jLong*") long[] shape, @Const int[] coords); /** * take into account only dimensions stored in tadDims, tadDims must be sorted in increasing order! */ - @Namespace("shape") public static native @Cast("Nd4jLong") long coords2index(@Cast("const Nd4jLong*") LongPointer shapeInfo, @Cast("const Nd4jLong*") LongPointer coords, int dimsSize, @Const IntPointer tadDims); - @Namespace("shape") public static native @Cast("Nd4jLong") long coords2index(@Cast("const Nd4jLong*") LongBuffer shapeInfo, @Cast("const Nd4jLong*") LongBuffer coords, int dimsSize, @Const IntBuffer tadDims); - @Namespace("shape") public static native @Cast("Nd4jLong") long coords2index(@Cast("const Nd4jLong*") long[] shapeInfo, @Cast("const Nd4jLong*") long[] coords, int dimsSize, @Const int[] tadDims); + @Namespace("shape") public static native @Cast("Nd4jLong") long coords2index(@Cast("const Nd4jLong*") LongPointer shapeInfo, @Const IntPointer coords, int dimsSize, @Const IntPointer tadDims); + @Namespace("shape") public static native @Cast("Nd4jLong") long coords2index(@Cast("const Nd4jLong*") LongBuffer shapeInfo, @Const IntBuffer coords, int dimsSize, @Const IntBuffer tadDims); + @Namespace("shape") public static native @Cast("Nd4jLong") long coords2index(@Cast("const Nd4jLong*") long[] shapeInfo, @Const int[] coords, int dimsSize, @Const int[] tadDims); /** * increment n-dimensional array by one iteration by changing coord appropriately @@ -7934,32 +7943,32 @@ public static final int PREALLOC_SIZE = 33554432; // function calculates the coordinates of min array (and saves them into minIdxs) given coordinates of max array (already stored in maxIdxs) // dimsToExclude - should be sorted in increasing order // dimsLen - length of dimsToExclude, if not set (= -1), then it is calculated as maxRank - minRank - @Namespace("shape") public static native void maxIndToMinInd(@Cast("Nd4jLong*") LongPointer maxIdxs, @Cast("Nd4jLong*") LongPointer minIdxs, @Cast("const Nd4jLong*") LongPointer maxShapeInfo, @Cast("const Nd4jLong*") LongPointer minShapeInfo, @Const IntPointer dimsToExclude/*=nullptr*/, int dimsLen/*=-1*/); - @Namespace("shape") public static native void maxIndToMinInd(@Cast("Nd4jLong*") LongPointer maxIdxs, @Cast("Nd4jLong*") LongPointer minIdxs, @Cast("const Nd4jLong*") LongPointer maxShapeInfo, @Cast("const Nd4jLong*") LongPointer minShapeInfo); - @Namespace("shape") public static native void maxIndToMinInd(@Cast("Nd4jLong*") LongBuffer maxIdxs, @Cast("Nd4jLong*") LongBuffer minIdxs, @Cast("const Nd4jLong*") LongBuffer maxShapeInfo, @Cast("const Nd4jLong*") LongBuffer minShapeInfo, @Const IntBuffer dimsToExclude/*=nullptr*/, int dimsLen/*=-1*/); - @Namespace("shape") public static native void maxIndToMinInd(@Cast("Nd4jLong*") LongBuffer maxIdxs, @Cast("Nd4jLong*") LongBuffer minIdxs, @Cast("const Nd4jLong*") LongBuffer maxShapeInfo, @Cast("const Nd4jLong*") LongBuffer minShapeInfo); - @Namespace("shape") public static native void maxIndToMinInd(@Cast("Nd4jLong*") long[] maxIdxs, @Cast("Nd4jLong*") long[] minIdxs, @Cast("const Nd4jLong*") long[] maxShapeInfo, @Cast("const Nd4jLong*") long[] minShapeInfo, @Const int[] dimsToExclude/*=nullptr*/, int dimsLen/*=-1*/); - @Namespace("shape") public static native void maxIndToMinInd(@Cast("Nd4jLong*") long[] maxIdxs, @Cast("Nd4jLong*") long[] minIdxs, @Cast("const Nd4jLong*") long[] maxShapeInfo, @Cast("const Nd4jLong*") long[] minShapeInfo); + @Namespace("shape") public static native void maxIndToMinInd(IntPointer maxIdxs, IntPointer minIdxs, @Cast("const Nd4jLong*") LongPointer maxShapeInfo, @Cast("const Nd4jLong*") LongPointer minShapeInfo, @Const IntPointer dimsToExclude/*=nullptr*/, int dimsLen/*=-1*/); + @Namespace("shape") public static native void maxIndToMinInd(IntPointer maxIdxs, IntPointer minIdxs, @Cast("const Nd4jLong*") LongPointer maxShapeInfo, @Cast("const Nd4jLong*") LongPointer minShapeInfo); + @Namespace("shape") public static native void maxIndToMinInd(IntBuffer maxIdxs, IntBuffer minIdxs, @Cast("const Nd4jLong*") LongBuffer maxShapeInfo, @Cast("const Nd4jLong*") LongBuffer minShapeInfo, @Const IntBuffer dimsToExclude/*=nullptr*/, int dimsLen/*=-1*/); + @Namespace("shape") public static native void maxIndToMinInd(IntBuffer maxIdxs, IntBuffer minIdxs, @Cast("const Nd4jLong*") LongBuffer maxShapeInfo, @Cast("const Nd4jLong*") LongBuffer minShapeInfo); + @Namespace("shape") public static native void maxIndToMinInd(int[] maxIdxs, int[] minIdxs, @Cast("const Nd4jLong*") long[] maxShapeInfo, @Cast("const Nd4jLong*") long[] minShapeInfo, @Const int[] dimsToExclude/*=nullptr*/, int dimsLen/*=-1*/); + @Namespace("shape") public static native void maxIndToMinInd(int[] maxIdxs, int[] minIdxs, @Cast("const Nd4jLong*") long[] maxShapeInfo, @Cast("const Nd4jLong*") long[] minShapeInfo); // calculate indexes of max-array, these output indexes correspond to one minIdx index of min-array which is sub-array of max-array // dimsToExclude - should be sorted in increasing order - @Namespace("shape") public static native int outerArrayIndexes(@Cast("Nd4jLong*") LongPointer maxIdxs, @Cast("const Nd4jLong") long minIdx, @Cast("const Nd4jLong*") LongPointer maxShapeInfo, @Cast("const Nd4jLong*") LongPointer minShapeInfo, @Const IntPointer dimsToExclude/*=nullptr*/); - @Namespace("shape") public static native int outerArrayIndexes(@Cast("Nd4jLong*") LongPointer maxIdxs, @Cast("const Nd4jLong") long minIdx, @Cast("const Nd4jLong*") LongPointer maxShapeInfo, @Cast("const Nd4jLong*") LongPointer minShapeInfo); - @Namespace("shape") public static native int outerArrayIndexes(@Cast("Nd4jLong*") LongBuffer maxIdxs, @Cast("const Nd4jLong") long minIdx, @Cast("const Nd4jLong*") LongBuffer maxShapeInfo, @Cast("const Nd4jLong*") LongBuffer minShapeInfo, @Const IntBuffer dimsToExclude/*=nullptr*/); - @Namespace("shape") public static native int outerArrayIndexes(@Cast("Nd4jLong*") LongBuffer maxIdxs, @Cast("const Nd4jLong") long minIdx, @Cast("const Nd4jLong*") LongBuffer maxShapeInfo, @Cast("const Nd4jLong*") LongBuffer minShapeInfo); - @Namespace("shape") public static native int outerArrayIndexes(@Cast("Nd4jLong*") long[] maxIdxs, @Cast("const Nd4jLong") long minIdx, @Cast("const Nd4jLong*") long[] maxShapeInfo, @Cast("const Nd4jLong*") long[] minShapeInfo, @Const int[] dimsToExclude/*=nullptr*/); - @Namespace("shape") public static native int outerArrayIndexes(@Cast("Nd4jLong*") long[] maxIdxs, @Cast("const Nd4jLong") long minIdx, @Cast("const Nd4jLong*") long[] maxShapeInfo, @Cast("const Nd4jLong*") long[] minShapeInfo); + @Namespace("shape") public static native int outerArrayIndexes(IntPointer maxIdxs, @Cast("const Nd4jLong") long minIdx, @Cast("const Nd4jLong*") LongPointer maxShapeInfo, @Cast("const Nd4jLong*") LongPointer minShapeInfo, @Const IntPointer dimsToExclude/*=nullptr*/); + @Namespace("shape") public static native int outerArrayIndexes(IntPointer maxIdxs, @Cast("const Nd4jLong") long minIdx, @Cast("const Nd4jLong*") LongPointer maxShapeInfo, @Cast("const Nd4jLong*") LongPointer minShapeInfo); + @Namespace("shape") public static native int outerArrayIndexes(IntBuffer maxIdxs, @Cast("const Nd4jLong") long minIdx, @Cast("const Nd4jLong*") LongBuffer maxShapeInfo, @Cast("const Nd4jLong*") LongBuffer minShapeInfo, @Const IntBuffer dimsToExclude/*=nullptr*/); + @Namespace("shape") public static native int outerArrayIndexes(IntBuffer maxIdxs, @Cast("const Nd4jLong") long minIdx, @Cast("const Nd4jLong*") LongBuffer maxShapeInfo, @Cast("const Nd4jLong*") LongBuffer minShapeInfo); + @Namespace("shape") public static native int outerArrayIndexes(int[] maxIdxs, @Cast("const Nd4jLong") long minIdx, @Cast("const Nd4jLong*") long[] maxShapeInfo, @Cast("const Nd4jLong*") long[] minShapeInfo, @Const int[] dimsToExclude/*=nullptr*/); + @Namespace("shape") public static native int outerArrayIndexes(int[] maxIdxs, @Cast("const Nd4jLong") long minIdx, @Cast("const Nd4jLong*") long[] maxShapeInfo, @Cast("const Nd4jLong*") long[] minShapeInfo); // calculate offsets of max-array, these offsets correspond to one minIdx index of min-array which is sub-array of max-array // maxOffsets - will contain calculated offsets of max-array, buffer for maxOffsets should be allocated beforehand // dimsToExclude - should be sorted in increasing order // memBuff - auxiliary memory buffer (size = 2 * max_rank) for coordinates and increments storing, should be allocated beforehand - @Namespace("shape") public static native int outerArrayOffsets(@Cast("Nd4jLong*") LongPointer maxOffsets, @Cast("const Nd4jLong") long minIdx, @Cast("const Nd4jLong*") LongPointer maxShapeInfo, @Cast("const Nd4jLong*") LongPointer minShapeInfo, @Cast("Nd4jLong*") LongPointer memBuff, @Const IntPointer dimsToExclude/*=nullptr*/); - @Namespace("shape") public static native int outerArrayOffsets(@Cast("Nd4jLong*") LongPointer maxOffsets, @Cast("const Nd4jLong") long minIdx, @Cast("const Nd4jLong*") LongPointer maxShapeInfo, @Cast("const Nd4jLong*") LongPointer minShapeInfo, @Cast("Nd4jLong*") LongPointer memBuff); - @Namespace("shape") public static native int outerArrayOffsets(@Cast("Nd4jLong*") LongBuffer maxOffsets, @Cast("const Nd4jLong") long minIdx, @Cast("const Nd4jLong*") LongBuffer maxShapeInfo, @Cast("const Nd4jLong*") LongBuffer minShapeInfo, @Cast("Nd4jLong*") LongBuffer memBuff, @Const IntBuffer dimsToExclude/*=nullptr*/); - @Namespace("shape") public static native int outerArrayOffsets(@Cast("Nd4jLong*") LongBuffer maxOffsets, @Cast("const Nd4jLong") long minIdx, @Cast("const Nd4jLong*") LongBuffer maxShapeInfo, @Cast("const Nd4jLong*") LongBuffer minShapeInfo, @Cast("Nd4jLong*") LongBuffer memBuff); - @Namespace("shape") public static native int outerArrayOffsets(@Cast("Nd4jLong*") long[] maxOffsets, @Cast("const Nd4jLong") long minIdx, @Cast("const Nd4jLong*") long[] maxShapeInfo, @Cast("const Nd4jLong*") long[] minShapeInfo, @Cast("Nd4jLong*") long[] memBuff, @Const int[] dimsToExclude/*=nullptr*/); - @Namespace("shape") public static native int outerArrayOffsets(@Cast("Nd4jLong*") long[] maxOffsets, @Cast("const Nd4jLong") long minIdx, @Cast("const Nd4jLong*") long[] maxShapeInfo, @Cast("const Nd4jLong*") long[] minShapeInfo, @Cast("Nd4jLong*") long[] memBuff); + @Namespace("shape") public static native int outerArrayOffsets(@Cast("Nd4jLong*") LongPointer maxOffsets, @Cast("const Nd4jLong") long minIdx, @Cast("const Nd4jLong*") LongPointer maxShapeInfo, @Cast("const Nd4jLong*") LongPointer minShapeInfo, IntPointer memBuff, @Const IntPointer dimsToExclude/*=nullptr*/); + @Namespace("shape") public static native int outerArrayOffsets(@Cast("Nd4jLong*") LongPointer maxOffsets, @Cast("const Nd4jLong") long minIdx, @Cast("const Nd4jLong*") LongPointer maxShapeInfo, @Cast("const Nd4jLong*") LongPointer minShapeInfo, IntPointer memBuff); + @Namespace("shape") public static native int outerArrayOffsets(@Cast("Nd4jLong*") LongBuffer maxOffsets, @Cast("const Nd4jLong") long minIdx, @Cast("const Nd4jLong*") LongBuffer maxShapeInfo, @Cast("const Nd4jLong*") LongBuffer minShapeInfo, IntBuffer memBuff, @Const IntBuffer dimsToExclude/*=nullptr*/); + @Namespace("shape") public static native int outerArrayOffsets(@Cast("Nd4jLong*") LongBuffer maxOffsets, @Cast("const Nd4jLong") long minIdx, @Cast("const Nd4jLong*") LongBuffer maxShapeInfo, @Cast("const Nd4jLong*") LongBuffer minShapeInfo, IntBuffer memBuff); + @Namespace("shape") public static native int outerArrayOffsets(@Cast("Nd4jLong*") long[] maxOffsets, @Cast("const Nd4jLong") long minIdx, @Cast("const Nd4jLong*") long[] maxShapeInfo, @Cast("const Nd4jLong*") long[] minShapeInfo, int[] memBuff, @Const int[] dimsToExclude/*=nullptr*/); + @Namespace("shape") public static native int outerArrayOffsets(@Cast("Nd4jLong*") long[] maxOffsets, @Cast("const Nd4jLong") long minIdx, @Cast("const Nd4jLong*") long[] maxShapeInfo, @Cast("const Nd4jLong*") long[] minShapeInfo, int[] memBuff); // calculates offsets for entities (elements or sub-arrays), shape in context of sub-array means dimensions excluded from outer array // rank is equal to size of shape @@ -8055,9 +8064,7 @@ public static final int PREALLOC_SIZE = 33554432; * get stride over contiguous axis (contiguous axis must have stride = 1) * for example when inShapeInfo is {4, 2,5,4,3, 60,1,5,20, 16384,0,99} then output is 5 (that is smallest stride in inShapeInfo except those equal to 1) */ - @Namespace("shape") public static native @Cast("Nd4jLong") long strideOverContigAxis(int axis, @Cast("const Nd4jLong*") LongPointer inShapeInfo); - @Namespace("shape") public static native @Cast("Nd4jLong") long strideOverContigAxis(int axis, @Cast("const Nd4jLong*") LongBuffer inShapeInfo); - @Namespace("shape") public static native @Cast("Nd4jLong") long strideOverContigAxis(int axis, @Cast("const Nd4jLong*") long[] inShapeInfo); + // INLINEDEF _CUDA_HD Nd4jLong strideOverContigAxis(const int axis, const Nd4jLong* inShapeInfo); @@ -8933,6 +8940,10 @@ public static final int PREALLOC_SIZE = 33554432; ////////////////////////////////////////////////////////////////////// +////////////////////////////////////////////////////////////////////// + +////////////////////////////////////////////////////////////////////// + ////////////////////////////////////////////////////////////////////// // INLINEDEF _CUDA_HD void calcOffsets(const Nd4jLong *xShapeInfo, Nd4jLong*& xOffsets, const Nd4jLong *yShapeInfo, Nd4jLong*& yOffsets, const Nd4jLong* zShapeInfo, Nd4jLong*& zOffsets, const char order) { @@ -9129,6 +9140,23 @@ public static final int PREALLOC_SIZE = 33554432; ////////////////////////////////////////////////////////////////////// +// INLINEDEF _CUDA_HD Nd4jLong strideOverContigAxis(const int axis, const Nd4jLong* inShapeInfo) { + +// Nd4jLong result = 9223372036854775807LL; + +// for(uint i = 0; i < shape::rank(inShapeInfo); ++i) { + +// const auto currentStride = shape::stride(inShapeInfo)[i]; + +// if(i == axis || shape::shapeOf(inShapeInfo)[i] == 1) +// continue; + +// if(result > currentStride) +// result = currentStride; +// } + +// return result == 9223372036854775807LL ? 1 : result; +// } @@ -11949,18 +11977,17 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD(); public native @Cast("Nd4jStatus") int execute(@Const @ByRef NDArrayVector inputs, @Const @ByRef NDArrayVector outputs, @StdVector DoubleBuffer tArgs, @Cast("Nd4jLong*") @StdVector LongBuffer iArgs, @Cast("bool*") @StdVector BooleanPointer bArgs/*=std::vector()*/, @Cast("sd::DataType*") @StdVector IntBuffer dArgs/*=std::vector()*/, @Cast("bool") boolean isInplace/*=false*/); public native @Cast("Nd4jStatus") int execute(@Const @ByRef NDArrayVector inputs, @Const @ByRef NDArrayVector outputs, @StdVector double[] tArgs, @Cast("Nd4jLong*") @StdVector long[] iArgs, @Cast("bool*") @StdVector boolean[] bArgs/*=std::vector()*/, @Cast("sd::DataType*") @StdVector int[] dArgs/*=std::vector()*/, @Cast("bool") boolean isInplace/*=false*/); + public native @ByVal ResultSet evaluate(@Const @ByRef NDArrayVector inputs); - public native ResultSet evaluate(@Const @ByRef NDArrayVector inputs); - - public native ResultSet evaluate(@Const @ByRef NDArrayVector inputs, @StdVector DoublePointer tArgs, @Cast("Nd4jLong*") @StdVector LongPointer iArgs, @Cast("bool*") @StdVector BooleanPointer bArgs/*=std::vector()*/, @Cast("sd::DataType*") @StdVector IntPointer dArgs/*=std::vector()*/, @Cast("bool") boolean isInplace/*=false*/); - public native ResultSet evaluate(@Const @ByRef NDArrayVector inputs, @StdVector DoublePointer tArgs, @Cast("Nd4jLong*") @StdVector LongPointer iArgs); - public native ResultSet evaluate(@Const @ByRef NDArrayVector inputs, @StdVector DoubleBuffer tArgs, @Cast("Nd4jLong*") @StdVector LongBuffer iArgs, @Cast("bool*") @StdVector boolean[] bArgs/*=std::vector()*/, @Cast("sd::DataType*") @StdVector IntBuffer dArgs/*=std::vector()*/, @Cast("bool") boolean isInplace/*=false*/); - public native ResultSet evaluate(@Const @ByRef NDArrayVector inputs, @StdVector DoubleBuffer tArgs, @Cast("Nd4jLong*") @StdVector LongBuffer iArgs); - public native ResultSet evaluate(@Const @ByRef NDArrayVector inputs, @StdVector double[] tArgs, @Cast("Nd4jLong*") @StdVector long[] iArgs, @Cast("bool*") @StdVector BooleanPointer bArgs/*=std::vector()*/, @Cast("sd::DataType*") @StdVector int[] dArgs/*=std::vector()*/, @Cast("bool") boolean isInplace/*=false*/); - public native ResultSet evaluate(@Const @ByRef NDArrayVector inputs, @StdVector double[] tArgs, @Cast("Nd4jLong*") @StdVector long[] iArgs); - public native ResultSet evaluate(@Const @ByRef NDArrayVector inputs, @StdVector DoublePointer tArgs, @Cast("Nd4jLong*") @StdVector LongPointer iArgs, @Cast("bool*") @StdVector boolean[] bArgs/*=std::vector()*/, @Cast("sd::DataType*") @StdVector IntPointer dArgs/*=std::vector()*/, @Cast("bool") boolean isInplace/*=false*/); - public native ResultSet evaluate(@Const @ByRef NDArrayVector inputs, @StdVector DoubleBuffer tArgs, @Cast("Nd4jLong*") @StdVector LongBuffer iArgs, @Cast("bool*") @StdVector BooleanPointer bArgs/*=std::vector()*/, @Cast("sd::DataType*") @StdVector IntBuffer dArgs/*=std::vector()*/, @Cast("bool") boolean isInplace/*=false*/); - public native ResultSet evaluate(@Const @ByRef NDArrayVector inputs, @StdVector double[] tArgs, @Cast("Nd4jLong*") @StdVector long[] iArgs, @Cast("bool*") @StdVector boolean[] bArgs/*=std::vector()*/, @Cast("sd::DataType*") @StdVector int[] dArgs/*=std::vector()*/, @Cast("bool") boolean isInplace/*=false*/); + public native @ByVal ResultSet evaluate(@Const @ByRef NDArrayVector inputs, @StdVector DoublePointer tArgs, @Cast("Nd4jLong*") @StdVector LongPointer iArgs, @Cast("bool*") @StdVector BooleanPointer bArgs/*=std::vector()*/, @Cast("sd::DataType*") @StdVector IntPointer dArgs/*=std::vector()*/, @Cast("bool") boolean isInplace/*=false*/); + public native @ByVal ResultSet evaluate(@Const @ByRef NDArrayVector inputs, @StdVector DoublePointer tArgs, @Cast("Nd4jLong*") @StdVector LongPointer iArgs); + public native @ByVal ResultSet evaluate(@Const @ByRef NDArrayVector inputs, @StdVector DoubleBuffer tArgs, @Cast("Nd4jLong*") @StdVector LongBuffer iArgs, @Cast("bool*") @StdVector boolean[] bArgs/*=std::vector()*/, @Cast("sd::DataType*") @StdVector IntBuffer dArgs/*=std::vector()*/, @Cast("bool") boolean isInplace/*=false*/); + public native @ByVal ResultSet evaluate(@Const @ByRef NDArrayVector inputs, @StdVector DoubleBuffer tArgs, @Cast("Nd4jLong*") @StdVector LongBuffer iArgs); + public native @ByVal ResultSet evaluate(@Const @ByRef NDArrayVector inputs, @StdVector double[] tArgs, @Cast("Nd4jLong*") @StdVector long[] iArgs, @Cast("bool*") @StdVector BooleanPointer bArgs/*=std::vector()*/, @Cast("sd::DataType*") @StdVector int[] dArgs/*=std::vector()*/, @Cast("bool") boolean isInplace/*=false*/); + public native @ByVal ResultSet evaluate(@Const @ByRef NDArrayVector inputs, @StdVector double[] tArgs, @Cast("Nd4jLong*") @StdVector long[] iArgs); + public native @ByVal ResultSet evaluate(@Const @ByRef NDArrayVector inputs, @StdVector DoublePointer tArgs, @Cast("Nd4jLong*") @StdVector LongPointer iArgs, @Cast("bool*") @StdVector boolean[] bArgs/*=std::vector()*/, @Cast("sd::DataType*") @StdVector IntPointer dArgs/*=std::vector()*/, @Cast("bool") boolean isInplace/*=false*/); + public native @ByVal ResultSet evaluate(@Const @ByRef NDArrayVector inputs, @StdVector DoubleBuffer tArgs, @Cast("Nd4jLong*") @StdVector LongBuffer iArgs, @Cast("bool*") @StdVector BooleanPointer bArgs/*=std::vector()*/, @Cast("sd::DataType*") @StdVector IntBuffer dArgs/*=std::vector()*/, @Cast("bool") boolean isInplace/*=false*/); + public native @ByVal ResultSet evaluate(@Const @ByRef NDArrayVector inputs, @StdVector double[] tArgs, @Cast("Nd4jLong*") @StdVector long[] iArgs, @Cast("bool*") @StdVector boolean[] bArgs/*=std::vector()*/, @Cast("sd::DataType*") @StdVector int[] dArgs/*=std::vector()*/, @Cast("bool") boolean isInplace/*=false*/); public native @Cast("Nd4jStatus") int execute(@ByRef RandomGenerator rng, @Const @ByRef NDArrayVector inputs, @Const @ByRef NDArrayVector outputs, @StdVector DoublePointer tArgs, @Cast("Nd4jLong*") @StdVector LongPointer iArgs, @Cast("bool*") @StdVector BooleanPointer bArgs, @Cast("sd::DataType*") @StdVector IntPointer dArgs/*=std::vector()*/, @Cast("bool") boolean isInplace/*=false*/, @Cast("sd::DataType") int type/*=sd::DataType::FLOAT32*/); public native @Cast("Nd4jStatus") int execute(@ByRef RandomGenerator rng, @Const @ByRef NDArrayVector inputs, @Const @ByRef NDArrayVector outputs, @StdVector DoublePointer tArgs, @Cast("Nd4jLong*") @StdVector LongPointer iArgs, @Cast("bool*") @StdVector BooleanPointer bArgs); @@ -11975,8 +12002,9 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD(); public native @Cast("Nd4jStatus") int execute(@ByRef RandomGenerator rng, @Const @ByRef NDArrayVector inputs, @Const @ByRef NDArrayVector outputs, @StdVector double[] tArgs, @Cast("Nd4jLong*") @StdVector long[] iArgs, @Cast("bool*") @StdVector boolean[] bArgs, @Cast("sd::DataType*") @StdVector int[] dArgs/*=std::vector()*/, @Cast("bool") boolean isInplace/*=false*/, @Cast("sd::DataType") int type/*=sd::DataType::FLOAT32*/); public native @Cast("Nd4jStatus") int execute(@ByRef RandomGenerator rng, @Const @ByRef NDArrayVector inputs, @Const @ByRef NDArrayVector outputs, @StdVector double[] tArgs, @Cast("Nd4jLong*") @StdVector long[] iArgs, @Cast("bool*") @StdVector boolean[] bArgs); - public native ResultSet execute(@Const @ByRef OpArgsHolder holder, @Cast("bool") boolean isInplace/*=false*/); - public native ResultSet execute(@Const @ByRef OpArgsHolder holder); + public native @ByVal ResultSet execute(@Const @ByRef OpArgsHolder holder, @Cast("bool") boolean isInplace/*=false*/); + public native @ByVal ResultSet execute(@Const @ByRef OpArgsHolder holder); + // There methods provide various validation options public native @Cast("Nd4jStatus") int validateNonEmptyInput(@ByRef Context block); @@ -12048,9 +12076,9 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD(); public native @Cast("Nd4jStatus") int execute(Context block); - public native ResultSet execute(NDArrayList list, @ByRef NDArrayVector inputs, @StdVector DoublePointer tArgs, @StdVector IntPointer iArgs); - public native ResultSet execute(NDArrayList list, @ByRef NDArrayVector inputs, @StdVector DoubleBuffer tArgs, @StdVector IntBuffer iArgs); - public native ResultSet execute(NDArrayList list, @ByRef NDArrayVector inputs, @StdVector double[] tArgs, @StdVector int[] iArgs); + public native @ByVal ResultSet execute(NDArrayList list, @ByRef NDArrayVector inputs, @StdVector DoublePointer tArgs, @StdVector IntPointer iArgs); + public native @ByVal ResultSet execute(NDArrayList list, @ByRef NDArrayVector inputs, @StdVector DoubleBuffer tArgs, @StdVector IntBuffer iArgs); + public native @ByVal ResultSet execute(NDArrayList list, @ByRef NDArrayVector inputs, @StdVector double[] tArgs, @StdVector int[] iArgs); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/ShapeOpValidation.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/ShapeOpValidation.java index 733628490..30d4baf5c 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/ShapeOpValidation.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/ShapeOpValidation.java @@ -2050,7 +2050,7 @@ public class ShapeOpValidation extends BaseOpValidation { print(out[0]); */ - INDArray emptyIn = Nd4j.empty(DataType.FLOAT); + INDArray emptyIn = Nd4j.empty(DataType.FLOAT).reshape(0, 4); INDArray axis = Nd4j.scalar(1); DynamicCustomOp op = DynamicCustomOp.builder("split") @@ -2061,9 +2061,10 @@ public class ShapeOpValidation extends BaseOpValidation { List l = op.calculateOutputShape(); assertEquals(4, l.size()); for( int i=0; i<4; i++ ){ - assertArrayEquals(new long[0], l.get(i).getShape()); - assertTrue(l.get(i).isEmpty()); - op.addOutputArgument(Nd4j.empty(DataType.FLOAT)); + val desc = l.get(i); + assertArrayEquals(new long[]{0, 1}, desc.getShape()); + assertTrue(desc.isEmpty()); + op.addOutputArgument(Nd4j.empty(DataType.FLOAT).reshape(desc.getShape())); } Nd4j.exec(op);