max_pool with argmax - more data types (#486)

Signed-off-by: raver119@gmail.com <raver119@gmail.com>
master
raver119 2020-06-11 12:39:14 +03:00 committed by GitHub
parent fadc2d8622
commit 8733c0c3ed
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 9 additions and 9 deletions

View File

@ -45,18 +45,18 @@ namespace sd {
DECLARE_TYPES(max_pool_with_argmax) {
getOpDescriptor()
->setAllowedInputTypes(sd::DataType::ANY)
->setAllowedOutputTypes(0, DataType::INHERIT)
->setAllowedOutputTypes(1, {ALL_INTS});
->setAllowedOutputTypes(0, {ALL_FLOATS, ALL_INTS})
->setAllowedOutputTypes(1, {ALL_INDICES});
}
DECLARE_SHAPE_FN(max_pool_with_argmax) {
auto in = inputShape->at(0);
auto dtype = block.numD() ? D_ARG(0) : sd::DataType::INT64;
auto valuesShape = ConstantShapeHelper::getInstance().createShapeInfo(ShapeDescriptor(in));
auto indicesShape = ConstantShapeHelper::getInstance().createShapeInfo(ShapeDescriptor(in, dtype));
auto in = inputShape->at(0);
auto valuesShape = ConstantShapeHelper::getInstance().createShapeInfo(ShapeDescriptor(in));
auto indicesShape = ConstantShapeHelper::getInstance().createShapeInfo(ShapeDescriptor(in, DataType::INT64));
return SHAPELIST(valuesShape, indicesShape);
return SHAPELIST(valuesShape, indicesShape);
}
}
}

View File

@ -73,7 +73,7 @@ namespace helpers {
}
void maxPoolingFunctor(sd::LaunchContext * context, sd::graph::Context& block, NDArray* input, NDArray* values, std::vector<int> const& params, NDArray* indices) {
BUILD_SINGLE_SELECTOR(input->dataType(), maxPoolingFunctor_, (block, input, values, params, indices), FLOAT_TYPES);
BUILD_SINGLE_SELECTOR(input->dataType(), maxPoolingFunctor_, (block, input, values, params, indices), LIBND4J_TYPES);
}
}

View File

@ -88,7 +88,7 @@ namespace helpers {
void maxPoolingFunctor(sd::LaunchContext * context, sd::graph::Context& block, NDArray* input, NDArray* values, std::vector<int> const& params, NDArray* indices) {
NDArray::prepareSpecialUse({values, indices}, {input});
auto yType = indices == nullptr ? sd::DataType::INT64 : indices->dataType();
BUILD_DOUBLE_SELECTOR(input->dataType(), yType, maxPoolingFunctor_, (block, input, values, params, indices), FLOAT_TYPES, INDEXING_TYPES);
BUILD_DOUBLE_SELECTOR(input->dataType(), yType, maxPoolingFunctor_, (block, input, values, params, indices), LIBND4J_TYPES, INDEXING_TYPES);
NDArray::registerSpecialUse({values, indices}, {input});
}