Shugeo segment fix2 (#185)

* Added test for segment_mean.

* Added another test for segment_mean.

* Fixed segment_* ops helpers for cuda to proper use external data.
master
shugeo 2019-08-27 18:25:39 +03:00 committed by raver119
parent 9d325ad070
commit 0849b3c1a4
7 changed files with 85 additions and 1 deletions

View File

@ -201,7 +201,9 @@ namespace nd4j {
} }
// -------------------------------------------------------------------------------------------------------------- // // -------------------------------------------------------------------------------------------------------------- //
void segmentMaxFunctor(nd4j::LaunchContext* context , NDArray* input, NDArray* indices, NDArray* output) { void segmentMaxFunctor(nd4j::LaunchContext* context , NDArray* input, NDArray* indices, NDArray* output) {
NDArray::prepareSpecialUse({output}, {input, indices});
BUILD_DOUBLE_SELECTOR(input->dataType(), indices->dataType(), segmentMaxFunctor_, (context, input, indices, output), NUMERIC_TYPES, INDEXING_TYPES); BUILD_DOUBLE_SELECTOR(input->dataType(), indices->dataType(), segmentMaxFunctor_, (context, input, indices, output), NUMERIC_TYPES, INDEXING_TYPES);
NDArray::registerSpecialUse({output}, {input, indices});
} }
// -------------------------------------------------------------------------------------------------------------- // // -------------------------------------------------------------------------------------------------------------- //
@ -240,7 +242,9 @@ namespace nd4j {
} }
// -------------------------------------------------------------------------------------------------------------- // // -------------------------------------------------------------------------------------------------------------- //
void unsortedSegmentMaxFunctor(nd4j::LaunchContext* context, NDArray* input, NDArray* indices, Nd4jLong numOfClasses, NDArray* output) { void unsortedSegmentMaxFunctor(nd4j::LaunchContext* context, NDArray* input, NDArray* indices, Nd4jLong numOfClasses, NDArray* output) {
NDArray::prepareSpecialUse({output}, {input, indices});
BUILD_DOUBLE_SELECTOR(input->dataType(), indices->dataType(), unsortedSegmentMaxFunctor_, (context, input, indices, numOfClasses, output), NUMERIC_TYPES, INDEXING_TYPES); BUILD_DOUBLE_SELECTOR(input->dataType(), indices->dataType(), unsortedSegmentMaxFunctor_, (context, input, indices, numOfClasses, output), NUMERIC_TYPES, INDEXING_TYPES);
NDArray::registerSpecialUse({output}, {input, indices});
} }
// -------------------------------------------------------------------------------------------------------------- // // -------------------------------------------------------------------------------------------------------------- //
@ -370,8 +374,10 @@ namespace nd4j {
} }
// -------------------------------------------------------------------------------------------------------------- // // -------------------------------------------------------------------------------------------------------------- //
int segmentMaxFunctorBP(nd4j::LaunchContext* context , NDArray* input, NDArray* indices, NDArray* gradOut, NDArray* output) { int segmentMaxFunctorBP(nd4j::LaunchContext* context , NDArray* input, NDArray* indices, NDArray* gradOut, NDArray* output) {
NDArray::prepareSpecialUse({output}, {input, indices, gradOut});
BUILD_DOUBLE_SELECTOR(output->dataType(), indices->dataType(), return segmentMaxFunctorBP_, (context, input, BUILD_DOUBLE_SELECTOR(output->dataType(), indices->dataType(), return segmentMaxFunctorBP_, (context, input,
indices, gradOut, output), FLOAT_TYPES, INDEXING_TYPES); indices, gradOut, output), FLOAT_TYPES, INDEXING_TYPES);
NDArray::registerSpecialUse({output}, {input, indices, gradOut});
} }
// -------------------------------------------------------------------------------------------------------------- // // -------------------------------------------------------------------------------------------------------------- //
@ -416,7 +422,9 @@ namespace nd4j {
} }
// -------------------------------------------------------------------------------------------------------------- // // -------------------------------------------------------------------------------------------------------------- //
int unsortedSegmentMaxFunctorBP(nd4j::LaunchContext* context , NDArray* input, NDArray* indices, NDArray* gradOut, Nd4jLong numOfClasses, NDArray* output) { int unsortedSegmentMaxFunctorBP(nd4j::LaunchContext* context , NDArray* input, NDArray* indices, NDArray* gradOut, Nd4jLong numOfClasses, NDArray* output) {
NDArray::prepareSpecialUse({output}, {input, indices, gradOut});
BUILD_DOUBLE_SELECTOR(output->dataType(), indices->dataType(), return unsortedSegmentMaxFunctorBP_, (context, input, indices, gradOut, numOfClasses, output), FLOAT_TYPES, INDEXING_TYPES); BUILD_DOUBLE_SELECTOR(output->dataType(), indices->dataType(), return unsortedSegmentMaxFunctorBP_, (context, input, indices, gradOut, numOfClasses, output), FLOAT_TYPES, INDEXING_TYPES);
NDArray::registerSpecialUse({output}, {input, indices, gradOut});
} }
} }
} }

View File

@ -163,7 +163,7 @@ namespace helpers {
classesRangesBegs.assign(indices->lengthOf()); classesRangesBegs.assign(indices->lengthOf());
classesRangesLens.assign(0); classesRangesLens.assign(0);
NDArray::prepareSpecialUse({output}, {input, indices});
dim3 dims(numClasses, indices->lengthOf(), numClasses * 32 + 32); dim3 dims(numClasses, indices->lengthOf(), numClasses * 32 + 32);
int* begins = reinterpret_cast<int*>(classesRangesBegs.specialBuffer()); int* begins = reinterpret_cast<int*>(classesRangesBegs.specialBuffer());
int* lengths = reinterpret_cast<int*>(classesRangesLens.specialBuffer()); int* lengths = reinterpret_cast<int*>(classesRangesLens.specialBuffer());
@ -182,11 +182,14 @@ namespace helpers {
Nd4jLong* outputTadOffsets = packZ.specialOffsets(); Nd4jLong* outputTadOffsets = packZ.specialOffsets();
segmentMeanTadKernel<T,I><<<input->sizeAt(0), 512, 2048, *stream>>>(input->specialBuffer(), input->specialShapeInfo(), inputTads, inputTadOffsets, reinterpret_cast<I*>(indices->specialBuffer()), begins, lengths, numClasses, output->specialBuffer(), output->specialShapeInfo(), outputTads, outputTadOffsets); segmentMeanTadKernel<T,I><<<input->sizeAt(0), 512, 2048, *stream>>>(input->specialBuffer(), input->specialShapeInfo(), inputTads, inputTadOffsets, reinterpret_cast<I*>(indices->specialBuffer()), begins, lengths, numClasses, output->specialBuffer(), output->specialShapeInfo(), outputTads, outputTadOffsets);
} }
NDArray::registerSpecialUse({output}, {input, indices});
} }
// -------------------------------------------------------------------------------------------------------------- // // -------------------------------------------------------------------------------------------------------------- //
void segmentMeanFunctor(nd4j::LaunchContext* context , NDArray* input, NDArray* indices, NDArray* output) { void segmentMeanFunctor(nd4j::LaunchContext* context , NDArray* input, NDArray* indices, NDArray* output) {
NDArray::prepareSpecialUse({output}, {input, indices});
BUILD_DOUBLE_SELECTOR(output->dataType(), indices->dataType(), segmentMeanFunctor_, (context, input, indices, output), NUMERIC_TYPES, INDEXING_TYPES); BUILD_DOUBLE_SELECTOR(output->dataType(), indices->dataType(), segmentMeanFunctor_, (context, input, indices, output), NUMERIC_TYPES, INDEXING_TYPES);
NDArray::registerSpecialUse({output}, {input, indices});
} }
// -------------------------------------------------------------------------------------------------------------- // // -------------------------------------------------------------------------------------------------------------- //
@ -194,6 +197,8 @@ namespace helpers {
static void unsortedSegmentMeanFunctor_(nd4j::LaunchContext* context, NDArray* input, NDArray* indices, Nd4jLong numOfClasses, NDArray* output) { static void unsortedSegmentMeanFunctor_(nd4j::LaunchContext* context, NDArray* input, NDArray* indices, Nd4jLong numOfClasses, NDArray* output) {
auto stream = context->getCudaStream(); auto stream = context->getCudaStream();
// NDArray classes = NDArrayFactory::create<int>('c', {numOfClasses, 2}); // NDArray classes = NDArrayFactory::create<int>('c', {numOfClasses, 2});
NDArray::prepareSpecialUse({output}, {input, indices});
NDArray classesRangesBegs = NDArrayFactory::create<int>('c', {numOfClasses}); NDArray classesRangesBegs = NDArrayFactory::create<int>('c', {numOfClasses});
NDArray classesRangesLens = NDArrayFactory::create<int>('c', {numOfClasses}); NDArray classesRangesLens = NDArrayFactory::create<int>('c', {numOfClasses});
// NDArray row = NDArrayFactory::create<int>('c', {1, 2}, {(int)indices->lengthOf(), (int)0}); // NDArray row = NDArrayFactory::create<int>('c', {1, 2}, {(int)indices->lengthOf(), (int)0});
@ -221,12 +226,15 @@ namespace helpers {
dims.x = input->sizeAt(0); dims.x = input->sizeAt(0);
segmentMeanTadKernel<T,I><<<dims.x, dims.y, dims.z, *stream>>>(input->specialBuffer(), input->specialShapeInfo(), inputTads, inputTadOffsets, reinterpret_cast<I*>(indices->specialBuffer()), begins, lengths, numOfClasses, output->specialBuffer(), output->specialShapeInfo(), outputTads, outputTadOffsets); segmentMeanTadKernel<T,I><<<dims.x, dims.y, dims.z, *stream>>>(input->specialBuffer(), input->specialShapeInfo(), inputTads, inputTadOffsets, reinterpret_cast<I*>(indices->specialBuffer()), begins, lengths, numOfClasses, output->specialBuffer(), output->specialShapeInfo(), outputTads, outputTadOffsets);
} }
NDArray::registerSpecialUse({output}, {input, indices});
} }
// -------------------------------------------------------------------------------------------------------------- // // -------------------------------------------------------------------------------------------------------------- //
void unsortedSegmentMeanFunctor(nd4j::LaunchContext* context , NDArray* input, NDArray* indices, Nd4jLong numOfClasses, NDArray* output) { void unsortedSegmentMeanFunctor(nd4j::LaunchContext* context , NDArray* input, NDArray* indices, Nd4jLong numOfClasses, NDArray* output) {
NDArray::prepareSpecialUse({output}, {input, indices});
BUILD_DOUBLE_SELECTOR(input->dataType(), indices->dataType(), unsortedSegmentMeanFunctor_, (context, input, indices, numOfClasses, output), BUILD_DOUBLE_SELECTOR(input->dataType(), indices->dataType(), unsortedSegmentMeanFunctor_, (context, input, indices, numOfClasses, output),
NUMERIC_TYPES, INDEXING_TYPES); NUMERIC_TYPES, INDEXING_TYPES);
NDArray::prepareSpecialUse({output}, {input, indices});
} }
// -------------------------------------------------------------------------------------------------------------- // // -------------------------------------------------------------------------------------------------------------- //
@ -349,8 +357,10 @@ namespace helpers {
// -------------------------------------------------------------------------------------------------------------- // // -------------------------------------------------------------------------------------------------------------- //
// segmen mean bp main // segmen mean bp main
int segmentMeanFunctorBP(nd4j::LaunchContext* context , NDArray* input, NDArray* indices, NDArray* gradOut, NDArray* output) { int segmentMeanFunctorBP(nd4j::LaunchContext* context , NDArray* input, NDArray* indices, NDArray* gradOut, NDArray* output) {
NDArray::prepareSpecialUse({output}, {input, indices, gradOut});
BUILD_DOUBLE_SELECTOR(output->dataType(), indices->dataType(), return segmentMeanFunctorBP_, (context, input, BUILD_DOUBLE_SELECTOR(output->dataType(), indices->dataType(), return segmentMeanFunctorBP_, (context, input,
indices, gradOut, output), FLOAT_TYPES, INDEXING_TYPES); indices, gradOut, output), FLOAT_TYPES, INDEXING_TYPES);
NDArray::registerSpecialUse({output}, {input, indices, gradOut});
} }
// -------------------------------------------------------------------------------------------------------------- // // -------------------------------------------------------------------------------------------------------------- //
@ -399,7 +409,9 @@ namespace helpers {
} }
// -------------------------------------------------------------------------------------------------------------- // // -------------------------------------------------------------------------------------------------------------- //
int unsortedSegmentMeanFunctorBP(nd4j::LaunchContext* context , NDArray* input, NDArray* indices, NDArray* gradOut, Nd4jLong numOfClasses, NDArray* output) { int unsortedSegmentMeanFunctorBP(nd4j::LaunchContext* context , NDArray* input, NDArray* indices, NDArray* gradOut, Nd4jLong numOfClasses, NDArray* output) {
NDArray::prepareSpecialUse({output}, {input, indices, gradOut});
BUILD_DOUBLE_SELECTOR(output->dataType(), indices->dataType(), return unsortedSegmentMeanFunctorBP_, (context, input, indices, gradOut, numOfClasses, output), FLOAT_TYPES, INDEXING_TYPES); BUILD_DOUBLE_SELECTOR(output->dataType(), indices->dataType(), return unsortedSegmentMeanFunctorBP_, (context, input, indices, gradOut, numOfClasses, output), FLOAT_TYPES, INDEXING_TYPES);
NDArray::registerSpecialUse({output}, {input, indices, gradOut});
} }
} }

View File

@ -192,7 +192,9 @@ namespace helpers {
} }
// -------------------------------------------------------------------------------------------------------------- // // -------------------------------------------------------------------------------------------------------------- //
void segmentMinFunctor(nd4j::LaunchContext* context , NDArray* input, NDArray* indices, NDArray* output) { void segmentMinFunctor(nd4j::LaunchContext* context , NDArray* input, NDArray* indices, NDArray* output) {
NDArray::prepareSpecialUse({output}, {input, indices});
BUILD_DOUBLE_SELECTOR(input->dataType(), indices->dataType(), segmentMinFunctor_, (context, input, indices, output), NUMERIC_TYPES, INDEXING_TYPES); BUILD_DOUBLE_SELECTOR(input->dataType(), indices->dataType(), segmentMinFunctor_, (context, input, indices, output), NUMERIC_TYPES, INDEXING_TYPES);
NDArray::registerSpecialUse({output}, {input, indices});
} }
// -------------------------------------------------------------------------------------------------------------- // // -------------------------------------------------------------------------------------------------------------- //
@ -233,8 +235,10 @@ namespace helpers {
} }
// -------------------------------------------------------------------------------------------------------------- // // -------------------------------------------------------------------------------------------------------------- //
void unsortedSegmentMinFunctor(nd4j::LaunchContext* context , NDArray* input, NDArray* indices, Nd4jLong numOfClasses, NDArray* output) { void unsortedSegmentMinFunctor(nd4j::LaunchContext* context , NDArray* input, NDArray* indices, Nd4jLong numOfClasses, NDArray* output) {
NDArray::prepareSpecialUse({output}, {input, indices});
BUILD_DOUBLE_SELECTOR(input->dataType(), indices->dataType(), unsortedSegmentMinFunctor_, (context, input, indices, numOfClasses, output), BUILD_DOUBLE_SELECTOR(input->dataType(), indices->dataType(), unsortedSegmentMinFunctor_, (context, input, indices, numOfClasses, output),
NUMERIC_TYPES, INDEXING_TYPES); NUMERIC_TYPES, INDEXING_TYPES);
NDArray::registerSpecialUse({output}, {input, indices});
} }
template <typename T, typename I> template <typename T, typename I>
@ -364,8 +368,10 @@ namespace helpers {
// -------------------------------------------------------------------------------------------------------------- // // -------------------------------------------------------------------------------------------------------------- //
// segmen min // segmen min
int segmentMinFunctorBP(nd4j::LaunchContext* context , NDArray* input, NDArray* indices, NDArray* gradOut, NDArray* output) { int segmentMinFunctorBP(nd4j::LaunchContext* context , NDArray* input, NDArray* indices, NDArray* gradOut, NDArray* output) {
NDArray::prepareSpecialUse({output}, {input, indices, gradOut});
BUILD_DOUBLE_SELECTOR(output->dataType(), indices->dataType(), return segmentMinFunctorBP_, (context, input, BUILD_DOUBLE_SELECTOR(output->dataType(), indices->dataType(), return segmentMinFunctorBP_, (context, input,
indices, gradOut, output), FLOAT_TYPES, INDEXING_TYPES); indices, gradOut, output), FLOAT_TYPES, INDEXING_TYPES);
NDArray::registerSpecialUse({output}, {input, indices, gradOut});
} }
template <typename T, typename I> template <typename T, typename I>
@ -409,7 +415,9 @@ namespace helpers {
} }
// -------------------------------------------------------------------------------------------------------------- // // -------------------------------------------------------------------------------------------------------------- //
int unsortedSegmentMinFunctorBP(nd4j::LaunchContext* context , NDArray* input, NDArray* indices, NDArray* gradOut, Nd4jLong numOfClasses, NDArray* output) { int unsortedSegmentMinFunctorBP(nd4j::LaunchContext* context , NDArray* input, NDArray* indices, NDArray* gradOut, Nd4jLong numOfClasses, NDArray* output) {
NDArray::prepareSpecialUse({output}, {input, indices, gradOut});
BUILD_DOUBLE_SELECTOR(output->dataType(), indices->dataType(), return unsortedSegmentMinFunctorBP_, (context, input, indices, gradOut, numOfClasses, output), FLOAT_TYPES, INDEXING_TYPES); BUILD_DOUBLE_SELECTOR(output->dataType(), indices->dataType(), return unsortedSegmentMinFunctorBP_, (context, input, indices, gradOut, numOfClasses, output), FLOAT_TYPES, INDEXING_TYPES);
NDArray::registerSpecialUse({output}, {input, indices, gradOut});
} }
} }
} }

View File

@ -192,7 +192,9 @@ namespace helpers {
} }
// -------------------------------------------------------------------------------------------------------------- // // -------------------------------------------------------------------------------------------------------------- //
void segmentProdFunctor(nd4j::LaunchContext* context , NDArray* input, NDArray* indices, NDArray* output) { void segmentProdFunctor(nd4j::LaunchContext* context , NDArray* input, NDArray* indices, NDArray* output) {
NDArray::prepareSpecialUse({output}, {input, indices});
BUILD_DOUBLE_SELECTOR(output->dataType(), indices->dataType(), segmentProdFunctor_, (context, input, indices, output), NUMERIC_TYPES, INDEXING_TYPES); BUILD_DOUBLE_SELECTOR(output->dataType(), indices->dataType(), segmentProdFunctor_, (context, input, indices, output), NUMERIC_TYPES, INDEXING_TYPES);
NDArray::registerSpecialUse({output}, {input, indices});
} }
// -------------------------------------------------------------------------------------------------------------- // // -------------------------------------------------------------------------------------------------------------- //
@ -231,8 +233,10 @@ namespace helpers {
} }
// -------------------------------------------------------------------------------------------------------------- // // -------------------------------------------------------------------------------------------------------------- //
void unsortedSegmentProdFunctor(nd4j::LaunchContext* context , NDArray* input, NDArray* indices, Nd4jLong numOfClasses, NDArray* output) { void unsortedSegmentProdFunctor(nd4j::LaunchContext* context , NDArray* input, NDArray* indices, Nd4jLong numOfClasses, NDArray* output) {
NDArray::prepareSpecialUse({output}, {input, indices});
BUILD_DOUBLE_SELECTOR(input->dataType(), indices->dataType(), unsortedSegmentProdFunctor_, (context, input, indices, numOfClasses, output), BUILD_DOUBLE_SELECTOR(input->dataType(), indices->dataType(), unsortedSegmentProdFunctor_, (context, input, indices, numOfClasses, output),
NUMERIC_TYPES, INDEXING_TYPES); NUMERIC_TYPES, INDEXING_TYPES);
NDArray::registerSpecialUse({output}, {input, indices});
} }
// -------------------------------------------------------------------------------------------------------------- // // -------------------------------------------------------------------------------------------------------------- //
@ -358,8 +362,10 @@ namespace helpers {
// -------------------------------------------------------------------------------------------------------------- // // -------------------------------------------------------------------------------------------------------------- //
int segmentProdFunctorBP(nd4j::LaunchContext* context , NDArray* input, NDArray* indices, NDArray* gradOut, NDArray* output) { int segmentProdFunctorBP(nd4j::LaunchContext* context , NDArray* input, NDArray* indices, NDArray* gradOut, NDArray* output) {
NDArray::prepareSpecialUse({output}, {input, indices, gradOut});
BUILD_DOUBLE_SELECTOR(output->dataType(), indices->dataType(), return segmentProdFunctorBP_, (context, input, BUILD_DOUBLE_SELECTOR(output->dataType(), indices->dataType(), return segmentProdFunctorBP_, (context, input,
indices, gradOut, output), FLOAT_TYPES, INDEXING_TYPES); indices, gradOut, output), FLOAT_TYPES, INDEXING_TYPES);
NDArray::registerSpecialUse({output}, {input, indices, gradOut});
} }
// -------------------------------------------------------------------------------------------------------------- // // -------------------------------------------------------------------------------------------------------------- //
@ -404,7 +410,9 @@ namespace helpers {
// -------------------------------------------------------------------------------------------------------------- // // -------------------------------------------------------------------------------------------------------------- //
int unsortedSegmentProdFunctorBP(nd4j::LaunchContext* context , NDArray* input, NDArray* indices, NDArray* gradOut, Nd4jLong numOfClasses, NDArray* output) { int unsortedSegmentProdFunctorBP(nd4j::LaunchContext* context , NDArray* input, NDArray* indices, NDArray* gradOut, Nd4jLong numOfClasses, NDArray* output) {
NDArray::prepareSpecialUse({output}, {input, indices, gradOut});
BUILD_DOUBLE_SELECTOR(output->dataType(), indices->dataType(), return unsortedSegmentProdFunctorBP_, (context, input, indices, gradOut, numOfClasses, output), FLOAT_TYPES, INDEXING_TYPES); BUILD_DOUBLE_SELECTOR(output->dataType(), indices->dataType(), return unsortedSegmentProdFunctorBP_, (context, input, indices, gradOut, numOfClasses, output), FLOAT_TYPES, INDEXING_TYPES);
NDArray::registerSpecialUse({output}, {input, indices, gradOut});
} }
// -------------------------------------------------------------------------------------------------------------- // // -------------------------------------------------------------------------------------------------------------- //

View File

@ -146,8 +146,10 @@ namespace helpers {
} }
// -------------------------------------------------------------------------------------------------------------- // // -------------------------------------------------------------------------------------------------------------- //
void unsortedSegmentSqrtNFunctor(nd4j::LaunchContext* context , NDArray* input, NDArray* indices, Nd4jLong numOfClasses, NDArray* output) { void unsortedSegmentSqrtNFunctor(nd4j::LaunchContext* context , NDArray* input, NDArray* indices, Nd4jLong numOfClasses, NDArray* output) {
NDArray::prepareSpecialUse({output}, {input, indices});
BUILD_DOUBLE_SELECTOR(input->dataType(), indices->dataType(), unsortedSegmentSqrtNFunctor_, (context, input, indices, numOfClasses, output), BUILD_DOUBLE_SELECTOR(input->dataType(), indices->dataType(), unsortedSegmentSqrtNFunctor_, (context, input, indices, numOfClasses, output),
FLOAT_TYPES, INDEXING_TYPES); FLOAT_TYPES, INDEXING_TYPES);
NDArray::registerSpecialUse({output}, {input, indices});
} }
// -------------------------------------------------------------------------------------------------------------- // // -------------------------------------------------------------------------------------------------------------- //
template <typename T, typename I> template <typename T, typename I>
@ -270,7 +272,9 @@ namespace helpers {
} }
// -------------------------------------------------------------------------------------------------------------- // // -------------------------------------------------------------------------------------------------------------- //
int unsortedSegmentSqrtNFunctorBP(nd4j::LaunchContext* context , NDArray* input, NDArray* indices, NDArray* gradOut, Nd4jLong numOfClasses, NDArray* output) { int unsortedSegmentSqrtNFunctorBP(nd4j::LaunchContext* context , NDArray* input, NDArray* indices, NDArray* gradOut, Nd4jLong numOfClasses, NDArray* output) {
NDArray::prepareSpecialUse({output}, {input, indices, gradOut});
BUILD_DOUBLE_SELECTOR(output->dataType(), indices->dataType(), return unsortedSegmentSqrtNFunctorBP_, (context, input, indices, gradOut, numOfClasses, output), FLOAT_TYPES, INDEXING_TYPES); BUILD_DOUBLE_SELECTOR(output->dataType(), indices->dataType(), return unsortedSegmentSqrtNFunctorBP_, (context, input, indices, gradOut, numOfClasses, output), FLOAT_TYPES, INDEXING_TYPES);
NDArray::registerSpecialUse({output}, {input, indices, gradOut});
} }
} }
} }

View File

@ -190,7 +190,9 @@ namespace helpers {
} }
// -------------------------------------------------------------------------------------------------------------- // // -------------------------------------------------------------------------------------------------------------- //
void segmentSumFunctor(nd4j::LaunchContext* context , NDArray* input, NDArray* indices, NDArray* output) { void segmentSumFunctor(nd4j::LaunchContext* context , NDArray* input, NDArray* indices, NDArray* output) {
NDArray::prepareSpecialUse({output}, {input, indices});
BUILD_DOUBLE_SELECTOR(input->dataType(), indices->dataType(), segmentSumFunctor_, (context, input, indices, output), NUMERIC_TYPES, INDEXING_TYPES); BUILD_DOUBLE_SELECTOR(input->dataType(), indices->dataType(), segmentSumFunctor_, (context, input, indices, output), NUMERIC_TYPES, INDEXING_TYPES);
NDArray::registerSpecialUse({output}, {input, indices});
} }
// -------------------------------------------------------------------------------------------------------------- // // -------------------------------------------------------------------------------------------------------------- //
@ -229,8 +231,10 @@ namespace helpers {
} }
// -------------------------------------------------------------------------------------------------------------- // // -------------------------------------------------------------------------------------------------------------- //
void unsortedSegmentSumFunctor(nd4j::LaunchContext* context , NDArray* input, NDArray* indices, Nd4jLong numOfClasses, NDArray* output) { void unsortedSegmentSumFunctor(nd4j::LaunchContext* context , NDArray* input, NDArray* indices, Nd4jLong numOfClasses, NDArray* output) {
NDArray::prepareSpecialUse({output}, {input, indices});
BUILD_DOUBLE_SELECTOR(input->dataType(), indices->dataType(), unsortedSegmentSumFunctor_, (context, input, indices, numOfClasses, output), BUILD_DOUBLE_SELECTOR(input->dataType(), indices->dataType(), unsortedSegmentSumFunctor_, (context, input, indices, numOfClasses, output),
NUMERIC_TYPES, INDEXING_TYPES); NUMERIC_TYPES, INDEXING_TYPES);
NDArray::registerSpecialUse({output}, {input, indices});
} }
@ -343,8 +347,10 @@ namespace helpers {
// -------------------------------------------------------------------------------------------------------------- // // -------------------------------------------------------------------------------------------------------------- //
int segmentSumFunctorBP(nd4j::LaunchContext* context , NDArray* input, NDArray* indices, NDArray* gradOut, NDArray* output) { int segmentSumFunctorBP(nd4j::LaunchContext* context , NDArray* input, NDArray* indices, NDArray* gradOut, NDArray* output) {
NDArray::prepareSpecialUse({output}, {input, indices, gradOut});
BUILD_DOUBLE_SELECTOR(output->dataType(), indices->dataType(), return segmentSumFunctorBP_, (context, input, BUILD_DOUBLE_SELECTOR(output->dataType(), indices->dataType(), return segmentSumFunctorBP_, (context, input,
indices, gradOut, output), FLOAT_TYPES, INDEXING_TYPES); indices, gradOut, output), FLOAT_TYPES, INDEXING_TYPES);
NDArray::registerSpecialUse({output}, {input, indices, gradOut});
} }
template <typename T, typename I> template <typename T, typename I>
@ -381,7 +387,9 @@ namespace helpers {
} }
// -------------------------------------------------------------------------------------------------------------- // // -------------------------------------------------------------------------------------------------------------- //
int unsortedSegmentSumFunctorBP(nd4j::LaunchContext* context , NDArray* input, NDArray* indices, NDArray* gradOut, Nd4jLong numOfClasses, NDArray* output) { int unsortedSegmentSumFunctorBP(nd4j::LaunchContext* context , NDArray* input, NDArray* indices, NDArray* gradOut, Nd4jLong numOfClasses, NDArray* output) {
NDArray::prepareSpecialUse({output}, {input, indices, gradOut});
BUILD_DOUBLE_SELECTOR(output->dataType(), indices->dataType(), return unsortedSegmentSumFunctorBP_, (context, input, indices, gradOut, numOfClasses, output), FLOAT_TYPES, INDEXING_TYPES); BUILD_DOUBLE_SELECTOR(output->dataType(), indices->dataType(), return unsortedSegmentSumFunctorBP_, (context, input, indices, gradOut, numOfClasses, output), FLOAT_TYPES, INDEXING_TYPES);
NDArray::registerSpecialUse({output}, {input, indices, gradOut});
} }
} }

View File

@ -1423,6 +1423,42 @@ TEST_F(DeclarableOpsTests7, TestSegmentMean_2) {
delete result; delete result;
} }
TEST_F(DeclarableOpsTests7, TestSegmentMean_02) {
auto x = NDArrayFactory::create<double>('c', {6, 3}, {1, 2, 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14., 15., 16., 17., 18.});
auto idx = NDArrayFactory::create<int>({0, 0, 1, 1, 2,2});
auto exp = NDArrayFactory::create<double>('c', {3, 3}, { 2.5, 3.5, 4.5, 8.5, 9.5, 10.5, 14.5, 15.5, 16.5});
nd4j::ops::segment_mean op;
auto result = op.execute({&x, &idx}, {}, {});
ASSERT_EQ(result->status(), Status::OK());
ASSERT_EQ(result->size(), 1);
exp.printIndexedBuffer("Expect Mean");
result->at(0)->printIndexedBuffer("Output Mean");
// exp.printShapeInfo("Exp Shape");
ASSERT_TRUE(exp.equalsTo(result->at(0)));
delete result;
}
TEST_F(DeclarableOpsTests7, TestSegmentMean_021) {
auto x = NDArrayFactory::create<float>('c', {6, 3});//, {1, 2, 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14., 15., 16., 17., 18.});
auto idx = NDArrayFactory::create<int>({0, 0, 1, 1, 2,2});
auto exp = NDArrayFactory::create<float>('c', {3, 3}, { 2.5, 3.5, 4.5, 8.5, 9.5, 10.5, 14.5, 15.5, 16.5});
nd4j::ops::segment_mean op;
x.linspace(1.);
auto result = op.execute({&x, &idx}, {}, {});
ASSERT_EQ(result->status(), Status::OK());
ASSERT_EQ(result->size(), 1);
exp.printIndexedBuffer("Expect Mean");
result->at(0)->printIndexedBuffer("Output Mean");
// exp.printShapeInfo("Exp Shape");
ASSERT_TRUE(exp.equalsTo(result->at(0)));
delete result;
}
//////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests7, TestSegmentMeanBP_2) { TEST_F(DeclarableOpsTests7, TestSegmentMeanBP_2) {
auto x = NDArrayFactory::create<double>('c', {4, 4}, {1.8, 2.5, 4., 9.,2.1, 2.4, 3., 9.,2.1, 2.1, 0.7, 0.1,3., 4.2, 2.2, 1.}); auto x = NDArrayFactory::create<double>('c', {4, 4}, {1.8, 2.5, 4., 9.,2.1, 2.4, 3., 9.,2.1, 2.1, 0.7, 0.1,3., 4.2, 2.2, 1.});