max_pool with argmax - more data types (#486)
Signed-off-by: raver119@gmail.com <raver119@gmail.com>master
parent
fadc2d8622
commit
8733c0c3ed
|
@ -45,18 +45,18 @@ namespace sd {
|
||||||
DECLARE_TYPES(max_pool_with_argmax) {
|
DECLARE_TYPES(max_pool_with_argmax) {
|
||||||
getOpDescriptor()
|
getOpDescriptor()
|
||||||
->setAllowedInputTypes(sd::DataType::ANY)
|
->setAllowedInputTypes(sd::DataType::ANY)
|
||||||
->setAllowedOutputTypes(0, DataType::INHERIT)
|
->setAllowedOutputTypes(0, {ALL_FLOATS, ALL_INTS})
|
||||||
->setAllowedOutputTypes(1, {ALL_INTS});
|
->setAllowedOutputTypes(1, {ALL_INDICES});
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
DECLARE_SHAPE_FN(max_pool_with_argmax) {
|
DECLARE_SHAPE_FN(max_pool_with_argmax) {
|
||||||
|
auto in = inputShape->at(0);
|
||||||
|
auto dtype = block.numD() ? D_ARG(0) : sd::DataType::INT64;
|
||||||
|
auto valuesShape = ConstantShapeHelper::getInstance().createShapeInfo(ShapeDescriptor(in));
|
||||||
|
auto indicesShape = ConstantShapeHelper::getInstance().createShapeInfo(ShapeDescriptor(in, dtype));
|
||||||
|
|
||||||
auto in = inputShape->at(0);
|
return SHAPELIST(valuesShape, indicesShape);
|
||||||
auto valuesShape = ConstantShapeHelper::getInstance().createShapeInfo(ShapeDescriptor(in));
|
|
||||||
auto indicesShape = ConstantShapeHelper::getInstance().createShapeInfo(ShapeDescriptor(in, DataType::INT64));
|
|
||||||
|
|
||||||
return SHAPELIST(valuesShape, indicesShape);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -73,7 +73,7 @@ namespace helpers {
|
||||||
}
|
}
|
||||||
|
|
||||||
void maxPoolingFunctor(sd::LaunchContext * context, sd::graph::Context& block, NDArray* input, NDArray* values, std::vector<int> const& params, NDArray* indices) {
|
void maxPoolingFunctor(sd::LaunchContext * context, sd::graph::Context& block, NDArray* input, NDArray* values, std::vector<int> const& params, NDArray* indices) {
|
||||||
BUILD_SINGLE_SELECTOR(input->dataType(), maxPoolingFunctor_, (block, input, values, params, indices), FLOAT_TYPES);
|
BUILD_SINGLE_SELECTOR(input->dataType(), maxPoolingFunctor_, (block, input, values, params, indices), LIBND4J_TYPES);
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -88,7 +88,7 @@ namespace helpers {
|
||||||
void maxPoolingFunctor(sd::LaunchContext * context, sd::graph::Context& block, NDArray* input, NDArray* values, std::vector<int> const& params, NDArray* indices) {
|
void maxPoolingFunctor(sd::LaunchContext * context, sd::graph::Context& block, NDArray* input, NDArray* values, std::vector<int> const& params, NDArray* indices) {
|
||||||
NDArray::prepareSpecialUse({values, indices}, {input});
|
NDArray::prepareSpecialUse({values, indices}, {input});
|
||||||
auto yType = indices == nullptr ? sd::DataType::INT64 : indices->dataType();
|
auto yType = indices == nullptr ? sd::DataType::INT64 : indices->dataType();
|
||||||
BUILD_DOUBLE_SELECTOR(input->dataType(), yType, maxPoolingFunctor_, (block, input, values, params, indices), FLOAT_TYPES, INDEXING_TYPES);
|
BUILD_DOUBLE_SELECTOR(input->dataType(), yType, maxPoolingFunctor_, (block, input, values, params, indices), LIBND4J_TYPES, INDEXING_TYPES);
|
||||||
NDArray::registerSpecialUse({values, indices}, {input});
|
NDArray::registerSpecialUse({values, indices}, {input});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue