Add dynamic shape support for TopK#4880
Conversation
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## develop #4880 +/- ##
===========================================
- Coverage 92.86% 92.66% -0.20%
===========================================
Files 585 588 +3
Lines 30152 30433 +281
===========================================
+ Hits 27998 28199 +201
- Misses 2154 2234 +80
🚀 New features to boost your workflow:
|
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
There was a problem hiding this comment.
Pull request overview
This PR extends MIGraphX’s TopK support to handle dynamic input shapes and cases where k is not a compile-time constant (by using a placeholder k derived from the input shape), and adjusts dynamic-output handling in evaluation/GPU lowering to better support tuple-shaped outputs like TopK.
Changes:
- Update ONNX TopK parsing to allow non-constant
kby deriving a placeholderkfrom the input shape (static lens or dynamic max lens). - Update
op::topkto support dynamic shapes at shape-inference/eval time viadyn_output, including runtime clamping ofk. - Adjust GPU/lowering + dyn-output plumbing for tuple/dynamic shapes, and add tests for dynamic TopK scenarios.
Reviewed changes
Copilot reviewed 10 out of 10 changed files in this pull request and generated 6 comments.
Show a summary per file
| File | Description |
|---|---|
| test/verify/test_topk_dynamic.cpp | Adds a verify test exercising TopK with a dynamic input shape and placeholder k. |
| test/ref/topk.cpp | Adds reference tests for k > n with dynamic input and for k == n behavior. |
| src/targets/gpu/lowering.cpp | Switches dynamic-shape detection to any_of_dynamic() to account for tuple outputs. |
| src/targets/gpu/include/migraphx/gpu/hip.hpp | Minor whitespace-only adjustment. |
| src/targets/gpu/compile_ops.cpp | Alters dynamic code-object execution to always compute/reshape runtime output shape and adds a skip-on-empty heuristic. |
| src/rewrite_topk.cpp | Disables large-TopK rewrite when input shape is dynamic. |
| src/program.cpp | Makes tracing more robust when evaluating submodules from dynamic code-object ops. |
| src/onnx/parse_topk.cpp | Removes the “k must be constant” restriction; derives placeholder k from input shape when needed. |
| src/include/migraphx/op/topk.hpp | Adds dynamic-shape support and switches compute to dyn_output, clamping k at runtime. |
| src/include/migraphx/dyn_output.hpp | Uses any_of_dynamic() so tuple outputs with dynamic subshapes get runtime-computed shapes. |
| auto input_shape = args.at(0)->get_shape(); | ||
| auto ndim = input_shape.ndim(); | ||
| auto norm_axis = axis < 0 ? axis + static_cast<int64_t>(ndim) : axis; | ||
| if(input_shape.dynamic()) | ||
| { |
| // k is not constant: use the input dimension along the topk axis | ||
| auto input_shape = args.at(0)->get_shape(); | ||
| auto ndim = input_shape.ndim(); | ||
| auto norm_axis = axis < 0 ? axis + static_cast<int64_t>(ndim) : axis; | ||
| if(input_shape.dynamic()) | ||
| { | ||
| k = input_shape.dyn_dims().at(norm_axis).get_interval().max; | ||
| } | ||
| else | ||
| { | ||
| k = input_shape.lens().at(norm_axis); |
| auto out_shape = pre_op.compute_shape(to_shapes(static_args), module_args); | ||
| static_args[static_args.size() - 1] = output_arg.reshape(out_shape); | ||
| // Skip JIT compilation when dynamic shape resolves to 0 elements at runtime | ||
| if(args.front().get_shape().elements() == 0) | ||
| return static_args.back(); |
| auto topk_ret = info.add_instruction( | ||
| make_op("topk", {{"k", k}, {"axis", axis}, {"largest", largest}}), args.at(0)); | ||
|
|
Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com>
Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com>
|
Hi @pfultz2 |
| if(arg_k.empty()) | ||
| { | ||
| // k is not constant: use the input dimension along the topk axis |
There was a problem hiding this comment.
Rather than a placeholder value of max_len make the k attribute in topk a std::optional<int64_t>. This would be more obvious what is meant.
| auto ins = r.result; | ||
| auto input = ins->inputs().front(); | ||
| if(input->get_shape().dynamic()) | ||
| return; |
There was a problem hiding this comment.
Use not_dynamic_shape in the matcher rather than checking here.
… represent unknown k instead of using a placeholder value
| lens[axis] = k; | ||
| if(inputs.at(0).dynamic()) | ||
| { | ||
| auto dyn_dims = inputs.at(0).dyn_dims(); |
There was a problem hiding this comment.
[format.py] reported by reviewdog 🐶
| auto dyn_dims = inputs.at(0).dyn_dims(); | |
| auto dyn_dims = inputs.at(0).dyn_dims(); |
| argument res_ind{vec_ss.back()}; | ||
| auto in_val = args.front(); | ||
| auto relements = in_val.get_shape().lens()[axis]; | ||
| auto actual_k = k.has_value() ? std::min(static_cast<std::size_t>(*k), relements) : relements; |
There was a problem hiding this comment.
[format.py] reported by reviewdog 🐶
| auto actual_k = k.has_value() ? std::min(static_cast<std::size_t>(*k), relements) : relements; | |
| auto actual_k = | |
| k.has_value() ? std::min(static_cast<std::size_t>(*k), relements) : relements; |
| : info.add_instruction( | ||
| make_op("topk", {{"axis", axis}, {"largest", largest}}), args.at(0)); |
There was a problem hiding this comment.
[format.py] reported by reviewdog 🐶
| : info.add_instruction( | |
| make_op("topk", {{"axis", axis}, {"largest", largest}}), args.at(0)); | |
| : info.add_instruction(make_op("topk", {{"axis", axis}, {"largest", largest}}), | |
| args.at(0)); |
| auto matcher() const | ||
| { | ||
| return match::name("topk")(match::arg(0)(match::not_dynamic_shape())); | ||
| } |
There was a problem hiding this comment.
[format.py] reported by reviewdog 🐶
| auto matcher() const | |
| { | |
| return match::name("topk")(match::arg(0)(match::not_dynamic_shape())); | |
| } | |
| auto matcher() const { return match::name("topk")(match::arg(0)(match::not_dynamic_shape())); } |
| argument hip_topk::compute(context& ctx, const shape&, const std::vector<argument>& args) const | ||
| { | ||
| auto outputs = args.back().get_sub_objects(); | ||
| auto actual_k = op.k.has_value() ? *op.k : static_cast<int64_t>(args[0].get_shape().lens()[op.axis]); |
There was a problem hiding this comment.
[format.py] reported by reviewdog 🐶
| auto actual_k = op.k.has_value() ? *op.k : static_cast<int64_t>(args[0].get_shape().lens()[op.axis]); | |
| auto actual_k = | |
| op.k.has_value() ? *op.k : static_cast<int64_t>(args[0].get_shape().lens()[op.axis]); |
Motivation
The topk operator previously required a constant k and static input shapes, which will block the SSDMobileNetV2 model from running.
AMDMIGraphX/src/onnx/parse_topk.cpp
Line 47 in 93b8849
This PR adds dynamic input support for the topk op.
Affected model: SSDMobileNetV2
Technical Details
.\bin\migraphx-driver.exe perf .\bin\topk_shape_derived_k.onnx can run success.
onnx file
topk_shape_derived_k.zip
Add a
CHANGELOG.mdentry for any option other thanNot Applicable