diff --git a/libnd4j/include/helpers/shape.h b/libnd4j/include/helpers/shape.h index 0a581d718..12162d77c 100644 --- a/libnd4j/include/helpers/shape.h +++ b/libnd4j/include/helpers/shape.h @@ -533,7 +533,7 @@ namespace shape { * the given shape info buffer * represents a scalar shape */ - ND4J_EXPORT _CUDA_HD int isScalar(Nd4jLong *info); + ND4J_EXPORT _CUDA_HD int isScalar(const Nd4jLong *info); /** * Returns whether @@ -904,6 +904,7 @@ namespace shape { ND4J_EXPORT _CUDA_HD void index2coords(Nd4jLong index, const Nd4jLong *shapeInfo, int *coords); ND4J_EXPORT _CUDA_HD void index2coords(Nd4jLong index, const Nd4jLong *shapeInfo, uint *coords); ND4J_EXPORT _CUDA_HD void index2coords(Nd4jLong index, const int rank, const Nd4jLong *shape, Nd4jLong *coords); + ND4J_EXPORT _CUDA_HD void index2coords(Nd4jLong index, const int rank, const Nd4jLong *shape, int *coords); /** * take into account only dimensions stored in tadDims, tadDims must be sorted in increasing order! */ @@ -2706,7 +2707,7 @@ INLINEDEF _CUDA_HD Nd4jLong indexOffset(Nd4jLong index, const Nd4jLong* lShapeIn * the given shape info buffer * represents a scalar shape */ - INLINEDEF _CUDA_HD int isScalar(Nd4jLong *info) { + INLINEDEF _CUDA_HD int isScalar(const Nd4jLong *info) { const int rank = shape::rank(info); @@ -2715,9 +2716,9 @@ INLINEDEF _CUDA_HD Nd4jLong indexOffset(Nd4jLong index, const Nd4jLong* lShapeIn if(rank == 0) return 1; if(rank == 1) - return shape::shapeOf(info)[0] == 1; + return shape::shapeOf(const_cast(info))[0] == 1; if(rank == 2) - return shape::shapeOf(info)[0] == 1 && shape::shapeOf(info)[1] == 1; + return shape::shapeOf(const_cast(info))[0] == 1 && shape::shapeOf(const_cast(info))[1] == 1; return 0; } @@ -4793,6 +4794,16 @@ INLINEDEF void _CUDA_HD index2coords(Nd4jLong index, const int rank, const Nd4jL coords[0] = index; // last iteration } +////////////////////////////////////////////////////////////////////// +INLINEDEF void _CUDA_HD index2coords(Nd4jLong index, const int rank, const Nd4jLong *shape, int *coords) { + + for(uint i = rank - 1; i > 0; --i) { + coords[i] = index % shape[i]; + index /= shape[i]; + } + coords[0] = index; // last iteration +} + ////////////////////////////////////////////////////////////////////// INLINEDEF void _CUDA_HD index2coords(Nd4jLong index, const Nd4jLong *shapeInfo, Nd4jLong *coords, const int dimsSize, const int* tadDims) { diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/scatter_add.cpp b/libnd4j/include/ops/declarable/generic/parity_ops/scatter_add.cpp index 0d9465b02..db4eeeff6 100644 --- a/libnd4j/include/ops/declarable/generic/parity_ops/scatter_add.cpp +++ b/libnd4j/include/ops/declarable/generic/parity_ops/scatter_add.cpp @@ -39,6 +39,7 @@ OP_IMPL(scatter_add, 3, 1, true) { output->assign(input); const bool lock = block.getBArguments()->empty() ? false : B_ARG(0); + const bool checkIndices = block.getBArguments()->size() <= 1 ? false : B_ARG(1); const int inRank = input->rankOf(); const int indRank = indices->rankOf(); @@ -71,8 +72,15 @@ OP_IMPL(scatter_add, 3, 1, true) { REQUIRE_TRUE(expectedUpdShape == updShape, 0, "SCATTER_ADD OP: wrong shape of updates array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedUpdShape).c_str(), ShapeUtils::shapeAsString(updShape).c_str()); } - if (!indices->isEmpty()) + if (!indices->isEmpty()) { + + if(checkIndices) { + const Nd4jLong numOfBadIndx = helpers::checkIndices(block.launchContext(), *indices, *output, 0); + REQUIRE_TRUE(numOfBadIndx == 0, 0, "SCATTER_ADD OP: please check elements of indices-array, total number of wrong elements is %lld!", numOfBadIndx); + } + helpers::scatter(block.launchContext(), pairwise::Add, *indices, *updates, *output, lock); + } return Status::OK(); } diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/scatter_div.cpp b/libnd4j/include/ops/declarable/generic/parity_ops/scatter_div.cpp index dccc34e59..40ddbd424 100644 --- a/libnd4j/include/ops/declarable/generic/parity_ops/scatter_div.cpp +++ b/libnd4j/include/ops/declarable/generic/parity_ops/scatter_div.cpp @@ -39,6 +39,7 @@ namespace nd4j { output->assign(input); const bool lock = block.getBArguments()->empty() ? false : B_ARG(0); + const bool checkIndices = block.getBArguments()->size() <= 1 ? false : B_ARG(1); const int inRank = input->rankOf(); const int indRank = indices->rankOf(); @@ -70,9 +71,15 @@ namespace nd4j { REQUIRE_TRUE(expectedUpdShape == updShape, 0, "SCATTER_DIV OP: wrong shape of updates array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedUpdShape).c_str(), ShapeUtils::shapeAsString(updShape).c_str()); } - if (!indices->isEmpty()) - // ScatterHelper::template scatterApply>(output, indices, updates); + if (!indices->isEmpty()) { + + if(checkIndices) { + const Nd4jLong numOfBadIndx = helpers::checkIndices(block.launchContext(), *indices, *output, 0); + REQUIRE_TRUE(numOfBadIndx == 0, 0, "SCATTER_DIV OP: please check elements of indices-array, total number of wrong elements is %lld!", numOfBadIndx); + } + helpers::scatter(block.launchContext(), pairwise::Divide, *indices, *updates, *output, lock); + } return Status::OK(); } diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/scatter_max.cpp b/libnd4j/include/ops/declarable/generic/parity_ops/scatter_max.cpp index 5d37a71d0..4ec55f088 100644 --- a/libnd4j/include/ops/declarable/generic/parity_ops/scatter_max.cpp +++ b/libnd4j/include/ops/declarable/generic/parity_ops/scatter_max.cpp @@ -39,6 +39,7 @@ OP_IMPL(scatter_max, 3, 1, true) { output->assign(input); const bool lock = block.getBArguments()->empty() ? false : B_ARG(0); + const bool checkIndices = block.getBArguments()->size() <= 1 ? false : B_ARG(1); const int inRank = input->rankOf(); const int indRank = indices->rankOf(); @@ -70,8 +71,15 @@ OP_IMPL(scatter_max, 3, 1, true) { REQUIRE_TRUE(expectedUpdShape == updShape, 0, "SCATTER_MAX OP: wrong shape of updates array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedUpdShape).c_str(), ShapeUtils::shapeAsString(updShape).c_str()); } - if (!indices->isEmpty()) + if (!indices->isEmpty()) { + + if(checkIndices) { + const Nd4jLong numOfBadIndx = helpers::checkIndices(block.launchContext(), *indices, *output, 0); + REQUIRE_TRUE(numOfBadIndx == 0, 0, "SCATTER_MAX OP: please check elements of indices-array, total number of wrong elements is %lld!", numOfBadIndx); + } + helpers::scatter(block.launchContext(), pairwise::MaxPairwise, *indices, *updates, *output, lock); + } return Status::OK(); } diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/scatter_min.cpp b/libnd4j/include/ops/declarable/generic/parity_ops/scatter_min.cpp index 1bed296f9..ea8dbf081 100644 --- a/libnd4j/include/ops/declarable/generic/parity_ops/scatter_min.cpp +++ b/libnd4j/include/ops/declarable/generic/parity_ops/scatter_min.cpp @@ -39,6 +39,7 @@ OP_IMPL(scatter_min, 3, 1, true) { output->assign(input); const bool lock = block.getBArguments()->empty() ? false : B_ARG(0); + const bool checkIndices = block.getBArguments()->size() <= 1 ? false : B_ARG(1); const int inRank = input->rankOf(); const int indRank = indices->rankOf(); @@ -70,8 +71,15 @@ OP_IMPL(scatter_min, 3, 1, true) { REQUIRE_TRUE(expectedUpdShape == updShape, 0, "SCATTER_MIN OP: wrong shape of updates array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedUpdShape).c_str(), ShapeUtils::shapeAsString(updShape).c_str()); } - if (!indices->isEmpty()) + if (!indices->isEmpty()) { + + if(checkIndices) { + const Nd4jLong numOfBadIndx = helpers::checkIndices(block.launchContext(), *indices, *output, 0); + REQUIRE_TRUE(numOfBadIndx == 0, 0, "SCATTER_MIN OP: please check elements of indices-array, total number of wrong elements is %lld!", numOfBadIndx); + } + helpers::scatter(block.launchContext(), pairwise::MinPairwise, *indices, *updates, *output, lock); + } return Status::OK(); } diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/scatter_mul.cpp b/libnd4j/include/ops/declarable/generic/parity_ops/scatter_mul.cpp index 46b9f7008..4685cef5d 100644 --- a/libnd4j/include/ops/declarable/generic/parity_ops/scatter_mul.cpp +++ b/libnd4j/include/ops/declarable/generic/parity_ops/scatter_mul.cpp @@ -35,6 +35,7 @@ namespace nd4j { auto output = OUTPUT_VARIABLE(0); const bool lock = block.getBArguments()->empty() ? false : B_ARG(0); + const bool checkIndices = block.getBArguments()->size() <= 1 ? false : B_ARG(1); const int inRank = input->rankOf(); const int indRank = indices->rankOf(); @@ -70,8 +71,15 @@ namespace nd4j { REQUIRE_TRUE(expectedUpdShape == updShape, 0, "SCATTER_MUL OP: wrong shape of updates array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedUpdShape).c_str(), ShapeUtils::shapeAsString(updShape).c_str()); } - if (!indices->isEmpty()) + if (!indices->isEmpty()) { + + if(checkIndices) { + const Nd4jLong numOfBadIndx = helpers::checkIndices(block.launchContext(), *indices, *output, 0); + REQUIRE_TRUE(numOfBadIndx == 0, 0, "SCATTER_MUL OP: please check elements of indices-array, total number of wrong elements is %lld!", numOfBadIndx); + } + helpers::scatter(block.launchContext(), pairwise::Multiply, *indices, *updates, *output, lock); + } return Status::OK(); } diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/scatter_nd.cpp b/libnd4j/include/ops/declarable/generic/parity_ops/scatter_nd.cpp index de4839d3e..4c3884e07 100644 --- a/libnd4j/include/ops/declarable/generic/parity_ops/scatter_nd.cpp +++ b/libnd4j/include/ops/declarable/generic/parity_ops/scatter_nd.cpp @@ -35,6 +35,7 @@ namespace ops { auto output = OUTPUT_VARIABLE(0); const bool lock = block.getBArguments()->empty() ? false : B_ARG(0); + const bool checkIndices = block.getBArguments()->size() <= 1 ? false : B_ARG(1); const int indRank = indices->rankOf(); const int updRank = updates->rankOf(); @@ -53,6 +54,11 @@ namespace ops { std::move(std::begin(outShape) + indices->sizeAt(-1), std::end(outShape), std::back_inserter(expectedUpdShape)); REQUIRE_TRUE(expectedUpdShape == updShape, 0, "SCATTER_ND OP: wrong shape of updates array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedUpdShape).c_str(), ShapeUtils::shapeAsString(updShape).c_str()); + if(checkIndices) { + const Nd4jLong numOfBadIndx = helpers::checkIndices(block.launchContext(), *indices, *output); + REQUIRE_TRUE(numOfBadIndx == 0, 0, "SCATTER_ND OP: please check elements of indices-array, total number of wrong elements is %lld!", numOfBadIndx); + } + // initial zeroing of output *output = 0; @@ -73,7 +79,7 @@ namespace ops { DECLARE_SHAPE_FN(scatter_nd) { auto shape = INPUT_VARIABLE(2); - auto updShapeInfo = inputShape->at(1); + auto updShapeInfo = inputShape->at(1); Nd4jLong *outShapeInfo; ALLOCATE(outShapeInfo, block.getWorkspace(), shape::shapeInfoLength(shape->lengthOf()), Nd4jLong); diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/scatter_nd_add.cpp b/libnd4j/include/ops/declarable/generic/parity_ops/scatter_nd_add.cpp index 5679b22ec..43c2c66ed 100644 --- a/libnd4j/include/ops/declarable/generic/parity_ops/scatter_nd_add.cpp +++ b/libnd4j/include/ops/declarable/generic/parity_ops/scatter_nd_add.cpp @@ -34,25 +34,31 @@ OP_IMPL(scatter_nd_add, 3, 1, true) { auto output = OUTPUT_VARIABLE(0); - bool lock = block.getBArguments()->empty() ? false : B_ARG(0); + const bool lock = block.getBArguments()->empty() ? false : B_ARG(0); + const bool checkIndices = block.getBArguments()->size() <= 1 ? false : B_ARG(1); const int inRank = input->rankOf(); const int indRank = indices->rankOf(); const int updRank = updates->rankOf(); const Nd4jLong indLastDim = indices->sizeAt(-1); - + REQUIRE_TRUE(indLastDim <= inRank, 0, "SCATTER_ND_ADD OP: the last dimension of indices array must be <= input_array_rank, but got %i instead !", indLastDim); REQUIRE_TRUE(updRank == (indRank - 1 + inRank - indLastDim), 0, "SCATTER_ND_ADD OP: the equality updates_rank = (indices_rank - 1 + input_rank - last_indices_dimension) must be true for input arrays, but got instead: updates_rank = %i, indices_rank = %i, last_indices_dimension = %i !", updRank, indRank, indLastDim); std::vector inShape = input->getShapeAsVector(); std::vector updShape = updates->getShapeAsVector(); - std::vector indShape = indices->getShapeAsVector(); - std::vector expectedUpdShape(std::begin(indShape), std::end(indShape) - 1); + std::vector indShape = indices->getShapeAsVector(); + std::vector expectedUpdShape(std::begin(indShape), std::end(indShape) - 1); if(inRank > indLastDim) - std::move(std::begin(inShape) + indLastDim, std::end(inShape), std::back_inserter(expectedUpdShape)); + std::move(std::begin(inShape) + indLastDim, std::end(inShape), std::back_inserter(expectedUpdShape)); REQUIRE_TRUE(expectedUpdShape == updShape, 0, "SCATTER_ND_ADD OP: wrong shape of updates array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedUpdShape).c_str(), ShapeUtils::shapeAsString(updShape).c_str()); + if(checkIndices) { + const Nd4jLong numOfBadIndx = helpers::checkIndices(block.launchContext(), *indices, *output); + REQUIRE_TRUE(numOfBadIndx == 0, 0, "SCATTER_ND_ADD OP: please check elements of indices-array, total number of wrong elements is %lld!", numOfBadIndx); + } + if (!block.isInplace()) output->assign(input); diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/scatter_nd_sub.cpp b/libnd4j/include/ops/declarable/generic/parity_ops/scatter_nd_sub.cpp index 7b0d0e1ad..eb4768f86 100644 --- a/libnd4j/include/ops/declarable/generic/parity_ops/scatter_nd_sub.cpp +++ b/libnd4j/include/ops/declarable/generic/parity_ops/scatter_nd_sub.cpp @@ -34,25 +34,31 @@ OP_IMPL(scatter_nd_sub, 3, 1, true) { auto output = OUTPUT_VARIABLE(0); - bool lock = block.getBArguments()->empty() ? false : B_ARG(0); + const bool lock = block.getBArguments()->empty() ? false : B_ARG(0); + const bool checkIndices = block.getBArguments()->size() <= 1 ? false : B_ARG(1); const int inRank = input->rankOf(); const int indRank = indices->rankOf(); const int updRank = updates->rankOf(); const Nd4jLong indLastDim = indices->sizeAt(-1); - + REQUIRE_TRUE(indLastDim <= inRank, 0, "SCATTER_ND_SUB OP: the last dimension of indices array must be <= input_array_rank, but got %i instead !", indLastDim); REQUIRE_TRUE(updRank == (indRank - 1 + inRank - indLastDim), 0, "SCATTER_ND_SUB OP: the equality updates_rank = (indices_rank - 1 + input_rank - last_indices_dimension) must be true for input arrays, but got instead: updates_rank = %i, indices_rank = %i, last_indices_dimension = %i !", updRank, indRank, indLastDim); std::vector inShape = input->getShapeAsVector(); std::vector updShape = updates->getShapeAsVector(); - std::vector indShape = indices->getShapeAsVector(); - std::vector expectedUpdShape(std::begin(indShape), std::end(indShape) - 1); + std::vector indShape = indices->getShapeAsVector(); + std::vector expectedUpdShape(std::begin(indShape), std::end(indShape) - 1); if(inRank > indLastDim) - std::move(std::begin(inShape) + indLastDim, std::end(inShape), std::back_inserter(expectedUpdShape)); + std::move(std::begin(inShape) + indLastDim, std::end(inShape), std::back_inserter(expectedUpdShape)); REQUIRE_TRUE(expectedUpdShape == updShape, 0, "SCATTER_ND_SUB OP: wrong shape of updates array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedUpdShape).c_str(), ShapeUtils::shapeAsString(updShape).c_str()); + if(checkIndices) { + const Nd4jLong numOfBadIndx = helpers::checkIndices(block.launchContext(), *indices, *output); + REQUIRE_TRUE(numOfBadIndx == 0, 0, "SCATTER_ND_SUB OP: please check elements of indices-array, total number of wrong elements is %lld!", numOfBadIndx); + } + if (!block.isInplace()) output->assign(input); @@ -62,7 +68,7 @@ OP_IMPL(scatter_nd_sub, 3, 1, true) { } DECLARE_TYPES(scatter_nd_sub) { - getOpDescriptor() + getOpDescriptor() ->setAllowedInputTypes(0, {ALL_INTS, ALL_FLOATS}) ->setAllowedInputTypes(1, {ALL_INTS}) ->setAllowedInputTypes(2, {ALL_INTS, ALL_FLOATS}) diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/scatter_nd_update.cpp b/libnd4j/include/ops/declarable/generic/parity_ops/scatter_nd_update.cpp index e973ff160..e6bfb8703 100644 --- a/libnd4j/include/ops/declarable/generic/parity_ops/scatter_nd_update.cpp +++ b/libnd4j/include/ops/declarable/generic/parity_ops/scatter_nd_update.cpp @@ -34,25 +34,31 @@ OP_IMPL(scatter_nd_update, 3, 1, true) { auto output = OUTPUT_VARIABLE(0); - bool lock = block.getBArguments()->empty() ? true : B_ARG(0); + const bool lock = block.getBArguments()->empty() ? true : B_ARG(0); + const bool checkIndices = block.getBArguments()->size() <= 1 ? false : B_ARG(1); const int inRank = input->rankOf(); const int indRank = indices->rankOf(); const int updRank = updates->rankOf(); const Nd4jLong indLastDim = indices->sizeAt(-1); - + REQUIRE_TRUE(indLastDim <= inRank, 0, "SCATTER_ND_UPDATE OP: the last dimension of indices array must be <= input_array_rank, but got %i instead !", indLastDim); REQUIRE_TRUE(updRank == (indRank - 1 + inRank - indLastDim), 0, "SCATTER_ND_UPDATE OP: the equality updates_rank = (indices_rank - 1 + input_rank - last_indices_dimension) must be true for input arrays, but got instead: updates_rank = %i, indices_rank = %i, last_indices_dimension = %i !", updRank, indRank, indLastDim); std::vector inShape = input->getShapeAsVector(); std::vector updShape = updates->getShapeAsVector(); - std::vector indShape = indices->getShapeAsVector(); - std::vector expectedUpdShape(std::begin(indShape), std::end(indShape) - 1); + std::vector indShape = indices->getShapeAsVector(); + std::vector expectedUpdShape(std::begin(indShape), std::end(indShape) - 1); if(inRank > indLastDim) std::move(std::begin(inShape) + indLastDim, std::end(inShape), std::back_inserter(expectedUpdShape)); REQUIRE_TRUE(expectedUpdShape == updShape, 0, "SCATTER_ND_UPDATE OP: wrong shape of updates array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedUpdShape).c_str(), ShapeUtils::shapeAsString(updShape).c_str()); + if(checkIndices) { + const Nd4jLong numOfBadIndx = helpers::checkIndices(block.launchContext(), *indices, *output); + REQUIRE_TRUE(numOfBadIndx == 0, 0, "SCATTER_ND_UPDATE OP: please check elements of indices-array, total number of wrong elements is %lld!", numOfBadIndx); + } + if (!block.isInplace()) output->assign(input); @@ -62,7 +68,7 @@ OP_IMPL(scatter_nd_update, 3, 1, true) { } DECLARE_TYPES(scatter_nd_update) { - getOpDescriptor() + getOpDescriptor() ->setAllowedInputTypes(0, {ALL_INTS, ALL_FLOATS}) ->setAllowedInputTypes(1, {ALL_INTS}) ->setAllowedInputTypes(2, {ALL_INTS, ALL_FLOATS}) diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/scatter_sub.cpp b/libnd4j/include/ops/declarable/generic/parity_ops/scatter_sub.cpp index cf3745236..1971b346f 100644 --- a/libnd4j/include/ops/declarable/generic/parity_ops/scatter_sub.cpp +++ b/libnd4j/include/ops/declarable/generic/parity_ops/scatter_sub.cpp @@ -38,6 +38,7 @@ namespace nd4j { output->assign(input); const bool lock = block.getBArguments()->empty() ? false : B_ARG(0); + const bool checkIndices = block.getBArguments()->size() <= 1 ? false : B_ARG(1); const int inRank = input->rankOf(); const int indRank = indices->rankOf(); @@ -70,9 +71,16 @@ namespace nd4j { REQUIRE_TRUE(expectedUpdShape == updShape, 0, "SCATTER_SUB OP: wrong shape of updates array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedUpdShape).c_str(), ShapeUtils::shapeAsString(updShape).c_str()); } - if (!indices->isEmpty()) + if (!indices->isEmpty()) { + + if(checkIndices) { + const Nd4jLong numOfBadIndx = helpers::checkIndices(block.launchContext(), *indices, *output, 0); + REQUIRE_TRUE(numOfBadIndx == 0, 0, "SCATTER_SUB OP: please check elements of indices-array, total number of wrong elements is %lld!", numOfBadIndx); + } + // ScatterHelper::template scatterApply>(output, indices, updates); helpers::scatter(block.launchContext(), pairwise::Subtract, *indices, *updates, *output, lock); + } return Status::OK(); } diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/scatter_upd.cpp b/libnd4j/include/ops/declarable/generic/parity_ops/scatter_upd.cpp index 55076e51e..081d1cd76 100644 --- a/libnd4j/include/ops/declarable/generic/parity_ops/scatter_upd.cpp +++ b/libnd4j/include/ops/declarable/generic/parity_ops/scatter_upd.cpp @@ -37,6 +37,7 @@ namespace nd4j { output->assign(input); const bool lock = block.getBArguments()->empty() ? true : B_ARG(0); + const bool checkIndices = block.getBArguments()->size() <= 1 ? false : B_ARG(1); const int inRank = input->rankOf(); const int indRank = indices->rankOf(); @@ -68,9 +69,16 @@ namespace nd4j { REQUIRE_TRUE(expectedUpdShape == updShape, 0, "SCATTER_UPD OP: wrong shape of updates array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedUpdShape).c_str(), ShapeUtils::shapeAsString(updShape).c_str()); } - if (!indices->isEmpty()) + if (!indices->isEmpty()) { + + if(checkIndices) { + const Nd4jLong numOfBadIndx = helpers::checkIndices(block.launchContext(), *indices, *output, 0); + REQUIRE_TRUE(numOfBadIndx == 0, 0, "SCATTER_UPD OP: please check elements of indices-array, total number of wrong elements is %lld!", numOfBadIndx); + } + // ScatterHelper::template scatterApply>(output, indices, updates); helpers::scatter(block.launchContext(), pairwise::CopyPws, *indices, *updates, *output, lock); + } return Status::OK(); } diff --git a/libnd4j/include/ops/declarable/generic/transforms/gather.cpp b/libnd4j/include/ops/declarable/generic/transforms/gather.cpp index 2e38f9977..61ed3bc65 100644 --- a/libnd4j/include/ops/declarable/generic/transforms/gather.cpp +++ b/libnd4j/include/ops/declarable/generic/transforms/gather.cpp @@ -22,7 +22,8 @@ #if NOT_EXCLUDED(OP_gather) #include -#include +#include +#include namespace nd4j { @@ -34,8 +35,10 @@ CUSTOM_OP_IMPL(gather, 1, 1, false, 0, -2) { auto input = INPUT_VARIABLE(0); auto indices = block.width() > 1 ? INPUT_VARIABLE(1) : nullptr; - auto output = OUTPUT_VARIABLE(0); - + auto output = OUTPUT_VARIABLE(0); + + const bool checkIndices = block.getBArguments()->empty() ? false : B_ARG(0); + //Edge case: empty indices -> empty output if(indices != nullptr && indices->isEmpty()){ REQUIRE_TRUE(output->isEmpty(), 0, "Gather op: If indices are empty, output must also be empty"); @@ -47,7 +50,7 @@ CUSTOM_OP_IMPL(gather, 1, 1, false, 0, -2) { std::vector intArgs; if (block.width() > 2) { intArgs = INPUT_VARIABLE(2)->template asVectorT(); - } + } else { if (numOfIntArgs == 0) intArgs.emplace_back(0); @@ -64,14 +67,16 @@ CUSTOM_OP_IMPL(gather, 1, 1, false, 0, -2) { REQUIRE_TRUE(intArgs[0] < inputRank, 0, "GATHER op: input axis must be smaller than input array rank, but got %i and %i correspondingly!", intArgs[0], inputRank); REQUIRE_TRUE(indices != nullptr || numOfIntArgs > 1, 0, "GATHER op: indices should be provided either as additional input array or as IntArguments !"); - if (indices != nullptr) { - for(int i = 0; i < indices->lengthOf(); ++i) - REQUIRE_TRUE(indices->e(i) < input->sizeAt(intArgs[0]), 0, "GATHER op: indices array contains wrong elements, each element must be smaller than corresponding dimension of input array !"); - } - else { - for(int i = 1; i < numOfIntArgs; ++i) - REQUIRE_TRUE(intArgs[i] < input->sizeAt(intArgs[0]), 0, "GATHER op: some of indexes is larger than corresponding shape of input array !"); - } + if(checkIndices) { + + NDArray* pIndices = indices; + if(indices == nullptr) + pIndices = new NDArray(input->ordering(), {static_cast(intArgs.size()) - 1}, std::vector(intArgs.begin() + 1, intArgs.end()), DataType::INT64, block.launchContext()); + const Nd4jLong numOfBadIndx = helpers::checkIndices(block.launchContext(), *pIndices, *input, intArgs[0]); + REQUIRE_TRUE(numOfBadIndx == 0, 0, "GATHER OP: please check elements of indices-array, total number of wrong elements is %lld!", numOfBadIndx); + if(indices == nullptr) + delete pIndices; + } helpers::gather(block.launchContext(), input, indices, output, intArgs); @@ -87,7 +92,7 @@ DECLARE_TYPES(gather) { DECLARE_SHAPE_FN(gather) { - // check shape of paddings + // check shape of paddings auto inputShapeInfo = inputShape->at(0); Nd4jLong* outputShapeInfo = nullptr; @@ -105,21 +110,21 @@ DECLARE_SHAPE_FN(gather) { REQUIRE_TRUE(axis < inputRank, 0, "GATHER op: input axis must be smaller than input array rank, but got %i and %i correspondingly!", axis, inputRank); bool isEmpty = false; - + if (block.width() > 1) { auto indicesShapeInfo = inputShape->at(1); - + int indicesRank = shape::rank(indicesShapeInfo); - + int outputRank = inputRank + indicesRank - 1; - + ALLOCATE(outputShapeInfo, block.getWorkspace(), shape::shapeInfoLength(outputRank), Nd4jLong); // fill output shapeInfo outputShapeInfo[0] = outputRank; - int shapeIdx = 1; - - for(int i = 0; i < axis; ++i) + int shapeIdx = 1; + + for(int i = 0; i < axis; ++i) outputShapeInfo[shapeIdx++] = inputShapeInfo[i+1]; for(int i = 0; i < indicesRank; ++i) @@ -127,7 +132,7 @@ DECLARE_SHAPE_FN(gather) { for(int i = axis+1; i < inputRank; ++i) outputShapeInfo[shapeIdx++] = inputShapeInfo[i+1]; - } + } else if (block.numI() > 1) { int indicesRank = block.numI() == 2 ? 0 : 1; @@ -137,7 +142,7 @@ DECLARE_SHAPE_FN(gather) { // building shape manually outputShapeInfo[0] = outputRank; - int shapeIdx = 1; + int shapeIdx = 1; for(int i = 0; i < axis; ++i) outputShapeInfo[shapeIdx++] = inputShapeInfo[i+1]; diff --git a/libnd4j/include/ops/declarable/generic/transforms/gatherNd.cpp b/libnd4j/include/ops/declarable/generic/transforms/gatherNd.cpp index c64150c04..c889569e2 100644 --- a/libnd4j/include/ops/declarable/generic/transforms/gatherNd.cpp +++ b/libnd4j/include/ops/declarable/generic/transforms/gatherNd.cpp @@ -23,6 +23,7 @@ #include #include +#include namespace nd4j { namespace ops { @@ -35,6 +36,8 @@ CUSTOM_OP_IMPL(gather_nd, 2, 1, false, 0, 0) { auto indices = INPUT_VARIABLE(1); auto output = OUTPUT_VARIABLE(0); + const bool checkIndices = block.getBArguments()->empty() ? false : B_ARG(0); + const int rankIn = input->rankOf(); const int rankInd = indices->rankOf(); @@ -42,6 +45,11 @@ CUSTOM_OP_IMPL(gather_nd, 2, 1, false, 0, 0) { int lastIndDim = indices->sizeAt(-1); REQUIRE_TRUE(lastIndDim <= rankIn, 0, "GATHER_ND op: the last dimension of indices array must be <= rank of input array but got %i and %i correspondingly!", lastIndDim, rankIn); + if(checkIndices) { + const Nd4jLong numOfBadIndx = helpers::checkIndices(block.launchContext(), *indices, *input); + REQUIRE_TRUE(numOfBadIndx == 0, 0, "GATHER_ND OP: please check elements of indices-array, total number of wrong elements is %lld!", numOfBadIndx); + } + helpers::gatherND(block.launchContext(), *input, *indices, *output); return Status::OK(); diff --git a/libnd4j/include/ops/declarable/helpers/cpu/scatter.cpp b/libnd4j/include/ops/declarable/helpers/cpu/scatter.cpp index 99605e7cc..9ae191c76 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/scatter.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/scatter.cpp @@ -27,6 +27,49 @@ namespace nd4j { namespace ops { namespace helpers { +/////////////////////////////////////////////////////////////////// +// x - indices, z - input/output +template +Nd4jLong checkIndices_(const NDArray& indices, const NDArray& output, const int axis) { + + std::atomic numOfBadIndx{0}; + + const auto x = indices.bufferAsT(); + + const auto xShapeInfo = indices.getShapeInfo(); + const auto zShapeInfo = output.getShapeInfo(); + + const auto xRank = indices.rankOf(); + + auto func = PRAGMA_THREADS_FOR { + + Nd4jLong xCoords[MAX_RANK]; + + for (auto i = start; i < stop; i += increment) { + + shape::index2coords(i, xShapeInfo, xCoords); + + const Nd4jLong currentInd = x[shape::getOffset(xShapeInfo, xCoords)]; + + if(currentInd >= shape::sizeAt(zShapeInfo, axis == -1 ? xCoords[xRank-1] : axis)) { + printf("checkIndices: out of range element %lld at index %ld \n", currentInd, i); + ++numOfBadIndx; + } + } + }; + + samediff::Threads::parallel_for(func, 0, indices.lengthOf()); + + return numOfBadIndx; +} + +/////////////////////////////////////////////////////////////////// +Nd4jLong checkIndices(nd4j::LaunchContext *context, const NDArray& indices, const NDArray& output, const int axis) { + + BUILD_SINGLE_SELECTOR(indices.dataType(), return checkIndices_, (indices, output, axis), INDEXING_TYPES); +} + +/////////////////////////////////////////////////////////////////// void scatter(nd4j::LaunchContext *context, pairwise::Ops op, const NDArray& indices, const NDArray& updates, NDArray& output, const bool lock) { const int outRank = output.rankOf(); diff --git a/libnd4j/include/ops/declarable/helpers/cuda/gather.cu b/libnd4j/include/ops/declarable/helpers/cuda/gather.cu index 308e58814..f6d8acc77 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/gather.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/gather.cu @@ -108,12 +108,12 @@ __host__ static void gatherCudaLauncher(const cudaStream_t *stream, const int nu void gather(nd4j::LaunchContext * context, const NDArray* input, const NDArray* indices, NDArray* output, const std::vector& intArgs) { const int inputRank = input->rankOf(); - int axis = intArgs.size() > 0 ? intArgs[0] : 0; + const int numOfIntArgs = intArgs.size(); + + int axis = numOfIntArgs > 0 ? intArgs[0] : 0; if(axis < 0) axis += inputRank; - const int numOfIntArgs = intArgs.size(); - if (indices == nullptr && numOfIntArgs == 2) { // scalar case output->assign((*input)(intArgs[1], {axis})); } diff --git a/libnd4j/include/ops/declarable/helpers/cuda/gather_nd.cu b/libnd4j/include/ops/declarable/helpers/cuda/gather_nd.cu index 11ba6571b..6b3bf5135 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/gather_nd.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/gather_nd.cu @@ -106,7 +106,7 @@ namespace nd4j { const auto xOffset = shape::getOffset(xShapeInfo, xCoordStart); z[zOffset] = x[xOffset]; - printf("z[%lld] = x[%lld] = %f\n", zOffset, xOffset, (float) z[zOffset]); + // printf("z[%lld] = x[%lld] = %f\n", zOffset, xOffset, (float) z[zOffset]); } } diff --git a/libnd4j/include/ops/declarable/helpers/cuda/scatter.cu b/libnd4j/include/ops/declarable/helpers/cuda/scatter.cu index 501b9bca4..ec85efddf 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/scatter.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/scatter.cu @@ -31,142 +31,688 @@ namespace nd4j { namespace ops { namespace helpers { - // template - // __global__ static void scatterCuda(const int opCode, const int numOfSubArrs, - // void* vx, const Nd4jLong *xShapeInfo, const Nd4jLong *xOffsets, - // void* vy, const Nd4jLong *yShapeInfo, const Nd4jLong *yOffsets, - // const int* indexes, unsigned int arrLenX, unsigned int arrLenY) { - // __shared__ T *x, *y; +/////////////////////////////////////////////////////////////////// +// x - indices, y - contains number of bad indices, z - input/output +template +__global__ static void checkIndicesCuda(const void *vx, const Nd4jLong *xShapeInfo, Nd4jLong* y, const Nd4jLong *zShapeInfo, const int axis) { - // if (locking) { + const auto x = reinterpret_cast(vx); - // for (int e = 0; e < numOfSubArrs; e++) { + __shared__ int xRank, *coords, xLastDim; + __shared__ Nd4jLong xLen, numOfBadIndxPerBlock; - // const auto xIndex = indexes[e]; - // const bool isOwner = xIndex < gridDim.x ? blockIdx.x == xIndex : blockIdx.x == xIndex % gridDim.x; + if (threadIdx.x == 0) { + extern __shared__ unsigned char shmem[]; + coords = reinterpret_cast(shmem); - // if (!isOwner) - // continue; + xRank = shape::rank(xShapeInfo); + xLen = shape::length(xShapeInfo); - // if (threadIdx.x == 0) { - // x = reinterpret_cast(vx) + xOffsets[xIndex]; - // y = reinterpret_cast(vy) + yOffsets[e]; - // } - // __syncthreads(); + numOfBadIndxPerBlock = 0; + } + __syncthreads(); - // for (Nd4jLong i = threadIdx.x; i < arrLenX; i += blockDim.x) { + auto xCoords = coords + threadIdx.x * xRank; - // const auto xOffset = shape::getIndexOffset(i, xShapeInfo); - // const auto yOffset = shape::getIndexOffset(i, yShapeInfo); + for (Nd4jLong i = blockIdx.x * blockDim.x + threadIdx.x; i < xLen; i += gridDim.x * blockDim.x) { - // switch (opCode) { - // case pairwise::Add: - // x[xOffset] += y[yOffset]; - // break; - // case pairwise::Subtract: - // x[xOffset] -= y[yOffset]; - // break; - // case pairwise::Multiply: - // x[xOffset] *= y[yOffset]; - // break; - // case pairwise::Divide: - // x[xOffset] /= y[yOffset]; - // break; - // case pairwise::ReverseSubtract: - // x[xOffset] = y[yOffset] - x[xOffset]; - // break; - // case pairwise::ReverseDivide: - // x[xOffset] = y[yOffset] / x[xOffset]; - // break; - // case pairwise::CopyPws: - // x[xOffset] = y[yOffset]; - // break; - // default: - // continue; - // } - // } - // __syncthreads(); - // } - // } else { - // for (int e = blockIdx.x; e < numOfSubArrs; e+= gridDim.x) { + shape::index2coords(i, xShapeInfo, xCoords); - // if (threadIdx.x == 0) { - // const auto xIndex = indexes[e]; - // x = reinterpret_cast(vx) + xOffsets[xIndex]; - // y = reinterpret_cast(vy) + yOffsets[e]; - // } - // __syncthreads(); + const Nd4jLong currentInd = x[shape::getOffset(xShapeInfo, xCoords)]; - // for (Nd4jLong i = threadIdx.x; i < arrLenX; i += blockDim.x) { - // const auto xOffset = shape::getIndexOffset(i, xShapeInfo); - // const auto yOffset = shape::getIndexOffset(i, yShapeInfo); + if(currentInd >= shape::sizeAt(zShapeInfo, axis == -1 ? xCoords[xRank-1] : axis)) { + printf("checkIndices cuda: out of range element %lld at index %lld \n", currentInd, i); + nd4j::math::atomics::nd4j_atomicAdd(&numOfBadIndxPerBlock, 1); + } + } + __syncthreads(); - // switch (opCode) { - // case pairwise::Add: - // x[xOffset] += y[yOffset]; - // break; - // case pairwise::Subtract: - // x[xOffset] -= y[yOffset]; - // break; - // case pairwise::Multiply: - // x[xOffset] *= y[yOffset]; - // break; - // case pairwise::Divide: - // x[xOffset] /= y[yOffset]; - // break; - // case pairwise::ReverseSubtract: - // x[xOffset] = y[yOffset] - x[xOffset]; - // break; - // case pairwise::ReverseDivide: - // x[xOffset] = y[yOffset] / x[xOffset]; - // break; - // case pairwise::CopyPws: - // x[xOffset] = y[yOffset]; - // break; - // default: - // continue; - // } - // } - // __syncthreads(); - // } - // } - // } + if (threadIdx.x == 0 && numOfBadIndxPerBlock != 0) + nd4j::math::atomics::nd4j_atomicAdd(y, numOfBadIndxPerBlock); +} + +/////////////////////////////////////////////////////////////////// +template +static void checkIndicesCudaLauncher(const int blocksPerGrid, const int threadsPerBlock, const int sharedMem, const cudaStream_t *stream, + const void *vx, const Nd4jLong *xShapeInfo, Nd4jLong* y, const Nd4jLong *zShapeInfo, const int axis) { + + checkIndicesCuda<<>>(vx, xShapeInfo, y, zShapeInfo, axis); +} - // template - // void scatter_(nd4j::LaunchContext *context, pairwise::Ops op, const NDArray& indices, const NDArray& updates, NDArray& output, const bool lock) { - // std::vector dims = {0}; - // auto inverted = ShapeUtils::evalDimsToExclude(output.rankOf(), dims); +/////////////////////////////////////////////////////////////////// +Nd4jLong checkIndices(nd4j::LaunchContext *context, const NDArray& indices, const NDArray& output, const int axis) { - // auto packX = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(output.getShapeInfo(), inverted); - // auto packY = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(updates.getShapeInfo(), inverted); + const int threadsPerBlock = MAX_NUM_THREADS / 2; + const int blocksPerGrid = (indices.lengthOf() + threadsPerBlock - 1) / threadsPerBlock; + const int sharedMem = threadsPerBlock * sizeof(int) * indices.rankOf() + 256; - // auto psX = packX.specialShapeInfo(); - // auto psY = packY.specialShapeInfo(); + const auto xType = indices.dataType(); - // PointersManager manager(context, "scatter"); + PointersManager manager(context, "scatterNDcheckIndices"); - // auto poX = packX.specialOffsets(); - // auto poY = packY.specialOffsets(); + // scalar, initial value = 0 + NDArray numOfBadIndx(nd4j::DataType::INT64, context, true); - // NDArray::prepareSpecialUse({&output}, {&updates, &indices}); + NDArray::prepareSpecialUse({&numOfBadIndx}, {&indices}); + BUILD_SINGLE_SELECTOR(xType, checkIndicesCudaLauncher, (blocksPerGrid, threadsPerBlock, sharedMem, context->getCudaStream(), indices.getSpecialBuffer(), indices.getSpecialShapeInfo(), reinterpret_cast(numOfBadIndx.getSpecialBuffer()), output.getSpecialShapeInfo(), axis), INDEXING_TYPES); + NDArray::registerSpecialUse({&numOfBadIndx}, {&indices}); - // unsigned int tadLengthX = shape::length(packX.primaryShapeInfo()); - // unsigned int tadLengthY = shape::length(packY.primaryShapeInfo()); - // if (tadLengthX != tadLengthY) - // throw std::runtime_error("scatter: Lengths of TADs must be equal"); + manager.synchronize(); - // auto blockSize = nd4j::math::nd4j_max(32, nd4j::math::nd4j_min(tadLengthX, 1024)); + return numOfBadIndx.t(0); +} - // if (lock) - // scatterCuda<<<512, blockSize, 1024, *context->getCudaStream()>>>(op, indices.lengthOf(), output.getSpecialBuffer(), psX, poX, updates.getSpecialBuffer(), psY, poY, reinterpret_cast(indices.getSpecialBuffer()), tadLengthX, tadLengthY); - // else - // scatterCuda<<<512, blockSize, 1024, *context->getCudaStream()>>>(op, indices.lengthOf(), output.getSpecialBuffer(), psX, poX, updates.getSpecialBuffer(), psY, poY, reinterpret_cast(indices.getSpecialBuffer()), tadLengthX, tadLengthY); +/////////////////////////////////////////////////////////////////// +// x - indices, y - updates, z - input/output +template +__global__ static void scatterLockCuda(const int opCode, + const void *vx, const Nd4jLong *xShapeInfo, + const void *vy, const Nd4jLong *yShapeInfo, + void *vz, const Nd4jLong *zShapeInfo) { + + const auto x = reinterpret_cast(vx); + const auto y = reinterpret_cast(vy); + auto z = reinterpret_cast(vz); + + __shared__ int xRank, yRank, zRank, xNonUnitDim, yNonUnitDim, zNonUnitDim, *coords; + __shared__ Nd4jLong xLen, zLen; + __shared__ bool is1Dcase, xySameStride; + + if (threadIdx.x == 0) { + extern __shared__ unsigned char shmem[]; + coords = reinterpret_cast(shmem); + + xLen = shape::length(xShapeInfo); + zLen = shape::length(zShapeInfo); + + xRank = shape::rank(xShapeInfo); + yRank = shape::rank(yShapeInfo); + zRank = shape::rank(zShapeInfo); + + xNonUnitDim = yNonUnitDim = zNonUnitDim = 0; + + is1Dcase = (shape::isCommonVector(zShapeInfo, zNonUnitDim) || shape::isScalar(zShapeInfo)) && (shape::isCommonVector(yShapeInfo, yNonUnitDim) || shape::isScalar(yShapeInfo)) && (shape::isCommonVector(xShapeInfo, xNonUnitDim) || shape::isScalar(xShapeInfo)); + + if(is1Dcase) + xySameStride = shape::stride(xShapeInfo)[xNonUnitDim] = shape::stride(yShapeInfo)[yNonUnitDim]; + } + __syncthreads(); + + + Nd4jLong yOffset, zOffset; + int zFirstCoord, *yCoords, *zCoords; + + for (Nd4jLong i = blockIdx.x * blockDim.x + threadIdx.x; i < zLen; i += gridDim.x * blockDim.x) { + + if(!is1Dcase) { + + yCoords = coords + threadIdx.x * (yRank + zRank); + zCoords = yCoords + yRank; + shape::index2coords(i, zShapeInfo, zCoords); + } + + for (Nd4jLong j = 0; j < xLen; ++j) { + + if(is1Dcase) { + + yOffset = j * shape::stride(yShapeInfo)[yNonUnitDim]; + zFirstCoord = x[xySameStride ? yOffset : j * shape::stride(xShapeInfo)[xNonUnitDim]]; + + if(i != zFirstCoord) + continue; + + zOffset = i * shape::stride(zShapeInfo)[zNonUnitDim]; + } + + else { + + shape::index2coords(j, xShapeInfo, yCoords); // first xRank coordinates in yCoords are the same for y and x + + zFirstCoord = x[shape::getOffset(xShapeInfo, yCoords)]; + + if(zCoords[0] != zFirstCoord) + continue; + + for (uint k = 0; k < yRank - xRank; ++k) + yCoords[xRank + k] = zCoords[k + 1]; + + yOffset = shape::getOffset(yShapeInfo, yCoords); + zOffset = shape::getOffset(zShapeInfo, zCoords); + } + + switch (opCode) { + case pairwise::Add: + z[zOffset] += y[yOffset]; + break; + case pairwise::Subtract: + z[zOffset] -= y[yOffset]; + break; + case pairwise::Multiply: + z[zOffset] *= y[yOffset]; + break; + case pairwise::Divide: + z[zOffset] /= y[yOffset]; + break; + case pairwise::ReverseSubtract: + z[zOffset] = y[yOffset] - z[zOffset]; + break; + case pairwise::ReverseDivide: + z[zOffset] = y[yOffset] / z[zOffset]; + break; + case pairwise::CopyPws: + z[zOffset] = y[yOffset]; + break; + case pairwise::MaxPairwise: + if(z[zOffset] < y[yOffset]) z[zOffset] = y[yOffset]; + break; + case pairwise::MinPairwise: + if(z[zOffset] > y[yOffset]) z[zOffset] = y[yOffset]; + break; + default: + continue; + } + } + } +} + + +/////////////////////////////////////////////////////////////////// +// x - indices, y - updates, z - input/output +template +__global__ static void scatterCuda(const int opCode, + const void *vx, const Nd4jLong *xShapeInfo, + const void *vy, const Nd4jLong *yShapeInfo, + void *vz, const Nd4jLong *zShapeInfo) { + + const auto x = reinterpret_cast(vx); + const auto y = reinterpret_cast(vy); + auto z = reinterpret_cast(vz); + + __shared__ int xRank, yRank, zRank, xNonUnitDim, yNonUnitDim, zNonUnitDim, *coords; + __shared__ Nd4jLong yLen; + __shared__ bool is1Dcase, xySameStride; + + if (threadIdx.x == 0) { + extern __shared__ unsigned char shmem[]; + coords = reinterpret_cast(shmem); + + yLen = shape::length(yShapeInfo); + + xRank = shape::rank(xShapeInfo); + yRank = shape::rank(yShapeInfo); + zRank = shape::rank(zShapeInfo); + + xNonUnitDim = yNonUnitDim = zNonUnitDim = 0; + + is1Dcase = (shape::isCommonVector(zShapeInfo, zNonUnitDim) || shape::isScalar(zShapeInfo)) && (shape::isCommonVector(yShapeInfo, yNonUnitDim) || shape::isScalar(yShapeInfo)) && (shape::isCommonVector(xShapeInfo, xNonUnitDim) || shape::isScalar(xShapeInfo)); + + if(is1Dcase) + xySameStride = shape::stride(xShapeInfo)[xNonUnitDim] = shape::stride(yShapeInfo)[yNonUnitDim]; + } + __syncthreads(); + + + Nd4jLong xOffset, yOffset, zOffset; + int *yCoords, *zCoords; + + if(!is1Dcase) { + yCoords = coords + threadIdx.x * (yRank + zRank); + zCoords = yCoords + yRank; + } + + for (Nd4jLong i = blockIdx.x * blockDim.x + threadIdx.x; i < yLen; i += gridDim.x * blockDim.x) { + + if(is1Dcase) { + + yOffset = i * shape::stride(yShapeInfo)[yNonUnitDim]; + zOffset = x[xySameStride ? yOffset : i * shape::stride(xShapeInfo)[xNonUnitDim]] * shape::stride(zShapeInfo)[zNonUnitDim]; + } + else { + shape::index2coords(i, yShapeInfo, yCoords); + + yOffset = shape::getOffset(yShapeInfo, yCoords); + xOffset = shape::getOffset(xShapeInfo, yCoords); // first xRank coordinates in yCoords are the same for y and x -> for (uint j = 0; j < xRank; ++j) xCoords[j] = yCoords[j]; + + zCoords[0] = x[xOffset]; + + for (uint j = 0; j < yRank - xRank; ++j) + zCoords[j + 1] = yCoords[xRank + j]; + + zOffset = shape::getOffset(zShapeInfo, zCoords); + } + + switch (opCode) { + case pairwise::Add: + z[zOffset] += y[yOffset]; + break; + case pairwise::Subtract: + z[zOffset] -= y[yOffset]; + break; + case pairwise::Multiply: + z[zOffset] *= y[yOffset]; + break; + case pairwise::Divide: + z[zOffset] /= y[yOffset]; + break; + case pairwise::ReverseSubtract: + z[zOffset] = y[yOffset] - z[zOffset]; + break; + case pairwise::ReverseDivide: + z[zOffset] = y[yOffset] / z[zOffset]; + break; + case pairwise::CopyPws: + z[zOffset] = y[yOffset]; + break; + case pairwise::MaxPairwise: + if(z[zOffset] < y[yOffset]) z[zOffset] = y[yOffset]; + break; + case pairwise::MinPairwise: + if(z[zOffset] > y[yOffset]) z[zOffset] = y[yOffset]; + break; + default: + continue; + } + } +} + +/////////////////////////////////////////////////////////////////// +template +static void scatterCudaLauncher(const int blocksPerGrid, const int threadsPerBlock, const int sharedMem, const cudaStream_t *stream, + const int opCode, + const void *vx, const Nd4jLong *xShapeInfo, + const void *vy, const Nd4jLong *yShapeInfo, + void *vz, const Nd4jLong *zShapeInfo, + const bool lock) { + + if(lock) + scatterLockCuda<<>>(opCode, vx, xShapeInfo, vy, yShapeInfo, vz, zShapeInfo); + else + scatterCuda<<>>(opCode, vx, xShapeInfo, vy, yShapeInfo, vz, zShapeInfo); +} + + +/////////////////////////////////////////////////////////////////// +void scatter(nd4j::LaunchContext *context, pairwise::Ops op, const NDArray& indices, const NDArray& updates, NDArray& output, const bool lock) { + + const auto xType = indices.dataType(); + const auto yType = updates.dataType(); + + const int threadsPerBlock = MAX_NUM_THREADS / 4; + const int blocksPerGrid = ((lock ? output.lengthOf() : updates.lengthOf()) + threadsPerBlock - 1) / threadsPerBlock; + const int sharedMem = sizeof(int) * threadsPerBlock * (updates.rankOf() + output.rankOf()) + 256; + + PointersManager manager(context, "scatter"); + + NDArray::prepareSpecialUse({&output}, {&updates, &indices}); + BUILD_DOUBLE_SELECTOR(xType, yType, scatterCudaLauncher, (blocksPerGrid, threadsPerBlock, sharedMem, context->getCudaStream(), op, indices.getSpecialBuffer(), indices.getSpecialShapeInfo(), updates.getSpecialBuffer(), updates.getSpecialShapeInfo(), output.getSpecialBuffer(), output.getSpecialShapeInfo(), lock), INDEXING_TYPES, GENERIC_NUMERIC_TYPES); + NDArray::registerSpecialUse({&output}, {&updates, &indices}); + + manager.synchronize(); +} + +/////////////////////////////////////////////////////////////////// +// x - indices, y - updates, z - output +template +__global__ static void scatterNDLockCuda(const int opCode, + const void *vx, const Nd4jLong *xShapeInfo, + const void *vy, const Nd4jLong *yShapeInfo, + void *vz, const Nd4jLong *zShapeInfo) { + + const auto x = reinterpret_cast(vx); + const auto y = reinterpret_cast(vy); + auto z = reinterpret_cast(vz); + + __shared__ int xRank, yRank, zRank, biggerXYRank, xLastDim, *coords, xNonUnitDim, yNonUnitDim, zNonUnitDim; + __shared__ Nd4jLong zLen, len; + __shared__ bool is1Dcase; + + if (threadIdx.x == 0) { + extern __shared__ unsigned char shmem[]; + coords = reinterpret_cast(shmem); + + xRank = shape::rank(xShapeInfo); + yRank = shape::rank(yShapeInfo); + zRank = shape::rank(zShapeInfo); + xLastDim = shape::sizeAt(xShapeInfo, -1); + + biggerXYRank = xRank > yRank ? xRank : yRank; + + xNonUnitDim = yNonUnitDim = zNonUnitDim = 0; + + is1Dcase = (shape::isCommonVector(zShapeInfo, zNonUnitDim) || shape::isScalar(zShapeInfo)) && (shape::isCommonVector(yShapeInfo, yNonUnitDim) || shape::isScalar(yShapeInfo)) && (shape::isCommonVector(xShapeInfo, xNonUnitDim) || shape::isScalar(xShapeInfo)); + + len = is1Dcase ? shape::length(xShapeInfo) : shape::length(xShapeInfo) / xLastDim; + zLen = shape::length(zShapeInfo); + } + __syncthreads(); + + Nd4jLong yOffset, zOffset, xOffset; + int *yCoords, *zCoords; + + if(!is1Dcase) { + yCoords = coords + threadIdx.x * (biggerXYRank + zRank); + zCoords = yCoords + biggerXYRank; + } + + for (Nd4jLong i = blockIdx.x * blockDim.x + threadIdx.x; i < zLen; i += gridDim.x * blockDim.x) { + + if(!is1Dcase) + shape::index2coords(i, zShapeInfo, zCoords); + + for (Nd4jLong j = 0; j < len; ++j) { // if !is1Dcase then we loop through first xRank-1 dimensions of x, that is we exclude last x dimension + + if(is1Dcase) { + + if(x[j * shape::stride(xShapeInfo)[xNonUnitDim]] != i) + continue; + + yOffset = j * shape::stride(yShapeInfo)[yNonUnitDim]; + zOffset = i * shape::stride(zShapeInfo)[zNonUnitDim]; + } + else { + + shape::index2coords(j, xRank-1, shape::shapeOf(const_cast(xShapeInfo)), yCoords); // first xRank-1 coordinates in yCoords are the same for y and x + + // first iteration + yCoords[xRank - 1] = 0; + xOffset = shape::getOffset(xShapeInfo, yCoords); + if(zCoords[0] != x[xOffset]) + continue; + + // rest iterations + bool matched = true; + for (uint k = 1; k < xLastDim; ++k) { + yCoords[xRank - 1] = k; + xOffset += shape::stride(xShapeInfo)[xRank-1]; + if(zCoords[k] != x[xOffset]) { + matched = false; + break; + } + } + + if(!matched) + continue; + + for (uint k = xLastDim; k < zRank; ++k) + yCoords[yRank - zRank + k] = zCoords[k]; + + yOffset = shape::getOffset(yShapeInfo, yCoords); + zOffset = shape::getOffset(zShapeInfo, zCoords); + } + + switch (opCode) { + case pairwise::Add: + z[zOffset] += y[yOffset]; + break; + case pairwise::Subtract: + z[zOffset] -= y[yOffset]; + break; + case pairwise::Multiply: + z[zOffset] *= y[yOffset]; + break; + case pairwise::Divide: + z[zOffset] /= y[yOffset]; + break; + case pairwise::ReverseSubtract: + z[zOffset] = y[yOffset] - z[zOffset]; + break; + case pairwise::ReverseDivide: + z[zOffset] = y[yOffset] / z[zOffset]; + break; + case pairwise::CopyPws: + z[zOffset] = y[yOffset]; + break; + case pairwise::MaxPairwise: + if(z[zOffset] < y[yOffset]) z[zOffset] = y[yOffset]; + break; + case pairwise::MinPairwise: + if(z[zOffset] > y[yOffset]) z[zOffset] = y[yOffset]; + break; + default: + continue; + } + } + } +} + +/////////////////////////////////////////////////////////////////// +// x - indices, y - updates, z - output +template +__global__ static void scatterNDCuda(const int opCode, + const void *vx, const Nd4jLong *xShapeInfo, + const void *vy, const Nd4jLong *yShapeInfo, + void *vz, const Nd4jLong *zShapeInfo) { + + const auto x = reinterpret_cast(vx); + const auto y = reinterpret_cast(vy); + auto z = reinterpret_cast(vz); + + __shared__ int xRank, yRank, zRank, biggerXYRank, xLastDim, *coords, xNonUnitDim, yNonUnitDim, zNonUnitDim; + __shared__ Nd4jLong yLen; + __shared__ bool is1Dcase; + + if (threadIdx.x == 0) { + extern __shared__ unsigned char shmem[]; + coords = reinterpret_cast(shmem); + + yLen = shape::length(yShapeInfo); + xRank = shape::rank(xShapeInfo); + yRank = shape::rank(yShapeInfo); + zRank = shape::rank(zShapeInfo); + xLastDim = shape::sizeAt(xShapeInfo, -1); + + biggerXYRank = xRank > yRank ? xRank : yRank; + + xNonUnitDim = yNonUnitDim = zNonUnitDim = 0; + + is1Dcase = (shape::isCommonVector(zShapeInfo, zNonUnitDim) || shape::isScalar(zShapeInfo)) && (shape::isCommonVector(yShapeInfo, yNonUnitDim) || shape::isScalar(yShapeInfo)) && (shape::isCommonVector(xShapeInfo, xNonUnitDim) || shape::isScalar(xShapeInfo)); + } + __syncthreads(); + + Nd4jLong yOffset, zOffset; + int *yCoords, *zCoords; + + if(!is1Dcase) { + yCoords = coords + threadIdx.x * (biggerXYRank + zRank); + zCoords = yCoords + biggerXYRank; + } + + for (Nd4jLong i = blockIdx.x * blockDim.x + threadIdx.x; i < yLen; i += gridDim.x * blockDim.x) { + + if(is1Dcase) { + + yOffset = i * shape::stride(yShapeInfo)[zNonUnitDim]; + zOffset = x[i * shape::stride(xShapeInfo)[xNonUnitDim]] * shape::stride(zShapeInfo)[zNonUnitDim]; + } + else { + + shape::index2coords(i, yShapeInfo, yCoords); + + yOffset = shape::getOffset(yShapeInfo, yCoords); + + if(yRank >= xRank) + zCoords[xLastDim] = yCoords[xRank - 1]; // saving y coordinate, since it might be changed in next instructions + + for (uint j = 0; j < xLastDim; ++j) { // first xRank-1 coordinates in yCoords are the same for y and x + yCoords[xRank - 1] = j; + zCoords[j] = x[shape::getOffset(xShapeInfo, yCoords)]; + } + + for (uint j = xLastDim + 1; j < zRank; ++j) + zCoords[j] = yCoords[yRank - zRank + j]; + + zOffset = shape::getOffset(zShapeInfo, zCoords); + } + + switch (opCode) { + case pairwise::Add: + z[zOffset] += y[yOffset]; + break; + case pairwise::Subtract: + z[zOffset] -= y[yOffset]; + break; + case pairwise::Multiply: + z[zOffset] *= y[yOffset]; + break; + case pairwise::Divide: + z[zOffset] /= y[yOffset]; + break; + case pairwise::ReverseSubtract: + z[zOffset] = y[yOffset] - z[zOffset]; + break; + case pairwise::ReverseDivide: + z[zOffset] = y[yOffset] / z[zOffset]; + break; + case pairwise::CopyPws: + z[zOffset] = y[yOffset]; + break; + case pairwise::MaxPairwise: + if(z[zOffset] < y[yOffset]) z[zOffset] = y[yOffset]; + break; + case pairwise::MinPairwise: + if(z[zOffset] > y[yOffset]) z[zOffset] = y[yOffset]; + break; + default: + continue; + } + } +} + +/////////////////////////////////////////////////////////////////// +template +static void scatterNDCudaLauncher(const int blocksPerGrid, const int threadsPerBlock, const int sharedMem, const cudaStream_t *stream, + const int opCode, + const void *vx, const Nd4jLong *xShapeInfo, + const void *vy, const Nd4jLong *yShapeInfo, + void *vz, const Nd4jLong *zShapeInfo, + const bool lock) { + + if(lock) + scatterNDLockCuda<<>>(opCode, vx, xShapeInfo, vy, yShapeInfo, vz, zShapeInfo); + else + scatterNDCuda<<>>(opCode, vx, xShapeInfo, vy, yShapeInfo, vz, zShapeInfo); +} + +/////////////////////////////////////////////////////////////////// +void scatterND(nd4j::LaunchContext *context, pairwise::Ops op, const NDArray& indices, const NDArray& updates, NDArray& output, const bool lock) { + + const int xRank = indices.rankOf(); + const int yRank = updates.rankOf(); + const int zRank = output.rankOf(); + + const int threadsPerBlock = MAX_NUM_THREADS / 4; + const int blocksPerGrid = ((lock ? output.lengthOf() : updates.lengthOf()) + threadsPerBlock - 1) / threadsPerBlock; + const int sharedMem = threadsPerBlock * sizeof(int) * ((yRank > xRank ? yRank : xRank) + zRank) + 256; + + const auto xType = indices.dataType(); + const auto yType = updates.dataType(); + + PointersManager manager(context, "scatterND"); + + NDArray::prepareSpecialUse({&output}, {&updates, &indices}); + BUILD_DOUBLE_SELECTOR(xType, yType, scatterNDCudaLauncher, (blocksPerGrid, threadsPerBlock, sharedMem, context->getCudaStream(), op, indices.getSpecialBuffer(), indices.getSpecialShapeInfo(), updates.getSpecialBuffer(), updates.getSpecialShapeInfo(), output.getSpecialBuffer(), output.getSpecialShapeInfo(), lock), INDEXING_TYPES, GENERIC_NUMERIC_TYPES); + NDArray::registerSpecialUse({&output}, {&updates, &indices}); + + manager.synchronize(); +} + +/////////////////////////////////////////////////////////////////// +template +__global__ void scatterForLossCuda(const void *vx, const Nd4jLong *xShapeInfo, + void *vy, const Nd4jLong *yShapeInfo, + void *vz, const Nd4jLong *zShapeInfo) { + + const auto x = reinterpret_cast(vx); + auto y = reinterpret_cast(vy); + auto z = reinterpret_cast(vz); + + __shared__ Nd4jLong xLen, *sharedMem; + __shared__ int xRank; // xRank = zRank, yRank = xRank + 1 + + if (threadIdx.x == 0) { + extern __shared__ unsigned char shmem[]; + sharedMem = reinterpret_cast(shmem); + + xLen = shape::length(xShapeInfo); + xRank = shape::rank(xShapeInfo); + } + __syncthreads(); + + const auto xInd = threadIdx.x + blockIdx.x * blockDim.x; + + if(xInd >= xLen) + return; + + auto coords = sharedMem + threadIdx.x * (xRank + 1); + + shape::index2coords(xInd, xShapeInfo, coords); + + // y last coordinate + coords[xRank] = x[shape::getOffset(xShapeInfo, coords)]; + + const auto yOffset = shape::getOffset(yShapeInfo, coords); + + if(z == nullptr) { // gradient calculation + y[yOffset] -= 1.f; + } + else { + z[shape::getOffset(zShapeInfo, coords)] = y[yOffset]; + } +} + +/////////////////////////////////////////////////////////////////// +template +static void scatterForLossCudaLauncher(const int blocksPerGrid, const int threadsPerBlock, const int sharedMem, const cudaStream_t *stream, const void *vx, const Nd4jLong* xShapeInfo, void *vy, const Nd4jLong* yShapeInfo, void *vz, const Nd4jLong* zShapeInfo) { + + scatterForLossCuda<<>>(vx, xShapeInfo, vy, yShapeInfo, vz, zShapeInfo); +} + +/////////////////////////////////////////////////////////////////// +void scatterForLoss(nd4j::LaunchContext* context, const NDArray& indices, NDArray& updates, NDArray& output, const bool calcGrad) { + // shapes of indices and output must be the same + // shape of indices should be the same as updates shape with last dimension excluded, for example if updates is {a,b,c} then indices should be {a,b} + + PointersManager manager(context, "scatterForLoss"); + + const int threadsPerBlock = MAX_NUM_THREADS / 2; + const int blocksPerGrid = (indices.lengthOf() + threadsPerBlock - 1) / threadsPerBlock; + const int sharedMem = updates.rankOf() * sizeof(Nd4jLong) * threadsPerBlock + 128; + + if(calcGrad) { + NDArray::prepareSpecialUse({&updates}, {&indices}); + BUILD_DOUBLE_SELECTOR(indices.dataType(), updates.dataType(), scatterForLossCudaLauncher, (blocksPerGrid, threadsPerBlock, sharedMem, context->getCudaStream(), indices.getSpecialBuffer(), indices.getSpecialShapeInfo(), updates.specialBuffer(), updates.specialShapeInfo(), nullptr, nullptr), INDEXING_TYPES, FLOAT_TYPES); + NDArray::registerSpecialUse({&updates}, {&indices}); + } + else { + NDArray::prepareSpecialUse({&output}, {&indices, &updates}); + BUILD_DOUBLE_SELECTOR(indices.dataType(), updates.dataType(), scatterForLossCudaLauncher, (blocksPerGrid, threadsPerBlock, sharedMem, context->getCudaStream(), indices.getSpecialBuffer(), indices.getSpecialShapeInfo(), updates.getSpecialBuffer(), updates.getSpecialShapeInfo(), output.specialBuffer(), output.specialShapeInfo()), INDEXING_TYPES, FLOAT_TYPES); + NDArray::registerSpecialUse({&output}, {&indices, &updates}); + } + + manager.synchronize(); +} + +} +} +} + + +/* + +/////////////////////////////////////////////////////////////////// +template +static void scatterLockCudaLauncher(const int blocksPerGrid, const int threadsPerBlock, const int sharedMem, const cudaStream_t *stream, + const int opCode, + const void* vx, const Nd4jLong *xShapeInfo, + const void* vy, const Nd4jLong *yTadShapeInfo, const Nd4jLong *yOffsets, + void* vz, const Nd4jLong *zTadShapeInfo, const Nd4jLong *zOffsets, + const Nd4jLong xLen, const Nd4jLong yTadLen, const Nd4jLong zTadLen) { + + scatterLockCuda<<>>(opCode, vx, xShapeInfo, vy, yTadShapeInfo, yOffsets, vz, zTadShapeInfo, zOffsets, xLen, yTadLen, zTadLen); +} - // NDArray::registerSpecialUse({&output}, {&updates, &indices}); - // manager.synchronize(); - // } /////////////////////////////////////////////////////////////////// // x - indices, y - updates, z - input/output @@ -177,6 +723,35 @@ __global__ static void scatterLockCuda(const int opCode, void* vz, const Nd4jLong *zTadShapeInfo, const Nd4jLong *zOffsets, const Nd4jLong xLen, const Nd4jLong yTadLen, const Nd4jLong zTadLen) { + + + const int xRank = indices.rankOf(); + + std::vector zTadDims = ShapeUtils::evalDimsToExclude(output.rankOf(), {0}); + + int sizeOfUpdDims = xRank; + if(output.rankOf() == updates.rankOf() && indices.isVector()) + sizeOfUpdDims = 1; + + std::vector yTadDims(sizeOfUpdDims); + std::iota(yTadDims.begin(), yTadDims.end(), 0); + + auto packY = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(updates.getShapeInfo(), ShapeUtils::evalDimsToExclude(updates.rankOf(), yTadDims)); + auto packZ = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(output.getShapeInfo(), zTadDims); + + const Nd4jLong zTadLen = shape::length(packZ.primaryShapeInfo()); + const Nd4jLong yTadLen = shape::length(packY.primaryShapeInfo()); + + const auto threadsPerBlock = nd4j::math::nd4j_max(32, nd4j::math::nd4j_min(zTadLen, 1024)); + const auto blocksPerGrid = indices.lengthOf(); + + const auto xType = indices.dataType(); + const auto yType = updates.dataType(); + + BUILD_DOUBLE_SELECTOR(xType, yType, scatterLockCudaLauncher, (blocksPerGrid, threadsPerBlock, 1024, context->getCudaStream(), op, indices.getSpecialBuffer(), indices.getSpecialShapeInfo(), updates.getSpecialBuffer(), packY.specialShapeInfo(), packY.specialOffsets(), output.getSpecialBuffer(), packZ.specialShapeInfo(), packZ.specialOffsets(), indices.lengthOf(), yTadLen, zTadLen), INDEXING_TYPES, GENERIC_NUMERIC_TYPES); + + + const auto x = reinterpret_cast(vx); const auto y = reinterpret_cast(vy); auto z = reinterpret_cast(vz); @@ -280,161 +855,143 @@ __global__ static void scatterLockCuda(const int opCode, } } -/////////////////////////////////////////////////////////////////// -template -static void scatterLockCudaLauncher(const int blocksPerGrid, const int threadsPerBlock, const int sharedMem, const cudaStream_t *stream, - const int opCode, - const void* vx, const Nd4jLong *xShapeInfo, - const void* vy, const Nd4jLong *yTadShapeInfo, const Nd4jLong *yOffsets, - void* vz, const Nd4jLong *zTadShapeInfo, const Nd4jLong *zOffsets, - const Nd4jLong xLen, const Nd4jLong yTadLen, const Nd4jLong zTadLen) { + template + __global__ static void scatterCuda(const int opCode, const int numOfSubArrs, + void* vx, const Nd4jLong *xShapeInfo, const Nd4jLong *xOffsets, + void* vy, const Nd4jLong *yShapeInfo, const Nd4jLong *yOffsets, + const int* indexes, unsigned int arrLenX, unsigned int arrLenY) { - scatterLockCuda<<>>(opCode, vx, xShapeInfo, vy, yTadShapeInfo, yOffsets, vz, zTadShapeInfo, zOffsets, xLen, yTadLen, zTadLen); -} + __shared__ T *x, *y; -/////////////////////////////////////////////////////////////////// -// x - indices, y - updates, z - input/output -template -__global__ static void scatterCuda(const int opCode, - const void *vx, const Nd4jLong *xShapeInfo, - const void *vy, const Nd4jLong *yShapeInfo, - void *vz, const Nd4jLong *zShapeInfo) { + if (locking) { - const auto x = reinterpret_cast(vx); - const auto y = reinterpret_cast(vy); - auto z = reinterpret_cast(vz); + for (int e = 0; e < numOfSubArrs; e++) { - __shared__ int xRank, yRank, zRank; - __shared__ Nd4jLong yLen, totalThreads, *coord; + const auto xIndex = indexes[e]; + const bool isOwner = xIndex < gridDim.x ? blockIdx.x == xIndex : blockIdx.x == xIndex % gridDim.x; - if (threadIdx.x == 0) { - extern __shared__ unsigned char shmem[]; - coord = reinterpret_cast(shmem); - yLen = shape::length(yShapeInfo); - totalThreads = gridDim.x * blockDim.x; - xRank = shape::rank(xShapeInfo); - yRank = shape::rank(yShapeInfo); - zRank = shape::rank(zShapeInfo); - } - __syncthreads(); + if (!isOwner) + continue; - auto xCoord = coord + threadIdx.x * (xRank + yRank + zRank); - auto yCoord = xCoord + xRank; - auto zCoord = yCoord + yRank; + if (threadIdx.x == 0) { + x = reinterpret_cast(vx) + xOffsets[xIndex]; + y = reinterpret_cast(vy) + yOffsets[e]; + } + __syncthreads(); - const auto tid = blockIdx.x * blockDim.x + threadIdx.x; + for (Nd4jLong i = threadIdx.x; i < arrLenX; i += blockDim.x) { - for (Nd4jLong i = tid; i < yLen; i += totalThreads) { + const auto xOffset = shape::getIndexOffset(i, xShapeInfo); + const auto yOffset = shape::getIndexOffset(i, yShapeInfo); - shape::index2coords(i, yShapeInfo, yCoord); + switch (opCode) { + case pairwise::Add: + x[xOffset] += y[yOffset]; + break; + case pairwise::Subtract: + x[xOffset] -= y[yOffset]; + break; + case pairwise::Multiply: + x[xOffset] *= y[yOffset]; + break; + case pairwise::Divide: + x[xOffset] /= y[yOffset]; + break; + case pairwise::ReverseSubtract: + x[xOffset] = y[yOffset] - x[xOffset]; + break; + case pairwise::ReverseDivide: + x[xOffset] = y[yOffset] / x[xOffset]; + break; + case pairwise::CopyPws: + x[xOffset] = y[yOffset]; + break; + default: + continue; + } + } + __syncthreads(); + } + } else { + for (int e = blockIdx.x; e < numOfSubArrs; e+= gridDim.x) { - for (uint j = 0; j < xRank; ++j) - xCoord[j] = yCoord[j]; + if (threadIdx.x == 0) { + const auto xIndex = indexes[e]; + x = reinterpret_cast(vx) + xOffsets[xIndex]; + y = reinterpret_cast(vy) + yOffsets[e]; + } + __syncthreads(); - const auto xOffset = shape::getOffset(xShapeInfo, xCoord); - zCoord[0] = x[xOffset]; + for (Nd4jLong i = threadIdx.x; i < arrLenX; i += blockDim.x) { + const auto xOffset = shape::getIndexOffset(i, xShapeInfo); + const auto yOffset = shape::getIndexOffset(i, yShapeInfo); - for (uint j = 0; j < yRank - xRank; ++j) - zCoord[j + 1] = yCoord[xRank + j]; - - const auto yOffset = shape::getOffset(yShapeInfo, yCoord); - const auto zOffset = shape::getOffset(zShapeInfo, zCoord); - - switch (opCode) { - case pairwise::Add: - z[zOffset] += y[yOffset]; - break; - case pairwise::Subtract: - z[zOffset] -= y[yOffset]; - break; - case pairwise::Multiply: - z[zOffset] *= y[yOffset]; - break; - case pairwise::Divide: - z[zOffset] /= y[yOffset]; - break; - case pairwise::ReverseSubtract: - z[zOffset] = y[yOffset] - z[zOffset]; - break; - case pairwise::ReverseDivide: - z[zOffset] = y[yOffset] / z[zOffset]; - break; - case pairwise::CopyPws: - z[zOffset] = y[yOffset]; - break; - case pairwise::MaxPairwise: - if(z[zOffset] < y[yOffset]) z[zOffset] = y[yOffset]; - break; - case pairwise::MinPairwise: - if(z[zOffset] > y[yOffset]) z[zOffset] = y[yOffset]; - break; - default: - continue; - } - } -} - -/////////////////////////////////////////////////////////////////// -template -static void scatterCudaLauncher(const int blocksPerGrid, const int threadsPerBlock, const int sharedMem, const cudaStream_t *stream, - const int opCode, - const void *vx, const Nd4jLong *xShapeInfo, - const void *vy, const Nd4jLong *yShapeInfo, - void *vz, const Nd4jLong *zShapeInfo) { - - scatterCuda<<>>(opCode, vx, xShapeInfo, vy, yShapeInfo, vz, zShapeInfo); -} + switch (opCode) { + case pairwise::Add: + x[xOffset] += y[yOffset]; + break; + case pairwise::Subtract: + x[xOffset] -= y[yOffset]; + break; + case pairwise::Multiply: + x[xOffset] *= y[yOffset]; + break; + case pairwise::Divide: + x[xOffset] /= y[yOffset]; + break; + case pairwise::ReverseSubtract: + x[xOffset] = y[yOffset] - x[xOffset]; + break; + case pairwise::ReverseDivide: + x[xOffset] = y[yOffset] / x[xOffset]; + break; + case pairwise::CopyPws: + x[xOffset] = y[yOffset]; + break; + default: + continue; + } + } + __syncthreads(); + } + } + } -/////////////////////////////////////////////////////////////////// -void scatter(nd4j::LaunchContext *context, pairwise::Ops op, const NDArray& indices, const NDArray& updates, NDArray& output, const bool lock) { + template + void scatter_(nd4j::LaunchContext *context, pairwise::Ops op, const NDArray& indices, const NDArray& updates, NDArray& output, const bool lock) { + std::vector dims = {0}; + auto inverted = ShapeUtils::evalDimsToExclude(output.rankOf(), dims); - PointersManager manager(context, "scatter"); + auto packX = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(output.getShapeInfo(), inverted); + auto packY = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(updates.getShapeInfo(), inverted); - NDArray::prepareSpecialUse({&output}, {&updates, &indices}); + auto psX = packX.specialShapeInfo(); + auto psY = packY.specialShapeInfo(); - if(lock) { + PointersManager manager(context, "scatter"); - const int xRank = indices.rankOf(); + auto poX = packX.specialOffsets(); + auto poY = packY.specialOffsets(); - std::vector zTadDims = ShapeUtils::evalDimsToExclude(output.rankOf(), {0}); + NDArray::prepareSpecialUse({&output}, {&updates, &indices}); - int sizeOfUpdDims = xRank; - if(output.rankOf() == updates.rankOf() && indices.isVector()) - sizeOfUpdDims = 1; + unsigned int tadLengthX = shape::length(packX.primaryShapeInfo()); + unsigned int tadLengthY = shape::length(packY.primaryShapeInfo()); + if (tadLengthX != tadLengthY) + throw std::runtime_error("scatter: Lengths of TADs must be equal"); - std::vector yTadDims(sizeOfUpdDims); - std::iota(yTadDims.begin(), yTadDims.end(), 0); + auto blockSize = nd4j::math::nd4j_max(32, nd4j::math::nd4j_min(tadLengthX, 1024)); - auto packY = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(updates.getShapeInfo(), ShapeUtils::evalDimsToExclude(updates.rankOf(), yTadDims)); - auto packZ = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(output.getShapeInfo(), zTadDims); + if (lock) + scatterCuda<<<512, blockSize, 1024, *context->getCudaStream()>>>(op, indices.lengthOf(), output.getSpecialBuffer(), psX, poX, updates.getSpecialBuffer(), psY, poY, reinterpret_cast(indices.getSpecialBuffer()), tadLengthX, tadLengthY); + else + scatterCuda<<<512, blockSize, 1024, *context->getCudaStream()>>>(op, indices.lengthOf(), output.getSpecialBuffer(), psX, poX, updates.getSpecialBuffer(), psY, poY, reinterpret_cast(indices.getSpecialBuffer()), tadLengthX, tadLengthY); - const Nd4jLong zTadLen = shape::length(packZ.primaryShapeInfo()); - const Nd4jLong yTadLen = shape::length(packY.primaryShapeInfo()); + NDArray::registerSpecialUse({&output}, {&updates, &indices}); + manager.synchronize(); + } - const auto threadsPerBlock = nd4j::math::nd4j_max(32, nd4j::math::nd4j_min(zTadLen, 1024)); - const auto blocksPerGrid = indices.lengthOf(); - - const auto xType = indices.dataType(); - const auto yType = updates.dataType(); - - BUILD_DOUBLE_SELECTOR(xType, yType, scatterLockCudaLauncher, (blocksPerGrid, threadsPerBlock, 1024, context->getCudaStream(), op, indices.getSpecialBuffer(), indices.getSpecialShapeInfo(), updates.getSpecialBuffer(), packY.specialShapeInfo(), packY.specialOffsets(), output.getSpecialBuffer(), packZ.specialShapeInfo(), packZ.specialOffsets(), indices.lengthOf(), yTadLen, zTadLen), INDEXING_TYPES, GENERIC_NUMERIC_TYPES); - } - else { - - const int threadsPerBlock = MAX_NUM_THREADS / 8; - const int blocksPerGrid = (updates.lengthOf() + threadsPerBlock - 1) / threadsPerBlock; - const int sharedMem = 8 * threadsPerBlock * (indices.rankOf() + updates.rankOf() + output.rankOf()) + 128; - - const auto xType = indices.dataType(); - const auto yType = updates.dataType(); - - BUILD_DOUBLE_SELECTOR(xType, yType, scatterCudaLauncher, (blocksPerGrid, threadsPerBlock, sharedMem, context->getCudaStream(), op, indices.getSpecialBuffer(), indices.getSpecialShapeInfo(), updates.getSpecialBuffer(), updates.getSpecialShapeInfo(), output.getSpecialBuffer(), output.getSpecialShapeInfo()), INDEXING_TYPES, GENERIC_NUMERIC_TYPES); - } - - NDArray::registerSpecialUse({&output}, {&updates, &indices}); - manager.synchronize(); -} /////////////////////////////////////////////////////////////////// @@ -447,6 +1004,27 @@ __global__ static void scatterNDLockCuda(const int opCode, const Nd4jLong *zShapeInfo, const Nd4jLong numOfXTads, const Nd4jLong numOfZTads, const Nd4jLong yTadLen) { + + +--------------------------------------------------------------------------- +const int xLastDim = indices.sizeAt(-1); + + // y_tad and z_tad have the same shape + std::vector yTadDims(zRank - xLastDim), zTadDims(zRank - xLastDim); + for (int j = 0, i = zTadDims.size() - 1; i >=0 ; --i, ++j) { + yTadDims[i] = yRank - 1 - j; + zTadDims[i] = zRank - 1 - j; + } + + auto packX = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(indices.getShapeInfo(), {xRank - 1}); + auto packY = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(updates.getShapeInfo(), yTadDims); + auto packZ = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(output.getShapeInfo(), zTadDims); + + const int threadsPerBlock = MAX_NUM_THREADS / 4; + const int blocksPerGrid = packZ.numberOfTads(); + const int sharedMem = 8 * threadsPerBlock * xLastDim + 128; +--------------------------------------------------------------------------- + // zTadLen == yTadLen if numOfZTads > 1, in opposite case z and y are vectors // numOfXTads == numOfYTads if numOfZTads > 1, in opposite case z and y are vectors @@ -565,252 +1143,7 @@ __global__ static void scatterNDLockCuda(const int opCode, } } -/////////////////////////////////////////////////////////////////// -template -static void scatterNDLockCudaLauncher(const int blocksPerGrid, const int threadsPerBlock, const int sharedMem, const cudaStream_t *stream, - const int opCode, - const void* vx, const Nd4jLong *xTadShapeInfo, const Nd4jLong *xOffsets, - const void* vy, const Nd4jLong *yTadShapeInfo, const Nd4jLong *yOffsets, - void* vz, const Nd4jLong *zTadShapeInfo, const Nd4jLong *zOffsets, - const Nd4jLong *zShapeInfo, - const Nd4jLong numOfXTads, const Nd4jLong numOfZTads, const Nd4jLong zTadLen) { - - scatterNDLockCuda<<>>(opCode, - vx, xTadShapeInfo, xOffsets, - vy, yTadShapeInfo, yOffsets, - vz, zTadShapeInfo, zOffsets, - zShapeInfo, - numOfXTads, numOfZTads, zTadLen); -} - -/////////////////////////////////////////////////////////////////// -// x - indices, y - updates, z - output -template -__global__ static void scatterNDCuda(const int opCode, - const void *vx, const Nd4jLong *xShapeInfo, - const void *vy, const Nd4jLong *yShapeInfo, - void *vz, const Nd4jLong *zShapeInfo) { - - const auto x = reinterpret_cast(vx); - const auto y = reinterpret_cast(vy); - auto z = reinterpret_cast(vz); - - __shared__ int xRank, yRank, zRank, xLastDim; - __shared__ Nd4jLong yLen, totalThreads, *coord; - - if (threadIdx.x == 0) { - extern __shared__ unsigned char shmem[]; - coord = reinterpret_cast(shmem); - yLen = shape::length(yShapeInfo); - totalThreads = gridDim.x * blockDim.x; - xRank = shape::rank(xShapeInfo); - yRank = shape::rank(yShapeInfo); - zRank = shape::rank(zShapeInfo); - xLastDim = xShapeInfo[xRank]; - } - __syncthreads(); - - auto xCoord = coord + threadIdx.x * (xRank + yRank + zRank); - auto yCoord = xCoord + xRank; - auto zCoord = yCoord + yRank; - - const auto tid = blockIdx.x * blockDim.x + threadIdx.x; - - for (Nd4jLong i = tid; i < yLen; i += totalThreads) { - - shape::index2coords(i, yShapeInfo, yCoord); - - for (uint j = 0; j < xRank - 1; ++j) - xCoord[j] = yCoord[j]; - - for (uint j = 0; j < xLastDim; ++j) { - xCoord[xRank - 1] = j; - const auto xOffset = shape::getOffset(xShapeInfo, xCoord); - zCoord[j] = x[xOffset]; - } - - for (uint j = xLastDim; j < zRank; ++j) - zCoord[j] = yCoord[yRank - zRank + j]; - - const auto yOffset = shape::getOffset(yShapeInfo, yCoord); - const auto zOffset = shape::getOffset(zShapeInfo, zCoord); - - switch (opCode) { - case pairwise::Add: - z[zOffset] += y[yOffset]; - break; - case pairwise::Subtract: - z[zOffset] -= y[yOffset]; - break; - case pairwise::Multiply: - z[zOffset] *= y[yOffset]; - break; - case pairwise::Divide: - z[zOffset] /= y[yOffset]; - break; - case pairwise::ReverseSubtract: - z[zOffset] = y[yOffset] - z[zOffset]; - break; - case pairwise::ReverseDivide: - z[zOffset] = y[yOffset] / z[zOffset]; - break; - case pairwise::CopyPws: - z[zOffset] = y[yOffset]; - break; - case pairwise::MaxPairwise: - if(z[zOffset] < y[yOffset]) z[zOffset] = y[yOffset]; - break; - case pairwise::MinPairwise: - if(z[zOffset] > y[yOffset]) z[zOffset] = y[yOffset]; - break; - default: - continue; - } - } -} - -/////////////////////////////////////////////////////////////////// -template -static void scatterNDCudaLauncher(const int blocksPerGrid, const int threadsPerBlock, const int sharedMem, const cudaStream_t *stream, - const int opCode, - const void *vx, const Nd4jLong *xShapeInfo, - const void *vy, const Nd4jLong *yShapeInfo, - void *vz, const Nd4jLong *zShapeInfo) { - - scatterNDCuda<<>>(opCode, vx, xShapeInfo, vy, yShapeInfo, vz, zShapeInfo); -} - -/////////////////////////////////////////////////////////////////// -void scatterND(nd4j::LaunchContext *context, pairwise::Ops op, const NDArray& indices, const NDArray& updates, NDArray& output, const bool lock) { - - const int xRank = indices.rankOf(); - const int yRank = updates.rankOf(); - const int zRank = output.rankOf(); - - PointersManager manager(context, "scatterND"); - - NDArray::prepareSpecialUse({&output}, {&updates, &indices}); - - if(lock) { - - const int xLastDim = indices.sizeAt(-1); - - // y_tad and z_tad have the same shape - std::vector yTadDims(zRank - xLastDim), zTadDims(zRank - xLastDim); - for (int j = 0, i = zTadDims.size() - 1; i >=0 ; --i, ++j) { - yTadDims[i] = yRank - 1 - j; - zTadDims[i] = zRank - 1 - j; - } - - auto packX = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(indices.getShapeInfo(), {xRank - 1}); - auto packY = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(updates.getShapeInfo(), yTadDims); - auto packZ = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(output.getShapeInfo(), zTadDims); - - const int threadsPerBlock = MAX_NUM_THREADS / 4; - const int blocksPerGrid = packZ.numberOfTads(); - const int sharedMem = 8 * threadsPerBlock * xLastDim + 128; - - const auto xType = indices.dataType(); - const auto yType = updates.dataType(); - - BUILD_DOUBLE_SELECTOR(xType, yType, scatterNDLockCudaLauncher, (blocksPerGrid, threadsPerBlock, sharedMem, context->getCudaStream(), op, indices.getSpecialBuffer(), packX.specialShapeInfo(), packX.specialOffsets(), updates.getSpecialBuffer(), packY.specialShapeInfo(), packY.specialOffsets(), output.getSpecialBuffer(), packZ.specialShapeInfo(), packZ.specialOffsets(), output.getSpecialShapeInfo(), packX.numberOfTads(), packZ.numberOfTads(), shape::length(packY.primaryShapeInfo())), INDEXING_TYPES, GENERIC_NUMERIC_TYPES); - } - else { - - const int threadsPerBlock = MAX_NUM_THREADS / 8; - const int blocksPerGrid = (updates.lengthOf() + threadsPerBlock - 1) / threadsPerBlock; - const int sharedMem = 8 * threadsPerBlock * (xRank + yRank + zRank) + 128; - - const auto xType = indices.dataType(); - const auto yType = updates.dataType(); - - BUILD_DOUBLE_SELECTOR(xType, yType, scatterNDCudaLauncher, (blocksPerGrid, threadsPerBlock, sharedMem, context->getCudaStream(), op, indices.getSpecialBuffer(), indices.getSpecialShapeInfo(), updates.getSpecialBuffer(), updates.getSpecialShapeInfo(), output.getSpecialBuffer(), output.getSpecialShapeInfo()), INDEXING_TYPES, GENERIC_NUMERIC_TYPES); - } - - NDArray::registerSpecialUse({&output}, {&updates, &indices}); - manager.synchronize(); -} - -/////////////////////////////////////////////////////////////////// -template -__global__ void scatterForLossCuda(const void *vx, const Nd4jLong *xShapeInfo, - void *vy, const Nd4jLong *yShapeInfo, - void *vz, const Nd4jLong *zShapeInfo) { - - const auto x = reinterpret_cast(vx); - auto y = reinterpret_cast(vy); - auto z = reinterpret_cast(vz); - - __shared__ Nd4jLong xLen, *sharedMem; - __shared__ int xRank; // xRank = zRank, yRank = xRank + 1 - - if (threadIdx.x == 0) { - extern __shared__ unsigned char shmem[]; - sharedMem = reinterpret_cast(shmem); - - xLen = shape::length(xShapeInfo); - xRank = shape::rank(xShapeInfo); - } - __syncthreads(); - - const auto xInd = threadIdx.x + blockIdx.x * blockDim.x; - - if(xInd >= xLen) - return; - - auto coords = sharedMem + threadIdx.x * (xRank + 1); - - shape::index2coords(xInd, xShapeInfo, coords); - - // y last coordinate - coords[xRank] = x[shape::getOffset(xShapeInfo, coords)]; - - const auto yOffset = shape::getOffset(yShapeInfo, coords); - - if(z == nullptr) { // gradient calculation - y[yOffset] -= 1.f; - } - else { - z[shape::getOffset(zShapeInfo, coords)] = y[yOffset]; - } -} - -/////////////////////////////////////////////////////////////////// -template -static void scatterForLossCudaLauncher(const int blocksPerGrid, const int threadsPerBlock, const int sharedMem, const cudaStream_t *stream, const void *vx, const Nd4jLong* xShapeInfo, void *vy, const Nd4jLong* yShapeInfo, void *vz, const Nd4jLong* zShapeInfo) { - - scatterForLossCuda<<>>(vx, xShapeInfo, vy, yShapeInfo, vz, zShapeInfo); -} - -/////////////////////////////////////////////////////////////////// -void scatterForLoss(nd4j::LaunchContext* context, const NDArray& indices, NDArray& updates, NDArray& output, const bool calcGrad) { - // shapes of indices and output must be the same - // shape of indices should be the same as updates shape with last dimension excluded, for example if updates is {a,b,c} then indices should be {a,b} - - PointersManager manager(context, "scatterForLoss"); - - const int threadsPerBlock = MAX_NUM_THREADS / 2; - const int blocksPerGrid = (indices.lengthOf() + threadsPerBlock - 1) / threadsPerBlock; - const int sharedMem = updates.rankOf() * sizeof(Nd4jLong) * threadsPerBlock + 128; - - if(calcGrad) { - NDArray::prepareSpecialUse({&updates}, {&indices}); - BUILD_DOUBLE_SELECTOR(indices.dataType(), updates.dataType(), scatterForLossCudaLauncher, (blocksPerGrid, threadsPerBlock, sharedMem, context->getCudaStream(), indices.getSpecialBuffer(), indices.getSpecialShapeInfo(), updates.specialBuffer(), updates.specialShapeInfo(), nullptr, nullptr), INDEXING_TYPES, FLOAT_TYPES); - NDArray::registerSpecialUse({&updates}, {&indices}); - } - else { - NDArray::prepareSpecialUse({&output}, {&indices, &updates}); - BUILD_DOUBLE_SELECTOR(indices.dataType(), updates.dataType(), scatterForLossCudaLauncher, (blocksPerGrid, threadsPerBlock, sharedMem, context->getCudaStream(), indices.getSpecialBuffer(), indices.getSpecialShapeInfo(), updates.getSpecialBuffer(), updates.getSpecialShapeInfo(), output.specialBuffer(), output.specialShapeInfo()), INDEXING_TYPES, FLOAT_TYPES); - NDArray::registerSpecialUse({&output}, {&indices, &updates}); - } - - manager.synchronize(); -} - -} -} -} - +*/ // PointersManager manager(&context, "NativeOps::concat"); // PointersManager::printDevContentOnDev(vx, 2); // PointersManager::printDevContentOnDev(xShapeInfo, 8); diff --git a/libnd4j/include/ops/declarable/helpers/scatter.h b/libnd4j/include/ops/declarable/helpers/scatter.h index d0eb76a52..b470285ff 100644 --- a/libnd4j/include/ops/declarable/helpers/scatter.h +++ b/libnd4j/include/ops/declarable/helpers/scatter.h @@ -31,6 +31,8 @@ namespace nd4j { void scatterND(nd4j::LaunchContext* context, pairwise::Ops op, const NDArray& indices, const NDArray& updates, NDArray& output, const bool lock); void scatterForLoss(nd4j::LaunchContext* context, const NDArray& indices, NDArray& updates, NDArray& output, const bool calcGrad); + + Nd4jLong checkIndices(nd4j::LaunchContext *context, const NDArray& indices, const NDArray& output, const int axis = -1); } } } diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests16.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests16.cpp index d29d1f0e1..c2e39cab5 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests16.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests16.cpp @@ -76,6 +76,16 @@ TEST_F(DeclarableOpsTests16, scatter_upd_2) { delete result; } +TEST_F(DeclarableOpsTests16, scatter_upd_3) { + + NDArray x('c', {10, 3}, nd4j::DataType::FLOAT32); + NDArray indices('c', {2}, {20,5}, nd4j::DataType::INT32); + NDArray updates('c', {2, 3}, {100,101,102, 200,201,202}, nd4j::DataType::FLOAT32); + NDArray output('c', {10, 3}, nd4j::DataType::FLOAT32); + + nd4j::ops::scatter_upd op; + ASSERT_ANY_THROW(op.execute({&x, &indices, &updates}, {&output}, {}, {}, {true, true})); +} TEST_F(DeclarableOpsTests16, test_size_dtype_1) { auto x = NDArrayFactory::create('c', {3}, {1, 1, 1}); diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests2.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests2.cpp index 4941e7459..076d14385 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests2.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests2.cpp @@ -62,7 +62,7 @@ TEST_F(DeclarableOpsTests2, gather_2) { nd4j::ops::gather op; - auto result = op.execute({&input}, {}, {1, 0,1, 2,2, 1,2}); + auto result = op.execute({&input}, {}, {1, 0,1, 2,2, 1,2}, {true}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); @@ -125,7 +125,7 @@ TEST_F(DeclarableOpsTests2, gather_5) { nd4j::ops::gather op; - auto result = op.execute({&input, &indices}, {}, {1}); + auto result = op.execute({&input, &indices}, {}, {1}, {true}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); @@ -294,7 +294,7 @@ TEST_F(DeclarableOpsTests2, gather_13) { nd4j::ops::gather op; - auto result = op.execute({&input, &indices}, {}, {2}); + auto result = op.execute({&input, &indices}, {}, {2}, {true}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); @@ -306,6 +306,30 @@ TEST_F(DeclarableOpsTests2, gather_13) { delete result; } +//////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests2, gather_14) { + + NDArray input ('c', {2,3,4}, {1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24}); + NDArray indices ('c', {2,3}, {0, 10, 2, 20, 1,2}, nd4j::DataType::INT32); + NDArray output('c', {2,2,3,4}); + + nd4j::ops::gather op; + + ASSERT_ANY_THROW(op.execute({&input, &indices}, {&output}, {}, {1}, {true})); +} + +//////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests2, gather_15) { + + NDArray input ('c', {2,3,4,5}, nd4j::DataType::DOUBLE); + NDArray indices ('c', {2,3,4}, {0, 10, 2, 3, 0, 1, 20, 3, 0, 1, 2, 3,0, 1, 2, 3, 0, 1, 2, 30, 0, 1, 2, 3}, nd4j::DataType::INT32); + NDArray output('c', {2,3, 2,3,4, 5}); + + nd4j::ops::gather op; + + ASSERT_ANY_THROW(op.execute({&input, &indices}, {&output}, {}, {2}, {true})); +} + TEST_F(DeclarableOpsTests2, BroadcastGradientArgs_1) { NDArray input ('c', {3,3,4}, {1, 2, 3, 4, 5, 6, 7, 8, 9,10,11,12, 13,14,15,16,17,18,19,20,21,22,23,24, 25,26,27,28,29,30,31,32,33,34,35,36}, nd4j::DataType::INT32); diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests5.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests5.cpp index 2e8d96f3c..2c15f24bc 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests5.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests5.cpp @@ -704,7 +704,7 @@ TEST_F(DeclarableOpsTests5, gatherNd_test2) { auto expected = NDArrayFactory::create('c', {2,2,2}, {23, 24, 11, 12, 3, 4, 3, 4}); nd4j::ops::gather_nd op; - auto results = op.execute({&input, &indices}, {}, {}); + auto results = op.execute({&input, &indices}, {}, {}, {true}); auto output = results->at(0); ASSERT_EQ(Status::OK(), results->status()); @@ -798,7 +798,7 @@ TEST_F(DeclarableOpsTests5, gatherNd_test7) { auto expected = NDArrayFactory::create('c', {3,3}, {3,5,5,8,5,10,2,2,14}); nd4j::ops::gather_nd op; - auto results = op.execute({&input, &indices}, {}, {}); + auto results = op.execute({&input, &indices}, {}, {}, {true}); auto output = results->at(0); ASSERT_EQ(Status::OK(), results->status()); @@ -825,6 +825,52 @@ TEST_F(DeclarableOpsTests5, gatherNd_test8) { delete result; } +TEST_F(DeclarableOpsTests5, gatherNd_test9) { + auto x = NDArrayFactory::create('c', {2, 4, 2, 2}); + auto indices = NDArrayFactory::create('c', {3, 3}, {0,2,1, 0,1,0, 1,3,1}); + auto exp = NDArrayFactory::create('c', {3,2}, {11.f, 12.f, 5.f, 6.f, 31.f, 32.f}); + x.linspace(1); + + nd4j::ops::gather_nd op; + auto result = op.execute({&x, &indices}, {}, {}); + ASSERT_EQ(Status::OK(), result->status()); + + auto z = result->at(0); + + //z->printIndexedBuffer(); + //z->printShapeInfo("z shape"); + + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); + + delete result; +} + +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests5, gatherNd_test10) { + + auto input = NDArrayFactory::create('c', {4, 3, 2}); + auto indices = NDArrayFactory::create('c', {2,2,2}, {30,20,1,2, 0,10,0,1}); + + auto output = NDArrayFactory::create('c', {2,2,2}); + + nd4j::ops::gather_nd op; + + ASSERT_ANY_THROW(op.execute({&input, &indices}, {&output}, {}, {}, {true})); +} + +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests5, gatherNd_test11) { + + auto input = NDArrayFactory::create('c', {4, 4}); + auto indices = NDArrayFactory::create('c', {3,3,2}, {0,2,1, 0,10,0, 1,30,1, 0,20,1, 0,1,0, 1,30,1}); + auto output = NDArrayFactory::create('c', {3,3}); + + nd4j::ops::gather_nd op; + + ASSERT_ANY_THROW(op.execute({&input, &indices}, {&output}, {}, {}, {true})); +} + ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests5, reverse_sequense_test1) { diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests6.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests6.cpp index 79a569e0f..be25f0c62 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests6.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests6.cpp @@ -314,27 +314,6 @@ TEST_F(DeclarableOpsTests6, Test_Simple_Scalar_1) { delete result; } -TEST_F(DeclarableOpsTests6, Test_gatherNd_Edge_1) { - auto x = NDArrayFactory::create('c', {2, 4, 2, 2}); - auto indices = NDArrayFactory::create('c', {3, 3}, {0,2,1, 0,1,0, 1,3,1}); - auto exp = NDArrayFactory::create('c', {3,2}, {11.f, 12.f, 5.f, 6.f, 31.f, 32.f}); - x.linspace(1); - - nd4j::ops::gather_nd op; - auto result = op.execute({&x, &indices}, {}, {}); - ASSERT_EQ(Status::OK(), result->status()); - - auto z = result->at(0); - - //z->printIndexedBuffer(); - //z->printShapeInfo("z shape"); - - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); - - delete result; -} - TEST_F(DeclarableOpsTests6, Test_Order_1) { auto x = NDArrayFactory::create('f', {2, 3}); auto exp = NDArrayFactory::create('c', {2, 3}); diff --git a/libnd4j/tests_cpu/layers_tests/ParityOpsTests.cpp b/libnd4j/tests_cpu/layers_tests/ParityOpsTests.cpp index d5880d689..17f00011c 100644 --- a/libnd4j/tests_cpu/layers_tests/ParityOpsTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/ParityOpsTests.cpp @@ -756,7 +756,7 @@ TEST_F(ParityOpsTests, Test_Scatter_Add_4) { auto exp = NDArrayFactory::create('c', {2, 2, 2}, {3, 4, 5, 6, 5, 6, 7, 8}); nd4j::ops::scatter_add op; - auto result = op.execute({&matrix, &idc, &updates}, {}, {}, {true}); + auto result = op.execute({&matrix, &idc, &updates}, {}, {}, {true, true}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); auto z = result->at(0); @@ -791,7 +791,7 @@ TEST_F(ParityOpsTests, Test_Scatter_Add_6) { auto exp = NDArrayFactory::create('c', {2, 2, 2}, {7, 9, 11, 13, 7, 9, 11, 13}); nd4j::ops::scatter_add op; - auto result = op.execute({&matrix, &idc, &updates}, {}, {}, {true}); + auto result = op.execute({&matrix, &idc, &updates}, {}, {}, {true, true}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); auto z = result->at(0); @@ -837,6 +837,18 @@ TEST_F(ParityOpsTests, Test_Scatter_Add_8) { ASSERT_TRUE(expected.equalsTo(z)); } +//////////////////////////////////////////////////////////////////// +TEST_F(ParityOpsTests, Test_Scatter_Add_9) { + auto matrix = NDArrayFactory::create('c', {2, 2, 3}, {1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1}); + NDArray idc('c', {2, 2}, {1, 10, 0, 0}, nd4j::DataType::INT64); + auto updates = NDArrayFactory::create('c', {2, 2, 2, 3}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}); + auto output = NDArrayFactory::create('c', {2, 2, 3}); + + nd4j::ops::scatter_add op; + + ASSERT_ANY_THROW(op.execute({&matrix, &idc, &updates}, {&output}, {}, {}, {true, true})); +} + //////////////////////////////////////////////////////////////////// TEST_F(ParityOpsTests, scatterMax_test1) { auto matrix = NDArrayFactory::create('c', {2, 2}, {1, 2, 3, 4}); @@ -1010,6 +1022,18 @@ TEST_F(ParityOpsTests, scatterMin_test4) { delete result; } +//////////////////////////////////////////////////////////////////////// +TEST_F(ParityOpsTests, scatterMin_test5) { + auto matrix = NDArrayFactory::create('c', {2, 2, 2}, {1, 2, 3, 4, 5, 6, 7, 8}); + NDArray idc('c', {1,2}, {10,10}, nd4j::DataType::INT32); + auto updates = NDArrayFactory::create('c', {1, 2, 2, 2}, {1,10,1,10, 1,1,10,1.}); + auto output = NDArrayFactory::create('c', {2, 2, 2}); + + nd4j::ops::scatter_min op; + + ASSERT_ANY_THROW(op.execute({&matrix, &idc, &updates}, {&output}, {}, {}, {true, true})); +} + //////////////////////////////////////////////////////////////////////// TEST_F(ParityOpsTests, scatterND_test1) { @@ -1019,7 +1043,7 @@ TEST_F(ParityOpsTests, scatterND_test1) { auto exp = NDArrayFactory::create('c', {3, 4}, {50.f, 60.f, 70.f, 80.f, 10.f, 20.f, 30.f, 40.f, 0.f, 0.f, 0.f, 0.f}); nd4j::ops::scatter_nd op; - auto result = op.execute({&indices, &updates, &shape}, {}, {}); + auto result = op.execute({&indices, &updates, &shape}, {}, {false, true}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); auto z = result->at(0); @@ -1066,7 +1090,7 @@ TEST_F(ParityOpsTests, scatterND_test3) { updates.linspace(1.f); nd4j::ops::scatter_nd op; - auto result = op.execute({&indices, &updates, &shape}, {}, {}); + auto result = op.execute({&indices, &updates, &shape}, {}, {false, true}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); auto z = result->at(0); @@ -1160,7 +1184,7 @@ TEST_F(ParityOpsTests, scatterND_test7) { updates.linspace(1); nd4j::ops::scatter_nd op; - auto result = op.execute({&indices, &updates, &shape}, {}, {}, {true}); + auto result = op.execute({&indices, &updates, &shape}, {}, {}, {true, true}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); auto z = result->at(0); @@ -1193,6 +1217,20 @@ TEST_F(ParityOpsTests, scatterND_test8) { delete result; } +//////////////////////////////////////////////////////////////////////// +TEST_F(ParityOpsTests, scatterND_test9) { + + NDArray indices('c', {2, 3, 1}, {0., 20., 7., 30., 6., 90.}, nd4j::DataType::INT32); + auto updates = NDArrayFactory::create('c', {2,3, 3,4}); + auto shape = NDArrayFactory::create('c', {3}, {10, 3, 4}); + auto output = NDArrayFactory::create('c', {10, 3, 4}); + + nd4j::ops::scatter_nd op; + + ASSERT_ANY_THROW(auto result = op.execute({&indices, &updates, &shape}, {&output}, {}, {}, {false, true})); +} + + //////////////////////////////////////////////////////////////////////// TEST_F(ParityOpsTests, scatterND_add_test1) { @@ -1323,6 +1361,19 @@ TEST_F(ParityOpsTests, scatterND_add_test5) { delete result; } +//////////////////////////////////////////////////////////////////////// +TEST_F(ParityOpsTests, scatterND_add_test6) { + + auto input = NDArrayFactory::create('c', {6, 4}); + NDArray indices('c', {2, 3, 1}, {50.f, 1.f, 2.f, 3.f, 40.f, 0.f}, nd4j::DataType::INT32); + auto updates = NDArrayFactory::create('c', {2,3,4}); + auto output = NDArrayFactory::create('c', {6,4}); + + nd4j::ops::scatter_nd_add op; + + ASSERT_ANY_THROW(op.execute({&input, &indices, &updates}, {&output}, {}, {}, {false, true})); +} + //////////////////////////////////////////////////////////////////////// TEST_F(ParityOpsTests, scatterND_sub_test1) { @@ -1586,6 +1637,19 @@ TEST_F(ParityOpsTests, scatterND_update_test5) { delete result; } +//////////////////////////////////////////////////////////////////////// +TEST_F(ParityOpsTests, scatterND_update_test6) { + + auto input = NDArrayFactory::create('c', {6, 4}); + NDArray indices('c', {3, 3, 2}, {0.f,0.f, 10.f,1.f, 20.f,2.f, 30.f,3.f, 40.f,0.f, 50.f,1.f, 0.f,2.f, 1.f,3.f, 2.f,0.f}, nd4j::DataType::INT32); + auto updates = NDArrayFactory::create('c', {3,3}); + auto output = NDArrayFactory::create('c', {6,4}); + + nd4j::ops::scatter_nd_update op; + + ASSERT_ANY_THROW(op.execute({&input, &indices, &updates}, {&output}, {}, {}, {true, true})); +} + ////////////////////////////////////////////////////////////////////// TEST_F(ParityOpsTests, scatter_update_1) {