Refactoring split operation (#266)
* libnd4j moved split operation implementation to helpers before special case adding Signed-off-by: Oleg <oleg.semeniv@gmail.com> * libnd4j minor fixes for general split operation move, merge master Signed-off-by: Oleg <oleg.semeniv@gmail.com> * libndj4 split cpu implementation Signed-off-by: Oleg <oleg.semeniv@gmail.com> * - provide cuda helper for split op Signed-off-by: Yurii <iuriish@yahoo.com> * - minor correction Signed-off-by: Yurii <iuriish@yahoo.com> * - minor correction 2 Signed-off-by: Yurii <iuriish@yahoo.com> * libnd4j moved split implementation from specials to split.cpp Signed-off-by: Oleg <oleg.semeniv@gmail.com> * libnd4j update loopkind selections for 3D, 4D and 5D cases Signed-off-by: Oleg <oleg.semeniv@gmail.com> * libnd4j removed unnecessary BUILD_SINGLE_TEMPLATE Signed-off-by: Oleg <oleg.semeniv@gmail.com> Co-authored-by: Yurii Shyrma <iuriish@yahoo.com>master
parent
cf67c7165a
commit
b686368b82
|
@ -107,12 +107,22 @@ LoopKind::Kind LoopKind::deduceKindOfLoopBroadcast(const Nd4jLong* xShapeInfo, c
|
||||||
|
|
||||||
bool bNotCommonVectorCase = (countUnityDimsInY != yRank - 1) && (countUnityDimsInX != xRank - 1);
|
bool bNotCommonVectorCase = (countUnityDimsInY != yRank - 1) && (countUnityDimsInX != xRank - 1);
|
||||||
|
|
||||||
if (3 == xRank && bNDLoopsRanks && bNotCommonVectorCase)
|
|
||||||
return nd4j::LoopKind::BROADCAST_3D;
|
if (bNDLoopsRanks && bNotCommonVectorCase) {
|
||||||
if (4 == xRank && bNDLoopsRanks && bNotCommonVectorCase)
|
// case x[3,4,5] * y[1,4,5] = z[3,4,5] or reverse x[1,4,5] + y[3,4,5] = z[3,4,5]
|
||||||
return nd4j::LoopKind::BROADCAST_4D;
|
if (nd4j::LoopKind::EWS1 == deduceKindOfLoopXYZ(xShapeInfo, yShapeInfo, zShapeInfo)
|
||||||
if (5 == xRank && bNDLoopsRanks && bNotCommonVectorCase)
|
&& (1 == shape::sizeAt(yShapeInfo, 0) || 1 == shape::sizeAt(xShapeInfo, 0))) {
|
||||||
return nd4j::LoopKind::BROADCAST_5D;
|
return EWS1;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (3 == xRank)
|
||||||
|
return nd4j::LoopKind::BROADCAST_3D;
|
||||||
|
if (4 == xRank)
|
||||||
|
return nd4j::LoopKind::BROADCAST_4D;
|
||||||
|
if (5 == xRank)
|
||||||
|
return nd4j::LoopKind::BROADCAST_5D;
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
if (xRank == yRank && xRank == zRank && xOrder == 'c' && yOrder == 'c' && zOrder == 'c' && xEws == 1 && yEws == 1 && zEws == 1 && xRank >= 2) {
|
if (xRank == yRank && xRank == zRank && xOrder == 'c' && yOrder == 'c' && zOrder == 'c' && xEws == 1 && yEws == 1 && zEws == 1 && xRank >= 2) {
|
||||||
|
|
|
@ -18,26 +18,109 @@
|
||||||
// @author Oleh Semeniv (oleg.semeniv@gmail.com)
|
// @author Oleh Semeniv (oleg.semeniv@gmail.com)
|
||||||
//
|
//
|
||||||
|
|
||||||
|
|
||||||
#include <ops/declarable/helpers/transforms.h>
|
#include <ops/declarable/helpers/transforms.h>
|
||||||
#include <ops/specials.h>
|
#include <helpers/Loops.h>
|
||||||
|
|
||||||
namespace nd4j {
|
namespace nd4j {
|
||||||
namespace ops {
|
namespace ops {
|
||||||
namespace helpers {
|
namespace helpers {
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////////
|
|
||||||
template<typename T>
|
|
||||||
static void split_(const NDArray& input, const std::vector<NDArray*>& outArrs, const int axis) {
|
|
||||||
nd4j::SpecialMethods<T>::splitCpuGeneric(input, outArrs, axis);
|
|
||||||
}
|
|
||||||
|
|
||||||
void split(nd4j::LaunchContext* context, const NDArray& input, std::vector<NDArray*>& outArrs, const int axis) {
|
//////////////////////////////////////////////////////////////////////////
|
||||||
BUILD_SINGLE_SELECTOR(input.dataType(), split_, (input, outArrs, axis), LIBND4J_TYPES);
|
template <typename T>
|
||||||
}
|
static void split_(const NDArray& input, const std::vector<NDArray*>& outArrs, const int axis) {
|
||||||
|
int numSplits = outArrs.size();
|
||||||
|
|
||||||
BUILD_SINGLE_TEMPLATE(template void split_, (const NDArray& input, const std::vector<NDArray*>& outArrs, const int axis), LIBND4J_TYPES);
|
const auto sizeofT = input.sizeOfT();
|
||||||
|
|
||||||
|
T* xBuff = input.bufferAsT<T>();
|
||||||
|
|
||||||
|
bool luckCase1 = ((axis == 0 && input.ordering() == 'c') || (axis == input.rankOf() - 1 && input.ordering() == 'f')) && input.ews() == 1;
|
||||||
|
|
||||||
|
if (luckCase1) {
|
||||||
|
for (uint i = 0; i < numSplits; ++i) {
|
||||||
|
luckCase1 &= outArrs[i]->ordering() == input.ordering() && outArrs[i]->ews() == 1;
|
||||||
|
if (!luckCase1)
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (luckCase1) {
|
||||||
|
|
||||||
|
T* x = const_cast<T*>(xBuff);
|
||||||
|
for (uint i = 0; i < numSplits; ++i) {
|
||||||
|
const auto memAmountToCopy = outArrs[i]->lengthOf();
|
||||||
|
memcpy(outArrs[i]->bufferAsT<T>(), x, memAmountToCopy * sizeofT);
|
||||||
|
x += memAmountToCopy;
|
||||||
|
}
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const bool isXcontin = input.strideAt(axis) == 1 && input.ordering() == 'c';
|
||||||
|
bool areOutsContin = true;
|
||||||
|
bool allSameOrder = true;
|
||||||
|
|
||||||
|
if (isXcontin) {
|
||||||
|
for (uint i = 0; i < numSplits; ++i) {
|
||||||
|
areOutsContin &= outArrs[i]->strideAt(axis) == 1;
|
||||||
|
allSameOrder &= outArrs[i]->ordering() == input.ordering();
|
||||||
|
if (!areOutsContin || !allSameOrder)
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const bool luckCase2 = isXcontin && areOutsContin && allSameOrder;
|
||||||
|
|
||||||
|
if (luckCase2) {
|
||||||
|
|
||||||
|
const uint xDim = input.sizeAt(axis);
|
||||||
|
|
||||||
|
for (uint i = 0; i < input.lengthOf() / xDim; ++i) {
|
||||||
|
|
||||||
|
T* x = xBuff + xDim * i;
|
||||||
|
|
||||||
|
for (uint j = 0; j < numSplits; ++j) {
|
||||||
|
const auto zDim = outArrs[j]->sizeAt(axis);
|
||||||
|
T* z = outArrs[j]->bufferAsT<T>() + zDim * i;
|
||||||
|
memcpy(z, x, zDim * sizeofT);
|
||||||
|
z += zDim;
|
||||||
|
x += zDim;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
uint zDim = outArrs[0]->sizeAt(axis);
|
||||||
|
// general case
|
||||||
|
|
||||||
|
auto func = PRAGMA_THREADS_FOR{
|
||||||
|
|
||||||
|
Nd4jLong coords[MAX_RANK];
|
||||||
|
for (auto i = start; i < stop; i += increment) {
|
||||||
|
|
||||||
|
shape::index2coords(i, input.getShapeInfo(), coords);
|
||||||
|
const auto xOffset = shape::getOffset(input.getShapeInfo(), coords);
|
||||||
|
|
||||||
|
uint outArrIdx = 0;
|
||||||
|
|
||||||
|
while (coords[axis] >= zDim) {
|
||||||
|
coords[axis] -= zDim;
|
||||||
|
++outArrIdx;
|
||||||
|
}
|
||||||
|
|
||||||
|
T* z = outArrs[outArrIdx]->bufferAsT<T>();
|
||||||
|
const auto zOffset = shape::getOffset(outArrs[outArrIdx]->getShapeInfo(), coords);
|
||||||
|
z[zOffset] = xBuff[xOffset];
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
samediff::Threads::parallel_for(func, 0, input.lengthOf());
|
||||||
|
}
|
||||||
|
|
||||||
|
void split(nd4j::LaunchContext* context, const NDArray& input, std::vector<NDArray*>& outArrs, const int axis) {
|
||||||
|
BUILD_SINGLE_SELECTOR(input.dataType(), split_, (input, outArrs, axis), LIBND4J_TYPES);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
|
||||||
}
|
|
|
@ -217,100 +217,6 @@ void SpecialMethods<T>::concatCpuGeneric(int dimension, int numArrays, Nd4jPoint
|
||||||
delete inputs[i];
|
delete inputs[i];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
template <typename T>
|
|
||||||
void SpecialMethods<T>::splitCpuGeneric(const NDArray& input, const std::vector<NDArray*>& outArrs, const int axis) {
|
|
||||||
|
|
||||||
int numSplits = outArrs.size();
|
|
||||||
|
|
||||||
const auto sizeofT = input.sizeOfT();
|
|
||||||
|
|
||||||
T* xBuff = input.bufferAsT<T>();
|
|
||||||
|
|
||||||
bool luckCase1 = ((axis == 0 && input.ordering() == 'c') || (axis == input.rankOf() - 1 && input.ordering() == 'f')) && input.ews() == 1;
|
|
||||||
|
|
||||||
if (luckCase1) {
|
|
||||||
for (uint i = 0; i < numSplits; ++i) {
|
|
||||||
luckCase1 &= outArrs[i]->ordering() == input.ordering() && outArrs[i]->ews() == 1;
|
|
||||||
if (!luckCase1)
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if (luckCase1) {
|
|
||||||
|
|
||||||
T* x = const_cast<T*>(xBuff);
|
|
||||||
for (uint i = 0; i < numSplits; ++i) {
|
|
||||||
const auto memAmountToCopy = outArrs[i]->lengthOf();
|
|
||||||
memcpy(outArrs[i]->bufferAsT<T>(), x, memAmountToCopy * sizeofT);
|
|
||||||
x += memAmountToCopy;
|
|
||||||
}
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
const bool isXcontin = input.strideAt(axis) == 1 && input.ordering() == 'c';
|
|
||||||
bool areOutsContin = true;
|
|
||||||
bool allSameOrder = true;
|
|
||||||
|
|
||||||
if (isXcontin) {
|
|
||||||
for (uint i = 0; i < numSplits; ++i) {
|
|
||||||
areOutsContin &= outArrs[i]->strideAt(axis) == 1;
|
|
||||||
allSameOrder &= outArrs[i]->ordering() == input.ordering();
|
|
||||||
if (!areOutsContin || !allSameOrder)
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
const bool luckCase2 = isXcontin && areOutsContin && allSameOrder;
|
|
||||||
|
|
||||||
if (luckCase2) {
|
|
||||||
|
|
||||||
const uint xDim = input.sizeAt(axis);
|
|
||||||
|
|
||||||
for (uint i = 0; i < input.lengthOf() / xDim; ++i) {
|
|
||||||
|
|
||||||
T* x = xBuff + xDim * i;
|
|
||||||
|
|
||||||
for (uint j = 0; j < numSplits; ++j) {
|
|
||||||
const auto zDim = outArrs[j]->sizeAt(axis);
|
|
||||||
T* z = outArrs[j]->bufferAsT<T>() + zDim * i;
|
|
||||||
memcpy(z, x, zDim * sizeofT);
|
|
||||||
z += zDim;
|
|
||||||
x += zDim;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
uint zDim = outArrs[0]->sizeAt(axis);
|
|
||||||
// general case
|
|
||||||
|
|
||||||
auto func = PRAGMA_THREADS_FOR{
|
|
||||||
|
|
||||||
Nd4jLong coords[MAX_RANK];
|
|
||||||
for (auto i = start; i < stop; i += increment) {
|
|
||||||
|
|
||||||
shape::index2coords(i, input.getShapeInfo(), coords);
|
|
||||||
const auto xOffset = shape::getOffset(input.getShapeInfo(), coords);
|
|
||||||
|
|
||||||
uint outArrIdx = 0;
|
|
||||||
|
|
||||||
while (coords[axis] >= zDim) {
|
|
||||||
coords[axis] -= zDim;
|
|
||||||
++outArrIdx;
|
|
||||||
}
|
|
||||||
|
|
||||||
T* z = outArrs[outArrIdx]->bufferAsT<T>();
|
|
||||||
const auto zOffset = shape::getOffset(outArrs[outArrIdx]->getShapeInfo(), coords);
|
|
||||||
z[zOffset] = xBuff[xOffset];
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
samediff::Threads::parallel_for(func, 0, input.lengthOf());
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* This kernel accumulates X arrays, and stores result into Z
|
* This kernel accumulates X arrays, and stores result into Z
|
||||||
*
|
*
|
||||||
|
|
|
@ -68,7 +68,6 @@ namespace nd4j {
|
||||||
static void decodeBitmapGeneric(void *dx, Nd4jLong N, void *dz, Nd4jLong *zShapeInfo);
|
static void decodeBitmapGeneric(void *dx, Nd4jLong N, void *dz, Nd4jLong *zShapeInfo);
|
||||||
static Nd4jLong encodeBitmapGeneric(void *dx, Nd4jLong *zShapeInfo, Nd4jLong N, int *dz, float threshold);
|
static Nd4jLong encodeBitmapGeneric(void *dx, Nd4jLong *zShapeInfo, Nd4jLong N, int *dz, float threshold);
|
||||||
|
|
||||||
static void splitCpuGeneric(const NDArray& input, const std::vector<NDArray*>& outArrs, const int axis);
|
|
||||||
};
|
};
|
||||||
|
|
||||||
template <typename X, typename Y>
|
template <typename X, typename Y>
|
||||||
|
|
Loading…
Reference in New Issue