Shugeo segment fix2 (#185)
* Added test for segment_mean. * Added another test for segment_mean. * Fixed segment_* ops helpers for cuda to proper use external data.master
parent
9d325ad070
commit
0849b3c1a4
|
@ -201,7 +201,9 @@ namespace nd4j {
|
||||||
}
|
}
|
||||||
// -------------------------------------------------------------------------------------------------------------- //
|
// -------------------------------------------------------------------------------------------------------------- //
|
||||||
void segmentMaxFunctor(nd4j::LaunchContext* context , NDArray* input, NDArray* indices, NDArray* output) {
|
void segmentMaxFunctor(nd4j::LaunchContext* context , NDArray* input, NDArray* indices, NDArray* output) {
|
||||||
|
NDArray::prepareSpecialUse({output}, {input, indices});
|
||||||
BUILD_DOUBLE_SELECTOR(input->dataType(), indices->dataType(), segmentMaxFunctor_, (context, input, indices, output), NUMERIC_TYPES, INDEXING_TYPES);
|
BUILD_DOUBLE_SELECTOR(input->dataType(), indices->dataType(), segmentMaxFunctor_, (context, input, indices, output), NUMERIC_TYPES, INDEXING_TYPES);
|
||||||
|
NDArray::registerSpecialUse({output}, {input, indices});
|
||||||
}
|
}
|
||||||
// -------------------------------------------------------------------------------------------------------------- //
|
// -------------------------------------------------------------------------------------------------------------- //
|
||||||
|
|
||||||
|
@ -240,7 +242,9 @@ namespace nd4j {
|
||||||
}
|
}
|
||||||
// -------------------------------------------------------------------------------------------------------------- //
|
// -------------------------------------------------------------------------------------------------------------- //
|
||||||
void unsortedSegmentMaxFunctor(nd4j::LaunchContext* context, NDArray* input, NDArray* indices, Nd4jLong numOfClasses, NDArray* output) {
|
void unsortedSegmentMaxFunctor(nd4j::LaunchContext* context, NDArray* input, NDArray* indices, Nd4jLong numOfClasses, NDArray* output) {
|
||||||
|
NDArray::prepareSpecialUse({output}, {input, indices});
|
||||||
BUILD_DOUBLE_SELECTOR(input->dataType(), indices->dataType(), unsortedSegmentMaxFunctor_, (context, input, indices, numOfClasses, output), NUMERIC_TYPES, INDEXING_TYPES);
|
BUILD_DOUBLE_SELECTOR(input->dataType(), indices->dataType(), unsortedSegmentMaxFunctor_, (context, input, indices, numOfClasses, output), NUMERIC_TYPES, INDEXING_TYPES);
|
||||||
|
NDArray::registerSpecialUse({output}, {input, indices});
|
||||||
}
|
}
|
||||||
|
|
||||||
// -------------------------------------------------------------------------------------------------------------- //
|
// -------------------------------------------------------------------------------------------------------------- //
|
||||||
|
@ -370,8 +374,10 @@ namespace nd4j {
|
||||||
}
|
}
|
||||||
// -------------------------------------------------------------------------------------------------------------- //
|
// -------------------------------------------------------------------------------------------------------------- //
|
||||||
int segmentMaxFunctorBP(nd4j::LaunchContext* context , NDArray* input, NDArray* indices, NDArray* gradOut, NDArray* output) {
|
int segmentMaxFunctorBP(nd4j::LaunchContext* context , NDArray* input, NDArray* indices, NDArray* gradOut, NDArray* output) {
|
||||||
|
NDArray::prepareSpecialUse({output}, {input, indices, gradOut});
|
||||||
BUILD_DOUBLE_SELECTOR(output->dataType(), indices->dataType(), return segmentMaxFunctorBP_, (context, input,
|
BUILD_DOUBLE_SELECTOR(output->dataType(), indices->dataType(), return segmentMaxFunctorBP_, (context, input,
|
||||||
indices, gradOut, output), FLOAT_TYPES, INDEXING_TYPES);
|
indices, gradOut, output), FLOAT_TYPES, INDEXING_TYPES);
|
||||||
|
NDArray::registerSpecialUse({output}, {input, indices, gradOut});
|
||||||
}
|
}
|
||||||
|
|
||||||
// -------------------------------------------------------------------------------------------------------------- //
|
// -------------------------------------------------------------------------------------------------------------- //
|
||||||
|
@ -416,7 +422,9 @@ namespace nd4j {
|
||||||
}
|
}
|
||||||
// -------------------------------------------------------------------------------------------------------------- //
|
// -------------------------------------------------------------------------------------------------------------- //
|
||||||
int unsortedSegmentMaxFunctorBP(nd4j::LaunchContext* context , NDArray* input, NDArray* indices, NDArray* gradOut, Nd4jLong numOfClasses, NDArray* output) {
|
int unsortedSegmentMaxFunctorBP(nd4j::LaunchContext* context , NDArray* input, NDArray* indices, NDArray* gradOut, Nd4jLong numOfClasses, NDArray* output) {
|
||||||
|
NDArray::prepareSpecialUse({output}, {input, indices, gradOut});
|
||||||
BUILD_DOUBLE_SELECTOR(output->dataType(), indices->dataType(), return unsortedSegmentMaxFunctorBP_, (context, input, indices, gradOut, numOfClasses, output), FLOAT_TYPES, INDEXING_TYPES);
|
BUILD_DOUBLE_SELECTOR(output->dataType(), indices->dataType(), return unsortedSegmentMaxFunctorBP_, (context, input, indices, gradOut, numOfClasses, output), FLOAT_TYPES, INDEXING_TYPES);
|
||||||
|
NDArray::registerSpecialUse({output}, {input, indices, gradOut});
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -163,7 +163,7 @@ namespace helpers {
|
||||||
|
|
||||||
classesRangesBegs.assign(indices->lengthOf());
|
classesRangesBegs.assign(indices->lengthOf());
|
||||||
classesRangesLens.assign(0);
|
classesRangesLens.assign(0);
|
||||||
|
NDArray::prepareSpecialUse({output}, {input, indices});
|
||||||
dim3 dims(numClasses, indices->lengthOf(), numClasses * 32 + 32);
|
dim3 dims(numClasses, indices->lengthOf(), numClasses * 32 + 32);
|
||||||
int* begins = reinterpret_cast<int*>(classesRangesBegs.specialBuffer());
|
int* begins = reinterpret_cast<int*>(classesRangesBegs.specialBuffer());
|
||||||
int* lengths = reinterpret_cast<int*>(classesRangesLens.specialBuffer());
|
int* lengths = reinterpret_cast<int*>(classesRangesLens.specialBuffer());
|
||||||
|
@ -182,11 +182,14 @@ namespace helpers {
|
||||||
Nd4jLong* outputTadOffsets = packZ.specialOffsets();
|
Nd4jLong* outputTadOffsets = packZ.specialOffsets();
|
||||||
segmentMeanTadKernel<T,I><<<input->sizeAt(0), 512, 2048, *stream>>>(input->specialBuffer(), input->specialShapeInfo(), inputTads, inputTadOffsets, reinterpret_cast<I*>(indices->specialBuffer()), begins, lengths, numClasses, output->specialBuffer(), output->specialShapeInfo(), outputTads, outputTadOffsets);
|
segmentMeanTadKernel<T,I><<<input->sizeAt(0), 512, 2048, *stream>>>(input->specialBuffer(), input->specialShapeInfo(), inputTads, inputTadOffsets, reinterpret_cast<I*>(indices->specialBuffer()), begins, lengths, numClasses, output->specialBuffer(), output->specialShapeInfo(), outputTads, outputTadOffsets);
|
||||||
}
|
}
|
||||||
|
NDArray::registerSpecialUse({output}, {input, indices});
|
||||||
|
|
||||||
}
|
}
|
||||||
// -------------------------------------------------------------------------------------------------------------- //
|
// -------------------------------------------------------------------------------------------------------------- //
|
||||||
void segmentMeanFunctor(nd4j::LaunchContext* context , NDArray* input, NDArray* indices, NDArray* output) {
|
void segmentMeanFunctor(nd4j::LaunchContext* context , NDArray* input, NDArray* indices, NDArray* output) {
|
||||||
|
NDArray::prepareSpecialUse({output}, {input, indices});
|
||||||
BUILD_DOUBLE_SELECTOR(output->dataType(), indices->dataType(), segmentMeanFunctor_, (context, input, indices, output), NUMERIC_TYPES, INDEXING_TYPES);
|
BUILD_DOUBLE_SELECTOR(output->dataType(), indices->dataType(), segmentMeanFunctor_, (context, input, indices, output), NUMERIC_TYPES, INDEXING_TYPES);
|
||||||
|
NDArray::registerSpecialUse({output}, {input, indices});
|
||||||
}
|
}
|
||||||
|
|
||||||
// -------------------------------------------------------------------------------------------------------------- //
|
// -------------------------------------------------------------------------------------------------------------- //
|
||||||
|
@ -194,6 +197,8 @@ namespace helpers {
|
||||||
static void unsortedSegmentMeanFunctor_(nd4j::LaunchContext* context, NDArray* input, NDArray* indices, Nd4jLong numOfClasses, NDArray* output) {
|
static void unsortedSegmentMeanFunctor_(nd4j::LaunchContext* context, NDArray* input, NDArray* indices, Nd4jLong numOfClasses, NDArray* output) {
|
||||||
auto stream = context->getCudaStream();
|
auto stream = context->getCudaStream();
|
||||||
// NDArray classes = NDArrayFactory::create<int>('c', {numOfClasses, 2});
|
// NDArray classes = NDArrayFactory::create<int>('c', {numOfClasses, 2});
|
||||||
|
NDArray::prepareSpecialUse({output}, {input, indices});
|
||||||
|
|
||||||
NDArray classesRangesBegs = NDArrayFactory::create<int>('c', {numOfClasses});
|
NDArray classesRangesBegs = NDArrayFactory::create<int>('c', {numOfClasses});
|
||||||
NDArray classesRangesLens = NDArrayFactory::create<int>('c', {numOfClasses});
|
NDArray classesRangesLens = NDArrayFactory::create<int>('c', {numOfClasses});
|
||||||
// NDArray row = NDArrayFactory::create<int>('c', {1, 2}, {(int)indices->lengthOf(), (int)0});
|
// NDArray row = NDArrayFactory::create<int>('c', {1, 2}, {(int)indices->lengthOf(), (int)0});
|
||||||
|
@ -221,12 +226,15 @@ namespace helpers {
|
||||||
dims.x = input->sizeAt(0);
|
dims.x = input->sizeAt(0);
|
||||||
segmentMeanTadKernel<T,I><<<dims.x, dims.y, dims.z, *stream>>>(input->specialBuffer(), input->specialShapeInfo(), inputTads, inputTadOffsets, reinterpret_cast<I*>(indices->specialBuffer()), begins, lengths, numOfClasses, output->specialBuffer(), output->specialShapeInfo(), outputTads, outputTadOffsets);
|
segmentMeanTadKernel<T,I><<<dims.x, dims.y, dims.z, *stream>>>(input->specialBuffer(), input->specialShapeInfo(), inputTads, inputTadOffsets, reinterpret_cast<I*>(indices->specialBuffer()), begins, lengths, numOfClasses, output->specialBuffer(), output->specialShapeInfo(), outputTads, outputTadOffsets);
|
||||||
}
|
}
|
||||||
|
NDArray::registerSpecialUse({output}, {input, indices});
|
||||||
|
|
||||||
}
|
}
|
||||||
// -------------------------------------------------------------------------------------------------------------- //
|
// -------------------------------------------------------------------------------------------------------------- //
|
||||||
void unsortedSegmentMeanFunctor(nd4j::LaunchContext* context , NDArray* input, NDArray* indices, Nd4jLong numOfClasses, NDArray* output) {
|
void unsortedSegmentMeanFunctor(nd4j::LaunchContext* context , NDArray* input, NDArray* indices, Nd4jLong numOfClasses, NDArray* output) {
|
||||||
|
NDArray::prepareSpecialUse({output}, {input, indices});
|
||||||
BUILD_DOUBLE_SELECTOR(input->dataType(), indices->dataType(), unsortedSegmentMeanFunctor_, (context, input, indices, numOfClasses, output),
|
BUILD_DOUBLE_SELECTOR(input->dataType(), indices->dataType(), unsortedSegmentMeanFunctor_, (context, input, indices, numOfClasses, output),
|
||||||
NUMERIC_TYPES, INDEXING_TYPES);
|
NUMERIC_TYPES, INDEXING_TYPES);
|
||||||
|
NDArray::prepareSpecialUse({output}, {input, indices});
|
||||||
}
|
}
|
||||||
|
|
||||||
// -------------------------------------------------------------------------------------------------------------- //
|
// -------------------------------------------------------------------------------------------------------------- //
|
||||||
|
@ -349,8 +357,10 @@ namespace helpers {
|
||||||
// -------------------------------------------------------------------------------------------------------------- //
|
// -------------------------------------------------------------------------------------------------------------- //
|
||||||
// segmen mean bp main
|
// segmen mean bp main
|
||||||
int segmentMeanFunctorBP(nd4j::LaunchContext* context , NDArray* input, NDArray* indices, NDArray* gradOut, NDArray* output) {
|
int segmentMeanFunctorBP(nd4j::LaunchContext* context , NDArray* input, NDArray* indices, NDArray* gradOut, NDArray* output) {
|
||||||
|
NDArray::prepareSpecialUse({output}, {input, indices, gradOut});
|
||||||
BUILD_DOUBLE_SELECTOR(output->dataType(), indices->dataType(), return segmentMeanFunctorBP_, (context, input,
|
BUILD_DOUBLE_SELECTOR(output->dataType(), indices->dataType(), return segmentMeanFunctorBP_, (context, input,
|
||||||
indices, gradOut, output), FLOAT_TYPES, INDEXING_TYPES);
|
indices, gradOut, output), FLOAT_TYPES, INDEXING_TYPES);
|
||||||
|
NDArray::registerSpecialUse({output}, {input, indices, gradOut});
|
||||||
}
|
}
|
||||||
// -------------------------------------------------------------------------------------------------------------- //
|
// -------------------------------------------------------------------------------------------------------------- //
|
||||||
|
|
||||||
|
@ -399,7 +409,9 @@ namespace helpers {
|
||||||
}
|
}
|
||||||
// -------------------------------------------------------------------------------------------------------------- //
|
// -------------------------------------------------------------------------------------------------------------- //
|
||||||
int unsortedSegmentMeanFunctorBP(nd4j::LaunchContext* context , NDArray* input, NDArray* indices, NDArray* gradOut, Nd4jLong numOfClasses, NDArray* output) {
|
int unsortedSegmentMeanFunctorBP(nd4j::LaunchContext* context , NDArray* input, NDArray* indices, NDArray* gradOut, Nd4jLong numOfClasses, NDArray* output) {
|
||||||
|
NDArray::prepareSpecialUse({output}, {input, indices, gradOut});
|
||||||
BUILD_DOUBLE_SELECTOR(output->dataType(), indices->dataType(), return unsortedSegmentMeanFunctorBP_, (context, input, indices, gradOut, numOfClasses, output), FLOAT_TYPES, INDEXING_TYPES);
|
BUILD_DOUBLE_SELECTOR(output->dataType(), indices->dataType(), return unsortedSegmentMeanFunctorBP_, (context, input, indices, gradOut, numOfClasses, output), FLOAT_TYPES, INDEXING_TYPES);
|
||||||
|
NDArray::registerSpecialUse({output}, {input, indices, gradOut});
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -192,7 +192,9 @@ namespace helpers {
|
||||||
}
|
}
|
||||||
// -------------------------------------------------------------------------------------------------------------- //
|
// -------------------------------------------------------------------------------------------------------------- //
|
||||||
void segmentMinFunctor(nd4j::LaunchContext* context , NDArray* input, NDArray* indices, NDArray* output) {
|
void segmentMinFunctor(nd4j::LaunchContext* context , NDArray* input, NDArray* indices, NDArray* output) {
|
||||||
|
NDArray::prepareSpecialUse({output}, {input, indices});
|
||||||
BUILD_DOUBLE_SELECTOR(input->dataType(), indices->dataType(), segmentMinFunctor_, (context, input, indices, output), NUMERIC_TYPES, INDEXING_TYPES);
|
BUILD_DOUBLE_SELECTOR(input->dataType(), indices->dataType(), segmentMinFunctor_, (context, input, indices, output), NUMERIC_TYPES, INDEXING_TYPES);
|
||||||
|
NDArray::registerSpecialUse({output}, {input, indices});
|
||||||
}
|
}
|
||||||
|
|
||||||
// -------------------------------------------------------------------------------------------------------------- //
|
// -------------------------------------------------------------------------------------------------------------- //
|
||||||
|
@ -233,8 +235,10 @@ namespace helpers {
|
||||||
}
|
}
|
||||||
// -------------------------------------------------------------------------------------------------------------- //
|
// -------------------------------------------------------------------------------------------------------------- //
|
||||||
void unsortedSegmentMinFunctor(nd4j::LaunchContext* context , NDArray* input, NDArray* indices, Nd4jLong numOfClasses, NDArray* output) {
|
void unsortedSegmentMinFunctor(nd4j::LaunchContext* context , NDArray* input, NDArray* indices, Nd4jLong numOfClasses, NDArray* output) {
|
||||||
|
NDArray::prepareSpecialUse({output}, {input, indices});
|
||||||
BUILD_DOUBLE_SELECTOR(input->dataType(), indices->dataType(), unsortedSegmentMinFunctor_, (context, input, indices, numOfClasses, output),
|
BUILD_DOUBLE_SELECTOR(input->dataType(), indices->dataType(), unsortedSegmentMinFunctor_, (context, input, indices, numOfClasses, output),
|
||||||
NUMERIC_TYPES, INDEXING_TYPES);
|
NUMERIC_TYPES, INDEXING_TYPES);
|
||||||
|
NDArray::registerSpecialUse({output}, {input, indices});
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T, typename I>
|
template <typename T, typename I>
|
||||||
|
@ -364,8 +368,10 @@ namespace helpers {
|
||||||
// -------------------------------------------------------------------------------------------------------------- //
|
// -------------------------------------------------------------------------------------------------------------- //
|
||||||
// segmen min
|
// segmen min
|
||||||
int segmentMinFunctorBP(nd4j::LaunchContext* context , NDArray* input, NDArray* indices, NDArray* gradOut, NDArray* output) {
|
int segmentMinFunctorBP(nd4j::LaunchContext* context , NDArray* input, NDArray* indices, NDArray* gradOut, NDArray* output) {
|
||||||
|
NDArray::prepareSpecialUse({output}, {input, indices, gradOut});
|
||||||
BUILD_DOUBLE_SELECTOR(output->dataType(), indices->dataType(), return segmentMinFunctorBP_, (context, input,
|
BUILD_DOUBLE_SELECTOR(output->dataType(), indices->dataType(), return segmentMinFunctorBP_, (context, input,
|
||||||
indices, gradOut, output), FLOAT_TYPES, INDEXING_TYPES);
|
indices, gradOut, output), FLOAT_TYPES, INDEXING_TYPES);
|
||||||
|
NDArray::registerSpecialUse({output}, {input, indices, gradOut});
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T, typename I>
|
template <typename T, typename I>
|
||||||
|
@ -409,7 +415,9 @@ namespace helpers {
|
||||||
}
|
}
|
||||||
// -------------------------------------------------------------------------------------------------------------- //
|
// -------------------------------------------------------------------------------------------------------------- //
|
||||||
int unsortedSegmentMinFunctorBP(nd4j::LaunchContext* context , NDArray* input, NDArray* indices, NDArray* gradOut, Nd4jLong numOfClasses, NDArray* output) {
|
int unsortedSegmentMinFunctorBP(nd4j::LaunchContext* context , NDArray* input, NDArray* indices, NDArray* gradOut, Nd4jLong numOfClasses, NDArray* output) {
|
||||||
|
NDArray::prepareSpecialUse({output}, {input, indices, gradOut});
|
||||||
BUILD_DOUBLE_SELECTOR(output->dataType(), indices->dataType(), return unsortedSegmentMinFunctorBP_, (context, input, indices, gradOut, numOfClasses, output), FLOAT_TYPES, INDEXING_TYPES);
|
BUILD_DOUBLE_SELECTOR(output->dataType(), indices->dataType(), return unsortedSegmentMinFunctorBP_, (context, input, indices, gradOut, numOfClasses, output), FLOAT_TYPES, INDEXING_TYPES);
|
||||||
|
NDArray::registerSpecialUse({output}, {input, indices, gradOut});
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -192,7 +192,9 @@ namespace helpers {
|
||||||
}
|
}
|
||||||
// -------------------------------------------------------------------------------------------------------------- //
|
// -------------------------------------------------------------------------------------------------------------- //
|
||||||
void segmentProdFunctor(nd4j::LaunchContext* context , NDArray* input, NDArray* indices, NDArray* output) {
|
void segmentProdFunctor(nd4j::LaunchContext* context , NDArray* input, NDArray* indices, NDArray* output) {
|
||||||
|
NDArray::prepareSpecialUse({output}, {input, indices});
|
||||||
BUILD_DOUBLE_SELECTOR(output->dataType(), indices->dataType(), segmentProdFunctor_, (context, input, indices, output), NUMERIC_TYPES, INDEXING_TYPES);
|
BUILD_DOUBLE_SELECTOR(output->dataType(), indices->dataType(), segmentProdFunctor_, (context, input, indices, output), NUMERIC_TYPES, INDEXING_TYPES);
|
||||||
|
NDArray::registerSpecialUse({output}, {input, indices});
|
||||||
}
|
}
|
||||||
|
|
||||||
// -------------------------------------------------------------------------------------------------------------- //
|
// -------------------------------------------------------------------------------------------------------------- //
|
||||||
|
@ -231,8 +233,10 @@ namespace helpers {
|
||||||
}
|
}
|
||||||
// -------------------------------------------------------------------------------------------------------------- //
|
// -------------------------------------------------------------------------------------------------------------- //
|
||||||
void unsortedSegmentProdFunctor(nd4j::LaunchContext* context , NDArray* input, NDArray* indices, Nd4jLong numOfClasses, NDArray* output) {
|
void unsortedSegmentProdFunctor(nd4j::LaunchContext* context , NDArray* input, NDArray* indices, Nd4jLong numOfClasses, NDArray* output) {
|
||||||
|
NDArray::prepareSpecialUse({output}, {input, indices});
|
||||||
BUILD_DOUBLE_SELECTOR(input->dataType(), indices->dataType(), unsortedSegmentProdFunctor_, (context, input, indices, numOfClasses, output),
|
BUILD_DOUBLE_SELECTOR(input->dataType(), indices->dataType(), unsortedSegmentProdFunctor_, (context, input, indices, numOfClasses, output),
|
||||||
NUMERIC_TYPES, INDEXING_TYPES);
|
NUMERIC_TYPES, INDEXING_TYPES);
|
||||||
|
NDArray::registerSpecialUse({output}, {input, indices});
|
||||||
}
|
}
|
||||||
|
|
||||||
// -------------------------------------------------------------------------------------------------------------- //
|
// -------------------------------------------------------------------------------------------------------------- //
|
||||||
|
@ -358,8 +362,10 @@ namespace helpers {
|
||||||
// -------------------------------------------------------------------------------------------------------------- //
|
// -------------------------------------------------------------------------------------------------------------- //
|
||||||
|
|
||||||
int segmentProdFunctorBP(nd4j::LaunchContext* context , NDArray* input, NDArray* indices, NDArray* gradOut, NDArray* output) {
|
int segmentProdFunctorBP(nd4j::LaunchContext* context , NDArray* input, NDArray* indices, NDArray* gradOut, NDArray* output) {
|
||||||
|
NDArray::prepareSpecialUse({output}, {input, indices, gradOut});
|
||||||
BUILD_DOUBLE_SELECTOR(output->dataType(), indices->dataType(), return segmentProdFunctorBP_, (context, input,
|
BUILD_DOUBLE_SELECTOR(output->dataType(), indices->dataType(), return segmentProdFunctorBP_, (context, input,
|
||||||
indices, gradOut, output), FLOAT_TYPES, INDEXING_TYPES);
|
indices, gradOut, output), FLOAT_TYPES, INDEXING_TYPES);
|
||||||
|
NDArray::registerSpecialUse({output}, {input, indices, gradOut});
|
||||||
}
|
}
|
||||||
|
|
||||||
// -------------------------------------------------------------------------------------------------------------- //
|
// -------------------------------------------------------------------------------------------------------------- //
|
||||||
|
@ -404,7 +410,9 @@ namespace helpers {
|
||||||
|
|
||||||
// -------------------------------------------------------------------------------------------------------------- //
|
// -------------------------------------------------------------------------------------------------------------- //
|
||||||
int unsortedSegmentProdFunctorBP(nd4j::LaunchContext* context , NDArray* input, NDArray* indices, NDArray* gradOut, Nd4jLong numOfClasses, NDArray* output) {
|
int unsortedSegmentProdFunctorBP(nd4j::LaunchContext* context , NDArray* input, NDArray* indices, NDArray* gradOut, Nd4jLong numOfClasses, NDArray* output) {
|
||||||
|
NDArray::prepareSpecialUse({output}, {input, indices, gradOut});
|
||||||
BUILD_DOUBLE_SELECTOR(output->dataType(), indices->dataType(), return unsortedSegmentProdFunctorBP_, (context, input, indices, gradOut, numOfClasses, output), FLOAT_TYPES, INDEXING_TYPES);
|
BUILD_DOUBLE_SELECTOR(output->dataType(), indices->dataType(), return unsortedSegmentProdFunctorBP_, (context, input, indices, gradOut, numOfClasses, output), FLOAT_TYPES, INDEXING_TYPES);
|
||||||
|
NDArray::registerSpecialUse({output}, {input, indices, gradOut});
|
||||||
}
|
}
|
||||||
|
|
||||||
// -------------------------------------------------------------------------------------------------------------- //
|
// -------------------------------------------------------------------------------------------------------------- //
|
||||||
|
|
|
@ -146,8 +146,10 @@ namespace helpers {
|
||||||
}
|
}
|
||||||
// -------------------------------------------------------------------------------------------------------------- //
|
// -------------------------------------------------------------------------------------------------------------- //
|
||||||
void unsortedSegmentSqrtNFunctor(nd4j::LaunchContext* context , NDArray* input, NDArray* indices, Nd4jLong numOfClasses, NDArray* output) {
|
void unsortedSegmentSqrtNFunctor(nd4j::LaunchContext* context , NDArray* input, NDArray* indices, Nd4jLong numOfClasses, NDArray* output) {
|
||||||
|
NDArray::prepareSpecialUse({output}, {input, indices});
|
||||||
BUILD_DOUBLE_SELECTOR(input->dataType(), indices->dataType(), unsortedSegmentSqrtNFunctor_, (context, input, indices, numOfClasses, output),
|
BUILD_DOUBLE_SELECTOR(input->dataType(), indices->dataType(), unsortedSegmentSqrtNFunctor_, (context, input, indices, numOfClasses, output),
|
||||||
FLOAT_TYPES, INDEXING_TYPES);
|
FLOAT_TYPES, INDEXING_TYPES);
|
||||||
|
NDArray::registerSpecialUse({output}, {input, indices});
|
||||||
}
|
}
|
||||||
// -------------------------------------------------------------------------------------------------------------- //
|
// -------------------------------------------------------------------------------------------------------------- //
|
||||||
template <typename T, typename I>
|
template <typename T, typename I>
|
||||||
|
@ -270,7 +272,9 @@ namespace helpers {
|
||||||
}
|
}
|
||||||
// -------------------------------------------------------------------------------------------------------------- //
|
// -------------------------------------------------------------------------------------------------------------- //
|
||||||
int unsortedSegmentSqrtNFunctorBP(nd4j::LaunchContext* context , NDArray* input, NDArray* indices, NDArray* gradOut, Nd4jLong numOfClasses, NDArray* output) {
|
int unsortedSegmentSqrtNFunctorBP(nd4j::LaunchContext* context , NDArray* input, NDArray* indices, NDArray* gradOut, Nd4jLong numOfClasses, NDArray* output) {
|
||||||
|
NDArray::prepareSpecialUse({output}, {input, indices, gradOut});
|
||||||
BUILD_DOUBLE_SELECTOR(output->dataType(), indices->dataType(), return unsortedSegmentSqrtNFunctorBP_, (context, input, indices, gradOut, numOfClasses, output), FLOAT_TYPES, INDEXING_TYPES);
|
BUILD_DOUBLE_SELECTOR(output->dataType(), indices->dataType(), return unsortedSegmentSqrtNFunctorBP_, (context, input, indices, gradOut, numOfClasses, output), FLOAT_TYPES, INDEXING_TYPES);
|
||||||
|
NDArray::registerSpecialUse({output}, {input, indices, gradOut});
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -190,7 +190,9 @@ namespace helpers {
|
||||||
}
|
}
|
||||||
// -------------------------------------------------------------------------------------------------------------- //
|
// -------------------------------------------------------------------------------------------------------------- //
|
||||||
void segmentSumFunctor(nd4j::LaunchContext* context , NDArray* input, NDArray* indices, NDArray* output) {
|
void segmentSumFunctor(nd4j::LaunchContext* context , NDArray* input, NDArray* indices, NDArray* output) {
|
||||||
|
NDArray::prepareSpecialUse({output}, {input, indices});
|
||||||
BUILD_DOUBLE_SELECTOR(input->dataType(), indices->dataType(), segmentSumFunctor_, (context, input, indices, output), NUMERIC_TYPES, INDEXING_TYPES);
|
BUILD_DOUBLE_SELECTOR(input->dataType(), indices->dataType(), segmentSumFunctor_, (context, input, indices, output), NUMERIC_TYPES, INDEXING_TYPES);
|
||||||
|
NDArray::registerSpecialUse({output}, {input, indices});
|
||||||
}
|
}
|
||||||
|
|
||||||
// -------------------------------------------------------------------------------------------------------------- //
|
// -------------------------------------------------------------------------------------------------------------- //
|
||||||
|
@ -229,8 +231,10 @@ namespace helpers {
|
||||||
}
|
}
|
||||||
// -------------------------------------------------------------------------------------------------------------- //
|
// -------------------------------------------------------------------------------------------------------------- //
|
||||||
void unsortedSegmentSumFunctor(nd4j::LaunchContext* context , NDArray* input, NDArray* indices, Nd4jLong numOfClasses, NDArray* output) {
|
void unsortedSegmentSumFunctor(nd4j::LaunchContext* context , NDArray* input, NDArray* indices, Nd4jLong numOfClasses, NDArray* output) {
|
||||||
|
NDArray::prepareSpecialUse({output}, {input, indices});
|
||||||
BUILD_DOUBLE_SELECTOR(input->dataType(), indices->dataType(), unsortedSegmentSumFunctor_, (context, input, indices, numOfClasses, output),
|
BUILD_DOUBLE_SELECTOR(input->dataType(), indices->dataType(), unsortedSegmentSumFunctor_, (context, input, indices, numOfClasses, output),
|
||||||
NUMERIC_TYPES, INDEXING_TYPES);
|
NUMERIC_TYPES, INDEXING_TYPES);
|
||||||
|
NDArray::registerSpecialUse({output}, {input, indices});
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -343,8 +347,10 @@ namespace helpers {
|
||||||
// -------------------------------------------------------------------------------------------------------------- //
|
// -------------------------------------------------------------------------------------------------------------- //
|
||||||
|
|
||||||
int segmentSumFunctorBP(nd4j::LaunchContext* context , NDArray* input, NDArray* indices, NDArray* gradOut, NDArray* output) {
|
int segmentSumFunctorBP(nd4j::LaunchContext* context , NDArray* input, NDArray* indices, NDArray* gradOut, NDArray* output) {
|
||||||
|
NDArray::prepareSpecialUse({output}, {input, indices, gradOut});
|
||||||
BUILD_DOUBLE_SELECTOR(output->dataType(), indices->dataType(), return segmentSumFunctorBP_, (context, input,
|
BUILD_DOUBLE_SELECTOR(output->dataType(), indices->dataType(), return segmentSumFunctorBP_, (context, input,
|
||||||
indices, gradOut, output), FLOAT_TYPES, INDEXING_TYPES);
|
indices, gradOut, output), FLOAT_TYPES, INDEXING_TYPES);
|
||||||
|
NDArray::registerSpecialUse({output}, {input, indices, gradOut});
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T, typename I>
|
template <typename T, typename I>
|
||||||
|
@ -381,7 +387,9 @@ namespace helpers {
|
||||||
}
|
}
|
||||||
// -------------------------------------------------------------------------------------------------------------- //
|
// -------------------------------------------------------------------------------------------------------------- //
|
||||||
int unsortedSegmentSumFunctorBP(nd4j::LaunchContext* context , NDArray* input, NDArray* indices, NDArray* gradOut, Nd4jLong numOfClasses, NDArray* output) {
|
int unsortedSegmentSumFunctorBP(nd4j::LaunchContext* context , NDArray* input, NDArray* indices, NDArray* gradOut, Nd4jLong numOfClasses, NDArray* output) {
|
||||||
|
NDArray::prepareSpecialUse({output}, {input, indices, gradOut});
|
||||||
BUILD_DOUBLE_SELECTOR(output->dataType(), indices->dataType(), return unsortedSegmentSumFunctorBP_, (context, input, indices, gradOut, numOfClasses, output), FLOAT_TYPES, INDEXING_TYPES);
|
BUILD_DOUBLE_SELECTOR(output->dataType(), indices->dataType(), return unsortedSegmentSumFunctorBP_, (context, input, indices, gradOut, numOfClasses, output), FLOAT_TYPES, INDEXING_TYPES);
|
||||||
|
NDArray::registerSpecialUse({output}, {input, indices, gradOut});
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -1423,6 +1423,42 @@ TEST_F(DeclarableOpsTests7, TestSegmentMean_2) {
|
||||||
delete result;
|
delete result;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_F(DeclarableOpsTests7, TestSegmentMean_02) {
|
||||||
|
auto x = NDArrayFactory::create<double>('c', {6, 3}, {1, 2, 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14., 15., 16., 17., 18.});
|
||||||
|
auto idx = NDArrayFactory::create<int>({0, 0, 1, 1, 2,2});
|
||||||
|
auto exp = NDArrayFactory::create<double>('c', {3, 3}, { 2.5, 3.5, 4.5, 8.5, 9.5, 10.5, 14.5, 15.5, 16.5});
|
||||||
|
|
||||||
|
nd4j::ops::segment_mean op;
|
||||||
|
|
||||||
|
auto result = op.execute({&x, &idx}, {}, {});
|
||||||
|
ASSERT_EQ(result->status(), Status::OK());
|
||||||
|
ASSERT_EQ(result->size(), 1);
|
||||||
|
exp.printIndexedBuffer("Expect Mean");
|
||||||
|
result->at(0)->printIndexedBuffer("Output Mean");
|
||||||
|
// exp.printShapeInfo("Exp Shape");
|
||||||
|
ASSERT_TRUE(exp.equalsTo(result->at(0)));
|
||||||
|
|
||||||
|
delete result;
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(DeclarableOpsTests7, TestSegmentMean_021) {
|
||||||
|
auto x = NDArrayFactory::create<float>('c', {6, 3});//, {1, 2, 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14., 15., 16., 17., 18.});
|
||||||
|
auto idx = NDArrayFactory::create<int>({0, 0, 1, 1, 2,2});
|
||||||
|
auto exp = NDArrayFactory::create<float>('c', {3, 3}, { 2.5, 3.5, 4.5, 8.5, 9.5, 10.5, 14.5, 15.5, 16.5});
|
||||||
|
|
||||||
|
nd4j::ops::segment_mean op;
|
||||||
|
x.linspace(1.);
|
||||||
|
auto result = op.execute({&x, &idx}, {}, {});
|
||||||
|
ASSERT_EQ(result->status(), Status::OK());
|
||||||
|
ASSERT_EQ(result->size(), 1);
|
||||||
|
exp.printIndexedBuffer("Expect Mean");
|
||||||
|
result->at(0)->printIndexedBuffer("Output Mean");
|
||||||
|
// exp.printShapeInfo("Exp Shape");
|
||||||
|
ASSERT_TRUE(exp.equalsTo(result->at(0)));
|
||||||
|
|
||||||
|
delete result;
|
||||||
|
}
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
TEST_F(DeclarableOpsTests7, TestSegmentMeanBP_2) {
|
TEST_F(DeclarableOpsTests7, TestSegmentMeanBP_2) {
|
||||||
auto x = NDArrayFactory::create<double>('c', {4, 4}, {1.8, 2.5, 4., 9.,2.1, 2.4, 3., 9.,2.1, 2.1, 0.7, 0.1,3., 4.2, 2.2, 1.});
|
auto x = NDArrayFactory::create<double>('c', {4, 4}, {1.8, 2.5, 4., 9.,2.1, 2.4, 3., 9.,2.1, 2.1, 0.7, 0.1,3., 4.2, 2.2, 1.});
|
||||||
|
|
Loading…
Reference in New Issue