max_pool with argmax - more data types (#486)
Signed-off-by: raver119@gmail.com <raver119@gmail.com>master
parent
fadc2d8622
commit
8733c0c3ed
|
@ -45,18 +45,18 @@ namespace sd {
|
|||
DECLARE_TYPES(max_pool_with_argmax) {
|
||||
getOpDescriptor()
|
||||
->setAllowedInputTypes(sd::DataType::ANY)
|
||||
->setAllowedOutputTypes(0, DataType::INHERIT)
|
||||
->setAllowedOutputTypes(1, {ALL_INTS});
|
||||
->setAllowedOutputTypes(0, {ALL_FLOATS, ALL_INTS})
|
||||
->setAllowedOutputTypes(1, {ALL_INDICES});
|
||||
|
||||
}
|
||||
|
||||
DECLARE_SHAPE_FN(max_pool_with_argmax) {
|
||||
auto in = inputShape->at(0);
|
||||
auto dtype = block.numD() ? D_ARG(0) : sd::DataType::INT64;
|
||||
auto valuesShape = ConstantShapeHelper::getInstance().createShapeInfo(ShapeDescriptor(in));
|
||||
auto indicesShape = ConstantShapeHelper::getInstance().createShapeInfo(ShapeDescriptor(in, dtype));
|
||||
|
||||
auto in = inputShape->at(0);
|
||||
auto valuesShape = ConstantShapeHelper::getInstance().createShapeInfo(ShapeDescriptor(in));
|
||||
auto indicesShape = ConstantShapeHelper::getInstance().createShapeInfo(ShapeDescriptor(in, DataType::INT64));
|
||||
|
||||
return SHAPELIST(valuesShape, indicesShape);
|
||||
return SHAPELIST(valuesShape, indicesShape);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -73,7 +73,7 @@ namespace helpers {
|
|||
}
|
||||
|
||||
void maxPoolingFunctor(sd::LaunchContext * context, sd::graph::Context& block, NDArray* input, NDArray* values, std::vector<int> const& params, NDArray* indices) {
|
||||
BUILD_SINGLE_SELECTOR(input->dataType(), maxPoolingFunctor_, (block, input, values, params, indices), FLOAT_TYPES);
|
||||
BUILD_SINGLE_SELECTOR(input->dataType(), maxPoolingFunctor_, (block, input, values, params, indices), LIBND4J_TYPES);
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
@ -88,7 +88,7 @@ namespace helpers {
|
|||
void maxPoolingFunctor(sd::LaunchContext * context, sd::graph::Context& block, NDArray* input, NDArray* values, std::vector<int> const& params, NDArray* indices) {
|
||||
NDArray::prepareSpecialUse({values, indices}, {input});
|
||||
auto yType = indices == nullptr ? sd::DataType::INT64 : indices->dataType();
|
||||
BUILD_DOUBLE_SELECTOR(input->dataType(), yType, maxPoolingFunctor_, (block, input, values, params, indices), FLOAT_TYPES, INDEXING_TYPES);
|
||||
BUILD_DOUBLE_SELECTOR(input->dataType(), yType, maxPoolingFunctor_, (block, input, values, params, indices), LIBND4J_TYPES, INDEXING_TYPES);
|
||||
NDArray::registerSpecialUse({values, indices}, {input});
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue