diff --git a/libnd4j/include/ops/declarable/helpers/cuda/segment_max.cu b/libnd4j/include/ops/declarable/helpers/cuda/segment_max.cu index 180af41e1..dc91a2704 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/segment_max.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/segment_max.cu @@ -201,7 +201,9 @@ namespace nd4j { } // -------------------------------------------------------------------------------------------------------------- // void segmentMaxFunctor(nd4j::LaunchContext* context , NDArray* input, NDArray* indices, NDArray* output) { + NDArray::prepareSpecialUse({output}, {input, indices}); BUILD_DOUBLE_SELECTOR(input->dataType(), indices->dataType(), segmentMaxFunctor_, (context, input, indices, output), NUMERIC_TYPES, INDEXING_TYPES); + NDArray::registerSpecialUse({output}, {input, indices}); } // -------------------------------------------------------------------------------------------------------------- // @@ -240,7 +242,9 @@ namespace nd4j { } // -------------------------------------------------------------------------------------------------------------- // void unsortedSegmentMaxFunctor(nd4j::LaunchContext* context, NDArray* input, NDArray* indices, Nd4jLong numOfClasses, NDArray* output) { + NDArray::prepareSpecialUse({output}, {input, indices}); BUILD_DOUBLE_SELECTOR(input->dataType(), indices->dataType(), unsortedSegmentMaxFunctor_, (context, input, indices, numOfClasses, output), NUMERIC_TYPES, INDEXING_TYPES); + NDArray::registerSpecialUse({output}, {input, indices}); } // -------------------------------------------------------------------------------------------------------------- // @@ -370,8 +374,10 @@ namespace nd4j { } // -------------------------------------------------------------------------------------------------------------- // int segmentMaxFunctorBP(nd4j::LaunchContext* context , NDArray* input, NDArray* indices, NDArray* gradOut, NDArray* output) { + NDArray::prepareSpecialUse({output}, {input, indices, gradOut}); BUILD_DOUBLE_SELECTOR(output->dataType(), indices->dataType(), return segmentMaxFunctorBP_, (context, input, indices, gradOut, output), FLOAT_TYPES, INDEXING_TYPES); + NDArray::registerSpecialUse({output}, {input, indices, gradOut}); } // -------------------------------------------------------------------------------------------------------------- // @@ -416,7 +422,9 @@ namespace nd4j { } // -------------------------------------------------------------------------------------------------------------- // int unsortedSegmentMaxFunctorBP(nd4j::LaunchContext* context , NDArray* input, NDArray* indices, NDArray* gradOut, Nd4jLong numOfClasses, NDArray* output) { + NDArray::prepareSpecialUse({output}, {input, indices, gradOut}); BUILD_DOUBLE_SELECTOR(output->dataType(), indices->dataType(), return unsortedSegmentMaxFunctorBP_, (context, input, indices, gradOut, numOfClasses, output), FLOAT_TYPES, INDEXING_TYPES); + NDArray::registerSpecialUse({output}, {input, indices, gradOut}); } } } diff --git a/libnd4j/include/ops/declarable/helpers/cuda/segment_mean.cu b/libnd4j/include/ops/declarable/helpers/cuda/segment_mean.cu index 3f2168da4..fbb45a375 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/segment_mean.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/segment_mean.cu @@ -163,7 +163,7 @@ namespace helpers { classesRangesBegs.assign(indices->lengthOf()); classesRangesLens.assign(0); - + NDArray::prepareSpecialUse({output}, {input, indices}); dim3 dims(numClasses, indices->lengthOf(), numClasses * 32 + 32); int* begins = reinterpret_cast(classesRangesBegs.specialBuffer()); int* lengths = reinterpret_cast(classesRangesLens.specialBuffer()); @@ -182,11 +182,14 @@ namespace helpers { Nd4jLong* outputTadOffsets = packZ.specialOffsets(); segmentMeanTadKernel<<sizeAt(0), 512, 2048, *stream>>>(input->specialBuffer(), input->specialShapeInfo(), inputTads, inputTadOffsets, reinterpret_cast(indices->specialBuffer()), begins, lengths, numClasses, output->specialBuffer(), output->specialShapeInfo(), outputTads, outputTadOffsets); } + NDArray::registerSpecialUse({output}, {input, indices}); } // -------------------------------------------------------------------------------------------------------------- // void segmentMeanFunctor(nd4j::LaunchContext* context , NDArray* input, NDArray* indices, NDArray* output) { + NDArray::prepareSpecialUse({output}, {input, indices}); BUILD_DOUBLE_SELECTOR(output->dataType(), indices->dataType(), segmentMeanFunctor_, (context, input, indices, output), NUMERIC_TYPES, INDEXING_TYPES); + NDArray::registerSpecialUse({output}, {input, indices}); } // -------------------------------------------------------------------------------------------------------------- // @@ -194,6 +197,8 @@ namespace helpers { static void unsortedSegmentMeanFunctor_(nd4j::LaunchContext* context, NDArray* input, NDArray* indices, Nd4jLong numOfClasses, NDArray* output) { auto stream = context->getCudaStream(); // NDArray classes = NDArrayFactory::create('c', {numOfClasses, 2}); + NDArray::prepareSpecialUse({output}, {input, indices}); + NDArray classesRangesBegs = NDArrayFactory::create('c', {numOfClasses}); NDArray classesRangesLens = NDArrayFactory::create('c', {numOfClasses}); // NDArray row = NDArrayFactory::create('c', {1, 2}, {(int)indices->lengthOf(), (int)0}); @@ -221,12 +226,15 @@ namespace helpers { dims.x = input->sizeAt(0); segmentMeanTadKernel<<>>(input->specialBuffer(), input->specialShapeInfo(), inputTads, inputTadOffsets, reinterpret_cast(indices->specialBuffer()), begins, lengths, numOfClasses, output->specialBuffer(), output->specialShapeInfo(), outputTads, outputTadOffsets); } + NDArray::registerSpecialUse({output}, {input, indices}); } // -------------------------------------------------------------------------------------------------------------- // void unsortedSegmentMeanFunctor(nd4j::LaunchContext* context , NDArray* input, NDArray* indices, Nd4jLong numOfClasses, NDArray* output) { + NDArray::prepareSpecialUse({output}, {input, indices}); BUILD_DOUBLE_SELECTOR(input->dataType(), indices->dataType(), unsortedSegmentMeanFunctor_, (context, input, indices, numOfClasses, output), NUMERIC_TYPES, INDEXING_TYPES); + NDArray::prepareSpecialUse({output}, {input, indices}); } // -------------------------------------------------------------------------------------------------------------- // @@ -349,8 +357,10 @@ namespace helpers { // -------------------------------------------------------------------------------------------------------------- // // segmen mean bp main int segmentMeanFunctorBP(nd4j::LaunchContext* context , NDArray* input, NDArray* indices, NDArray* gradOut, NDArray* output) { + NDArray::prepareSpecialUse({output}, {input, indices, gradOut}); BUILD_DOUBLE_SELECTOR(output->dataType(), indices->dataType(), return segmentMeanFunctorBP_, (context, input, indices, gradOut, output), FLOAT_TYPES, INDEXING_TYPES); + NDArray::registerSpecialUse({output}, {input, indices, gradOut}); } // -------------------------------------------------------------------------------------------------------------- // @@ -399,7 +409,9 @@ namespace helpers { } // -------------------------------------------------------------------------------------------------------------- // int unsortedSegmentMeanFunctorBP(nd4j::LaunchContext* context , NDArray* input, NDArray* indices, NDArray* gradOut, Nd4jLong numOfClasses, NDArray* output) { + NDArray::prepareSpecialUse({output}, {input, indices, gradOut}); BUILD_DOUBLE_SELECTOR(output->dataType(), indices->dataType(), return unsortedSegmentMeanFunctorBP_, (context, input, indices, gradOut, numOfClasses, output), FLOAT_TYPES, INDEXING_TYPES); + NDArray::registerSpecialUse({output}, {input, indices, gradOut}); } } diff --git a/libnd4j/include/ops/declarable/helpers/cuda/segment_min.cu b/libnd4j/include/ops/declarable/helpers/cuda/segment_min.cu index 0c67b41d5..950abde67 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/segment_min.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/segment_min.cu @@ -192,7 +192,9 @@ namespace helpers { } // -------------------------------------------------------------------------------------------------------------- // void segmentMinFunctor(nd4j::LaunchContext* context , NDArray* input, NDArray* indices, NDArray* output) { + NDArray::prepareSpecialUse({output}, {input, indices}); BUILD_DOUBLE_SELECTOR(input->dataType(), indices->dataType(), segmentMinFunctor_, (context, input, indices, output), NUMERIC_TYPES, INDEXING_TYPES); + NDArray::registerSpecialUse({output}, {input, indices}); } // -------------------------------------------------------------------------------------------------------------- // @@ -233,8 +235,10 @@ namespace helpers { } // -------------------------------------------------------------------------------------------------------------- // void unsortedSegmentMinFunctor(nd4j::LaunchContext* context , NDArray* input, NDArray* indices, Nd4jLong numOfClasses, NDArray* output) { + NDArray::prepareSpecialUse({output}, {input, indices}); BUILD_DOUBLE_SELECTOR(input->dataType(), indices->dataType(), unsortedSegmentMinFunctor_, (context, input, indices, numOfClasses, output), NUMERIC_TYPES, INDEXING_TYPES); + NDArray::registerSpecialUse({output}, {input, indices}); } template @@ -364,8 +368,10 @@ namespace helpers { // -------------------------------------------------------------------------------------------------------------- // // segmen min int segmentMinFunctorBP(nd4j::LaunchContext* context , NDArray* input, NDArray* indices, NDArray* gradOut, NDArray* output) { + NDArray::prepareSpecialUse({output}, {input, indices, gradOut}); BUILD_DOUBLE_SELECTOR(output->dataType(), indices->dataType(), return segmentMinFunctorBP_, (context, input, indices, gradOut, output), FLOAT_TYPES, INDEXING_TYPES); + NDArray::registerSpecialUse({output}, {input, indices, gradOut}); } template @@ -409,7 +415,9 @@ namespace helpers { } // -------------------------------------------------------------------------------------------------------------- // int unsortedSegmentMinFunctorBP(nd4j::LaunchContext* context , NDArray* input, NDArray* indices, NDArray* gradOut, Nd4jLong numOfClasses, NDArray* output) { + NDArray::prepareSpecialUse({output}, {input, indices, gradOut}); BUILD_DOUBLE_SELECTOR(output->dataType(), indices->dataType(), return unsortedSegmentMinFunctorBP_, (context, input, indices, gradOut, numOfClasses, output), FLOAT_TYPES, INDEXING_TYPES); + NDArray::registerSpecialUse({output}, {input, indices, gradOut}); } } } diff --git a/libnd4j/include/ops/declarable/helpers/cuda/segment_prod.cu b/libnd4j/include/ops/declarable/helpers/cuda/segment_prod.cu index 78f21916d..3ae4ebcb8 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/segment_prod.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/segment_prod.cu @@ -192,7 +192,9 @@ namespace helpers { } // -------------------------------------------------------------------------------------------------------------- // void segmentProdFunctor(nd4j::LaunchContext* context , NDArray* input, NDArray* indices, NDArray* output) { + NDArray::prepareSpecialUse({output}, {input, indices}); BUILD_DOUBLE_SELECTOR(output->dataType(), indices->dataType(), segmentProdFunctor_, (context, input, indices, output), NUMERIC_TYPES, INDEXING_TYPES); + NDArray::registerSpecialUse({output}, {input, indices}); } // -------------------------------------------------------------------------------------------------------------- // @@ -231,8 +233,10 @@ namespace helpers { } // -------------------------------------------------------------------------------------------------------------- // void unsortedSegmentProdFunctor(nd4j::LaunchContext* context , NDArray* input, NDArray* indices, Nd4jLong numOfClasses, NDArray* output) { + NDArray::prepareSpecialUse({output}, {input, indices}); BUILD_DOUBLE_SELECTOR(input->dataType(), indices->dataType(), unsortedSegmentProdFunctor_, (context, input, indices, numOfClasses, output), NUMERIC_TYPES, INDEXING_TYPES); + NDArray::registerSpecialUse({output}, {input, indices}); } // -------------------------------------------------------------------------------------------------------------- // @@ -358,8 +362,10 @@ namespace helpers { // -------------------------------------------------------------------------------------------------------------- // int segmentProdFunctorBP(nd4j::LaunchContext* context , NDArray* input, NDArray* indices, NDArray* gradOut, NDArray* output) { + NDArray::prepareSpecialUse({output}, {input, indices, gradOut}); BUILD_DOUBLE_SELECTOR(output->dataType(), indices->dataType(), return segmentProdFunctorBP_, (context, input, indices, gradOut, output), FLOAT_TYPES, INDEXING_TYPES); + NDArray::registerSpecialUse({output}, {input, indices, gradOut}); } // -------------------------------------------------------------------------------------------------------------- // @@ -404,7 +410,9 @@ namespace helpers { // -------------------------------------------------------------------------------------------------------------- // int unsortedSegmentProdFunctorBP(nd4j::LaunchContext* context , NDArray* input, NDArray* indices, NDArray* gradOut, Nd4jLong numOfClasses, NDArray* output) { + NDArray::prepareSpecialUse({output}, {input, indices, gradOut}); BUILD_DOUBLE_SELECTOR(output->dataType(), indices->dataType(), return unsortedSegmentProdFunctorBP_, (context, input, indices, gradOut, numOfClasses, output), FLOAT_TYPES, INDEXING_TYPES); + NDArray::registerSpecialUse({output}, {input, indices, gradOut}); } // -------------------------------------------------------------------------------------------------------------- // diff --git a/libnd4j/include/ops/declarable/helpers/cuda/segment_sqrtn.cu b/libnd4j/include/ops/declarable/helpers/cuda/segment_sqrtn.cu index 4141cefba..229d41cc9 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/segment_sqrtn.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/segment_sqrtn.cu @@ -146,8 +146,10 @@ namespace helpers { } // -------------------------------------------------------------------------------------------------------------- // void unsortedSegmentSqrtNFunctor(nd4j::LaunchContext* context , NDArray* input, NDArray* indices, Nd4jLong numOfClasses, NDArray* output) { + NDArray::prepareSpecialUse({output}, {input, indices}); BUILD_DOUBLE_SELECTOR(input->dataType(), indices->dataType(), unsortedSegmentSqrtNFunctor_, (context, input, indices, numOfClasses, output), FLOAT_TYPES, INDEXING_TYPES); + NDArray::registerSpecialUse({output}, {input, indices}); } // -------------------------------------------------------------------------------------------------------------- // template @@ -270,7 +272,9 @@ namespace helpers { } // -------------------------------------------------------------------------------------------------------------- // int unsortedSegmentSqrtNFunctorBP(nd4j::LaunchContext* context , NDArray* input, NDArray* indices, NDArray* gradOut, Nd4jLong numOfClasses, NDArray* output) { + NDArray::prepareSpecialUse({output}, {input, indices, gradOut}); BUILD_DOUBLE_SELECTOR(output->dataType(), indices->dataType(), return unsortedSegmentSqrtNFunctorBP_, (context, input, indices, gradOut, numOfClasses, output), FLOAT_TYPES, INDEXING_TYPES); + NDArray::registerSpecialUse({output}, {input, indices, gradOut}); } } } diff --git a/libnd4j/include/ops/declarable/helpers/cuda/segment_sum.cu b/libnd4j/include/ops/declarable/helpers/cuda/segment_sum.cu index 37dacee09..08b36253a 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/segment_sum.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/segment_sum.cu @@ -190,7 +190,9 @@ namespace helpers { } // -------------------------------------------------------------------------------------------------------------- // void segmentSumFunctor(nd4j::LaunchContext* context , NDArray* input, NDArray* indices, NDArray* output) { + NDArray::prepareSpecialUse({output}, {input, indices}); BUILD_DOUBLE_SELECTOR(input->dataType(), indices->dataType(), segmentSumFunctor_, (context, input, indices, output), NUMERIC_TYPES, INDEXING_TYPES); + NDArray::registerSpecialUse({output}, {input, indices}); } // -------------------------------------------------------------------------------------------------------------- // @@ -229,8 +231,10 @@ namespace helpers { } // -------------------------------------------------------------------------------------------------------------- // void unsortedSegmentSumFunctor(nd4j::LaunchContext* context , NDArray* input, NDArray* indices, Nd4jLong numOfClasses, NDArray* output) { + NDArray::prepareSpecialUse({output}, {input, indices}); BUILD_DOUBLE_SELECTOR(input->dataType(), indices->dataType(), unsortedSegmentSumFunctor_, (context, input, indices, numOfClasses, output), NUMERIC_TYPES, INDEXING_TYPES); + NDArray::registerSpecialUse({output}, {input, indices}); } @@ -343,8 +347,10 @@ namespace helpers { // -------------------------------------------------------------------------------------------------------------- // int segmentSumFunctorBP(nd4j::LaunchContext* context , NDArray* input, NDArray* indices, NDArray* gradOut, NDArray* output) { + NDArray::prepareSpecialUse({output}, {input, indices, gradOut}); BUILD_DOUBLE_SELECTOR(output->dataType(), indices->dataType(), return segmentSumFunctorBP_, (context, input, indices, gradOut, output), FLOAT_TYPES, INDEXING_TYPES); + NDArray::registerSpecialUse({output}, {input, indices, gradOut}); } template @@ -381,7 +387,9 @@ namespace helpers { } // -------------------------------------------------------------------------------------------------------------- // int unsortedSegmentSumFunctorBP(nd4j::LaunchContext* context , NDArray* input, NDArray* indices, NDArray* gradOut, Nd4jLong numOfClasses, NDArray* output) { + NDArray::prepareSpecialUse({output}, {input, indices, gradOut}); BUILD_DOUBLE_SELECTOR(output->dataType(), indices->dataType(), return unsortedSegmentSumFunctorBP_, (context, input, indices, gradOut, numOfClasses, output), FLOAT_TYPES, INDEXING_TYPES); + NDArray::registerSpecialUse({output}, {input, indices, gradOut}); } } diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests7.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests7.cpp index 2e1dab1a3..b0488c23a 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests7.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests7.cpp @@ -1423,6 +1423,42 @@ TEST_F(DeclarableOpsTests7, TestSegmentMean_2) { delete result; } +TEST_F(DeclarableOpsTests7, TestSegmentMean_02) { + auto x = NDArrayFactory::create('c', {6, 3}, {1, 2, 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14., 15., 16., 17., 18.}); + auto idx = NDArrayFactory::create({0, 0, 1, 1, 2,2}); + auto exp = NDArrayFactory::create('c', {3, 3}, { 2.5, 3.5, 4.5, 8.5, 9.5, 10.5, 14.5, 15.5, 16.5}); + + nd4j::ops::segment_mean op; + + auto result = op.execute({&x, &idx}, {}, {}); + ASSERT_EQ(result->status(), Status::OK()); + ASSERT_EQ(result->size(), 1); + exp.printIndexedBuffer("Expect Mean"); + result->at(0)->printIndexedBuffer("Output Mean"); +// exp.printShapeInfo("Exp Shape"); + ASSERT_TRUE(exp.equalsTo(result->at(0))); + + delete result; +} + +TEST_F(DeclarableOpsTests7, TestSegmentMean_021) { + auto x = NDArrayFactory::create('c', {6, 3});//, {1, 2, 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14., 15., 16., 17., 18.}); + auto idx = NDArrayFactory::create({0, 0, 1, 1, 2,2}); + auto exp = NDArrayFactory::create('c', {3, 3}, { 2.5, 3.5, 4.5, 8.5, 9.5, 10.5, 14.5, 15.5, 16.5}); + + nd4j::ops::segment_mean op; + x.linspace(1.); + auto result = op.execute({&x, &idx}, {}, {}); + ASSERT_EQ(result->status(), Status::OK()); + ASSERT_EQ(result->size(), 1); + exp.printIndexedBuffer("Expect Mean"); + result->at(0)->printIndexedBuffer("Output Mean"); +// exp.printShapeInfo("Exp Shape"); + ASSERT_TRUE(exp.equalsTo(result->at(0))); + + delete result; +} + //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, TestSegmentMeanBP_2) { auto x = NDArrayFactory::create('c', {4, 4}, {1.8, 2.5, 4., 9.,2.1, 2.4, 3., 9.,2.1, 2.1, 0.7, 0.1,3., 4.2, 2.2, 1.});