Shyrma scatter (#84)
* - improve performance of scatter (no lock) ops for 1D case Signed-off-by: Yurii <iuriish@yahoo.com> * - improve scatter lock op performance for 1D case Signed-off-by: Yurii <iuriish@yahoo.com> * - add kernel for verification of input indices-array elements in scatter and scatter_nd ops Signed-off-by: Yurii <iuriish@yahoo.com> * - provide fast indices checking on cpu side for scatter and gather osp Signed-off-by: Yurii <iuriish@yahoo.com> * - apply corrections requested by pr reviewer Signed-off-by: Yurii <iuriish@yahoo.com>master
parent
8843c7377a
commit
a8dd6713aa
|
@ -533,7 +533,7 @@ namespace shape {
|
||||||
* the given shape info buffer
|
* the given shape info buffer
|
||||||
* represents a scalar shape
|
* represents a scalar shape
|
||||||
*/
|
*/
|
||||||
ND4J_EXPORT _CUDA_HD int isScalar(Nd4jLong *info);
|
ND4J_EXPORT _CUDA_HD int isScalar(const Nd4jLong *info);
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Returns whether
|
* Returns whether
|
||||||
|
@ -904,6 +904,7 @@ namespace shape {
|
||||||
ND4J_EXPORT _CUDA_HD void index2coords(Nd4jLong index, const Nd4jLong *shapeInfo, int *coords);
|
ND4J_EXPORT _CUDA_HD void index2coords(Nd4jLong index, const Nd4jLong *shapeInfo, int *coords);
|
||||||
ND4J_EXPORT _CUDA_HD void index2coords(Nd4jLong index, const Nd4jLong *shapeInfo, uint *coords);
|
ND4J_EXPORT _CUDA_HD void index2coords(Nd4jLong index, const Nd4jLong *shapeInfo, uint *coords);
|
||||||
ND4J_EXPORT _CUDA_HD void index2coords(Nd4jLong index, const int rank, const Nd4jLong *shape, Nd4jLong *coords);
|
ND4J_EXPORT _CUDA_HD void index2coords(Nd4jLong index, const int rank, const Nd4jLong *shape, Nd4jLong *coords);
|
||||||
|
ND4J_EXPORT _CUDA_HD void index2coords(Nd4jLong index, const int rank, const Nd4jLong *shape, int *coords);
|
||||||
/**
|
/**
|
||||||
* take into account only dimensions stored in tadDims, tadDims must be sorted in increasing order!
|
* take into account only dimensions stored in tadDims, tadDims must be sorted in increasing order!
|
||||||
*/
|
*/
|
||||||
|
@ -2706,7 +2707,7 @@ INLINEDEF _CUDA_HD Nd4jLong indexOffset(Nd4jLong index, const Nd4jLong* lShapeIn
|
||||||
* the given shape info buffer
|
* the given shape info buffer
|
||||||
* represents a scalar shape
|
* represents a scalar shape
|
||||||
*/
|
*/
|
||||||
INLINEDEF _CUDA_HD int isScalar(Nd4jLong *info) {
|
INLINEDEF _CUDA_HD int isScalar(const Nd4jLong *info) {
|
||||||
|
|
||||||
const int rank = shape::rank(info);
|
const int rank = shape::rank(info);
|
||||||
|
|
||||||
|
@ -2715,9 +2716,9 @@ INLINEDEF _CUDA_HD Nd4jLong indexOffset(Nd4jLong index, const Nd4jLong* lShapeIn
|
||||||
if(rank == 0)
|
if(rank == 0)
|
||||||
return 1;
|
return 1;
|
||||||
if(rank == 1)
|
if(rank == 1)
|
||||||
return shape::shapeOf(info)[0] == 1;
|
return shape::shapeOf(const_cast<Nd4jLong*>(info))[0] == 1;
|
||||||
if(rank == 2)
|
if(rank == 2)
|
||||||
return shape::shapeOf(info)[0] == 1 && shape::shapeOf(info)[1] == 1;
|
return shape::shapeOf(const_cast<Nd4jLong*>(info))[0] == 1 && shape::shapeOf(const_cast<Nd4jLong*>(info))[1] == 1;
|
||||||
|
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
@ -4793,6 +4794,16 @@ INLINEDEF void _CUDA_HD index2coords(Nd4jLong index, const int rank, const Nd4jL
|
||||||
coords[0] = index; // last iteration
|
coords[0] = index; // last iteration
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//////////////////////////////////////////////////////////////////////
|
||||||
|
INLINEDEF void _CUDA_HD index2coords(Nd4jLong index, const int rank, const Nd4jLong *shape, int *coords) {
|
||||||
|
|
||||||
|
for(uint i = rank - 1; i > 0; --i) {
|
||||||
|
coords[i] = index % shape[i];
|
||||||
|
index /= shape[i];
|
||||||
|
}
|
||||||
|
coords[0] = index; // last iteration
|
||||||
|
}
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////
|
||||||
INLINEDEF void _CUDA_HD index2coords(Nd4jLong index, const Nd4jLong *shapeInfo, Nd4jLong *coords, const int dimsSize, const int* tadDims) {
|
INLINEDEF void _CUDA_HD index2coords(Nd4jLong index, const Nd4jLong *shapeInfo, Nd4jLong *coords, const int dimsSize, const int* tadDims) {
|
||||||
|
|
||||||
|
|
|
@ -39,6 +39,7 @@ OP_IMPL(scatter_add, 3, 1, true) {
|
||||||
output->assign(input);
|
output->assign(input);
|
||||||
|
|
||||||
const bool lock = block.getBArguments()->empty() ? false : B_ARG(0);
|
const bool lock = block.getBArguments()->empty() ? false : B_ARG(0);
|
||||||
|
const bool checkIndices = block.getBArguments()->size() <= 1 ? false : B_ARG(1);
|
||||||
|
|
||||||
const int inRank = input->rankOf();
|
const int inRank = input->rankOf();
|
||||||
const int indRank = indices->rankOf();
|
const int indRank = indices->rankOf();
|
||||||
|
@ -71,8 +72,15 @@ OP_IMPL(scatter_add, 3, 1, true) {
|
||||||
REQUIRE_TRUE(expectedUpdShape == updShape, 0, "SCATTER_ADD OP: wrong shape of updates array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedUpdShape).c_str(), ShapeUtils::shapeAsString(updShape).c_str());
|
REQUIRE_TRUE(expectedUpdShape == updShape, 0, "SCATTER_ADD OP: wrong shape of updates array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedUpdShape).c_str(), ShapeUtils::shapeAsString(updShape).c_str());
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!indices->isEmpty())
|
if (!indices->isEmpty()) {
|
||||||
|
|
||||||
|
if(checkIndices) {
|
||||||
|
const Nd4jLong numOfBadIndx = helpers::checkIndices(block.launchContext(), *indices, *output, 0);
|
||||||
|
REQUIRE_TRUE(numOfBadIndx == 0, 0, "SCATTER_ADD OP: please check elements of indices-array, total number of wrong elements is %lld!", numOfBadIndx);
|
||||||
|
}
|
||||||
|
|
||||||
helpers::scatter(block.launchContext(), pairwise::Add, *indices, *updates, *output, lock);
|
helpers::scatter(block.launchContext(), pairwise::Add, *indices, *updates, *output, lock);
|
||||||
|
}
|
||||||
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
|
@ -39,6 +39,7 @@ namespace nd4j {
|
||||||
output->assign(input);
|
output->assign(input);
|
||||||
|
|
||||||
const bool lock = block.getBArguments()->empty() ? false : B_ARG(0);
|
const bool lock = block.getBArguments()->empty() ? false : B_ARG(0);
|
||||||
|
const bool checkIndices = block.getBArguments()->size() <= 1 ? false : B_ARG(1);
|
||||||
|
|
||||||
const int inRank = input->rankOf();
|
const int inRank = input->rankOf();
|
||||||
const int indRank = indices->rankOf();
|
const int indRank = indices->rankOf();
|
||||||
|
@ -70,9 +71,15 @@ namespace nd4j {
|
||||||
REQUIRE_TRUE(expectedUpdShape == updShape, 0, "SCATTER_DIV OP: wrong shape of updates array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedUpdShape).c_str(), ShapeUtils::shapeAsString(updShape).c_str());
|
REQUIRE_TRUE(expectedUpdShape == updShape, 0, "SCATTER_DIV OP: wrong shape of updates array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedUpdShape).c_str(), ShapeUtils::shapeAsString(updShape).c_str());
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!indices->isEmpty())
|
if (!indices->isEmpty()) {
|
||||||
// ScatterHelper<T>::template scatterApply<simdOps::Divide<T>>(output, indices, updates);
|
|
||||||
|
if(checkIndices) {
|
||||||
|
const Nd4jLong numOfBadIndx = helpers::checkIndices(block.launchContext(), *indices, *output, 0);
|
||||||
|
REQUIRE_TRUE(numOfBadIndx == 0, 0, "SCATTER_DIV OP: please check elements of indices-array, total number of wrong elements is %lld!", numOfBadIndx);
|
||||||
|
}
|
||||||
|
|
||||||
helpers::scatter(block.launchContext(), pairwise::Divide, *indices, *updates, *output, lock);
|
helpers::scatter(block.launchContext(), pairwise::Divide, *indices, *updates, *output, lock);
|
||||||
|
}
|
||||||
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
|
@ -39,6 +39,7 @@ OP_IMPL(scatter_max, 3, 1, true) {
|
||||||
output->assign(input);
|
output->assign(input);
|
||||||
|
|
||||||
const bool lock = block.getBArguments()->empty() ? false : B_ARG(0);
|
const bool lock = block.getBArguments()->empty() ? false : B_ARG(0);
|
||||||
|
const bool checkIndices = block.getBArguments()->size() <= 1 ? false : B_ARG(1);
|
||||||
|
|
||||||
const int inRank = input->rankOf();
|
const int inRank = input->rankOf();
|
||||||
const int indRank = indices->rankOf();
|
const int indRank = indices->rankOf();
|
||||||
|
@ -70,8 +71,15 @@ OP_IMPL(scatter_max, 3, 1, true) {
|
||||||
REQUIRE_TRUE(expectedUpdShape == updShape, 0, "SCATTER_MAX OP: wrong shape of updates array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedUpdShape).c_str(), ShapeUtils::shapeAsString(updShape).c_str());
|
REQUIRE_TRUE(expectedUpdShape == updShape, 0, "SCATTER_MAX OP: wrong shape of updates array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedUpdShape).c_str(), ShapeUtils::shapeAsString(updShape).c_str());
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!indices->isEmpty())
|
if (!indices->isEmpty()) {
|
||||||
|
|
||||||
|
if(checkIndices) {
|
||||||
|
const Nd4jLong numOfBadIndx = helpers::checkIndices(block.launchContext(), *indices, *output, 0);
|
||||||
|
REQUIRE_TRUE(numOfBadIndx == 0, 0, "SCATTER_MAX OP: please check elements of indices-array, total number of wrong elements is %lld!", numOfBadIndx);
|
||||||
|
}
|
||||||
|
|
||||||
helpers::scatter(block.launchContext(), pairwise::MaxPairwise, *indices, *updates, *output, lock);
|
helpers::scatter(block.launchContext(), pairwise::MaxPairwise, *indices, *updates, *output, lock);
|
||||||
|
}
|
||||||
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
|
@ -39,6 +39,7 @@ OP_IMPL(scatter_min, 3, 1, true) {
|
||||||
output->assign(input);
|
output->assign(input);
|
||||||
|
|
||||||
const bool lock = block.getBArguments()->empty() ? false : B_ARG(0);
|
const bool lock = block.getBArguments()->empty() ? false : B_ARG(0);
|
||||||
|
const bool checkIndices = block.getBArguments()->size() <= 1 ? false : B_ARG(1);
|
||||||
|
|
||||||
const int inRank = input->rankOf();
|
const int inRank = input->rankOf();
|
||||||
const int indRank = indices->rankOf();
|
const int indRank = indices->rankOf();
|
||||||
|
@ -70,8 +71,15 @@ OP_IMPL(scatter_min, 3, 1, true) {
|
||||||
REQUIRE_TRUE(expectedUpdShape == updShape, 0, "SCATTER_MIN OP: wrong shape of updates array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedUpdShape).c_str(), ShapeUtils::shapeAsString(updShape).c_str());
|
REQUIRE_TRUE(expectedUpdShape == updShape, 0, "SCATTER_MIN OP: wrong shape of updates array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedUpdShape).c_str(), ShapeUtils::shapeAsString(updShape).c_str());
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!indices->isEmpty())
|
if (!indices->isEmpty()) {
|
||||||
|
|
||||||
|
if(checkIndices) {
|
||||||
|
const Nd4jLong numOfBadIndx = helpers::checkIndices(block.launchContext(), *indices, *output, 0);
|
||||||
|
REQUIRE_TRUE(numOfBadIndx == 0, 0, "SCATTER_MIN OP: please check elements of indices-array, total number of wrong elements is %lld!", numOfBadIndx);
|
||||||
|
}
|
||||||
|
|
||||||
helpers::scatter(block.launchContext(), pairwise::MinPairwise, *indices, *updates, *output, lock);
|
helpers::scatter(block.launchContext(), pairwise::MinPairwise, *indices, *updates, *output, lock);
|
||||||
|
}
|
||||||
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
|
@ -35,6 +35,7 @@ namespace nd4j {
|
||||||
auto output = OUTPUT_VARIABLE(0);
|
auto output = OUTPUT_VARIABLE(0);
|
||||||
|
|
||||||
const bool lock = block.getBArguments()->empty() ? false : B_ARG(0);
|
const bool lock = block.getBArguments()->empty() ? false : B_ARG(0);
|
||||||
|
const bool checkIndices = block.getBArguments()->size() <= 1 ? false : B_ARG(1);
|
||||||
|
|
||||||
const int inRank = input->rankOf();
|
const int inRank = input->rankOf();
|
||||||
const int indRank = indices->rankOf();
|
const int indRank = indices->rankOf();
|
||||||
|
@ -70,8 +71,15 @@ namespace nd4j {
|
||||||
REQUIRE_TRUE(expectedUpdShape == updShape, 0, "SCATTER_MUL OP: wrong shape of updates array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedUpdShape).c_str(), ShapeUtils::shapeAsString(updShape).c_str());
|
REQUIRE_TRUE(expectedUpdShape == updShape, 0, "SCATTER_MUL OP: wrong shape of updates array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedUpdShape).c_str(), ShapeUtils::shapeAsString(updShape).c_str());
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!indices->isEmpty())
|
if (!indices->isEmpty()) {
|
||||||
|
|
||||||
|
if(checkIndices) {
|
||||||
|
const Nd4jLong numOfBadIndx = helpers::checkIndices(block.launchContext(), *indices, *output, 0);
|
||||||
|
REQUIRE_TRUE(numOfBadIndx == 0, 0, "SCATTER_MUL OP: please check elements of indices-array, total number of wrong elements is %lld!", numOfBadIndx);
|
||||||
|
}
|
||||||
|
|
||||||
helpers::scatter(block.launchContext(), pairwise::Multiply, *indices, *updates, *output, lock);
|
helpers::scatter(block.launchContext(), pairwise::Multiply, *indices, *updates, *output, lock);
|
||||||
|
}
|
||||||
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
|
@ -35,6 +35,7 @@ namespace ops {
|
||||||
auto output = OUTPUT_VARIABLE(0);
|
auto output = OUTPUT_VARIABLE(0);
|
||||||
|
|
||||||
const bool lock = block.getBArguments()->empty() ? false : B_ARG(0);
|
const bool lock = block.getBArguments()->empty() ? false : B_ARG(0);
|
||||||
|
const bool checkIndices = block.getBArguments()->size() <= 1 ? false : B_ARG(1);
|
||||||
|
|
||||||
const int indRank = indices->rankOf();
|
const int indRank = indices->rankOf();
|
||||||
const int updRank = updates->rankOf();
|
const int updRank = updates->rankOf();
|
||||||
|
@ -53,6 +54,11 @@ namespace ops {
|
||||||
std::move(std::begin(outShape) + indices->sizeAt(-1), std::end(outShape), std::back_inserter(expectedUpdShape));
|
std::move(std::begin(outShape) + indices->sizeAt(-1), std::end(outShape), std::back_inserter(expectedUpdShape));
|
||||||
REQUIRE_TRUE(expectedUpdShape == updShape, 0, "SCATTER_ND OP: wrong shape of updates array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedUpdShape).c_str(), ShapeUtils::shapeAsString(updShape).c_str());
|
REQUIRE_TRUE(expectedUpdShape == updShape, 0, "SCATTER_ND OP: wrong shape of updates array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedUpdShape).c_str(), ShapeUtils::shapeAsString(updShape).c_str());
|
||||||
|
|
||||||
|
if(checkIndices) {
|
||||||
|
const Nd4jLong numOfBadIndx = helpers::checkIndices(block.launchContext(), *indices, *output);
|
||||||
|
REQUIRE_TRUE(numOfBadIndx == 0, 0, "SCATTER_ND OP: please check elements of indices-array, total number of wrong elements is %lld!", numOfBadIndx);
|
||||||
|
}
|
||||||
|
|
||||||
// initial zeroing of output
|
// initial zeroing of output
|
||||||
*output = 0;
|
*output = 0;
|
||||||
|
|
||||||
|
|
|
@ -34,7 +34,8 @@ OP_IMPL(scatter_nd_add, 3, 1, true) {
|
||||||
|
|
||||||
auto output = OUTPUT_VARIABLE(0);
|
auto output = OUTPUT_VARIABLE(0);
|
||||||
|
|
||||||
bool lock = block.getBArguments()->empty() ? false : B_ARG(0);
|
const bool lock = block.getBArguments()->empty() ? false : B_ARG(0);
|
||||||
|
const bool checkIndices = block.getBArguments()->size() <= 1 ? false : B_ARG(1);
|
||||||
|
|
||||||
const int inRank = input->rankOf();
|
const int inRank = input->rankOf();
|
||||||
const int indRank = indices->rankOf();
|
const int indRank = indices->rankOf();
|
||||||
|
@ -53,6 +54,11 @@ OP_IMPL(scatter_nd_add, 3, 1, true) {
|
||||||
std::move(std::begin(inShape) + indLastDim, std::end(inShape), std::back_inserter(expectedUpdShape));
|
std::move(std::begin(inShape) + indLastDim, std::end(inShape), std::back_inserter(expectedUpdShape));
|
||||||
REQUIRE_TRUE(expectedUpdShape == updShape, 0, "SCATTER_ND_ADD OP: wrong shape of updates array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedUpdShape).c_str(), ShapeUtils::shapeAsString(updShape).c_str());
|
REQUIRE_TRUE(expectedUpdShape == updShape, 0, "SCATTER_ND_ADD OP: wrong shape of updates array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedUpdShape).c_str(), ShapeUtils::shapeAsString(updShape).c_str());
|
||||||
|
|
||||||
|
if(checkIndices) {
|
||||||
|
const Nd4jLong numOfBadIndx = helpers::checkIndices(block.launchContext(), *indices, *output);
|
||||||
|
REQUIRE_TRUE(numOfBadIndx == 0, 0, "SCATTER_ND_ADD OP: please check elements of indices-array, total number of wrong elements is %lld!", numOfBadIndx);
|
||||||
|
}
|
||||||
|
|
||||||
if (!block.isInplace())
|
if (!block.isInplace())
|
||||||
output->assign(input);
|
output->assign(input);
|
||||||
|
|
||||||
|
|
|
@ -34,7 +34,8 @@ OP_IMPL(scatter_nd_sub, 3, 1, true) {
|
||||||
|
|
||||||
auto output = OUTPUT_VARIABLE(0);
|
auto output = OUTPUT_VARIABLE(0);
|
||||||
|
|
||||||
bool lock = block.getBArguments()->empty() ? false : B_ARG(0);
|
const bool lock = block.getBArguments()->empty() ? false : B_ARG(0);
|
||||||
|
const bool checkIndices = block.getBArguments()->size() <= 1 ? false : B_ARG(1);
|
||||||
|
|
||||||
const int inRank = input->rankOf();
|
const int inRank = input->rankOf();
|
||||||
const int indRank = indices->rankOf();
|
const int indRank = indices->rankOf();
|
||||||
|
@ -53,6 +54,11 @@ OP_IMPL(scatter_nd_sub, 3, 1, true) {
|
||||||
std::move(std::begin(inShape) + indLastDim, std::end(inShape), std::back_inserter(expectedUpdShape));
|
std::move(std::begin(inShape) + indLastDim, std::end(inShape), std::back_inserter(expectedUpdShape));
|
||||||
REQUIRE_TRUE(expectedUpdShape == updShape, 0, "SCATTER_ND_SUB OP: wrong shape of updates array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedUpdShape).c_str(), ShapeUtils::shapeAsString(updShape).c_str());
|
REQUIRE_TRUE(expectedUpdShape == updShape, 0, "SCATTER_ND_SUB OP: wrong shape of updates array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedUpdShape).c_str(), ShapeUtils::shapeAsString(updShape).c_str());
|
||||||
|
|
||||||
|
if(checkIndices) {
|
||||||
|
const Nd4jLong numOfBadIndx = helpers::checkIndices(block.launchContext(), *indices, *output);
|
||||||
|
REQUIRE_TRUE(numOfBadIndx == 0, 0, "SCATTER_ND_SUB OP: please check elements of indices-array, total number of wrong elements is %lld!", numOfBadIndx);
|
||||||
|
}
|
||||||
|
|
||||||
if (!block.isInplace())
|
if (!block.isInplace())
|
||||||
output->assign(input);
|
output->assign(input);
|
||||||
|
|
||||||
|
|
|
@ -34,7 +34,8 @@ OP_IMPL(scatter_nd_update, 3, 1, true) {
|
||||||
|
|
||||||
auto output = OUTPUT_VARIABLE(0);
|
auto output = OUTPUT_VARIABLE(0);
|
||||||
|
|
||||||
bool lock = block.getBArguments()->empty() ? true : B_ARG(0);
|
const bool lock = block.getBArguments()->empty() ? true : B_ARG(0);
|
||||||
|
const bool checkIndices = block.getBArguments()->size() <= 1 ? false : B_ARG(1);
|
||||||
|
|
||||||
const int inRank = input->rankOf();
|
const int inRank = input->rankOf();
|
||||||
const int indRank = indices->rankOf();
|
const int indRank = indices->rankOf();
|
||||||
|
@ -53,6 +54,11 @@ OP_IMPL(scatter_nd_update, 3, 1, true) {
|
||||||
std::move(std::begin(inShape) + indLastDim, std::end(inShape), std::back_inserter(expectedUpdShape));
|
std::move(std::begin(inShape) + indLastDim, std::end(inShape), std::back_inserter(expectedUpdShape));
|
||||||
REQUIRE_TRUE(expectedUpdShape == updShape, 0, "SCATTER_ND_UPDATE OP: wrong shape of updates array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedUpdShape).c_str(), ShapeUtils::shapeAsString(updShape).c_str());
|
REQUIRE_TRUE(expectedUpdShape == updShape, 0, "SCATTER_ND_UPDATE OP: wrong shape of updates array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedUpdShape).c_str(), ShapeUtils::shapeAsString(updShape).c_str());
|
||||||
|
|
||||||
|
if(checkIndices) {
|
||||||
|
const Nd4jLong numOfBadIndx = helpers::checkIndices(block.launchContext(), *indices, *output);
|
||||||
|
REQUIRE_TRUE(numOfBadIndx == 0, 0, "SCATTER_ND_UPDATE OP: please check elements of indices-array, total number of wrong elements is %lld!", numOfBadIndx);
|
||||||
|
}
|
||||||
|
|
||||||
if (!block.isInplace())
|
if (!block.isInplace())
|
||||||
output->assign(input);
|
output->assign(input);
|
||||||
|
|
||||||
|
|
|
@ -38,6 +38,7 @@ namespace nd4j {
|
||||||
output->assign(input);
|
output->assign(input);
|
||||||
|
|
||||||
const bool lock = block.getBArguments()->empty() ? false : B_ARG(0);
|
const bool lock = block.getBArguments()->empty() ? false : B_ARG(0);
|
||||||
|
const bool checkIndices = block.getBArguments()->size() <= 1 ? false : B_ARG(1);
|
||||||
|
|
||||||
const int inRank = input->rankOf();
|
const int inRank = input->rankOf();
|
||||||
const int indRank = indices->rankOf();
|
const int indRank = indices->rankOf();
|
||||||
|
@ -70,9 +71,16 @@ namespace nd4j {
|
||||||
REQUIRE_TRUE(expectedUpdShape == updShape, 0, "SCATTER_SUB OP: wrong shape of updates array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedUpdShape).c_str(), ShapeUtils::shapeAsString(updShape).c_str());
|
REQUIRE_TRUE(expectedUpdShape == updShape, 0, "SCATTER_SUB OP: wrong shape of updates array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedUpdShape).c_str(), ShapeUtils::shapeAsString(updShape).c_str());
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!indices->isEmpty())
|
if (!indices->isEmpty()) {
|
||||||
|
|
||||||
|
if(checkIndices) {
|
||||||
|
const Nd4jLong numOfBadIndx = helpers::checkIndices(block.launchContext(), *indices, *output, 0);
|
||||||
|
REQUIRE_TRUE(numOfBadIndx == 0, 0, "SCATTER_SUB OP: please check elements of indices-array, total number of wrong elements is %lld!", numOfBadIndx);
|
||||||
|
}
|
||||||
|
|
||||||
// ScatterHelper<T>::template scatterApply<simdOps::Subtract<T>>(output, indices, updates);
|
// ScatterHelper<T>::template scatterApply<simdOps::Subtract<T>>(output, indices, updates);
|
||||||
helpers::scatter(block.launchContext(), pairwise::Subtract, *indices, *updates, *output, lock);
|
helpers::scatter(block.launchContext(), pairwise::Subtract, *indices, *updates, *output, lock);
|
||||||
|
}
|
||||||
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
|
@ -37,6 +37,7 @@ namespace nd4j {
|
||||||
output->assign(input);
|
output->assign(input);
|
||||||
|
|
||||||
const bool lock = block.getBArguments()->empty() ? true : B_ARG(0);
|
const bool lock = block.getBArguments()->empty() ? true : B_ARG(0);
|
||||||
|
const bool checkIndices = block.getBArguments()->size() <= 1 ? false : B_ARG(1);
|
||||||
|
|
||||||
const int inRank = input->rankOf();
|
const int inRank = input->rankOf();
|
||||||
const int indRank = indices->rankOf();
|
const int indRank = indices->rankOf();
|
||||||
|
@ -68,9 +69,16 @@ namespace nd4j {
|
||||||
REQUIRE_TRUE(expectedUpdShape == updShape, 0, "SCATTER_UPD OP: wrong shape of updates array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedUpdShape).c_str(), ShapeUtils::shapeAsString(updShape).c_str());
|
REQUIRE_TRUE(expectedUpdShape == updShape, 0, "SCATTER_UPD OP: wrong shape of updates array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedUpdShape).c_str(), ShapeUtils::shapeAsString(updShape).c_str());
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!indices->isEmpty())
|
if (!indices->isEmpty()) {
|
||||||
|
|
||||||
|
if(checkIndices) {
|
||||||
|
const Nd4jLong numOfBadIndx = helpers::checkIndices(block.launchContext(), *indices, *output, 0);
|
||||||
|
REQUIRE_TRUE(numOfBadIndx == 0, 0, "SCATTER_UPD OP: please check elements of indices-array, total number of wrong elements is %lld!", numOfBadIndx);
|
||||||
|
}
|
||||||
|
|
||||||
// ScatterHelper<T>::template scatterApply<simdOps::Copy<T>>(output, indices, updates);
|
// ScatterHelper<T>::template scatterApply<simdOps::Copy<T>>(output, indices, updates);
|
||||||
helpers::scatter(block.launchContext(), pairwise::CopyPws, *indices, *updates, *output, lock);
|
helpers::scatter(block.launchContext(), pairwise::CopyPws, *indices, *updates, *output, lock);
|
||||||
|
}
|
||||||
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
|
@ -22,7 +22,8 @@
|
||||||
#if NOT_EXCLUDED(OP_gather)
|
#if NOT_EXCLUDED(OP_gather)
|
||||||
|
|
||||||
#include <ops/declarable/CustomOperations.h>
|
#include <ops/declarable/CustomOperations.h>
|
||||||
#include<ops/declarable/helpers/gather.h>
|
#include <ops/declarable/helpers/gather.h>
|
||||||
|
#include <ops/declarable/helpers/scatter.h>
|
||||||
|
|
||||||
|
|
||||||
namespace nd4j {
|
namespace nd4j {
|
||||||
|
@ -36,6 +37,8 @@ CUSTOM_OP_IMPL(gather, 1, 1, false, 0, -2) {
|
||||||
auto indices = block.width() > 1 ? INPUT_VARIABLE(1) : nullptr;
|
auto indices = block.width() > 1 ? INPUT_VARIABLE(1) : nullptr;
|
||||||
auto output = OUTPUT_VARIABLE(0);
|
auto output = OUTPUT_VARIABLE(0);
|
||||||
|
|
||||||
|
const bool checkIndices = block.getBArguments()->empty() ? false : B_ARG(0);
|
||||||
|
|
||||||
//Edge case: empty indices -> empty output
|
//Edge case: empty indices -> empty output
|
||||||
if(indices != nullptr && indices->isEmpty()){
|
if(indices != nullptr && indices->isEmpty()){
|
||||||
REQUIRE_TRUE(output->isEmpty(), 0, "Gather op: If indices are empty, output must also be empty");
|
REQUIRE_TRUE(output->isEmpty(), 0, "Gather op: If indices are empty, output must also be empty");
|
||||||
|
@ -64,13 +67,15 @@ CUSTOM_OP_IMPL(gather, 1, 1, false, 0, -2) {
|
||||||
REQUIRE_TRUE(intArgs[0] < inputRank, 0, "GATHER op: input axis must be smaller than input array rank, but got %i and %i correspondingly!", intArgs[0], inputRank);
|
REQUIRE_TRUE(intArgs[0] < inputRank, 0, "GATHER op: input axis must be smaller than input array rank, but got %i and %i correspondingly!", intArgs[0], inputRank);
|
||||||
REQUIRE_TRUE(indices != nullptr || numOfIntArgs > 1, 0, "GATHER op: indices should be provided either as additional input array or as IntArguments !");
|
REQUIRE_TRUE(indices != nullptr || numOfIntArgs > 1, 0, "GATHER op: indices should be provided either as additional input array or as IntArguments !");
|
||||||
|
|
||||||
if (indices != nullptr) {
|
if(checkIndices) {
|
||||||
for(int i = 0; i < indices->lengthOf(); ++i)
|
|
||||||
REQUIRE_TRUE(indices->e<Nd4jLong>(i) < input->sizeAt(intArgs[0]), 0, "GATHER op: indices array contains wrong elements, each element must be smaller than corresponding dimension of input array !");
|
NDArray* pIndices = indices;
|
||||||
}
|
if(indices == nullptr)
|
||||||
else {
|
pIndices = new NDArray(input->ordering(), {static_cast<int>(intArgs.size()) - 1}, std::vector<double>(intArgs.begin() + 1, intArgs.end()), DataType::INT64, block.launchContext());
|
||||||
for(int i = 1; i < numOfIntArgs; ++i)
|
const Nd4jLong numOfBadIndx = helpers::checkIndices(block.launchContext(), *pIndices, *input, intArgs[0]);
|
||||||
REQUIRE_TRUE(intArgs[i] < input->sizeAt(intArgs[0]), 0, "GATHER op: some of indexes is larger than corresponding shape of input array !");
|
REQUIRE_TRUE(numOfBadIndx == 0, 0, "GATHER OP: please check elements of indices-array, total number of wrong elements is %lld!", numOfBadIndx);
|
||||||
|
if(indices == nullptr)
|
||||||
|
delete pIndices;
|
||||||
}
|
}
|
||||||
|
|
||||||
helpers::gather(block.launchContext(), input, indices, output, intArgs);
|
helpers::gather(block.launchContext(), input, indices, output, intArgs);
|
||||||
|
|
|
@ -23,6 +23,7 @@
|
||||||
|
|
||||||
#include <ops/declarable/CustomOperations.h>
|
#include <ops/declarable/CustomOperations.h>
|
||||||
#include<ops/declarable/helpers/transforms.h>
|
#include<ops/declarable/helpers/transforms.h>
|
||||||
|
#include <ops/declarable/helpers/scatter.h>
|
||||||
|
|
||||||
namespace nd4j {
|
namespace nd4j {
|
||||||
namespace ops {
|
namespace ops {
|
||||||
|
@ -35,6 +36,8 @@ CUSTOM_OP_IMPL(gather_nd, 2, 1, false, 0, 0) {
|
||||||
auto indices = INPUT_VARIABLE(1);
|
auto indices = INPUT_VARIABLE(1);
|
||||||
auto output = OUTPUT_VARIABLE(0);
|
auto output = OUTPUT_VARIABLE(0);
|
||||||
|
|
||||||
|
const bool checkIndices = block.getBArguments()->empty() ? false : B_ARG(0);
|
||||||
|
|
||||||
const int rankIn = input->rankOf();
|
const int rankIn = input->rankOf();
|
||||||
const int rankInd = indices->rankOf();
|
const int rankInd = indices->rankOf();
|
||||||
|
|
||||||
|
@ -42,6 +45,11 @@ CUSTOM_OP_IMPL(gather_nd, 2, 1, false, 0, 0) {
|
||||||
int lastIndDim = indices->sizeAt(-1);
|
int lastIndDim = indices->sizeAt(-1);
|
||||||
REQUIRE_TRUE(lastIndDim <= rankIn, 0, "GATHER_ND op: the last dimension of indices array must be <= rank of input array but got %i and %i correspondingly!", lastIndDim, rankIn);
|
REQUIRE_TRUE(lastIndDim <= rankIn, 0, "GATHER_ND op: the last dimension of indices array must be <= rank of input array but got %i and %i correspondingly!", lastIndDim, rankIn);
|
||||||
|
|
||||||
|
if(checkIndices) {
|
||||||
|
const Nd4jLong numOfBadIndx = helpers::checkIndices(block.launchContext(), *indices, *input);
|
||||||
|
REQUIRE_TRUE(numOfBadIndx == 0, 0, "GATHER_ND OP: please check elements of indices-array, total number of wrong elements is %lld!", numOfBadIndx);
|
||||||
|
}
|
||||||
|
|
||||||
helpers::gatherND(block.launchContext(), *input, *indices, *output);
|
helpers::gatherND(block.launchContext(), *input, *indices, *output);
|
||||||
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
|
|
|
@ -27,6 +27,49 @@ namespace nd4j {
|
||||||
namespace ops {
|
namespace ops {
|
||||||
namespace helpers {
|
namespace helpers {
|
||||||
|
|
||||||
|
///////////////////////////////////////////////////////////////////
|
||||||
|
// x - indices, z - input/output
|
||||||
|
template<typename T>
|
||||||
|
Nd4jLong checkIndices_(const NDArray& indices, const NDArray& output, const int axis) {
|
||||||
|
|
||||||
|
std::atomic<int64_t> numOfBadIndx{0};
|
||||||
|
|
||||||
|
const auto x = indices.bufferAsT<T>();
|
||||||
|
|
||||||
|
const auto xShapeInfo = indices.getShapeInfo();
|
||||||
|
const auto zShapeInfo = output.getShapeInfo();
|
||||||
|
|
||||||
|
const auto xRank = indices.rankOf();
|
||||||
|
|
||||||
|
auto func = PRAGMA_THREADS_FOR {
|
||||||
|
|
||||||
|
Nd4jLong xCoords[MAX_RANK];
|
||||||
|
|
||||||
|
for (auto i = start; i < stop; i += increment) {
|
||||||
|
|
||||||
|
shape::index2coords(i, xShapeInfo, xCoords);
|
||||||
|
|
||||||
|
const Nd4jLong currentInd = x[shape::getOffset(xShapeInfo, xCoords)];
|
||||||
|
|
||||||
|
if(currentInd >= shape::sizeAt(zShapeInfo, axis == -1 ? xCoords[xRank-1] : axis)) {
|
||||||
|
printf("checkIndices: out of range element %lld at index %ld \n", currentInd, i);
|
||||||
|
++numOfBadIndx;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
samediff::Threads::parallel_for(func, 0, indices.lengthOf());
|
||||||
|
|
||||||
|
return numOfBadIndx;
|
||||||
|
}
|
||||||
|
|
||||||
|
///////////////////////////////////////////////////////////////////
|
||||||
|
Nd4jLong checkIndices(nd4j::LaunchContext *context, const NDArray& indices, const NDArray& output, const int axis) {
|
||||||
|
|
||||||
|
BUILD_SINGLE_SELECTOR(indices.dataType(), return checkIndices_, (indices, output, axis), INDEXING_TYPES);
|
||||||
|
}
|
||||||
|
|
||||||
|
///////////////////////////////////////////////////////////////////
|
||||||
void scatter(nd4j::LaunchContext *context, pairwise::Ops op, const NDArray& indices, const NDArray& updates, NDArray& output, const bool lock) {
|
void scatter(nd4j::LaunchContext *context, pairwise::Ops op, const NDArray& indices, const NDArray& updates, NDArray& output, const bool lock) {
|
||||||
|
|
||||||
const int outRank = output.rankOf();
|
const int outRank = output.rankOf();
|
||||||
|
|
|
@ -108,12 +108,12 @@ __host__ static void gatherCudaLauncher(const cudaStream_t *stream, const int nu
|
||||||
void gather(nd4j::LaunchContext * context, const NDArray* input, const NDArray* indices, NDArray* output, const std::vector<int>& intArgs) {
|
void gather(nd4j::LaunchContext * context, const NDArray* input, const NDArray* indices, NDArray* output, const std::vector<int>& intArgs) {
|
||||||
|
|
||||||
const int inputRank = input->rankOf();
|
const int inputRank = input->rankOf();
|
||||||
int axis = intArgs.size() > 0 ? intArgs[0] : 0;
|
const int numOfIntArgs = intArgs.size();
|
||||||
|
|
||||||
|
int axis = numOfIntArgs > 0 ? intArgs[0] : 0;
|
||||||
if(axis < 0)
|
if(axis < 0)
|
||||||
axis += inputRank;
|
axis += inputRank;
|
||||||
|
|
||||||
const int numOfIntArgs = intArgs.size();
|
|
||||||
|
|
||||||
if (indices == nullptr && numOfIntArgs == 2) { // scalar case
|
if (indices == nullptr && numOfIntArgs == 2) { // scalar case
|
||||||
output->assign((*input)(intArgs[1], {axis}));
|
output->assign((*input)(intArgs[1], {axis}));
|
||||||
}
|
}
|
||||||
|
|
|
@ -106,7 +106,7 @@ namespace nd4j {
|
||||||
const auto xOffset = shape::getOffset(xShapeInfo, xCoordStart);
|
const auto xOffset = shape::getOffset(xShapeInfo, xCoordStart);
|
||||||
|
|
||||||
z[zOffset] = x[xOffset];
|
z[zOffset] = x[xOffset];
|
||||||
printf("z[%lld] = x[%lld] = %f\n", zOffset, xOffset, (float) z[zOffset]);
|
// printf("z[%lld] = x[%lld] = %f\n", zOffset, xOffset, (float) z[zOffset]);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
File diff suppressed because it is too large
Load Diff
|
@ -31,6 +31,8 @@ namespace nd4j {
|
||||||
void scatterND(nd4j::LaunchContext* context, pairwise::Ops op, const NDArray& indices, const NDArray& updates, NDArray& output, const bool lock);
|
void scatterND(nd4j::LaunchContext* context, pairwise::Ops op, const NDArray& indices, const NDArray& updates, NDArray& output, const bool lock);
|
||||||
|
|
||||||
void scatterForLoss(nd4j::LaunchContext* context, const NDArray& indices, NDArray& updates, NDArray& output, const bool calcGrad);
|
void scatterForLoss(nd4j::LaunchContext* context, const NDArray& indices, NDArray& updates, NDArray& output, const bool calcGrad);
|
||||||
|
|
||||||
|
Nd4jLong checkIndices(nd4j::LaunchContext *context, const NDArray& indices, const NDArray& output, const int axis = -1);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -76,6 +76,16 @@ TEST_F(DeclarableOpsTests16, scatter_upd_2) {
|
||||||
delete result;
|
delete result;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_F(DeclarableOpsTests16, scatter_upd_3) {
|
||||||
|
|
||||||
|
NDArray x('c', {10, 3}, nd4j::DataType::FLOAT32);
|
||||||
|
NDArray indices('c', {2}, {20,5}, nd4j::DataType::INT32);
|
||||||
|
NDArray updates('c', {2, 3}, {100,101,102, 200,201,202}, nd4j::DataType::FLOAT32);
|
||||||
|
NDArray output('c', {10, 3}, nd4j::DataType::FLOAT32);
|
||||||
|
|
||||||
|
nd4j::ops::scatter_upd op;
|
||||||
|
ASSERT_ANY_THROW(op.execute({&x, &indices, &updates}, {&output}, {}, {}, {true, true}));
|
||||||
|
}
|
||||||
|
|
||||||
TEST_F(DeclarableOpsTests16, test_size_dtype_1) {
|
TEST_F(DeclarableOpsTests16, test_size_dtype_1) {
|
||||||
auto x = NDArrayFactory::create<float>('c', {3}, {1, 1, 1});
|
auto x = NDArrayFactory::create<float>('c', {3}, {1, 1, 1});
|
||||||
|
|
|
@ -62,7 +62,7 @@ TEST_F(DeclarableOpsTests2, gather_2) {
|
||||||
|
|
||||||
nd4j::ops::gather op;
|
nd4j::ops::gather op;
|
||||||
|
|
||||||
auto result = op.execute({&input}, {}, {1, 0,1, 2,2, 1,2});
|
auto result = op.execute({&input}, {}, {1, 0,1, 2,2, 1,2}, {true});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||||
|
|
||||||
|
@ -125,7 +125,7 @@ TEST_F(DeclarableOpsTests2, gather_5) {
|
||||||
|
|
||||||
nd4j::ops::gather op;
|
nd4j::ops::gather op;
|
||||||
|
|
||||||
auto result = op.execute({&input, &indices}, {}, {1});
|
auto result = op.execute({&input, &indices}, {}, {1}, {true});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||||
|
|
||||||
|
@ -294,7 +294,7 @@ TEST_F(DeclarableOpsTests2, gather_13) {
|
||||||
|
|
||||||
nd4j::ops::gather op;
|
nd4j::ops::gather op;
|
||||||
|
|
||||||
auto result = op.execute({&input, &indices}, {}, {2});
|
auto result = op.execute({&input, &indices}, {}, {2}, {true});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||||
|
|
||||||
|
@ -306,6 +306,30 @@ TEST_F(DeclarableOpsTests2, gather_13) {
|
||||||
delete result;
|
delete result;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
////////////////////////////////////////////////////////////////////
|
||||||
|
TEST_F(DeclarableOpsTests2, gather_14) {
|
||||||
|
|
||||||
|
NDArray input ('c', {2,3,4}, {1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24});
|
||||||
|
NDArray indices ('c', {2,3}, {0, 10, 2, 20, 1,2}, nd4j::DataType::INT32);
|
||||||
|
NDArray output('c', {2,2,3,4});
|
||||||
|
|
||||||
|
nd4j::ops::gather op;
|
||||||
|
|
||||||
|
ASSERT_ANY_THROW(op.execute({&input, &indices}, {&output}, {}, {1}, {true}));
|
||||||
|
}
|
||||||
|
|
||||||
|
////////////////////////////////////////////////////////////////////
|
||||||
|
TEST_F(DeclarableOpsTests2, gather_15) {
|
||||||
|
|
||||||
|
NDArray input ('c', {2,3,4,5}, nd4j::DataType::DOUBLE);
|
||||||
|
NDArray indices ('c', {2,3,4}, {0, 10, 2, 3, 0, 1, 20, 3, 0, 1, 2, 3,0, 1, 2, 3, 0, 1, 2, 30, 0, 1, 2, 3}, nd4j::DataType::INT32);
|
||||||
|
NDArray output('c', {2,3, 2,3,4, 5});
|
||||||
|
|
||||||
|
nd4j::ops::gather op;
|
||||||
|
|
||||||
|
ASSERT_ANY_THROW(op.execute({&input, &indices}, {&output}, {}, {2}, {true}));
|
||||||
|
}
|
||||||
|
|
||||||
TEST_F(DeclarableOpsTests2, BroadcastGradientArgs_1) {
|
TEST_F(DeclarableOpsTests2, BroadcastGradientArgs_1) {
|
||||||
|
|
||||||
NDArray input ('c', {3,3,4}, {1, 2, 3, 4, 5, 6, 7, 8, 9,10,11,12, 13,14,15,16,17,18,19,20,21,22,23,24, 25,26,27,28,29,30,31,32,33,34,35,36}, nd4j::DataType::INT32);
|
NDArray input ('c', {3,3,4}, {1, 2, 3, 4, 5, 6, 7, 8, 9,10,11,12, 13,14,15,16,17,18,19,20,21,22,23,24, 25,26,27,28,29,30,31,32,33,34,35,36}, nd4j::DataType::INT32);
|
||||||
|
|
|
@ -704,7 +704,7 @@ TEST_F(DeclarableOpsTests5, gatherNd_test2) {
|
||||||
auto expected = NDArrayFactory::create<double>('c', {2,2,2}, {23, 24, 11, 12, 3, 4, 3, 4});
|
auto expected = NDArrayFactory::create<double>('c', {2,2,2}, {23, 24, 11, 12, 3, 4, 3, 4});
|
||||||
|
|
||||||
nd4j::ops::gather_nd op;
|
nd4j::ops::gather_nd op;
|
||||||
auto results = op.execute({&input, &indices}, {}, {});
|
auto results = op.execute({&input, &indices}, {}, {}, {true});
|
||||||
auto output = results->at(0);
|
auto output = results->at(0);
|
||||||
|
|
||||||
ASSERT_EQ(Status::OK(), results->status());
|
ASSERT_EQ(Status::OK(), results->status());
|
||||||
|
@ -798,7 +798,7 @@ TEST_F(DeclarableOpsTests5, gatherNd_test7) {
|
||||||
auto expected = NDArrayFactory::create<double>('c', {3,3}, {3,5,5,8,5,10,2,2,14});
|
auto expected = NDArrayFactory::create<double>('c', {3,3}, {3,5,5,8,5,10,2,2,14});
|
||||||
|
|
||||||
nd4j::ops::gather_nd op;
|
nd4j::ops::gather_nd op;
|
||||||
auto results = op.execute({&input, &indices}, {}, {});
|
auto results = op.execute({&input, &indices}, {}, {}, {true});
|
||||||
auto output = results->at(0);
|
auto output = results->at(0);
|
||||||
|
|
||||||
ASSERT_EQ(Status::OK(), results->status());
|
ASSERT_EQ(Status::OK(), results->status());
|
||||||
|
@ -825,6 +825,52 @@ TEST_F(DeclarableOpsTests5, gatherNd_test8) {
|
||||||
delete result;
|
delete result;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_F(DeclarableOpsTests5, gatherNd_test9) {
|
||||||
|
auto x = NDArrayFactory::create<double>('c', {2, 4, 2, 2});
|
||||||
|
auto indices = NDArrayFactory::create<int>('c', {3, 3}, {0,2,1, 0,1,0, 1,3,1});
|
||||||
|
auto exp = NDArrayFactory::create<double>('c', {3,2}, {11.f, 12.f, 5.f, 6.f, 31.f, 32.f});
|
||||||
|
x.linspace(1);
|
||||||
|
|
||||||
|
nd4j::ops::gather_nd op;
|
||||||
|
auto result = op.execute({&x, &indices}, {}, {});
|
||||||
|
ASSERT_EQ(Status::OK(), result->status());
|
||||||
|
|
||||||
|
auto z = result->at(0);
|
||||||
|
|
||||||
|
//z->printIndexedBuffer();
|
||||||
|
//z->printShapeInfo("z shape");
|
||||||
|
|
||||||
|
ASSERT_TRUE(exp.isSameShape(z));
|
||||||
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
|
||||||
|
delete result;
|
||||||
|
}
|
||||||
|
|
||||||
|
//////////////////////////////////////////////////////////////////////
|
||||||
|
TEST_F(DeclarableOpsTests5, gatherNd_test10) {
|
||||||
|
|
||||||
|
auto input = NDArrayFactory::create<double>('c', {4, 3, 2});
|
||||||
|
auto indices = NDArrayFactory::create<int>('c', {2,2,2}, {30,20,1,2, 0,10,0,1});
|
||||||
|
|
||||||
|
auto output = NDArrayFactory::create<double>('c', {2,2,2});
|
||||||
|
|
||||||
|
nd4j::ops::gather_nd op;
|
||||||
|
|
||||||
|
ASSERT_ANY_THROW(op.execute({&input, &indices}, {&output}, {}, {}, {true}));
|
||||||
|
}
|
||||||
|
|
||||||
|
//////////////////////////////////////////////////////////////////////
|
||||||
|
TEST_F(DeclarableOpsTests5, gatherNd_test11) {
|
||||||
|
|
||||||
|
auto input = NDArrayFactory::create<double>('c', {4, 4});
|
||||||
|
auto indices = NDArrayFactory::create<int>('c', {3,3,2}, {0,2,1, 0,10,0, 1,30,1, 0,20,1, 0,1,0, 1,30,1});
|
||||||
|
auto output = NDArrayFactory::create<double>('c', {3,3});
|
||||||
|
|
||||||
|
nd4j::ops::gather_nd op;
|
||||||
|
|
||||||
|
ASSERT_ANY_THROW(op.execute({&input, &indices}, {&output}, {}, {}, {true}));
|
||||||
|
}
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////
|
||||||
TEST_F(DeclarableOpsTests5, reverse_sequense_test1) {
|
TEST_F(DeclarableOpsTests5, reverse_sequense_test1) {
|
||||||
|
|
||||||
|
|
|
@ -314,27 +314,6 @@ TEST_F(DeclarableOpsTests6, Test_Simple_Scalar_1) {
|
||||||
delete result;
|
delete result;
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(DeclarableOpsTests6, Test_gatherNd_Edge_1) {
|
|
||||||
auto x = NDArrayFactory::create<double>('c', {2, 4, 2, 2});
|
|
||||||
auto indices = NDArrayFactory::create<int>('c', {3, 3}, {0,2,1, 0,1,0, 1,3,1});
|
|
||||||
auto exp = NDArrayFactory::create<double>('c', {3,2}, {11.f, 12.f, 5.f, 6.f, 31.f, 32.f});
|
|
||||||
x.linspace(1);
|
|
||||||
|
|
||||||
nd4j::ops::gather_nd op;
|
|
||||||
auto result = op.execute({&x, &indices}, {}, {});
|
|
||||||
ASSERT_EQ(Status::OK(), result->status());
|
|
||||||
|
|
||||||
auto z = result->at(0);
|
|
||||||
|
|
||||||
//z->printIndexedBuffer();
|
|
||||||
//z->printShapeInfo("z shape");
|
|
||||||
|
|
||||||
ASSERT_TRUE(exp.isSameShape(z));
|
|
||||||
ASSERT_TRUE(exp.equalsTo(z));
|
|
||||||
|
|
||||||
delete result;
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST_F(DeclarableOpsTests6, Test_Order_1) {
|
TEST_F(DeclarableOpsTests6, Test_Order_1) {
|
||||||
auto x = NDArrayFactory::create<double>('f', {2, 3});
|
auto x = NDArrayFactory::create<double>('f', {2, 3});
|
||||||
auto exp = NDArrayFactory::create<double>('c', {2, 3});
|
auto exp = NDArrayFactory::create<double>('c', {2, 3});
|
||||||
|
|
|
@ -756,7 +756,7 @@ TEST_F(ParityOpsTests, Test_Scatter_Add_4) {
|
||||||
auto exp = NDArrayFactory::create<float>('c', {2, 2, 2}, {3, 4, 5, 6, 5, 6, 7, 8});
|
auto exp = NDArrayFactory::create<float>('c', {2, 2, 2}, {3, 4, 5, 6, 5, 6, 7, 8});
|
||||||
|
|
||||||
nd4j::ops::scatter_add op;
|
nd4j::ops::scatter_add op;
|
||||||
auto result = op.execute({&matrix, &idc, &updates}, {}, {}, {true});
|
auto result = op.execute({&matrix, &idc, &updates}, {}, {}, {true, true});
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||||
|
|
||||||
auto z = result->at(0);
|
auto z = result->at(0);
|
||||||
|
@ -791,7 +791,7 @@ TEST_F(ParityOpsTests, Test_Scatter_Add_6) {
|
||||||
auto exp = NDArrayFactory::create<float>('c', {2, 2, 2}, {7, 9, 11, 13, 7, 9, 11, 13});
|
auto exp = NDArrayFactory::create<float>('c', {2, 2, 2}, {7, 9, 11, 13, 7, 9, 11, 13});
|
||||||
|
|
||||||
nd4j::ops::scatter_add op;
|
nd4j::ops::scatter_add op;
|
||||||
auto result = op.execute({&matrix, &idc, &updates}, {}, {}, {true});
|
auto result = op.execute({&matrix, &idc, &updates}, {}, {}, {true, true});
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||||
|
|
||||||
auto z = result->at(0);
|
auto z = result->at(0);
|
||||||
|
@ -837,6 +837,18 @@ TEST_F(ParityOpsTests, Test_Scatter_Add_8) {
|
||||||
ASSERT_TRUE(expected.equalsTo(z));
|
ASSERT_TRUE(expected.equalsTo(z));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
////////////////////////////////////////////////////////////////////
|
||||||
|
TEST_F(ParityOpsTests, Test_Scatter_Add_9) {
|
||||||
|
auto matrix = NDArrayFactory::create<float>('c', {2, 2, 3}, {1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1});
|
||||||
|
NDArray idc('c', {2, 2}, {1, 10, 0, 0}, nd4j::DataType::INT64);
|
||||||
|
auto updates = NDArrayFactory::create<float>('c', {2, 2, 2, 3}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12});
|
||||||
|
auto output = NDArrayFactory::create<float>('c', {2, 2, 3});
|
||||||
|
|
||||||
|
nd4j::ops::scatter_add op;
|
||||||
|
|
||||||
|
ASSERT_ANY_THROW(op.execute({&matrix, &idc, &updates}, {&output}, {}, {}, {true, true}));
|
||||||
|
}
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////
|
||||||
TEST_F(ParityOpsTests, scatterMax_test1) {
|
TEST_F(ParityOpsTests, scatterMax_test1) {
|
||||||
auto matrix = NDArrayFactory::create<float>('c', {2, 2}, {1, 2, 3, 4});
|
auto matrix = NDArrayFactory::create<float>('c', {2, 2}, {1, 2, 3, 4});
|
||||||
|
@ -1010,6 +1022,18 @@ TEST_F(ParityOpsTests, scatterMin_test4) {
|
||||||
delete result;
|
delete result;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
////////////////////////////////////////////////////////////////////////
|
||||||
|
TEST_F(ParityOpsTests, scatterMin_test5) {
|
||||||
|
auto matrix = NDArrayFactory::create<float>('c', {2, 2, 2}, {1, 2, 3, 4, 5, 6, 7, 8});
|
||||||
|
NDArray idc('c', {1,2}, {10,10}, nd4j::DataType::INT32);
|
||||||
|
auto updates = NDArrayFactory::create<float>('c', {1, 2, 2, 2}, {1,10,1,10, 1,1,10,1.});
|
||||||
|
auto output = NDArrayFactory::create<float>('c', {2, 2, 2});
|
||||||
|
|
||||||
|
nd4j::ops::scatter_min op;
|
||||||
|
|
||||||
|
ASSERT_ANY_THROW(op.execute({&matrix, &idc, &updates}, {&output}, {}, {}, {true, true}));
|
||||||
|
}
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////
|
||||||
TEST_F(ParityOpsTests, scatterND_test1) {
|
TEST_F(ParityOpsTests, scatterND_test1) {
|
||||||
|
|
||||||
|
@ -1019,7 +1043,7 @@ TEST_F(ParityOpsTests, scatterND_test1) {
|
||||||
auto exp = NDArrayFactory::create<float>('c', {3, 4}, {50.f, 60.f, 70.f, 80.f, 10.f, 20.f, 30.f, 40.f, 0.f, 0.f, 0.f, 0.f});
|
auto exp = NDArrayFactory::create<float>('c', {3, 4}, {50.f, 60.f, 70.f, 80.f, 10.f, 20.f, 30.f, 40.f, 0.f, 0.f, 0.f, 0.f});
|
||||||
|
|
||||||
nd4j::ops::scatter_nd op;
|
nd4j::ops::scatter_nd op;
|
||||||
auto result = op.execute({&indices, &updates, &shape}, {}, {});
|
auto result = op.execute({&indices, &updates, &shape}, {}, {false, true});
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||||
|
|
||||||
auto z = result->at(0);
|
auto z = result->at(0);
|
||||||
|
@ -1066,7 +1090,7 @@ TEST_F(ParityOpsTests, scatterND_test3) {
|
||||||
updates.linspace(1.f);
|
updates.linspace(1.f);
|
||||||
|
|
||||||
nd4j::ops::scatter_nd op;
|
nd4j::ops::scatter_nd op;
|
||||||
auto result = op.execute({&indices, &updates, &shape}, {}, {});
|
auto result = op.execute({&indices, &updates, &shape}, {}, {false, true});
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||||
|
|
||||||
auto z = result->at(0);
|
auto z = result->at(0);
|
||||||
|
@ -1160,7 +1184,7 @@ TEST_F(ParityOpsTests, scatterND_test7) {
|
||||||
updates.linspace(1);
|
updates.linspace(1);
|
||||||
|
|
||||||
nd4j::ops::scatter_nd op;
|
nd4j::ops::scatter_nd op;
|
||||||
auto result = op.execute({&indices, &updates, &shape}, {}, {}, {true});
|
auto result = op.execute({&indices, &updates, &shape}, {}, {}, {true, true});
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||||
|
|
||||||
auto z = result->at(0);
|
auto z = result->at(0);
|
||||||
|
@ -1193,6 +1217,20 @@ TEST_F(ParityOpsTests, scatterND_test8) {
|
||||||
delete result;
|
delete result;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
////////////////////////////////////////////////////////////////////////
|
||||||
|
TEST_F(ParityOpsTests, scatterND_test9) {
|
||||||
|
|
||||||
|
NDArray indices('c', {2, 3, 1}, {0., 20., 7., 30., 6., 90.}, nd4j::DataType::INT32);
|
||||||
|
auto updates = NDArrayFactory::create<float>('c', {2,3, 3,4});
|
||||||
|
auto shape = NDArrayFactory::create<int>('c', {3}, {10, 3, 4});
|
||||||
|
auto output = NDArrayFactory::create<float>('c', {10, 3, 4});
|
||||||
|
|
||||||
|
nd4j::ops::scatter_nd op;
|
||||||
|
|
||||||
|
ASSERT_ANY_THROW(auto result = op.execute({&indices, &updates, &shape}, {&output}, {}, {}, {false, true}));
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////
|
||||||
TEST_F(ParityOpsTests, scatterND_add_test1) {
|
TEST_F(ParityOpsTests, scatterND_add_test1) {
|
||||||
|
|
||||||
|
@ -1323,6 +1361,19 @@ TEST_F(ParityOpsTests, scatterND_add_test5) {
|
||||||
delete result;
|
delete result;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
////////////////////////////////////////////////////////////////////////
|
||||||
|
TEST_F(ParityOpsTests, scatterND_add_test6) {
|
||||||
|
|
||||||
|
auto input = NDArrayFactory::create<float>('c', {6, 4});
|
||||||
|
NDArray indices('c', {2, 3, 1}, {50.f, 1.f, 2.f, 3.f, 40.f, 0.f}, nd4j::DataType::INT32);
|
||||||
|
auto updates = NDArrayFactory::create<float>('c', {2,3,4});
|
||||||
|
auto output = NDArrayFactory::create<float>('c', {6,4});
|
||||||
|
|
||||||
|
nd4j::ops::scatter_nd_add op;
|
||||||
|
|
||||||
|
ASSERT_ANY_THROW(op.execute({&input, &indices, &updates}, {&output}, {}, {}, {false, true}));
|
||||||
|
}
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////
|
||||||
TEST_F(ParityOpsTests, scatterND_sub_test1) {
|
TEST_F(ParityOpsTests, scatterND_sub_test1) {
|
||||||
|
|
||||||
|
@ -1586,6 +1637,19 @@ TEST_F(ParityOpsTests, scatterND_update_test5) {
|
||||||
delete result;
|
delete result;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
////////////////////////////////////////////////////////////////////////
|
||||||
|
TEST_F(ParityOpsTests, scatterND_update_test6) {
|
||||||
|
|
||||||
|
auto input = NDArrayFactory::create<float>('c', {6, 4});
|
||||||
|
NDArray indices('c', {3, 3, 2}, {0.f,0.f, 10.f,1.f, 20.f,2.f, 30.f,3.f, 40.f,0.f, 50.f,1.f, 0.f,2.f, 1.f,3.f, 2.f,0.f}, nd4j::DataType::INT32);
|
||||||
|
auto updates = NDArrayFactory::create<float>('c', {3,3});
|
||||||
|
auto output = NDArrayFactory::create<float>('c', {6,4});
|
||||||
|
|
||||||
|
nd4j::ops::scatter_nd_update op;
|
||||||
|
|
||||||
|
ASSERT_ANY_THROW(op.execute({&input, &indices, &updates}, {&output}, {}, {}, {true, true}));
|
||||||
|
}
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////
|
||||||
TEST_F(ParityOpsTests, scatter_update_1) {
|
TEST_F(ParityOpsTests, scatter_update_1) {
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue