max_pool with argmax - more data types (#486)

Signed-off-by: raver119@gmail.com <raver119@gmail.com>
master
raver119 2020-06-11 12:39:14 +03:00 committed by GitHub
parent fadc2d8622
commit 8733c0c3ed
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 9 additions and 9 deletions

View File

@ -45,16 +45,16 @@ namespace sd {
DECLARE_TYPES(max_pool_with_argmax) { DECLARE_TYPES(max_pool_with_argmax) {
getOpDescriptor() getOpDescriptor()
->setAllowedInputTypes(sd::DataType::ANY) ->setAllowedInputTypes(sd::DataType::ANY)
->setAllowedOutputTypes(0, DataType::INHERIT) ->setAllowedOutputTypes(0, {ALL_FLOATS, ALL_INTS})
->setAllowedOutputTypes(1, {ALL_INTS}); ->setAllowedOutputTypes(1, {ALL_INDICES});
} }
DECLARE_SHAPE_FN(max_pool_with_argmax) { DECLARE_SHAPE_FN(max_pool_with_argmax) {
auto in = inputShape->at(0); auto in = inputShape->at(0);
auto dtype = block.numD() ? D_ARG(0) : sd::DataType::INT64;
auto valuesShape = ConstantShapeHelper::getInstance().createShapeInfo(ShapeDescriptor(in)); auto valuesShape = ConstantShapeHelper::getInstance().createShapeInfo(ShapeDescriptor(in));
auto indicesShape = ConstantShapeHelper::getInstance().createShapeInfo(ShapeDescriptor(in, DataType::INT64)); auto indicesShape = ConstantShapeHelper::getInstance().createShapeInfo(ShapeDescriptor(in, dtype));
return SHAPELIST(valuesShape, indicesShape); return SHAPELIST(valuesShape, indicesShape);
} }

View File

@ -73,7 +73,7 @@ namespace helpers {
} }
void maxPoolingFunctor(sd::LaunchContext * context, sd::graph::Context& block, NDArray* input, NDArray* values, std::vector<int> const& params, NDArray* indices) { void maxPoolingFunctor(sd::LaunchContext * context, sd::graph::Context& block, NDArray* input, NDArray* values, std::vector<int> const& params, NDArray* indices) {
BUILD_SINGLE_SELECTOR(input->dataType(), maxPoolingFunctor_, (block, input, values, params, indices), FLOAT_TYPES); BUILD_SINGLE_SELECTOR(input->dataType(), maxPoolingFunctor_, (block, input, values, params, indices), LIBND4J_TYPES);
} }
} }

View File

@ -88,7 +88,7 @@ namespace helpers {
void maxPoolingFunctor(sd::LaunchContext * context, sd::graph::Context& block, NDArray* input, NDArray* values, std::vector<int> const& params, NDArray* indices) { void maxPoolingFunctor(sd::LaunchContext * context, sd::graph::Context& block, NDArray* input, NDArray* values, std::vector<int> const& params, NDArray* indices) {
NDArray::prepareSpecialUse({values, indices}, {input}); NDArray::prepareSpecialUse({values, indices}, {input});
auto yType = indices == nullptr ? sd::DataType::INT64 : indices->dataType(); auto yType = indices == nullptr ? sd::DataType::INT64 : indices->dataType();
BUILD_DOUBLE_SELECTOR(input->dataType(), yType, maxPoolingFunctor_, (block, input, values, params, indices), FLOAT_TYPES, INDEXING_TYPES); BUILD_DOUBLE_SELECTOR(input->dataType(), yType, maxPoolingFunctor_, (block, input, values, params, indices), LIBND4J_TYPES, INDEXING_TYPES);
NDArray::registerSpecialUse({values, indices}, {input}); NDArray::registerSpecialUse({values, indices}, {input});
} }