Shyrma scatter (#84)

* - improve performance of scatter (no lock) ops for 1D case

Signed-off-by: Yurii <iuriish@yahoo.com>

* - improve scatter lock op performance for 1D case

Signed-off-by: Yurii <iuriish@yahoo.com>

* - add kernel for verification of input indices-array elements in scatter and scatter_nd ops

Signed-off-by: Yurii <iuriish@yahoo.com>

* - provide fast indices checking on cpu side for scatter and gather osp

Signed-off-by: Yurii <iuriish@yahoo.com>

* - apply corrections requested by pr reviewer

Signed-off-by: Yurii <iuriish@yahoo.com>
master
Yurii Shyrma 2019-11-26 19:29:09 +02:00 committed by raver119
parent 8843c7377a
commit a8dd6713aa
24 changed files with 1180 additions and 576 deletions

View File

@ -533,7 +533,7 @@ namespace shape {
* the given shape info buffer * the given shape info buffer
* represents a scalar shape * represents a scalar shape
*/ */
ND4J_EXPORT _CUDA_HD int isScalar(Nd4jLong *info); ND4J_EXPORT _CUDA_HD int isScalar(const Nd4jLong *info);
/** /**
* Returns whether * Returns whether
@ -904,6 +904,7 @@ namespace shape {
ND4J_EXPORT _CUDA_HD void index2coords(Nd4jLong index, const Nd4jLong *shapeInfo, int *coords); ND4J_EXPORT _CUDA_HD void index2coords(Nd4jLong index, const Nd4jLong *shapeInfo, int *coords);
ND4J_EXPORT _CUDA_HD void index2coords(Nd4jLong index, const Nd4jLong *shapeInfo, uint *coords); ND4J_EXPORT _CUDA_HD void index2coords(Nd4jLong index, const Nd4jLong *shapeInfo, uint *coords);
ND4J_EXPORT _CUDA_HD void index2coords(Nd4jLong index, const int rank, const Nd4jLong *shape, Nd4jLong *coords); ND4J_EXPORT _CUDA_HD void index2coords(Nd4jLong index, const int rank, const Nd4jLong *shape, Nd4jLong *coords);
ND4J_EXPORT _CUDA_HD void index2coords(Nd4jLong index, const int rank, const Nd4jLong *shape, int *coords);
/** /**
* take into account only dimensions stored in tadDims, tadDims must be sorted in increasing order! * take into account only dimensions stored in tadDims, tadDims must be sorted in increasing order!
*/ */
@ -2706,7 +2707,7 @@ INLINEDEF _CUDA_HD Nd4jLong indexOffset(Nd4jLong index, const Nd4jLong* lShapeIn
* the given shape info buffer * the given shape info buffer
* represents a scalar shape * represents a scalar shape
*/ */
INLINEDEF _CUDA_HD int isScalar(Nd4jLong *info) { INLINEDEF _CUDA_HD int isScalar(const Nd4jLong *info) {
const int rank = shape::rank(info); const int rank = shape::rank(info);
@ -2715,9 +2716,9 @@ INLINEDEF _CUDA_HD Nd4jLong indexOffset(Nd4jLong index, const Nd4jLong* lShapeIn
if(rank == 0) if(rank == 0)
return 1; return 1;
if(rank == 1) if(rank == 1)
return shape::shapeOf(info)[0] == 1; return shape::shapeOf(const_cast<Nd4jLong*>(info))[0] == 1;
if(rank == 2) if(rank == 2)
return shape::shapeOf(info)[0] == 1 && shape::shapeOf(info)[1] == 1; return shape::shapeOf(const_cast<Nd4jLong*>(info))[0] == 1 && shape::shapeOf(const_cast<Nd4jLong*>(info))[1] == 1;
return 0; return 0;
} }
@ -4793,6 +4794,16 @@ INLINEDEF void _CUDA_HD index2coords(Nd4jLong index, const int rank, const Nd4jL
coords[0] = index; // last iteration coords[0] = index; // last iteration
} }
//////////////////////////////////////////////////////////////////////
INLINEDEF void _CUDA_HD index2coords(Nd4jLong index, const int rank, const Nd4jLong *shape, int *coords) {
for(uint i = rank - 1; i > 0; --i) {
coords[i] = index % shape[i];
index /= shape[i];
}
coords[0] = index; // last iteration
}
////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////
INLINEDEF void _CUDA_HD index2coords(Nd4jLong index, const Nd4jLong *shapeInfo, Nd4jLong *coords, const int dimsSize, const int* tadDims) { INLINEDEF void _CUDA_HD index2coords(Nd4jLong index, const Nd4jLong *shapeInfo, Nd4jLong *coords, const int dimsSize, const int* tadDims) {

View File

@ -39,6 +39,7 @@ OP_IMPL(scatter_add, 3, 1, true) {
output->assign(input); output->assign(input);
const bool lock = block.getBArguments()->empty() ? false : B_ARG(0); const bool lock = block.getBArguments()->empty() ? false : B_ARG(0);
const bool checkIndices = block.getBArguments()->size() <= 1 ? false : B_ARG(1);
const int inRank = input->rankOf(); const int inRank = input->rankOf();
const int indRank = indices->rankOf(); const int indRank = indices->rankOf();
@ -71,8 +72,15 @@ OP_IMPL(scatter_add, 3, 1, true) {
REQUIRE_TRUE(expectedUpdShape == updShape, 0, "SCATTER_ADD OP: wrong shape of updates array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedUpdShape).c_str(), ShapeUtils::shapeAsString(updShape).c_str()); REQUIRE_TRUE(expectedUpdShape == updShape, 0, "SCATTER_ADD OP: wrong shape of updates array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedUpdShape).c_str(), ShapeUtils::shapeAsString(updShape).c_str());
} }
if (!indices->isEmpty()) if (!indices->isEmpty()) {
if(checkIndices) {
const Nd4jLong numOfBadIndx = helpers::checkIndices(block.launchContext(), *indices, *output, 0);
REQUIRE_TRUE(numOfBadIndx == 0, 0, "SCATTER_ADD OP: please check elements of indices-array, total number of wrong elements is %lld!", numOfBadIndx);
}
helpers::scatter(block.launchContext(), pairwise::Add, *indices, *updates, *output, lock); helpers::scatter(block.launchContext(), pairwise::Add, *indices, *updates, *output, lock);
}
return Status::OK(); return Status::OK();
} }

View File

@ -39,6 +39,7 @@ namespace nd4j {
output->assign(input); output->assign(input);
const bool lock = block.getBArguments()->empty() ? false : B_ARG(0); const bool lock = block.getBArguments()->empty() ? false : B_ARG(0);
const bool checkIndices = block.getBArguments()->size() <= 1 ? false : B_ARG(1);
const int inRank = input->rankOf(); const int inRank = input->rankOf();
const int indRank = indices->rankOf(); const int indRank = indices->rankOf();
@ -70,9 +71,15 @@ namespace nd4j {
REQUIRE_TRUE(expectedUpdShape == updShape, 0, "SCATTER_DIV OP: wrong shape of updates array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedUpdShape).c_str(), ShapeUtils::shapeAsString(updShape).c_str()); REQUIRE_TRUE(expectedUpdShape == updShape, 0, "SCATTER_DIV OP: wrong shape of updates array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedUpdShape).c_str(), ShapeUtils::shapeAsString(updShape).c_str());
} }
if (!indices->isEmpty()) if (!indices->isEmpty()) {
// ScatterHelper<T>::template scatterApply<simdOps::Divide<T>>(output, indices, updates);
if(checkIndices) {
const Nd4jLong numOfBadIndx = helpers::checkIndices(block.launchContext(), *indices, *output, 0);
REQUIRE_TRUE(numOfBadIndx == 0, 0, "SCATTER_DIV OP: please check elements of indices-array, total number of wrong elements is %lld!", numOfBadIndx);
}
helpers::scatter(block.launchContext(), pairwise::Divide, *indices, *updates, *output, lock); helpers::scatter(block.launchContext(), pairwise::Divide, *indices, *updates, *output, lock);
}
return Status::OK(); return Status::OK();
} }

View File

@ -39,6 +39,7 @@ OP_IMPL(scatter_max, 3, 1, true) {
output->assign(input); output->assign(input);
const bool lock = block.getBArguments()->empty() ? false : B_ARG(0); const bool lock = block.getBArguments()->empty() ? false : B_ARG(0);
const bool checkIndices = block.getBArguments()->size() <= 1 ? false : B_ARG(1);
const int inRank = input->rankOf(); const int inRank = input->rankOf();
const int indRank = indices->rankOf(); const int indRank = indices->rankOf();
@ -70,8 +71,15 @@ OP_IMPL(scatter_max, 3, 1, true) {
REQUIRE_TRUE(expectedUpdShape == updShape, 0, "SCATTER_MAX OP: wrong shape of updates array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedUpdShape).c_str(), ShapeUtils::shapeAsString(updShape).c_str()); REQUIRE_TRUE(expectedUpdShape == updShape, 0, "SCATTER_MAX OP: wrong shape of updates array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedUpdShape).c_str(), ShapeUtils::shapeAsString(updShape).c_str());
} }
if (!indices->isEmpty()) if (!indices->isEmpty()) {
if(checkIndices) {
const Nd4jLong numOfBadIndx = helpers::checkIndices(block.launchContext(), *indices, *output, 0);
REQUIRE_TRUE(numOfBadIndx == 0, 0, "SCATTER_MAX OP: please check elements of indices-array, total number of wrong elements is %lld!", numOfBadIndx);
}
helpers::scatter(block.launchContext(), pairwise::MaxPairwise, *indices, *updates, *output, lock); helpers::scatter(block.launchContext(), pairwise::MaxPairwise, *indices, *updates, *output, lock);
}
return Status::OK(); return Status::OK();
} }

View File

@ -39,6 +39,7 @@ OP_IMPL(scatter_min, 3, 1, true) {
output->assign(input); output->assign(input);
const bool lock = block.getBArguments()->empty() ? false : B_ARG(0); const bool lock = block.getBArguments()->empty() ? false : B_ARG(0);
const bool checkIndices = block.getBArguments()->size() <= 1 ? false : B_ARG(1);
const int inRank = input->rankOf(); const int inRank = input->rankOf();
const int indRank = indices->rankOf(); const int indRank = indices->rankOf();
@ -70,8 +71,15 @@ OP_IMPL(scatter_min, 3, 1, true) {
REQUIRE_TRUE(expectedUpdShape == updShape, 0, "SCATTER_MIN OP: wrong shape of updates array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedUpdShape).c_str(), ShapeUtils::shapeAsString(updShape).c_str()); REQUIRE_TRUE(expectedUpdShape == updShape, 0, "SCATTER_MIN OP: wrong shape of updates array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedUpdShape).c_str(), ShapeUtils::shapeAsString(updShape).c_str());
} }
if (!indices->isEmpty()) if (!indices->isEmpty()) {
if(checkIndices) {
const Nd4jLong numOfBadIndx = helpers::checkIndices(block.launchContext(), *indices, *output, 0);
REQUIRE_TRUE(numOfBadIndx == 0, 0, "SCATTER_MIN OP: please check elements of indices-array, total number of wrong elements is %lld!", numOfBadIndx);
}
helpers::scatter(block.launchContext(), pairwise::MinPairwise, *indices, *updates, *output, lock); helpers::scatter(block.launchContext(), pairwise::MinPairwise, *indices, *updates, *output, lock);
}
return Status::OK(); return Status::OK();
} }

View File

@ -35,6 +35,7 @@ namespace nd4j {
auto output = OUTPUT_VARIABLE(0); auto output = OUTPUT_VARIABLE(0);
const bool lock = block.getBArguments()->empty() ? false : B_ARG(0); const bool lock = block.getBArguments()->empty() ? false : B_ARG(0);
const bool checkIndices = block.getBArguments()->size() <= 1 ? false : B_ARG(1);
const int inRank = input->rankOf(); const int inRank = input->rankOf();
const int indRank = indices->rankOf(); const int indRank = indices->rankOf();
@ -70,8 +71,15 @@ namespace nd4j {
REQUIRE_TRUE(expectedUpdShape == updShape, 0, "SCATTER_MUL OP: wrong shape of updates array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedUpdShape).c_str(), ShapeUtils::shapeAsString(updShape).c_str()); REQUIRE_TRUE(expectedUpdShape == updShape, 0, "SCATTER_MUL OP: wrong shape of updates array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedUpdShape).c_str(), ShapeUtils::shapeAsString(updShape).c_str());
} }
if (!indices->isEmpty()) if (!indices->isEmpty()) {
if(checkIndices) {
const Nd4jLong numOfBadIndx = helpers::checkIndices(block.launchContext(), *indices, *output, 0);
REQUIRE_TRUE(numOfBadIndx == 0, 0, "SCATTER_MUL OP: please check elements of indices-array, total number of wrong elements is %lld!", numOfBadIndx);
}
helpers::scatter(block.launchContext(), pairwise::Multiply, *indices, *updates, *output, lock); helpers::scatter(block.launchContext(), pairwise::Multiply, *indices, *updates, *output, lock);
}
return Status::OK(); return Status::OK();
} }

View File

@ -35,6 +35,7 @@ namespace ops {
auto output = OUTPUT_VARIABLE(0); auto output = OUTPUT_VARIABLE(0);
const bool lock = block.getBArguments()->empty() ? false : B_ARG(0); const bool lock = block.getBArguments()->empty() ? false : B_ARG(0);
const bool checkIndices = block.getBArguments()->size() <= 1 ? false : B_ARG(1);
const int indRank = indices->rankOf(); const int indRank = indices->rankOf();
const int updRank = updates->rankOf(); const int updRank = updates->rankOf();
@ -53,6 +54,11 @@ namespace ops {
std::move(std::begin(outShape) + indices->sizeAt(-1), std::end(outShape), std::back_inserter(expectedUpdShape)); std::move(std::begin(outShape) + indices->sizeAt(-1), std::end(outShape), std::back_inserter(expectedUpdShape));
REQUIRE_TRUE(expectedUpdShape == updShape, 0, "SCATTER_ND OP: wrong shape of updates array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedUpdShape).c_str(), ShapeUtils::shapeAsString(updShape).c_str()); REQUIRE_TRUE(expectedUpdShape == updShape, 0, "SCATTER_ND OP: wrong shape of updates array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedUpdShape).c_str(), ShapeUtils::shapeAsString(updShape).c_str());
if(checkIndices) {
const Nd4jLong numOfBadIndx = helpers::checkIndices(block.launchContext(), *indices, *output);
REQUIRE_TRUE(numOfBadIndx == 0, 0, "SCATTER_ND OP: please check elements of indices-array, total number of wrong elements is %lld!", numOfBadIndx);
}
// initial zeroing of output // initial zeroing of output
*output = 0; *output = 0;

View File

@ -34,7 +34,8 @@ OP_IMPL(scatter_nd_add, 3, 1, true) {
auto output = OUTPUT_VARIABLE(0); auto output = OUTPUT_VARIABLE(0);
bool lock = block.getBArguments()->empty() ? false : B_ARG(0); const bool lock = block.getBArguments()->empty() ? false : B_ARG(0);
const bool checkIndices = block.getBArguments()->size() <= 1 ? false : B_ARG(1);
const int inRank = input->rankOf(); const int inRank = input->rankOf();
const int indRank = indices->rankOf(); const int indRank = indices->rankOf();
@ -53,6 +54,11 @@ OP_IMPL(scatter_nd_add, 3, 1, true) {
std::move(std::begin(inShape) + indLastDim, std::end(inShape), std::back_inserter(expectedUpdShape)); std::move(std::begin(inShape) + indLastDim, std::end(inShape), std::back_inserter(expectedUpdShape));
REQUIRE_TRUE(expectedUpdShape == updShape, 0, "SCATTER_ND_ADD OP: wrong shape of updates array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedUpdShape).c_str(), ShapeUtils::shapeAsString(updShape).c_str()); REQUIRE_TRUE(expectedUpdShape == updShape, 0, "SCATTER_ND_ADD OP: wrong shape of updates array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedUpdShape).c_str(), ShapeUtils::shapeAsString(updShape).c_str());
if(checkIndices) {
const Nd4jLong numOfBadIndx = helpers::checkIndices(block.launchContext(), *indices, *output);
REQUIRE_TRUE(numOfBadIndx == 0, 0, "SCATTER_ND_ADD OP: please check elements of indices-array, total number of wrong elements is %lld!", numOfBadIndx);
}
if (!block.isInplace()) if (!block.isInplace())
output->assign(input); output->assign(input);

View File

@ -34,7 +34,8 @@ OP_IMPL(scatter_nd_sub, 3, 1, true) {
auto output = OUTPUT_VARIABLE(0); auto output = OUTPUT_VARIABLE(0);
bool lock = block.getBArguments()->empty() ? false : B_ARG(0); const bool lock = block.getBArguments()->empty() ? false : B_ARG(0);
const bool checkIndices = block.getBArguments()->size() <= 1 ? false : B_ARG(1);
const int inRank = input->rankOf(); const int inRank = input->rankOf();
const int indRank = indices->rankOf(); const int indRank = indices->rankOf();
@ -53,6 +54,11 @@ OP_IMPL(scatter_nd_sub, 3, 1, true) {
std::move(std::begin(inShape) + indLastDim, std::end(inShape), std::back_inserter(expectedUpdShape)); std::move(std::begin(inShape) + indLastDim, std::end(inShape), std::back_inserter(expectedUpdShape));
REQUIRE_TRUE(expectedUpdShape == updShape, 0, "SCATTER_ND_SUB OP: wrong shape of updates array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedUpdShape).c_str(), ShapeUtils::shapeAsString(updShape).c_str()); REQUIRE_TRUE(expectedUpdShape == updShape, 0, "SCATTER_ND_SUB OP: wrong shape of updates array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedUpdShape).c_str(), ShapeUtils::shapeAsString(updShape).c_str());
if(checkIndices) {
const Nd4jLong numOfBadIndx = helpers::checkIndices(block.launchContext(), *indices, *output);
REQUIRE_TRUE(numOfBadIndx == 0, 0, "SCATTER_ND_SUB OP: please check elements of indices-array, total number of wrong elements is %lld!", numOfBadIndx);
}
if (!block.isInplace()) if (!block.isInplace())
output->assign(input); output->assign(input);

View File

@ -34,7 +34,8 @@ OP_IMPL(scatter_nd_update, 3, 1, true) {
auto output = OUTPUT_VARIABLE(0); auto output = OUTPUT_VARIABLE(0);
bool lock = block.getBArguments()->empty() ? true : B_ARG(0); const bool lock = block.getBArguments()->empty() ? true : B_ARG(0);
const bool checkIndices = block.getBArguments()->size() <= 1 ? false : B_ARG(1);
const int inRank = input->rankOf(); const int inRank = input->rankOf();
const int indRank = indices->rankOf(); const int indRank = indices->rankOf();
@ -53,6 +54,11 @@ OP_IMPL(scatter_nd_update, 3, 1, true) {
std::move(std::begin(inShape) + indLastDim, std::end(inShape), std::back_inserter(expectedUpdShape)); std::move(std::begin(inShape) + indLastDim, std::end(inShape), std::back_inserter(expectedUpdShape));
REQUIRE_TRUE(expectedUpdShape == updShape, 0, "SCATTER_ND_UPDATE OP: wrong shape of updates array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedUpdShape).c_str(), ShapeUtils::shapeAsString(updShape).c_str()); REQUIRE_TRUE(expectedUpdShape == updShape, 0, "SCATTER_ND_UPDATE OP: wrong shape of updates array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedUpdShape).c_str(), ShapeUtils::shapeAsString(updShape).c_str());
if(checkIndices) {
const Nd4jLong numOfBadIndx = helpers::checkIndices(block.launchContext(), *indices, *output);
REQUIRE_TRUE(numOfBadIndx == 0, 0, "SCATTER_ND_UPDATE OP: please check elements of indices-array, total number of wrong elements is %lld!", numOfBadIndx);
}
if (!block.isInplace()) if (!block.isInplace())
output->assign(input); output->assign(input);

View File

@ -38,6 +38,7 @@ namespace nd4j {
output->assign(input); output->assign(input);
const bool lock = block.getBArguments()->empty() ? false : B_ARG(0); const bool lock = block.getBArguments()->empty() ? false : B_ARG(0);
const bool checkIndices = block.getBArguments()->size() <= 1 ? false : B_ARG(1);
const int inRank = input->rankOf(); const int inRank = input->rankOf();
const int indRank = indices->rankOf(); const int indRank = indices->rankOf();
@ -70,9 +71,16 @@ namespace nd4j {
REQUIRE_TRUE(expectedUpdShape == updShape, 0, "SCATTER_SUB OP: wrong shape of updates array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedUpdShape).c_str(), ShapeUtils::shapeAsString(updShape).c_str()); REQUIRE_TRUE(expectedUpdShape == updShape, 0, "SCATTER_SUB OP: wrong shape of updates array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedUpdShape).c_str(), ShapeUtils::shapeAsString(updShape).c_str());
} }
if (!indices->isEmpty()) if (!indices->isEmpty()) {
if(checkIndices) {
const Nd4jLong numOfBadIndx = helpers::checkIndices(block.launchContext(), *indices, *output, 0);
REQUIRE_TRUE(numOfBadIndx == 0, 0, "SCATTER_SUB OP: please check elements of indices-array, total number of wrong elements is %lld!", numOfBadIndx);
}
// ScatterHelper<T>::template scatterApply<simdOps::Subtract<T>>(output, indices, updates); // ScatterHelper<T>::template scatterApply<simdOps::Subtract<T>>(output, indices, updates);
helpers::scatter(block.launchContext(), pairwise::Subtract, *indices, *updates, *output, lock); helpers::scatter(block.launchContext(), pairwise::Subtract, *indices, *updates, *output, lock);
}
return Status::OK(); return Status::OK();
} }

View File

@ -37,6 +37,7 @@ namespace nd4j {
output->assign(input); output->assign(input);
const bool lock = block.getBArguments()->empty() ? true : B_ARG(0); const bool lock = block.getBArguments()->empty() ? true : B_ARG(0);
const bool checkIndices = block.getBArguments()->size() <= 1 ? false : B_ARG(1);
const int inRank = input->rankOf(); const int inRank = input->rankOf();
const int indRank = indices->rankOf(); const int indRank = indices->rankOf();
@ -68,9 +69,16 @@ namespace nd4j {
REQUIRE_TRUE(expectedUpdShape == updShape, 0, "SCATTER_UPD OP: wrong shape of updates array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedUpdShape).c_str(), ShapeUtils::shapeAsString(updShape).c_str()); REQUIRE_TRUE(expectedUpdShape == updShape, 0, "SCATTER_UPD OP: wrong shape of updates array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedUpdShape).c_str(), ShapeUtils::shapeAsString(updShape).c_str());
} }
if (!indices->isEmpty()) if (!indices->isEmpty()) {
if(checkIndices) {
const Nd4jLong numOfBadIndx = helpers::checkIndices(block.launchContext(), *indices, *output, 0);
REQUIRE_TRUE(numOfBadIndx == 0, 0, "SCATTER_UPD OP: please check elements of indices-array, total number of wrong elements is %lld!", numOfBadIndx);
}
// ScatterHelper<T>::template scatterApply<simdOps::Copy<T>>(output, indices, updates); // ScatterHelper<T>::template scatterApply<simdOps::Copy<T>>(output, indices, updates);
helpers::scatter(block.launchContext(), pairwise::CopyPws, *indices, *updates, *output, lock); helpers::scatter(block.launchContext(), pairwise::CopyPws, *indices, *updates, *output, lock);
}
return Status::OK(); return Status::OK();
} }

View File

@ -22,7 +22,8 @@
#if NOT_EXCLUDED(OP_gather) #if NOT_EXCLUDED(OP_gather)
#include <ops/declarable/CustomOperations.h> #include <ops/declarable/CustomOperations.h>
#include<ops/declarable/helpers/gather.h> #include <ops/declarable/helpers/gather.h>
#include <ops/declarable/helpers/scatter.h>
namespace nd4j { namespace nd4j {
@ -36,6 +37,8 @@ CUSTOM_OP_IMPL(gather, 1, 1, false, 0, -2) {
auto indices = block.width() > 1 ? INPUT_VARIABLE(1) : nullptr; auto indices = block.width() > 1 ? INPUT_VARIABLE(1) : nullptr;
auto output = OUTPUT_VARIABLE(0); auto output = OUTPUT_VARIABLE(0);
const bool checkIndices = block.getBArguments()->empty() ? false : B_ARG(0);
//Edge case: empty indices -> empty output //Edge case: empty indices -> empty output
if(indices != nullptr && indices->isEmpty()){ if(indices != nullptr && indices->isEmpty()){
REQUIRE_TRUE(output->isEmpty(), 0, "Gather op: If indices are empty, output must also be empty"); REQUIRE_TRUE(output->isEmpty(), 0, "Gather op: If indices are empty, output must also be empty");
@ -64,13 +67,15 @@ CUSTOM_OP_IMPL(gather, 1, 1, false, 0, -2) {
REQUIRE_TRUE(intArgs[0] < inputRank, 0, "GATHER op: input axis must be smaller than input array rank, but got %i and %i correspondingly!", intArgs[0], inputRank); REQUIRE_TRUE(intArgs[0] < inputRank, 0, "GATHER op: input axis must be smaller than input array rank, but got %i and %i correspondingly!", intArgs[0], inputRank);
REQUIRE_TRUE(indices != nullptr || numOfIntArgs > 1, 0, "GATHER op: indices should be provided either as additional input array or as IntArguments !"); REQUIRE_TRUE(indices != nullptr || numOfIntArgs > 1, 0, "GATHER op: indices should be provided either as additional input array or as IntArguments !");
if (indices != nullptr) { if(checkIndices) {
for(int i = 0; i < indices->lengthOf(); ++i)
REQUIRE_TRUE(indices->e<Nd4jLong>(i) < input->sizeAt(intArgs[0]), 0, "GATHER op: indices array contains wrong elements, each element must be smaller than corresponding dimension of input array !"); NDArray* pIndices = indices;
} if(indices == nullptr)
else { pIndices = new NDArray(input->ordering(), {static_cast<int>(intArgs.size()) - 1}, std::vector<double>(intArgs.begin() + 1, intArgs.end()), DataType::INT64, block.launchContext());
for(int i = 1; i < numOfIntArgs; ++i) const Nd4jLong numOfBadIndx = helpers::checkIndices(block.launchContext(), *pIndices, *input, intArgs[0]);
REQUIRE_TRUE(intArgs[i] < input->sizeAt(intArgs[0]), 0, "GATHER op: some of indexes is larger than corresponding shape of input array !"); REQUIRE_TRUE(numOfBadIndx == 0, 0, "GATHER OP: please check elements of indices-array, total number of wrong elements is %lld!", numOfBadIndx);
if(indices == nullptr)
delete pIndices;
} }
helpers::gather(block.launchContext(), input, indices, output, intArgs); helpers::gather(block.launchContext(), input, indices, output, intArgs);

View File

@ -23,6 +23,7 @@
#include <ops/declarable/CustomOperations.h> #include <ops/declarable/CustomOperations.h>
#include<ops/declarable/helpers/transforms.h> #include<ops/declarable/helpers/transforms.h>
#include <ops/declarable/helpers/scatter.h>
namespace nd4j { namespace nd4j {
namespace ops { namespace ops {
@ -35,6 +36,8 @@ CUSTOM_OP_IMPL(gather_nd, 2, 1, false, 0, 0) {
auto indices = INPUT_VARIABLE(1); auto indices = INPUT_VARIABLE(1);
auto output = OUTPUT_VARIABLE(0); auto output = OUTPUT_VARIABLE(0);
const bool checkIndices = block.getBArguments()->empty() ? false : B_ARG(0);
const int rankIn = input->rankOf(); const int rankIn = input->rankOf();
const int rankInd = indices->rankOf(); const int rankInd = indices->rankOf();
@ -42,6 +45,11 @@ CUSTOM_OP_IMPL(gather_nd, 2, 1, false, 0, 0) {
int lastIndDim = indices->sizeAt(-1); int lastIndDim = indices->sizeAt(-1);
REQUIRE_TRUE(lastIndDim <= rankIn, 0, "GATHER_ND op: the last dimension of indices array must be <= rank of input array but got %i and %i correspondingly!", lastIndDim, rankIn); REQUIRE_TRUE(lastIndDim <= rankIn, 0, "GATHER_ND op: the last dimension of indices array must be <= rank of input array but got %i and %i correspondingly!", lastIndDim, rankIn);
if(checkIndices) {
const Nd4jLong numOfBadIndx = helpers::checkIndices(block.launchContext(), *indices, *input);
REQUIRE_TRUE(numOfBadIndx == 0, 0, "GATHER_ND OP: please check elements of indices-array, total number of wrong elements is %lld!", numOfBadIndx);
}
helpers::gatherND(block.launchContext(), *input, *indices, *output); helpers::gatherND(block.launchContext(), *input, *indices, *output);
return Status::OK(); return Status::OK();

View File

@ -27,6 +27,49 @@ namespace nd4j {
namespace ops { namespace ops {
namespace helpers { namespace helpers {
///////////////////////////////////////////////////////////////////
// x - indices, z - input/output
template<typename T>
Nd4jLong checkIndices_(const NDArray& indices, const NDArray& output, const int axis) {
std::atomic<int64_t> numOfBadIndx{0};
const auto x = indices.bufferAsT<T>();
const auto xShapeInfo = indices.getShapeInfo();
const auto zShapeInfo = output.getShapeInfo();
const auto xRank = indices.rankOf();
auto func = PRAGMA_THREADS_FOR {
Nd4jLong xCoords[MAX_RANK];
for (auto i = start; i < stop; i += increment) {
shape::index2coords(i, xShapeInfo, xCoords);
const Nd4jLong currentInd = x[shape::getOffset(xShapeInfo, xCoords)];
if(currentInd >= shape::sizeAt(zShapeInfo, axis == -1 ? xCoords[xRank-1] : axis)) {
printf("checkIndices: out of range element %lld at index %ld \n", currentInd, i);
++numOfBadIndx;
}
}
};
samediff::Threads::parallel_for(func, 0, indices.lengthOf());
return numOfBadIndx;
}
///////////////////////////////////////////////////////////////////
Nd4jLong checkIndices(nd4j::LaunchContext *context, const NDArray& indices, const NDArray& output, const int axis) {
BUILD_SINGLE_SELECTOR(indices.dataType(), return checkIndices_, (indices, output, axis), INDEXING_TYPES);
}
///////////////////////////////////////////////////////////////////
void scatter(nd4j::LaunchContext *context, pairwise::Ops op, const NDArray& indices, const NDArray& updates, NDArray& output, const bool lock) { void scatter(nd4j::LaunchContext *context, pairwise::Ops op, const NDArray& indices, const NDArray& updates, NDArray& output, const bool lock) {
const int outRank = output.rankOf(); const int outRank = output.rankOf();

View File

@ -108,12 +108,12 @@ __host__ static void gatherCudaLauncher(const cudaStream_t *stream, const int nu
void gather(nd4j::LaunchContext * context, const NDArray* input, const NDArray* indices, NDArray* output, const std::vector<int>& intArgs) { void gather(nd4j::LaunchContext * context, const NDArray* input, const NDArray* indices, NDArray* output, const std::vector<int>& intArgs) {
const int inputRank = input->rankOf(); const int inputRank = input->rankOf();
int axis = intArgs.size() > 0 ? intArgs[0] : 0; const int numOfIntArgs = intArgs.size();
int axis = numOfIntArgs > 0 ? intArgs[0] : 0;
if(axis < 0) if(axis < 0)
axis += inputRank; axis += inputRank;
const int numOfIntArgs = intArgs.size();
if (indices == nullptr && numOfIntArgs == 2) { // scalar case if (indices == nullptr && numOfIntArgs == 2) { // scalar case
output->assign((*input)(intArgs[1], {axis})); output->assign((*input)(intArgs[1], {axis}));
} }

View File

@ -106,7 +106,7 @@ namespace nd4j {
const auto xOffset = shape::getOffset(xShapeInfo, xCoordStart); const auto xOffset = shape::getOffset(xShapeInfo, xCoordStart);
z[zOffset] = x[xOffset]; z[zOffset] = x[xOffset];
printf("z[%lld] = x[%lld] = %f\n", zOffset, xOffset, (float) z[zOffset]); // printf("z[%lld] = x[%lld] = %f\n", zOffset, xOffset, (float) z[zOffset]);
} }
} }

File diff suppressed because it is too large Load Diff

View File

@ -31,6 +31,8 @@ namespace nd4j {
void scatterND(nd4j::LaunchContext* context, pairwise::Ops op, const NDArray& indices, const NDArray& updates, NDArray& output, const bool lock); void scatterND(nd4j::LaunchContext* context, pairwise::Ops op, const NDArray& indices, const NDArray& updates, NDArray& output, const bool lock);
void scatterForLoss(nd4j::LaunchContext* context, const NDArray& indices, NDArray& updates, NDArray& output, const bool calcGrad); void scatterForLoss(nd4j::LaunchContext* context, const NDArray& indices, NDArray& updates, NDArray& output, const bool calcGrad);
Nd4jLong checkIndices(nd4j::LaunchContext *context, const NDArray& indices, const NDArray& output, const int axis = -1);
} }
} }
} }

View File

@ -76,6 +76,16 @@ TEST_F(DeclarableOpsTests16, scatter_upd_2) {
delete result; delete result;
} }
TEST_F(DeclarableOpsTests16, scatter_upd_3) {
NDArray x('c', {10, 3}, nd4j::DataType::FLOAT32);
NDArray indices('c', {2}, {20,5}, nd4j::DataType::INT32);
NDArray updates('c', {2, 3}, {100,101,102, 200,201,202}, nd4j::DataType::FLOAT32);
NDArray output('c', {10, 3}, nd4j::DataType::FLOAT32);
nd4j::ops::scatter_upd op;
ASSERT_ANY_THROW(op.execute({&x, &indices, &updates}, {&output}, {}, {}, {true, true}));
}
TEST_F(DeclarableOpsTests16, test_size_dtype_1) { TEST_F(DeclarableOpsTests16, test_size_dtype_1) {
auto x = NDArrayFactory::create<float>('c', {3}, {1, 1, 1}); auto x = NDArrayFactory::create<float>('c', {3}, {1, 1, 1});

View File

@ -62,7 +62,7 @@ TEST_F(DeclarableOpsTests2, gather_2) {
nd4j::ops::gather op; nd4j::ops::gather op;
auto result = op.execute({&input}, {}, {1, 0,1, 2,2, 1,2}); auto result = op.execute({&input}, {}, {1, 0,1, 2,2, 1,2}, {true});
ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_EQ(ND4J_STATUS_OK, result->status());
@ -125,7 +125,7 @@ TEST_F(DeclarableOpsTests2, gather_5) {
nd4j::ops::gather op; nd4j::ops::gather op;
auto result = op.execute({&input, &indices}, {}, {1}); auto result = op.execute({&input, &indices}, {}, {1}, {true});
ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_EQ(ND4J_STATUS_OK, result->status());
@ -294,7 +294,7 @@ TEST_F(DeclarableOpsTests2, gather_13) {
nd4j::ops::gather op; nd4j::ops::gather op;
auto result = op.execute({&input, &indices}, {}, {2}); auto result = op.execute({&input, &indices}, {}, {2}, {true});
ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_EQ(ND4J_STATUS_OK, result->status());
@ -306,6 +306,30 @@ TEST_F(DeclarableOpsTests2, gather_13) {
delete result; delete result;
} }
////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests2, gather_14) {
NDArray input ('c', {2,3,4}, {1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24});
NDArray indices ('c', {2,3}, {0, 10, 2, 20, 1,2}, nd4j::DataType::INT32);
NDArray output('c', {2,2,3,4});
nd4j::ops::gather op;
ASSERT_ANY_THROW(op.execute({&input, &indices}, {&output}, {}, {1}, {true}));
}
////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests2, gather_15) {
NDArray input ('c', {2,3,4,5}, nd4j::DataType::DOUBLE);
NDArray indices ('c', {2,3,4}, {0, 10, 2, 3, 0, 1, 20, 3, 0, 1, 2, 3,0, 1, 2, 3, 0, 1, 2, 30, 0, 1, 2, 3}, nd4j::DataType::INT32);
NDArray output('c', {2,3, 2,3,4, 5});
nd4j::ops::gather op;
ASSERT_ANY_THROW(op.execute({&input, &indices}, {&output}, {}, {2}, {true}));
}
TEST_F(DeclarableOpsTests2, BroadcastGradientArgs_1) { TEST_F(DeclarableOpsTests2, BroadcastGradientArgs_1) {
NDArray input ('c', {3,3,4}, {1, 2, 3, 4, 5, 6, 7, 8, 9,10,11,12, 13,14,15,16,17,18,19,20,21,22,23,24, 25,26,27,28,29,30,31,32,33,34,35,36}, nd4j::DataType::INT32); NDArray input ('c', {3,3,4}, {1, 2, 3, 4, 5, 6, 7, 8, 9,10,11,12, 13,14,15,16,17,18,19,20,21,22,23,24, 25,26,27,28,29,30,31,32,33,34,35,36}, nd4j::DataType::INT32);

View File

@ -704,7 +704,7 @@ TEST_F(DeclarableOpsTests5, gatherNd_test2) {
auto expected = NDArrayFactory::create<double>('c', {2,2,2}, {23, 24, 11, 12, 3, 4, 3, 4}); auto expected = NDArrayFactory::create<double>('c', {2,2,2}, {23, 24, 11, 12, 3, 4, 3, 4});
nd4j::ops::gather_nd op; nd4j::ops::gather_nd op;
auto results = op.execute({&input, &indices}, {}, {}); auto results = op.execute({&input, &indices}, {}, {}, {true});
auto output = results->at(0); auto output = results->at(0);
ASSERT_EQ(Status::OK(), results->status()); ASSERT_EQ(Status::OK(), results->status());
@ -798,7 +798,7 @@ TEST_F(DeclarableOpsTests5, gatherNd_test7) {
auto expected = NDArrayFactory::create<double>('c', {3,3}, {3,5,5,8,5,10,2,2,14}); auto expected = NDArrayFactory::create<double>('c', {3,3}, {3,5,5,8,5,10,2,2,14});
nd4j::ops::gather_nd op; nd4j::ops::gather_nd op;
auto results = op.execute({&input, &indices}, {}, {}); auto results = op.execute({&input, &indices}, {}, {}, {true});
auto output = results->at(0); auto output = results->at(0);
ASSERT_EQ(Status::OK(), results->status()); ASSERT_EQ(Status::OK(), results->status());
@ -825,6 +825,52 @@ TEST_F(DeclarableOpsTests5, gatherNd_test8) {
delete result; delete result;
} }
TEST_F(DeclarableOpsTests5, gatherNd_test9) {
auto x = NDArrayFactory::create<double>('c', {2, 4, 2, 2});
auto indices = NDArrayFactory::create<int>('c', {3, 3}, {0,2,1, 0,1,0, 1,3,1});
auto exp = NDArrayFactory::create<double>('c', {3,2}, {11.f, 12.f, 5.f, 6.f, 31.f, 32.f});
x.linspace(1);
nd4j::ops::gather_nd op;
auto result = op.execute({&x, &indices}, {}, {});
ASSERT_EQ(Status::OK(), result->status());
auto z = result->at(0);
//z->printIndexedBuffer();
//z->printShapeInfo("z shape");
ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z));
delete result;
}
//////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests5, gatherNd_test10) {
auto input = NDArrayFactory::create<double>('c', {4, 3, 2});
auto indices = NDArrayFactory::create<int>('c', {2,2,2}, {30,20,1,2, 0,10,0,1});
auto output = NDArrayFactory::create<double>('c', {2,2,2});
nd4j::ops::gather_nd op;
ASSERT_ANY_THROW(op.execute({&input, &indices}, {&output}, {}, {}, {true}));
}
//////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests5, gatherNd_test11) {
auto input = NDArrayFactory::create<double>('c', {4, 4});
auto indices = NDArrayFactory::create<int>('c', {3,3,2}, {0,2,1, 0,10,0, 1,30,1, 0,20,1, 0,1,0, 1,30,1});
auto output = NDArrayFactory::create<double>('c', {3,3});
nd4j::ops::gather_nd op;
ASSERT_ANY_THROW(op.execute({&input, &indices}, {&output}, {}, {}, {true}));
}
////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests5, reverse_sequense_test1) { TEST_F(DeclarableOpsTests5, reverse_sequense_test1) {

View File

@ -314,27 +314,6 @@ TEST_F(DeclarableOpsTests6, Test_Simple_Scalar_1) {
delete result; delete result;
} }
TEST_F(DeclarableOpsTests6, Test_gatherNd_Edge_1) {
auto x = NDArrayFactory::create<double>('c', {2, 4, 2, 2});
auto indices = NDArrayFactory::create<int>('c', {3, 3}, {0,2,1, 0,1,0, 1,3,1});
auto exp = NDArrayFactory::create<double>('c', {3,2}, {11.f, 12.f, 5.f, 6.f, 31.f, 32.f});
x.linspace(1);
nd4j::ops::gather_nd op;
auto result = op.execute({&x, &indices}, {}, {});
ASSERT_EQ(Status::OK(), result->status());
auto z = result->at(0);
//z->printIndexedBuffer();
//z->printShapeInfo("z shape");
ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z));
delete result;
}
TEST_F(DeclarableOpsTests6, Test_Order_1) { TEST_F(DeclarableOpsTests6, Test_Order_1) {
auto x = NDArrayFactory::create<double>('f', {2, 3}); auto x = NDArrayFactory::create<double>('f', {2, 3});
auto exp = NDArrayFactory::create<double>('c', {2, 3}); auto exp = NDArrayFactory::create<double>('c', {2, 3});

View File

@ -756,7 +756,7 @@ TEST_F(ParityOpsTests, Test_Scatter_Add_4) {
auto exp = NDArrayFactory::create<float>('c', {2, 2, 2}, {3, 4, 5, 6, 5, 6, 7, 8}); auto exp = NDArrayFactory::create<float>('c', {2, 2, 2}, {3, 4, 5, 6, 5, 6, 7, 8});
nd4j::ops::scatter_add op; nd4j::ops::scatter_add op;
auto result = op.execute({&matrix, &idc, &updates}, {}, {}, {true}); auto result = op.execute({&matrix, &idc, &updates}, {}, {}, {true, true});
ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_EQ(ND4J_STATUS_OK, result->status());
auto z = result->at(0); auto z = result->at(0);
@ -791,7 +791,7 @@ TEST_F(ParityOpsTests, Test_Scatter_Add_6) {
auto exp = NDArrayFactory::create<float>('c', {2, 2, 2}, {7, 9, 11, 13, 7, 9, 11, 13}); auto exp = NDArrayFactory::create<float>('c', {2, 2, 2}, {7, 9, 11, 13, 7, 9, 11, 13});
nd4j::ops::scatter_add op; nd4j::ops::scatter_add op;
auto result = op.execute({&matrix, &idc, &updates}, {}, {}, {true}); auto result = op.execute({&matrix, &idc, &updates}, {}, {}, {true, true});
ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_EQ(ND4J_STATUS_OK, result->status());
auto z = result->at(0); auto z = result->at(0);
@ -837,6 +837,18 @@ TEST_F(ParityOpsTests, Test_Scatter_Add_8) {
ASSERT_TRUE(expected.equalsTo(z)); ASSERT_TRUE(expected.equalsTo(z));
} }
////////////////////////////////////////////////////////////////////
TEST_F(ParityOpsTests, Test_Scatter_Add_9) {
auto matrix = NDArrayFactory::create<float>('c', {2, 2, 3}, {1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1});
NDArray idc('c', {2, 2}, {1, 10, 0, 0}, nd4j::DataType::INT64);
auto updates = NDArrayFactory::create<float>('c', {2, 2, 2, 3}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12});
auto output = NDArrayFactory::create<float>('c', {2, 2, 3});
nd4j::ops::scatter_add op;
ASSERT_ANY_THROW(op.execute({&matrix, &idc, &updates}, {&output}, {}, {}, {true, true}));
}
//////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////
TEST_F(ParityOpsTests, scatterMax_test1) { TEST_F(ParityOpsTests, scatterMax_test1) {
auto matrix = NDArrayFactory::create<float>('c', {2, 2}, {1, 2, 3, 4}); auto matrix = NDArrayFactory::create<float>('c', {2, 2}, {1, 2, 3, 4});
@ -1010,6 +1022,18 @@ TEST_F(ParityOpsTests, scatterMin_test4) {
delete result; delete result;
} }
////////////////////////////////////////////////////////////////////////
TEST_F(ParityOpsTests, scatterMin_test5) {
auto matrix = NDArrayFactory::create<float>('c', {2, 2, 2}, {1, 2, 3, 4, 5, 6, 7, 8});
NDArray idc('c', {1,2}, {10,10}, nd4j::DataType::INT32);
auto updates = NDArrayFactory::create<float>('c', {1, 2, 2, 2}, {1,10,1,10, 1,1,10,1.});
auto output = NDArrayFactory::create<float>('c', {2, 2, 2});
nd4j::ops::scatter_min op;
ASSERT_ANY_THROW(op.execute({&matrix, &idc, &updates}, {&output}, {}, {}, {true, true}));
}
//////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////
TEST_F(ParityOpsTests, scatterND_test1) { TEST_F(ParityOpsTests, scatterND_test1) {
@ -1019,7 +1043,7 @@ TEST_F(ParityOpsTests, scatterND_test1) {
auto exp = NDArrayFactory::create<float>('c', {3, 4}, {50.f, 60.f, 70.f, 80.f, 10.f, 20.f, 30.f, 40.f, 0.f, 0.f, 0.f, 0.f}); auto exp = NDArrayFactory::create<float>('c', {3, 4}, {50.f, 60.f, 70.f, 80.f, 10.f, 20.f, 30.f, 40.f, 0.f, 0.f, 0.f, 0.f});
nd4j::ops::scatter_nd op; nd4j::ops::scatter_nd op;
auto result = op.execute({&indices, &updates, &shape}, {}, {}); auto result = op.execute({&indices, &updates, &shape}, {}, {false, true});
ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_EQ(ND4J_STATUS_OK, result->status());
auto z = result->at(0); auto z = result->at(0);
@ -1066,7 +1090,7 @@ TEST_F(ParityOpsTests, scatterND_test3) {
updates.linspace(1.f); updates.linspace(1.f);
nd4j::ops::scatter_nd op; nd4j::ops::scatter_nd op;
auto result = op.execute({&indices, &updates, &shape}, {}, {}); auto result = op.execute({&indices, &updates, &shape}, {}, {false, true});
ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_EQ(ND4J_STATUS_OK, result->status());
auto z = result->at(0); auto z = result->at(0);
@ -1160,7 +1184,7 @@ TEST_F(ParityOpsTests, scatterND_test7) {
updates.linspace(1); updates.linspace(1);
nd4j::ops::scatter_nd op; nd4j::ops::scatter_nd op;
auto result = op.execute({&indices, &updates, &shape}, {}, {}, {true}); auto result = op.execute({&indices, &updates, &shape}, {}, {}, {true, true});
ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_EQ(ND4J_STATUS_OK, result->status());
auto z = result->at(0); auto z = result->at(0);
@ -1193,6 +1217,20 @@ TEST_F(ParityOpsTests, scatterND_test8) {
delete result; delete result;
} }
////////////////////////////////////////////////////////////////////////
TEST_F(ParityOpsTests, scatterND_test9) {
NDArray indices('c', {2, 3, 1}, {0., 20., 7., 30., 6., 90.}, nd4j::DataType::INT32);
auto updates = NDArrayFactory::create<float>('c', {2,3, 3,4});
auto shape = NDArrayFactory::create<int>('c', {3}, {10, 3, 4});
auto output = NDArrayFactory::create<float>('c', {10, 3, 4});
nd4j::ops::scatter_nd op;
ASSERT_ANY_THROW(auto result = op.execute({&indices, &updates, &shape}, {&output}, {}, {}, {false, true}));
}
//////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////
TEST_F(ParityOpsTests, scatterND_add_test1) { TEST_F(ParityOpsTests, scatterND_add_test1) {
@ -1323,6 +1361,19 @@ TEST_F(ParityOpsTests, scatterND_add_test5) {
delete result; delete result;
} }
////////////////////////////////////////////////////////////////////////
TEST_F(ParityOpsTests, scatterND_add_test6) {
auto input = NDArrayFactory::create<float>('c', {6, 4});
NDArray indices('c', {2, 3, 1}, {50.f, 1.f, 2.f, 3.f, 40.f, 0.f}, nd4j::DataType::INT32);
auto updates = NDArrayFactory::create<float>('c', {2,3,4});
auto output = NDArrayFactory::create<float>('c', {6,4});
nd4j::ops::scatter_nd_add op;
ASSERT_ANY_THROW(op.execute({&input, &indices, &updates}, {&output}, {}, {}, {false, true}));
}
//////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////
TEST_F(ParityOpsTests, scatterND_sub_test1) { TEST_F(ParityOpsTests, scatterND_sub_test1) {
@ -1586,6 +1637,19 @@ TEST_F(ParityOpsTests, scatterND_update_test5) {
delete result; delete result;
} }
////////////////////////////////////////////////////////////////////////
TEST_F(ParityOpsTests, scatterND_update_test6) {
auto input = NDArrayFactory::create<float>('c', {6, 4});
NDArray indices('c', {3, 3, 2}, {0.f,0.f, 10.f,1.f, 20.f,2.f, 30.f,3.f, 40.f,0.f, 50.f,1.f, 0.f,2.f, 1.f,3.f, 2.f,0.f}, nd4j::DataType::INT32);
auto updates = NDArrayFactory::create<float>('c', {3,3});
auto output = NDArrayFactory::create<float>('c', {6,4});
nd4j::ops::scatter_nd_update op;
ASSERT_ANY_THROW(op.execute({&input, &indices, &updates}, {&output}, {}, {}, {true, true}));
}
////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////
TEST_F(ParityOpsTests, scatter_update_1) { TEST_F(ParityOpsTests, scatter_update_1) {