diff --git a/src/include/migraphx/dyn_output.hpp b/src/include/migraphx/dyn_output.hpp index ac3263cde3b..8e4b7f1d529 100644 --- a/src/include/migraphx/dyn_output.hpp +++ b/src/include/migraphx/dyn_output.hpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2026 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -52,7 +52,7 @@ struct compute_output_shape operator dyn_output() const { return ins_inputs([](const auto& x, shape ins_shape, const std::vector& inputs) { - if(ins_shape.dynamic()) + if(ins_shape.any_of_dynamic()) // some op returns a tuple shape e.g. TopK return dyn_output{ins_shape, compute_shape(x, to_shapes(inputs))}; return dyn_output{ins_shape, ins_shape}; }); diff --git a/src/include/migraphx/op/topk.hpp b/src/include/migraphx/op/topk.hpp index 5ff9393e24c..254df980c18 100644 --- a/src/include/migraphx/op/topk.hpp +++ b/src/include/migraphx/op/topk.hpp @@ -28,10 +28,12 @@ #include #include #include +#include #include #include #include #include +#include namespace migraphx { inline namespace MIGRAPHX_INLINE_NS { @@ -39,7 +41,7 @@ namespace op { struct topk { - int64_t k = 1; + std::optional k; int64_t axis = 0; bool largest = true; @@ -60,16 +62,38 @@ struct topk shape normalize_compute_shape(std::vector inputs) const { - check_shapes{inputs, *this}.has(1, 2); - auto lens = inputs.at(0).lens(); + check_shapes{inputs, *this, true}.has(1, 2); auto type = inputs.at(0).type(); - lens[axis] = k; + if(inputs.at(0).dynamic()) + { + auto dyn_dims = inputs.at(0).dyn_dims(); + if(k.has_value()) + { + auto min_lens_vec = inputs.at(0).min_lens(); + auto max_lens_vec = inputs.at(0).max_lens(); + auto min_kk = std::min(static_cast(*k), min_lens_vec[axis]); + auto max_kk = std::min(static_cast(*k), max_lens_vec[axis]); + dyn_dims[axis] = {min_kk, max_kk}; + } - shape s_val{type, lens}; - shape s_ind{shape::int64_type, lens}; + shape s_val{type, dyn_dims}; + shape s_ind{shape::int64_type, dyn_dims}; + return shape({s_val, s_ind}); + } + else + { + auto lens = inputs.at(0).lens(); + if(k.has_value()) + { + auto kk = std::min(static_cast(*k), lens[axis]); + lens[axis] = kk; + } - return shape({s_val, s_ind}); + shape s_val{type, lens}; + shape s_ind{shape::int64_type, lens}; + return shape({s_val, s_ind}); + } } template @@ -84,13 +108,15 @@ struct topk }; } - argument compute(const shape& output_shape, std::vector args) const + argument compute(const dyn_output& dyn_out, std::vector args) const { + const auto& output_shape = dyn_out.computed_shape; const auto& vec_ss = output_shape.sub_shapes(); argument res_val{vec_ss.front()}; argument res_ind{vec_ss.back()}; auto in_val = args.front(); auto relements = in_val.get_shape().lens()[axis]; + auto actual_k = k.has_value() ? std::min(static_cast(*k), relements) : relements; auto make_indices = [&](const auto& m_idx) { return [&](int64_t i) { if(args.size() < 2) @@ -118,20 +144,20 @@ struct topk }); if(this->largest) std::partial_sort(data.begin(), - data.begin() + k, + data.begin() + actual_k, data.end(), compare_pair(std::greater<>{})); else std::partial_sort(data.begin(), - data.begin() + k, + data.begin() + actual_k, data.end(), compare_pair(std::less<>{})); std::transform(data.begin(), - data.begin() + this->k, + data.begin() + actual_k, y.begin(), [](const auto& p) { return p.first; }); std::transform(data.begin(), - data.begin() + this->k, + data.begin() + actual_k, y_ind.begin(), [](const auto& p) { return p.second; }); }); diff --git a/src/onnx/parse_topk.cpp b/src/onnx/parse_topk.cpp index 66ab9f7ad95..9d181928490 100644 --- a/src/onnx/parse_topk.cpp +++ b/src/onnx/parse_topk.cpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2026 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -26,6 +26,7 @@ #include #include #include +#include namespace migraphx { inline namespace MIGRAPHX_INLINE_NS { @@ -40,18 +41,6 @@ struct parse_topk : op_parser onnx_parser::node_info info, std::vector args) const { - int64_t k = 0; - if(args.size() == 2) - { - auto arg_k = args.at(1)->eval(); - check_arg_empty(arg_k, "PARSE_TopK: k input must be constant"); - k = arg_k.at(); - } - else if(contains(info.attributes, "k")) - { - k = info.attributes.at("k").i(); - } - bool largest = true; if(contains(info.attributes, "largest")) { @@ -64,8 +53,27 @@ struct parse_topk : op_parser axis = parser.parse_value(info.attributes.at("axis")).at(); } - auto topk_ret = info.add_instruction( - make_op("topk", {{"k", k}, {"axis", axis}, {"largest", largest}}), args.at(0)); + std::optional k; + if(args.size() == 2) + { + auto arg_k = args.at(1)->eval(); + if(not arg_k.empty()) + { + k = arg_k.at(); + } + } + else if(contains(info.attributes, "k")) + { + k = info.attributes.at("k").i(); + } + + auto topk_ret = + k.has_value() + ? info.add_instruction( + make_op("topk", {{"k", *k}, {"axis", axis}, {"largest", largest}}), + args.at(0)) + : info.add_instruction( + make_op("topk", {{"axis", axis}, {"largest", largest}}), args.at(0)); auto ret_val = info.add_instruction(make_op("get_tuple_elem", {{"index", 0}}), topk_ret); auto ret_ind = info.add_instruction(make_op("get_tuple_elem", {{"index", 1}}), topk_ret); diff --git a/src/program.cpp b/src/program.cpp index 617a215f361..41889d7a1e0 100644 --- a/src/program.cpp +++ b/src/program.cpp @@ -671,7 +671,13 @@ std::vector program::eval(const parameter_map& params, if(trace_level > 0) { ctx.finish(); - std::cout << "Run instruction: " << ins_out.at(ins) << std::endl; + // The ins_out map is populated from the main module's + // but when dynamic_code_object_op::compute recursively calls generic_eval + // on its runtime sub-module ins_out don't have it. + if(ins_out.find(ins) != ins_out.end()) + std::cout << "Run instruction: " << ins_out.at(ins) << std::endl; + else + std::cout << "Run instruction: " << ins->name() << " (submodule)" << std::endl; } timer t{}; auto result = f(); diff --git a/src/rewrite_topk.cpp b/src/rewrite_topk.cpp index 19411680db8..83c91bc421e 100644 --- a/src/rewrite_topk.cpp +++ b/src/rewrite_topk.cpp @@ -37,7 +37,10 @@ namespace { struct find_large_topk { std::size_t n_threshold = 0; - auto matcher() const { return match::name("topk"); } + auto matcher() const + { + return match::name("topk")(match::arg(0)(match::not_dynamic_shape())); + } void apply(module& m, const match::matcher_result& r) const { @@ -45,9 +48,9 @@ struct find_large_topk auto input = ins->inputs().front(); auto op = ins->get_operator().to_value(); auto axis = op["axis"].to(); - auto k = op["k"].to(); auto dims = input->get_shape().lens(); auto n = dims.at(axis); + auto k = op["k"].is_null() ? static_cast(n) : op["k"].to(); if(n < n_threshold) return; diff --git a/src/targets/gpu/compile_ops.cpp b/src/targets/gpu/compile_ops.cpp index 52272b1d7af..83f1f84acf1 100644 --- a/src/targets/gpu/compile_ops.cpp +++ b/src/targets/gpu/compile_ops.cpp @@ -159,11 +159,12 @@ struct dynamic_code_object_op return results.front(); } - if(output_arg.get_shape().dynamic()) - { - auto out_shape = pre_op.compute_shape(to_shapes(static_args), module_args); - static_args[static_args.size() - 1] = output_arg.reshape(out_shape); - } + // static shape code can't be here, remove the check. + auto out_shape = pre_op.compute_shape(to_shapes(static_args), module_args); + static_args[static_args.size() - 1] = output_arg.reshape(out_shape); + // Skip JIT compilation when dynamic shape resolves to 0 elements at runtime + if(args.front().get_shape().elements() == 0) + return static_args.back(); // Rewrite submodule without dynamic shapes to be used as the IR for compilation module static_submod; diff --git a/src/targets/gpu/include/migraphx/gpu/hip.hpp b/src/targets/gpu/include/migraphx/gpu/hip.hpp index d04e81e218c..c1134e13b5c 100644 --- a/src/targets/gpu/include/migraphx/gpu/hip.hpp +++ b/src/targets/gpu/include/migraphx/gpu/hip.hpp @@ -252,7 +252,6 @@ struct hip_allocate_memory { return get_preallocation(ctx, id); } - void finalize(context& ctx, const shape&, const std::vector&) const { argument a = allocate_gpu(s); diff --git a/src/targets/gpu/lowering.cpp b/src/targets/gpu/lowering.cpp index 5eca58aaa13..f19f2716226 100644 --- a/src/targets/gpu/lowering.cpp +++ b/src/targets/gpu/lowering.cpp @@ -239,8 +239,8 @@ struct miopen_apply instruction_ref insert_dynamic_code_object_op(instruction_ref ins) const { assert(ins->get_operator().name() == "gpu::precompile_op"); - - if(not ins->get_shape().dynamic()) + // some op returns a tuple shape e.g. TopK + if(not ins->get_shape().any_of_dynamic()) return ins; return mod->replace_instruction( diff --git a/src/targets/gpu/topk.cpp b/src/targets/gpu/topk.cpp index 2e799c650af..278d05d3a9d 100644 --- a/src/targets/gpu/topk.cpp +++ b/src/targets/gpu/topk.cpp @@ -37,17 +37,18 @@ shape hip_topk::compute_shape(std::vector inputs) const argument hip_topk::compute(context& ctx, const shape&, const std::vector& args) const { auto outputs = args.back().get_sub_objects(); + auto actual_k = op.k.has_value() ? *op.k : static_cast(args[0].get_shape().lens()[op.axis]); return op.largest ? device::topk_largest(ctx.get_stream().get(), outputs.front(), outputs.back(), args[0], - op.k, + actual_k, op.axis) : device::topk_smallest(ctx.get_stream().get(), outputs.front(), outputs.back(), args[0], - op.k, + actual_k, op.axis); } diff --git a/test/ref/topk.cpp b/test/ref/topk.cpp index 5e2ea0e0246..3f71f4f8879 100644 --- a/test/ref/topk.cpp +++ b/test/ref/topk.cpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2026 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -145,3 +145,92 @@ TEST_CASE(topk_smallest_custom_indices) std::vector gold_ind = {11, 13, 15, 14, 7, 9, 6, 10, 2, 5, 1, 3}; EXPECT(results.second == gold_ind); } + +// Test k > n with dynamic shapes: k=100 placeholder but runtime input has 5 elements +TEST_CASE(topk_k_greater_than_n_dynamic) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + // Dynamic shape: axis 0 ranges from 1 to 100 + std::vector dds = {{1, 100}}; + migraphx::shape s{migraphx::shape::float_type, dds}; + auto data = mm->add_parameter("data", s); + // k=100 is the max placeholder from parse time + auto r = mm->add_instruction( + migraphx::make_op("topk", {{"axis", 0}, {"k", 100}, {"largest", 1}}), data); + auto r0 = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), r); + auto r1 = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 1}}), r); + mm->add_return({r0, r1}); + + p.compile(migraphx::make_target("ref")); + + // Runtime: only 5 elements + std::vector input_data = {3.0f, 1.0f, 4.0f, 1.5f, 2.0f}; + migraphx::shape input_fixed{migraphx::shape::float_type, {5}}; + migraphx::parameter_map pp; + pp["data"] = migraphx::argument(input_fixed, input_data.data()); + auto rets = p.eval(pp); + + std::vector ret_val; + rets.front().visit([&](auto v) { ret_val.assign(v.begin(), v.end()); }); + std::vector ret_ind; + rets.back().visit([&](auto v) { ret_ind.assign(v.begin(), v.end()); }); + + // k=100 clamped to n=5, sorted descending + EXPECT(ret_val.size() == 5u); + std::vector gold_val = {4.0f, 3.0f, 2.0f, 1.5f, 1.0f}; + EXPECT(ret_val == gold_val); + std::vector gold_ind = {2, 0, 4, 3, 1}; + EXPECT(ret_ind == gold_ind); +} + +// Test k == n: k equals the axis dimension, should return all elements sorted +TEST_CASE(topk_k_equals_n) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape s{migraphx::shape::float_type, {3, 5}}; + auto data = mm->add_parameter("data", s); + // k=5 equals axis=1 dimension of 5 + auto r = mm->add_instruction(migraphx::make_op("topk", {{"axis", 1}, {"k", 5}, {"largest", 0}}), + data); + auto r0 = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), r); + auto r1 = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 1}}), r); + mm->add_return({r0, r1}); + + p.compile(migraphx::make_target("ref")); + + std::vector input_data = { + 2.1, + 2.3, + 2.0, + 2.5, + 1.9, + 3.3, + 0.2, + 4.5, + 0.1, + 0.8, + 1.0, + 4.5, + 2.1, + 0.8, + 1.5, + }; + migraphx::parameter_map pp; + pp["data"] = migraphx::argument(s, input_data.data()); + auto rets = p.eval(pp); + + std::vector ret_val; + rets.front().visit([&](auto v) { ret_val.assign(v.begin(), v.end()); }); + std::vector ret_ind; + rets.back().visit([&](auto v) { ret_ind.assign(v.begin(), v.end()); }); + + // All 5 elements returned per row, sorted ascending (smallest first) + EXPECT(ret_val.size() == 15u); + std::vector gold_val = { + 1.9, 2.0, 2.1, 2.3, 2.5, 0.1, 0.2, 0.8, 3.3, 4.5, 0.8, 1.0, 1.5, 2.1, 4.5}; + EXPECT(ret_val == gold_val); + std::vector gold_ind = {4, 2, 0, 1, 3, 3, 1, 4, 0, 2, 3, 0, 4, 2, 1}; + EXPECT(ret_ind == gold_ind); +} diff --git a/test/verify/test_topk_dynamic.cpp b/test/verify/test_topk_dynamic.cpp new file mode 100644 index 00000000000..89df61959fd --- /dev/null +++ b/test/verify/test_topk_dynamic.cpp @@ -0,0 +1,55 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2026 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ + +#include "verify_program.hpp" +#include +#include +#include + +// Test k > n with dynamic shapes: k=100 placeholder but runtime input has fewer elements +template +struct test_topk_dynamic : verify_program> +{ + migraphx::program create_program() const + { + migraphx::program p; + auto* mm = p.get_main_module(); + std::vector dds = {{1, 100}}; + migraphx::shape s{migraphx::shape::float_type, dds}; + auto data = mm->add_parameter("data", s); + auto r = mm->add_instruction( + migraphx::make_op("topk", {{"axis", 0}, {"k", 100}, {"largest", 1}}), data); + auto r0 = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), r); + auto r1 = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 1}}), r); + mm->add_return({r0, r1}); + return p; + } + + std::unordered_map get_test_dims() const + { + return {{"data", migraphx::shape{migraphx::shape::float_type, {N}}}}; + } +}; + +template struct test_topk_dynamic<10>;