-
Notifications
You must be signed in to change notification settings - Fork 131
Add dynamic shape support for TopK #4880
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: develop
Are you sure you want to change the base?
Changes from all commits
02ec17c
b998457
3527d11
171aec9
222a2dd
33398d8
b067618
36ad89d
b55ae07
df26e31
89f31ba
2624cec
bf9e4e4
e4073bd
eedd15f
49b757b
3499f43
d5d0d3e
68332cc
0ee21ec
3431d57
c360580
e6090e3
688864a
13b06cb
a048213
0d08c50
040e186
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||
|---|---|---|---|---|---|---|---|---|
|
|
@@ -28,18 +28,20 @@ | |||||||
| #include <migraphx/check_shapes.hpp> | ||||||||
| #include <migraphx/argument.hpp> | ||||||||
| #include <migraphx/config.hpp> | ||||||||
| #include <migraphx/dyn_output.hpp> | ||||||||
| #include <migraphx/op/normalize_attribute.hpp> | ||||||||
| #include <migraphx/par_for.hpp> | ||||||||
| #include <migraphx/ranges.hpp> | ||||||||
| #include <migraphx/value.hpp> | ||||||||
| #include <optional> | ||||||||
|
|
||||||||
| namespace migraphx { | ||||||||
| inline namespace MIGRAPHX_INLINE_NS { | ||||||||
| namespace op { | ||||||||
|
|
||||||||
| struct topk | ||||||||
| { | ||||||||
| int64_t k = 1; | ||||||||
| std::optional<int64_t> k; | ||||||||
| int64_t axis = 0; | ||||||||
| bool largest = true; | ||||||||
|
|
||||||||
|
|
@@ -60,16 +62,38 @@ struct topk | |||||||
|
|
||||||||
| shape normalize_compute_shape(std::vector<shape> inputs) const | ||||||||
| { | ||||||||
| check_shapes{inputs, *this}.has(1, 2); | ||||||||
| auto lens = inputs.at(0).lens(); | ||||||||
| check_shapes{inputs, *this, true}.has(1, 2); | ||||||||
| auto type = inputs.at(0).type(); | ||||||||
|
|
||||||||
| lens[axis] = k; | ||||||||
| if(inputs.at(0).dynamic()) | ||||||||
| { | ||||||||
| auto dyn_dims = inputs.at(0).dyn_dims(); | ||||||||
| if(k.has_value()) | ||||||||
| { | ||||||||
| auto min_lens_vec = inputs.at(0).min_lens(); | ||||||||
| auto max_lens_vec = inputs.at(0).max_lens(); | ||||||||
| auto min_kk = std::min(static_cast<std::size_t>(*k), min_lens_vec[axis]); | ||||||||
| auto max_kk = std::min(static_cast<std::size_t>(*k), max_lens_vec[axis]); | ||||||||
| dyn_dims[axis] = {min_kk, max_kk}; | ||||||||
| } | ||||||||
|
|
||||||||
| shape s_val{type, lens}; | ||||||||
| shape s_ind{shape::int64_type, lens}; | ||||||||
| shape s_val{type, dyn_dims}; | ||||||||
| shape s_ind{shape::int64_type, dyn_dims}; | ||||||||
| return shape({s_val, s_ind}); | ||||||||
| } | ||||||||
| else | ||||||||
| { | ||||||||
| auto lens = inputs.at(0).lens(); | ||||||||
| if(k.has_value()) | ||||||||
| { | ||||||||
| auto kk = std::min(static_cast<std::size_t>(*k), lens[axis]); | ||||||||
| lens[axis] = kk; | ||||||||
| } | ||||||||
|
|
||||||||
| return shape({s_val, s_ind}); | ||||||||
| shape s_val{type, lens}; | ||||||||
| shape s_ind{shape::int64_type, lens}; | ||||||||
| return shape({s_val, s_ind}); | ||||||||
| } | ||||||||
| } | ||||||||
|
|
||||||||
| template <class Compare> | ||||||||
|
|
@@ -84,13 +108,15 @@ struct topk | |||||||
| }; | ||||||||
| } | ||||||||
|
|
||||||||
| argument compute(const shape& output_shape, std::vector<argument> args) const | ||||||||
| argument compute(const dyn_output& dyn_out, std::vector<argument> args) const | ||||||||
| { | ||||||||
| const auto& output_shape = dyn_out.computed_shape; | ||||||||
| const auto& vec_ss = output_shape.sub_shapes(); | ||||||||
| argument res_val{vec_ss.front()}; | ||||||||
| argument res_ind{vec_ss.back()}; | ||||||||
| auto in_val = args.front(); | ||||||||
| auto relements = in_val.get_shape().lens()[axis]; | ||||||||
| auto actual_k = k.has_value() ? std::min(static_cast<std::size_t>(*k), relements) : relements; | ||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [format.py] reported by reviewdog 🐶
Suggested change
|
||||||||
| auto make_indices = [&](const auto& m_idx) { | ||||||||
| return [&](int64_t i) { | ||||||||
| if(args.size() < 2) | ||||||||
|
|
@@ -118,20 +144,20 @@ struct topk | |||||||
| }); | ||||||||
| if(this->largest) | ||||||||
| std::partial_sort(data.begin(), | ||||||||
| data.begin() + k, | ||||||||
| data.begin() + actual_k, | ||||||||
| data.end(), | ||||||||
| compare_pair(std::greater<>{})); | ||||||||
| else | ||||||||
| std::partial_sort(data.begin(), | ||||||||
| data.begin() + k, | ||||||||
| data.begin() + actual_k, | ||||||||
| data.end(), | ||||||||
| compare_pair(std::less<>{})); | ||||||||
| std::transform(data.begin(), | ||||||||
| data.begin() + this->k, | ||||||||
| data.begin() + actual_k, | ||||||||
| y.begin(), | ||||||||
| [](const auto& p) { return p.first; }); | ||||||||
| std::transform(data.begin(), | ||||||||
| data.begin() + this->k, | ||||||||
| data.begin() + actual_k, | ||||||||
| y_ind.begin(), | ||||||||
| [](const auto& p) { return p.second; }); | ||||||||
| }); | ||||||||
|
|
||||||||
| Original file line number | Diff line number | Diff line change | ||||||||
|---|---|---|---|---|---|---|---|---|---|---|
| @@ -1,7 +1,7 @@ | ||||||||||
| /* | ||||||||||
| * The MIT License (MIT) | ||||||||||
| * | ||||||||||
| * Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved. | ||||||||||
| * Copyright (c) 2015-2026 Advanced Micro Devices, Inc. All rights reserved. | ||||||||||
| * | ||||||||||
| * Permission is hereby granted, free of charge, to any person obtaining a copy | ||||||||||
| * of this software and associated documentation files (the "Software"), to deal | ||||||||||
|
|
@@ -26,6 +26,7 @@ | |||||||||
| #include <migraphx/ranges.hpp> | ||||||||||
| #include <migraphx/instruction.hpp> | ||||||||||
| #include <migraphx/make_op.hpp> | ||||||||||
| #include <optional> | ||||||||||
|
|
||||||||||
| namespace migraphx { | ||||||||||
| inline namespace MIGRAPHX_INLINE_NS { | ||||||||||
|
|
@@ -40,18 +41,6 @@ struct parse_topk : op_parser<parse_topk> | |||||||||
| onnx_parser::node_info info, | ||||||||||
| std::vector<instruction_ref> args) const | ||||||||||
| { | ||||||||||
| int64_t k = 0; | ||||||||||
| if(args.size() == 2) | ||||||||||
| { | ||||||||||
| auto arg_k = args.at(1)->eval(); | ||||||||||
| check_arg_empty(arg_k, "PARSE_TopK: k input must be constant"); | ||||||||||
| k = arg_k.at<int>(); | ||||||||||
| } | ||||||||||
| else if(contains(info.attributes, "k")) | ||||||||||
| { | ||||||||||
| k = info.attributes.at("k").i(); | ||||||||||
| } | ||||||||||
|
|
||||||||||
| bool largest = true; | ||||||||||
| if(contains(info.attributes, "largest")) | ||||||||||
| { | ||||||||||
|
|
@@ -64,8 +53,27 @@ struct parse_topk : op_parser<parse_topk> | |||||||||
| axis = parser.parse_value(info.attributes.at("axis")).at<int>(); | ||||||||||
| } | ||||||||||
|
|
||||||||||
| auto topk_ret = info.add_instruction( | ||||||||||
| make_op("topk", {{"k", k}, {"axis", axis}, {"largest", largest}}), args.at(0)); | ||||||||||
| std::optional<int64_t> k; | ||||||||||
| if(args.size() == 2) | ||||||||||
| { | ||||||||||
| auto arg_k = args.at(1)->eval(); | ||||||||||
| if(not arg_k.empty()) | ||||||||||
| { | ||||||||||
| k = arg_k.at<int>(); | ||||||||||
| } | ||||||||||
| } | ||||||||||
| else if(contains(info.attributes, "k")) | ||||||||||
| { | ||||||||||
| k = info.attributes.at("k").i(); | ||||||||||
| } | ||||||||||
|
|
||||||||||
| auto topk_ret = | ||||||||||
| k.has_value() | ||||||||||
| ? info.add_instruction( | ||||||||||
| make_op("topk", {{"k", *k}, {"axis", axis}, {"largest", largest}}), | ||||||||||
| args.at(0)) | ||||||||||
| : info.add_instruction( | ||||||||||
| make_op("topk", {{"axis", axis}, {"largest", largest}}), args.at(0)); | ||||||||||
|
Comment on lines
+75
to
+76
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [format.py] reported by reviewdog 🐶
Suggested change
|
||||||||||
|
|
||||||||||
| auto ret_val = info.add_instruction(make_op("get_tuple_elem", {{"index", 0}}), topk_ret); | ||||||||||
| auto ret_ind = info.add_instruction(make_op("get_tuple_elem", {{"index", 1}}), topk_ret); | ||||||||||
|
|
||||||||||
| Original file line number | Diff line number | Diff line change | ||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -37,17 +37,20 @@ namespace { | |||||||||||
| struct find_large_topk | ||||||||||||
| { | ||||||||||||
| std::size_t n_threshold = 0; | ||||||||||||
| auto matcher() const { return match::name("topk"); } | ||||||||||||
| auto matcher() const | ||||||||||||
| { | ||||||||||||
| return match::name("topk")(match::arg(0)(match::not_dynamic_shape())); | ||||||||||||
| } | ||||||||||||
|
Comment on lines
+40
to
+43
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [format.py] reported by reviewdog 🐶
Suggested change
|
||||||||||||
|
|
||||||||||||
| void apply(module& m, const match::matcher_result& r) const | ||||||||||||
| { | ||||||||||||
| auto ins = r.result; | ||||||||||||
| auto input = ins->inputs().front(); | ||||||||||||
| auto op = ins->get_operator().to_value(); | ||||||||||||
| auto axis = op["axis"].to<std::int64_t>(); | ||||||||||||
| auto k = op["k"].to<std::int64_t>(); | ||||||||||||
| auto dims = input->get_shape().lens(); | ||||||||||||
| auto n = dims.at(axis); | ||||||||||||
| auto k = op["k"].is_null() ? static_cast<std::int64_t>(n) : op["k"].to<std::int64_t>(); | ||||||||||||
| if(n < n_threshold) | ||||||||||||
| return; | ||||||||||||
|
|
||||||||||||
|
|
||||||||||||
| Original file line number | Diff line number | Diff line change | ||||||
|---|---|---|---|---|---|---|---|---|
|
|
@@ -37,17 +37,18 @@ shape hip_topk::compute_shape(std::vector<shape> inputs) const | |||||||
| argument hip_topk::compute(context& ctx, const shape&, const std::vector<argument>& args) const | ||||||||
| { | ||||||||
| auto outputs = args.back().get_sub_objects(); | ||||||||
| auto actual_k = op.k.has_value() ? *op.k : static_cast<int64_t>(args[0].get_shape().lens()[op.axis]); | ||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [format.py] reported by reviewdog 🐶
Suggested change
|
||||||||
| return op.largest ? device::topk_largest(ctx.get_stream().get(), | ||||||||
| outputs.front(), | ||||||||
| outputs.back(), | ||||||||
| args[0], | ||||||||
| op.k, | ||||||||
| actual_k, | ||||||||
| op.axis) | ||||||||
| : device::topk_smallest(ctx.get_stream().get(), | ||||||||
| outputs.front(), | ||||||||
| outputs.back(), | ||||||||
| args[0], | ||||||||
| op.k, | ||||||||
| actual_k, | ||||||||
| op.axis); | ||||||||
| } | ||||||||
|
|
||||||||
|
|
||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[format.py] reported by reviewdog 🐶