correct output empty shapes deducing in split op (#311)
* - correct output empty shapes deducing in split op Signed-off-by: Yurii <iuriish@yahoo.com> * java test fixed Signed-off-by: raver119 <raver119@gmail.com> * - split broadcast::exec function on individual functions corresponding to switch arg Signed-off-by: Yurii <iuriish@yahoo.com> * - split broadcast::exec _int and _bool function on individual functions corresponding to switch arg Signed-off-by: Yurii <iuriish@yahoo.com> Co-authored-by: raver119 <raver119@gmail.com>master
parent
41bde8f885
commit
e42b4e96c3
|
@ -573,53 +573,13 @@ template <typename X, typename Y, typename Z>
|
|||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////
|
||||
template <typename X, typename Y, typename Z>
|
||||
template<typename OpType>
|
||||
void Broadcast<X, Y, Z>::exec(const void *vx, const Nd4jLong *xShapeInfo, const void *vy, const Nd4jLong *yShapeInfo, void *vz, const Nd4jLong *zShapeInfo) {
|
||||
template <typename X, typename Y, typename Z, typename OpType>
|
||||
static void execRank1(const X *x, const Nd4jLong *xShapeInfo, const Y *y, const Nd4jLong *yShapeInfo, Z* z, const Nd4jLong *zShapeInfo) {
|
||||
|
||||
const X* x = reinterpret_cast<const X*>(vx);
|
||||
const Y* y = reinterpret_cast<const Y*>(vy);
|
||||
Z* z = reinterpret_cast<Z*>(vz);
|
||||
|
||||
const int rank = shape::rank(zShapeInfo); // xRank = yRank = zRank
|
||||
const char zOrder = shape::order(zShapeInfo);
|
||||
|
||||
uint xAxis0 = shape::sizeAt(xShapeInfo, zOrder == 'c' ? 0 : rank-1);
|
||||
uint xAxis1 = shape::sizeAt(xShapeInfo, zOrder == 'c' ? 1 : rank-2);
|
||||
uint xAxis2 = rank > 2 ? shape::sizeAt(xShapeInfo, zOrder == 'c' ? 2 : rank - 3) : 0;
|
||||
uint xAxis3 = rank > 3 ? shape::sizeAt(xShapeInfo, zOrder == 'c' ? 3 : rank - 4) : 0;
|
||||
uint xAxis4 = rank > 4 ? shape::sizeAt(xShapeInfo, zOrder == 'c' ? 4 : rank - 5) : 0;
|
||||
Nd4jLong xStrd0 = shape::strideAt(xShapeInfo, zOrder == 'c' ? 0 : rank-1);
|
||||
Nd4jLong xStrd1 = shape::strideAt(xShapeInfo, zOrder == 'c' ? 1 : rank-2);
|
||||
Nd4jLong xStrd2 = rank > 2 ? shape::strideAt(xShapeInfo, zOrder == 'c' ? 2 : rank - 3) : 0;
|
||||
Nd4jLong xStrd3 = rank > 3 ? shape::strideAt(xShapeInfo, zOrder == 'c' ? 3 : rank - 4) : 0;
|
||||
Nd4jLong xStrd4 = rank > 4 ? shape::strideAt(xShapeInfo, zOrder == 'c' ? 4 : rank - 5) : 0;
|
||||
|
||||
uint yAxis0 = shape::sizeAt(yShapeInfo, zOrder == 'c' ? 0 : rank-1);
|
||||
uint yAxis1 = shape::sizeAt(yShapeInfo, zOrder == 'c' ? 1 : rank-2);
|
||||
uint yAxis2 = rank > 2 ? shape::sizeAt(yShapeInfo, zOrder == 'c' ? 2 : rank - 3) : 0;
|
||||
uint yAxis3 = rank > 3 ? shape::sizeAt(yShapeInfo, zOrder == 'c' ? 3 : rank - 4) : 0;
|
||||
uint yAxis4 = rank > 4 ? shape::sizeAt(yShapeInfo, zOrder == 'c' ? 4 : rank - 5) : 0;
|
||||
Nd4jLong yStrd0 = shape::strideAt(yShapeInfo, zOrder == 'c' ? 0 : rank-1);
|
||||
Nd4jLong yStrd1 = shape::strideAt(yShapeInfo, zOrder == 'c' ? 1 : rank-2);
|
||||
Nd4jLong yStrd2 = rank > 2 ? shape::strideAt(yShapeInfo, zOrder == 'c' ? 2 : rank - 3) : 0;
|
||||
Nd4jLong yStrd3 = rank > 3 ? shape::strideAt(yShapeInfo, zOrder == 'c' ? 3 : rank - 4) : 0;
|
||||
Nd4jLong yStrd4 = rank > 4 ? shape::strideAt(yShapeInfo, zOrder == 'c' ? 4 : rank - 5) : 0;
|
||||
|
||||
uint zAxis0 = shape::sizeAt(zShapeInfo, zOrder == 'c' ? 0 : rank-1);
|
||||
uint zAxis1 = shape::sizeAt(zShapeInfo, zOrder == 'c' ? 1 : rank-2);
|
||||
uint zAxis2 = rank > 2 ? shape::sizeAt(zShapeInfo, zOrder == 'c' ? 2 : rank - 3) : 0;
|
||||
uint zAxis3 = rank > 3 ? shape::sizeAt(zShapeInfo, zOrder == 'c' ? 3 : rank - 4) : 0;
|
||||
uint zAxis4 = rank > 4 ? shape::sizeAt(zShapeInfo, zOrder == 'c' ? 4 : rank - 5) : 0;
|
||||
Nd4jLong zStrd0 = shape::strideAt(zShapeInfo, zOrder == 'c' ? 0 : rank-1);
|
||||
Nd4jLong zStrd1 = shape::strideAt(zShapeInfo, zOrder == 'c' ? 1 : rank-2);
|
||||
Nd4jLong zStrd2 = rank > 2 ? shape::strideAt(zShapeInfo, zOrder == 'c' ? 2 : rank - 3) : 0;
|
||||
Nd4jLong zStrd3 = rank > 3 ? shape::strideAt(zShapeInfo, zOrder == 'c' ? 3 : rank - 4) : 0;
|
||||
Nd4jLong zStrd4 = rank > 4 ? shape::strideAt(zShapeInfo, zOrder == 'c' ? 4 : rank - 5) : 0;
|
||||
|
||||
switch (rank) {
|
||||
|
||||
case 1: {
|
||||
uint zAxis0 = shape::sizeAt(zShapeInfo, 0);
|
||||
Nd4jLong xStrd0 = shape::strideAt(xShapeInfo, 0);
|
||||
Nd4jLong yStrd0 = shape::strideAt(yShapeInfo, 0);
|
||||
Nd4jLong zStrd0 = shape::strideAt(zShapeInfo, 0);
|
||||
|
||||
auto func = PRAGMA_THREADS_FOR{
|
||||
|
||||
|
@ -641,10 +601,21 @@ void Broadcast<X, Y, Z>::exec(const void *vx, const Nd4jLong *xShapeInfo, const
|
|||
}
|
||||
};
|
||||
samediff::Threads::parallel_tad(func, 0, zAxis0);
|
||||
}
|
||||
break;
|
||||
}
|
||||
|
||||
case 2: {
|
||||
////////////////////////////////////////////////////////////////////////
|
||||
template <typename X, typename Y, typename Z, typename OpType>
|
||||
static void execRank2(const X *x, const Nd4jLong *xShapeInfo, const Y *y, const Nd4jLong *yShapeInfo, Z* z, const Nd4jLong *zShapeInfo) {
|
||||
|
||||
uint zAxis0 = shape::sizeAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 1);
|
||||
Nd4jLong xStrd0 = shape::strideAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 1);
|
||||
Nd4jLong yStrd0 = shape::strideAt(yShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 1);
|
||||
Nd4jLong zStrd0 = shape::strideAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 1);
|
||||
|
||||
uint zAxis1 = shape::sizeAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 1 : 0);
|
||||
Nd4jLong xStrd1 = shape::strideAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? 1 : 0);
|
||||
Nd4jLong yStrd1 = shape::strideAt(yShapeInfo, shape::order(zShapeInfo) == 'c' ? 1 : 0);
|
||||
Nd4jLong zStrd1 = shape::strideAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 1 : 0);
|
||||
|
||||
auto func = PRAGMA_THREADS_FOR{
|
||||
|
||||
|
@ -668,11 +639,28 @@ void Broadcast<X, Y, Z>::exec(const void *vx, const Nd4jLong *xShapeInfo, const
|
|||
z0[i1 * zStrd1] = OpType::op(x0[i1 * xStrd1], y0[i1 * yStrd1]);
|
||||
}
|
||||
};
|
||||
samediff::Threads::parallel_tad(func, 0, zAxis0);
|
||||
}
|
||||
break;
|
||||
|
||||
case 3: {
|
||||
samediff::Threads::parallel_tad(func, 0, zAxis0);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////
|
||||
template <typename X, typename Y, typename Z, typename OpType>
|
||||
static void execRank3(const X *x, const Nd4jLong *xShapeInfo, const Y *y, const Nd4jLong *yShapeInfo, Z* z, const Nd4jLong *zShapeInfo) {
|
||||
|
||||
uint zAxis0 = shape::sizeAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 2);
|
||||
Nd4jLong xStrd0 = shape::strideAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 2);
|
||||
Nd4jLong yStrd0 = shape::strideAt(yShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 2);
|
||||
Nd4jLong zStrd0 = shape::strideAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 2);
|
||||
|
||||
uint zAxis1 = shape::sizeAt(zShapeInfo, 1);
|
||||
Nd4jLong xStrd1 = shape::strideAt(xShapeInfo, 1);
|
||||
Nd4jLong yStrd1 = shape::strideAt(yShapeInfo, 1);
|
||||
Nd4jLong zStrd1 = shape::strideAt(zShapeInfo, 1);
|
||||
|
||||
uint zAxis2 = shape::sizeAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 2 : 0);
|
||||
Nd4jLong xStrd2 = shape::strideAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? 2 : 0);
|
||||
Nd4jLong yStrd2 = shape::strideAt(yShapeInfo, shape::order(zShapeInfo) == 'c' ? 2 : 0);
|
||||
Nd4jLong zStrd2 = shape::strideAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 2 : 0);
|
||||
|
||||
auto func = PRAGMA_THREADS_FOR_2D {
|
||||
|
||||
|
@ -698,11 +686,33 @@ void Broadcast<X, Y, Z>::exec(const void *vx, const Nd4jLong *xShapeInfo, const
|
|||
}
|
||||
}
|
||||
};
|
||||
samediff::Threads::parallel_for(func, 0,zAxis0,1, 0,zAxis1,1);
|
||||
}
|
||||
break;
|
||||
|
||||
case 4: {
|
||||
samediff::Threads::parallel_for(func, 0,zAxis0,1, 0,zAxis1,1);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////
|
||||
template <typename X, typename Y, typename Z, typename OpType>
|
||||
static void execRank4(const X *x, const Nd4jLong *xShapeInfo, const Y *y, const Nd4jLong *yShapeInfo, Z* z, const Nd4jLong *zShapeInfo) {
|
||||
|
||||
uint zAxis0 = shape::sizeAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 3);
|
||||
Nd4jLong xStrd0 = shape::strideAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 3);
|
||||
Nd4jLong yStrd0 = shape::strideAt(yShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 3);
|
||||
Nd4jLong zStrd0 = shape::strideAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 3);
|
||||
|
||||
uint zAxis1 = shape::sizeAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 1 : 2);
|
||||
Nd4jLong xStrd1 = shape::strideAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? 1 : 2);
|
||||
Nd4jLong yStrd1 = shape::strideAt(yShapeInfo, shape::order(zShapeInfo) == 'c' ? 1 : 2);
|
||||
Nd4jLong zStrd1 = shape::strideAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 1 : 2);
|
||||
|
||||
uint zAxis2 = shape::sizeAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 2 : 1);
|
||||
Nd4jLong xStrd2 = shape::strideAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? 2 : 1);
|
||||
Nd4jLong yStrd2 = shape::strideAt(yShapeInfo, shape::order(zShapeInfo) == 'c' ? 2 : 1);
|
||||
Nd4jLong zStrd2 = shape::strideAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 2 : 1);
|
||||
|
||||
uint zAxis3 = shape::sizeAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 3 : 0);
|
||||
Nd4jLong xStrd3 = shape::strideAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? 3 : 0);
|
||||
Nd4jLong yStrd3 = shape::strideAt(yShapeInfo, shape::order(zShapeInfo) == 'c' ? 3 : 0);
|
||||
Nd4jLong zStrd3 = shape::strideAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 3 : 0);
|
||||
|
||||
auto func = PRAGMA_THREADS_FOR_3D {
|
||||
|
||||
|
@ -730,11 +740,38 @@ void Broadcast<X, Y, Z>::exec(const void *vx, const Nd4jLong *xShapeInfo, const
|
|||
}
|
||||
}
|
||||
};
|
||||
samediff::Threads::parallel_for(func, 0,zAxis0,1, 0,zAxis1,1, 0,zAxis2,1);
|
||||
}
|
||||
break;
|
||||
|
||||
case 5: {
|
||||
samediff::Threads::parallel_for(func, 0,zAxis0,1, 0,zAxis1,1, 0,zAxis2,1);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////
|
||||
template <typename X, typename Y, typename Z, typename OpType>
|
||||
static void execRank5(const X *x, const Nd4jLong *xShapeInfo, const Y *y, const Nd4jLong *yShapeInfo, Z* z, const Nd4jLong *zShapeInfo) {
|
||||
|
||||
uint zAxis0 = shape::sizeAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 4);
|
||||
Nd4jLong xStrd0 = shape::strideAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 4);
|
||||
Nd4jLong yStrd0 = shape::strideAt(yShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 4);
|
||||
Nd4jLong zStrd0 = shape::strideAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 4);
|
||||
|
||||
uint zAxis1 = shape::sizeAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 1 : 3);
|
||||
Nd4jLong xStrd1 = shape::strideAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? 1 : 3);
|
||||
Nd4jLong yStrd1 = shape::strideAt(yShapeInfo, shape::order(zShapeInfo) == 'c' ? 1 : 3);
|
||||
Nd4jLong zStrd1 = shape::strideAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 1 : 3);
|
||||
|
||||
uint zAxis2 = shape::sizeAt(zShapeInfo, 2);
|
||||
Nd4jLong xStrd2 = shape::strideAt(xShapeInfo, 2);
|
||||
Nd4jLong yStrd2 = shape::strideAt(yShapeInfo, 2);
|
||||
Nd4jLong zStrd2 = shape::strideAt(zShapeInfo, 2);
|
||||
|
||||
uint zAxis3 = shape::sizeAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 3 : 1);
|
||||
Nd4jLong xStrd3 = shape::strideAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? 3 : 1);
|
||||
Nd4jLong yStrd3 = shape::strideAt(yShapeInfo, shape::order(zShapeInfo) == 'c' ? 3 : 1);
|
||||
Nd4jLong zStrd3 = shape::strideAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 3 : 1);
|
||||
|
||||
uint zAxis4 = shape::sizeAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 4 : 0);
|
||||
Nd4jLong xStrd4 = shape::strideAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? 4 : 0);
|
||||
Nd4jLong yStrd4 = shape::strideAt(yShapeInfo, shape::order(zShapeInfo) == 'c' ? 4 : 0);
|
||||
Nd4jLong zStrd4 = shape::strideAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 4 : 0);
|
||||
|
||||
auto func = PRAGMA_THREADS_FOR_3D {
|
||||
|
||||
|
@ -764,15 +801,19 @@ void Broadcast<X, Y, Z>::exec(const void *vx, const Nd4jLong *xShapeInfo, const
|
|||
}
|
||||
}
|
||||
};
|
||||
samediff::Threads::parallel_for(func, 0,zAxis0,1, 0,zAxis1,1, 0,zAxis2,1);
|
||||
}
|
||||
break;
|
||||
|
||||
default: {
|
||||
samediff::Threads::parallel_for(func, 0,zAxis0,1, 0,zAxis1,1, 0,zAxis2,1);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////
|
||||
template <typename X, typename Y, typename Z, typename OpType>
|
||||
static void execDefault(const X *x, const Nd4jLong *xShapeInfo, const Y *y, const Nd4jLong *yShapeInfo, Z* z, const Nd4jLong *zShapeInfo) {
|
||||
|
||||
const bool xzSameOffsets = shape::haveSameShapeAndStrides(xShapeInfo, zShapeInfo);
|
||||
const bool yzSameOffsets = shape::haveSameShapeAndStrides(yShapeInfo, zShapeInfo);
|
||||
|
||||
const int rank = shape::rank(zShapeInfo); // xRank = yRank = zRank
|
||||
|
||||
auto func = PRAGMA_THREADS_FOR{
|
||||
|
||||
int xCoords[MAX_RANK], yCoords[MAX_RANK], zCoords[MAX_RANK];
|
||||
|
@ -795,7 +836,38 @@ void Broadcast<X, Y, Z>::exec(const void *vx, const Nd4jLong *xShapeInfo, const
|
|||
};
|
||||
|
||||
samediff::Threads::parallel_for(func, 0, shape::length(zShapeInfo));
|
||||
}
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////
|
||||
template <typename X, typename Y, typename Z>
|
||||
template<typename OpType>
|
||||
void Broadcast<X, Y, Z>::exec(const void *vx, const Nd4jLong *xShapeInfo, const void *vy, const Nd4jLong *yShapeInfo, void *vz, const Nd4jLong *zShapeInfo) {
|
||||
|
||||
const X* x = reinterpret_cast<const X*>(vx);
|
||||
const Y* y = reinterpret_cast<const Y*>(vy);
|
||||
Z* z = reinterpret_cast<Z*>(vz);
|
||||
|
||||
const int rank = shape::rank(zShapeInfo); // xRank = yRank = zRank
|
||||
|
||||
switch (rank) {
|
||||
|
||||
case 1:
|
||||
execRank1<X,Y,Z, OpType>(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo);
|
||||
break;
|
||||
case 2:
|
||||
execRank2<X,Y,Z, OpType>(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo);
|
||||
break;
|
||||
case 3:
|
||||
execRank3<X,Y,Z, OpType>(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo);
|
||||
break;
|
||||
case 4:
|
||||
execRank4<X,Y,Z, OpType>(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo);
|
||||
break;
|
||||
case 5:
|
||||
execRank5<X,Y,Z, OpType>(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo);
|
||||
break;
|
||||
default:
|
||||
execDefault<X,Y,Z, OpType>(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -454,58 +454,13 @@ namespace broadcast {
|
|||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////
|
||||
template <typename X, typename Z>
|
||||
template<typename OpType>
|
||||
void BroadcastBool<X, Z>::exec(const void *vx, const Nd4jLong *xShapeInfo,
|
||||
const void *vy, const Nd4jLong *yShapeInfo,
|
||||
void *vz, const Nd4jLong *zShapeInfo,
|
||||
void *vextraParams) {
|
||||
template <typename X, typename Z, typename OpType>
|
||||
static void execRank1(const X *x, const Nd4jLong *xShapeInfo, const X *y, const Nd4jLong *yShapeInfo, Z* z, const Nd4jLong *zShapeInfo, X* extraParams) {
|
||||
|
||||
const X* x = reinterpret_cast<const X*>(vx);
|
||||
const X* y = reinterpret_cast<const X*>(vy);
|
||||
Z* z = reinterpret_cast<Z*>(vz);
|
||||
|
||||
X* extraParams = reinterpret_cast<X*>(vextraParams);
|
||||
|
||||
const int rank = shape::rank(zShapeInfo); // xRank = yRank = zRank
|
||||
const char zOrder = shape::order(zShapeInfo);
|
||||
|
||||
uint xAxis0 = shape::sizeAt(xShapeInfo, zOrder == 'c' ? 0 : rank-1);
|
||||
uint xAxis1 = shape::sizeAt(xShapeInfo, zOrder == 'c' ? 1 : rank-2);
|
||||
uint xAxis2 = rank > 2 ? shape::sizeAt(xShapeInfo, zOrder == 'c' ? 2 : rank - 3) : 0;
|
||||
uint xAxis3 = rank > 3 ? shape::sizeAt(xShapeInfo, zOrder == 'c' ? 3 : rank - 4) : 0;
|
||||
uint xAxis4 = rank > 4 ? shape::sizeAt(xShapeInfo, zOrder == 'c' ? 4 : rank - 5) : 0;
|
||||
Nd4jLong xStrd0 = shape::strideAt(xShapeInfo, zOrder == 'c' ? 0 : rank-1);
|
||||
Nd4jLong xStrd1 = shape::strideAt(xShapeInfo, zOrder == 'c' ? 1 : rank-2);
|
||||
Nd4jLong xStrd2 = rank > 2 ? shape::strideAt(xShapeInfo, zOrder == 'c' ? 2 : rank - 3) : 0;
|
||||
Nd4jLong xStrd3 = rank > 3 ? shape::strideAt(xShapeInfo, zOrder == 'c' ? 3 : rank - 4) : 0;
|
||||
Nd4jLong xStrd4 = rank > 4 ? shape::strideAt(xShapeInfo, zOrder == 'c' ? 4 : rank - 5) : 0;
|
||||
|
||||
uint yAxis0 = shape::sizeAt(yShapeInfo, zOrder == 'c' ? 0 : rank-1);
|
||||
uint yAxis1 = shape::sizeAt(yShapeInfo, zOrder == 'c' ? 1 : rank-2);
|
||||
uint yAxis2 = rank > 2 ? shape::sizeAt(yShapeInfo, zOrder == 'c' ? 2 : rank - 3) : 0;
|
||||
uint yAxis3 = rank > 3 ? shape::sizeAt(yShapeInfo, zOrder == 'c' ? 3 : rank - 4) : 0;
|
||||
uint yAxis4 = rank > 4 ? shape::sizeAt(yShapeInfo, zOrder == 'c' ? 4 : rank - 5) : 0;
|
||||
Nd4jLong yStrd0 = shape::strideAt(yShapeInfo, zOrder == 'c' ? 0 : rank-1);
|
||||
Nd4jLong yStrd1 = shape::strideAt(yShapeInfo, zOrder == 'c' ? 1 : rank-2);
|
||||
Nd4jLong yStrd2 = rank > 2 ? shape::strideAt(yShapeInfo, zOrder == 'c' ? 2 : rank - 3) : 0;
|
||||
Nd4jLong yStrd3 = rank > 3 ? shape::strideAt(yShapeInfo, zOrder == 'c' ? 3 : rank - 4) : 0;
|
||||
Nd4jLong yStrd4 = rank > 4 ? shape::strideAt(yShapeInfo, zOrder == 'c' ? 4 : rank - 5) : 0;
|
||||
|
||||
uint zAxis0 = shape::sizeAt(zShapeInfo, zOrder == 'c' ? 0 : rank-1);
|
||||
uint zAxis1 = shape::sizeAt(zShapeInfo, zOrder == 'c' ? 1 : rank-2);
|
||||
uint zAxis2 = rank > 2 ? shape::sizeAt(zShapeInfo, zOrder == 'c' ? 2 : rank - 3) : 0;
|
||||
uint zAxis3 = rank > 3 ? shape::sizeAt(zShapeInfo, zOrder == 'c' ? 3 : rank - 4) : 0;
|
||||
uint zAxis4 = rank > 4 ? shape::sizeAt(zShapeInfo, zOrder == 'c' ? 4 : rank - 5) : 0;
|
||||
Nd4jLong zStrd0 = shape::strideAt(zShapeInfo, zOrder == 'c' ? 0 : rank-1);
|
||||
Nd4jLong zStrd1 = shape::strideAt(zShapeInfo, zOrder == 'c' ? 1 : rank-2);
|
||||
Nd4jLong zStrd2 = rank > 2 ? shape::strideAt(zShapeInfo, zOrder == 'c' ? 2 : rank - 3) : 0;
|
||||
Nd4jLong zStrd3 = rank > 3 ? shape::strideAt(zShapeInfo, zOrder == 'c' ? 3 : rank - 4) : 0;
|
||||
Nd4jLong zStrd4 = rank > 4 ? shape::strideAt(zShapeInfo, zOrder == 'c' ? 4 : rank - 5) : 0;
|
||||
|
||||
switch (rank) {
|
||||
|
||||
case 1: {
|
||||
uint zAxis0 = shape::sizeAt(zShapeInfo, 0);
|
||||
Nd4jLong xStrd0 = shape::strideAt(xShapeInfo, 0);
|
||||
Nd4jLong yStrd0 = shape::strideAt(yShapeInfo, 0);
|
||||
Nd4jLong zStrd0 = shape::strideAt(zShapeInfo, 0);
|
||||
|
||||
auto func = PRAGMA_THREADS_FOR{
|
||||
|
||||
|
@ -527,10 +482,21 @@ void BroadcastBool<X, Z>::exec(const void *vx, const Nd4jLong *xShapeInfo,
|
|||
}
|
||||
};
|
||||
samediff::Threads::parallel_tad(func, 0, zAxis0);
|
||||
}
|
||||
break;
|
||||
}
|
||||
|
||||
case 2: {
|
||||
////////////////////////////////////////////////////////////////////////
|
||||
template <typename X, typename Z, typename OpType>
|
||||
static void execRank2(const X *x, const Nd4jLong *xShapeInfo, const X *y, const Nd4jLong *yShapeInfo, Z* z, const Nd4jLong *zShapeInfo, X* extraParams) {
|
||||
|
||||
uint zAxis0 = shape::sizeAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 1);
|
||||
Nd4jLong xStrd0 = shape::strideAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 1);
|
||||
Nd4jLong yStrd0 = shape::strideAt(yShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 1);
|
||||
Nd4jLong zStrd0 = shape::strideAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 1);
|
||||
|
||||
uint zAxis1 = shape::sizeAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 1 : 0);
|
||||
Nd4jLong xStrd1 = shape::strideAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? 1 : 0);
|
||||
Nd4jLong yStrd1 = shape::strideAt(yShapeInfo, shape::order(zShapeInfo) == 'c' ? 1 : 0);
|
||||
Nd4jLong zStrd1 = shape::strideAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 1 : 0);
|
||||
|
||||
auto func = PRAGMA_THREADS_FOR{
|
||||
|
||||
|
@ -554,11 +520,28 @@ void BroadcastBool<X, Z>::exec(const void *vx, const Nd4jLong *xShapeInfo,
|
|||
z0[i1 * zStrd1] = OpType::op(x0[i1 * xStrd1], y0[i1 * yStrd1], extraParams);
|
||||
}
|
||||
};
|
||||
samediff::Threads::parallel_tad(func, 0, zAxis0);
|
||||
}
|
||||
break;
|
||||
|
||||
case 3: {
|
||||
samediff::Threads::parallel_tad(func, 0, zAxis0);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////
|
||||
template <typename X, typename Z, typename OpType>
|
||||
static void execRank3(const X *x, const Nd4jLong *xShapeInfo, const X *y, const Nd4jLong *yShapeInfo, Z* z, const Nd4jLong *zShapeInfo, X* extraParams) {
|
||||
|
||||
uint zAxis0 = shape::sizeAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 2);
|
||||
Nd4jLong xStrd0 = shape::strideAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 2);
|
||||
Nd4jLong yStrd0 = shape::strideAt(yShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 2);
|
||||
Nd4jLong zStrd0 = shape::strideAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 2);
|
||||
|
||||
uint zAxis1 = shape::sizeAt(zShapeInfo, 1);
|
||||
Nd4jLong xStrd1 = shape::strideAt(xShapeInfo, 1);
|
||||
Nd4jLong yStrd1 = shape::strideAt(yShapeInfo, 1);
|
||||
Nd4jLong zStrd1 = shape::strideAt(zShapeInfo, 1);
|
||||
|
||||
uint zAxis2 = shape::sizeAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 2 : 0);
|
||||
Nd4jLong xStrd2 = shape::strideAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? 2 : 0);
|
||||
Nd4jLong yStrd2 = shape::strideAt(yShapeInfo, shape::order(zShapeInfo) == 'c' ? 2 : 0);
|
||||
Nd4jLong zStrd2 = shape::strideAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 2 : 0);
|
||||
|
||||
auto func = PRAGMA_THREADS_FOR_2D {
|
||||
|
||||
|
@ -584,11 +567,33 @@ void BroadcastBool<X, Z>::exec(const void *vx, const Nd4jLong *xShapeInfo,
|
|||
}
|
||||
}
|
||||
};
|
||||
samediff::Threads::parallel_for(func, 0,zAxis0,1, 0,zAxis1,1);
|
||||
}
|
||||
break;
|
||||
|
||||
case 4: {
|
||||
samediff::Threads::parallel_for(func, 0,zAxis0,1, 0,zAxis1,1);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////
|
||||
template <typename X, typename Z, typename OpType>
|
||||
static void execRank4(const X *x, const Nd4jLong *xShapeInfo, const X *y, const Nd4jLong *yShapeInfo, Z* z, const Nd4jLong *zShapeInfo, X* extraParams) {
|
||||
|
||||
uint zAxis0 = shape::sizeAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 3);
|
||||
Nd4jLong xStrd0 = shape::strideAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 3);
|
||||
Nd4jLong yStrd0 = shape::strideAt(yShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 3);
|
||||
Nd4jLong zStrd0 = shape::strideAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 3);
|
||||
|
||||
uint zAxis1 = shape::sizeAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 1 : 2);
|
||||
Nd4jLong xStrd1 = shape::strideAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? 1 : 2);
|
||||
Nd4jLong yStrd1 = shape::strideAt(yShapeInfo, shape::order(zShapeInfo) == 'c' ? 1 : 2);
|
||||
Nd4jLong zStrd1 = shape::strideAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 1 : 2);
|
||||
|
||||
uint zAxis2 = shape::sizeAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 2 : 1);
|
||||
Nd4jLong xStrd2 = shape::strideAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? 2 : 1);
|
||||
Nd4jLong yStrd2 = shape::strideAt(yShapeInfo, shape::order(zShapeInfo) == 'c' ? 2 : 1);
|
||||
Nd4jLong zStrd2 = shape::strideAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 2 : 1);
|
||||
|
||||
uint zAxis3 = shape::sizeAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 3 : 0);
|
||||
Nd4jLong xStrd3 = shape::strideAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? 3 : 0);
|
||||
Nd4jLong yStrd3 = shape::strideAt(yShapeInfo, shape::order(zShapeInfo) == 'c' ? 3 : 0);
|
||||
Nd4jLong zStrd3 = shape::strideAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 3 : 0);
|
||||
|
||||
auto func = PRAGMA_THREADS_FOR_3D {
|
||||
|
||||
|
@ -616,11 +621,38 @@ void BroadcastBool<X, Z>::exec(const void *vx, const Nd4jLong *xShapeInfo,
|
|||
}
|
||||
}
|
||||
};
|
||||
samediff::Threads::parallel_for(func, 0,zAxis0,1, 0,zAxis1,1, 0,zAxis2,1);
|
||||
}
|
||||
break;
|
||||
|
||||
case 5: {
|
||||
samediff::Threads::parallel_for(func, 0,zAxis0,1, 0,zAxis1,1, 0,zAxis2,1);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////
|
||||
template <typename X, typename Z, typename OpType>
|
||||
static void execRank5(const X *x, const Nd4jLong *xShapeInfo, const X *y, const Nd4jLong *yShapeInfo, Z* z, const Nd4jLong *zShapeInfo, X* extraParams) {
|
||||
|
||||
uint zAxis0 = shape::sizeAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 4);
|
||||
Nd4jLong xStrd0 = shape::strideAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 4);
|
||||
Nd4jLong yStrd0 = shape::strideAt(yShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 4);
|
||||
Nd4jLong zStrd0 = shape::strideAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 4);
|
||||
|
||||
uint zAxis1 = shape::sizeAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 1 : 3);
|
||||
Nd4jLong xStrd1 = shape::strideAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? 1 : 3);
|
||||
Nd4jLong yStrd1 = shape::strideAt(yShapeInfo, shape::order(zShapeInfo) == 'c' ? 1 : 3);
|
||||
Nd4jLong zStrd1 = shape::strideAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 1 : 3);
|
||||
|
||||
uint zAxis2 = shape::sizeAt(zShapeInfo, 2);
|
||||
Nd4jLong xStrd2 = shape::strideAt(xShapeInfo, 2);
|
||||
Nd4jLong yStrd2 = shape::strideAt(yShapeInfo, 2);
|
||||
Nd4jLong zStrd2 = shape::strideAt(zShapeInfo, 2);
|
||||
|
||||
uint zAxis3 = shape::sizeAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 3 : 1);
|
||||
Nd4jLong xStrd3 = shape::strideAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? 3 : 1);
|
||||
Nd4jLong yStrd3 = shape::strideAt(yShapeInfo, shape::order(zShapeInfo) == 'c' ? 3 : 1);
|
||||
Nd4jLong zStrd3 = shape::strideAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 3 : 1);
|
||||
|
||||
uint zAxis4 = shape::sizeAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 4 : 0);
|
||||
Nd4jLong xStrd4 = shape::strideAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? 4 : 0);
|
||||
Nd4jLong yStrd4 = shape::strideAt(yShapeInfo, shape::order(zShapeInfo) == 'c' ? 4 : 0);
|
||||
Nd4jLong zStrd4 = shape::strideAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 4 : 0);
|
||||
|
||||
auto func = PRAGMA_THREADS_FOR_3D {
|
||||
|
||||
|
@ -650,15 +682,19 @@ void BroadcastBool<X, Z>::exec(const void *vx, const Nd4jLong *xShapeInfo,
|
|||
}
|
||||
}
|
||||
};
|
||||
samediff::Threads::parallel_for(func, 0,zAxis0,1, 0,zAxis1,1, 0,zAxis2,1);
|
||||
}
|
||||
break;
|
||||
|
||||
default: {
|
||||
samediff::Threads::parallel_for(func, 0,zAxis0,1, 0,zAxis1,1, 0,zAxis2,1);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////
|
||||
template <typename X, typename Z, typename OpType>
|
||||
static void execDefault(const X *x, const Nd4jLong *xShapeInfo, const X *y, const Nd4jLong *yShapeInfo, Z* z, const Nd4jLong *zShapeInfo, X* extraParams) {
|
||||
|
||||
const bool xzSameOffsets = shape::haveSameShapeAndStrides(xShapeInfo, zShapeInfo);
|
||||
const bool yzSameOffsets = shape::haveSameShapeAndStrides(yShapeInfo, zShapeInfo);
|
||||
|
||||
const int rank = shape::rank(zShapeInfo); // xRank = yRank = zRank
|
||||
|
||||
auto func = PRAGMA_THREADS_FOR{
|
||||
|
||||
int xCoords[MAX_RANK], yCoords[MAX_RANK], zCoords[MAX_RANK];
|
||||
|
@ -681,7 +717,42 @@ void BroadcastBool<X, Z>::exec(const void *vx, const Nd4jLong *xShapeInfo,
|
|||
};
|
||||
|
||||
samediff::Threads::parallel_for(func, 0, shape::length(zShapeInfo));
|
||||
}
|
||||
}
|
||||
////////////////////////////////////////////////////////////////////////
|
||||
template <typename X, typename Z>
|
||||
template<typename OpType>
|
||||
void BroadcastBool<X, Z>::exec(const void *vx, const Nd4jLong *xShapeInfo,
|
||||
const void *vy, const Nd4jLong *yShapeInfo,
|
||||
void *vz, const Nd4jLong *zShapeInfo,
|
||||
void *vextraParams) {
|
||||
|
||||
const X* x = reinterpret_cast<const X*>(vx);
|
||||
const X* y = reinterpret_cast<const X*>(vy);
|
||||
Z* z = reinterpret_cast<Z*>(vz);
|
||||
|
||||
X* extraParams = reinterpret_cast<X*>(vextraParams);
|
||||
|
||||
const int rank = shape::rank(zShapeInfo); // xRank = yRank = zRank
|
||||
|
||||
switch (rank) {
|
||||
|
||||
case 1:
|
||||
execRank1<X,Z, OpType>(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, extraParams);
|
||||
break;
|
||||
case 2:
|
||||
execRank2<X,Z, OpType>(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, extraParams);
|
||||
break;
|
||||
case 3:
|
||||
execRank3<X,Z, OpType>(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, extraParams);
|
||||
break;
|
||||
case 4:
|
||||
execRank4<X,Z, OpType>(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, extraParams);
|
||||
break;
|
||||
case 5:
|
||||
execRank5<X,Z, OpType>(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, extraParams);
|
||||
break;
|
||||
default:
|
||||
execDefault<X,Z, OpType>(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, extraParams);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -439,57 +439,14 @@ namespace functions {
|
|||
}
|
||||
}
|
||||
|
||||
|
||||
////////////////////////////////////////////////////////////////////////
|
||||
template <typename X>
|
||||
template<typename OpType>
|
||||
void BroadcastInt<X>::exec(const void *vx, const Nd4jLong *xShapeInfo,
|
||||
const void *vy, const Nd4jLong *yShapeInfo,
|
||||
void *vz, const Nd4jLong *zShapeInfo) {
|
||||
template <typename X, typename OpType>
|
||||
static void execRank1(const X *x, const Nd4jLong *xShapeInfo, const X *y, const Nd4jLong *yShapeInfo, X* z, const Nd4jLong *zShapeInfo) {
|
||||
|
||||
const X* x = reinterpret_cast<const X*>(vx);
|
||||
const X* y = reinterpret_cast<const X*>(vy);
|
||||
X* z = reinterpret_cast<X*>(vz);
|
||||
|
||||
const int rank = shape::rank(zShapeInfo); // xRank = yRank = zRank
|
||||
const char zOrder = shape::order(zShapeInfo);
|
||||
|
||||
uint xAxis0 = shape::sizeAt(xShapeInfo, zOrder == 'c' ? 0 : rank-1);
|
||||
uint xAxis1 = shape::sizeAt(xShapeInfo, zOrder == 'c' ? 1 : rank-2);
|
||||
uint xAxis2 = rank > 2 ? shape::sizeAt(xShapeInfo, zOrder == 'c' ? 2 : rank - 3) : 0;
|
||||
uint xAxis3 = rank > 3 ? shape::sizeAt(xShapeInfo, zOrder == 'c' ? 3 : rank - 4) : 0;
|
||||
uint xAxis4 = rank > 4 ? shape::sizeAt(xShapeInfo, zOrder == 'c' ? 4 : rank - 5) : 0;
|
||||
Nd4jLong xStrd0 = shape::strideAt(xShapeInfo, zOrder == 'c' ? 0 : rank-1);
|
||||
Nd4jLong xStrd1 = shape::strideAt(xShapeInfo, zOrder == 'c' ? 1 : rank-2);
|
||||
Nd4jLong xStrd2 = rank > 2 ? shape::strideAt(xShapeInfo, zOrder == 'c' ? 2 : rank - 3) : 0;
|
||||
Nd4jLong xStrd3 = rank > 3 ? shape::strideAt(xShapeInfo, zOrder == 'c' ? 3 : rank - 4) : 0;
|
||||
Nd4jLong xStrd4 = rank > 4 ? shape::strideAt(xShapeInfo, zOrder == 'c' ? 4 : rank - 5) : 0;
|
||||
|
||||
uint yAxis0 = shape::sizeAt(yShapeInfo, zOrder == 'c' ? 0 : rank-1);
|
||||
uint yAxis1 = shape::sizeAt(yShapeInfo, zOrder == 'c' ? 1 : rank-2);
|
||||
uint yAxis2 = rank > 2 ? shape::sizeAt(yShapeInfo, zOrder == 'c' ? 2 : rank - 3) : 0;
|
||||
uint yAxis3 = rank > 3 ? shape::sizeAt(yShapeInfo, zOrder == 'c' ? 3 : rank - 4) : 0;
|
||||
uint yAxis4 = rank > 4 ? shape::sizeAt(yShapeInfo, zOrder == 'c' ? 4 : rank - 5) : 0;
|
||||
Nd4jLong yStrd0 = shape::strideAt(yShapeInfo, zOrder == 'c' ? 0 : rank-1);
|
||||
Nd4jLong yStrd1 = shape::strideAt(yShapeInfo, zOrder == 'c' ? 1 : rank-2);
|
||||
Nd4jLong yStrd2 = rank > 2 ? shape::strideAt(yShapeInfo, zOrder == 'c' ? 2 : rank - 3) : 0;
|
||||
Nd4jLong yStrd3 = rank > 3 ? shape::strideAt(yShapeInfo, zOrder == 'c' ? 3 : rank - 4) : 0;
|
||||
Nd4jLong yStrd4 = rank > 4 ? shape::strideAt(yShapeInfo, zOrder == 'c' ? 4 : rank - 5) : 0;
|
||||
|
||||
uint zAxis0 = shape::sizeAt(zShapeInfo, zOrder == 'c' ? 0 : rank-1);
|
||||
uint zAxis1 = shape::sizeAt(zShapeInfo, zOrder == 'c' ? 1 : rank-2);
|
||||
uint zAxis2 = rank > 2 ? shape::sizeAt(zShapeInfo, zOrder == 'c' ? 2 : rank - 3) : 0;
|
||||
uint zAxis3 = rank > 3 ? shape::sizeAt(zShapeInfo, zOrder == 'c' ? 3 : rank - 4) : 0;
|
||||
uint zAxis4 = rank > 4 ? shape::sizeAt(zShapeInfo, zOrder == 'c' ? 4 : rank - 5) : 0;
|
||||
Nd4jLong zStrd0 = shape::strideAt(zShapeInfo, zOrder == 'c' ? 0 : rank-1);
|
||||
Nd4jLong zStrd1 = shape::strideAt(zShapeInfo, zOrder == 'c' ? 1 : rank-2);
|
||||
Nd4jLong zStrd2 = rank > 2 ? shape::strideAt(zShapeInfo, zOrder == 'c' ? 2 : rank - 3) : 0;
|
||||
Nd4jLong zStrd3 = rank > 3 ? shape::strideAt(zShapeInfo, zOrder == 'c' ? 3 : rank - 4) : 0;
|
||||
Nd4jLong zStrd4 = rank > 4 ? shape::strideAt(zShapeInfo, zOrder == 'c' ? 4 : rank - 5) : 0;
|
||||
|
||||
switch (rank) {
|
||||
|
||||
case 1: {
|
||||
uint zAxis0 = shape::sizeAt(zShapeInfo, 0);
|
||||
Nd4jLong xStrd0 = shape::strideAt(xShapeInfo, 0);
|
||||
Nd4jLong yStrd0 = shape::strideAt(yShapeInfo, 0);
|
||||
Nd4jLong zStrd0 = shape::strideAt(zShapeInfo, 0);
|
||||
|
||||
auto func = PRAGMA_THREADS_FOR{
|
||||
|
||||
|
@ -511,10 +468,21 @@ void BroadcastInt<X>::exec(const void *vx, const Nd4jLong *xShapeInfo,
|
|||
}
|
||||
};
|
||||
samediff::Threads::parallel_tad(func, 0, zAxis0);
|
||||
}
|
||||
break;
|
||||
}
|
||||
|
||||
case 2: {
|
||||
////////////////////////////////////////////////////////////////////////
|
||||
template <typename X, typename OpType>
|
||||
static void execRank2(const X *x, const Nd4jLong *xShapeInfo, const X *y, const Nd4jLong *yShapeInfo, X* z, const Nd4jLong *zShapeInfo) {
|
||||
|
||||
uint zAxis0 = shape::sizeAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 1);
|
||||
Nd4jLong xStrd0 = shape::strideAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 1);
|
||||
Nd4jLong yStrd0 = shape::strideAt(yShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 1);
|
||||
Nd4jLong zStrd0 = shape::strideAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 1);
|
||||
|
||||
uint zAxis1 = shape::sizeAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 1 : 0);
|
||||
Nd4jLong xStrd1 = shape::strideAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? 1 : 0);
|
||||
Nd4jLong yStrd1 = shape::strideAt(yShapeInfo, shape::order(zShapeInfo) == 'c' ? 1 : 0);
|
||||
Nd4jLong zStrd1 = shape::strideAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 1 : 0);
|
||||
|
||||
auto func = PRAGMA_THREADS_FOR{
|
||||
|
||||
|
@ -538,11 +506,28 @@ void BroadcastInt<X>::exec(const void *vx, const Nd4jLong *xShapeInfo,
|
|||
z0[i1 * zStrd1] = OpType::op(x0[i1 * xStrd1], y0[i1 * yStrd1]);
|
||||
}
|
||||
};
|
||||
samediff::Threads::parallel_tad(func, 0, zAxis0);
|
||||
}
|
||||
break;
|
||||
|
||||
case 3: {
|
||||
samediff::Threads::parallel_tad(func, 0, zAxis0);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////
|
||||
template <typename X, typename OpType>
|
||||
static void execRank3(const X *x, const Nd4jLong *xShapeInfo, const X *y, const Nd4jLong *yShapeInfo, X* z, const Nd4jLong *zShapeInfo) {
|
||||
|
||||
uint zAxis0 = shape::sizeAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 2);
|
||||
Nd4jLong xStrd0 = shape::strideAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 2);
|
||||
Nd4jLong yStrd0 = shape::strideAt(yShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 2);
|
||||
Nd4jLong zStrd0 = shape::strideAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 2);
|
||||
|
||||
uint zAxis1 = shape::sizeAt(zShapeInfo, 1);
|
||||
Nd4jLong xStrd1 = shape::strideAt(xShapeInfo, 1);
|
||||
Nd4jLong yStrd1 = shape::strideAt(yShapeInfo, 1);
|
||||
Nd4jLong zStrd1 = shape::strideAt(zShapeInfo, 1);
|
||||
|
||||
uint zAxis2 = shape::sizeAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 2 : 0);
|
||||
Nd4jLong xStrd2 = shape::strideAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? 2 : 0);
|
||||
Nd4jLong yStrd2 = shape::strideAt(yShapeInfo, shape::order(zShapeInfo) == 'c' ? 2 : 0);
|
||||
Nd4jLong zStrd2 = shape::strideAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 2 : 0);
|
||||
|
||||
auto func = PRAGMA_THREADS_FOR_2D {
|
||||
|
||||
|
@ -568,11 +553,33 @@ void BroadcastInt<X>::exec(const void *vx, const Nd4jLong *xShapeInfo,
|
|||
}
|
||||
}
|
||||
};
|
||||
samediff::Threads::parallel_for(func, 0,zAxis0,1, 0,zAxis1,1);
|
||||
}
|
||||
break;
|
||||
|
||||
case 4: {
|
||||
samediff::Threads::parallel_for(func, 0,zAxis0,1, 0,zAxis1,1);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////
|
||||
template <typename X, typename OpType>
|
||||
static void execRank4(const X *x, const Nd4jLong *xShapeInfo, const X *y, const Nd4jLong *yShapeInfo, X* z, const Nd4jLong *zShapeInfo) {
|
||||
|
||||
uint zAxis0 = shape::sizeAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 3);
|
||||
Nd4jLong xStrd0 = shape::strideAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 3);
|
||||
Nd4jLong yStrd0 = shape::strideAt(yShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 3);
|
||||
Nd4jLong zStrd0 = shape::strideAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 3);
|
||||
|
||||
uint zAxis1 = shape::sizeAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 1 : 2);
|
||||
Nd4jLong xStrd1 = shape::strideAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? 1 : 2);
|
||||
Nd4jLong yStrd1 = shape::strideAt(yShapeInfo, shape::order(zShapeInfo) == 'c' ? 1 : 2);
|
||||
Nd4jLong zStrd1 = shape::strideAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 1 : 2);
|
||||
|
||||
uint zAxis2 = shape::sizeAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 2 : 1);
|
||||
Nd4jLong xStrd2 = shape::strideAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? 2 : 1);
|
||||
Nd4jLong yStrd2 = shape::strideAt(yShapeInfo, shape::order(zShapeInfo) == 'c' ? 2 : 1);
|
||||
Nd4jLong zStrd2 = shape::strideAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 2 : 1);
|
||||
|
||||
uint zAxis3 = shape::sizeAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 3 : 0);
|
||||
Nd4jLong xStrd3 = shape::strideAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? 3 : 0);
|
||||
Nd4jLong yStrd3 = shape::strideAt(yShapeInfo, shape::order(zShapeInfo) == 'c' ? 3 : 0);
|
||||
Nd4jLong zStrd3 = shape::strideAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 3 : 0);
|
||||
|
||||
auto func = PRAGMA_THREADS_FOR_3D {
|
||||
|
||||
|
@ -600,11 +607,38 @@ void BroadcastInt<X>::exec(const void *vx, const Nd4jLong *xShapeInfo,
|
|||
}
|
||||
}
|
||||
};
|
||||
samediff::Threads::parallel_for(func, 0,zAxis0,1, 0,zAxis1,1, 0,zAxis2,1);
|
||||
}
|
||||
break;
|
||||
|
||||
case 5: {
|
||||
samediff::Threads::parallel_for(func, 0,zAxis0,1, 0,zAxis1,1, 0,zAxis2,1);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////
|
||||
template <typename X, typename OpType>
|
||||
static void execRank5(const X *x, const Nd4jLong *xShapeInfo, const X *y, const Nd4jLong *yShapeInfo, X* z, const Nd4jLong *zShapeInfo) {
|
||||
|
||||
uint zAxis0 = shape::sizeAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 4);
|
||||
Nd4jLong xStrd0 = shape::strideAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 4);
|
||||
Nd4jLong yStrd0 = shape::strideAt(yShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 4);
|
||||
Nd4jLong zStrd0 = shape::strideAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 0 : 4);
|
||||
|
||||
uint zAxis1 = shape::sizeAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 1 : 3);
|
||||
Nd4jLong xStrd1 = shape::strideAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? 1 : 3);
|
||||
Nd4jLong yStrd1 = shape::strideAt(yShapeInfo, shape::order(zShapeInfo) == 'c' ? 1 : 3);
|
||||
Nd4jLong zStrd1 = shape::strideAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 1 : 3);
|
||||
|
||||
uint zAxis2 = shape::sizeAt(zShapeInfo, 2);
|
||||
Nd4jLong xStrd2 = shape::strideAt(xShapeInfo, 2);
|
||||
Nd4jLong yStrd2 = shape::strideAt(yShapeInfo, 2);
|
||||
Nd4jLong zStrd2 = shape::strideAt(zShapeInfo, 2);
|
||||
|
||||
uint zAxis3 = shape::sizeAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 3 : 1);
|
||||
Nd4jLong xStrd3 = shape::strideAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? 3 : 1);
|
||||
Nd4jLong yStrd3 = shape::strideAt(yShapeInfo, shape::order(zShapeInfo) == 'c' ? 3 : 1);
|
||||
Nd4jLong zStrd3 = shape::strideAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 3 : 1);
|
||||
|
||||
uint zAxis4 = shape::sizeAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 4 : 0);
|
||||
Nd4jLong xStrd4 = shape::strideAt(xShapeInfo, shape::order(zShapeInfo) == 'c' ? 4 : 0);
|
||||
Nd4jLong yStrd4 = shape::strideAt(yShapeInfo, shape::order(zShapeInfo) == 'c' ? 4 : 0);
|
||||
Nd4jLong zStrd4 = shape::strideAt(zShapeInfo, shape::order(zShapeInfo) == 'c' ? 4 : 0);
|
||||
|
||||
auto func = PRAGMA_THREADS_FOR_3D {
|
||||
|
||||
|
@ -634,15 +668,19 @@ void BroadcastInt<X>::exec(const void *vx, const Nd4jLong *xShapeInfo,
|
|||
}
|
||||
}
|
||||
};
|
||||
samediff::Threads::parallel_for(func, 0,zAxis0,1, 0,zAxis1,1, 0,zAxis2,1);
|
||||
}
|
||||
break;
|
||||
|
||||
default: {
|
||||
samediff::Threads::parallel_for(func, 0,zAxis0,1, 0,zAxis1,1, 0,zAxis2,1);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////
|
||||
template <typename X, typename OpType>
|
||||
static void execDefault(const X *x, const Nd4jLong *xShapeInfo, const X *y, const Nd4jLong *yShapeInfo, X* z, const Nd4jLong *zShapeInfo) {
|
||||
|
||||
const bool xzSameOffsets = shape::haveSameShapeAndStrides(xShapeInfo, zShapeInfo);
|
||||
const bool yzSameOffsets = shape::haveSameShapeAndStrides(yShapeInfo, zShapeInfo);
|
||||
|
||||
const int rank = shape::rank(zShapeInfo); // xRank = yRank = zRank
|
||||
|
||||
auto func = PRAGMA_THREADS_FOR{
|
||||
|
||||
int xCoords[MAX_RANK], yCoords[MAX_RANK], zCoords[MAX_RANK];
|
||||
|
@ -665,7 +703,40 @@ void BroadcastInt<X>::exec(const void *vx, const Nd4jLong *xShapeInfo,
|
|||
};
|
||||
|
||||
samediff::Threads::parallel_for(func, 0, shape::length(zShapeInfo));
|
||||
}
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////
|
||||
template <typename X>
|
||||
template<typename OpType>
|
||||
void BroadcastInt<X>::exec(const void *vx, const Nd4jLong *xShapeInfo,
|
||||
const void *vy, const Nd4jLong *yShapeInfo,
|
||||
void *vz, const Nd4jLong *zShapeInfo) {
|
||||
|
||||
const X* x = reinterpret_cast<const X*>(vx);
|
||||
const X* y = reinterpret_cast<const X*>(vy);
|
||||
X* z = reinterpret_cast<X*>(vz);
|
||||
|
||||
const int rank = shape::rank(zShapeInfo); // xRank = yRank = zRank
|
||||
|
||||
switch (rank) {
|
||||
|
||||
case 1:
|
||||
execRank1<X, OpType>(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo);
|
||||
break;
|
||||
case 2:
|
||||
execRank2<X, OpType>(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo);
|
||||
break;
|
||||
case 3:
|
||||
execRank3<X, OpType>(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo);
|
||||
break;
|
||||
case 4:
|
||||
execRank4<X, OpType>(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo);
|
||||
break;
|
||||
case 5:
|
||||
execRank5<X, OpType>(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo);
|
||||
break;
|
||||
default:
|
||||
execDefault<X, OpType>(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -116,13 +116,13 @@ namespace ops {
|
|||
auto shapes = SHAPELIST();
|
||||
|
||||
//Edge case: splitting empty array (mainly for TF import compatibility) -> return N empty arrays
|
||||
if(INPUT_VARIABLE(inputVar)->isEmpty()){
|
||||
for (int e = 0; e < num_splits; e++) {
|
||||
auto empty = ConstantShapeHelper::getInstance()->emptyShapeInfo(dataType);
|
||||
shapes->push_back(empty);
|
||||
}
|
||||
return shapes;
|
||||
}
|
||||
// if(INPUT_VARIABLE(inputVar)->isEmpty()){
|
||||
// for (int e = 0; e < num_splits; e++) {
|
||||
// auto empty = ConstantShapeHelper::getInstance()->emptyShapeInfo(dataType);
|
||||
// shapes->push_back(empty);
|
||||
// }
|
||||
// return shapes;
|
||||
// }
|
||||
|
||||
if (block.numI() == 2)
|
||||
axis = INT_ARG(1);
|
||||
|
|
|
@ -764,50 +764,6 @@ TEST_F(DeclarableOpsTests10, sparse_softmax_cross_entropy_loss_with_logits_test4
|
|||
ASSERT_TRUE(expected.equalsTo(output));
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests10, split_test4) {
|
||||
|
||||
auto input = NDArrayFactory::create<double>('c', {10},{1.f,2.f,3.f,4.f,5.f,6.f,7.f,8.f,9.f,10.f});
|
||||
auto axis = NDArrayFactory::create<double>(-1);
|
||||
auto exp1 = NDArrayFactory::create<double>('c', {5}, {1.f,2.f,3.f,4.f,5.f});
|
||||
auto exp2 = NDArrayFactory::create<double>('c', {5}, {6.f,7.f,8.f,9.f,10.f});
|
||||
|
||||
sd::ops::split op;
|
||||
auto results = op.evaluate({&input, &axis}, {}, {2}, {});
|
||||
|
||||
ASSERT_EQ(ND4J_STATUS_OK, results.status());
|
||||
|
||||
auto out1 = results.at(0);
|
||||
auto out2 = results.at(1);
|
||||
|
||||
ASSERT_TRUE(exp1.isSameShape(out1));
|
||||
ASSERT_TRUE(exp2.isSameShape(out2));
|
||||
ASSERT_TRUE(exp1.equalsTo(out1));
|
||||
ASSERT_TRUE(exp2.equalsTo(out2));
|
||||
}
|
||||
|
||||
|
||||
///////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests10, split_test5) {
|
||||
|
||||
auto input = NDArrayFactory::create<double>('c', {3,8},{1.f,2.f,3.f,4.f,5.f,6.f,7.f,8.f,9.f,10.f,11.f,12.f,13.f,14.f,15.f,16.f,17.f,18.f,19.f,20.f,21.f,22.f,23.f,24.f});
|
||||
auto exp1 = NDArrayFactory::create<double>('c', {3,4}, {1.f,2.f,3.f,4.f, 9.f,10.f,11.f,12.f, 17.f,18.f,19.f,20.f});
|
||||
auto exp2 = NDArrayFactory::create<double>('c', {3,4}, {5.f,6.f,7.f,8.f, 13.f,14.f,15.f,16.f, 21.f,22.f,23.f,24.f});
|
||||
|
||||
sd::ops::split op;
|
||||
auto results = op.evaluate({&input}, {}, {2,-1},{});
|
||||
|
||||
ASSERT_EQ(ND4J_STATUS_OK, results.status());
|
||||
|
||||
auto out1 = results.at(0);
|
||||
auto out2 = results.at(1);
|
||||
|
||||
ASSERT_TRUE(exp1.isSameShape(out1));
|
||||
ASSERT_TRUE(exp2.isSameShape(out2));
|
||||
ASSERT_TRUE(exp1.equalsTo(out1));
|
||||
ASSERT_TRUE(exp2.equalsTo(out2));
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests10, histogram_fixed_width_test1) {
|
||||
|
||||
|
|
|
@ -923,10 +923,89 @@ TEST_F(DeclarableOpsTests4, Test_Split_3) {
|
|||
ASSERT_TRUE(sub0.equalsTo(z0));
|
||||
ASSERT_TRUE(sub1.equalsTo(z1));
|
||||
ASSERT_TRUE(sub2.equalsTo(z2));
|
||||
|
||||
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests4, split_test4) {
|
||||
|
||||
auto input = NDArrayFactory::create<double>('c', {10},{1.f,2.f,3.f,4.f,5.f,6.f,7.f,8.f,9.f,10.f});
|
||||
auto axis = NDArrayFactory::create<double>(-1);
|
||||
auto exp1 = NDArrayFactory::create<double>('c', {5}, {1.f,2.f,3.f,4.f,5.f});
|
||||
auto exp2 = NDArrayFactory::create<double>('c', {5}, {6.f,7.f,8.f,9.f,10.f});
|
||||
|
||||
sd::ops::split op;
|
||||
auto results = op.evaluate({&input, &axis}, {}, {2}, {});
|
||||
|
||||
ASSERT_EQ(ND4J_STATUS_OK, results.status());
|
||||
|
||||
auto out1 = results.at(0);
|
||||
auto out2 = results.at(1);
|
||||
|
||||
ASSERT_TRUE(exp1.isSameShape(out1));
|
||||
ASSERT_TRUE(exp2.isSameShape(out2));
|
||||
ASSERT_TRUE(exp1.equalsTo(out1));
|
||||
ASSERT_TRUE(exp2.equalsTo(out2));
|
||||
}
|
||||
|
||||
|
||||
///////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests4, split_test5) {
|
||||
|
||||
auto input = NDArrayFactory::create<double>('c', {3,8},{1.f,2.f,3.f,4.f,5.f,6.f,7.f,8.f,9.f,10.f,11.f,12.f,13.f,14.f,15.f,16.f,17.f,18.f,19.f,20.f,21.f,22.f,23.f,24.f});
|
||||
auto exp1 = NDArrayFactory::create<double>('c', {3,4}, {1.f,2.f,3.f,4.f, 9.f,10.f,11.f,12.f, 17.f,18.f,19.f,20.f});
|
||||
auto exp2 = NDArrayFactory::create<double>('c', {3,4}, {5.f,6.f,7.f,8.f, 13.f,14.f,15.f,16.f, 21.f,22.f,23.f,24.f});
|
||||
|
||||
sd::ops::split op;
|
||||
auto results = op.evaluate({&input}, {}, {2,-1},{});
|
||||
|
||||
ASSERT_EQ(ND4J_STATUS_OK, results.status());
|
||||
|
||||
auto out1 = results.at(0);
|
||||
auto out2 = results.at(1);
|
||||
|
||||
ASSERT_TRUE(exp1.isSameShape(out1));
|
||||
ASSERT_TRUE(exp2.isSameShape(out2));
|
||||
ASSERT_TRUE(exp1.equalsTo(out1));
|
||||
ASSERT_TRUE(exp2.equalsTo(out2));
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests4, split_test6) {
|
||||
|
||||
NDArray input('c', {0,4}, sd::DataType::FLOAT32);
|
||||
std::vector<Nd4jLong> expShape = {0,1};
|
||||
|
||||
const int numSplits = 4;
|
||||
const int axis = 1;
|
||||
|
||||
sd::ops::split op;
|
||||
auto results = op.evaluate({&input}, {}, {numSplits, axis}, {});
|
||||
|
||||
ASSERT_EQ(ND4J_STATUS_OK, results.status());
|
||||
|
||||
for (int i = 0; i < numSplits; ++i)
|
||||
ASSERT_TRUE(results.at(i)->isSameShape(expShape));
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests4, split_test7) {
|
||||
|
||||
NDArray input('c', {0,4}, sd::DataType::FLOAT32);
|
||||
std::vector<Nd4jLong> expShape = {0,4};
|
||||
|
||||
const int numSplits = 4;
|
||||
const int axis = 0;
|
||||
|
||||
sd::ops::split op;
|
||||
auto results = op.evaluate({&input}, {}, {numSplits, axis}, {});
|
||||
|
||||
ASSERT_EQ(ND4J_STATUS_OK, results.status());
|
||||
|
||||
for (int i = 0; i < numSplits; ++i)
|
||||
ASSERT_TRUE(results.at(i)->isSameShape(expShape));
|
||||
}
|
||||
|
||||
|
||||
TEST_F(DeclarableOpsTests4, Test_Squeeze_args_1) {
|
||||
auto x = NDArrayFactory::create<double>('c', {2, 1, 1, 1, 2}, {1, 2, 3, 4});
|
||||
auto exp = NDArrayFactory::create<double>('c', {2, 1, 2}, {1, 2, 3, 4});
|
||||
|
|
|
@ -582,8 +582,6 @@ TEST_F(DeclarableOpsTests9, concat_test16) {
|
|||
auto z = result.at(0);
|
||||
|
||||
ASSERT_TRUE(exp.isSameShape(z));
|
||||
|
||||
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////
|
||||
|
@ -611,8 +609,6 @@ TEST_F(DeclarableOpsTests9, concat_test17) {
|
|||
auto mean = tad.meanNumber().e<double>(0);
|
||||
ASSERT_NEAR((e+1)*1., mean, 1e-5);
|
||||
}
|
||||
|
||||
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////
|
||||
|
@ -803,6 +799,25 @@ TEST_F(DeclarableOpsTests9, concat_test26) {
|
|||
ASSERT_TRUE(exp.equalsTo(output));
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests9, concat_test27) {
|
||||
|
||||
auto x1 = NDArrayFactory::create<double>('c', {0,1});
|
||||
auto x2 = NDArrayFactory::create<double>('c', {0,1});
|
||||
auto x3 = NDArrayFactory::create<double>('c', {0,1});
|
||||
auto x4 = NDArrayFactory::create<double>('c', {0,1});
|
||||
|
||||
std::vector<Nd4jLong> expShape = {0, 4};
|
||||
|
||||
sd::ops::concat op;
|
||||
auto result = op.evaluate({&x1, &x2, &x3, &x4}, {}, {1});
|
||||
ASSERT_EQ(ND4J_STATUS_OK, result.status());
|
||||
|
||||
auto z = result.at(0);
|
||||
|
||||
ASSERT_TRUE(z->isSameShape(expShape));
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests9, tile_bp_test1) {
|
||||
|
||||
|
|
|
@ -3772,6 +3772,8 @@ public abstract class BaseNDArray implements INDArray, Iterable {
|
|||
INDArray ret = Nd4j.createUninitialized(this.dataType(), shape, order);
|
||||
ret.setData(dup(order).data());
|
||||
return ret;
|
||||
} else if (this.isEmpty()) {
|
||||
return Nd4j.create(this.dataType(), shape);
|
||||
} else {
|
||||
INDArray ret = this.dup(order);
|
||||
return Nd4j.create(ret.data(), shape);
|
||||
|
|
|
@ -6877,6 +6877,9 @@ public static final int PREALLOC_SIZE = 33554432;
|
|||
@Namespace("shape") public static native int sizeAt(@Cast("const Nd4jLong*") LongPointer shapeInfo, int dim);
|
||||
@Namespace("shape") public static native int sizeAt(@Cast("const Nd4jLong*") LongBuffer shapeInfo, int dim);
|
||||
@Namespace("shape") public static native int sizeAt(@Cast("const Nd4jLong*") long[] shapeInfo, int dim);
|
||||
@Namespace("shape") public static native @Cast("Nd4jLong") long strideAt(@Cast("const Nd4jLong*") LongPointer shapeInfo, int dim);
|
||||
@Namespace("shape") public static native @Cast("Nd4jLong") long strideAt(@Cast("const Nd4jLong*") LongBuffer shapeInfo, int dim);
|
||||
@Namespace("shape") public static native @Cast("Nd4jLong") long strideAt(@Cast("const Nd4jLong*") long[] shapeInfo, int dim);
|
||||
|
||||
@Namespace("shape") public static native void traceNew(int id);
|
||||
|
||||
|
@ -7814,14 +7817,20 @@ public static final int PREALLOC_SIZE = 33554432;
|
|||
@Namespace("shape") public static native void index2coords(@Cast("Nd4jLong") long index, int rank, @Cast("const Nd4jLong*") LongPointer shape, IntPointer coords);
|
||||
@Namespace("shape") public static native void index2coords(@Cast("Nd4jLong") long index, int rank, @Cast("const Nd4jLong*") LongBuffer shape, IntBuffer coords);
|
||||
@Namespace("shape") public static native void index2coords(@Cast("Nd4jLong") long index, int rank, @Cast("const Nd4jLong*") long[] shape, int[] coords);
|
||||
|
||||
@Namespace("shape") public static native void index2coordsCPU(@Cast("const Nd4jLong") long startIndex, @Cast("const Nd4jLong") long index, @Cast("const Nd4jLong*") LongPointer shapeInfo, @Cast("Nd4jLong*") LongPointer coords);
|
||||
@Namespace("shape") public static native void index2coordsCPU(@Cast("const Nd4jLong") long startIndex, @Cast("const Nd4jLong") long index, @Cast("const Nd4jLong*") LongBuffer shapeInfo, @Cast("Nd4jLong*") LongBuffer coords);
|
||||
@Namespace("shape") public static native void index2coordsCPU(@Cast("const Nd4jLong") long startIndex, @Cast("const Nd4jLong") long index, @Cast("const Nd4jLong*") long[] shapeInfo, @Cast("Nd4jLong*") long[] coords);
|
||||
@Namespace("shape") public static native void index2coordsCPU(@Cast("const Nd4jLong") long startIndex, @Cast("const Nd4jLong") long index, @Cast("const Nd4jLong*") LongPointer shapeInfo, IntPointer coords);
|
||||
@Namespace("shape") public static native void index2coordsCPU(@Cast("const Nd4jLong") long startIndex, @Cast("const Nd4jLong") long index, @Cast("const Nd4jLong*") LongBuffer shapeInfo, IntBuffer coords);
|
||||
@Namespace("shape") public static native void index2coordsCPU(@Cast("const Nd4jLong") long startIndex, @Cast("const Nd4jLong") long index, @Cast("const Nd4jLong*") long[] shapeInfo, int[] coords);
|
||||
|
||||
/**
|
||||
* take into account only dimensions stored in tadDims, tadDims must be sorted in increasing order!
|
||||
*/
|
||||
@Namespace("shape") public static native void index2coords(@Cast("Nd4jLong") long index, @Cast("const Nd4jLong*") LongPointer shapeInfo, @Cast("Nd4jLong*") LongPointer coords, int dimsSize, @Const IntPointer tadDims);
|
||||
@Namespace("shape") public static native void index2coords(@Cast("Nd4jLong") long index, @Cast("const Nd4jLong*") LongBuffer shapeInfo, @Cast("Nd4jLong*") LongBuffer coords, int dimsSize, @Const IntBuffer tadDims);
|
||||
@Namespace("shape") public static native void index2coords(@Cast("Nd4jLong") long index, @Cast("const Nd4jLong*") long[] shapeInfo, @Cast("Nd4jLong*") long[] coords, int dimsSize, @Const int[] tadDims);
|
||||
|
||||
|
||||
@Namespace("shape") public static native void index2coords(@Cast("Nd4jLong") long index, @Cast("const Nd4jLong*") LongPointer shapeInfo, IntPointer coords, int dimsSize, @Const IntPointer tadDims);
|
||||
@Namespace("shape") public static native void index2coords(@Cast("Nd4jLong") long index, @Cast("const Nd4jLong*") LongBuffer shapeInfo, IntBuffer coords, int dimsSize, @Const IntBuffer tadDims);
|
||||
@Namespace("shape") public static native void index2coords(@Cast("Nd4jLong") long index, @Cast("const Nd4jLong*") long[] shapeInfo, int[] coords, int dimsSize, @Const int[] tadDims);
|
||||
|
||||
/**
|
||||
* Convert coordinates to the corresponding linear index (sequence number in other words)
|
||||
|
@ -7833,15 +7842,15 @@ public static final int PREALLOC_SIZE = 33554432;
|
|||
@Namespace("shape") public static native @Cast("Nd4jLong") long coords2index(@Cast("const Nd4jLong*") LongPointer shapeInfo, @Const IntPointer coords);
|
||||
@Namespace("shape") public static native @Cast("Nd4jLong") long coords2index(@Cast("const Nd4jLong*") LongBuffer shapeInfo, @Const IntBuffer coords);
|
||||
@Namespace("shape") public static native @Cast("Nd4jLong") long coords2index(@Cast("const Nd4jLong*") long[] shapeInfo, @Const int[] coords);
|
||||
@Namespace("shape") public static native @Cast("Nd4jLong") long coords2index(int rank, @Cast("const Nd4jLong*") LongPointer shape, @Cast("const Nd4jLong*") LongPointer coords);
|
||||
@Namespace("shape") public static native @Cast("Nd4jLong") long coords2index(int rank, @Cast("const Nd4jLong*") LongBuffer shape, @Cast("const Nd4jLong*") LongBuffer coords);
|
||||
@Namespace("shape") public static native @Cast("Nd4jLong") long coords2index(int rank, @Cast("const Nd4jLong*") long[] shape, @Cast("const Nd4jLong*") long[] coords);
|
||||
@Namespace("shape") public static native @Cast("Nd4jLong") long coords2index(int rank, @Cast("const Nd4jLong*") LongPointer shape, @Const IntPointer coords);
|
||||
@Namespace("shape") public static native @Cast("Nd4jLong") long coords2index(int rank, @Cast("const Nd4jLong*") LongBuffer shape, @Const IntBuffer coords);
|
||||
@Namespace("shape") public static native @Cast("Nd4jLong") long coords2index(int rank, @Cast("const Nd4jLong*") long[] shape, @Const int[] coords);
|
||||
/**
|
||||
* take into account only dimensions stored in tadDims, tadDims must be sorted in increasing order!
|
||||
*/
|
||||
@Namespace("shape") public static native @Cast("Nd4jLong") long coords2index(@Cast("const Nd4jLong*") LongPointer shapeInfo, @Cast("const Nd4jLong*") LongPointer coords, int dimsSize, @Const IntPointer tadDims);
|
||||
@Namespace("shape") public static native @Cast("Nd4jLong") long coords2index(@Cast("const Nd4jLong*") LongBuffer shapeInfo, @Cast("const Nd4jLong*") LongBuffer coords, int dimsSize, @Const IntBuffer tadDims);
|
||||
@Namespace("shape") public static native @Cast("Nd4jLong") long coords2index(@Cast("const Nd4jLong*") long[] shapeInfo, @Cast("const Nd4jLong*") long[] coords, int dimsSize, @Const int[] tadDims);
|
||||
@Namespace("shape") public static native @Cast("Nd4jLong") long coords2index(@Cast("const Nd4jLong*") LongPointer shapeInfo, @Const IntPointer coords, int dimsSize, @Const IntPointer tadDims);
|
||||
@Namespace("shape") public static native @Cast("Nd4jLong") long coords2index(@Cast("const Nd4jLong*") LongBuffer shapeInfo, @Const IntBuffer coords, int dimsSize, @Const IntBuffer tadDims);
|
||||
@Namespace("shape") public static native @Cast("Nd4jLong") long coords2index(@Cast("const Nd4jLong*") long[] shapeInfo, @Const int[] coords, int dimsSize, @Const int[] tadDims);
|
||||
|
||||
/**
|
||||
* increment n-dimensional array by one iteration by changing coord appropriately
|
||||
|
@ -7931,32 +7940,32 @@ public static final int PREALLOC_SIZE = 33554432;
|
|||
// function calculates the coordinates of min array (and saves them into minIdxs) given coordinates of max array (already stored in maxIdxs)
|
||||
// dimsToExclude - should be sorted in increasing order
|
||||
// dimsLen - length of dimsToExclude, if not set (= -1), then it is calculated as maxRank - minRank
|
||||
@Namespace("shape") public static native void maxIndToMinInd(@Cast("Nd4jLong*") LongPointer maxIdxs, @Cast("Nd4jLong*") LongPointer minIdxs, @Cast("const Nd4jLong*") LongPointer maxShapeInfo, @Cast("const Nd4jLong*") LongPointer minShapeInfo, @Const IntPointer dimsToExclude/*=nullptr*/, int dimsLen/*=-1*/);
|
||||
@Namespace("shape") public static native void maxIndToMinInd(@Cast("Nd4jLong*") LongPointer maxIdxs, @Cast("Nd4jLong*") LongPointer minIdxs, @Cast("const Nd4jLong*") LongPointer maxShapeInfo, @Cast("const Nd4jLong*") LongPointer minShapeInfo);
|
||||
@Namespace("shape") public static native void maxIndToMinInd(@Cast("Nd4jLong*") LongBuffer maxIdxs, @Cast("Nd4jLong*") LongBuffer minIdxs, @Cast("const Nd4jLong*") LongBuffer maxShapeInfo, @Cast("const Nd4jLong*") LongBuffer minShapeInfo, @Const IntBuffer dimsToExclude/*=nullptr*/, int dimsLen/*=-1*/);
|
||||
@Namespace("shape") public static native void maxIndToMinInd(@Cast("Nd4jLong*") LongBuffer maxIdxs, @Cast("Nd4jLong*") LongBuffer minIdxs, @Cast("const Nd4jLong*") LongBuffer maxShapeInfo, @Cast("const Nd4jLong*") LongBuffer minShapeInfo);
|
||||
@Namespace("shape") public static native void maxIndToMinInd(@Cast("Nd4jLong*") long[] maxIdxs, @Cast("Nd4jLong*") long[] minIdxs, @Cast("const Nd4jLong*") long[] maxShapeInfo, @Cast("const Nd4jLong*") long[] minShapeInfo, @Const int[] dimsToExclude/*=nullptr*/, int dimsLen/*=-1*/);
|
||||
@Namespace("shape") public static native void maxIndToMinInd(@Cast("Nd4jLong*") long[] maxIdxs, @Cast("Nd4jLong*") long[] minIdxs, @Cast("const Nd4jLong*") long[] maxShapeInfo, @Cast("const Nd4jLong*") long[] minShapeInfo);
|
||||
@Namespace("shape") public static native void maxIndToMinInd(IntPointer maxIdxs, IntPointer minIdxs, @Cast("const Nd4jLong*") LongPointer maxShapeInfo, @Cast("const Nd4jLong*") LongPointer minShapeInfo, @Const IntPointer dimsToExclude/*=nullptr*/, int dimsLen/*=-1*/);
|
||||
@Namespace("shape") public static native void maxIndToMinInd(IntPointer maxIdxs, IntPointer minIdxs, @Cast("const Nd4jLong*") LongPointer maxShapeInfo, @Cast("const Nd4jLong*") LongPointer minShapeInfo);
|
||||
@Namespace("shape") public static native void maxIndToMinInd(IntBuffer maxIdxs, IntBuffer minIdxs, @Cast("const Nd4jLong*") LongBuffer maxShapeInfo, @Cast("const Nd4jLong*") LongBuffer minShapeInfo, @Const IntBuffer dimsToExclude/*=nullptr*/, int dimsLen/*=-1*/);
|
||||
@Namespace("shape") public static native void maxIndToMinInd(IntBuffer maxIdxs, IntBuffer minIdxs, @Cast("const Nd4jLong*") LongBuffer maxShapeInfo, @Cast("const Nd4jLong*") LongBuffer minShapeInfo);
|
||||
@Namespace("shape") public static native void maxIndToMinInd(int[] maxIdxs, int[] minIdxs, @Cast("const Nd4jLong*") long[] maxShapeInfo, @Cast("const Nd4jLong*") long[] minShapeInfo, @Const int[] dimsToExclude/*=nullptr*/, int dimsLen/*=-1*/);
|
||||
@Namespace("shape") public static native void maxIndToMinInd(int[] maxIdxs, int[] minIdxs, @Cast("const Nd4jLong*") long[] maxShapeInfo, @Cast("const Nd4jLong*") long[] minShapeInfo);
|
||||
|
||||
// calculate indexes of max-array, these output indexes correspond to one minIdx index of min-array which is sub-array of max-array
|
||||
// dimsToExclude - should be sorted in increasing order
|
||||
@Namespace("shape") public static native int outerArrayIndexes(@Cast("Nd4jLong*") LongPointer maxIdxs, @Cast("const Nd4jLong") long minIdx, @Cast("const Nd4jLong*") LongPointer maxShapeInfo, @Cast("const Nd4jLong*") LongPointer minShapeInfo, @Const IntPointer dimsToExclude/*=nullptr*/);
|
||||
@Namespace("shape") public static native int outerArrayIndexes(@Cast("Nd4jLong*") LongPointer maxIdxs, @Cast("const Nd4jLong") long minIdx, @Cast("const Nd4jLong*") LongPointer maxShapeInfo, @Cast("const Nd4jLong*") LongPointer minShapeInfo);
|
||||
@Namespace("shape") public static native int outerArrayIndexes(@Cast("Nd4jLong*") LongBuffer maxIdxs, @Cast("const Nd4jLong") long minIdx, @Cast("const Nd4jLong*") LongBuffer maxShapeInfo, @Cast("const Nd4jLong*") LongBuffer minShapeInfo, @Const IntBuffer dimsToExclude/*=nullptr*/);
|
||||
@Namespace("shape") public static native int outerArrayIndexes(@Cast("Nd4jLong*") LongBuffer maxIdxs, @Cast("const Nd4jLong") long minIdx, @Cast("const Nd4jLong*") LongBuffer maxShapeInfo, @Cast("const Nd4jLong*") LongBuffer minShapeInfo);
|
||||
@Namespace("shape") public static native int outerArrayIndexes(@Cast("Nd4jLong*") long[] maxIdxs, @Cast("const Nd4jLong") long minIdx, @Cast("const Nd4jLong*") long[] maxShapeInfo, @Cast("const Nd4jLong*") long[] minShapeInfo, @Const int[] dimsToExclude/*=nullptr*/);
|
||||
@Namespace("shape") public static native int outerArrayIndexes(@Cast("Nd4jLong*") long[] maxIdxs, @Cast("const Nd4jLong") long minIdx, @Cast("const Nd4jLong*") long[] maxShapeInfo, @Cast("const Nd4jLong*") long[] minShapeInfo);
|
||||
@Namespace("shape") public static native int outerArrayIndexes(IntPointer maxIdxs, @Cast("const Nd4jLong") long minIdx, @Cast("const Nd4jLong*") LongPointer maxShapeInfo, @Cast("const Nd4jLong*") LongPointer minShapeInfo, @Const IntPointer dimsToExclude/*=nullptr*/);
|
||||
@Namespace("shape") public static native int outerArrayIndexes(IntPointer maxIdxs, @Cast("const Nd4jLong") long minIdx, @Cast("const Nd4jLong*") LongPointer maxShapeInfo, @Cast("const Nd4jLong*") LongPointer minShapeInfo);
|
||||
@Namespace("shape") public static native int outerArrayIndexes(IntBuffer maxIdxs, @Cast("const Nd4jLong") long minIdx, @Cast("const Nd4jLong*") LongBuffer maxShapeInfo, @Cast("const Nd4jLong*") LongBuffer minShapeInfo, @Const IntBuffer dimsToExclude/*=nullptr*/);
|
||||
@Namespace("shape") public static native int outerArrayIndexes(IntBuffer maxIdxs, @Cast("const Nd4jLong") long minIdx, @Cast("const Nd4jLong*") LongBuffer maxShapeInfo, @Cast("const Nd4jLong*") LongBuffer minShapeInfo);
|
||||
@Namespace("shape") public static native int outerArrayIndexes(int[] maxIdxs, @Cast("const Nd4jLong") long minIdx, @Cast("const Nd4jLong*") long[] maxShapeInfo, @Cast("const Nd4jLong*") long[] minShapeInfo, @Const int[] dimsToExclude/*=nullptr*/);
|
||||
@Namespace("shape") public static native int outerArrayIndexes(int[] maxIdxs, @Cast("const Nd4jLong") long minIdx, @Cast("const Nd4jLong*") long[] maxShapeInfo, @Cast("const Nd4jLong*") long[] minShapeInfo);
|
||||
|
||||
// calculate offsets of max-array, these offsets correspond to one minIdx index of min-array which is sub-array of max-array
|
||||
// maxOffsets - will contain calculated offsets of max-array, buffer for maxOffsets should be allocated beforehand
|
||||
// dimsToExclude - should be sorted in increasing order
|
||||
// memBuff - auxiliary memory buffer (size = 2 * max_rank) for coordinates and increments storing, should be allocated beforehand
|
||||
@Namespace("shape") public static native int outerArrayOffsets(@Cast("Nd4jLong*") LongPointer maxOffsets, @Cast("const Nd4jLong") long minIdx, @Cast("const Nd4jLong*") LongPointer maxShapeInfo, @Cast("const Nd4jLong*") LongPointer minShapeInfo, @Cast("Nd4jLong*") LongPointer memBuff, @Const IntPointer dimsToExclude/*=nullptr*/);
|
||||
@Namespace("shape") public static native int outerArrayOffsets(@Cast("Nd4jLong*") LongPointer maxOffsets, @Cast("const Nd4jLong") long minIdx, @Cast("const Nd4jLong*") LongPointer maxShapeInfo, @Cast("const Nd4jLong*") LongPointer minShapeInfo, @Cast("Nd4jLong*") LongPointer memBuff);
|
||||
@Namespace("shape") public static native int outerArrayOffsets(@Cast("Nd4jLong*") LongBuffer maxOffsets, @Cast("const Nd4jLong") long minIdx, @Cast("const Nd4jLong*") LongBuffer maxShapeInfo, @Cast("const Nd4jLong*") LongBuffer minShapeInfo, @Cast("Nd4jLong*") LongBuffer memBuff, @Const IntBuffer dimsToExclude/*=nullptr*/);
|
||||
@Namespace("shape") public static native int outerArrayOffsets(@Cast("Nd4jLong*") LongBuffer maxOffsets, @Cast("const Nd4jLong") long minIdx, @Cast("const Nd4jLong*") LongBuffer maxShapeInfo, @Cast("const Nd4jLong*") LongBuffer minShapeInfo, @Cast("Nd4jLong*") LongBuffer memBuff);
|
||||
@Namespace("shape") public static native int outerArrayOffsets(@Cast("Nd4jLong*") long[] maxOffsets, @Cast("const Nd4jLong") long minIdx, @Cast("const Nd4jLong*") long[] maxShapeInfo, @Cast("const Nd4jLong*") long[] minShapeInfo, @Cast("Nd4jLong*") long[] memBuff, @Const int[] dimsToExclude/*=nullptr*/);
|
||||
@Namespace("shape") public static native int outerArrayOffsets(@Cast("Nd4jLong*") long[] maxOffsets, @Cast("const Nd4jLong") long minIdx, @Cast("const Nd4jLong*") long[] maxShapeInfo, @Cast("const Nd4jLong*") long[] minShapeInfo, @Cast("Nd4jLong*") long[] memBuff);
|
||||
@Namespace("shape") public static native int outerArrayOffsets(@Cast("Nd4jLong*") LongPointer maxOffsets, @Cast("const Nd4jLong") long minIdx, @Cast("const Nd4jLong*") LongPointer maxShapeInfo, @Cast("const Nd4jLong*") LongPointer minShapeInfo, IntPointer memBuff, @Const IntPointer dimsToExclude/*=nullptr*/);
|
||||
@Namespace("shape") public static native int outerArrayOffsets(@Cast("Nd4jLong*") LongPointer maxOffsets, @Cast("const Nd4jLong") long minIdx, @Cast("const Nd4jLong*") LongPointer maxShapeInfo, @Cast("const Nd4jLong*") LongPointer minShapeInfo, IntPointer memBuff);
|
||||
@Namespace("shape") public static native int outerArrayOffsets(@Cast("Nd4jLong*") LongBuffer maxOffsets, @Cast("const Nd4jLong") long minIdx, @Cast("const Nd4jLong*") LongBuffer maxShapeInfo, @Cast("const Nd4jLong*") LongBuffer minShapeInfo, IntBuffer memBuff, @Const IntBuffer dimsToExclude/*=nullptr*/);
|
||||
@Namespace("shape") public static native int outerArrayOffsets(@Cast("Nd4jLong*") LongBuffer maxOffsets, @Cast("const Nd4jLong") long minIdx, @Cast("const Nd4jLong*") LongBuffer maxShapeInfo, @Cast("const Nd4jLong*") LongBuffer minShapeInfo, IntBuffer memBuff);
|
||||
@Namespace("shape") public static native int outerArrayOffsets(@Cast("Nd4jLong*") long[] maxOffsets, @Cast("const Nd4jLong") long minIdx, @Cast("const Nd4jLong*") long[] maxShapeInfo, @Cast("const Nd4jLong*") long[] minShapeInfo, int[] memBuff, @Const int[] dimsToExclude/*=nullptr*/);
|
||||
@Namespace("shape") public static native int outerArrayOffsets(@Cast("Nd4jLong*") long[] maxOffsets, @Cast("const Nd4jLong") long minIdx, @Cast("const Nd4jLong*") long[] maxShapeInfo, @Cast("const Nd4jLong*") long[] minShapeInfo, int[] memBuff);
|
||||
|
||||
// calculates offsets for entities (elements or sub-arrays), shape in context of sub-array means dimensions excluded from outer array
|
||||
// rank is equal to size of shape
|
||||
|
@ -8052,9 +8061,7 @@ public static final int PREALLOC_SIZE = 33554432;
|
|||
* get stride over contiguous axis (contiguous axis must have stride = 1)
|
||||
* for example when inShapeInfo is {4, 2,5,4,3, 60,1,5,20, 16384,0,99} then output is 5 (that is smallest stride in inShapeInfo except those equal to 1)
|
||||
*/
|
||||
@Namespace("shape") public static native @Cast("Nd4jLong") long strideOverContigAxis(int axis, @Cast("const Nd4jLong*") LongPointer inShapeInfo);
|
||||
@Namespace("shape") public static native @Cast("Nd4jLong") long strideOverContigAxis(int axis, @Cast("const Nd4jLong*") LongBuffer inShapeInfo);
|
||||
@Namespace("shape") public static native @Cast("Nd4jLong") long strideOverContigAxis(int axis, @Cast("const Nd4jLong*") long[] inShapeInfo);
|
||||
// INLINEDEF _CUDA_HD Nd4jLong strideOverContigAxis(const int axis, const Nd4jLong* inShapeInfo);
|
||||
|
||||
|
||||
|
||||
|
@ -8930,6 +8937,10 @@ public static final int PREALLOC_SIZE = 33554432;
|
|||
|
||||
//////////////////////////////////////////////////////////////////////
|
||||
|
||||
//////////////////////////////////////////////////////////////////////
|
||||
|
||||
//////////////////////////////////////////////////////////////////////
|
||||
|
||||
//////////////////////////////////////////////////////////////////////
|
||||
// INLINEDEF _CUDA_HD void calcOffsets(const Nd4jLong *xShapeInfo, Nd4jLong*& xOffsets, const Nd4jLong *yShapeInfo, Nd4jLong*& yOffsets, const Nd4jLong* zShapeInfo, Nd4jLong*& zOffsets, const char order) {
|
||||
|
||||
|
@ -9126,6 +9137,23 @@ public static final int PREALLOC_SIZE = 33554432;
|
|||
|
||||
|
||||
//////////////////////////////////////////////////////////////////////
|
||||
// INLINEDEF _CUDA_HD Nd4jLong strideOverContigAxis(const int axis, const Nd4jLong* inShapeInfo) {
|
||||
|
||||
// Nd4jLong result = 9223372036854775807LL;
|
||||
|
||||
// for(uint i = 0; i < shape::rank(inShapeInfo); ++i) {
|
||||
|
||||
// const auto currentStride = shape::stride(inShapeInfo)[i];
|
||||
|
||||
// if(i == axis || shape::shapeOf(inShapeInfo)[i] == 1)
|
||||
// continue;
|
||||
|
||||
// if(result > currentStride)
|
||||
// result = currentStride;
|
||||
// }
|
||||
|
||||
// return result == 9223372036854775807LL ? 1 : result;
|
||||
// }
|
||||
|
||||
|
||||
|
||||
|
@ -9736,18 +9764,17 @@ public static final int PREALLOC_SIZE = 33554432;
|
|||
public native @Cast("Nd4jStatus") int execute(@Const @ByRef NDArrayVector inputs, @Const @ByRef NDArrayVector outputs, @StdVector DoubleBuffer tArgs, @Cast("Nd4jLong*") @StdVector LongBuffer iArgs, @Cast("bool*") @StdVector BooleanPointer bArgs/*=std::vector<bool>()*/, @Cast("sd::DataType*") @StdVector IntBuffer dArgs/*=std::vector<sd::DataType>()*/, @Cast("bool") boolean isInplace/*=false*/);
|
||||
public native @Cast("Nd4jStatus") int execute(@Const @ByRef NDArrayVector inputs, @Const @ByRef NDArrayVector outputs, @StdVector double[] tArgs, @Cast("Nd4jLong*") @StdVector long[] iArgs, @Cast("bool*") @StdVector boolean[] bArgs/*=std::vector<bool>()*/, @Cast("sd::DataType*") @StdVector int[] dArgs/*=std::vector<sd::DataType>()*/, @Cast("bool") boolean isInplace/*=false*/);
|
||||
|
||||
public native @ByVal ResultSet evaluate(@Const @ByRef NDArrayVector inputs);
|
||||
|
||||
public native ResultSet evaluate(@Const @ByRef NDArrayVector inputs);
|
||||
|
||||
public native ResultSet evaluate(@Const @ByRef NDArrayVector inputs, @StdVector DoublePointer tArgs, @Cast("Nd4jLong*") @StdVector LongPointer iArgs, @Cast("bool*") @StdVector BooleanPointer bArgs/*=std::vector<bool>()*/, @Cast("sd::DataType*") @StdVector IntPointer dArgs/*=std::vector<sd::DataType>()*/, @Cast("bool") boolean isInplace/*=false*/);
|
||||
public native ResultSet evaluate(@Const @ByRef NDArrayVector inputs, @StdVector DoublePointer tArgs, @Cast("Nd4jLong*") @StdVector LongPointer iArgs);
|
||||
public native ResultSet evaluate(@Const @ByRef NDArrayVector inputs, @StdVector DoubleBuffer tArgs, @Cast("Nd4jLong*") @StdVector LongBuffer iArgs, @Cast("bool*") @StdVector boolean[] bArgs/*=std::vector<bool>()*/, @Cast("sd::DataType*") @StdVector IntBuffer dArgs/*=std::vector<sd::DataType>()*/, @Cast("bool") boolean isInplace/*=false*/);
|
||||
public native ResultSet evaluate(@Const @ByRef NDArrayVector inputs, @StdVector DoubleBuffer tArgs, @Cast("Nd4jLong*") @StdVector LongBuffer iArgs);
|
||||
public native ResultSet evaluate(@Const @ByRef NDArrayVector inputs, @StdVector double[] tArgs, @Cast("Nd4jLong*") @StdVector long[] iArgs, @Cast("bool*") @StdVector BooleanPointer bArgs/*=std::vector<bool>()*/, @Cast("sd::DataType*") @StdVector int[] dArgs/*=std::vector<sd::DataType>()*/, @Cast("bool") boolean isInplace/*=false*/);
|
||||
public native ResultSet evaluate(@Const @ByRef NDArrayVector inputs, @StdVector double[] tArgs, @Cast("Nd4jLong*") @StdVector long[] iArgs);
|
||||
public native ResultSet evaluate(@Const @ByRef NDArrayVector inputs, @StdVector DoublePointer tArgs, @Cast("Nd4jLong*") @StdVector LongPointer iArgs, @Cast("bool*") @StdVector boolean[] bArgs/*=std::vector<bool>()*/, @Cast("sd::DataType*") @StdVector IntPointer dArgs/*=std::vector<sd::DataType>()*/, @Cast("bool") boolean isInplace/*=false*/);
|
||||
public native ResultSet evaluate(@Const @ByRef NDArrayVector inputs, @StdVector DoubleBuffer tArgs, @Cast("Nd4jLong*") @StdVector LongBuffer iArgs, @Cast("bool*") @StdVector BooleanPointer bArgs/*=std::vector<bool>()*/, @Cast("sd::DataType*") @StdVector IntBuffer dArgs/*=std::vector<sd::DataType>()*/, @Cast("bool") boolean isInplace/*=false*/);
|
||||
public native ResultSet evaluate(@Const @ByRef NDArrayVector inputs, @StdVector double[] tArgs, @Cast("Nd4jLong*") @StdVector long[] iArgs, @Cast("bool*") @StdVector boolean[] bArgs/*=std::vector<bool>()*/, @Cast("sd::DataType*") @StdVector int[] dArgs/*=std::vector<sd::DataType>()*/, @Cast("bool") boolean isInplace/*=false*/);
|
||||
public native @ByVal ResultSet evaluate(@Const @ByRef NDArrayVector inputs, @StdVector DoublePointer tArgs, @Cast("Nd4jLong*") @StdVector LongPointer iArgs, @Cast("bool*") @StdVector BooleanPointer bArgs/*=std::vector<bool>()*/, @Cast("sd::DataType*") @StdVector IntPointer dArgs/*=std::vector<sd::DataType>()*/, @Cast("bool") boolean isInplace/*=false*/);
|
||||
public native @ByVal ResultSet evaluate(@Const @ByRef NDArrayVector inputs, @StdVector DoublePointer tArgs, @Cast("Nd4jLong*") @StdVector LongPointer iArgs);
|
||||
public native @ByVal ResultSet evaluate(@Const @ByRef NDArrayVector inputs, @StdVector DoubleBuffer tArgs, @Cast("Nd4jLong*") @StdVector LongBuffer iArgs, @Cast("bool*") @StdVector boolean[] bArgs/*=std::vector<bool>()*/, @Cast("sd::DataType*") @StdVector IntBuffer dArgs/*=std::vector<sd::DataType>()*/, @Cast("bool") boolean isInplace/*=false*/);
|
||||
public native @ByVal ResultSet evaluate(@Const @ByRef NDArrayVector inputs, @StdVector DoubleBuffer tArgs, @Cast("Nd4jLong*") @StdVector LongBuffer iArgs);
|
||||
public native @ByVal ResultSet evaluate(@Const @ByRef NDArrayVector inputs, @StdVector double[] tArgs, @Cast("Nd4jLong*") @StdVector long[] iArgs, @Cast("bool*") @StdVector BooleanPointer bArgs/*=std::vector<bool>()*/, @Cast("sd::DataType*") @StdVector int[] dArgs/*=std::vector<sd::DataType>()*/, @Cast("bool") boolean isInplace/*=false*/);
|
||||
public native @ByVal ResultSet evaluate(@Const @ByRef NDArrayVector inputs, @StdVector double[] tArgs, @Cast("Nd4jLong*") @StdVector long[] iArgs);
|
||||
public native @ByVal ResultSet evaluate(@Const @ByRef NDArrayVector inputs, @StdVector DoublePointer tArgs, @Cast("Nd4jLong*") @StdVector LongPointer iArgs, @Cast("bool*") @StdVector boolean[] bArgs/*=std::vector<bool>()*/, @Cast("sd::DataType*") @StdVector IntPointer dArgs/*=std::vector<sd::DataType>()*/, @Cast("bool") boolean isInplace/*=false*/);
|
||||
public native @ByVal ResultSet evaluate(@Const @ByRef NDArrayVector inputs, @StdVector DoubleBuffer tArgs, @Cast("Nd4jLong*") @StdVector LongBuffer iArgs, @Cast("bool*") @StdVector BooleanPointer bArgs/*=std::vector<bool>()*/, @Cast("sd::DataType*") @StdVector IntBuffer dArgs/*=std::vector<sd::DataType>()*/, @Cast("bool") boolean isInplace/*=false*/);
|
||||
public native @ByVal ResultSet evaluate(@Const @ByRef NDArrayVector inputs, @StdVector double[] tArgs, @Cast("Nd4jLong*") @StdVector long[] iArgs, @Cast("bool*") @StdVector boolean[] bArgs/*=std::vector<bool>()*/, @Cast("sd::DataType*") @StdVector int[] dArgs/*=std::vector<sd::DataType>()*/, @Cast("bool") boolean isInplace/*=false*/);
|
||||
|
||||
public native @Cast("Nd4jStatus") int execute(@ByRef RandomGenerator rng, @Const @ByRef NDArrayVector inputs, @Const @ByRef NDArrayVector outputs, @StdVector DoublePointer tArgs, @Cast("Nd4jLong*") @StdVector LongPointer iArgs, @Cast("bool*") @StdVector BooleanPointer bArgs, @Cast("sd::DataType*") @StdVector IntPointer dArgs/*=std::vector<sd::DataType>()*/, @Cast("bool") boolean isInplace/*=false*/, @Cast("sd::DataType") int type/*=sd::DataType::FLOAT32*/);
|
||||
public native @Cast("Nd4jStatus") int execute(@ByRef RandomGenerator rng, @Const @ByRef NDArrayVector inputs, @Const @ByRef NDArrayVector outputs, @StdVector DoublePointer tArgs, @Cast("Nd4jLong*") @StdVector LongPointer iArgs, @Cast("bool*") @StdVector BooleanPointer bArgs);
|
||||
|
@ -9762,8 +9789,9 @@ public static final int PREALLOC_SIZE = 33554432;
|
|||
public native @Cast("Nd4jStatus") int execute(@ByRef RandomGenerator rng, @Const @ByRef NDArrayVector inputs, @Const @ByRef NDArrayVector outputs, @StdVector double[] tArgs, @Cast("Nd4jLong*") @StdVector long[] iArgs, @Cast("bool*") @StdVector boolean[] bArgs, @Cast("sd::DataType*") @StdVector int[] dArgs/*=std::vector<sd::DataType>()*/, @Cast("bool") boolean isInplace/*=false*/, @Cast("sd::DataType") int type/*=sd::DataType::FLOAT32*/);
|
||||
public native @Cast("Nd4jStatus") int execute(@ByRef RandomGenerator rng, @Const @ByRef NDArrayVector inputs, @Const @ByRef NDArrayVector outputs, @StdVector double[] tArgs, @Cast("Nd4jLong*") @StdVector long[] iArgs, @Cast("bool*") @StdVector boolean[] bArgs);
|
||||
|
||||
public native ResultSet execute(@Const @ByRef OpArgsHolder holder, @Cast("bool") boolean isInplace/*=false*/);
|
||||
public native ResultSet execute(@Const @ByRef OpArgsHolder holder);
|
||||
public native @ByVal ResultSet execute(@Const @ByRef OpArgsHolder holder, @Cast("bool") boolean isInplace/*=false*/);
|
||||
public native @ByVal ResultSet execute(@Const @ByRef OpArgsHolder holder);
|
||||
|
||||
|
||||
// There methods provide various validation options
|
||||
public native @Cast("Nd4jStatus") int validateNonEmptyInput(@ByRef Context block);
|
||||
|
@ -9835,9 +9863,9 @@ public static final int PREALLOC_SIZE = 33554432;
|
|||
|
||||
|
||||
public native @Cast("Nd4jStatus") int execute(Context block);
|
||||
public native ResultSet execute(NDArrayList list, @ByRef NDArrayVector inputs, @StdVector DoublePointer tArgs, @StdVector IntPointer iArgs);
|
||||
public native ResultSet execute(NDArrayList list, @ByRef NDArrayVector inputs, @StdVector DoubleBuffer tArgs, @StdVector IntBuffer iArgs);
|
||||
public native ResultSet execute(NDArrayList list, @ByRef NDArrayVector inputs, @StdVector double[] tArgs, @StdVector int[] iArgs);
|
||||
public native @ByVal ResultSet execute(NDArrayList list, @ByRef NDArrayVector inputs, @StdVector DoublePointer tArgs, @StdVector IntPointer iArgs);
|
||||
public native @ByVal ResultSet execute(NDArrayList list, @ByRef NDArrayVector inputs, @StdVector DoubleBuffer tArgs, @StdVector IntBuffer iArgs);
|
||||
public native @ByVal ResultSet execute(NDArrayList list, @ByRef NDArrayVector inputs, @StdVector double[] tArgs, @StdVector int[] iArgs);
|
||||
|
||||
public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block);
|
||||
}
|
||||
|
|
|
@ -6880,6 +6880,9 @@ public static final int PREALLOC_SIZE = 33554432;
|
|||
@Namespace("shape") public static native int sizeAt(@Cast("const Nd4jLong*") LongPointer shapeInfo, int dim);
|
||||
@Namespace("shape") public static native int sizeAt(@Cast("const Nd4jLong*") LongBuffer shapeInfo, int dim);
|
||||
@Namespace("shape") public static native int sizeAt(@Cast("const Nd4jLong*") long[] shapeInfo, int dim);
|
||||
@Namespace("shape") public static native @Cast("Nd4jLong") long strideAt(@Cast("const Nd4jLong*") LongPointer shapeInfo, int dim);
|
||||
@Namespace("shape") public static native @Cast("Nd4jLong") long strideAt(@Cast("const Nd4jLong*") LongBuffer shapeInfo, int dim);
|
||||
@Namespace("shape") public static native @Cast("Nd4jLong") long strideAt(@Cast("const Nd4jLong*") long[] shapeInfo, int dim);
|
||||
|
||||
@Namespace("shape") public static native void traceNew(int id);
|
||||
|
||||
|
@ -7817,14 +7820,20 @@ public static final int PREALLOC_SIZE = 33554432;
|
|||
@Namespace("shape") public static native void index2coords(@Cast("Nd4jLong") long index, int rank, @Cast("const Nd4jLong*") LongPointer shape, IntPointer coords);
|
||||
@Namespace("shape") public static native void index2coords(@Cast("Nd4jLong") long index, int rank, @Cast("const Nd4jLong*") LongBuffer shape, IntBuffer coords);
|
||||
@Namespace("shape") public static native void index2coords(@Cast("Nd4jLong") long index, int rank, @Cast("const Nd4jLong*") long[] shape, int[] coords);
|
||||
|
||||
@Namespace("shape") public static native void index2coordsCPU(@Cast("const Nd4jLong") long startIndex, @Cast("const Nd4jLong") long index, @Cast("const Nd4jLong*") LongPointer shapeInfo, @Cast("Nd4jLong*") LongPointer coords);
|
||||
@Namespace("shape") public static native void index2coordsCPU(@Cast("const Nd4jLong") long startIndex, @Cast("const Nd4jLong") long index, @Cast("const Nd4jLong*") LongBuffer shapeInfo, @Cast("Nd4jLong*") LongBuffer coords);
|
||||
@Namespace("shape") public static native void index2coordsCPU(@Cast("const Nd4jLong") long startIndex, @Cast("const Nd4jLong") long index, @Cast("const Nd4jLong*") long[] shapeInfo, @Cast("Nd4jLong*") long[] coords);
|
||||
@Namespace("shape") public static native void index2coordsCPU(@Cast("const Nd4jLong") long startIndex, @Cast("const Nd4jLong") long index, @Cast("const Nd4jLong*") LongPointer shapeInfo, IntPointer coords);
|
||||
@Namespace("shape") public static native void index2coordsCPU(@Cast("const Nd4jLong") long startIndex, @Cast("const Nd4jLong") long index, @Cast("const Nd4jLong*") LongBuffer shapeInfo, IntBuffer coords);
|
||||
@Namespace("shape") public static native void index2coordsCPU(@Cast("const Nd4jLong") long startIndex, @Cast("const Nd4jLong") long index, @Cast("const Nd4jLong*") long[] shapeInfo, int[] coords);
|
||||
|
||||
/**
|
||||
* take into account only dimensions stored in tadDims, tadDims must be sorted in increasing order!
|
||||
*/
|
||||
@Namespace("shape") public static native void index2coords(@Cast("Nd4jLong") long index, @Cast("const Nd4jLong*") LongPointer shapeInfo, @Cast("Nd4jLong*") LongPointer coords, int dimsSize, @Const IntPointer tadDims);
|
||||
@Namespace("shape") public static native void index2coords(@Cast("Nd4jLong") long index, @Cast("const Nd4jLong*") LongBuffer shapeInfo, @Cast("Nd4jLong*") LongBuffer coords, int dimsSize, @Const IntBuffer tadDims);
|
||||
@Namespace("shape") public static native void index2coords(@Cast("Nd4jLong") long index, @Cast("const Nd4jLong*") long[] shapeInfo, @Cast("Nd4jLong*") long[] coords, int dimsSize, @Const int[] tadDims);
|
||||
|
||||
|
||||
@Namespace("shape") public static native void index2coords(@Cast("Nd4jLong") long index, @Cast("const Nd4jLong*") LongPointer shapeInfo, IntPointer coords, int dimsSize, @Const IntPointer tadDims);
|
||||
@Namespace("shape") public static native void index2coords(@Cast("Nd4jLong") long index, @Cast("const Nd4jLong*") LongBuffer shapeInfo, IntBuffer coords, int dimsSize, @Const IntBuffer tadDims);
|
||||
@Namespace("shape") public static native void index2coords(@Cast("Nd4jLong") long index, @Cast("const Nd4jLong*") long[] shapeInfo, int[] coords, int dimsSize, @Const int[] tadDims);
|
||||
|
||||
/**
|
||||
* Convert coordinates to the corresponding linear index (sequence number in other words)
|
||||
|
@ -7836,15 +7845,15 @@ public static final int PREALLOC_SIZE = 33554432;
|
|||
@Namespace("shape") public static native @Cast("Nd4jLong") long coords2index(@Cast("const Nd4jLong*") LongPointer shapeInfo, @Const IntPointer coords);
|
||||
@Namespace("shape") public static native @Cast("Nd4jLong") long coords2index(@Cast("const Nd4jLong*") LongBuffer shapeInfo, @Const IntBuffer coords);
|
||||
@Namespace("shape") public static native @Cast("Nd4jLong") long coords2index(@Cast("const Nd4jLong*") long[] shapeInfo, @Const int[] coords);
|
||||
@Namespace("shape") public static native @Cast("Nd4jLong") long coords2index(int rank, @Cast("const Nd4jLong*") LongPointer shape, @Cast("const Nd4jLong*") LongPointer coords);
|
||||
@Namespace("shape") public static native @Cast("Nd4jLong") long coords2index(int rank, @Cast("const Nd4jLong*") LongBuffer shape, @Cast("const Nd4jLong*") LongBuffer coords);
|
||||
@Namespace("shape") public static native @Cast("Nd4jLong") long coords2index(int rank, @Cast("const Nd4jLong*") long[] shape, @Cast("const Nd4jLong*") long[] coords);
|
||||
@Namespace("shape") public static native @Cast("Nd4jLong") long coords2index(int rank, @Cast("const Nd4jLong*") LongPointer shape, @Const IntPointer coords);
|
||||
@Namespace("shape") public static native @Cast("Nd4jLong") long coords2index(int rank, @Cast("const Nd4jLong*") LongBuffer shape, @Const IntBuffer coords);
|
||||
@Namespace("shape") public static native @Cast("Nd4jLong") long coords2index(int rank, @Cast("const Nd4jLong*") long[] shape, @Const int[] coords);
|
||||
/**
|
||||
* take into account only dimensions stored in tadDims, tadDims must be sorted in increasing order!
|
||||
*/
|
||||
@Namespace("shape") public static native @Cast("Nd4jLong") long coords2index(@Cast("const Nd4jLong*") LongPointer shapeInfo, @Cast("const Nd4jLong*") LongPointer coords, int dimsSize, @Const IntPointer tadDims);
|
||||
@Namespace("shape") public static native @Cast("Nd4jLong") long coords2index(@Cast("const Nd4jLong*") LongBuffer shapeInfo, @Cast("const Nd4jLong*") LongBuffer coords, int dimsSize, @Const IntBuffer tadDims);
|
||||
@Namespace("shape") public static native @Cast("Nd4jLong") long coords2index(@Cast("const Nd4jLong*") long[] shapeInfo, @Cast("const Nd4jLong*") long[] coords, int dimsSize, @Const int[] tadDims);
|
||||
@Namespace("shape") public static native @Cast("Nd4jLong") long coords2index(@Cast("const Nd4jLong*") LongPointer shapeInfo, @Const IntPointer coords, int dimsSize, @Const IntPointer tadDims);
|
||||
@Namespace("shape") public static native @Cast("Nd4jLong") long coords2index(@Cast("const Nd4jLong*") LongBuffer shapeInfo, @Const IntBuffer coords, int dimsSize, @Const IntBuffer tadDims);
|
||||
@Namespace("shape") public static native @Cast("Nd4jLong") long coords2index(@Cast("const Nd4jLong*") long[] shapeInfo, @Const int[] coords, int dimsSize, @Const int[] tadDims);
|
||||
|
||||
/**
|
||||
* increment n-dimensional array by one iteration by changing coord appropriately
|
||||
|
@ -7934,32 +7943,32 @@ public static final int PREALLOC_SIZE = 33554432;
|
|||
// function calculates the coordinates of min array (and saves them into minIdxs) given coordinates of max array (already stored in maxIdxs)
|
||||
// dimsToExclude - should be sorted in increasing order
|
||||
// dimsLen - length of dimsToExclude, if not set (= -1), then it is calculated as maxRank - minRank
|
||||
@Namespace("shape") public static native void maxIndToMinInd(@Cast("Nd4jLong*") LongPointer maxIdxs, @Cast("Nd4jLong*") LongPointer minIdxs, @Cast("const Nd4jLong*") LongPointer maxShapeInfo, @Cast("const Nd4jLong*") LongPointer minShapeInfo, @Const IntPointer dimsToExclude/*=nullptr*/, int dimsLen/*=-1*/);
|
||||
@Namespace("shape") public static native void maxIndToMinInd(@Cast("Nd4jLong*") LongPointer maxIdxs, @Cast("Nd4jLong*") LongPointer minIdxs, @Cast("const Nd4jLong*") LongPointer maxShapeInfo, @Cast("const Nd4jLong*") LongPointer minShapeInfo);
|
||||
@Namespace("shape") public static native void maxIndToMinInd(@Cast("Nd4jLong*") LongBuffer maxIdxs, @Cast("Nd4jLong*") LongBuffer minIdxs, @Cast("const Nd4jLong*") LongBuffer maxShapeInfo, @Cast("const Nd4jLong*") LongBuffer minShapeInfo, @Const IntBuffer dimsToExclude/*=nullptr*/, int dimsLen/*=-1*/);
|
||||
@Namespace("shape") public static native void maxIndToMinInd(@Cast("Nd4jLong*") LongBuffer maxIdxs, @Cast("Nd4jLong*") LongBuffer minIdxs, @Cast("const Nd4jLong*") LongBuffer maxShapeInfo, @Cast("const Nd4jLong*") LongBuffer minShapeInfo);
|
||||
@Namespace("shape") public static native void maxIndToMinInd(@Cast("Nd4jLong*") long[] maxIdxs, @Cast("Nd4jLong*") long[] minIdxs, @Cast("const Nd4jLong*") long[] maxShapeInfo, @Cast("const Nd4jLong*") long[] minShapeInfo, @Const int[] dimsToExclude/*=nullptr*/, int dimsLen/*=-1*/);
|
||||
@Namespace("shape") public static native void maxIndToMinInd(@Cast("Nd4jLong*") long[] maxIdxs, @Cast("Nd4jLong*") long[] minIdxs, @Cast("const Nd4jLong*") long[] maxShapeInfo, @Cast("const Nd4jLong*") long[] minShapeInfo);
|
||||
@Namespace("shape") public static native void maxIndToMinInd(IntPointer maxIdxs, IntPointer minIdxs, @Cast("const Nd4jLong*") LongPointer maxShapeInfo, @Cast("const Nd4jLong*") LongPointer minShapeInfo, @Const IntPointer dimsToExclude/*=nullptr*/, int dimsLen/*=-1*/);
|
||||
@Namespace("shape") public static native void maxIndToMinInd(IntPointer maxIdxs, IntPointer minIdxs, @Cast("const Nd4jLong*") LongPointer maxShapeInfo, @Cast("const Nd4jLong*") LongPointer minShapeInfo);
|
||||
@Namespace("shape") public static native void maxIndToMinInd(IntBuffer maxIdxs, IntBuffer minIdxs, @Cast("const Nd4jLong*") LongBuffer maxShapeInfo, @Cast("const Nd4jLong*") LongBuffer minShapeInfo, @Const IntBuffer dimsToExclude/*=nullptr*/, int dimsLen/*=-1*/);
|
||||
@Namespace("shape") public static native void maxIndToMinInd(IntBuffer maxIdxs, IntBuffer minIdxs, @Cast("const Nd4jLong*") LongBuffer maxShapeInfo, @Cast("const Nd4jLong*") LongBuffer minShapeInfo);
|
||||
@Namespace("shape") public static native void maxIndToMinInd(int[] maxIdxs, int[] minIdxs, @Cast("const Nd4jLong*") long[] maxShapeInfo, @Cast("const Nd4jLong*") long[] minShapeInfo, @Const int[] dimsToExclude/*=nullptr*/, int dimsLen/*=-1*/);
|
||||
@Namespace("shape") public static native void maxIndToMinInd(int[] maxIdxs, int[] minIdxs, @Cast("const Nd4jLong*") long[] maxShapeInfo, @Cast("const Nd4jLong*") long[] minShapeInfo);
|
||||
|
||||
// calculate indexes of max-array, these output indexes correspond to one minIdx index of min-array which is sub-array of max-array
|
||||
// dimsToExclude - should be sorted in increasing order
|
||||
@Namespace("shape") public static native int outerArrayIndexes(@Cast("Nd4jLong*") LongPointer maxIdxs, @Cast("const Nd4jLong") long minIdx, @Cast("const Nd4jLong*") LongPointer maxShapeInfo, @Cast("const Nd4jLong*") LongPointer minShapeInfo, @Const IntPointer dimsToExclude/*=nullptr*/);
|
||||
@Namespace("shape") public static native int outerArrayIndexes(@Cast("Nd4jLong*") LongPointer maxIdxs, @Cast("const Nd4jLong") long minIdx, @Cast("const Nd4jLong*") LongPointer maxShapeInfo, @Cast("const Nd4jLong*") LongPointer minShapeInfo);
|
||||
@Namespace("shape") public static native int outerArrayIndexes(@Cast("Nd4jLong*") LongBuffer maxIdxs, @Cast("const Nd4jLong") long minIdx, @Cast("const Nd4jLong*") LongBuffer maxShapeInfo, @Cast("const Nd4jLong*") LongBuffer minShapeInfo, @Const IntBuffer dimsToExclude/*=nullptr*/);
|
||||
@Namespace("shape") public static native int outerArrayIndexes(@Cast("Nd4jLong*") LongBuffer maxIdxs, @Cast("const Nd4jLong") long minIdx, @Cast("const Nd4jLong*") LongBuffer maxShapeInfo, @Cast("const Nd4jLong*") LongBuffer minShapeInfo);
|
||||
@Namespace("shape") public static native int outerArrayIndexes(@Cast("Nd4jLong*") long[] maxIdxs, @Cast("const Nd4jLong") long minIdx, @Cast("const Nd4jLong*") long[] maxShapeInfo, @Cast("const Nd4jLong*") long[] minShapeInfo, @Const int[] dimsToExclude/*=nullptr*/);
|
||||
@Namespace("shape") public static native int outerArrayIndexes(@Cast("Nd4jLong*") long[] maxIdxs, @Cast("const Nd4jLong") long minIdx, @Cast("const Nd4jLong*") long[] maxShapeInfo, @Cast("const Nd4jLong*") long[] minShapeInfo);
|
||||
@Namespace("shape") public static native int outerArrayIndexes(IntPointer maxIdxs, @Cast("const Nd4jLong") long minIdx, @Cast("const Nd4jLong*") LongPointer maxShapeInfo, @Cast("const Nd4jLong*") LongPointer minShapeInfo, @Const IntPointer dimsToExclude/*=nullptr*/);
|
||||
@Namespace("shape") public static native int outerArrayIndexes(IntPointer maxIdxs, @Cast("const Nd4jLong") long minIdx, @Cast("const Nd4jLong*") LongPointer maxShapeInfo, @Cast("const Nd4jLong*") LongPointer minShapeInfo);
|
||||
@Namespace("shape") public static native int outerArrayIndexes(IntBuffer maxIdxs, @Cast("const Nd4jLong") long minIdx, @Cast("const Nd4jLong*") LongBuffer maxShapeInfo, @Cast("const Nd4jLong*") LongBuffer minShapeInfo, @Const IntBuffer dimsToExclude/*=nullptr*/);
|
||||
@Namespace("shape") public static native int outerArrayIndexes(IntBuffer maxIdxs, @Cast("const Nd4jLong") long minIdx, @Cast("const Nd4jLong*") LongBuffer maxShapeInfo, @Cast("const Nd4jLong*") LongBuffer minShapeInfo);
|
||||
@Namespace("shape") public static native int outerArrayIndexes(int[] maxIdxs, @Cast("const Nd4jLong") long minIdx, @Cast("const Nd4jLong*") long[] maxShapeInfo, @Cast("const Nd4jLong*") long[] minShapeInfo, @Const int[] dimsToExclude/*=nullptr*/);
|
||||
@Namespace("shape") public static native int outerArrayIndexes(int[] maxIdxs, @Cast("const Nd4jLong") long minIdx, @Cast("const Nd4jLong*") long[] maxShapeInfo, @Cast("const Nd4jLong*") long[] minShapeInfo);
|
||||
|
||||
// calculate offsets of max-array, these offsets correspond to one minIdx index of min-array which is sub-array of max-array
|
||||
// maxOffsets - will contain calculated offsets of max-array, buffer for maxOffsets should be allocated beforehand
|
||||
// dimsToExclude - should be sorted in increasing order
|
||||
// memBuff - auxiliary memory buffer (size = 2 * max_rank) for coordinates and increments storing, should be allocated beforehand
|
||||
@Namespace("shape") public static native int outerArrayOffsets(@Cast("Nd4jLong*") LongPointer maxOffsets, @Cast("const Nd4jLong") long minIdx, @Cast("const Nd4jLong*") LongPointer maxShapeInfo, @Cast("const Nd4jLong*") LongPointer minShapeInfo, @Cast("Nd4jLong*") LongPointer memBuff, @Const IntPointer dimsToExclude/*=nullptr*/);
|
||||
@Namespace("shape") public static native int outerArrayOffsets(@Cast("Nd4jLong*") LongPointer maxOffsets, @Cast("const Nd4jLong") long minIdx, @Cast("const Nd4jLong*") LongPointer maxShapeInfo, @Cast("const Nd4jLong*") LongPointer minShapeInfo, @Cast("Nd4jLong*") LongPointer memBuff);
|
||||
@Namespace("shape") public static native int outerArrayOffsets(@Cast("Nd4jLong*") LongBuffer maxOffsets, @Cast("const Nd4jLong") long minIdx, @Cast("const Nd4jLong*") LongBuffer maxShapeInfo, @Cast("const Nd4jLong*") LongBuffer minShapeInfo, @Cast("Nd4jLong*") LongBuffer memBuff, @Const IntBuffer dimsToExclude/*=nullptr*/);
|
||||
@Namespace("shape") public static native int outerArrayOffsets(@Cast("Nd4jLong*") LongBuffer maxOffsets, @Cast("const Nd4jLong") long minIdx, @Cast("const Nd4jLong*") LongBuffer maxShapeInfo, @Cast("const Nd4jLong*") LongBuffer minShapeInfo, @Cast("Nd4jLong*") LongBuffer memBuff);
|
||||
@Namespace("shape") public static native int outerArrayOffsets(@Cast("Nd4jLong*") long[] maxOffsets, @Cast("const Nd4jLong") long minIdx, @Cast("const Nd4jLong*") long[] maxShapeInfo, @Cast("const Nd4jLong*") long[] minShapeInfo, @Cast("Nd4jLong*") long[] memBuff, @Const int[] dimsToExclude/*=nullptr*/);
|
||||
@Namespace("shape") public static native int outerArrayOffsets(@Cast("Nd4jLong*") long[] maxOffsets, @Cast("const Nd4jLong") long minIdx, @Cast("const Nd4jLong*") long[] maxShapeInfo, @Cast("const Nd4jLong*") long[] minShapeInfo, @Cast("Nd4jLong*") long[] memBuff);
|
||||
@Namespace("shape") public static native int outerArrayOffsets(@Cast("Nd4jLong*") LongPointer maxOffsets, @Cast("const Nd4jLong") long minIdx, @Cast("const Nd4jLong*") LongPointer maxShapeInfo, @Cast("const Nd4jLong*") LongPointer minShapeInfo, IntPointer memBuff, @Const IntPointer dimsToExclude/*=nullptr*/);
|
||||
@Namespace("shape") public static native int outerArrayOffsets(@Cast("Nd4jLong*") LongPointer maxOffsets, @Cast("const Nd4jLong") long minIdx, @Cast("const Nd4jLong*") LongPointer maxShapeInfo, @Cast("const Nd4jLong*") LongPointer minShapeInfo, IntPointer memBuff);
|
||||
@Namespace("shape") public static native int outerArrayOffsets(@Cast("Nd4jLong*") LongBuffer maxOffsets, @Cast("const Nd4jLong") long minIdx, @Cast("const Nd4jLong*") LongBuffer maxShapeInfo, @Cast("const Nd4jLong*") LongBuffer minShapeInfo, IntBuffer memBuff, @Const IntBuffer dimsToExclude/*=nullptr*/);
|
||||
@Namespace("shape") public static native int outerArrayOffsets(@Cast("Nd4jLong*") LongBuffer maxOffsets, @Cast("const Nd4jLong") long minIdx, @Cast("const Nd4jLong*") LongBuffer maxShapeInfo, @Cast("const Nd4jLong*") LongBuffer minShapeInfo, IntBuffer memBuff);
|
||||
@Namespace("shape") public static native int outerArrayOffsets(@Cast("Nd4jLong*") long[] maxOffsets, @Cast("const Nd4jLong") long minIdx, @Cast("const Nd4jLong*") long[] maxShapeInfo, @Cast("const Nd4jLong*") long[] minShapeInfo, int[] memBuff, @Const int[] dimsToExclude/*=nullptr*/);
|
||||
@Namespace("shape") public static native int outerArrayOffsets(@Cast("Nd4jLong*") long[] maxOffsets, @Cast("const Nd4jLong") long minIdx, @Cast("const Nd4jLong*") long[] maxShapeInfo, @Cast("const Nd4jLong*") long[] minShapeInfo, int[] memBuff);
|
||||
|
||||
// calculates offsets for entities (elements or sub-arrays), shape in context of sub-array means dimensions excluded from outer array
|
||||
// rank is equal to size of shape
|
||||
|
@ -8055,9 +8064,7 @@ public static final int PREALLOC_SIZE = 33554432;
|
|||
* get stride over contiguous axis (contiguous axis must have stride = 1)
|
||||
* for example when inShapeInfo is {4, 2,5,4,3, 60,1,5,20, 16384,0,99} then output is 5 (that is smallest stride in inShapeInfo except those equal to 1)
|
||||
*/
|
||||
@Namespace("shape") public static native @Cast("Nd4jLong") long strideOverContigAxis(int axis, @Cast("const Nd4jLong*") LongPointer inShapeInfo);
|
||||
@Namespace("shape") public static native @Cast("Nd4jLong") long strideOverContigAxis(int axis, @Cast("const Nd4jLong*") LongBuffer inShapeInfo);
|
||||
@Namespace("shape") public static native @Cast("Nd4jLong") long strideOverContigAxis(int axis, @Cast("const Nd4jLong*") long[] inShapeInfo);
|
||||
// INLINEDEF _CUDA_HD Nd4jLong strideOverContigAxis(const int axis, const Nd4jLong* inShapeInfo);
|
||||
|
||||
|
||||
|
||||
|
@ -8933,6 +8940,10 @@ public static final int PREALLOC_SIZE = 33554432;
|
|||
|
||||
//////////////////////////////////////////////////////////////////////
|
||||
|
||||
//////////////////////////////////////////////////////////////////////
|
||||
|
||||
//////////////////////////////////////////////////////////////////////
|
||||
|
||||
//////////////////////////////////////////////////////////////////////
|
||||
// INLINEDEF _CUDA_HD void calcOffsets(const Nd4jLong *xShapeInfo, Nd4jLong*& xOffsets, const Nd4jLong *yShapeInfo, Nd4jLong*& yOffsets, const Nd4jLong* zShapeInfo, Nd4jLong*& zOffsets, const char order) {
|
||||
|
||||
|
@ -9129,6 +9140,23 @@ public static final int PREALLOC_SIZE = 33554432;
|
|||
|
||||
|
||||
//////////////////////////////////////////////////////////////////////
|
||||
// INLINEDEF _CUDA_HD Nd4jLong strideOverContigAxis(const int axis, const Nd4jLong* inShapeInfo) {
|
||||
|
||||
// Nd4jLong result = 9223372036854775807LL;
|
||||
|
||||
// for(uint i = 0; i < shape::rank(inShapeInfo); ++i) {
|
||||
|
||||
// const auto currentStride = shape::stride(inShapeInfo)[i];
|
||||
|
||||
// if(i == axis || shape::shapeOf(inShapeInfo)[i] == 1)
|
||||
// continue;
|
||||
|
||||
// if(result > currentStride)
|
||||
// result = currentStride;
|
||||
// }
|
||||
|
||||
// return result == 9223372036854775807LL ? 1 : result;
|
||||
// }
|
||||
|
||||
|
||||
|
||||
|
@ -11949,18 +11977,17 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD();
|
|||
public native @Cast("Nd4jStatus") int execute(@Const @ByRef NDArrayVector inputs, @Const @ByRef NDArrayVector outputs, @StdVector DoubleBuffer tArgs, @Cast("Nd4jLong*") @StdVector LongBuffer iArgs, @Cast("bool*") @StdVector BooleanPointer bArgs/*=std::vector<bool>()*/, @Cast("sd::DataType*") @StdVector IntBuffer dArgs/*=std::vector<sd::DataType>()*/, @Cast("bool") boolean isInplace/*=false*/);
|
||||
public native @Cast("Nd4jStatus") int execute(@Const @ByRef NDArrayVector inputs, @Const @ByRef NDArrayVector outputs, @StdVector double[] tArgs, @Cast("Nd4jLong*") @StdVector long[] iArgs, @Cast("bool*") @StdVector boolean[] bArgs/*=std::vector<bool>()*/, @Cast("sd::DataType*") @StdVector int[] dArgs/*=std::vector<sd::DataType>()*/, @Cast("bool") boolean isInplace/*=false*/);
|
||||
|
||||
public native @ByVal ResultSet evaluate(@Const @ByRef NDArrayVector inputs);
|
||||
|
||||
public native ResultSet evaluate(@Const @ByRef NDArrayVector inputs);
|
||||
|
||||
public native ResultSet evaluate(@Const @ByRef NDArrayVector inputs, @StdVector DoublePointer tArgs, @Cast("Nd4jLong*") @StdVector LongPointer iArgs, @Cast("bool*") @StdVector BooleanPointer bArgs/*=std::vector<bool>()*/, @Cast("sd::DataType*") @StdVector IntPointer dArgs/*=std::vector<sd::DataType>()*/, @Cast("bool") boolean isInplace/*=false*/);
|
||||
public native ResultSet evaluate(@Const @ByRef NDArrayVector inputs, @StdVector DoublePointer tArgs, @Cast("Nd4jLong*") @StdVector LongPointer iArgs);
|
||||
public native ResultSet evaluate(@Const @ByRef NDArrayVector inputs, @StdVector DoubleBuffer tArgs, @Cast("Nd4jLong*") @StdVector LongBuffer iArgs, @Cast("bool*") @StdVector boolean[] bArgs/*=std::vector<bool>()*/, @Cast("sd::DataType*") @StdVector IntBuffer dArgs/*=std::vector<sd::DataType>()*/, @Cast("bool") boolean isInplace/*=false*/);
|
||||
public native ResultSet evaluate(@Const @ByRef NDArrayVector inputs, @StdVector DoubleBuffer tArgs, @Cast("Nd4jLong*") @StdVector LongBuffer iArgs);
|
||||
public native ResultSet evaluate(@Const @ByRef NDArrayVector inputs, @StdVector double[] tArgs, @Cast("Nd4jLong*") @StdVector long[] iArgs, @Cast("bool*") @StdVector BooleanPointer bArgs/*=std::vector<bool>()*/, @Cast("sd::DataType*") @StdVector int[] dArgs/*=std::vector<sd::DataType>()*/, @Cast("bool") boolean isInplace/*=false*/);
|
||||
public native ResultSet evaluate(@Const @ByRef NDArrayVector inputs, @StdVector double[] tArgs, @Cast("Nd4jLong*") @StdVector long[] iArgs);
|
||||
public native ResultSet evaluate(@Const @ByRef NDArrayVector inputs, @StdVector DoublePointer tArgs, @Cast("Nd4jLong*") @StdVector LongPointer iArgs, @Cast("bool*") @StdVector boolean[] bArgs/*=std::vector<bool>()*/, @Cast("sd::DataType*") @StdVector IntPointer dArgs/*=std::vector<sd::DataType>()*/, @Cast("bool") boolean isInplace/*=false*/);
|
||||
public native ResultSet evaluate(@Const @ByRef NDArrayVector inputs, @StdVector DoubleBuffer tArgs, @Cast("Nd4jLong*") @StdVector LongBuffer iArgs, @Cast("bool*") @StdVector BooleanPointer bArgs/*=std::vector<bool>()*/, @Cast("sd::DataType*") @StdVector IntBuffer dArgs/*=std::vector<sd::DataType>()*/, @Cast("bool") boolean isInplace/*=false*/);
|
||||
public native ResultSet evaluate(@Const @ByRef NDArrayVector inputs, @StdVector double[] tArgs, @Cast("Nd4jLong*") @StdVector long[] iArgs, @Cast("bool*") @StdVector boolean[] bArgs/*=std::vector<bool>()*/, @Cast("sd::DataType*") @StdVector int[] dArgs/*=std::vector<sd::DataType>()*/, @Cast("bool") boolean isInplace/*=false*/);
|
||||
public native @ByVal ResultSet evaluate(@Const @ByRef NDArrayVector inputs, @StdVector DoublePointer tArgs, @Cast("Nd4jLong*") @StdVector LongPointer iArgs, @Cast("bool*") @StdVector BooleanPointer bArgs/*=std::vector<bool>()*/, @Cast("sd::DataType*") @StdVector IntPointer dArgs/*=std::vector<sd::DataType>()*/, @Cast("bool") boolean isInplace/*=false*/);
|
||||
public native @ByVal ResultSet evaluate(@Const @ByRef NDArrayVector inputs, @StdVector DoublePointer tArgs, @Cast("Nd4jLong*") @StdVector LongPointer iArgs);
|
||||
public native @ByVal ResultSet evaluate(@Const @ByRef NDArrayVector inputs, @StdVector DoubleBuffer tArgs, @Cast("Nd4jLong*") @StdVector LongBuffer iArgs, @Cast("bool*") @StdVector boolean[] bArgs/*=std::vector<bool>()*/, @Cast("sd::DataType*") @StdVector IntBuffer dArgs/*=std::vector<sd::DataType>()*/, @Cast("bool") boolean isInplace/*=false*/);
|
||||
public native @ByVal ResultSet evaluate(@Const @ByRef NDArrayVector inputs, @StdVector DoubleBuffer tArgs, @Cast("Nd4jLong*") @StdVector LongBuffer iArgs);
|
||||
public native @ByVal ResultSet evaluate(@Const @ByRef NDArrayVector inputs, @StdVector double[] tArgs, @Cast("Nd4jLong*") @StdVector long[] iArgs, @Cast("bool*") @StdVector BooleanPointer bArgs/*=std::vector<bool>()*/, @Cast("sd::DataType*") @StdVector int[] dArgs/*=std::vector<sd::DataType>()*/, @Cast("bool") boolean isInplace/*=false*/);
|
||||
public native @ByVal ResultSet evaluate(@Const @ByRef NDArrayVector inputs, @StdVector double[] tArgs, @Cast("Nd4jLong*") @StdVector long[] iArgs);
|
||||
public native @ByVal ResultSet evaluate(@Const @ByRef NDArrayVector inputs, @StdVector DoublePointer tArgs, @Cast("Nd4jLong*") @StdVector LongPointer iArgs, @Cast("bool*") @StdVector boolean[] bArgs/*=std::vector<bool>()*/, @Cast("sd::DataType*") @StdVector IntPointer dArgs/*=std::vector<sd::DataType>()*/, @Cast("bool") boolean isInplace/*=false*/);
|
||||
public native @ByVal ResultSet evaluate(@Const @ByRef NDArrayVector inputs, @StdVector DoubleBuffer tArgs, @Cast("Nd4jLong*") @StdVector LongBuffer iArgs, @Cast("bool*") @StdVector BooleanPointer bArgs/*=std::vector<bool>()*/, @Cast("sd::DataType*") @StdVector IntBuffer dArgs/*=std::vector<sd::DataType>()*/, @Cast("bool") boolean isInplace/*=false*/);
|
||||
public native @ByVal ResultSet evaluate(@Const @ByRef NDArrayVector inputs, @StdVector double[] tArgs, @Cast("Nd4jLong*") @StdVector long[] iArgs, @Cast("bool*") @StdVector boolean[] bArgs/*=std::vector<bool>()*/, @Cast("sd::DataType*") @StdVector int[] dArgs/*=std::vector<sd::DataType>()*/, @Cast("bool") boolean isInplace/*=false*/);
|
||||
|
||||
public native @Cast("Nd4jStatus") int execute(@ByRef RandomGenerator rng, @Const @ByRef NDArrayVector inputs, @Const @ByRef NDArrayVector outputs, @StdVector DoublePointer tArgs, @Cast("Nd4jLong*") @StdVector LongPointer iArgs, @Cast("bool*") @StdVector BooleanPointer bArgs, @Cast("sd::DataType*") @StdVector IntPointer dArgs/*=std::vector<sd::DataType>()*/, @Cast("bool") boolean isInplace/*=false*/, @Cast("sd::DataType") int type/*=sd::DataType::FLOAT32*/);
|
||||
public native @Cast("Nd4jStatus") int execute(@ByRef RandomGenerator rng, @Const @ByRef NDArrayVector inputs, @Const @ByRef NDArrayVector outputs, @StdVector DoublePointer tArgs, @Cast("Nd4jLong*") @StdVector LongPointer iArgs, @Cast("bool*") @StdVector BooleanPointer bArgs);
|
||||
|
@ -11975,8 +12002,9 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD();
|
|||
public native @Cast("Nd4jStatus") int execute(@ByRef RandomGenerator rng, @Const @ByRef NDArrayVector inputs, @Const @ByRef NDArrayVector outputs, @StdVector double[] tArgs, @Cast("Nd4jLong*") @StdVector long[] iArgs, @Cast("bool*") @StdVector boolean[] bArgs, @Cast("sd::DataType*") @StdVector int[] dArgs/*=std::vector<sd::DataType>()*/, @Cast("bool") boolean isInplace/*=false*/, @Cast("sd::DataType") int type/*=sd::DataType::FLOAT32*/);
|
||||
public native @Cast("Nd4jStatus") int execute(@ByRef RandomGenerator rng, @Const @ByRef NDArrayVector inputs, @Const @ByRef NDArrayVector outputs, @StdVector double[] tArgs, @Cast("Nd4jLong*") @StdVector long[] iArgs, @Cast("bool*") @StdVector boolean[] bArgs);
|
||||
|
||||
public native ResultSet execute(@Const @ByRef OpArgsHolder holder, @Cast("bool") boolean isInplace/*=false*/);
|
||||
public native ResultSet execute(@Const @ByRef OpArgsHolder holder);
|
||||
public native @ByVal ResultSet execute(@Const @ByRef OpArgsHolder holder, @Cast("bool") boolean isInplace/*=false*/);
|
||||
public native @ByVal ResultSet execute(@Const @ByRef OpArgsHolder holder);
|
||||
|
||||
|
||||
// There methods provide various validation options
|
||||
public native @Cast("Nd4jStatus") int validateNonEmptyInput(@ByRef Context block);
|
||||
|
@ -12048,9 +12076,9 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD();
|
|||
|
||||
|
||||
public native @Cast("Nd4jStatus") int execute(Context block);
|
||||
public native ResultSet execute(NDArrayList list, @ByRef NDArrayVector inputs, @StdVector DoublePointer tArgs, @StdVector IntPointer iArgs);
|
||||
public native ResultSet execute(NDArrayList list, @ByRef NDArrayVector inputs, @StdVector DoubleBuffer tArgs, @StdVector IntBuffer iArgs);
|
||||
public native ResultSet execute(NDArrayList list, @ByRef NDArrayVector inputs, @StdVector double[] tArgs, @StdVector int[] iArgs);
|
||||
public native @ByVal ResultSet execute(NDArrayList list, @ByRef NDArrayVector inputs, @StdVector DoublePointer tArgs, @StdVector IntPointer iArgs);
|
||||
public native @ByVal ResultSet execute(NDArrayList list, @ByRef NDArrayVector inputs, @StdVector DoubleBuffer tArgs, @StdVector IntBuffer iArgs);
|
||||
public native @ByVal ResultSet execute(NDArrayList list, @ByRef NDArrayVector inputs, @StdVector double[] tArgs, @StdVector int[] iArgs);
|
||||
|
||||
public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block);
|
||||
}
|
||||
|
|
|
@ -2050,7 +2050,7 @@ public class ShapeOpValidation extends BaseOpValidation {
|
|||
print(out[0]);
|
||||
*/
|
||||
|
||||
INDArray emptyIn = Nd4j.empty(DataType.FLOAT);
|
||||
INDArray emptyIn = Nd4j.empty(DataType.FLOAT).reshape(0, 4);
|
||||
INDArray axis = Nd4j.scalar(1);
|
||||
|
||||
DynamicCustomOp op = DynamicCustomOp.builder("split")
|
||||
|
@ -2061,9 +2061,10 @@ public class ShapeOpValidation extends BaseOpValidation {
|
|||
List<LongShapeDescriptor> l = op.calculateOutputShape();
|
||||
assertEquals(4, l.size());
|
||||
for( int i=0; i<4; i++ ){
|
||||
assertArrayEquals(new long[0], l.get(i).getShape());
|
||||
assertTrue(l.get(i).isEmpty());
|
||||
op.addOutputArgument(Nd4j.empty(DataType.FLOAT));
|
||||
val desc = l.get(i);
|
||||
assertArrayEquals(new long[]{0, 1}, desc.getShape());
|
||||
assertTrue(desc.isEmpty());
|
||||
op.addOutputArgument(Nd4j.empty(DataType.FLOAT).reshape(desc.getShape()));
|
||||
}
|
||||
|
||||
Nd4j.exec(op);
|
||||
|
|
Loading…
Reference in New Issue