diff --git a/libnd4j/include/ops/declarable/generic/nn/pooling/maxpool_with_argmax.cpp b/libnd4j/include/ops/declarable/generic/nn/pooling/maxpool_with_argmax.cpp index eced3c2b4..b03d19451 100644 --- a/libnd4j/include/ops/declarable/generic/nn/pooling/maxpool_with_argmax.cpp +++ b/libnd4j/include/ops/declarable/generic/nn/pooling/maxpool_with_argmax.cpp @@ -45,18 +45,18 @@ namespace sd { DECLARE_TYPES(max_pool_with_argmax) { getOpDescriptor() ->setAllowedInputTypes(sd::DataType::ANY) - ->setAllowedOutputTypes(0, DataType::INHERIT) - ->setAllowedOutputTypes(1, {ALL_INTS}); + ->setAllowedOutputTypes(0, {ALL_FLOATS, ALL_INTS}) + ->setAllowedOutputTypes(1, {ALL_INDICES}); } DECLARE_SHAPE_FN(max_pool_with_argmax) { + auto in = inputShape->at(0); + auto dtype = block.numD() ? D_ARG(0) : sd::DataType::INT64; + auto valuesShape = ConstantShapeHelper::getInstance().createShapeInfo(ShapeDescriptor(in)); + auto indicesShape = ConstantShapeHelper::getInstance().createShapeInfo(ShapeDescriptor(in, dtype)); - auto in = inputShape->at(0); - auto valuesShape = ConstantShapeHelper::getInstance().createShapeInfo(ShapeDescriptor(in)); - auto indicesShape = ConstantShapeHelper::getInstance().createShapeInfo(ShapeDescriptor(in, DataType::INT64)); - - return SHAPELIST(valuesShape, indicesShape); + return SHAPELIST(valuesShape, indicesShape); } } } diff --git a/libnd4j/include/ops/declarable/helpers/cpu/max_pooling.cpp b/libnd4j/include/ops/declarable/helpers/cpu/max_pooling.cpp index a458b5eff..ebb9d53fa 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/max_pooling.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/max_pooling.cpp @@ -73,7 +73,7 @@ namespace helpers { } void maxPoolingFunctor(sd::LaunchContext * context, sd::graph::Context& block, NDArray* input, NDArray* values, std::vector const& params, NDArray* indices) { - BUILD_SINGLE_SELECTOR(input->dataType(), maxPoolingFunctor_, (block, input, values, params, indices), FLOAT_TYPES); + BUILD_SINGLE_SELECTOR(input->dataType(), maxPoolingFunctor_, (block, input, values, params, indices), LIBND4J_TYPES); } } diff --git a/libnd4j/include/ops/declarable/helpers/cuda/max_pooling.cu b/libnd4j/include/ops/declarable/helpers/cuda/max_pooling.cu index 6e70d4510..8c30e510f 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/max_pooling.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/max_pooling.cu @@ -88,7 +88,7 @@ namespace helpers { void maxPoolingFunctor(sd::LaunchContext * context, sd::graph::Context& block, NDArray* input, NDArray* values, std::vector const& params, NDArray* indices) { NDArray::prepareSpecialUse({values, indices}, {input}); auto yType = indices == nullptr ? sd::DataType::INT64 : indices->dataType(); - BUILD_DOUBLE_SELECTOR(input->dataType(), yType, maxPoolingFunctor_, (block, input, values, params, indices), FLOAT_TYPES, INDEXING_TYPES); + BUILD_DOUBLE_SELECTOR(input->dataType(), yType, maxPoolingFunctor_, (block, input, values, params, indices), LIBND4J_TYPES, INDEXING_TYPES); NDArray::registerSpecialUse({values, indices}, {input}); }